Code reformatting.

This commit is contained in:
Revertron
2021-06-09 20:36:36 +02:00
parent 2d12fd0447
commit d513c29cfe
40 changed files with 546 additions and 815 deletions
+39 -64
View File
@@ -5,17 +5,17 @@ use std::sync::Arc;
use std::vec::Vec;
use derive_more::{Display, Error, From};
use rand::seq::IteratorRandom;
use crate::dns::context::ServerContext;
use crate::dns::protocol::{DnsPacket, QueryType, ResultCode};
use rand::seq::IteratorRandom;
#[derive(Debug, Display, From, Error)]
pub enum ResolveError {
Client(crate::dns::client::ClientError),
Cache(crate::dns::cache::CacheError),
Io(std::io::Error),
NoServerFound,
NoServerFound
}
type Result<T> = std::result::Result<T, ResolveError>;
@@ -69,7 +69,7 @@ pub trait DnsResolver {
/// This resolver uses an external DNS server to service a query
pub struct ForwardingDnsResolver {
context: Arc<ServerContext>,
upstreams: Vec<String>,
upstreams: Vec<String>
}
impl ForwardingDnsResolver {
@@ -87,9 +87,7 @@ impl DnsResolver for ForwardingDnsResolver {
let mut random = rand::thread_rng();
let upstream = self.upstreams.iter().choose(&mut random).unwrap();
let result = match self.context.cache.lookup(qname, qtype) {
None => {
self.context.client.send_query(qname, qtype, upstream, true)?
}
None => self.context.client.send_query(qname, qtype, upstream, true)?,
Some(packet) => packet
};
@@ -103,7 +101,7 @@ impl DnsResolver for ForwardingDnsResolver {
///
/// This resolver can answer any request using the root servers of the internet
pub struct RecursiveDnsResolver {
context: Arc<ServerContext>,
context: Arc<ServerContext>
}
impl RecursiveDnsResolver {
@@ -139,7 +137,7 @@ impl DnsResolver for RecursiveDnsResolver {
tentative_ns = Some(addr);
break;
}
None => continue,
None => continue
}
}
@@ -152,10 +150,7 @@ impl DnsResolver for RecursiveDnsResolver {
let ns_copy = ns.clone();
let server = format!("{}:{}", ns_copy.as_str(), 53);
let response = self
.context
.client
.send_query(qname, qtype.clone(), &server, false)?;
let response = self.context.client.send_query(qname, qtype.clone(), &server, false)?;
// If we've got an actual answer, we're done!
if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
@@ -187,7 +182,7 @@ impl DnsResolver for RecursiveDnsResolver {
// If not, we'll have to resolve the ip of a NS record
let new_ns_name = match response.get_unresolved_ns(qname) {
Some(x) => x,
None => return Ok(response.clone()),
None => return Ok(response.clone())
};
// Recursively resolve the NS
@@ -208,12 +203,10 @@ mod tests {
use std::sync::Arc;
use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode, TransientTtl};
use super::*;
use crate::dns::context::tests::create_test_context;
use crate::dns::context::ResolveStrategy;
use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode, TransientTtl};
#[test]
fn test_forwarding_resolver() {
@@ -224,7 +217,7 @@ mod tests {
packet.answers.push(DnsRecord::A {
domain: "google.com".to_string(),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600),
ttl: TransientTtl(3600)
});
} else {
packet.header.rescode = ResultCode::NXDOMAIN;
@@ -235,11 +228,9 @@ mod tests {
match Arc::get_mut(&mut context) {
Some(mut ctx) => {
ctx.resolve_strategy = ResolveStrategy::Forward {
upstreams: vec![String::from("127.0.0.1:53")]
};
ctx.resolve_strategy = ResolveStrategy::Forward { upstreams: vec![String::from("127.0.0.1:53")] };
}
None => panic!(),
None => panic!()
}
let mut resolver = context.create_resolver(Arc::clone(&context));
@@ -248,7 +239,7 @@ mod tests {
{
let res = match resolver.resolve("google.com", QueryType::A, true) {
Ok(x) => x,
Err(_) => panic!(),
Err(_) => panic!()
};
assert_eq!(1, res.answers.len());
@@ -257,7 +248,7 @@ mod tests {
DnsRecord::A { ref domain, .. } => {
assert_eq!("google.com", domain);
}
_ => panic!(),
_ => panic!()
}
};
@@ -266,14 +257,14 @@ mod tests {
{
let res = match resolver.resolve("google.com", QueryType::A, true) {
Ok(x) => x,
Err(_) => panic!(),
Err(_) => panic!()
};
assert_eq!(1, res.answers.len());
let list = match context.cache.list() {
Ok(x) => x,
Err(_) => panic!(),
Err(_) => panic!()
};
assert_eq!(1, list.len());
@@ -287,7 +278,7 @@ mod tests {
{
let res = match resolver.resolve("yahoo.com", QueryType::A, true) {
Ok(x) => x,
Err(_) => panic!(),
Err(_) => panic!()
};
assert_eq!(0, res.answers.len());
@@ -328,11 +319,7 @@ mod tests {
// Insert name server, but no corresponding A record
let mut nameservers = Vec::new();
nameservers.push(DnsRecord::NS {
domain: "".to_string(),
host: "a.myroot.net".to_string(),
ttl: TransientTtl(3600),
});
nameservers.push(DnsRecord::NS { domain: "".to_string(), host: "a.myroot.net".to_string(), ttl: TransientTtl(3600) });
let _ = context.cache.store(&nameservers);
@@ -352,7 +339,7 @@ mod tests {
packet.answers.push(DnsRecord::A {
domain: "a.google.com".to_string(),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600),
ttl: TransientTtl(3600)
});
return Ok(packet);
@@ -362,7 +349,7 @@ mod tests {
packet.answers.push(DnsRecord::A {
domain: "b.google.com".to_string(),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600),
ttl: TransientTtl(3600)
});
return Ok(packet);
@@ -372,7 +359,7 @@ mod tests {
packet.answers.push(DnsRecord::A {
domain: "c.google.com".to_string(),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600),
ttl: TransientTtl(3600)
});
return Ok(packet);
@@ -393,15 +380,11 @@ mod tests {
// Insert root servers
{
let mut nameservers = Vec::new();
nameservers.push(DnsRecord::NS {
domain: "".to_string(),
host: "a.myroot.net".to_string(),
ttl: TransientTtl(3600),
});
nameservers.push(DnsRecord::NS { domain: "".to_string(), host: "a.myroot.net".to_string(), ttl: TransientTtl(3600) });
nameservers.push(DnsRecord::A {
domain: "a.myroot.net".to_string(),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600),
ttl: TransientTtl(3600)
});
let _ = context.cache.store(&nameservers);
@@ -411,21 +394,17 @@ mod tests {
Ok(packet) => {
assert_eq!(1, packet.header.id);
}
Err(_) => panic!(),
Err(_) => panic!()
}
// Insert TLD servers
{
let mut nameservers = Vec::new();
nameservers.push(DnsRecord::NS {
domain: "com".to_string(),
host: "a.mytld.net".to_string(),
ttl: TransientTtl(3600),
});
nameservers.push(DnsRecord::NS { domain: "com".to_string(), host: "a.mytld.net".to_string(), ttl: TransientTtl(3600) });
nameservers.push(DnsRecord::A {
domain: "a.mytld.net".to_string(),
addr: "127.0.0.2".parse().unwrap(),
ttl: TransientTtl(3600),
ttl: TransientTtl(3600)
});
let _ = context.cache.store(&nameservers);
@@ -435,7 +414,7 @@ mod tests {
Ok(packet) => {
assert_eq!(2, packet.header.id);
}
Err(_) => panic!(),
Err(_) => panic!()
}
// Insert authoritative servers
@@ -444,12 +423,12 @@ mod tests {
nameservers.push(DnsRecord::NS {
domain: "google.com".to_string(),
host: "ns1.google.com".to_string(),
ttl: TransientTtl(3600),
ttl: TransientTtl(3600)
});
nameservers.push(DnsRecord::A {
domain: "ns1.google.com".to_string(),
addr: "127.0.0.3".parse().unwrap(),
ttl: TransientTtl(3600),
ttl: TransientTtl(3600)
});
let _ = context.cache.store(&nameservers);
@@ -459,7 +438,7 @@ mod tests {
Ok(packet) => {
assert_eq!(3, packet.header.id);
}
Err(_) => panic!(),
Err(_) => panic!()
}
}
@@ -472,7 +451,7 @@ mod tests {
packet.answers.push(DnsRecord::A {
domain: "google.com".to_string(),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600),
ttl: TransientTtl(3600)
});
} else {
packet.header.rescode = ResultCode::NXDOMAIN;
@@ -486,7 +465,7 @@ mod tests {
retry: 3600,
expire: 3600,
minimum: 3600,
ttl: TransientTtl(3600),
ttl: TransientTtl(3600)
});
}
@@ -497,15 +476,11 @@ mod tests {
// Insert name servers
let mut nameservers = Vec::new();
nameservers.push(DnsRecord::NS {
domain: "google.com".to_string(),
host: "ns1.google.com".to_string(),
ttl: TransientTtl(3600),
});
nameservers.push(DnsRecord::NS { domain: "google.com".to_string(), host: "ns1.google.com".to_string(), ttl: TransientTtl(3600) });
nameservers.push(DnsRecord::A {
domain: "ns1.google.com".to_string(),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600),
ttl: TransientTtl(3600)
});
let _ = context.cache.store(&nameservers);
@@ -514,7 +489,7 @@ mod tests {
{
let res = match resolver.resolve("google.com", QueryType::A, true) {
Ok(x) => x,
Err(_) => panic!(),
Err(_) => panic!()
};
assert_eq!(1, res.answers.len());
@@ -523,7 +498,7 @@ mod tests {
DnsRecord::A { ref domain, .. } => {
assert_eq!("google.com", domain);
}
_ => panic!(),
_ => panic!()
}
};
@@ -531,7 +506,7 @@ mod tests {
{
let res = match resolver.resolve("foobar.google.com", QueryType::A, true) {
Ok(x) => x,
Err(_) => panic!(),
Err(_) => panic!()
};
assert_eq!(ResultCode::NXDOMAIN, res.header.rescode);
@@ -542,7 +517,7 @@ mod tests {
{
let res = match resolver.resolve("google.com", QueryType::A, true) {
Ok(x) => x,
Err(_) => panic!(),
Err(_) => panic!()
};
assert_eq!(1, res.answers.len());
@@ -552,7 +527,7 @@ mod tests {
{
let list = match context.cache.list() {
Ok(x) => x,
Err(_) => panic!(),
Err(_) => panic!()
};
assert_eq!(3, list.len());