First DNS compile. Took DNS code from https://github.com/EmilHernvall/hermes.
This commit is contained in:
@@ -0,0 +1,462 @@
|
||||
//! a threadsafe cache for DNS information
|
||||
|
||||
extern crate serde;
|
||||
use std::clone::Clone;
|
||||
use std::collections::{BTreeMap, HashMap, HashSet};
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use chrono::*;
|
||||
use derive_more::{Display, Error, From};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode};
|
||||
|
||||
#[derive(Debug, Display, From, Error)]
|
||||
pub enum CacheError {
|
||||
Io(std::io::Error),
|
||||
PoisonedLock,
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, CacheError>;
|
||||
|
||||
pub enum CacheState {
|
||||
PositiveCache,
|
||||
NegativeCache,
|
||||
NotCached,
|
||||
}
|
||||
|
||||
#[derive(Clone, Eq, Debug, Serialize, Deserialize)]
|
||||
pub struct RecordEntry {
|
||||
pub record: DnsRecord,
|
||||
pub timestamp: DateTime<Local>,
|
||||
}
|
||||
|
||||
impl PartialEq<RecordEntry> for RecordEntry {
|
||||
fn eq(&self, other: &RecordEntry) -> bool {
|
||||
self.record == other.record
|
||||
}
|
||||
}
|
||||
|
||||
impl Hash for RecordEntry {
|
||||
fn hash<H>(&self, state: &mut H)
|
||||
where
|
||||
H: Hasher,
|
||||
{
|
||||
self.record.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum RecordSet {
|
||||
NoRecords {
|
||||
qtype: QueryType,
|
||||
ttl: u32,
|
||||
timestamp: DateTime<Local>,
|
||||
},
|
||||
Records {
|
||||
qtype: QueryType,
|
||||
records: HashSet<RecordEntry>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DomainEntry {
|
||||
pub domain: String,
|
||||
pub record_types: HashMap<QueryType, RecordSet>,
|
||||
pub hits: u32,
|
||||
pub updates: u32,
|
||||
}
|
||||
|
||||
impl DomainEntry {
|
||||
pub fn new(domain: String) -> DomainEntry {
|
||||
DomainEntry {
|
||||
domain: domain,
|
||||
record_types: HashMap::new(),
|
||||
hits: 0,
|
||||
updates: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn store_nxdomain(&mut self, qtype: QueryType, ttl: u32) {
|
||||
self.updates += 1;
|
||||
|
||||
let new_set = RecordSet::NoRecords {
|
||||
qtype: qtype,
|
||||
ttl: ttl,
|
||||
timestamp: Local::now(),
|
||||
};
|
||||
|
||||
self.record_types.insert(qtype, new_set);
|
||||
}
|
||||
|
||||
pub fn store_record(&mut self, rec: &DnsRecord) {
|
||||
self.updates += 1;
|
||||
|
||||
let entry = RecordEntry {
|
||||
record: rec.clone(),
|
||||
timestamp: Local::now(),
|
||||
};
|
||||
|
||||
if let Some(&mut RecordSet::Records {
|
||||
ref mut records, ..
|
||||
}) = self.record_types.get_mut(&rec.get_querytype())
|
||||
{
|
||||
if records.contains(&entry) {
|
||||
records.remove(&entry);
|
||||
}
|
||||
|
||||
records.insert(entry);
|
||||
return;
|
||||
}
|
||||
|
||||
let mut records = HashSet::new();
|
||||
records.insert(entry);
|
||||
|
||||
let new_set = RecordSet::Records {
|
||||
qtype: rec.get_querytype(),
|
||||
records: records,
|
||||
};
|
||||
|
||||
self.record_types.insert(rec.get_querytype(), new_set);
|
||||
}
|
||||
|
||||
pub fn get_cache_state(&self, qtype: QueryType) -> CacheState {
|
||||
match self.record_types.get(&qtype) {
|
||||
Some(&RecordSet::Records { ref records, .. }) => {
|
||||
let now = Local::now();
|
||||
|
||||
let mut valid_count = 0;
|
||||
for entry in records {
|
||||
let ttl_offset = Duration::seconds(entry.record.get_ttl() as i64);
|
||||
let expires = entry.timestamp + ttl_offset;
|
||||
if expires < now {
|
||||
continue;
|
||||
}
|
||||
|
||||
if entry.record.get_querytype() == qtype {
|
||||
valid_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if valid_count > 0 {
|
||||
CacheState::PositiveCache
|
||||
} else {
|
||||
CacheState::NotCached
|
||||
}
|
||||
}
|
||||
Some(&RecordSet::NoRecords { ttl, timestamp, .. }) => {
|
||||
let now = Local::now();
|
||||
let ttl_offset = Duration::seconds(ttl as i64);
|
||||
let expires = timestamp + ttl_offset;
|
||||
|
||||
if expires < now {
|
||||
CacheState::NotCached
|
||||
} else {
|
||||
CacheState::NegativeCache
|
||||
}
|
||||
}
|
||||
None => CacheState::NotCached,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fill_queryresult(&self, qtype: QueryType, result_vec: &mut Vec<DnsRecord>) {
|
||||
let now = Local::now();
|
||||
|
||||
let current_set = match self.record_types.get(&qtype) {
|
||||
Some(x) => x,
|
||||
None => return,
|
||||
};
|
||||
|
||||
if let RecordSet::Records { ref records, .. } = *current_set {
|
||||
for entry in records {
|
||||
let ttl_offset = Duration::seconds(entry.record.get_ttl() as i64);
|
||||
let expires = entry.timestamp + ttl_offset;
|
||||
if expires < now {
|
||||
continue;
|
||||
}
|
||||
|
||||
if entry.record.get_querytype() == qtype {
|
||||
result_vec.push(entry.record.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Cache {
|
||||
domain_entries: BTreeMap<String, Arc<DomainEntry>>,
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
pub fn new() -> Cache {
|
||||
Cache {
|
||||
domain_entries: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_cache_state(&mut self, qname: &str, qtype: QueryType) -> CacheState {
|
||||
match self.domain_entries.get(qname) {
|
||||
Some(x) => x.get_cache_state(qtype),
|
||||
None => CacheState::NotCached,
|
||||
}
|
||||
}
|
||||
|
||||
fn fill_queryresult(
|
||||
&mut self,
|
||||
qname: &str,
|
||||
qtype: QueryType,
|
||||
result_vec: &mut Vec<DnsRecord>,
|
||||
increment_stats: bool,
|
||||
) {
|
||||
if let Some(domain_entry) = self.domain_entries.get_mut(qname).and_then(Arc::get_mut) {
|
||||
if increment_stats {
|
||||
domain_entry.hits += 1
|
||||
}
|
||||
|
||||
domain_entry.fill_queryresult(qtype, result_vec);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lookup(&mut self, qname: &str, qtype: QueryType) -> Option<DnsPacket> {
|
||||
match self.get_cache_state(qname, qtype) {
|
||||
CacheState::PositiveCache => {
|
||||
let mut qr = DnsPacket::new();
|
||||
self.fill_queryresult(qname, qtype, &mut qr.answers, true);
|
||||
self.fill_queryresult(qname, QueryType::NS, &mut qr.authorities, false);
|
||||
|
||||
Some(qr)
|
||||
}
|
||||
CacheState::NegativeCache => {
|
||||
let mut qr = DnsPacket::new();
|
||||
qr.header.rescode = ResultCode::NXDOMAIN;
|
||||
|
||||
Some(qr)
|
||||
}
|
||||
CacheState::NotCached => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn store(&mut self, records: &[DnsRecord]) {
|
||||
for rec in records {
|
||||
let domain = match rec.get_domain() {
|
||||
Some(x) => x,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
if let Some(ref mut rs) = self.domain_entries.get_mut(&domain).and_then(Arc::get_mut) {
|
||||
rs.store_record(rec);
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut rs = DomainEntry::new(domain.clone());
|
||||
rs.store_record(rec);
|
||||
self.domain_entries.insert(domain.clone(), Arc::new(rs));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn store_nxdomain(&mut self, qname: &str, qtype: QueryType, ttl: u32) {
|
||||
if let Some(ref mut rs) = self.domain_entries.get_mut(qname).and_then(Arc::get_mut) {
|
||||
rs.store_nxdomain(qtype, ttl);
|
||||
return;
|
||||
}
|
||||
|
||||
let mut rs = DomainEntry::new(qname.to_string());
|
||||
rs.store_nxdomain(qtype, ttl);
|
||||
self.domain_entries.insert(qname.to_string(), Arc::new(rs));
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct SynchronizedCache {
|
||||
pub cache: RwLock<Cache>,
|
||||
}
|
||||
|
||||
impl SynchronizedCache {
|
||||
pub fn new() -> SynchronizedCache {
|
||||
SynchronizedCache {
|
||||
cache: RwLock::new(Cache::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn list(&self) -> Result<Vec<Arc<DomainEntry>>> {
|
||||
let cache = self.cache.read().map_err(|_| CacheError::PoisonedLock)?;
|
||||
|
||||
let mut list = Vec::new();
|
||||
|
||||
for rs in cache.domain_entries.values() {
|
||||
list.push(rs.clone());
|
||||
}
|
||||
|
||||
Ok(list)
|
||||
}
|
||||
|
||||
pub fn lookup(&self, qname: &str, qtype: QueryType) -> Option<DnsPacket> {
|
||||
let mut cache = match self.cache.write() {
|
||||
Ok(x) => x,
|
||||
Err(_) => return None,
|
||||
};
|
||||
|
||||
cache.lookup(qname, qtype)
|
||||
}
|
||||
|
||||
pub fn store(&self, records: &[DnsRecord]) -> Result<()> {
|
||||
let mut cache = self.cache.write().map_err(|_| CacheError::PoisonedLock)?;
|
||||
|
||||
cache.store(records);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn store_nxdomain(&self, qname: &str, qtype: QueryType, ttl: u32) -> Result<()> {
|
||||
let mut cache = self.cache.write().map_err(|_| CacheError::PoisonedLock)?;
|
||||
|
||||
cache.store_nxdomain(qname, qtype, ttl);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
use crate::dns::protocol::{DnsRecord, QueryType, ResultCode, TransientTtl};
|
||||
|
||||
#[test]
|
||||
fn test_cache() {
|
||||
let mut cache = Cache::new();
|
||||
|
||||
// Verify that no data is returned when nothing is present
|
||||
if cache.lookup("www.google.com", QueryType::A).is_some() {
|
||||
panic!()
|
||||
}
|
||||
|
||||
// Register a negative cache entry
|
||||
cache.store_nxdomain("www.google.com", QueryType::A, 3600);
|
||||
|
||||
// Verify that we get a response, with the NXDOMAIN flag set
|
||||
if let Some(packet) = cache.lookup("www.google.com", QueryType::A) {
|
||||
assert_eq!(ResultCode::NXDOMAIN, packet.header.rescode);
|
||||
}
|
||||
|
||||
// Register a negative cache entry with no TTL
|
||||
cache.store_nxdomain("www.yahoo.com", QueryType::A, 0);
|
||||
|
||||
// And check that no such result is actually returned, since it's expired
|
||||
if cache.lookup("www.yahoo.com", QueryType::A).is_some() {
|
||||
panic!()
|
||||
}
|
||||
|
||||
// Now add some actual records
|
||||
let mut records = Vec::new();
|
||||
records.push(DnsRecord::A {
|
||||
domain: "www.google.com".to_string(),
|
||||
addr: "127.0.0.1".parse().unwrap(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
records.push(DnsRecord::A {
|
||||
domain: "www.yahoo.com".to_string(),
|
||||
addr: "127.0.0.2".parse().unwrap(),
|
||||
ttl: TransientTtl(0),
|
||||
});
|
||||
records.push(DnsRecord::CNAME {
|
||||
domain: "www.microsoft.com".to_string(),
|
||||
host: "www.somecdn.com".to_string(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
|
||||
cache.store(&records);
|
||||
|
||||
// Test for successful lookup
|
||||
if let Some(packet) = cache.lookup("www.google.com", QueryType::A) {
|
||||
assert_eq!(records[0], packet.answers[0]);
|
||||
} else {
|
||||
panic!();
|
||||
}
|
||||
|
||||
// Test for failed lookup, since no CNAME's are known for this domain
|
||||
if cache.lookup("www.google.com", QueryType::CNAME).is_some() {
|
||||
panic!();
|
||||
}
|
||||
|
||||
// Check for successful CNAME lookup
|
||||
if let Some(packet) = cache.lookup("www.microsoft.com", QueryType::CNAME) {
|
||||
assert_eq!(records[2], packet.answers[0]);
|
||||
} else {
|
||||
panic!();
|
||||
}
|
||||
|
||||
// This lookup should fail, since it has expired due to the 0 second TTL
|
||||
if cache.lookup("www.yahoo.com", QueryType::A).is_some() {
|
||||
panic!();
|
||||
}
|
||||
|
||||
let mut records2 = Vec::new();
|
||||
records2.push(DnsRecord::A {
|
||||
domain: "www.yahoo.com".to_string(),
|
||||
addr: "127.0.0.2".parse().unwrap(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
|
||||
cache.store(&records2);
|
||||
|
||||
// And now it should succeed, since the record has been store
|
||||
if !cache.lookup("www.yahoo.com", QueryType::A).is_some() {
|
||||
panic!();
|
||||
}
|
||||
|
||||
// Check stat counter behavior
|
||||
assert_eq!(3, cache.domain_entries.len());
|
||||
assert_eq!(
|
||||
1,
|
||||
cache
|
||||
.domain_entries
|
||||
.get(&"www.google.com".to_string())
|
||||
.unwrap()
|
||||
.hits
|
||||
);
|
||||
assert_eq!(
|
||||
2,
|
||||
cache
|
||||
.domain_entries
|
||||
.get(&"www.google.com".to_string())
|
||||
.unwrap()
|
||||
.updates
|
||||
);
|
||||
assert_eq!(
|
||||
1,
|
||||
cache
|
||||
.domain_entries
|
||||
.get(&"www.yahoo.com".to_string())
|
||||
.unwrap()
|
||||
.hits
|
||||
);
|
||||
assert_eq!(
|
||||
3,
|
||||
cache
|
||||
.domain_entries
|
||||
.get(&"www.yahoo.com".to_string())
|
||||
.unwrap()
|
||||
.updates
|
||||
);
|
||||
assert_eq!(
|
||||
1,
|
||||
cache
|
||||
.domain_entries
|
||||
.get(&"www.microsoft.com".to_string())
|
||||
.unwrap()
|
||||
.updates
|
||||
);
|
||||
assert_eq!(
|
||||
1,
|
||||
cache
|
||||
.domain_entries
|
||||
.get(&"www.microsoft.com".to_string())
|
||||
.unwrap()
|
||||
.hits
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user