//! resolver implementations implementing different strategies for answering //! incoming queries use std::sync::Arc; use std::time::Instant; use std::vec::Vec; use derive_more::{Display, Error, From}; use crate::dns::context::ServerContext; use crate::dns::protocol::{DnsPacket, QueryType, ResultCode}; #[derive(Debug, Display, From, Error)] pub enum ResolveError { Client(crate::dns::client::ClientError), Cache(crate::dns::cache::CacheError), Io(std::io::Error), NoServerFound } type Result = std::result::Result; pub trait DnsResolver { fn get_context(&self) -> Arc; fn resolve(&mut self, qname: &str, qtype: QueryType, recursive: bool) -> Result { if let QueryType::UNKNOWN(_) = qtype { let mut packet = DnsPacket::new(); packet.header.rescode = ResultCode::NOTIMP; return Ok(packet); } let context = self.get_context(); if let Some(qr) = context.authority.query(qname, qtype) { return Ok(qr); } if !recursive || !context.allow_recursive { let mut packet = DnsPacket::new(); packet.header.rescode = ResultCode::REFUSED; return Ok(packet); } if let Some(qr) = context.cache.lookup(qname, qtype) { return Ok(qr); } if qtype == QueryType::A || qtype == QueryType::AAAA { if let Some(qr) = context.cache.lookup(qname, QueryType::CNAME) { return Ok(qr); } } for filter in context.filters.iter() { if let Some(packet) = filter.lookup(qname, qtype, recursive) { context.cache.store(&packet.answers)?; return Ok(packet); } } self.perform(qname, qtype) } fn perform(&mut self, qname: &str, qtype: QueryType) -> Result; } /// A Forwarding DNS Resolver /// /// This resolver uses an external DNS server to service a query pub struct ForwardingDnsResolver { context: Arc, upstreams: Vec } impl ForwardingDnsResolver { pub fn new(context: Arc, upstreams: Vec) -> ForwardingDnsResolver { ForwardingDnsResolver { context, upstreams } } } impl DnsResolver for ForwardingDnsResolver { fn get_context(&self) -> Arc { Arc::clone(&self.context) } fn perform(&mut self, qname: &str, qtype: QueryType) -> Result { if let Some(packet) = self.context.cache.lookup(qname, qtype) { return Ok(packet); } let ordered = self.context.forwarder_tracker.select_ordered(&self.upstreams); let mut last_err = ResolveError::NoServerFound; for upstream in &ordered { let start = Instant::now(); let query_result = if is_url(upstream) { if let Some(client) = &self.context.doh_client { client.send_query(qname, qtype, upstream, true) } else { log::error!("This build doesn't support DoH"); continue; } } else { self.context.old_client.send_query(qname, qtype, upstream, true) }; match query_result { Ok(mut result) => { let elapsed = start.elapsed().as_secs_f64() * 1000.0; self.context.forwarder_tracker.record_success(upstream, elapsed); self.context.cache.store(&result.answers)?; // Fix domain names in answers to match original query case let qname_lower = qname.to_lowercase(); for answer in &mut result.answers { if let Some(domain) = answer.get_domain() { if domain.to_lowercase() == qname_lower { answer.set_domain(qname.to_string()); } } } return Ok(result); } Err(e) => { self.context.forwarder_tracker.record_failure(upstream); last_err = e.into(); } } } Err(last_err) } } /// A Recursive DNS resolver /// /// This resolver can answer any request using the root servers of the internet pub struct RecursiveDnsResolver { context: Arc } impl RecursiveDnsResolver { pub fn new(context: Arc) -> RecursiveDnsResolver { RecursiveDnsResolver { context } } } impl DnsResolver for RecursiveDnsResolver { fn get_context(&self) -> Arc { Arc::clone(&self.context) } fn perform(&mut self, qname: &str, qtype: QueryType) -> Result { // Find the closest name server by splitting the label and progessively // moving towards the root servers. I.e. check "google.com", then "com", // and finally "". let mut tentative_ns = None; let labels = qname.split('.').collect::>(); for lbl_idx in 0..labels.len() + 1 { let domain = labels[lbl_idx..].join("."); match self .context .cache .lookup(&domain, QueryType::NS) .and_then(|qr| qr.get_unresolved_ns(&domain)) .and_then(|ns| self.context.cache.lookup(&ns, QueryType::A)) .and_then(|qr| qr.get_random_a()) { Some(addr) => { tentative_ns = Some(addr); break; } None => continue } } let mut ns = tentative_ns.ok_or(ResolveError::NoServerFound)?; // Start querying name servers loop { println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns); let ns_copy = ns.clone(); let server = format!("{}:{}", ns_copy.as_str(), 53); let response = self.context.old_client.send_query(qname, qtype, &server, false)?; // If we've got an actual answer, we're done! if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { let _ = self.context.cache.store(&response.answers); let _ = self.context.cache.store(&response.authorities); let _ = self.context.cache.store(&response.resources); // Fix domain names in answers to match original query case let qname_lower = qname.to_lowercase(); let mut fixed_response = response; for answer in &mut fixed_response.answers { if let Some(domain) = answer.get_domain() { if domain.to_lowercase() == qname_lower { answer.set_domain(qname.to_string()); } } } return Ok(fixed_response); } if response.header.rescode == ResultCode::NXDOMAIN { if let Some(ttl) = response.get_ttl_from_soa() { let _ = self.context.cache.store_nxdomain(qname, qtype, ttl); } return Ok(response); } // Otherwise, try to find a new nameserver based on NS and a // corresponding A record in the additional section if let Some(new_ns) = response.get_resolved_ns(qname) { // If there is such a record, we can retry the loop with that NS ns = new_ns.clone(); let _ = self.context.cache.store(&response.answers); let _ = self.context.cache.store(&response.authorities); let _ = self.context.cache.store(&response.resources); continue; } // If not, we'll have to resolve the ip of a NS record let new_ns_name = match response.get_unresolved_ns(qname) { Some(x) => x, None => { // Fix domain names before returning let qname_lower = qname.to_lowercase(); let mut fixed_response = response; for answer in &mut fixed_response.answers { if let Some(domain) = answer.get_domain() { if domain.to_lowercase() == qname_lower { answer.set_domain(qname.to_string()); } } } return Ok(fixed_response); } }; // Recursively resolve the NS let recursive_response = self.resolve(&new_ns_name, QueryType::A, true)?; // Pick a random IP and restart if let Some(new_ns) = recursive_response.get_random_a() { ns = new_ns.clone(); } else { // Fix domain names before returning let qname_lower = qname.to_lowercase(); let mut fixed_response = response; for answer in &mut fixed_response.answers { if let Some(domain) = answer.get_domain() { if domain.to_lowercase() == qname_lower { answer.set_domain(qname.to_string()); } } } return Ok(fixed_response); } } } } fn is_url(url: &str) -> bool { url.starts_with("https://") } #[cfg(test)] mod tests { use std::sync::Arc; use super::*; use crate::dns::context::tests::create_test_context; use crate::dns::context::ResolveStrategy; use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode, TransientTtl}; #[test] fn test_forwarding_resolver() { let mut context = create_test_context(Box::new(|qname, _, _, _| { let mut packet = DnsPacket::new(); if qname == "google.com" { packet.answers.push(DnsRecord::A { domain: "google.com".to_string(), addr: "127.0.0.1".parse().unwrap(), ttl: TransientTtl(3600) }); } else { packet.header.rescode = ResultCode::NXDOMAIN; } Ok(packet) })); match Arc::get_mut(&mut context) { Some(ctx) => { ctx.resolve_strategy = ResolveStrategy::Forward { upstreams: vec![String::from("127.0.0.1:53")] }; } None => panic!() } let mut resolver = context.create_resolver(Arc::clone(&context)); // First, verify that we get a match back { let res = match resolver.resolve("google.com", QueryType::A, true) { Ok(x) => x, Err(_) => panic!() }; assert_eq!(1, res.answers.len()); match res.answers[0] { DnsRecord::A { ref domain, .. } => { assert_eq!("google.com", domain); } _ => panic!() } }; // Do the same lookup again and verify that it's present in the cache // and that the counter has been updated { let res = match resolver.resolve("google.com", QueryType::A, true) { Ok(x) => x, Err(_) => panic!() }; assert_eq!(1, res.answers.len()); let list = match context.cache.list() { Ok(x) => x, Err(_) => panic!() }; assert_eq!(1, list.len()); assert_eq!("google.com", list[0].domain); assert_eq!(1, list[0].record_types.len()); assert_eq!(1, list[0].hits); }; // Do a failed lookup { let res = match resolver.resolve("yahoo.com", QueryType::A, true) { Ok(x) => x, Err(_) => panic!() }; assert_eq!(0, res.answers.len()); assert_eq!(ResultCode::NXDOMAIN, res.header.rescode); }; } #[test] fn test_recursive_resolver_with_no_nameserver() { let context = create_test_context(Box::new(|_, _, _, _| { let mut packet = DnsPacket::new(); packet.header.rescode = ResultCode::NXDOMAIN; Ok(packet) })); let mut resolver = context.create_resolver(Arc::clone(&context)); // Expect failure when no name servers are available if let Ok(_) = resolver.resolve("google.com", QueryType::A, true) { panic!(); } } #[test] fn test_recursive_resolver_with_missing_a_record() { let context = create_test_context(Box::new(|_, _, _, _| { let mut packet = DnsPacket::new(); packet.header.rescode = ResultCode::NXDOMAIN; Ok(packet) })); let mut resolver = context.create_resolver(Arc::clone(&context)); // Expect failure when no name servers are available if let Ok(_) = resolver.resolve("google.com", QueryType::A, true) { panic!(); } // Insert name server, but no corresponding A record let mut nameservers = Vec::new(); nameservers.push(DnsRecord::NS { domain: "".to_string(), host: "a.myroot.net".to_string(), ttl: TransientTtl(3600) }); let _ = context.cache.store(&nameservers); if let Ok(_) = resolver.resolve("google.com", QueryType::A, true) { panic!(); } } #[test] fn test_recursive_resolver_match_order() { let context = create_test_context(Box::new(|_, _, server, _| { let mut packet = DnsPacket::new(); if server.starts_with("127.0.0.1") { packet.header.id = 1; packet.answers.push(DnsRecord::A { domain: "a.google.com".to_string(), addr: "127.0.0.1".parse().unwrap(), ttl: TransientTtl(3600) }); return Ok(packet); } else if server.starts_with("127.0.0.2") { packet.header.id = 2; packet.answers.push(DnsRecord::A { domain: "b.google.com".to_string(), addr: "127.0.0.1".parse().unwrap(), ttl: TransientTtl(3600) }); return Ok(packet); } else if server.starts_with("127.0.0.3") { packet.header.id = 3; packet.answers.push(DnsRecord::A { domain: "c.google.com".to_string(), addr: "127.0.0.1".parse().unwrap(), ttl: TransientTtl(3600) }); return Ok(packet); } packet.header.id = 999; packet.header.rescode = ResultCode::NXDOMAIN; Ok(packet) })); let mut resolver = context.create_resolver(Arc::clone(&context)); // Expect failure when no name servers are available if let Ok(_) = resolver.resolve("google.com", QueryType::A, true) { panic!(); } // Insert root servers { let mut nameservers = Vec::new(); nameservers.push(DnsRecord::NS { domain: "".to_string(), host: "a.myroot.net".to_string(), ttl: TransientTtl(3600) }); nameservers.push(DnsRecord::A { domain: "a.myroot.net".to_string(), addr: "127.0.0.1".parse().unwrap(), ttl: TransientTtl(3600) }); let _ = context.cache.store(&nameservers); } match resolver.resolve("google.com", QueryType::A, true) { Ok(packet) => { assert_eq!(1, packet.header.id); } Err(_) => panic!() } // Insert TLD servers { let mut nameservers = Vec::new(); nameservers.push(DnsRecord::NS { domain: "com".to_string(), host: "a.mytld.net".to_string(), ttl: TransientTtl(3600) }); nameservers.push(DnsRecord::A { domain: "a.mytld.net".to_string(), addr: "127.0.0.2".parse().unwrap(), ttl: TransientTtl(3600) }); let _ = context.cache.store(&nameservers); } match resolver.resolve("google.com", QueryType::A, true) { Ok(packet) => { assert_eq!(2, packet.header.id); } Err(_) => panic!() } // Insert authoritative servers { let mut nameservers = Vec::new(); nameservers.push(DnsRecord::NS { domain: "google.com".to_string(), host: "ns1.google.com".to_string(), ttl: TransientTtl(3600) }); nameservers.push(DnsRecord::A { domain: "ns1.google.com".to_string(), addr: "127.0.0.3".parse().unwrap(), ttl: TransientTtl(3600) }); let _ = context.cache.store(&nameservers); } match resolver.resolve("google.com", QueryType::A, true) { Ok(packet) => { assert_eq!(3, packet.header.id); } Err(_) => panic!() } } #[test] fn test_recursive_resolver_successfully() { let context = create_test_context(Box::new(|qname, _, _, _| { let mut packet = DnsPacket::new(); if qname == "google.com" { packet.answers.push(DnsRecord::A { domain: "google.com".to_string(), addr: "127.0.0.1".parse().unwrap(), ttl: TransientTtl(3600) }); } else { packet.header.rescode = ResultCode::NXDOMAIN; packet.authorities.push(DnsRecord::SOA { domain: "google.com".to_string(), r_name: "google.com".to_string(), m_name: "google.com".to_string(), serial: 0, refresh: 3600, retry: 3600, expire: 3600, minimum: 3600, ttl: TransientTtl(3600) }); } Ok(packet) })); let mut resolver = context.create_resolver(Arc::clone(&context)); // Insert name servers let mut nameservers = Vec::new(); nameservers.push(DnsRecord::NS { domain: "google.com".to_string(), host: "ns1.google.com".to_string(), ttl: TransientTtl(3600) }); nameservers.push(DnsRecord::A { domain: "ns1.google.com".to_string(), addr: "127.0.0.1".parse().unwrap(), ttl: TransientTtl(3600) }); let _ = context.cache.store(&nameservers); // Check that we can successfully resolve { let res = match resolver.resolve("google.com", QueryType::A, true) { Ok(x) => x, Err(_) => panic!() }; assert_eq!(1, res.answers.len()); match res.answers[0] { DnsRecord::A { ref domain, .. } => { assert_eq!("google.com", domain); } _ => panic!() } }; // And that we won't find anything for a domain that isn't present { let res = match resolver.resolve("foobar.google.com", QueryType::A, true) { Ok(x) => x, Err(_) => panic!() }; assert_eq!(ResultCode::NXDOMAIN, res.header.rescode); assert_eq!(0, res.answers.len()); }; // Perform another successful query, that should hit the cache { let res = match resolver.resolve("google.com", QueryType::A, true) { Ok(x) => x, Err(_) => panic!() }; assert_eq!(1, res.answers.len()); }; // Now check that the cache is used, and that the statistics is correct { let list = match context.cache.list() { Ok(x) => x, Err(_) => panic!() }; assert_eq!(3, list.len()); // Find entries by domain name (LRU order may vary) let google_entry = list.iter().find(|e| e.domain == "google.com").expect("google.com entry"); let ns1_entry = list.iter().find(|e| e.domain == "ns1.google.com").expect("ns1.google.com entry"); let foobar_entry = list.iter().find(|e| e.domain == "foobar.google.com").expect("foobar.google.com NXDOMAIN entry"); // google.com should have a NS record and an A record for a total of 2 record types assert_eq!(2, google_entry.record_types.len()); // Should have been hit two times for NS google.com and once for // A google.com assert_eq!(3, google_entry.hits); assert_eq!(1, ns1_entry.record_types.len()); assert_eq!(2, ns1_entry.hits); // foobar.google.com should be a cached NXDOMAIN with 0 hits assert_eq!(0, foobar_entry.hits); }; } }