Many DNS fixes!

This commit is contained in:
Revertron
2025-10-22 22:55:38 +02:00
parent a9d7ec1093
commit d2b7080c96
8 changed files with 544 additions and 550 deletions
Generated
+332 -471
View File
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -34,7 +34,7 @@ rand = { package = "rand", version = "0.8.5" }
sqlite = "0.36.0"
uuid = { version = "1.11.0", features = ["serde", "v4"] }
mio = { version = "1.0.0", features = ["os-poll", "net"] }
ureq = { version = "2.10", optional = true }
ureq = { version = "3.0.10", optional = true }
lru = "0.12"
derive_more = { version = "1.0.0", features = ["display", "error", "from"] }
lazy_static = "1.5.0"
+111 -10
View File
@@ -60,20 +60,27 @@ impl BlockchainFilter {
None
}
fn create_packet(&self, qname: &str, qtype: QueryType, zone: String, answers: Vec<DnsRecord>) -> Option<DnsPacket> {
fn create_packet(&self, qname: &str, qtype: QueryType, zone: String, answers: Vec<DnsRecord>, ns_records: Vec<DnsRecord>, glue_records: Vec<DnsRecord>) -> Option<DnsPacket> {
if !answers.is_empty() {
// Create DnsPacket
// Create DnsPacket with answers
let mut packet = DnsPacket::new();
packet.header.authoritative_answer = true;
packet.questions.push(DnsQuestion::new(String::from(qname), qtype));
for answer in answers {
packet.answers.push(answer);
}
packet.authorities.push(DnsRecord::NS { domain: zone, host: String::from(NAME_SERVER), ttl: TransientTtl(600) });
// Add NS records to authority section
for ns_record in ns_records {
packet.authorities.push(ns_record);
}
// Add GLUE records to additional section (resources)
for glue_record in glue_records {
packet.resources.push(glue_record);
}
//trace!("Returning packet: {:?}", &packet);
Some(packet)
} else {
// Create DnsPacket
// Create DnsPacket without answers
let mut packet = DnsPacket::new();
packet.header.authoritative_answer = true;
packet.header.rescode = ResultCode::NXDOMAIN;
@@ -85,7 +92,7 @@ impl BlockchainFilter {
}
}
fn resolve_by_ns(qname: &str, qtype: QueryType, top_domain: &String, data: &DomainData) -> (bool, Option<DnsPacket>) {
fn resolve_by_ns(qname: &str, qtype: QueryType, top_domain: &String, data: &DomainData, recursive: bool) -> (bool, Option<DnsPacket>) {
// First we search for NS records, collecting nameserver domains
let mut hosts = Vec::new();
for record in data.records.iter() {
@@ -103,7 +110,27 @@ impl BlockchainFilter {
return (false, None);
}
// Searching glue records
// If non-recursive, return a referral response with NS and GLUE records
if !recursive {
trace!("Non-recursive query for delegated domain {}, returning referral", qname);
let ns_records = BlockchainFilter::get_ns_records(data, top_domain);
let glue_records = BlockchainFilter::get_glue_records(data, top_domain, &hosts);
let mut packet = DnsPacket::new();
packet.header.authoritative_answer = false; // Not authoritative for the answer, but for the zone
packet.questions.push(DnsQuestion::new(String::from(qname), qtype));
// Add NS records to authority section
for ns_record in ns_records {
packet.authorities.push(ns_record);
}
// Add GLUE records to additional section (resources)
for glue_record in glue_records {
packet.resources.push(glue_record);
}
return (true, Some(packet));
}
// For recursive queries, search for glue records to query external servers
let mut servers = Vec::new();
for record in data.records.iter() {
match &record {
@@ -138,10 +165,71 @@ impl BlockchainFilter {
(false, None)
}
/// Extract NS records from domain data and return them
fn get_ns_records(data: &DomainData, top_domain: &str) -> Vec<DnsRecord> {
data.records.iter()
.filter_map(|record| {
if let DnsRecord::NS { domain, host, ttl } = record {
if domain == "@" {
return Some(DnsRecord::NS {
domain: String::from(top_domain),
host: host.clone(),
ttl: *ttl
});
}
}
None
})
.collect()
}
/// Extract GLUE records (A/AAAA records for NS hosts within the same domain)
fn get_glue_records(data: &DomainData, top_domain: &str, ns_hosts: &[String]) -> Vec<DnsRecord> {
let mut glue_records = Vec::new();
for record in data.records.iter() {
match record {
DnsRecord::A { domain, addr, ttl } => {
let full_domain = if domain == "@" {
String::from(top_domain)
} else {
format!("{}.{}", domain, top_domain)
};
if ns_hosts.iter().any(|ns| ns == &full_domain) {
glue_records.push(DnsRecord::A {
domain: full_domain,
addr: addr.clone(),
ttl: *ttl
});
}
}
DnsRecord::AAAA { domain, addr, ttl } => {
let full_domain = if domain == "@" {
String::from(top_domain)
} else {
format!("{}.{}", domain, top_domain)
};
if ns_hosts.iter().any(|ns| ns == &full_domain) {
glue_records.push(DnsRecord::AAAA {
domain: full_domain,
addr: addr.clone(),
ttl: *ttl
});
}
}
_ => {}
}
}
glue_records
}
}
impl DnsFilter for BlockchainFilter {
fn lookup(&self, qname: &str, qtype: QueryType) -> Option<DnsPacket> {
fn lookup(&self, qname: &str, qtype: QueryType, recursive: bool) -> Option<DnsPacket> {
let top_domain;
let subdomain;
let parts: Vec<&str> = qname.rsplitn(3, '.').collect();
@@ -192,7 +280,7 @@ impl DnsFilter for BlockchainFilter {
};
// Check if this domain has NS records and needs to resolve all records through them
let (has_ns, result) = Self::resolve_by_ns(qname, qtype, &top_domain, &data);
let (has_ns, result) = Self::resolve_by_ns(qname, qtype, &top_domain, &data, recursive);
if has_ns {
return result;
}
@@ -237,7 +325,7 @@ impl DnsFilter for BlockchainFilter {
let mut domain_exists = !answers.is_empty() || subdomain.is_empty();
if answers.is_empty() {
// If there are no records found we search for *.domain.tld record
for mut record in data.records {
for mut record in data.records.iter_mut() {
let record_domain = record.get_domain().unwrap_or(String::new());
if record.get_querytype() == qtype && record_domain == "*" {
match &mut record {
@@ -263,7 +351,20 @@ impl DnsFilter for BlockchainFilter {
}
}
if let Some(mut packet) = self.create_packet(qname, qtype, zone, answers) {
// Extract NS records and GLUE records for the response
let ns_records = BlockchainFilter::get_ns_records(&data, &top_domain);
let ns_hosts: Vec<String> = ns_records.iter()
.filter_map(|record| {
if let DnsRecord::NS { host, .. } = record {
Some(host.clone())
} else {
None
}
})
.collect();
let glue_records = BlockchainFilter::get_glue_records(&data, &top_domain, &ns_hosts);
if let Some(mut packet) = self.create_packet(qname, qtype, zone, answers, ns_records, glue_records) {
if domain_exists && packet.answers.is_empty() {
packet.header.rescode = ResultCode::NOERROR;
}
+9 -7
View File
@@ -226,7 +226,7 @@ where T: Read {
}
impl<'a, T> StreamPacketBuffer<'a, T> where T: Read + 'a {
pub fn new(stream: &'a mut T) -> StreamPacketBuffer<'_, T> {
pub fn new(stream: &'a mut T) -> StreamPacketBuffer<'a, T> {
StreamPacketBuffer {
stream,
buffer: Vec::new(),
@@ -300,14 +300,16 @@ impl<'a, T> PacketBuffer for StreamPacketBuffer<'a, T> where T: Read + 'a {
}
}
const BUF_SIZE: usize = 4096;
pub struct BytePacketBuffer {
pub buf: [u8; 512],
pub buf: [u8; BUF_SIZE],
pub pos: usize
}
impl BytePacketBuffer {
pub fn new() -> BytePacketBuffer {
BytePacketBuffer { buf: [0; 512], pos: 0 }
BytePacketBuffer { buf: [0; BUF_SIZE], pos: 0 }
}
}
@@ -319,7 +321,7 @@ impl Default for BytePacketBuffer {
impl PacketBuffer for BytePacketBuffer {
fn read(&mut self) -> Result<u8> {
if self.pos >= 512 {
if self.pos >= BUF_SIZE {
return Err(BufferError::EndOfBuffer);
}
let res = self.buf[self.pos];
@@ -329,21 +331,21 @@ impl PacketBuffer for BytePacketBuffer {
}
fn get(&mut self, pos: usize) -> Result<u8> {
if pos >= 512 {
if pos >= BUF_SIZE {
return Err(BufferError::EndOfBuffer);
}
Ok(self.buf[pos])
}
fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
if start + len >= 512 {
if start + len >= BUF_SIZE {
return Err(BufferError::EndOfBuffer);
}
Ok(&self.buf[start..start + len as usize])
}
fn write(&mut self, val: u8) -> Result<()> {
if self.pos >= 512 {
if self.pos >= BUF_SIZE {
return Err(BufferError::EndOfBuffer);
}
self.buf[self.pos] = val;
+58 -28
View File
@@ -4,7 +4,7 @@ use std::io::Write;
#[cfg(feature = "doh")]
use std::io::Read;
use std::marker::{Send, Sync};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs, UdpSocket};
use std::net::{Ipv4Addr, SocketAddr, TcpStream, ToSocketAddrs, UdpSocket};
#[cfg(feature = "doh")]
use std::net::IpAddr;
#[cfg(feature = "doh")]
@@ -32,6 +32,11 @@ use crate::dns::protocol::{DnsPacket, DnsQuestion, QueryType};
use crate::dns::protocol::DnsRecord;
#[cfg(feature = "doh")]
use lru::LruCache;
use ureq::Agent;
use ureq::config::Config;
use ureq::http::Uri;
use ureq::unversioned::resolver::{ArrayVec, ResolvedSocketAddrs, Resolver};
use ureq::unversioned::transport::{DefaultConnector, NextTimeout};
#[derive(Debug, Display, From, Error)]
pub enum ClientError {
@@ -387,7 +392,7 @@ impl DnsClient for DnsNetworkClient {
#[cfg(feature = "doh")]
pub struct HttpsDnsClient {
agent: ureq::Agent,
agent: Agent,
/// Counter for assigning packet ids
seq: AtomicUsize,
}
@@ -402,34 +407,61 @@ impl HttpsDnsClient {
.collect::<Vec<SocketAddr>>();
trace!("Using bootstraps: {:?}", &servers);
let cache: LruCache<String, Vec<SocketAddr>> = LruCache::new(NonZeroUsize::new(10).unwrap());
let cache = RwLock::new(cache);
let agent = ureq::AgentBuilder::new()
let agent_config = Agent::config_builder()
.user_agent(&client_name)
.timeout(std::time::Duration::from_secs(3))
.timeout_global(Some(std::time::Duration::from_secs(3)))
.max_idle_connections_per_host(4)
.max_idle_connections(16)
.resolver(move |addr: &str| {
let addr = match addr.find(':') {
Some(index) => addr[0..index].to_string(),
None => addr.to_string()
.build();
let agent = Agent::with_parts(agent_config, DefaultConnector::default(), BootstrapResolver::new(servers));
Self { agent, seq: AtomicUsize::new(1) }
}
}
#[derive(Debug)]
struct BootstrapResolver {
servers: Vec<SocketAddr>,
cache: RwLock<LruCache<String, Vec<SocketAddr>>>
}
impl BootstrapResolver {
pub fn new(servers: Vec<SocketAddr>) -> Self {
let cache: LruCache<String, Vec<SocketAddr>> = LruCache::new(NonZeroUsize::new(10).unwrap());
let cache = RwLock::new(cache);
Self { servers, cache }
}
}
impl Resolver for BootstrapResolver {
// TODO use timeout parameter
fn resolve(&self, uri: &Uri, _config: &Config, _timeout: NextTimeout) -> std::result::Result<ResolvedSocketAddrs, ureq::Error> {
let domain = uri.host().unwrap_or("localhost");
let port = uri.port_u16().unwrap_or(443);
let addr = match domain.find(':') {
Some(index) => domain[0..index].to_string(),
None => domain.to_string()
};
trace!("Resolving {}", addr);
if let Some(addrs) = cache.write().unwrap().get(&addr) {
if let Some(addrs) = self.cache.write().unwrap().get(&addr) {
trace!("Found bootstrap ip in cache");
return Ok(addrs.clone());
let mut results: ResolvedSocketAddrs = ArrayVec::from_fn(|_| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0));
for addr in addrs {
results.push(addr.to_owned());
}
return Ok(results);
}
let port = 10000 + (rand::random::<u16>() % 50000);
let mut dns_client = DnsNetworkClient::new(port);
let client_port = 10000 + (rand::random::<u16>() % 50000);
let mut dns_client = DnsNetworkClient::new(client_port);
dns_client.run().unwrap();
let mut result: Vec<IpAddr> = Vec::new();
for server in &servers {
let mut results: ResolvedSocketAddrs = ArrayVec::from_fn(|_| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0));
for server in &self.servers {
if let Ok(res) = dns_client.send_udp_query(&addr, QueryType::A, server, true) {
for answer in &res.answers {
if let DnsRecord::A { addr, .. } = answer {
results.push(SocketAddr::new(IpAddr::V4(*addr), port));
result.push(IpAddr::V4(*addr))
}
}
@@ -437,6 +469,7 @@ impl HttpsDnsClient {
if let Ok(res) = dns_client.send_udp_query(&addr, QueryType::AAAA, server, true) {
for answer in &res.answers {
if let DnsRecord::AAAA { addr, .. } = answer {
results.push(SocketAddr::new(IpAddr::V6(*addr), port));
result.push(IpAddr::V6(*addr))
}
}
@@ -448,14 +481,11 @@ impl HttpsDnsClient {
result.dedup();
let addrs = result
.into_iter()
.map(|ip| SocketAddr::new(ip, 443))
.map(|ip| SocketAddr::new(ip, port))
.collect::<Vec<_>>();
trace!("Resolved addresses: {:?}", &addrs);
cache.write().unwrap().put(addr, addrs.clone());
Ok(addrs)
})
.build();
Self { agent, seq: AtomicUsize::new(1) }
self.cache.write().unwrap().put(addr, addrs.clone());
Ok(results)
}
}
@@ -497,20 +527,20 @@ impl DnsClient for HttpsDnsClient {
let response = self.agent
.post(doh_url)
.set("Content-Type", "application/dns-message")
.send_bytes(req_buffer.buffer.as_slice());
.header("Content-Type", "application/dns-message")
.send(req_buffer.buffer.as_slice());
match response {
Ok(response) => {
match response.status() {
match response.status().as_u16() {
200 => {
match response.header("Content-Length") {
match response.headers().get("Content-Length") {
None => warn!("No 'Content-Length' header in DoH response!"),
Some(str) => {
match str.parse::<usize>() {
match str.to_str().unwrap_or("0").parse::<usize>() {
Ok(size) => {
let mut bytes: Vec<u8> = Vec::with_capacity(size);
response.into_reader()
response.into_body().into_reader()
.take(4096)
.read_to_end(&mut bytes)?;
let mut buffer = VectorPacketBuffer::new();
+2 -2
View File
@@ -1,14 +1,14 @@
use crate::dns::protocol::{DnsPacket, QueryType};
pub trait DnsFilter {
fn lookup(&self, qname: &str, qtype: QueryType) -> Option<DnsPacket>;
fn lookup(&self, qname: &str, qtype: QueryType, recursive: bool) -> Option<DnsPacket>;
}
pub struct DummyFilter {}
#[allow(unused_variables)]
impl DnsFilter for DummyFilter {
fn lookup(&self, qname: &str, qtype: QueryType) -> Option<DnsPacket> {
fn lookup(&self, qname: &str, qtype: QueryType, recursive: bool) -> Option<DnsPacket> {
None
}
}
+1 -1
View File
@@ -53,7 +53,7 @@ impl HostsFilter {
}
impl DnsFilter for HostsFilter {
fn lookup(&self, qname: &str, qtype: QueryType) -> Option<DnsPacket> {
fn lookup(&self, qname: &str, qtype: QueryType, _recursive: bool) -> Option<DnsPacket> {
let mut packet = DnsPacket::new();
if let Some(list) = self.hosts.get(qname) {
for addr in list {
+1 -1
View File
@@ -53,7 +53,7 @@ pub trait DnsResolver {
}
for filter in context.filters.iter() {
if let Some(packet) = filter.lookup(qname, qtype) {
if let Some(packet) = filter.lookup(qname, qtype, recursive) {
context.cache.store(&packet.answers)?;
return Ok(packet);
}