Implemented rotating DNS upstreams from config. Fixed warnings.

This commit is contained in:
Revertron
2021-02-21 21:56:56 +01:00
parent daf1592341
commit 193275da7f
13 changed files with 48 additions and 40 deletions
+7 -7
View File
@@ -2,7 +2,7 @@
use std::io::Write;
use std::marker::{Send, Sync};
use std::net::{TcpStream, UdpSocket};
use std::net::{TcpStream, UdpSocket, ToSocketAddrs};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc::{channel, Sender};
use std::sync::{Arc, Mutex};
@@ -32,7 +32,7 @@ pub trait DnsClient {
fn get_failed_count(&self) -> usize;
fn run(&self) -> Result<()>;
fn send_query(&self, qname: &str, qtype: QueryType, server: (&str, u16), recursive: bool) -> Result<DnsPacket>;
fn send_query(&self, qname: &str, qtype: QueryType, server: &str, recursive: bool) -> Result<DnsPacket>;
}
/// The UDP client
@@ -84,7 +84,7 @@ impl DnsNetworkClient {
///
/// This is much simpler than using UDP, since the kernel will take care of
/// packet ordering, connection state, timeouts etc.
pub fn send_tcp_query(&self, qname: &str, qtype: QueryType, server: (&str, u16), recursive: bool) -> Result<DnsPacket> {
pub fn send_tcp_query<A: ToSocketAddrs>(&self, qname: &str, qtype: QueryType, server: A, recursive: bool) -> Result<DnsPacket> {
let _ = self.total_sent.fetch_add(1, Ordering::Release);
// Prepare request
@@ -125,7 +125,7 @@ impl DnsNetworkClient {
/// worker thread, and returned to this thread through a channel. Thus this
/// method is thread safe, and can be used from any number of threads in
/// parallel.
pub fn send_udp_query(&self, qname: &str, qtype: QueryType, server: (&str, u16), recursive: bool) -> Result<DnsPacket> {
pub fn send_udp_query<A: ToSocketAddrs>(&self, qname: &str, qtype: QueryType, server: A, recursive: bool) -> Result<DnsPacket> {
let _ = self.total_sent.fetch_add(1, Ordering::Release);
// Prepare request
@@ -268,7 +268,7 @@ impl DnsClient for DnsNetworkClient {
Ok(())
}
fn send_query(&self,qname: &str, qtype: QueryType, server: (&str, u16), recursive: bool) -> Result<DnsPacket> {
fn send_query(&self,qname: &str, qtype: QueryType, server: &str, recursive: bool) -> Result<DnsPacket> {
let packet = self.send_udp_query(qname, qtype, server, recursive)?;
if !packet.header.truncated_message {
return Ok(packet);
@@ -284,7 +284,7 @@ pub mod tests {
use super::*;
use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType};
pub type StubCallback = dyn Fn(&str, QueryType, (&str, u16), bool) -> Result<DnsPacket>;
pub type StubCallback = dyn Fn(&str, QueryType, &str, bool) -> Result<DnsPacket>;
pub struct DnsStubClient {
callback: Box<StubCallback>,
@@ -313,7 +313,7 @@ pub mod tests {
Ok(())
}
fn send_query(&self,qname: &str, qtype: QueryType, server: (&str, u16), recursive: bool) -> Result<DnsPacket> {
fn send_query(&self,qname: &str, qtype: QueryType, server: &str, recursive: bool) -> Result<DnsPacket> {
(self.callback)(qname, qtype, server, recursive)
}
}
+3 -3
View File
@@ -38,7 +38,7 @@ impl ServerStatistics {
pub enum ResolveStrategy {
Recursive,
Forward { host: String, port: u16 },
Forward { upstreams: Vec<String> },
}
pub struct ServerContext {
@@ -101,8 +101,8 @@ impl ServerContext {
pub fn create_resolver(&self, ptr: Arc<ServerContext>) -> Box<dyn DnsResolver> {
match self.resolve_strategy {
ResolveStrategy::Recursive => Box::new(RecursiveDnsResolver::new(ptr)),
ResolveStrategy::Forward { ref host, port } => {
Box::new(ForwardingDnsResolver::new(ptr, (host.clone(), port)))
ResolveStrategy::Forward { ref upstreams } => {
Box::new(ForwardingDnsResolver::new(ptr, upstreams.clone()))
}
}
}
+13 -13
View File
@@ -68,12 +68,12 @@ pub trait DnsResolver {
/// This resolver uses an external DNS server to service a query
pub struct ForwardingDnsResolver {
context: Arc<ServerContext>,
server: (String, u16),
upstreams: Vec<String>,
}
impl ForwardingDnsResolver {
pub fn new(context: Arc<ServerContext>, server: (String, u16)) -> ForwardingDnsResolver {
ForwardingDnsResolver { context, server }
pub fn new(context: Arc<ServerContext>, upstreams: Vec<String>) -> ForwardingDnsResolver {
ForwardingDnsResolver { context, upstreams }
}
}
@@ -83,10 +83,11 @@ impl DnsResolver for ForwardingDnsResolver {
}
fn perform(&mut self, qname: &str, qtype: QueryType) -> Result<DnsPacket> {
let &(ref host, port) = &self.server;
let index: usize = rand::random::<usize>() % self.upstreams.len();
let upstream = self.upstreams[index].as_ref();
let result = match self.context.cache.lookup(qname, qtype) {
None => {
self.context.client.send_query(qname, qtype, (host.as_str(), port), true)?
self.context.client.send_query(qname, qtype, upstream, true)?
}
Some(packet) => packet
};
@@ -149,11 +150,11 @@ impl DnsResolver for RecursiveDnsResolver {
let ns_copy = ns.clone();
let server = (ns_copy.as_str(), 53);
let server = format!("{}:{}", ns_copy.as_str(), 53);
let response = self
.context
.client
.send_query(qname, qtype.clone(), server, false)?;
.send_query(qname, qtype.clone(), &server, false)?;
// If we've got an actual answer, we're done!
if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
@@ -234,8 +235,7 @@ mod tests {
match Arc::get_mut(&mut context) {
Some(mut ctx) => {
ctx.resolve_strategy = ResolveStrategy::Forward {
host: "127.0.0.1".to_string(),
port: 53,
upstreams: vec![String::from("127.0.0.1:53")]
};
}
None => panic!(),
@@ -342,10 +342,10 @@ mod tests {
#[test]
fn test_recursive_resolver_match_order() {
let context = create_test_context(Box::new(|_, _, (server, _), _| {
let context = create_test_context(Box::new(|_, _, server, _| {
let mut packet = DnsPacket::new();
if server == "127.0.0.1" {
if server.starts_with("127.0.0.1") {
packet.header.id = 1;
packet.answers.push(DnsRecord::A {
@@ -355,7 +355,7 @@ mod tests {
});
return Ok(packet);
} else if server == "127.0.0.2" {
} else if server.starts_with("127.0.0.2") {
packet.header.id = 2;
packet.answers.push(DnsRecord::A {
@@ -365,7 +365,7 @@ mod tests {
});
return Ok(packet);
} else if server == "127.0.0.3" {
} else if server.starts_with("127.0.0.3") {
packet.header.id = 3;
packet.answers.push(DnsRecord::A {
+2 -4
View File
@@ -437,8 +437,7 @@ mod tests {
match Arc::get_mut(&mut context) {
Some(mut ctx) => {
ctx.resolve_strategy = ResolveStrategy::Forward {
host: "127.0.0.1".to_string(),
port: 53,
upstreams: vec![String::from("127.0.0.1:53")]
};
}
None => panic!(),
@@ -545,8 +544,7 @@ mod tests {
match Arc::get_mut(&mut context2) {
Some(mut ctx) => {
ctx.resolve_strategy = ResolveStrategy::Forward {
host: "127.0.0.1".to_string(),
port: 53,
upstreams: vec![String::from("127.0.0.1:53")]
};
}
None => panic!(),