Implemented memory limit for DNS cache.

This commit is contained in:
Revertron
2026-01-05 16:50:20 +01:00
parent 09c1cd5ddc
commit 2e1f05cadb
5 changed files with 308 additions and 16 deletions
+6
View File
@@ -22,6 +22,12 @@ yggdrasil_only = false
listen = "127.0.0.3:53"
# How many threads to spawn by DNS server
threads = 10
# DNS cache memory limit in megabytes (default: 100)
# Prevents unbounded cache growth in high-load environments
# Set to 0 for unlimited cache (not recommended for production)
cache_memory_limit_mb = 100
# AdGuard DNS servers to filter ads and trackers
forwarders = ["https://dns.adguard.com/dns-query"]
#forwarders = ["94.140.14.14:53", "94.140.15.15:53"]
+282 -9
View File
@@ -2,16 +2,82 @@
extern crate serde;
use std::clone::Clone;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};
use std::num::NonZeroUsize;
use std::sync::{Arc, RwLock};
use lru::LruCache;
use chrono::*;
use derive_more::{Display, Error, From};
#[allow(unused_imports)]
use log::{debug, error, info, trace, warn};
use serde::{Deserialize, Serialize};
use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode};
/// Estimate the memory size of a DNS record in bytes
fn estimate_dns_record_size(record: &DnsRecord) -> usize {
match record {
DnsRecord::A { domain, .. } => 56 + domain.len(),
DnsRecord::AAAA { domain, .. } => 68 + domain.len(),
DnsRecord::NS { domain, host, .. } |
DnsRecord::CNAME { domain, host, .. } => 64 + domain.len() + host.len(),
DnsRecord::MX { domain, host, .. } => 72 + domain.len() + host.len(),
DnsRecord::SRV { domain, host, .. } => 80 + domain.len() + host.len(),
DnsRecord::SOA { domain, m_name, r_name, .. } =>
120 + domain.len() + m_name.len() + r_name.len(),
DnsRecord::TXT { domain, data, .. } => 64 + domain.len() + data.len(),
DnsRecord::PTR { domain, data, .. } => 64 + domain.len() + data.len(),
DnsRecord::TLSA { domain, data, .. } => 80 + domain.len() + data.len(),
DnsRecord::HTTPS { domain, target, params, .. } =>
88 + domain.len() + target.len() + params.len(),
DnsRecord::UNKNOWN { domain, .. } => 64 + domain.len(),
DnsRecord::OPT { data, .. } => 48 + data.len(),
}
}
/// Estimate the memory size of a domain entry in bytes
fn estimate_domain_entry_size(entry: &DomainEntry) -> usize {
let mut size = 0;
// Base struct sizes
size += std::mem::size_of::<DomainEntry>(); // ~56 bytes
size += std::mem::size_of::<Arc<DomainEntry>>(); // 16 bytes
// Domain string: 24 byte header + actual chars
size += 24 + entry.domain.len();
// HashMap base overhead
size += 24;
size += entry.record_types.len() * 32; // Bucket overhead per entry
// Calculate size of each RecordSet
for (_qtype, record_set) in &entry.record_types {
size += std::mem::size_of::<QueryType>(); // 2 bytes
match record_set {
RecordSet::NoRecords { .. } => {
size += 56; // Enum variant + timestamp + ttl
}
RecordSet::Records { records, .. } => {
size += 56; // Base enum variant
size += 24; // HashSet base
size += records.len() * 16; // Bucket overhead per record
// Sum up all record sizes
for record_entry in records {
size += estimate_dns_record_size(&record_entry.record);
size += 32; // DateTime<Local> overhead
}
}
}
}
size
}
#[derive(Debug, Display, From, Error)]
pub enum CacheError {
Io(std::io::Error),
@@ -159,14 +225,61 @@ impl DomainEntry {
}
}
#[derive(Default)]
pub struct Cache {
domain_entries: BTreeMap<String, Arc<DomainEntry>>
domain_entries: LruCache<String, Arc<DomainEntry>>,
current_memory_bytes: usize,
max_memory_bytes: usize
}
impl Cache {
pub fn new() -> Cache {
Cache { domain_entries: BTreeMap::new() }
Cache::with_memory_limit(0)
}
pub fn with_memory_limit(limit_mb: usize) -> Cache {
let max_memory_bytes = if limit_mb == 0 {
usize::MAX
} else {
limit_mb * 1024 * 1024
};
// Estimate capacity: assume ~1KB per entry
let estimated_capacity = if limit_mb == 0 {
100_000 // Default capacity for unlimited
} else {
limit_mb * 1000
};
Cache {
domain_entries: LruCache::new(NonZeroUsize::new(estimated_capacity).unwrap()),
current_memory_bytes: 0,
max_memory_bytes,
}
}
fn evict_to_limit(&mut self) -> usize {
if self.max_memory_bytes == usize::MAX {
return 0; // Unlimited
}
let mut evicted = 0;
let target_memory = (self.max_memory_bytes * 90) / 100; // Evict to 90%
while self.current_memory_bytes > target_memory {
if let Some((_, entry)) = self.domain_entries.pop_lru() {
let size = estimate_domain_entry_size(&entry);
self.current_memory_bytes = self.current_memory_bytes.saturating_sub(size);
evicted += 1;
} else {
break;
}
}
if evicted > 0 {
info!("Evicted {} DNS cache entries (memory: {} bytes)", evicted, self.current_memory_bytes);
}
evicted
}
fn get_cache_state(&mut self, qname: &str, qtype: QueryType) -> CacheState {
@@ -218,39 +331,87 @@ impl Cache {
// Store with a lowercase key for case-insensitive lookups
let domain_lower = domain.to_lowercase();
// Try to update existing entry
if let Some(ref mut rs) = self.domain_entries.get_mut(&domain_lower).and_then(Arc::get_mut) {
let old_size = estimate_domain_entry_size(rs);
rs.store_record(rec);
let new_size = estimate_domain_entry_size(rs);
self.current_memory_bytes = self.current_memory_bytes
.saturating_sub(old_size)
.saturating_add(new_size);
continue;
}
// Insert new entry
let mut rs = DomainEntry::new(domain_lower.clone());
rs.store_record(rec);
self.domain_entries.insert(domain_lower, Arc::new(rs));
let entry_size = estimate_domain_entry_size(&rs);
// Check if eviction needed
if self.current_memory_bytes + entry_size > self.max_memory_bytes {
self.evict_to_limit();
}
self.domain_entries.put(domain_lower, Arc::new(rs));
self.current_memory_bytes = self.current_memory_bytes.saturating_add(entry_size);
}
}
pub fn store_nxdomain(&mut self, qname: &str, qtype: QueryType, ttl: u32) {
// Store with lowercase key for case-insensitive lookups
let qname_lower = qname.to_lowercase();
// Try to update existing entry
if let Some(ref mut rs) = self.domain_entries.get_mut(&qname_lower).and_then(Arc::get_mut) {
let old_size = estimate_domain_entry_size(rs);
rs.store_nxdomain(qtype, ttl);
let new_size = estimate_domain_entry_size(rs);
self.current_memory_bytes = self.current_memory_bytes
.saturating_sub(old_size)
.saturating_add(new_size);
return;
}
// Insert new entry
let mut rs = DomainEntry::new(qname_lower.clone());
rs.store_nxdomain(qtype, ttl);
self.domain_entries.insert(qname_lower, Arc::new(rs));
let entry_size = estimate_domain_entry_size(&rs);
// Check if eviction needed
if self.current_memory_bytes + entry_size > self.max_memory_bytes {
self.evict_to_limit();
}
self.domain_entries.put(qname_lower, Arc::new(rs));
self.current_memory_bytes = self.current_memory_bytes.saturating_add(entry_size);
}
}
#[derive(Default)]
pub struct SynchronizedCache {
pub cache: RwLock<Cache>
}
impl SynchronizedCache {
pub fn new() -> SynchronizedCache {
SynchronizedCache { cache: RwLock::new(Cache::new()) }
SynchronizedCache::with_memory_limit(0)
}
pub fn with_memory_limit(limit_mb: usize) -> SynchronizedCache {
SynchronizedCache {
cache: RwLock::new(Cache::with_memory_limit(limit_mb))
}
}
pub fn get_memory_usage(&self) -> Result<usize> {
let cache = self.cache.read().map_err(|_| CacheError::PoisonedLock)?;
Ok(cache.current_memory_bytes)
}
pub fn get_entry_count(&self) -> Result<usize> {
let cache = self.cache.read().map_err(|_| CacheError::PoisonedLock)?;
Ok(cache.domain_entries.len())
}
pub fn list(&self) -> Result<Vec<Arc<DomainEntry>>> {
@@ -258,7 +419,7 @@ impl SynchronizedCache {
let mut list = Vec::new();
for rs in cache.domain_entries.values() {
for (_, rs) in cache.domain_entries.iter() {
list.push(rs.clone());
}
@@ -390,4 +551,116 @@ mod tests {
assert_eq!(1, cache.domain_entries.get(&"www.microsoft.com".to_string()).unwrap().updates);
assert_eq!(1, cache.domain_entries.get(&"www.microsoft.com".to_string()).unwrap().hits);
}
#[test]
fn test_memory_limited_cache() {
let mut cache = Cache::with_memory_limit(1); // 1MB limit
// Add many records until limit is hit
for i in 0..5000 {
let domain = format!("test{}.com", i);
let records = vec![DnsRecord::A {
domain: domain.clone(),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600)
}];
cache.store(&records);
}
// Verify memory stayed under limit (with some tolerance)
let limit_bytes = 1024 * 1024;
let tolerance_bytes = limit_bytes * 110 / 100; // 110% tolerance
assert!(
cache.current_memory_bytes <= tolerance_bytes,
"Cache memory {} bytes exceeds limit with tolerance {} bytes",
cache.current_memory_bytes, tolerance_bytes
);
// Verify cache still works and has been evicted
assert!(cache.domain_entries.len() < 5000, "Cache should have evicted entries");
assert!(cache.domain_entries.len() > 0, "Cache should not be empty");
// Most recent entries should still be present
assert!(cache.lookup("test4999.com", QueryType::A).is_some());
}
#[test]
fn test_unlimited_cache() {
let mut cache = Cache::with_memory_limit(0); // Unlimited
for i in 0..1000 {
let domain = format!("test{}.com", i);
let records = vec![DnsRecord::A {
domain: domain.clone(),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600)
}];
cache.store(&records);
}
// All entries should be present
assert_eq!(cache.domain_entries.len(), 1000);
assert_eq!(cache.max_memory_bytes, usize::MAX);
// Verify lookups work for all entries
assert!(cache.lookup("test0.com", QueryType::A).is_some());
assert!(cache.lookup("test500.com", QueryType::A).is_some());
assert!(cache.lookup("test999.com", QueryType::A).is_some());
}
#[test]
fn test_lru_eviction_order() {
let mut cache = Cache::with_memory_limit(1); // Small limit to trigger eviction
// Add initial batch of records
for i in 0..100 {
cache.store(&[DnsRecord::A {
domain: format!("domain{}.com", i),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600)
}]);
}
// Access domain50 to make it recently used
let _ = cache.lookup("domain50.com", QueryType::A);
// Add more records to trigger eviction
for i in 100..200 {
cache.store(&[DnsRecord::A {
domain: format!("domain{}.com", i),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600)
}]);
}
// Most recently added entries should be present
assert!(cache.lookup("domain199.com", QueryType::A).is_some());
// Verify cache is respecting memory limit
let limit_bytes = 1024 * 1024;
let tolerance_bytes = limit_bytes * 110 / 100;
assert!(cache.current_memory_bytes <= tolerance_bytes);
}
#[test]
fn test_nxdomain_memory_tracking() {
let mut cache = Cache::with_memory_limit(1); // 1MB limit
// Store many NXDOMAIN responses
for i in 0..1000 {
let domain = format!("nonexistent{}.com", i);
cache.store_nxdomain(&domain, QueryType::A, 3600);
}
// Verify memory tracking works for NXDOMAIN
assert!(cache.current_memory_bytes > 0);
assert!(cache.current_memory_bytes <= 1024 * 1024 * 110 / 100);
// Verify NXDOMAIN responses work
if let Some(packet) = cache.lookup("nonexistent999.com", QueryType::A) {
assert_eq!(ResultCode::NXDOMAIN, packet.header.rescode);
} else {
panic!("NXDOMAIN entry should be cached");
}
}
}
+4 -4
View File
@@ -61,13 +61,13 @@ pub struct ServerContext {
impl Default for ServerContext {
fn default() -> Self {
ServerContext::new(String::from("0.0.0.0:53"), Vec::new(), true)
ServerContext::new(String::from("0.0.0.0:53"), Vec::new(), true, 100)
}
}
impl ServerContext {
#[allow(unused_variables)]
pub fn new(dns_listen: String, bootstraps: Vec<String>, enable_0x20: bool) -> ServerContext {
pub fn new(dns_listen: String, bootstraps: Vec<String>, enable_0x20: bool, cache_limit_mb: usize) -> ServerContext {
#[cfg(not(feature = "doh"))]
let doh_client = None;
#[cfg(feature = "doh")]
@@ -75,7 +75,7 @@ impl ServerContext {
ServerContext {
authority: Authority::new(),
cache: SynchronizedCache::new(),
cache: SynchronizedCache::with_memory_limit(cache_limit_mb),
filters: Vec::new(),
old_client: Box::new(DnsNetworkClient::new_with_0x20(enable_0x20)),
doh_client,
@@ -129,7 +129,7 @@ pub mod tests {
pub fn create_test_context(callback: Box<StubCallback>) -> Arc<ServerContext> {
Arc::new(ServerContext {
authority: Authority::new(),
cache: SynchronizedCache::new(),
cache: SynchronizedCache::with_memory_limit(0), // Unlimited for tests
filters: Vec::new(),
old_client: Box::new(DnsStubClient::new(callback)),
doh_client: Some(Box::new(HttpsDnsClient::new(Vec::new()))),
+6 -1
View File
@@ -35,7 +35,12 @@ pub fn start_dns_server(context: &Arc<Mutex<Context>>, settings: &Settings) -> b
/// Creates DNS-context with all necessary settings
fn create_server_context(context: Arc<Mutex<Context>>, settings: &Settings) -> Arc<ServerContext> {
let mut server_context = ServerContext::new(settings.dns.listen.clone(), settings.dns.bootstraps.clone(), settings.dns.enable_0x20);
let mut server_context = ServerContext::new(
settings.dns.listen.clone(),
settings.dns.bootstraps.clone(),
settings.dns.enable_0x20,
settings.dns.cache_memory_limit_mb
);
server_context.allow_recursive = true;
server_context.resolve_strategy = match settings.dns.forwarders.is_empty() {
true => ResolveStrategy::Recursive,
+10 -2
View File
@@ -97,7 +97,10 @@ pub struct Dns {
pub hosts: Vec<String>,
/// Enable DNS 0x20 encoding (random case) for additional security against cache poisoning
#[serde(default = "default_dns_0x20")]
pub enable_0x20: bool
pub enable_0x20: bool,
/// DNS cache memory limit in megabytes (default: 100MB, 0 = unlimited)
#[serde(default = "default_cache_memory_limit_mb")]
pub cache_memory_limit_mb: usize
}
impl Default for Dns {
@@ -108,7 +111,8 @@ impl Default for Dns {
forwarders: vec![String::from("94.140.14.14:53"), String::from("94.140.15.15:53")],
bootstraps: default_dns_bootstraps(),
hosts: Vec::new(),
enable_0x20: default_dns_0x20()
enable_0x20: default_dns_0x20(),
cache_memory_limit_mb: default_cache_memory_limit_mb()
}
}
}
@@ -178,6 +182,10 @@ fn default_dns_0x20() -> bool {
true
}
fn default_cache_memory_limit_mb() -> usize {
100 // 100 MB default
}
#[cfg(test)]
mod tests {
use super::*;