Implemented memory limit for DNS cache.
This commit is contained in:
@@ -22,6 +22,12 @@ yggdrasil_only = false
|
||||
listen = "127.0.0.3:53"
|
||||
# How many threads to spawn by DNS server
|
||||
threads = 10
|
||||
|
||||
# DNS cache memory limit in megabytes (default: 100)
|
||||
# Prevents unbounded cache growth in high-load environments
|
||||
# Set to 0 for unlimited cache (not recommended for production)
|
||||
cache_memory_limit_mb = 100
|
||||
|
||||
# AdGuard DNS servers to filter ads and trackers
|
||||
forwarders = ["https://dns.adguard.com/dns-query"]
|
||||
#forwarders = ["94.140.14.14:53", "94.140.15.15:53"]
|
||||
|
||||
+282
-9
@@ -2,16 +2,82 @@
|
||||
|
||||
extern crate serde;
|
||||
use std::clone::Clone;
|
||||
use std::collections::{BTreeMap, HashMap, HashSet};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::num::NonZeroUsize;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use lru::LruCache;
|
||||
|
||||
use chrono::*;
|
||||
use derive_more::{Display, Error, From};
|
||||
#[allow(unused_imports)]
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode};
|
||||
|
||||
/// Estimate the memory size of a DNS record in bytes
|
||||
fn estimate_dns_record_size(record: &DnsRecord) -> usize {
|
||||
match record {
|
||||
DnsRecord::A { domain, .. } => 56 + domain.len(),
|
||||
DnsRecord::AAAA { domain, .. } => 68 + domain.len(),
|
||||
DnsRecord::NS { domain, host, .. } |
|
||||
DnsRecord::CNAME { domain, host, .. } => 64 + domain.len() + host.len(),
|
||||
DnsRecord::MX { domain, host, .. } => 72 + domain.len() + host.len(),
|
||||
DnsRecord::SRV { domain, host, .. } => 80 + domain.len() + host.len(),
|
||||
DnsRecord::SOA { domain, m_name, r_name, .. } =>
|
||||
120 + domain.len() + m_name.len() + r_name.len(),
|
||||
DnsRecord::TXT { domain, data, .. } => 64 + domain.len() + data.len(),
|
||||
DnsRecord::PTR { domain, data, .. } => 64 + domain.len() + data.len(),
|
||||
DnsRecord::TLSA { domain, data, .. } => 80 + domain.len() + data.len(),
|
||||
DnsRecord::HTTPS { domain, target, params, .. } =>
|
||||
88 + domain.len() + target.len() + params.len(),
|
||||
DnsRecord::UNKNOWN { domain, .. } => 64 + domain.len(),
|
||||
DnsRecord::OPT { data, .. } => 48 + data.len(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate the memory size of a domain entry in bytes
|
||||
fn estimate_domain_entry_size(entry: &DomainEntry) -> usize {
|
||||
let mut size = 0;
|
||||
|
||||
// Base struct sizes
|
||||
size += std::mem::size_of::<DomainEntry>(); // ~56 bytes
|
||||
size += std::mem::size_of::<Arc<DomainEntry>>(); // 16 bytes
|
||||
|
||||
// Domain string: 24 byte header + actual chars
|
||||
size += 24 + entry.domain.len();
|
||||
|
||||
// HashMap base overhead
|
||||
size += 24;
|
||||
size += entry.record_types.len() * 32; // Bucket overhead per entry
|
||||
|
||||
// Calculate size of each RecordSet
|
||||
for (_qtype, record_set) in &entry.record_types {
|
||||
size += std::mem::size_of::<QueryType>(); // 2 bytes
|
||||
|
||||
match record_set {
|
||||
RecordSet::NoRecords { .. } => {
|
||||
size += 56; // Enum variant + timestamp + ttl
|
||||
}
|
||||
RecordSet::Records { records, .. } => {
|
||||
size += 56; // Base enum variant
|
||||
size += 24; // HashSet base
|
||||
size += records.len() * 16; // Bucket overhead per record
|
||||
|
||||
// Sum up all record sizes
|
||||
for record_entry in records {
|
||||
size += estimate_dns_record_size(&record_entry.record);
|
||||
size += 32; // DateTime<Local> overhead
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size
|
||||
}
|
||||
|
||||
#[derive(Debug, Display, From, Error)]
|
||||
pub enum CacheError {
|
||||
Io(std::io::Error),
|
||||
@@ -159,14 +225,61 @@ impl DomainEntry {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Cache {
|
||||
domain_entries: BTreeMap<String, Arc<DomainEntry>>
|
||||
domain_entries: LruCache<String, Arc<DomainEntry>>,
|
||||
current_memory_bytes: usize,
|
||||
max_memory_bytes: usize
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
pub fn new() -> Cache {
|
||||
Cache { domain_entries: BTreeMap::new() }
|
||||
Cache::with_memory_limit(0)
|
||||
}
|
||||
|
||||
pub fn with_memory_limit(limit_mb: usize) -> Cache {
|
||||
let max_memory_bytes = if limit_mb == 0 {
|
||||
usize::MAX
|
||||
} else {
|
||||
limit_mb * 1024 * 1024
|
||||
};
|
||||
|
||||
// Estimate capacity: assume ~1KB per entry
|
||||
let estimated_capacity = if limit_mb == 0 {
|
||||
100_000 // Default capacity for unlimited
|
||||
} else {
|
||||
limit_mb * 1000
|
||||
};
|
||||
|
||||
Cache {
|
||||
domain_entries: LruCache::new(NonZeroUsize::new(estimated_capacity).unwrap()),
|
||||
current_memory_bytes: 0,
|
||||
max_memory_bytes,
|
||||
}
|
||||
}
|
||||
|
||||
fn evict_to_limit(&mut self) -> usize {
|
||||
if self.max_memory_bytes == usize::MAX {
|
||||
return 0; // Unlimited
|
||||
}
|
||||
|
||||
let mut evicted = 0;
|
||||
let target_memory = (self.max_memory_bytes * 90) / 100; // Evict to 90%
|
||||
|
||||
while self.current_memory_bytes > target_memory {
|
||||
if let Some((_, entry)) = self.domain_entries.pop_lru() {
|
||||
let size = estimate_domain_entry_size(&entry);
|
||||
self.current_memory_bytes = self.current_memory_bytes.saturating_sub(size);
|
||||
evicted += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if evicted > 0 {
|
||||
info!("Evicted {} DNS cache entries (memory: {} bytes)", evicted, self.current_memory_bytes);
|
||||
}
|
||||
|
||||
evicted
|
||||
}
|
||||
|
||||
fn get_cache_state(&mut self, qname: &str, qtype: QueryType) -> CacheState {
|
||||
@@ -218,39 +331,87 @@ impl Cache {
|
||||
// Store with a lowercase key for case-insensitive lookups
|
||||
let domain_lower = domain.to_lowercase();
|
||||
|
||||
// Try to update existing entry
|
||||
if let Some(ref mut rs) = self.domain_entries.get_mut(&domain_lower).and_then(Arc::get_mut) {
|
||||
let old_size = estimate_domain_entry_size(rs);
|
||||
rs.store_record(rec);
|
||||
let new_size = estimate_domain_entry_size(rs);
|
||||
|
||||
self.current_memory_bytes = self.current_memory_bytes
|
||||
.saturating_sub(old_size)
|
||||
.saturating_add(new_size);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Insert new entry
|
||||
let mut rs = DomainEntry::new(domain_lower.clone());
|
||||
rs.store_record(rec);
|
||||
self.domain_entries.insert(domain_lower, Arc::new(rs));
|
||||
let entry_size = estimate_domain_entry_size(&rs);
|
||||
|
||||
// Check if eviction needed
|
||||
if self.current_memory_bytes + entry_size > self.max_memory_bytes {
|
||||
self.evict_to_limit();
|
||||
}
|
||||
|
||||
self.domain_entries.put(domain_lower, Arc::new(rs));
|
||||
self.current_memory_bytes = self.current_memory_bytes.saturating_add(entry_size);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn store_nxdomain(&mut self, qname: &str, qtype: QueryType, ttl: u32) {
|
||||
// Store with lowercase key for case-insensitive lookups
|
||||
let qname_lower = qname.to_lowercase();
|
||||
|
||||
// Try to update existing entry
|
||||
if let Some(ref mut rs) = self.domain_entries.get_mut(&qname_lower).and_then(Arc::get_mut) {
|
||||
let old_size = estimate_domain_entry_size(rs);
|
||||
rs.store_nxdomain(qtype, ttl);
|
||||
let new_size = estimate_domain_entry_size(rs);
|
||||
|
||||
self.current_memory_bytes = self.current_memory_bytes
|
||||
.saturating_sub(old_size)
|
||||
.saturating_add(new_size);
|
||||
return;
|
||||
}
|
||||
|
||||
// Insert new entry
|
||||
let mut rs = DomainEntry::new(qname_lower.clone());
|
||||
rs.store_nxdomain(qtype, ttl);
|
||||
self.domain_entries.insert(qname_lower, Arc::new(rs));
|
||||
let entry_size = estimate_domain_entry_size(&rs);
|
||||
|
||||
// Check if eviction needed
|
||||
if self.current_memory_bytes + entry_size > self.max_memory_bytes {
|
||||
self.evict_to_limit();
|
||||
}
|
||||
|
||||
self.domain_entries.put(qname_lower, Arc::new(rs));
|
||||
self.current_memory_bytes = self.current_memory_bytes.saturating_add(entry_size);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct SynchronizedCache {
|
||||
pub cache: RwLock<Cache>
|
||||
}
|
||||
|
||||
impl SynchronizedCache {
|
||||
pub fn new() -> SynchronizedCache {
|
||||
SynchronizedCache { cache: RwLock::new(Cache::new()) }
|
||||
SynchronizedCache::with_memory_limit(0)
|
||||
}
|
||||
|
||||
pub fn with_memory_limit(limit_mb: usize) -> SynchronizedCache {
|
||||
SynchronizedCache {
|
||||
cache: RwLock::new(Cache::with_memory_limit(limit_mb))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_memory_usage(&self) -> Result<usize> {
|
||||
let cache = self.cache.read().map_err(|_| CacheError::PoisonedLock)?;
|
||||
Ok(cache.current_memory_bytes)
|
||||
}
|
||||
|
||||
pub fn get_entry_count(&self) -> Result<usize> {
|
||||
let cache = self.cache.read().map_err(|_| CacheError::PoisonedLock)?;
|
||||
Ok(cache.domain_entries.len())
|
||||
}
|
||||
|
||||
pub fn list(&self) -> Result<Vec<Arc<DomainEntry>>> {
|
||||
@@ -258,7 +419,7 @@ impl SynchronizedCache {
|
||||
|
||||
let mut list = Vec::new();
|
||||
|
||||
for rs in cache.domain_entries.values() {
|
||||
for (_, rs) in cache.domain_entries.iter() {
|
||||
list.push(rs.clone());
|
||||
}
|
||||
|
||||
@@ -390,4 +551,116 @@ mod tests {
|
||||
assert_eq!(1, cache.domain_entries.get(&"www.microsoft.com".to_string()).unwrap().updates);
|
||||
assert_eq!(1, cache.domain_entries.get(&"www.microsoft.com".to_string()).unwrap().hits);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_limited_cache() {
|
||||
let mut cache = Cache::with_memory_limit(1); // 1MB limit
|
||||
|
||||
// Add many records until limit is hit
|
||||
for i in 0..5000 {
|
||||
let domain = format!("test{}.com", i);
|
||||
let records = vec![DnsRecord::A {
|
||||
domain: domain.clone(),
|
||||
addr: "127.0.0.1".parse().unwrap(),
|
||||
ttl: TransientTtl(3600)
|
||||
}];
|
||||
cache.store(&records);
|
||||
}
|
||||
|
||||
// Verify memory stayed under limit (with some tolerance)
|
||||
let limit_bytes = 1024 * 1024;
|
||||
let tolerance_bytes = limit_bytes * 110 / 100; // 110% tolerance
|
||||
assert!(
|
||||
cache.current_memory_bytes <= tolerance_bytes,
|
||||
"Cache memory {} bytes exceeds limit with tolerance {} bytes",
|
||||
cache.current_memory_bytes, tolerance_bytes
|
||||
);
|
||||
|
||||
// Verify cache still works and has been evicted
|
||||
assert!(cache.domain_entries.len() < 5000, "Cache should have evicted entries");
|
||||
assert!(cache.domain_entries.len() > 0, "Cache should not be empty");
|
||||
|
||||
// Most recent entries should still be present
|
||||
assert!(cache.lookup("test4999.com", QueryType::A).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unlimited_cache() {
|
||||
let mut cache = Cache::with_memory_limit(0); // Unlimited
|
||||
|
||||
for i in 0..1000 {
|
||||
let domain = format!("test{}.com", i);
|
||||
let records = vec![DnsRecord::A {
|
||||
domain: domain.clone(),
|
||||
addr: "127.0.0.1".parse().unwrap(),
|
||||
ttl: TransientTtl(3600)
|
||||
}];
|
||||
cache.store(&records);
|
||||
}
|
||||
|
||||
// All entries should be present
|
||||
assert_eq!(cache.domain_entries.len(), 1000);
|
||||
assert_eq!(cache.max_memory_bytes, usize::MAX);
|
||||
|
||||
// Verify lookups work for all entries
|
||||
assert!(cache.lookup("test0.com", QueryType::A).is_some());
|
||||
assert!(cache.lookup("test500.com", QueryType::A).is_some());
|
||||
assert!(cache.lookup("test999.com", QueryType::A).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lru_eviction_order() {
|
||||
let mut cache = Cache::with_memory_limit(1); // Small limit to trigger eviction
|
||||
|
||||
// Add initial batch of records
|
||||
for i in 0..100 {
|
||||
cache.store(&[DnsRecord::A {
|
||||
domain: format!("domain{}.com", i),
|
||||
addr: "127.0.0.1".parse().unwrap(),
|
||||
ttl: TransientTtl(3600)
|
||||
}]);
|
||||
}
|
||||
|
||||
// Access domain50 to make it recently used
|
||||
let _ = cache.lookup("domain50.com", QueryType::A);
|
||||
|
||||
// Add more records to trigger eviction
|
||||
for i in 100..200 {
|
||||
cache.store(&[DnsRecord::A {
|
||||
domain: format!("domain{}.com", i),
|
||||
addr: "127.0.0.1".parse().unwrap(),
|
||||
ttl: TransientTtl(3600)
|
||||
}]);
|
||||
}
|
||||
|
||||
// Most recently added entries should be present
|
||||
assert!(cache.lookup("domain199.com", QueryType::A).is_some());
|
||||
|
||||
// Verify cache is respecting memory limit
|
||||
let limit_bytes = 1024 * 1024;
|
||||
let tolerance_bytes = limit_bytes * 110 / 100;
|
||||
assert!(cache.current_memory_bytes <= tolerance_bytes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nxdomain_memory_tracking() {
|
||||
let mut cache = Cache::with_memory_limit(1); // 1MB limit
|
||||
|
||||
// Store many NXDOMAIN responses
|
||||
for i in 0..1000 {
|
||||
let domain = format!("nonexistent{}.com", i);
|
||||
cache.store_nxdomain(&domain, QueryType::A, 3600);
|
||||
}
|
||||
|
||||
// Verify memory tracking works for NXDOMAIN
|
||||
assert!(cache.current_memory_bytes > 0);
|
||||
assert!(cache.current_memory_bytes <= 1024 * 1024 * 110 / 100);
|
||||
|
||||
// Verify NXDOMAIN responses work
|
||||
if let Some(packet) = cache.lookup("nonexistent999.com", QueryType::A) {
|
||||
assert_eq!(ResultCode::NXDOMAIN, packet.header.rescode);
|
||||
} else {
|
||||
panic!("NXDOMAIN entry should be cached");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+4
-4
@@ -61,13 +61,13 @@ pub struct ServerContext {
|
||||
|
||||
impl Default for ServerContext {
|
||||
fn default() -> Self {
|
||||
ServerContext::new(String::from("0.0.0.0:53"), Vec::new(), true)
|
||||
ServerContext::new(String::from("0.0.0.0:53"), Vec::new(), true, 100)
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerContext {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(dns_listen: String, bootstraps: Vec<String>, enable_0x20: bool) -> ServerContext {
|
||||
pub fn new(dns_listen: String, bootstraps: Vec<String>, enable_0x20: bool, cache_limit_mb: usize) -> ServerContext {
|
||||
#[cfg(not(feature = "doh"))]
|
||||
let doh_client = None;
|
||||
#[cfg(feature = "doh")]
|
||||
@@ -75,7 +75,7 @@ impl ServerContext {
|
||||
|
||||
ServerContext {
|
||||
authority: Authority::new(),
|
||||
cache: SynchronizedCache::new(),
|
||||
cache: SynchronizedCache::with_memory_limit(cache_limit_mb),
|
||||
filters: Vec::new(),
|
||||
old_client: Box::new(DnsNetworkClient::new_with_0x20(enable_0x20)),
|
||||
doh_client,
|
||||
@@ -129,7 +129,7 @@ pub mod tests {
|
||||
pub fn create_test_context(callback: Box<StubCallback>) -> Arc<ServerContext> {
|
||||
Arc::new(ServerContext {
|
||||
authority: Authority::new(),
|
||||
cache: SynchronizedCache::new(),
|
||||
cache: SynchronizedCache::with_memory_limit(0), // Unlimited for tests
|
||||
filters: Vec::new(),
|
||||
old_client: Box::new(DnsStubClient::new(callback)),
|
||||
doh_client: Some(Box::new(HttpsDnsClient::new(Vec::new()))),
|
||||
|
||||
+6
-1
@@ -35,7 +35,12 @@ pub fn start_dns_server(context: &Arc<Mutex<Context>>, settings: &Settings) -> b
|
||||
|
||||
/// Creates DNS-context with all necessary settings
|
||||
fn create_server_context(context: Arc<Mutex<Context>>, settings: &Settings) -> Arc<ServerContext> {
|
||||
let mut server_context = ServerContext::new(settings.dns.listen.clone(), settings.dns.bootstraps.clone(), settings.dns.enable_0x20);
|
||||
let mut server_context = ServerContext::new(
|
||||
settings.dns.listen.clone(),
|
||||
settings.dns.bootstraps.clone(),
|
||||
settings.dns.enable_0x20,
|
||||
settings.dns.cache_memory_limit_mb
|
||||
);
|
||||
server_context.allow_recursive = true;
|
||||
server_context.resolve_strategy = match settings.dns.forwarders.is_empty() {
|
||||
true => ResolveStrategy::Recursive,
|
||||
|
||||
+10
-2
@@ -97,7 +97,10 @@ pub struct Dns {
|
||||
pub hosts: Vec<String>,
|
||||
/// Enable DNS 0x20 encoding (random case) for additional security against cache poisoning
|
||||
#[serde(default = "default_dns_0x20")]
|
||||
pub enable_0x20: bool
|
||||
pub enable_0x20: bool,
|
||||
/// DNS cache memory limit in megabytes (default: 100MB, 0 = unlimited)
|
||||
#[serde(default = "default_cache_memory_limit_mb")]
|
||||
pub cache_memory_limit_mb: usize
|
||||
}
|
||||
|
||||
impl Default for Dns {
|
||||
@@ -108,7 +111,8 @@ impl Default for Dns {
|
||||
forwarders: vec![String::from("94.140.14.14:53"), String::from("94.140.15.15:53")],
|
||||
bootstraps: default_dns_bootstraps(),
|
||||
hosts: Vec::new(),
|
||||
enable_0x20: default_dns_0x20()
|
||||
enable_0x20: default_dns_0x20(),
|
||||
cache_memory_limit_mb: default_cache_memory_limit_mb()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -178,6 +182,10 @@ fn default_dns_0x20() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_cache_memory_limit_mb() -> usize {
|
||||
100 // 100 MB default
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
Reference in New Issue
Block a user