Removed async dnsclient in favor of internal implementation.

This commit is contained in:
Revertron
2021-09-15 16:06:10 +02:00
parent 88bca10fbc
commit 6eb185f76a
3 changed files with 83 additions and 562 deletions
+62 -14
View File
@@ -3,7 +3,7 @@
use std::io::{Write, Read};
use std::marker::{Send, Sync};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs, UdpSocket, IpAddr};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::atomic::{AtomicUsize, Ordering, AtomicBool};
use std::sync::mpsc::{channel, Sender};
use std::sync::{Arc, Mutex, RwLock};
use std::thread::{sleep, Builder};
@@ -17,8 +17,7 @@ use log::{debug, error, info, trace, warn};
use crate::dns::buffer::{BytePacketBuffer, PacketBuffer, StreamPacketBuffer, VectorPacketBuffer};
use crate::dns::netutil::{read_packet_length, write_packet_length};
use crate::dns::protocol::{DnsPacket, DnsQuestion, QueryType};
use dnsclient::UpstreamServer;
use crate::dns::protocol::{DnsPacket, DnsQuestion, QueryType, DnsRecord};
use lru::LruCache;
#[derive(Debug, Display, From, Error)]
@@ -37,6 +36,7 @@ pub trait DnsClient {
fn get_failed_count(&self) -> usize;
fn run(&self) -> Result<()>;
fn stop(&mut self);
fn send_query(&self, qname: &str, qtype: QueryType, server: &str, recursive: bool) -> Result<DnsPacket>;
}
@@ -61,7 +61,10 @@ pub struct DnsNetworkClient {
socket_ipv6: UdpSocket,
/// Queries in progress
pending_queries: Arc<Mutex<Vec<PendingQuery>>>
pending_queries: Arc<Mutex<Vec<PendingQuery>>>,
/// Stopping handle
stopped: Arc<AtomicBool>
}
/// A query in progress. This struct holds the `id` if the request, and a channel
@@ -85,7 +88,8 @@ impl DnsNetworkClient {
seq: AtomicUsize::new(0),
socket_ipv4: UdpSocket::bind(format!("0.0.0.0:{}", port)).expect("Error binding IPv4"),
socket_ipv6: UdpSocket::bind(format!("[::]:{}", port + 1)).expect("Error binding IPv6"),
pending_queries: Arc::new(Mutex::new(Vec::new()))
pending_queries: Arc::new(Mutex::new(Vec::new())),
stopped: Arc::new(AtomicBool::new(false))
}
}
@@ -197,15 +201,22 @@ impl DnsClient for DnsNetworkClient {
/// The run method launches a worker thread. Unless this thread is running, no
/// responses will ever be generated, and clients will just block indefinitely.
fn run(&self) -> Result<()> {
let timeout = Some(std::time::Duration::from_millis(500));
// Start the thread for handling incoming responses
{
let socket_copy = self.socket_ipv4.try_clone()?;
let _ = socket_copy.set_read_timeout(timeout);
let pending_queries_lock = self.pending_queries.clone();
let stopped = Arc::clone(&self.stopped);
Builder::new()
.name("DnsNetworkClient-worker-thread".into())
.spawn(move || {
loop {
if stopped.load(Ordering::SeqCst) {
break;
}
// Read data into a buffer
let mut res_buffer = BytePacketBuffer::new();
match socket_copy.recv_from(&mut res_buffer.buf) {
@@ -253,12 +264,18 @@ impl DnsClient for DnsNetworkClient {
// Start the same thread for IPv6
{
let socket_copy = self.socket_ipv6.try_clone()?;
let _ = socket_copy.set_read_timeout(timeout);
let pending_queries_lock = self.pending_queries.clone();
let stopped = Arc::clone(&self.stopped);
Builder::new()
.name("DnsNetworkClient-worker-thread".into())
.spawn(move || {
loop {
if stopped.load(Ordering::SeqCst) {
break;
}
// Read data into a buffer
let mut res_buffer = BytePacketBuffer::new();
match socket_copy.recv_from(&mut res_buffer.buf) {
@@ -306,12 +323,16 @@ impl DnsClient for DnsNetworkClient {
// Start the thread for timing out requests
{
let pending_queries_lock = self.pending_queries.clone();
let stopped = Arc::clone(&self.stopped);
Builder::new()
.name("DnsNetworkClient-timeout-thread".into())
.spawn(move || {
let timeout = Duration::seconds(10);
let timeout = Duration::seconds(5);
loop {
if stopped.load(Ordering::SeqCst) {
break;
}
if let Ok(mut pending_queries) = pending_queries_lock.lock() {
let mut finished_queries = Vec::new();
for (i, pending_query) in pending_queries.iter().enumerate() {
@@ -336,6 +357,10 @@ impl DnsClient for DnsNetworkClient {
Ok(())
}
fn stop(&mut self) {
self.stopped.store(true, Ordering::SeqCst);
}
fn send_query(&self, qname: &str, qtype: QueryType, server: &str, recursive: bool) -> Result<DnsPacket> {
let packet = self.send_udp_query(qname, qtype, server, recursive)?;
if !packet.header.truncated_message {
@@ -359,11 +384,9 @@ impl HttpsDnsClient {
let servers = bootstraps
.iter()
.filter_map(|addr| addr.parse().ok())
.map(|addr: SocketAddr| UpstreamServer::new(addr.clone()))
.collect::<Vec<_>>();
.collect::<Vec<SocketAddr>>();
trace!("Using bootstraps: {:?}", &servers);
let dns_client = dnsclient::sync::DNSClient::new(servers);
let cache: LruCache<String, Vec<SocketAddr>> = LruCache::new(10);
let cache = RwLock::new(cache);
@@ -383,13 +406,30 @@ impl HttpsDnsClient {
return Ok(addrs.clone());
}
let port = 10000 + (rand::random::<u16>() % 50000);
let mut dns_client = DnsNetworkClient::new(port);
dns_client.run().unwrap();
let mut result: Vec<IpAddr> = Vec::new();
if let Ok(addrs) = dns_client.query_a(&addr) {
result.extend(addrs.into_iter().map(|ip| IpAddr::V4(ip)))
}
if let Ok(addrs) = dns_client.query_aaaa(&addr) {
result.extend(addrs.into_iter().map(|ip| IpAddr::V6(ip)));
for server in &servers {
if let Ok(res) = dns_client.send_udp_query(&addr, QueryType::A, server, true) {
for answer in &res.answers {
match answer {
DnsRecord::A { addr, .. } => result.push(IpAddr::V4(addr.clone())),
_ => {}
}
}
}
if let Ok(res) = dns_client.send_udp_query(&addr, QueryType::AAAA, server, true) {
for answer in &res.answers {
match answer {
DnsRecord::AAAA { addr, .. } => result.push(IpAddr::V6(addr.clone())),
_ => {}
}
}
}
}
dns_client.stop();
let addrs = result
.into_iter()
@@ -420,6 +460,10 @@ impl DnsClient for HttpsDnsClient {
Ok(())
}
fn stop(&mut self) {
debug!("Stopped DoH client");
}
fn send_query(&self, qname: &str, qtype: QueryType, doh_url: &str, recursive: bool) -> Result<DnsPacket> {
// Create DnsPacket
let mut packet = DnsPacket::new();
@@ -509,6 +553,10 @@ pub mod tests {
Ok(())
}
fn stop(&mut self) {
// Nothing
}
fn send_query(&self, qname: &str, qtype: QueryType, server: &str, recursive: bool) -> Result<DnsPacket> {
(self.callback)(qname, qtype, server, recursive)
}