diff --git a/src/blockchain/filter.rs b/src/blockchain/filter.rs index f423709..8bfaa5e 100644 --- a/src/blockchain/filter.rs +++ b/src/blockchain/filter.rs @@ -1,13 +1,12 @@ -use std::collections::HashMap; use std::net::{IpAddr, SocketAddr}; use std::sync::{Arc, Mutex}; use std::time::Instant; #[allow(unused_imports)] use log::{debug, error, info, trace, warn}; -use rand::seq::SliceRandom; use crate::blockchain::transaction::DomainData; +use crate::commons::rtt_tracker::RttTracker; use crate::dns::filter::DnsFilter; use crate::dns::protocol::{DnsPacket, DnsQuestion, DnsRecord, QueryType, ResultCode, TransientTtl}; use crate::Context; @@ -16,31 +15,16 @@ use crate::dns::client::{DnsClient, DnsNetworkClient}; const NAME_SERVER: &str = "ns.alfis.name"; const SERVER_ADMIN: &str = "admin.alfis.name"; -/// Unbound-style RTT band width in milliseconds. -/// Servers within min_rtt + BAND are considered equally good. -const RTT_BAND_MS: f64 = 100.0; -/// EWMA smoothing factor: 87.5% history, 12.5% new measurement. -const EWMA_WEIGHT: f64 = 7.0 / 8.0; -/// Penalty RTT assigned on timeout/failure (ms). -const TIMEOUT_PENALTY_MS: f64 = 5000.0; -/// Stats older than this are expired so the server gets re-probed. -const STATS_EXPIRE_SECS: u64 = 900; - -struct NsStats { - rtt: f64, - last_update: Instant, -} - pub struct BlockchainFilter { context: Arc>, - ns_stats: Arc>>, + ns_tracker: Arc>, } impl BlockchainFilter { pub fn new(context: Arc>) -> Self { BlockchainFilter { context, - ns_stats: Arc::new(Mutex::new(HashMap::new())), + ns_tracker: Arc::new(RttTracker::new()), } } @@ -66,13 +50,12 @@ impl BlockchainFilter { have_zone } - fn lookup_from_ns(qname: &str, qtype: QueryType, servers: &[IpAddr], ns_stats: &Arc>>) -> Option { + fn lookup_from_ns(qname: &str, qtype: QueryType, servers: &[IpAddr], tracker: &RttTracker) -> Option { let mut dns_client = DnsNetworkClient::new(); dns_client.run().unwrap(); let timeout = std::time::Duration::from_secs(2); - // Build ordered server list using RTT banding - let ordered = Self::select_servers(servers, ns_stats); + let ordered = tracker.select_ordered(servers); for server in &ordered { let addr = SocketAddr::new(*server, 53); @@ -80,12 +63,12 @@ impl BlockchainFilter { match dns_client.send_udp_query(qname, qtype, addr, false, timeout) { Ok(res) => { let elapsed = start.elapsed().as_secs_f64() * 1000.0; - Self::update_ns_stats(ns_stats, *server, elapsed); + tracker.record_success(server, elapsed); dns_client.stop(); return Some(res); } Err(_) => { - Self::update_ns_stats(ns_stats, *server, TIMEOUT_PENALTY_MS); + tracker.record_failure(server); } } } @@ -93,68 +76,6 @@ impl BlockchainFilter { None } - /// Select servers using Unbound-style RTT banding. - /// Servers with no stats or expired stats are treated as preferred (to be probed). - /// Among known servers, those within min_rtt + RTT_BAND_MS are preferred. - /// Each group is shuffled, then preferred servers come first. - fn select_servers(servers: &[IpAddr], ns_stats: &Arc>>) -> Vec { - let now = Instant::now(); - let stats = ns_stats.lock().unwrap(); - - // Separate into known (with valid stats) and unknown - let mut known: Vec<(IpAddr, f64)> = Vec::new(); - let mut unknown: Vec = Vec::new(); - for &ip in servers { - match stats.get(&ip) { - Some(s) if now.duration_since(s.last_update).as_secs() < STATS_EXPIRE_SECS => { - known.push((ip, s.rtt)); - } - _ => { - unknown.push(ip); - } - } - } - drop(stats); - - let mut rng = rand::thread_rng(); - - if known.is_empty() { - // No stats yet — shuffle all and probe - unknown.shuffle(&mut rng); - return unknown; - } - - let min_rtt = known.iter().map(|(_, rtt)| *rtt).fold(f64::INFINITY, f64::min); - let band_threshold = min_rtt + RTT_BAND_MS; - - let mut preferred: Vec = Vec::new(); - let mut fallback: Vec = Vec::new(); - for (ip, rtt) in &known { - if *rtt <= band_threshold { - preferred.push(*ip); - } else { - fallback.push(*ip); - } - } - - // Unknown servers join the preferred group to get probed - preferred.extend(unknown); - preferred.shuffle(&mut rng); - fallback.shuffle(&mut rng); - preferred.extend(fallback); - preferred - } - - fn update_ns_stats(ns_stats: &Arc>>, ip: IpAddr, rtt_ms: f64) { - let mut stats = ns_stats.lock().unwrap(); - let entry = stats.entry(ip).or_insert(NsStats { - rtt: rtt_ms, - last_update: Instant::now(), - }); - entry.rtt = entry.rtt * EWMA_WEIGHT + rtt_ms * (1.0 - EWMA_WEIGHT); - entry.last_update = Instant::now(); - } - fn create_packet(&self, qname: &str, qtype: QueryType, zone: String, answers: Vec, ns_records: Vec, glue_records: Vec) -> Option { if !answers.is_empty() { // Create DnsPacket with answers @@ -187,7 +108,7 @@ impl BlockchainFilter { } } - fn resolve_by_ns(qname: &str, qtype: QueryType, top_domain: &String, data: &DomainData, recursive: bool, ns_stats: &Arc>>) -> (bool, Option) { + fn resolve_by_ns(qname: &str, qtype: QueryType, top_domain: &String, data: &DomainData, recursive: bool, tracker: &RttTracker) -> (bool, Option) { // First we search for NS records, collecting nameserver domains let mut hosts = Vec::new(); for record in data.records.iter() { @@ -251,7 +172,7 @@ impl BlockchainFilter { if !servers.is_empty() { trace!("Found NS servers for domain {}: {:?}", &qname, &servers); - let answer = BlockchainFilter::lookup_from_ns(qname, qtype, &servers, ns_stats); + let answer = BlockchainFilter::lookup_from_ns(qname, qtype, &servers, tracker); if let Some(packet) = &answer { trace!("Resolved {:?} from NS: {:?}", (qname, qtype), &packet.answers); } @@ -379,7 +300,7 @@ impl DnsFilter for BlockchainFilter { // Check if this domain has NS records and needs to resolve all records through them // But skip this if we're querying for NS records themselves - return them directly if qtype != QueryType::NS { - let (has_ns, result) = Self::resolve_by_ns(qname, qtype, &top_domain, &data, recursive, &self.ns_stats); + let (has_ns, result) = Self::resolve_by_ns(qname, qtype, &top_domain, &data, recursive, &self.ns_tracker); if has_ns { return result; } diff --git a/src/commons/mod.rs b/src/commons/mod.rs index 65f7a61..a7c37c3 100644 --- a/src/commons/mod.rs +++ b/src/commons/mod.rs @@ -10,6 +10,7 @@ use crate::dns::protocol::DnsRecord; pub mod constants; pub mod eventbus; +pub mod rtt_tracker; pub mod simplebus; /// Convert bytes array to HEX format diff --git a/src/commons/rtt_tracker.rs b/src/commons/rtt_tracker.rs new file mode 100644 index 0000000..436521b --- /dev/null +++ b/src/commons/rtt_tracker.rs @@ -0,0 +1,110 @@ +use std::collections::HashMap; +use std::hash::Hash; +use std::sync::Mutex; +use std::time::Instant; + +use rand::seq::SliceRandom; + +/// Unbound-style RTT band width in milliseconds. +/// Servers within min_rtt + BAND are considered equally good. +const RTT_BAND_MS: f64 = 100.0; +/// EWMA smoothing factor: 87.5% history, 12.5% new measurement. +const EWMA_WEIGHT: f64 = 7.0 / 8.0; +/// Penalty RTT assigned on timeout/failure (ms). +const TIMEOUT_PENALTY_MS: f64 = 5000.0; +/// Stats older than this are expired so the server gets re-probed. +const STATS_EXPIRE_SECS: u64 = 900; + +struct RttStats { + rtt: f64, + last_update: Instant, +} + +/// Adaptive server selection using Unbound-style RTT banding. +/// +/// Tracks smoothed RTT per key and selects servers by grouping them into +/// a "preferred" band (within `RTT_BAND_MS` of the fastest known server) +/// and a "fallback" group. Unknown or expired servers are treated as +/// preferred so they get probed. +pub struct RttTracker { + stats: Mutex>, +} + +impl RttTracker { + pub fn new() -> Self { + RttTracker { + stats: Mutex::new(HashMap::new()), + } + } + + /// Returns `keys` reordered for adaptive selection. + /// + /// - Keys with no stats or expired stats go to the preferred group (to be probed). + /// - Known keys within `min_rtt + RTT_BAND_MS` go to the preferred group. + /// - The rest are fallback. + /// - Each group is shuffled; preferred comes first. + pub fn select_ordered(&self, keys: &[K]) -> Vec { + let now = Instant::now(); + let stats = self.stats.lock().unwrap(); + + let mut known: Vec<(K, f64)> = Vec::new(); + let mut unknown: Vec = Vec::new(); + for key in keys { + match stats.get(key) { + Some(s) if now.duration_since(s.last_update).as_secs() < STATS_EXPIRE_SECS => { + known.push((key.clone(), s.rtt)); + } + _ => { + unknown.push(key.clone()); + } + } + } + drop(stats); + + let mut rng = rand::thread_rng(); + + if known.is_empty() { + unknown.shuffle(&mut rng); + return unknown; + } + + let min_rtt = known.iter().map(|(_, rtt)| *rtt).fold(f64::INFINITY, f64::min); + let band_threshold = min_rtt + RTT_BAND_MS; + + let mut preferred: Vec = Vec::new(); + let mut fallback: Vec = Vec::new(); + for (key, rtt) in known { + if rtt <= band_threshold { + preferred.push(key); + } else { + fallback.push(key); + } + } + + preferred.extend(unknown); + preferred.shuffle(&mut rng); + fallback.shuffle(&mut rng); + preferred.extend(fallback); + preferred + } + + /// Record a successful query with the measured RTT in milliseconds. + pub fn record_success(&self, key: &K, rtt_ms: f64) { + self.update(key, rtt_ms); + } + + /// Record a failed/timed-out query, applying a penalty RTT. + pub fn record_failure(&self, key: &K) { + self.update(key, TIMEOUT_PENALTY_MS); + } + + fn update(&self, key: &K, rtt_ms: f64) { + let mut stats = self.stats.lock().unwrap(); + let entry = stats.entry(key.clone()).or_insert(RttStats { + rtt: rtt_ms, + last_update: Instant::now(), + }); + entry.rtt = entry.rtt * EWMA_WEIGHT + rtt_ms * (1.0 - EWMA_WEIGHT); + entry.last_update = Instant::now(); + } +} diff --git a/src/dns/context.rs b/src/dns/context.rs index 6a32d52..97d2ae0 100644 --- a/src/dns/context.rs +++ b/src/dns/context.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use derive_more::{Display, Error, From}; +use crate::commons::rtt_tracker::RttTracker; use crate::dns::authority::Authority; use crate::dns::cache::SynchronizedCache; use crate::dns::client::{DnsClient, DnsNetworkClient}; @@ -56,7 +57,8 @@ pub struct ServerContext { pub enable_tcp: bool, pub enable_api: bool, pub statistics: ServerStatistics, - pub zones_dir: &'static str + pub zones_dir: &'static str, + pub forwarder_tracker: Arc>, } impl Default for ServerContext { @@ -87,7 +89,8 @@ impl ServerContext { enable_tcp: true, enable_api: false, statistics: ServerStatistics { tcp_query_count: AtomicUsize::new(0), udp_query_count: AtomicUsize::new(0) }, - zones_dir: "zones" + zones_dir: "zones", + forwarder_tracker: Arc::new(RttTracker::new()), } } @@ -141,7 +144,8 @@ pub mod tests { enable_tcp: true, enable_api: false, statistics: ServerStatistics { tcp_query_count: AtomicUsize::new(0), udp_query_count: AtomicUsize::new(0) }, - zones_dir: "zones" + zones_dir: "zones", + forwarder_tracker: Arc::new(RttTracker::new()), }) } } diff --git a/src/dns/resolve.rs b/src/dns/resolve.rs index 21d5ea9..067dd26 100644 --- a/src/dns/resolve.rs +++ b/src/dns/resolve.rs @@ -2,10 +2,10 @@ //! incoming queries use std::sync::Arc; +use std::time::Instant; use std::vec::Vec; use derive_more::{Display, Error, From}; -use rand::seq::IteratorRandom; use crate::dns::context::ServerContext; use crate::dns::protocol::{DnsPacket, QueryType, ResultCode}; @@ -85,38 +85,51 @@ impl DnsResolver for ForwardingDnsResolver { } fn perform(&mut self, qname: &str, qtype: QueryType) -> Result { - let mut random = rand::thread_rng(); - let upstream = self.upstreams.iter().choose(&mut random).unwrap(); - let mut result = match self.context.cache.lookup(qname, qtype) { - None => { - if is_url(upstream) { - if let Some(client) = &self.context.doh_client { - client.send_query(qname, qtype, upstream, true)? - } else { - log::error!("This build doesn't support DoH"); - return Err(ResolveError::NoServerFound); - } + if let Some(packet) = self.context.cache.lookup(qname, qtype) { + return Ok(packet); + } + + let ordered = self.context.forwarder_tracker.select_ordered(&self.upstreams); + let mut last_err = ResolveError::NoServerFound; + + for upstream in &ordered { + let start = Instant::now(); + let query_result = if is_url(upstream) { + if let Some(client) = &self.context.doh_client { + client.send_query(qname, qtype, upstream, true) } else { - self.context.old_client.send_query(qname, qtype, upstream, true)? + log::error!("This build doesn't support DoH"); + continue; } - }, - Some(packet) => packet - }; + } else { + self.context.old_client.send_query(qname, qtype, upstream, true) + }; - self.context.cache.store(&result.answers)?; + match query_result { + Ok(mut result) => { + let elapsed = start.elapsed().as_secs_f64() * 1000.0; + self.context.forwarder_tracker.record_success(upstream, elapsed); + self.context.cache.store(&result.answers)?; - // Fix domain names in answers to match original query case (DNS 0x20 may have randomized them) - let qname_lower = qname.to_lowercase(); - for answer in &mut result.answers { - if let Some(domain) = answer.get_domain() { - // Only fix if it matches the query (case-insensitive) - if domain.to_lowercase() == qname_lower { - answer.set_domain(qname.to_string()); + // Fix domain names in answers to match original query case + let qname_lower = qname.to_lowercase(); + for answer in &mut result.answers { + if let Some(domain) = answer.get_domain() { + if domain.to_lowercase() == qname_lower { + answer.set_domain(qname.to_string()); + } + } + } + return Ok(result); + } + Err(e) => { + self.context.forwarder_tracker.record_failure(upstream); + last_err = e.into(); } } } - Ok(result) + Err(last_err) } }