From 2e1f05cadb4ee2a930086bd47efe15fc734e9e0c Mon Sep 17 00:00:00 2001 From: Revertron Date: Mon, 5 Jan 2026 16:50:20 +0100 Subject: [PATCH] Implemented memory limit for DNS cache. --- alfis.toml | 6 + src/dns/cache.rs | 291 +++++++++++++++++++++++++++++++++++++++++++-- src/dns/context.rs | 8 +- src/dns_utils.rs | 7 +- src/settings.rs | 12 +- 5 files changed, 308 insertions(+), 16 deletions(-) diff --git a/alfis.toml b/alfis.toml index 46153fb..5efb6ec 100644 --- a/alfis.toml +++ b/alfis.toml @@ -22,6 +22,12 @@ yggdrasil_only = false listen = "127.0.0.3:53" # How many threads to spawn by DNS server threads = 10 + +# DNS cache memory limit in megabytes (default: 100) +# Prevents unbounded cache growth in high-load environments +# Set to 0 for unlimited cache (not recommended for production) +cache_memory_limit_mb = 100 + # AdGuard DNS servers to filter ads and trackers forwarders = ["https://dns.adguard.com/dns-query"] #forwarders = ["94.140.14.14:53", "94.140.15.15:53"] diff --git a/src/dns/cache.rs b/src/dns/cache.rs index f476c63..2a366d0 100644 --- a/src/dns/cache.rs +++ b/src/dns/cache.rs @@ -2,16 +2,82 @@ extern crate serde; use std::clone::Clone; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{HashMap, HashSet}; use std::hash::{Hash, Hasher}; +use std::num::NonZeroUsize; use std::sync::{Arc, RwLock}; +use lru::LruCache; + use chrono::*; use derive_more::{Display, Error, From}; +#[allow(unused_imports)] +use log::{debug, error, info, trace, warn}; use serde::{Deserialize, Serialize}; use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode}; +/// Estimate the memory size of a DNS record in bytes +fn estimate_dns_record_size(record: &DnsRecord) -> usize { + match record { + DnsRecord::A { domain, .. } => 56 + domain.len(), + DnsRecord::AAAA { domain, .. } => 68 + domain.len(), + DnsRecord::NS { domain, host, .. } | + DnsRecord::CNAME { domain, host, .. } => 64 + domain.len() + host.len(), + DnsRecord::MX { domain, host, .. } => 72 + domain.len() + host.len(), + DnsRecord::SRV { domain, host, .. } => 80 + domain.len() + host.len(), + DnsRecord::SOA { domain, m_name, r_name, .. } => + 120 + domain.len() + m_name.len() + r_name.len(), + DnsRecord::TXT { domain, data, .. } => 64 + domain.len() + data.len(), + DnsRecord::PTR { domain, data, .. } => 64 + domain.len() + data.len(), + DnsRecord::TLSA { domain, data, .. } => 80 + domain.len() + data.len(), + DnsRecord::HTTPS { domain, target, params, .. } => + 88 + domain.len() + target.len() + params.len(), + DnsRecord::UNKNOWN { domain, .. } => 64 + domain.len(), + DnsRecord::OPT { data, .. } => 48 + data.len(), + } +} + +/// Estimate the memory size of a domain entry in bytes +fn estimate_domain_entry_size(entry: &DomainEntry) -> usize { + let mut size = 0; + + // Base struct sizes + size += std::mem::size_of::(); // ~56 bytes + size += std::mem::size_of::>(); // 16 bytes + + // Domain string: 24 byte header + actual chars + size += 24 + entry.domain.len(); + + // HashMap base overhead + size += 24; + size += entry.record_types.len() * 32; // Bucket overhead per entry + + // Calculate size of each RecordSet + for (_qtype, record_set) in &entry.record_types { + size += std::mem::size_of::(); // 2 bytes + + match record_set { + RecordSet::NoRecords { .. } => { + size += 56; // Enum variant + timestamp + ttl + } + RecordSet::Records { records, .. } => { + size += 56; // Base enum variant + size += 24; // HashSet base + size += records.len() * 16; // Bucket overhead per record + + // Sum up all record sizes + for record_entry in records { + size += estimate_dns_record_size(&record_entry.record); + size += 32; // DateTime overhead + } + } + } + } + + size +} + #[derive(Debug, Display, From, Error)] pub enum CacheError { Io(std::io::Error), @@ -159,14 +225,61 @@ impl DomainEntry { } } -#[derive(Default)] pub struct Cache { - domain_entries: BTreeMap> + domain_entries: LruCache>, + current_memory_bytes: usize, + max_memory_bytes: usize } impl Cache { pub fn new() -> Cache { - Cache { domain_entries: BTreeMap::new() } + Cache::with_memory_limit(0) + } + + pub fn with_memory_limit(limit_mb: usize) -> Cache { + let max_memory_bytes = if limit_mb == 0 { + usize::MAX + } else { + limit_mb * 1024 * 1024 + }; + + // Estimate capacity: assume ~1KB per entry + let estimated_capacity = if limit_mb == 0 { + 100_000 // Default capacity for unlimited + } else { + limit_mb * 1000 + }; + + Cache { + domain_entries: LruCache::new(NonZeroUsize::new(estimated_capacity).unwrap()), + current_memory_bytes: 0, + max_memory_bytes, + } + } + + fn evict_to_limit(&mut self) -> usize { + if self.max_memory_bytes == usize::MAX { + return 0; // Unlimited + } + + let mut evicted = 0; + let target_memory = (self.max_memory_bytes * 90) / 100; // Evict to 90% + + while self.current_memory_bytes > target_memory { + if let Some((_, entry)) = self.domain_entries.pop_lru() { + let size = estimate_domain_entry_size(&entry); + self.current_memory_bytes = self.current_memory_bytes.saturating_sub(size); + evicted += 1; + } else { + break; + } + } + + if evicted > 0 { + info!("Evicted {} DNS cache entries (memory: {} bytes)", evicted, self.current_memory_bytes); + } + + evicted } fn get_cache_state(&mut self, qname: &str, qtype: QueryType) -> CacheState { @@ -218,39 +331,87 @@ impl Cache { // Store with a lowercase key for case-insensitive lookups let domain_lower = domain.to_lowercase(); + // Try to update existing entry if let Some(ref mut rs) = self.domain_entries.get_mut(&domain_lower).and_then(Arc::get_mut) { + let old_size = estimate_domain_entry_size(rs); rs.store_record(rec); + let new_size = estimate_domain_entry_size(rs); + + self.current_memory_bytes = self.current_memory_bytes + .saturating_sub(old_size) + .saturating_add(new_size); continue; } + // Insert new entry let mut rs = DomainEntry::new(domain_lower.clone()); rs.store_record(rec); - self.domain_entries.insert(domain_lower, Arc::new(rs)); + let entry_size = estimate_domain_entry_size(&rs); + + // Check if eviction needed + if self.current_memory_bytes + entry_size > self.max_memory_bytes { + self.evict_to_limit(); + } + + self.domain_entries.put(domain_lower, Arc::new(rs)); + self.current_memory_bytes = self.current_memory_bytes.saturating_add(entry_size); } } pub fn store_nxdomain(&mut self, qname: &str, qtype: QueryType, ttl: u32) { // Store with lowercase key for case-insensitive lookups let qname_lower = qname.to_lowercase(); + + // Try to update existing entry if let Some(ref mut rs) = self.domain_entries.get_mut(&qname_lower).and_then(Arc::get_mut) { + let old_size = estimate_domain_entry_size(rs); rs.store_nxdomain(qtype, ttl); + let new_size = estimate_domain_entry_size(rs); + + self.current_memory_bytes = self.current_memory_bytes + .saturating_sub(old_size) + .saturating_add(new_size); return; } + // Insert new entry let mut rs = DomainEntry::new(qname_lower.clone()); rs.store_nxdomain(qtype, ttl); - self.domain_entries.insert(qname_lower, Arc::new(rs)); + let entry_size = estimate_domain_entry_size(&rs); + + // Check if eviction needed + if self.current_memory_bytes + entry_size > self.max_memory_bytes { + self.evict_to_limit(); + } + + self.domain_entries.put(qname_lower, Arc::new(rs)); + self.current_memory_bytes = self.current_memory_bytes.saturating_add(entry_size); } } -#[derive(Default)] pub struct SynchronizedCache { pub cache: RwLock } impl SynchronizedCache { pub fn new() -> SynchronizedCache { - SynchronizedCache { cache: RwLock::new(Cache::new()) } + SynchronizedCache::with_memory_limit(0) + } + + pub fn with_memory_limit(limit_mb: usize) -> SynchronizedCache { + SynchronizedCache { + cache: RwLock::new(Cache::with_memory_limit(limit_mb)) + } + } + + pub fn get_memory_usage(&self) -> Result { + let cache = self.cache.read().map_err(|_| CacheError::PoisonedLock)?; + Ok(cache.current_memory_bytes) + } + + pub fn get_entry_count(&self) -> Result { + let cache = self.cache.read().map_err(|_| CacheError::PoisonedLock)?; + Ok(cache.domain_entries.len()) } pub fn list(&self) -> Result>> { @@ -258,7 +419,7 @@ impl SynchronizedCache { let mut list = Vec::new(); - for rs in cache.domain_entries.values() { + for (_, rs) in cache.domain_entries.iter() { list.push(rs.clone()); } @@ -390,4 +551,116 @@ mod tests { assert_eq!(1, cache.domain_entries.get(&"www.microsoft.com".to_string()).unwrap().updates); assert_eq!(1, cache.domain_entries.get(&"www.microsoft.com".to_string()).unwrap().hits); } + + #[test] + fn test_memory_limited_cache() { + let mut cache = Cache::with_memory_limit(1); // 1MB limit + + // Add many records until limit is hit + for i in 0..5000 { + let domain = format!("test{}.com", i); + let records = vec![DnsRecord::A { + domain: domain.clone(), + addr: "127.0.0.1".parse().unwrap(), + ttl: TransientTtl(3600) + }]; + cache.store(&records); + } + + // Verify memory stayed under limit (with some tolerance) + let limit_bytes = 1024 * 1024; + let tolerance_bytes = limit_bytes * 110 / 100; // 110% tolerance + assert!( + cache.current_memory_bytes <= tolerance_bytes, + "Cache memory {} bytes exceeds limit with tolerance {} bytes", + cache.current_memory_bytes, tolerance_bytes + ); + + // Verify cache still works and has been evicted + assert!(cache.domain_entries.len() < 5000, "Cache should have evicted entries"); + assert!(cache.domain_entries.len() > 0, "Cache should not be empty"); + + // Most recent entries should still be present + assert!(cache.lookup("test4999.com", QueryType::A).is_some()); + } + + #[test] + fn test_unlimited_cache() { + let mut cache = Cache::with_memory_limit(0); // Unlimited + + for i in 0..1000 { + let domain = format!("test{}.com", i); + let records = vec![DnsRecord::A { + domain: domain.clone(), + addr: "127.0.0.1".parse().unwrap(), + ttl: TransientTtl(3600) + }]; + cache.store(&records); + } + + // All entries should be present + assert_eq!(cache.domain_entries.len(), 1000); + assert_eq!(cache.max_memory_bytes, usize::MAX); + + // Verify lookups work for all entries + assert!(cache.lookup("test0.com", QueryType::A).is_some()); + assert!(cache.lookup("test500.com", QueryType::A).is_some()); + assert!(cache.lookup("test999.com", QueryType::A).is_some()); + } + + #[test] + fn test_lru_eviction_order() { + let mut cache = Cache::with_memory_limit(1); // Small limit to trigger eviction + + // Add initial batch of records + for i in 0..100 { + cache.store(&[DnsRecord::A { + domain: format!("domain{}.com", i), + addr: "127.0.0.1".parse().unwrap(), + ttl: TransientTtl(3600) + }]); + } + + // Access domain50 to make it recently used + let _ = cache.lookup("domain50.com", QueryType::A); + + // Add more records to trigger eviction + for i in 100..200 { + cache.store(&[DnsRecord::A { + domain: format!("domain{}.com", i), + addr: "127.0.0.1".parse().unwrap(), + ttl: TransientTtl(3600) + }]); + } + + // Most recently added entries should be present + assert!(cache.lookup("domain199.com", QueryType::A).is_some()); + + // Verify cache is respecting memory limit + let limit_bytes = 1024 * 1024; + let tolerance_bytes = limit_bytes * 110 / 100; + assert!(cache.current_memory_bytes <= tolerance_bytes); + } + + #[test] + fn test_nxdomain_memory_tracking() { + let mut cache = Cache::with_memory_limit(1); // 1MB limit + + // Store many NXDOMAIN responses + for i in 0..1000 { + let domain = format!("nonexistent{}.com", i); + cache.store_nxdomain(&domain, QueryType::A, 3600); + } + + // Verify memory tracking works for NXDOMAIN + assert!(cache.current_memory_bytes > 0); + assert!(cache.current_memory_bytes <= 1024 * 1024 * 110 / 100); + + // Verify NXDOMAIN responses work + if let Some(packet) = cache.lookup("nonexistent999.com", QueryType::A) { + assert_eq!(ResultCode::NXDOMAIN, packet.header.rescode); + } else { + panic!("NXDOMAIN entry should be cached"); + } + } } diff --git a/src/dns/context.rs b/src/dns/context.rs index dc63e2c..6a32d52 100644 --- a/src/dns/context.rs +++ b/src/dns/context.rs @@ -61,13 +61,13 @@ pub struct ServerContext { impl Default for ServerContext { fn default() -> Self { - ServerContext::new(String::from("0.0.0.0:53"), Vec::new(), true) + ServerContext::new(String::from("0.0.0.0:53"), Vec::new(), true, 100) } } impl ServerContext { #[allow(unused_variables)] - pub fn new(dns_listen: String, bootstraps: Vec, enable_0x20: bool) -> ServerContext { + pub fn new(dns_listen: String, bootstraps: Vec, enable_0x20: bool, cache_limit_mb: usize) -> ServerContext { #[cfg(not(feature = "doh"))] let doh_client = None; #[cfg(feature = "doh")] @@ -75,7 +75,7 @@ impl ServerContext { ServerContext { authority: Authority::new(), - cache: SynchronizedCache::new(), + cache: SynchronizedCache::with_memory_limit(cache_limit_mb), filters: Vec::new(), old_client: Box::new(DnsNetworkClient::new_with_0x20(enable_0x20)), doh_client, @@ -129,7 +129,7 @@ pub mod tests { pub fn create_test_context(callback: Box) -> Arc { Arc::new(ServerContext { authority: Authority::new(), - cache: SynchronizedCache::new(), + cache: SynchronizedCache::with_memory_limit(0), // Unlimited for tests filters: Vec::new(), old_client: Box::new(DnsStubClient::new(callback)), doh_client: Some(Box::new(HttpsDnsClient::new(Vec::new()))), diff --git a/src/dns_utils.rs b/src/dns_utils.rs index 9cc65b4..8e06fc4 100644 --- a/src/dns_utils.rs +++ b/src/dns_utils.rs @@ -35,7 +35,12 @@ pub fn start_dns_server(context: &Arc>, settings: &Settings) -> b /// Creates DNS-context with all necessary settings fn create_server_context(context: Arc>, settings: &Settings) -> Arc { - let mut server_context = ServerContext::new(settings.dns.listen.clone(), settings.dns.bootstraps.clone(), settings.dns.enable_0x20); + let mut server_context = ServerContext::new( + settings.dns.listen.clone(), + settings.dns.bootstraps.clone(), + settings.dns.enable_0x20, + settings.dns.cache_memory_limit_mb + ); server_context.allow_recursive = true; server_context.resolve_strategy = match settings.dns.forwarders.is_empty() { true => ResolveStrategy::Recursive, diff --git a/src/settings.rs b/src/settings.rs index 349019f..c41aa61 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -97,7 +97,10 @@ pub struct Dns { pub hosts: Vec, /// Enable DNS 0x20 encoding (random case) for additional security against cache poisoning #[serde(default = "default_dns_0x20")] - pub enable_0x20: bool + pub enable_0x20: bool, + /// DNS cache memory limit in megabytes (default: 100MB, 0 = unlimited) + #[serde(default = "default_cache_memory_limit_mb")] + pub cache_memory_limit_mb: usize } impl Default for Dns { @@ -108,7 +111,8 @@ impl Default for Dns { forwarders: vec![String::from("94.140.14.14:53"), String::from("94.140.15.15:53")], bootstraps: default_dns_bootstraps(), hosts: Vec::new(), - enable_0x20: default_dns_0x20() + enable_0x20: default_dns_0x20(), + cache_memory_limit_mb: default_cache_memory_limit_mb() } } } @@ -178,6 +182,10 @@ fn default_dns_0x20() -> bool { true } +fn default_cache_memory_limit_mb() -> usize { + 100 // 100 MB default +} + #[cfg(test)] mod tests { use super::*;