From 4b5e5112da17979db73002efb5bb7a5c97fc3bbf Mon Sep 17 00:00:00 2001 From: Revertron Date: Wed, 17 Feb 2021 16:59:40 +0100 Subject: [PATCH 1/2] First DNS compile. Took DNS code from https://github.com/EmilHernvall/hermes. --- Cargo.toml | 4 +- src/dns/authority.rs | 257 ++++++++++ src/dns/buffer.rs | 487 +++++++++++++++++++ src/dns/cache.rs | 462 ++++++++++++++++++ src/dns/client.rs | 400 +++++++++++++++ src/dns/context.rs | 140 ++++++ src/dns/mod.rs | 26 + src/dns/netutil.rs | 19 + src/dns/protocol.rs | 1096 ++++++++++++++++++++++++++++++++++++++++++ src/dns/resolve.rs | 569 ++++++++++++++++++++++ src/dns/server.rs | 608 +++++++++++++++++++++++ src/lib.rs | 1 + 12 files changed, 4068 insertions(+), 1 deletion(-) create mode 100644 src/dns/authority.rs create mode 100644 src/dns/buffer.rs create mode 100644 src/dns/cache.rs create mode 100644 src/dns/client.rs create mode 100644 src/dns/context.rs create mode 100644 src/dns/mod.rs create mode 100644 src/dns/netutil.rs create mode 100644 src/dns/protocol.rs create mode 100644 src/dns/resolve.rs create mode 100644 src/dns/server.rs diff --git a/Cargo.toml b/Cargo.toml index 31307ce..cac146d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,11 +20,13 @@ num-traits = "0.2" bincode = "1.2.0" groestl = "0.8.0" base64 = "0.11.0" -chrono = "0.4.9" +chrono = { version = "0.4.13", features = ["serde"] } rand = "0.7.2" sqlite = "0.25.3" uuid = { version = "0.8.2", features = ["serde", "v4"] } mio = { version = "0.7", features = ["os-poll", "net"] } +# for DNS from hermes +derive_more = "0.99.9" [build-dependencies] winres = "0.1" diff --git a/src/dns/authority.rs b/src/dns/authority.rs new file mode 100644 index 0000000..960b3d4 --- /dev/null +++ b/src/dns/authority.rs @@ -0,0 +1,257 @@ +//! contains the data store for local zones + +use std::collections::{BTreeMap, BTreeSet}; +use std::fs::File; +use std::io::Write; +use std::path::Path; +use std::sync::{LockResult, RwLock, RwLockReadGuard, RwLockWriteGuard}; + +use derive_more::{Display, From, Error}; + +use crate::dns::buffer::{PacketBuffer, StreamPacketBuffer, VectorPacketBuffer}; +use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode, TransientTtl}; + +#[derive(Debug, Display, From, Error)] +pub enum AuthorityError { + Buffer(crate::dns::buffer::BufferError), + Protocol(crate::dns::protocol::ProtocolError), + Io(std::io::Error), + PoisonedLock, +} + +type Result = std::result::Result; + +#[derive(Clone, Debug, Default)] +pub struct Zone { + pub domain: String, + pub m_name: String, + pub r_name: String, + pub serial: u32, + pub refresh: u32, + pub retry: u32, + pub expire: u32, + pub minimum: u32, + pub records: BTreeSet, +} + +impl Zone { + pub fn new(domain: String, m_name: String, r_name: String) -> Zone { + Zone { + domain: domain, + m_name: m_name, + r_name: r_name, + serial: 0, + refresh: 0, + retry: 0, + expire: 0, + minimum: 0, + records: BTreeSet::new(), + } + } + + pub fn add_record(&mut self, rec: &DnsRecord) -> bool { + self.records.insert(rec.clone()) + } + + pub fn delete_record(&mut self, rec: &DnsRecord) -> bool { + self.records.remove(rec) + } +} + +#[derive(Default)] +pub struct Zones { + zones: BTreeMap, +} + +impl<'a> Zones { + pub fn new() -> Zones { + Zones { + zones: BTreeMap::new(), + } + } + + pub fn load(&mut self) -> Result<()> { + let zones_dir = Path::new("zones").read_dir()?; + + for wrapped_filename in zones_dir { + let filename = match wrapped_filename { + Ok(x) => x, + Err(_) => continue, + }; + + let mut zone_file = match File::open(filename.path()) { + Ok(x) => x, + Err(_) => continue, + }; + + let mut buffer = StreamPacketBuffer::new(&mut zone_file); + + let mut zone = Zone::new(String::new(), String::new(), String::new()); + buffer.read_qname(&mut zone.domain)?; + buffer.read_qname(&mut zone.m_name)?; + buffer.read_qname(&mut zone.r_name)?; + zone.serial = buffer.read_u32()?; + zone.refresh = buffer.read_u32()?; + zone.retry = buffer.read_u32()?; + zone.expire = buffer.read_u32()?; + zone.minimum = buffer.read_u32()?; + + let record_count = buffer.read_u32()?; + + for _ in 0..record_count { + let rr = DnsRecord::read(&mut buffer)?; + zone.add_record(&rr); + } + + println!("Loaded zone {} with {} records", zone.domain, record_count); + + self.zones.insert(zone.domain.clone(), zone); + } + + Ok(()) + } + + pub fn save(&mut self) -> Result<()> { + let zones_dir = Path::new("zones"); + for zone in self.zones.values() { + let filename = zones_dir.join(Path::new(&zone.domain)); + let mut zone_file = match File::create(&filename) { + Ok(x) => x, + Err(_) => { + println!("Failed to save file {:?}", filename); + continue; + } + }; + + let mut buffer = VectorPacketBuffer::new(); + let _ = buffer.write_qname(&zone.domain); + let _ = buffer.write_qname(&zone.m_name); + let _ = buffer.write_qname(&zone.r_name); + let _ = buffer.write_u32(zone.serial); + let _ = buffer.write_u32(zone.refresh); + let _ = buffer.write_u32(zone.retry); + let _ = buffer.write_u32(zone.expire); + let _ = buffer.write_u32(zone.minimum); + let _ = buffer.write_u32(zone.records.len() as u32); + + for rec in &zone.records { + let _ = rec.write(&mut buffer); + } + + let _ = zone_file.write(&buffer.buffer[0..buffer.pos]); + } + + Ok(()) + } + + pub fn zones(&self) -> Vec<&Zone> { + self.zones.values().collect() + } + + pub fn add_zone(&mut self, zone: Zone) { + self.zones.insert(zone.domain.clone(), zone); + } + + pub fn get_zone(&'a self, domain: &str) -> Option<&'a Zone> { + self.zones.get(domain) + } + + pub fn get_zone_mut(&'a mut self, domain: &str) -> Option<&'a mut Zone> { + self.zones.get_mut(domain) + } +} + +#[derive(Default)] +pub struct Authority { + zones: RwLock, +} + +impl Authority { + pub fn new() -> Authority { + Authority { + zones: RwLock::new(Zones::new()), + } + } + + pub fn load(&self) -> Result<()> { + let mut zones = self + .zones + .write() + .map_err(|_| AuthorityError::PoisonedLock)?; + zones.load()?; + + Ok(()) + } + + pub fn query(&self, qname: &str, qtype: QueryType) -> Option { + let zones = match self.zones.read().ok() { + Some(x) => x, + None => return None, + }; + + let mut best_match = None; + for zone in zones.zones() { + if !qname.ends_with(&zone.domain) { + continue; + } + + if let Some((len, _)) = best_match { + if len < zone.domain.len() { + best_match = Some((zone.domain.len(), zone)); + } + } else { + best_match = Some((zone.domain.len(), zone)); + } + } + + let zone = match best_match { + Some((_, zone)) => zone, + None => return None, + }; + + let mut packet = DnsPacket::new(); + packet.header.authoritative_answer = true; + + for rec in &zone.records { + let domain = match rec.get_domain() { + Some(x) => x, + None => continue, + }; + + if &domain != qname { + continue; + } + + let rtype = rec.get_querytype(); + if qtype == rtype || (qtype == QueryType::A && rtype == QueryType::CNAME) { + packet.answers.push(rec.clone()); + } + } + + if packet.answers.is_empty() { + packet.header.rescode = ResultCode::NXDOMAIN; + + packet.authorities.push(DnsRecord::SOA { + domain: zone.domain.clone(), + m_name: zone.m_name.clone(), + r_name: zone.r_name.clone(), + serial: zone.serial, + refresh: zone.refresh, + retry: zone.retry, + expire: zone.expire, + minimum: zone.minimum, + ttl: TransientTtl(zone.minimum), + }); + } + + Some(packet) + } + + pub fn read(&self) -> LockResult> { + self.zones.read() + } + + pub fn write(&self) -> LockResult> { + self.zones.write() + } +} diff --git a/src/dns/buffer.rs b/src/dns/buffer.rs new file mode 100644 index 0000000..3d99632 --- /dev/null +++ b/src/dns/buffer.rs @@ -0,0 +1,487 @@ +//! buffers for use when writing and reading dns packets + +use std::collections::BTreeMap; +use std::io::Read; + +use derive_more::{Display, Error, From}; + +#[derive(Debug, Display, From, Error)] +pub enum BufferError { + Io(std::io::Error), + EndOfBuffer, +} + +type Result = std::result::Result; + +pub trait PacketBuffer { + fn read(&mut self) -> Result; + fn get(&mut self, pos: usize) -> Result; + fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]>; + fn write(&mut self, val: u8) -> Result<()>; + fn set(&mut self, pos: usize, val: u8) -> Result<()>; + fn pos(&self) -> usize; + fn seek(&mut self, pos: usize) -> Result<()>; + fn step(&mut self, steps: usize) -> Result<()>; + fn find_label(&self, label: &str) -> Option; + fn save_label(&mut self, label: &str, pos: usize); + + fn write_u8(&mut self, val: u8) -> Result<()> { + self.write(val)?; + + Ok(()) + } + + fn set_u16(&mut self, pos: usize, val: u16) -> Result<()> { + self.set(pos, (val >> 8) as u8)?; + self.set(pos + 1, (val & 0xFF) as u8)?; + + Ok(()) + } + + fn write_u16(&mut self, val: u16) -> Result<()> { + self.write((val >> 8) as u8)?; + self.write((val & 0xFF) as u8)?; + + Ok(()) + } + + fn write_u32(&mut self, val: u32) -> Result<()> { + self.write(((val >> 24) & 0xFF) as u8)?; + self.write(((val >> 16) & 0xFF) as u8)?; + self.write(((val >> 8) & 0xFF) as u8)?; + self.write(((val >> 0) & 0xFF) as u8)?; + + Ok(()) + } + + fn write_qname(&mut self, qname: &str) -> Result<()> { + let split_str = qname.split('.').collect::>(); + + let mut jump_performed = false; + for (i, label) in split_str.iter().enumerate() { + let search_lbl = split_str[i..split_str.len()].join("."); + if let Some(prev_pos) = self.find_label(&search_lbl) { + let jump_inst = (prev_pos as u16) | 0xC000; + self.write_u16(jump_inst)?; + jump_performed = true; + + break; + } + + let pos = self.pos(); + self.save_label(&search_lbl, pos); + + let len = label.len(); + self.write_u8(len as u8)?; + for b in label.as_bytes() { + self.write_u8(*b)?; + } + } + + if !jump_performed { + self.write_u8(0)?; + } + + Ok(()) + } + + fn read_u16(&mut self) -> Result { + let res = ((self.read()? as u16) << 8) | (self.read()? as u16); + + Ok(res) + } + + fn read_u32(&mut self) -> Result { + let res = ((self.read()? as u32) << 24) + | ((self.read()? as u32) << 16) + | ((self.read()? as u32) << 8) + | ((self.read()? as u32) << 0); + + Ok(res) + } + + fn read_qname(&mut self, outstr: &mut String) -> Result<()> { + let mut pos = self.pos(); + let mut jumped = false; + + let mut delim = ""; + loop { + let len = self.get(pos)?; + + // A two byte sequence, where the two highest bits of the first byte is + // set, represents a offset relative to the start of the buffer. We + // handle this by jumping to the offset, setting a flag to indicate + // that we shouldn't update the shared buffer position once done. + if (len & 0xC0) > 0 { + // When a jump is performed, we only modify the shared buffer + // position once, and avoid making the change later on. + if !jumped { + self.seek(pos + 2)?; + } + + let b2 = self.get(pos + 1)? as u16; + let offset = (((len as u16) ^ 0xC0) << 8) | b2; + pos = offset as usize; + jumped = true; + continue; + } + + pos += 1; + + // Names are terminated by an empty label of length 0 + if len == 0 { + break; + } + + outstr.push_str(delim); + + let str_buffer = self.get_range(pos, len as usize)?; + outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase()); + + delim = "."; + + pos += len as usize; + } + + if !jumped { + self.seek(pos)?; + } + + Ok(()) + } +} + +#[derive(Default)] +pub struct VectorPacketBuffer { + pub buffer: Vec, + pub pos: usize, + pub label_lookup: BTreeMap, +} + +impl VectorPacketBuffer { + pub fn new() -> VectorPacketBuffer { + VectorPacketBuffer { + buffer: Vec::new(), + pos: 0, + label_lookup: BTreeMap::new(), + } + } +} + +impl PacketBuffer for VectorPacketBuffer { + fn find_label(&self, label: &str) -> Option { + self.label_lookup.get(label).cloned() + } + + fn save_label(&mut self, label: &str, pos: usize) { + self.label_lookup.insert(label.to_string(), pos); + } + + fn read(&mut self) -> Result { + let res = self.buffer[self.pos]; + self.pos += 1; + + Ok(res) + } + + fn get(&mut self, pos: usize) -> Result { + Ok(self.buffer[pos]) + } + + fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> { + Ok(&self.buffer[start..start + len as usize]) + } + + fn write(&mut self, val: u8) -> Result<()> { + self.buffer.push(val); + self.pos += 1; + + Ok(()) + } + + fn set(&mut self, pos: usize, val: u8) -> Result<()> { + self.buffer[pos] = val; + + Ok(()) + } + + fn pos(&self) -> usize { + self.pos + } + + fn seek(&mut self, pos: usize) -> Result<()> { + self.pos = pos; + + Ok(()) + } + + fn step(&mut self, steps: usize) -> Result<()> { + self.pos += steps; + + Ok(()) + } +} + +pub struct StreamPacketBuffer<'a, T> +where + T: Read, +{ + pub stream: &'a mut T, + pub buffer: Vec, + pub pos: usize, +} + +impl<'a, T> StreamPacketBuffer<'a, T> +where + T: Read + 'a, +{ + pub fn new(stream: &'a mut T) -> StreamPacketBuffer<'_, T> { + StreamPacketBuffer { + stream: stream, + buffer: Vec::new(), + pos: 0, + } + } +} + +impl<'a, T> PacketBuffer for StreamPacketBuffer<'a, T> +where + T: Read + 'a, +{ + fn find_label(&self, _: &str) -> Option { + None + } + + fn save_label(&mut self, _: &str, _: usize) { + unimplemented!(); + } + + fn read(&mut self) -> Result { + while self.pos >= self.buffer.len() { + let mut local_buffer = [0; 1]; + self.stream.read(&mut local_buffer)?; + self.buffer.push(local_buffer[0]); + } + + let res = self.buffer[self.pos]; + self.pos += 1; + + Ok(res) + } + + fn get(&mut self, pos: usize) -> Result { + while pos >= self.buffer.len() { + let mut local_buffer = [0; 1]; + self.stream.read(&mut local_buffer)?; + self.buffer.push(local_buffer[0]); + } + + Ok(self.buffer[pos]) + } + + fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> { + while start + len > self.buffer.len() { + let mut local_buffer = [0; 1]; + self.stream.read(&mut local_buffer)?; + self.buffer.push(local_buffer[0]); + } + + Ok(&self.buffer[start..start + len as usize]) + } + + fn write(&mut self, _: u8) -> Result<()> { + unimplemented!(); + } + + fn set(&mut self, _: usize, _: u8) -> Result<()> { + unimplemented!(); + } + + fn pos(&self) -> usize { + self.pos + } + + fn seek(&mut self, pos: usize) -> Result<()> { + self.pos = pos; + Ok(()) + } + + fn step(&mut self, steps: usize) -> Result<()> { + self.pos += steps; + Ok(()) + } +} + +pub struct BytePacketBuffer { + pub buf: [u8; 512], + pub pos: usize, +} + +impl BytePacketBuffer { + pub fn new() -> BytePacketBuffer { + BytePacketBuffer { + buf: [0; 512], + pos: 0, + } + } +} + +impl Default for BytePacketBuffer { + fn default() -> Self { + BytePacketBuffer::new() + } +} + +impl PacketBuffer for BytePacketBuffer { + fn find_label(&self, _: &str) -> Option { + None + } + + fn save_label(&mut self, _: &str, _: usize) {} + + fn read(&mut self) -> Result { + if self.pos >= 512 { + return Err(BufferError::EndOfBuffer); + } + let res = self.buf[self.pos]; + self.pos += 1; + + Ok(res) + } + + fn get(&mut self, pos: usize) -> Result { + if pos >= 512 { + return Err(BufferError::EndOfBuffer); + } + Ok(self.buf[pos]) + } + + fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> { + if start + len >= 512 { + return Err(BufferError::EndOfBuffer); + } + Ok(&self.buf[start..start + len as usize]) + } + + fn write(&mut self, val: u8) -> Result<()> { + if self.pos >= 512 { + return Err(BufferError::EndOfBuffer); + } + self.buf[self.pos] = val; + self.pos += 1; + Ok(()) + } + + fn set(&mut self, pos: usize, val: u8) -> Result<()> { + self.buf[pos] = val; + + Ok(()) + } + + fn pos(&self) -> usize { + self.pos + } + + fn seek(&mut self, pos: usize) -> Result<()> { + self.pos = pos; + + Ok(()) + } + + fn step(&mut self, steps: usize) -> Result<()> { + self.pos += steps; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_qname() { + let mut buffer = VectorPacketBuffer::new(); + + let instr1 = "a.google.com".to_string(); + let instr2 = "b.google.com".to_string(); + + // First write the standard string + match buffer.write_qname(&instr1) { + Ok(_) => {} + Err(_) => panic!(), + } + + // Then we set up a slight variation with relies on a jump back to the data of + // the first name + let crafted_data = [0x01, b'b' as u8, 0xC0, 0x02]; + for b in &crafted_data { + match buffer.write_u8(*b) { + Ok(_) => {} + Err(_) => panic!(), + } + } + + // Reset the buffer position for reading + buffer.pos = 0; + + // Read the standard name + let mut outstr1 = String::new(); + match buffer.read_qname(&mut outstr1) { + Ok(_) => {} + Err(_) => panic!(), + } + + assert_eq!(instr1, outstr1); + + // Read the name with a jump + let mut outstr2 = String::new(); + match buffer.read_qname(&mut outstr2) { + Ok(_) => {} + Err(_) => panic!(), + } + + assert_eq!(instr2, outstr2); + + // Make sure we're now at the end of the buffer + assert_eq!(buffer.pos, buffer.buffer.len()); + } + + #[test] + fn test_write_qname() { + let mut buffer = VectorPacketBuffer::new(); + + match buffer.write_qname(&"ns1.google.com".to_string()) { + Ok(_) => {} + Err(_) => panic!(), + } + match buffer.write_qname(&"ns2.google.com".to_string()) { + Ok(_) => {} + Err(_) => panic!(), + } + + assert_eq!(22, buffer.pos()); + + match buffer.seek(0) { + Ok(_) => {} + Err(_) => panic!(), + } + + let mut str1 = String::new(); + match buffer.read_qname(&mut str1) { + Ok(_) => {} + Err(_) => panic!(), + } + + assert_eq!("ns1.google.com", str1); + + let mut str2 = String::new(); + match buffer.read_qname(&mut str2) { + Ok(_) => {} + Err(_) => panic!(), + } + + assert_eq!("ns2.google.com", str2); + } +} diff --git a/src/dns/cache.rs b/src/dns/cache.rs new file mode 100644 index 0000000..64f1332 --- /dev/null +++ b/src/dns/cache.rs @@ -0,0 +1,462 @@ +//! a threadsafe cache for DNS information + +extern crate serde; +use std::clone::Clone; +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::hash::{Hash, Hasher}; +use std::sync::{Arc, RwLock}; + +use chrono::*; +use derive_more::{Display, Error, From}; +use serde::{Deserialize, Serialize}; + +use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode}; + +#[derive(Debug, Display, From, Error)] +pub enum CacheError { + Io(std::io::Error), + PoisonedLock, +} + +type Result = std::result::Result; + +pub enum CacheState { + PositiveCache, + NegativeCache, + NotCached, +} + +#[derive(Clone, Eq, Debug, Serialize, Deserialize)] +pub struct RecordEntry { + pub record: DnsRecord, + pub timestamp: DateTime, +} + +impl PartialEq for RecordEntry { + fn eq(&self, other: &RecordEntry) -> bool { + self.record == other.record + } +} + +impl Hash for RecordEntry { + fn hash(&self, state: &mut H) + where + H: Hasher, + { + self.record.hash(state); + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum RecordSet { + NoRecords { + qtype: QueryType, + ttl: u32, + timestamp: DateTime, + }, + Records { + qtype: QueryType, + records: HashSet, + }, +} + +#[derive(Clone, Debug)] +pub struct DomainEntry { + pub domain: String, + pub record_types: HashMap, + pub hits: u32, + pub updates: u32, +} + +impl DomainEntry { + pub fn new(domain: String) -> DomainEntry { + DomainEntry { + domain: domain, + record_types: HashMap::new(), + hits: 0, + updates: 0, + } + } + + pub fn store_nxdomain(&mut self, qtype: QueryType, ttl: u32) { + self.updates += 1; + + let new_set = RecordSet::NoRecords { + qtype: qtype, + ttl: ttl, + timestamp: Local::now(), + }; + + self.record_types.insert(qtype, new_set); + } + + pub fn store_record(&mut self, rec: &DnsRecord) { + self.updates += 1; + + let entry = RecordEntry { + record: rec.clone(), + timestamp: Local::now(), + }; + + if let Some(&mut RecordSet::Records { + ref mut records, .. + }) = self.record_types.get_mut(&rec.get_querytype()) + { + if records.contains(&entry) { + records.remove(&entry); + } + + records.insert(entry); + return; + } + + let mut records = HashSet::new(); + records.insert(entry); + + let new_set = RecordSet::Records { + qtype: rec.get_querytype(), + records: records, + }; + + self.record_types.insert(rec.get_querytype(), new_set); + } + + pub fn get_cache_state(&self, qtype: QueryType) -> CacheState { + match self.record_types.get(&qtype) { + Some(&RecordSet::Records { ref records, .. }) => { + let now = Local::now(); + + let mut valid_count = 0; + for entry in records { + let ttl_offset = Duration::seconds(entry.record.get_ttl() as i64); + let expires = entry.timestamp + ttl_offset; + if expires < now { + continue; + } + + if entry.record.get_querytype() == qtype { + valid_count += 1; + } + } + + if valid_count > 0 { + CacheState::PositiveCache + } else { + CacheState::NotCached + } + } + Some(&RecordSet::NoRecords { ttl, timestamp, .. }) => { + let now = Local::now(); + let ttl_offset = Duration::seconds(ttl as i64); + let expires = timestamp + ttl_offset; + + if expires < now { + CacheState::NotCached + } else { + CacheState::NegativeCache + } + } + None => CacheState::NotCached, + } + } + + pub fn fill_queryresult(&self, qtype: QueryType, result_vec: &mut Vec) { + let now = Local::now(); + + let current_set = match self.record_types.get(&qtype) { + Some(x) => x, + None => return, + }; + + if let RecordSet::Records { ref records, .. } = *current_set { + for entry in records { + let ttl_offset = Duration::seconds(entry.record.get_ttl() as i64); + let expires = entry.timestamp + ttl_offset; + if expires < now { + continue; + } + + if entry.record.get_querytype() == qtype { + result_vec.push(entry.record.clone()); + } + } + } + } +} + +#[derive(Default)] +pub struct Cache { + domain_entries: BTreeMap>, +} + +impl Cache { + pub fn new() -> Cache { + Cache { + domain_entries: BTreeMap::new(), + } + } + + fn get_cache_state(&mut self, qname: &str, qtype: QueryType) -> CacheState { + match self.domain_entries.get(qname) { + Some(x) => x.get_cache_state(qtype), + None => CacheState::NotCached, + } + } + + fn fill_queryresult( + &mut self, + qname: &str, + qtype: QueryType, + result_vec: &mut Vec, + increment_stats: bool, + ) { + if let Some(domain_entry) = self.domain_entries.get_mut(qname).and_then(Arc::get_mut) { + if increment_stats { + domain_entry.hits += 1 + } + + domain_entry.fill_queryresult(qtype, result_vec); + } + } + + pub fn lookup(&mut self, qname: &str, qtype: QueryType) -> Option { + match self.get_cache_state(qname, qtype) { + CacheState::PositiveCache => { + let mut qr = DnsPacket::new(); + self.fill_queryresult(qname, qtype, &mut qr.answers, true); + self.fill_queryresult(qname, QueryType::NS, &mut qr.authorities, false); + + Some(qr) + } + CacheState::NegativeCache => { + let mut qr = DnsPacket::new(); + qr.header.rescode = ResultCode::NXDOMAIN; + + Some(qr) + } + CacheState::NotCached => None, + } + } + + pub fn store(&mut self, records: &[DnsRecord]) { + for rec in records { + let domain = match rec.get_domain() { + Some(x) => x, + None => continue, + }; + + if let Some(ref mut rs) = self.domain_entries.get_mut(&domain).and_then(Arc::get_mut) { + rs.store_record(rec); + continue; + } + + let mut rs = DomainEntry::new(domain.clone()); + rs.store_record(rec); + self.domain_entries.insert(domain.clone(), Arc::new(rs)); + } + } + + pub fn store_nxdomain(&mut self, qname: &str, qtype: QueryType, ttl: u32) { + if let Some(ref mut rs) = self.domain_entries.get_mut(qname).and_then(Arc::get_mut) { + rs.store_nxdomain(qtype, ttl); + return; + } + + let mut rs = DomainEntry::new(qname.to_string()); + rs.store_nxdomain(qtype, ttl); + self.domain_entries.insert(qname.to_string(), Arc::new(rs)); + } +} + +#[derive(Default)] +pub struct SynchronizedCache { + pub cache: RwLock, +} + +impl SynchronizedCache { + pub fn new() -> SynchronizedCache { + SynchronizedCache { + cache: RwLock::new(Cache::new()), + } + } + + pub fn list(&self) -> Result>> { + let cache = self.cache.read().map_err(|_| CacheError::PoisonedLock)?; + + let mut list = Vec::new(); + + for rs in cache.domain_entries.values() { + list.push(rs.clone()); + } + + Ok(list) + } + + pub fn lookup(&self, qname: &str, qtype: QueryType) -> Option { + let mut cache = match self.cache.write() { + Ok(x) => x, + Err(_) => return None, + }; + + cache.lookup(qname, qtype) + } + + pub fn store(&self, records: &[DnsRecord]) -> Result<()> { + let mut cache = self.cache.write().map_err(|_| CacheError::PoisonedLock)?; + + cache.store(records); + + Ok(()) + } + + pub fn store_nxdomain(&self, qname: &str, qtype: QueryType, ttl: u32) -> Result<()> { + let mut cache = self.cache.write().map_err(|_| CacheError::PoisonedLock)?; + + cache.store_nxdomain(qname, qtype, ttl); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + use crate::dns::protocol::{DnsRecord, QueryType, ResultCode, TransientTtl}; + + #[test] + fn test_cache() { + let mut cache = Cache::new(); + + // Verify that no data is returned when nothing is present + if cache.lookup("www.google.com", QueryType::A).is_some() { + panic!() + } + + // Register a negative cache entry + cache.store_nxdomain("www.google.com", QueryType::A, 3600); + + // Verify that we get a response, with the NXDOMAIN flag set + if let Some(packet) = cache.lookup("www.google.com", QueryType::A) { + assert_eq!(ResultCode::NXDOMAIN, packet.header.rescode); + } + + // Register a negative cache entry with no TTL + cache.store_nxdomain("www.yahoo.com", QueryType::A, 0); + + // And check that no such result is actually returned, since it's expired + if cache.lookup("www.yahoo.com", QueryType::A).is_some() { + panic!() + } + + // Now add some actual records + let mut records = Vec::new(); + records.push(DnsRecord::A { + domain: "www.google.com".to_string(), + addr: "127.0.0.1".parse().unwrap(), + ttl: TransientTtl(3600), + }); + records.push(DnsRecord::A { + domain: "www.yahoo.com".to_string(), + addr: "127.0.0.2".parse().unwrap(), + ttl: TransientTtl(0), + }); + records.push(DnsRecord::CNAME { + domain: "www.microsoft.com".to_string(), + host: "www.somecdn.com".to_string(), + ttl: TransientTtl(3600), + }); + + cache.store(&records); + + // Test for successful lookup + if let Some(packet) = cache.lookup("www.google.com", QueryType::A) { + assert_eq!(records[0], packet.answers[0]); + } else { + panic!(); + } + + // Test for failed lookup, since no CNAME's are known for this domain + if cache.lookup("www.google.com", QueryType::CNAME).is_some() { + panic!(); + } + + // Check for successful CNAME lookup + if let Some(packet) = cache.lookup("www.microsoft.com", QueryType::CNAME) { + assert_eq!(records[2], packet.answers[0]); + } else { + panic!(); + } + + // This lookup should fail, since it has expired due to the 0 second TTL + if cache.lookup("www.yahoo.com", QueryType::A).is_some() { + panic!(); + } + + let mut records2 = Vec::new(); + records2.push(DnsRecord::A { + domain: "www.yahoo.com".to_string(), + addr: "127.0.0.2".parse().unwrap(), + ttl: TransientTtl(3600), + }); + + cache.store(&records2); + + // And now it should succeed, since the record has been store + if !cache.lookup("www.yahoo.com", QueryType::A).is_some() { + panic!(); + } + + // Check stat counter behavior + assert_eq!(3, cache.domain_entries.len()); + assert_eq!( + 1, + cache + .domain_entries + .get(&"www.google.com".to_string()) + .unwrap() + .hits + ); + assert_eq!( + 2, + cache + .domain_entries + .get(&"www.google.com".to_string()) + .unwrap() + .updates + ); + assert_eq!( + 1, + cache + .domain_entries + .get(&"www.yahoo.com".to_string()) + .unwrap() + .hits + ); + assert_eq!( + 3, + cache + .domain_entries + .get(&"www.yahoo.com".to_string()) + .unwrap() + .updates + ); + assert_eq!( + 1, + cache + .domain_entries + .get(&"www.microsoft.com".to_string()) + .unwrap() + .updates + ); + assert_eq!( + 1, + cache + .domain_entries + .get(&"www.microsoft.com".to_string()) + .unwrap() + .hits + ); + } +} diff --git a/src/dns/client.rs b/src/dns/client.rs new file mode 100644 index 0000000..7603743 --- /dev/null +++ b/src/dns/client.rs @@ -0,0 +1,400 @@ +//! client for sending DNS queries to other servers + +use std::io::Write; +use std::marker::{Send, Sync}; +use std::net::{TcpStream, UdpSocket}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::mpsc::{channel, Sender}; +use std::sync::{Arc, Mutex}; +use std::thread::{sleep, Builder}; +use std::time::Duration as SleepDuration; + +use chrono::*; +use derive_more::{Display, Error, From}; + +use crate::dns::buffer::{BytePacketBuffer, PacketBuffer, StreamPacketBuffer}; +use crate::dns::netutil::{read_packet_length, write_packet_length}; +use crate::dns::protocol::{DnsPacket, DnsQuestion, QueryType}; + +#[derive(Debug, Display, From, Error)] +pub enum ClientError { + Protocol(crate::dns::protocol::ProtocolError), + Io(std::io::Error), + PoisonedLock, + LookupFailed, + TimeOut, +} + +type Result = std::result::Result; + +pub trait DnsClient { + fn get_sent_count(&self) -> usize; + fn get_failed_count(&self) -> usize; + + fn run(&self) -> Result<()>; + fn send_query( + &self, + qname: &str, + qtype: QueryType, + server: (&str, u16), + recursive: bool, + ) -> Result; +} + +/// The UDP client +/// +/// This includes a fair bit of synchronization due to the stateless nature of UDP. +/// When many queries are sent in parallell, the response packets can come back +/// in any order. For that reason, we fire off replies on the sending thread, but +/// handle replies on a single thread. A channel is created for every response, +/// and the caller will block on the channel until the a response is received. +pub struct DnsNetworkClient { + total_sent: AtomicUsize, + total_failed: AtomicUsize, + + /// Counter for assigning packet ids + seq: AtomicUsize, + + /// The listener socket + socket: UdpSocket, + + /// Queries in progress + pending_queries: Arc>>, +} + +/// A query in progress. This struct holds the `id` if the request, and a channel +/// endpoint for returning a response back to the thread from which the query +/// was posed. +struct PendingQuery { + seq: u16, + timestamp: DateTime, + tx: Sender>, +} + +unsafe impl Send for DnsNetworkClient {} +unsafe impl Sync for DnsNetworkClient {} + +impl DnsNetworkClient { + pub fn new(port: u16) -> DnsNetworkClient { + DnsNetworkClient { + total_sent: AtomicUsize::new(0), + total_failed: AtomicUsize::new(0), + seq: AtomicUsize::new(0), + socket: UdpSocket::bind(("0.0.0.0", port)).unwrap(), + pending_queries: Arc::new(Mutex::new(Vec::new())), + } + } + + /// Send a DNS query using TCP transport + /// + /// This is much simpler than using UDP, since the kernel will take care of + /// packet ordering, connection state, timeouts etc. + pub fn send_tcp_query( + &self, + qname: &str, + qtype: QueryType, + server: (&str, u16), + recursive: bool, + ) -> Result { + let _ = self.total_sent.fetch_add(1, Ordering::Release); + + // Prepare request + let mut packet = DnsPacket::new(); + + packet.header.id = self.seq.fetch_add(1, Ordering::SeqCst) as u16; + if packet.header.id + 1 == 0xFFFF { + self.seq.compare_and_swap(0xFFFF, 0, Ordering::SeqCst); + } + + packet.header.questions = 1; + packet.header.recursion_desired = recursive; + + packet.questions.push(DnsQuestion::new(qname.into(), qtype)); + + // Send query + let mut req_buffer = BytePacketBuffer::new(); + packet.write(&mut req_buffer, 0xFFFF)?; + + let mut socket = TcpStream::connect(server)?; + + write_packet_length(&mut socket, req_buffer.pos())?; + socket.write(&req_buffer.buf[0..req_buffer.pos])?; + socket.flush()?; + + let _ = read_packet_length(&mut socket)?; + + let mut stream_buffer = StreamPacketBuffer::new(&mut socket); + let packet = DnsPacket::from_buffer(&mut stream_buffer)?; + + Ok(packet) + } + + /// Send a DNS query using UDP transport + /// + /// This will construct a query packet, and fire it off to the specified server. + /// The query is sent from the callee thread, but responses are read on a + /// worker thread, and returned to this thread through a channel. Thus this + /// method is thread safe, and can be used from any number of threads in + /// parallell. + pub fn send_udp_query( + &self, + qname: &str, + qtype: QueryType, + server: (&str, u16), + recursive: bool, + ) -> Result { + let _ = self.total_sent.fetch_add(1, Ordering::Release); + + // Prepare request + let mut packet = DnsPacket::new(); + + packet.header.id = self.seq.fetch_add(1, Ordering::SeqCst) as u16; + if packet.header.id + 1 == 0xFFFF { + self.seq.compare_and_swap(0xFFFF, 0, Ordering::SeqCst); + } + + packet.header.questions = 1; + packet.header.recursion_desired = recursive; + + packet + .questions + .push(DnsQuestion::new(qname.to_string(), qtype)); + + // Create a return channel, and add a `PendingQuery` to the list of lookups + // in progress + let (tx, rx) = channel(); + { + let mut pending_queries = self + .pending_queries + .lock() + .map_err(|_| ClientError::PoisonedLock)?; + pending_queries.push(PendingQuery { + seq: packet.header.id, + timestamp: Local::now(), + tx: tx, + }); + } + + // Send query + let mut req_buffer = BytePacketBuffer::new(); + packet.write(&mut req_buffer, 512)?; + self.socket + .send_to(&req_buffer.buf[0..req_buffer.pos], server)?; + + // Wait for response + match rx.recv() { + Ok(Some(qr)) => Ok(qr), + Ok(None) => { + let _ = self.total_failed.fetch_add(1, Ordering::Release); + Err(ClientError::TimeOut) + } + Err(_) => { + let _ = self.total_failed.fetch_add(1, Ordering::Release); + Err(ClientError::LookupFailed) + } + } + } +} + +impl DnsClient for DnsNetworkClient { + fn get_sent_count(&self) -> usize { + self.total_sent.load(Ordering::Acquire) + } + + fn get_failed_count(&self) -> usize { + self.total_failed.load(Ordering::Acquire) + } + + /// The run method launches a worker thread. Unless this thread is running, no + /// responses will ever be generated, and clients will just block indefinitely. + fn run(&self) -> Result<()> { + // Start the thread for handling incoming responses + { + let socket_copy = self.socket.try_clone()?; + let pending_queries_lock = self.pending_queries.clone(); + + Builder::new() + .name("DnsNetworkClient-worker-thread".into()) + .spawn(move || { + loop { + // Read data into a buffer + let mut res_buffer = BytePacketBuffer::new(); + match socket_copy.recv_from(&mut res_buffer.buf) { + Ok(_) => {} + Err(_) => { + continue; + } + } + + // Construct a DnsPacket from buffer, skipping the packet if parsing + // failed + let packet = match DnsPacket::from_buffer(&mut res_buffer) { + Ok(packet) => packet, + Err(err) => { + println!( + "DnsNetworkClient failed to parse packet with error: {}", + err + ); + continue; + } + }; + + // Acquire a lock on the pending_queries list, and search for a + // matching PendingQuery to which to deliver the response. + if let Ok(mut pending_queries) = pending_queries_lock.lock() { + let mut matched_query = None; + for (i, pending_query) in pending_queries.iter().enumerate() { + if pending_query.seq == packet.header.id { + // Matching query found, send the response + let _ = pending_query.tx.send(Some(packet.clone())); + + // Mark this index for removal from list + matched_query = Some(i); + + break; + } + } + + if let Some(idx) = matched_query { + pending_queries.remove(idx); + } else { + println!("Discarding response for: {:?}", packet.questions[0]); + } + } + } + })?; + } + + // Start the thread for timing out requests + { + let pending_queries_lock = self.pending_queries.clone(); + + Builder::new() + .name("DnsNetworkClient-timeout-thread".into()) + .spawn(move || { + let timeout = Duration::seconds(1); + loop { + if let Ok(mut pending_queries) = pending_queries_lock.lock() { + let mut finished_queries = Vec::new(); + for (i, pending_query) in pending_queries.iter().enumerate() { + let expires = pending_query.timestamp + timeout; + if expires < Local::now() { + let _ = pending_query.tx.send(None); + finished_queries.push(i); + } + } + + // Remove `PendingQuery` objects from the list, in reverse order + for idx in finished_queries.iter().rev() { + pending_queries.remove(*idx); + } + } + + sleep(SleepDuration::from_millis(100)); + } + })?; + } + + Ok(()) + } + + fn send_query( + &self, + qname: &str, + qtype: QueryType, + server: (&str, u16), + recursive: bool, + ) -> Result { + let packet = self.send_udp_query(qname, qtype, server, recursive)?; + if !packet.header.truncated_message { + return Ok(packet); + } + + println!("Truncated response - resending as TCP"); + self.send_tcp_query(qname, qtype, server, recursive) + } +} + +#[cfg(test)] +pub mod tests { + + use super::*; + use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType}; + + pub type StubCallback = dyn Fn(&str, QueryType, (&str, u16), bool) -> Result; + + pub struct DnsStubClient { + callback: Box, + } + + impl<'a> DnsStubClient { + pub fn new(callback: Box) -> DnsStubClient { + DnsStubClient { callback: callback } + } + } + + unsafe impl Send for DnsStubClient {} + unsafe impl Sync for DnsStubClient {} + + impl DnsClient for DnsStubClient { + fn get_sent_count(&self) -> usize { + 0 + } + + fn get_failed_count(&self) -> usize { + 0 + } + + fn run(&self) -> Result<()> { + Ok(()) + } + + fn send_query( + &self, + qname: &str, + qtype: QueryType, + server: (&str, u16), + recursive: bool, + ) -> Result { + (self.callback)(qname, qtype, server, recursive) + } + } + + #[test] + pub fn test_udp_client() { + let client = DnsNetworkClient::new(31456); + client.run().unwrap(); + + let res = client + .send_udp_query("google.com", QueryType::A, ("8.8.8.8", 53), true) + .unwrap(); + + assert_eq!(res.questions[0].name, "google.com"); + assert!(res.answers.len() > 0); + + match res.answers[0] { + DnsRecord::A { ref domain, .. } => { + assert_eq!("google.com", domain); + } + _ => panic!(), + } + } + + #[test] + pub fn test_tcp_client() { + let client = DnsNetworkClient::new(31457); + let res = client + .send_tcp_query("google.com", QueryType::A, ("8.8.8.8", 53), true) + .unwrap(); + + assert_eq!(res.questions[0].name, "google.com"); + assert!(res.answers.len() > 0); + + match res.answers[0] { + DnsRecord::A { ref domain, .. } => { + assert_eq!("google.com", domain); + } + _ => panic!(), + } + } +} diff --git a/src/dns/context.rs b/src/dns/context.rs new file mode 100644 index 0000000..0490c55 --- /dev/null +++ b/src/dns/context.rs @@ -0,0 +1,140 @@ +//! The `ServerContext in this thread holds the common state across the server + +use std::fs; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +use derive_more::{Display, Error, From}; + +use crate::dns::authority::Authority; +use crate::dns::cache::SynchronizedCache; +use crate::dns::client::{DnsClient, DnsNetworkClient}; +use crate::dns::resolve::{DnsResolver, ForwardingDnsResolver, RecursiveDnsResolver}; + +#[derive(Debug, Display, From, Error)] +pub enum ContextError { + Authority(crate::dns::authority::AuthorityError), + Client(crate::dns::client::ClientError), + Io(std::io::Error), +} + +type Result = std::result::Result; + +pub struct ServerStatistics { + pub tcp_query_count: AtomicUsize, + pub udp_query_count: AtomicUsize, +} + +impl ServerStatistics { + pub fn get_tcp_query_count(&self) -> usize { + self.tcp_query_count.load(Ordering::Acquire) + } + + pub fn get_udp_query_count(&self) -> usize { + self.udp_query_count.load(Ordering::Acquire) + } +} + +pub enum ResolveStrategy { + Recursive, + Forward { host: String, port: u16 }, +} + +pub struct ServerContext { + pub authority: Authority, + pub cache: SynchronizedCache, + pub client: Box, + pub dns_port: u16, + pub api_port: u16, + pub resolve_strategy: ResolveStrategy, + pub allow_recursive: bool, + pub enable_udp: bool, + pub enable_tcp: bool, + pub enable_api: bool, + pub statistics: ServerStatistics, + pub zones_dir: &'static str +} + +impl Default for ServerContext { + fn default() -> Self { + ServerContext::new() + } +} + +impl ServerContext { + pub fn new() -> ServerContext { + ServerContext { + authority: Authority::new(), + cache: SynchronizedCache::new(), + client: Box::new(DnsNetworkClient::new(34255)), + dns_port: 53, + api_port: 5380, + resolve_strategy: ResolveStrategy::Recursive, + allow_recursive: true, + enable_udp: true, + enable_tcp: true, + enable_api: true, + statistics: ServerStatistics { + tcp_query_count: AtomicUsize::new(0), + udp_query_count: AtomicUsize::new(0), + }, + zones_dir: "zones", + } + } + + pub fn initialize(&mut self) -> Result<()> { + // Create zones directory if it doesn't exist + fs::create_dir_all(self.zones_dir)?; + + // Start UDP client thread + self.client.run()?; + + // Load authority data + self.authority.load()?; + + Ok(()) + } + + pub fn create_resolver(&self, ptr: Arc) -> Box { + match self.resolve_strategy { + ResolveStrategy::Recursive => Box::new(RecursiveDnsResolver::new(ptr)), + ResolveStrategy::Forward { ref host, port } => { + Box::new(ForwardingDnsResolver::new(ptr, (host.clone(), port))) + } + } + } +} + +#[cfg(test)] +pub mod tests { + + use std::sync::atomic::AtomicUsize; + use std::sync::Arc; + + use crate::dns::authority::Authority; + use crate::dns::cache::SynchronizedCache; + + use crate::dns::client::tests::{DnsStubClient, StubCallback}; + + use super::*; + + pub fn create_test_context(callback: Box) -> Arc { + Arc::new(ServerContext { + authority: Authority::new(), + cache: SynchronizedCache::new(), + client: Box::new(DnsStubClient::new(callback)), + dns_port: 53, + api_port: 5380, + resolve_strategy: ResolveStrategy::Recursive, + allow_recursive: true, + enable_udp: true, + enable_tcp: true, + enable_api: true, + statistics: ServerStatistics { + tcp_query_count: AtomicUsize::new(0), + udp_query_count: AtomicUsize::new(0), + }, + zones_dir: "zones", + }) + } +} diff --git a/src/dns/mod.rs b/src/dns/mod.rs new file mode 100644 index 0000000..1783419 --- /dev/null +++ b/src/dns/mod.rs @@ -0,0 +1,26 @@ +/* +Copyright 2018 Emil Hernvall + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +*/ + +//! The dns module implements the DNS protocol and the related functions + +pub mod authority; +pub mod buffer; +pub mod cache; +pub mod client; +pub mod context; +pub mod protocol; +pub mod resolve; +pub mod server; + +mod netutil; diff --git a/src/dns/netutil.rs b/src/dns/netutil.rs new file mode 100644 index 0000000..ed2a7bc --- /dev/null +++ b/src/dns/netutil.rs @@ -0,0 +1,19 @@ +use std::io::{Read, Result, Write}; +use std::net::TcpStream; + +pub fn read_packet_length(stream: &mut TcpStream) -> Result { + let mut len_buffer = [0; 2]; + stream.read(&mut len_buffer)?; + + Ok(((len_buffer[0] as u16) << 8) | (len_buffer[1] as u16)) +} + +pub fn write_packet_length(stream: &mut TcpStream, len: usize) -> Result<()> { + let mut len_buffer = [0; 2]; + len_buffer[0] = (len >> 8) as u8; + len_buffer[1] = (len & 0xFF) as u8; + + stream.write(&len_buffer)?; + + Ok(()) +} diff --git a/src/dns/protocol.rs b/src/dns/protocol.rs new file mode 100644 index 0000000..b22c665 --- /dev/null +++ b/src/dns/protocol.rs @@ -0,0 +1,1096 @@ +//! implements the DNS protocol in a transport agnostic fashion + +//use std::io::{Error, ErrorKind}; +use std::cmp::Ordering; +use std::fmt; +use std::hash::{Hash, Hasher}; +use std::net::{Ipv4Addr, Ipv6Addr}; + +use derive_more::{Display, Error, From}; +use rand::random; +use serde::{Deserialize, Serialize}; + +use crate::dns::buffer::{PacketBuffer, VectorPacketBuffer}; + +#[derive(Debug, Display, From, Error)] +pub enum ProtocolError { + Buffer(crate::dns::buffer::BufferError), + Io(std::io::Error), +} + +type Result = std::result::Result; + +/// `QueryType` represents the requested Record Type of a query +/// +/// The specific type UNKNOWN that an integer parameter in order to retain the +/// id of an unknown query when compiling the reply. An integer can be converted +/// to a querytype using the `from_num` function, and back to an integer using +/// the `to_num` method. +#[derive(PartialEq, Eq, Debug, Clone, Hash, Copy, Serialize, Deserialize)] +pub enum QueryType { + UNKNOWN(u16), + A, // 1 + NS, // 2 + CNAME, // 5 + SOA, // 6 + MX, // 15 + TXT, // 16 + AAAA, // 28 + SRV, // 33 + OPT, // 41 +} + +impl QueryType { + pub fn to_num(&self) -> u16 { + match *self { + QueryType::UNKNOWN(x) => x, + QueryType::A => 1, + QueryType::NS => 2, + QueryType::CNAME => 5, + QueryType::SOA => 6, + QueryType::MX => 15, + QueryType::TXT => 16, + QueryType::AAAA => 28, + QueryType::SRV => 33, + QueryType::OPT => 41, + } + } + + pub fn from_num(num: u16) -> QueryType { + match num { + 1 => QueryType::A, + 2 => QueryType::NS, + 5 => QueryType::CNAME, + 6 => QueryType::SOA, + 15 => QueryType::MX, + 16 => QueryType::TXT, + 28 => QueryType::AAAA, + 33 => QueryType::SRV, + 41 => QueryType::OPT, + _ => QueryType::UNKNOWN(num), + } + } +} + +#[derive(Copy, Clone, Debug, Eq, Ord, Serialize, Deserialize)] +pub struct TransientTtl(pub u32); + +impl PartialEq for TransientTtl { + fn eq(&self, _: &TransientTtl) -> bool { + true + } +} + +impl PartialOrd for TransientTtl { + fn partial_cmp(&self, _: &TransientTtl) -> Option { + Some(Ordering::Equal) + } +} + +impl Hash for TransientTtl { + fn hash(&self, _: &mut H) + where + H: Hasher, + { + // purposely left empty + } +} + +/// `DnsRecord` is the primary representation of a DNS record +/// +/// This enumeration is used for reading as well as writing records, from network +/// and from disk (for storage of authority data). +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum DnsRecord { + UNKNOWN { + domain: String, + qtype: u16, + data_len: u16, + ttl: TransientTtl, + }, // 0 + A { + domain: String, + addr: Ipv4Addr, + ttl: TransientTtl, + }, // 1 + NS { + domain: String, + host: String, + ttl: TransientTtl, + }, // 2 + CNAME { + domain: String, + host: String, + ttl: TransientTtl, + }, // 5 + SOA { + domain: String, + m_name: String, + r_name: String, + serial: u32, + refresh: u32, + retry: u32, + expire: u32, + minimum: u32, + ttl: TransientTtl, + }, // 6 + MX { + domain: String, + priority: u16, + host: String, + ttl: TransientTtl, + }, // 15 + TXT { + domain: String, + data: String, + ttl: TransientTtl, + }, // 16 + AAAA { + domain: String, + addr: Ipv6Addr, + ttl: TransientTtl, + }, // 28 + SRV { + domain: String, + priority: u16, + weight: u16, + port: u16, + host: String, + ttl: TransientTtl, + }, // 33 + OPT { + packet_len: u16, + flags: u32, + data: String, + }, // 41 +} + +impl DnsRecord { + pub fn read(buffer: &mut T) -> Result { + let mut domain = String::new(); + buffer.read_qname(&mut domain)?; + + let qtype_num = buffer.read_u16()?; + let qtype = QueryType::from_num(qtype_num); + let class = buffer.read_u16()?; + let ttl = buffer.read_u32()?; + let data_len = buffer.read_u16()?; + + match qtype { + QueryType::A => { + let raw_addr = buffer.read_u32()?; + let addr = Ipv4Addr::new( + ((raw_addr >> 24) & 0xFF) as u8, + ((raw_addr >> 16) & 0xFF) as u8, + ((raw_addr >> 8) & 0xFF) as u8, + ((raw_addr >> 0) & 0xFF) as u8, + ); + + Ok(DnsRecord::A { + domain: domain, + addr: addr, + ttl: TransientTtl(ttl), + }) + } + QueryType::AAAA => { + let raw_addr1 = buffer.read_u32()?; + let raw_addr2 = buffer.read_u32()?; + let raw_addr3 = buffer.read_u32()?; + let raw_addr4 = buffer.read_u32()?; + let addr = Ipv6Addr::new( + ((raw_addr1 >> 16) & 0xFFFF) as u16, + ((raw_addr1 >> 0) & 0xFFFF) as u16, + ((raw_addr2 >> 16) & 0xFFFF) as u16, + ((raw_addr2 >> 0) & 0xFFFF) as u16, + ((raw_addr3 >> 16) & 0xFFFF) as u16, + ((raw_addr3 >> 0) & 0xFFFF) as u16, + ((raw_addr4 >> 16) & 0xFFFF) as u16, + ((raw_addr4 >> 0) & 0xFFFF) as u16, + ); + + Ok(DnsRecord::AAAA { + domain: domain, + addr: addr, + ttl: TransientTtl(ttl), + }) + } + QueryType::NS => { + let mut ns = String::new(); + buffer.read_qname(&mut ns)?; + + Ok(DnsRecord::NS { + domain: domain, + host: ns, + ttl: TransientTtl(ttl), + }) + } + QueryType::CNAME => { + let mut cname = String::new(); + buffer.read_qname(&mut cname)?; + + Ok(DnsRecord::CNAME { + domain: domain, + host: cname, + ttl: TransientTtl(ttl), + }) + } + QueryType::SRV => { + let priority = buffer.read_u16()?; + let weight = buffer.read_u16()?; + let port = buffer.read_u16()?; + + let mut srv = String::new(); + buffer.read_qname(&mut srv)?; + + Ok(DnsRecord::SRV { + domain: domain, + priority: priority, + weight: weight, + port: port, + host: srv, + ttl: TransientTtl(ttl), + }) + } + QueryType::MX => { + let priority = buffer.read_u16()?; + let mut mx = String::new(); + buffer.read_qname(&mut mx)?; + + Ok(DnsRecord::MX { + domain: domain, + priority: priority, + host: mx, + ttl: TransientTtl(ttl), + }) + } + QueryType::SOA => { + let mut m_name = String::new(); + buffer.read_qname(&mut m_name)?; + + let mut r_name = String::new(); + buffer.read_qname(&mut r_name)?; + + let serial = buffer.read_u32()?; + let refresh = buffer.read_u32()?; + let retry = buffer.read_u32()?; + let expire = buffer.read_u32()?; + let minimum = buffer.read_u32()?; + + Ok(DnsRecord::SOA { + domain: domain, + m_name: m_name, + r_name: r_name, + serial: serial, + refresh: refresh, + retry: retry, + expire: expire, + minimum: minimum, + ttl: TransientTtl(ttl), + }) + } + QueryType::TXT => { + let mut txt = String::new(); + + let cur_pos = buffer.pos(); + txt.push_str(&String::from_utf8_lossy( + buffer.get_range(cur_pos, data_len as usize)?, + )); + + buffer.step(data_len as usize)?; + + Ok(DnsRecord::TXT { + domain: domain, + data: txt, + ttl: TransientTtl(ttl), + }) + } + QueryType::OPT => { + let mut data = String::new(); + + let cur_pos = buffer.pos(); + data.push_str(&String::from_utf8_lossy( + buffer.get_range(cur_pos, data_len as usize)?, + )); + buffer.step(data_len as usize)?; + + Ok(DnsRecord::OPT { + packet_len: class, + flags: ttl, + data: data, + }) + } + QueryType::UNKNOWN(_) => { + buffer.step(data_len as usize)?; + + Ok(DnsRecord::UNKNOWN { + domain: domain, + qtype: qtype_num, + data_len: data_len, + ttl: TransientTtl(ttl), + }) + } + } + } + + pub fn write(&self, buffer: &mut T) -> Result { + let start_pos = buffer.pos(); + + match *self { + DnsRecord::A { + ref domain, + ref addr, + ttl: TransientTtl(ttl), + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::A.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + buffer.write_u16(4)?; + + let octets = addr.octets(); + buffer.write_u8(octets[0])?; + buffer.write_u8(octets[1])?; + buffer.write_u8(octets[2])?; + buffer.write_u8(octets[3])?; + } + DnsRecord::AAAA { + ref domain, + ref addr, + ttl: TransientTtl(ttl), + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::AAAA.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + buffer.write_u16(16)?; + + for octet in &addr.segments() { + buffer.write_u16(*octet)?; + } + } + DnsRecord::NS { + ref domain, + ref host, + ttl: TransientTtl(ttl), + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::NS.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::CNAME { + ref domain, + ref host, + ttl: TransientTtl(ttl), + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::CNAME.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::SRV { + ref domain, + priority, + weight, + port, + ref host, + ttl: TransientTtl(ttl), + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::SRV.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + + buffer.write_u16(priority)?; + buffer.write_u16(weight)?; + buffer.write_u16(port)?; + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::MX { + ref domain, + priority, + ref host, + ttl: TransientTtl(ttl), + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::MX.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + + buffer.write_u16(priority)?; + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::SOA { + ref domain, + ref m_name, + ref r_name, + serial, + refresh, + retry, + expire, + minimum, + ttl: TransientTtl(ttl), + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::SOA.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + + buffer.write_qname(m_name)?; + buffer.write_qname(r_name)?; + buffer.write_u32(serial)?; + buffer.write_u32(refresh)?; + buffer.write_u32(retry)?; + buffer.write_u32(expire)?; + buffer.write_u32(minimum)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::TXT { + ref domain, + ref data, + ttl: TransientTtl(ttl), + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::TXT.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + buffer.write_u16(data.len() as u16)?; + + for b in data.as_bytes() { + buffer.write_u8(*b)?; + } + } + DnsRecord::OPT { .. } => {} + DnsRecord::UNKNOWN { .. } => { + println!("Skipping record: {:?}", self); + } + } + + Ok(buffer.pos() - start_pos) + } + + pub fn get_querytype(&self) -> QueryType { + match *self { + DnsRecord::A { .. } => QueryType::A, + DnsRecord::AAAA { .. } => QueryType::AAAA, + DnsRecord::NS { .. } => QueryType::NS, + DnsRecord::CNAME { .. } => QueryType::CNAME, + DnsRecord::SRV { .. } => QueryType::SRV, + DnsRecord::MX { .. } => QueryType::MX, + DnsRecord::UNKNOWN { qtype, .. } => QueryType::UNKNOWN(qtype), + DnsRecord::SOA { .. } => QueryType::SOA, + DnsRecord::TXT { .. } => QueryType::TXT, + DnsRecord::OPT { .. } => QueryType::OPT, + } + } + + pub fn get_domain(&self) -> Option { + match *self { + DnsRecord::A { ref domain, .. } + | DnsRecord::AAAA { ref domain, .. } + | DnsRecord::NS { ref domain, .. } + | DnsRecord::CNAME { ref domain, .. } + | DnsRecord::SRV { ref domain, .. } + | DnsRecord::MX { ref domain, .. } + | DnsRecord::UNKNOWN { ref domain, .. } + | DnsRecord::SOA { ref domain, .. } + | DnsRecord::TXT { ref domain, .. } => Some(domain.clone()), + DnsRecord::OPT { .. } => None, + } + } + + pub fn get_ttl(&self) -> u32 { + match *self { + DnsRecord::A { + ttl: TransientTtl(ttl), + .. + } + | DnsRecord::AAAA { + ttl: TransientTtl(ttl), + .. + } + | DnsRecord::NS { + ttl: TransientTtl(ttl), + .. + } + | DnsRecord::CNAME { + ttl: TransientTtl(ttl), + .. + } + | DnsRecord::SRV { + ttl: TransientTtl(ttl), + .. + } + | DnsRecord::MX { + ttl: TransientTtl(ttl), + .. + } + | DnsRecord::UNKNOWN { + ttl: TransientTtl(ttl), + .. + } + | DnsRecord::SOA { + ttl: TransientTtl(ttl), + .. + } + | DnsRecord::TXT { + ttl: TransientTtl(ttl), + .. + } => ttl, + DnsRecord::OPT { .. } => 0, + } + } +} + +/// The result code for a DNS query, as described in the specification +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum ResultCode { + NOERROR = 0, + FORMERR = 1, + SERVFAIL = 2, + NXDOMAIN = 3, + NOTIMP = 4, + REFUSED = 5, +} + +impl Default for ResultCode { + fn default() -> Self { + ResultCode::NOERROR + } +} + +impl ResultCode { + pub fn from_num(num: u8) -> ResultCode { + match num { + 1 => ResultCode::FORMERR, + 2 => ResultCode::SERVFAIL, + 3 => ResultCode::NXDOMAIN, + 4 => ResultCode::NOTIMP, + 5 => ResultCode::REFUSED, + 0 | _ => ResultCode::NOERROR, + } + } +} + +/// Representation of a DNS header +#[derive(Clone, Debug, Default)] +pub struct DnsHeader { + pub id: u16, // 16 bits + + pub recursion_desired: bool, // 1 bit + pub truncated_message: bool, // 1 bit + pub authoritative_answer: bool, // 1 bit + pub opcode: u8, // 4 bits + pub response: bool, // 1 bit + + pub rescode: ResultCode, // 4 bits + pub checking_disabled: bool, // 1 bit + pub authed_data: bool, // 1 bit + pub z: bool, // 1 bit + pub recursion_available: bool, // 1 bit + + pub questions: u16, // 16 bits + pub answers: u16, // 16 bits + pub authoritative_entries: u16, // 16 bits + pub resource_entries: u16, // 16 bits +} + +impl DnsHeader { + pub fn new() -> DnsHeader { + DnsHeader { + id: 0, + + recursion_desired: false, + truncated_message: false, + authoritative_answer: false, + opcode: 0, + response: false, + + rescode: ResultCode::NOERROR, + checking_disabled: false, + authed_data: false, + z: false, + recursion_available: false, + + questions: 0, + answers: 0, + authoritative_entries: 0, + resource_entries: 0, + } + } + + pub fn write(&self, buffer: &mut T) -> Result<()> { + buffer.write_u16(self.id)?; + + buffer.write_u8( + (self.recursion_desired as u8) + | ((self.truncated_message as u8) << 1) + | ((self.authoritative_answer as u8) << 2) + | (self.opcode << 3) + | ((self.response as u8) << 7) as u8, + )?; + + buffer.write_u8( + (self.rescode.clone() as u8) + | ((self.checking_disabled as u8) << 4) + | ((self.authed_data as u8) << 5) + | ((self.z as u8) << 6) + | ((self.recursion_available as u8) << 7), + )?; + + buffer.write_u16(self.questions)?; + buffer.write_u16(self.answers)?; + buffer.write_u16(self.authoritative_entries)?; + buffer.write_u16(self.resource_entries)?; + + Ok(()) + } + + pub fn binary_len(&self) -> usize { + 12 + } + + pub fn read(&mut self, buffer: &mut T) -> Result<()> { + self.id = buffer.read_u16()?; + + let flags = buffer.read_u16()?; + let a = (flags >> 8) as u8; + let b = (flags & 0xFF) as u8; + self.recursion_desired = (a & (1 << 0)) > 0; + self.truncated_message = (a & (1 << 1)) > 0; + self.authoritative_answer = (a & (1 << 2)) > 0; + self.opcode = (a >> 3) & 0x0F; + self.response = (a & (1 << 7)) > 0; + + self.rescode = ResultCode::from_num(b & 0x0F); + self.checking_disabled = (b & (1 << 4)) > 0; + self.authed_data = (b & (1 << 5)) > 0; + self.z = (b & (1 << 6)) > 0; + self.recursion_available = (b & (1 << 7)) > 0; + + self.questions = buffer.read_u16()?; + self.answers = buffer.read_u16()?; + self.authoritative_entries = buffer.read_u16()?; + self.resource_entries = buffer.read_u16()?; + + // Return the constant header size + Ok(()) + } +} + +impl fmt::Display for DnsHeader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "DnsHeader:\n")?; + write!(f, "\tid: {0}\n", self.id)?; + + write!(f, "\trecursion_desired: {0}\n", self.recursion_desired)?; + write!(f, "\ttruncated_message: {0}\n", self.truncated_message)?; + write!( + f, + "\tauthoritative_answer: {0}\n", + self.authoritative_answer + )?; + write!(f, "\topcode: {0}\n", self.opcode)?; + write!(f, "\tresponse: {0}\n", self.response)?; + + write!(f, "\trescode: {:?}\n", self.rescode)?; + write!(f, "\tchecking_disabled: {0}\n", self.checking_disabled)?; + write!(f, "\tauthed_data: {0}\n", self.authed_data)?; + write!(f, "\tz: {0}\n", self.z)?; + write!(f, "\trecursion_available: {0}\n", self.recursion_available)?; + + write!(f, "\tquestions: {0}\n", self.questions)?; + write!(f, "\tanswers: {0}\n", self.answers)?; + write!( + f, + "\tauthoritative_entries: {0}\n", + self.authoritative_entries + )?; + write!(f, "\tresource_entries: {0}\n", self.resource_entries)?; + + Ok(()) + } +} + +/// Representation of a DNS question +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DnsQuestion { + pub name: String, + pub qtype: QueryType, +} + +impl DnsQuestion { + pub fn new(name: String, qtype: QueryType) -> DnsQuestion { + DnsQuestion { + name: name, + qtype: qtype, + } + } + + pub fn binary_len(&self) -> usize { + self.name + .split('.') + .map(|x| x.len() + 1) + .fold(1, |x, y| x + y) + } + + pub fn write(&self, buffer: &mut T) -> Result<()> { + buffer.write_qname(&self.name)?; + + let typenum = self.qtype.to_num(); + buffer.write_u16(typenum)?; + buffer.write_u16(1)?; + + Ok(()) + } + + pub fn read(&mut self, buffer: &mut T) -> Result<()> { + buffer.read_qname(&mut self.name)?; + self.qtype = QueryType::from_num(buffer.read_u16()?); // qtype + let _ = buffer.read_u16()?; // class + + Ok(()) + } +} + +impl fmt::Display for DnsQuestion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "DnsQuestion:\n")?; + write!(f, "\tname: {0}\n", self.name)?; + write!(f, "\trecord type: {:?}\n", self.qtype)?; + + Ok(()) + } +} + +/// Representation of a complete DNS packet +/// +/// This is the work horse of the server. A DNS packet can be read and written +/// in a single operation, and is used both by the network facing components and +/// internally by the resolver, cache and authority. +#[derive(Clone, Debug, Default)] +pub struct DnsPacket { + pub header: DnsHeader, + pub questions: Vec, + pub answers: Vec, + pub authorities: Vec, + pub resources: Vec, +} + +impl DnsPacket { + pub fn new() -> DnsPacket { + DnsPacket { + header: DnsHeader::new(), + questions: Vec::new(), + answers: Vec::new(), + authorities: Vec::new(), + resources: Vec::new(), + } + } + + pub fn from_buffer(buffer: &mut T) -> Result { + let mut result = DnsPacket::new(); + result.header.read(buffer)?; + + for _ in 0..result.header.questions { + let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0)); + question.read(buffer)?; + result.questions.push(question); + } + + for _ in 0..result.header.answers { + let rec = DnsRecord::read(buffer)?; + result.answers.push(rec); + } + for _ in 0..result.header.authoritative_entries { + let rec = DnsRecord::read(buffer)?; + result.authorities.push(rec); + } + for _ in 0..result.header.resource_entries { + let rec = DnsRecord::read(buffer)?; + result.resources.push(rec); + } + + Ok(result) + } + + #[allow(dead_code)] + pub fn print(&self) { + println!("{}", self.header); + + println!("questions:"); + for x in &self.questions { + println!("\t{:?}", x); + } + + println!("answers:"); + for x in &self.answers { + println!("\t{:?}", x); + } + + println!("authorities:"); + for x in &self.authorities { + println!("\t{:?}", x); + } + + println!("resources:"); + for x in &self.resources { + println!("\t{:?}", x); + } + } + + pub fn get_ttl_from_soa(&self) -> Option { + for answer in &self.authorities { + if let DnsRecord::SOA { minimum, .. } = *answer { + return Some(minimum); + } + } + + None + } + + pub fn get_random_a(&self) -> Option { + if !self.answers.is_empty() { + let idx = random::() % self.answers.len(); + let a_record = &self.answers[idx]; + if let DnsRecord::A { ref addr, .. } = *a_record { + return Some(addr.to_string()); + } + } + + None + } + + pub fn get_unresolved_cnames(&self) -> Vec { + let mut unresolved = Vec::new(); + for answer in &self.answers { + let mut matched = false; + if let DnsRecord::CNAME { ref host, .. } = *answer { + for answer2 in &self.answers { + if let DnsRecord::A { ref domain, .. } = *answer2 { + if domain == host { + matched = true; + break; + } + } + } + } + + if !matched { + unresolved.push(answer.clone()); + } + } + + unresolved + } + + pub fn get_resolved_ns(&self, qname: &str) -> Option { + let mut new_authorities = Vec::new(); + for auth in &self.authorities { + if let DnsRecord::NS { + ref domain, + ref host, + .. + } = *auth + { + if !qname.ends_with(domain) { + continue; + } + + for rsrc in &self.resources { + if let DnsRecord::A { + ref domain, + ref addr, + ttl: TransientTtl(ttl), + } = *rsrc + { + if domain != host { + continue; + } + + let rec = DnsRecord::A { + domain: host.clone(), + addr: *addr, + ttl: TransientTtl(ttl), + }; + + new_authorities.push(rec); + } + } + } + } + + if !new_authorities.is_empty() { + let idx = random::() % new_authorities.len(); + if let DnsRecord::A { addr, .. } = new_authorities[idx] { + return Some(addr.to_string()); + } + } + + None + } + + pub fn get_unresolved_ns(&self, qname: &str) -> Option { + let mut new_authorities = Vec::new(); + for auth in &self.authorities { + if let DnsRecord::NS { + ref domain, + ref host, + .. + } = *auth + { + if !qname.ends_with(domain) { + continue; + } + + new_authorities.push(host); + } + } + + if !new_authorities.is_empty() { + let idx = random::() % new_authorities.len(); + return Some(new_authorities[idx].clone()); + } + + None + } + + pub fn write(&mut self, buffer: &mut T, max_size: usize) -> Result<()> { + let mut test_buffer = VectorPacketBuffer::new(); + + let mut size = self.header.binary_len(); + for ref question in &self.questions { + size += question.binary_len(); + question.write(&mut test_buffer)?; + } + + let mut record_count = self.answers.len() + self.authorities.len() + self.resources.len(); + + for (i, rec) in self + .answers + .iter() + .chain(self.authorities.iter()) + .chain(self.resources.iter()) + .enumerate() + { + size += rec.write(&mut test_buffer)?; + if size > max_size { + record_count = i; + self.header.truncated_message = true; + break; + } else if i < self.answers.len() { + self.header.answers += 1; + } else if i < self.answers.len() + self.authorities.len() { + self.header.authoritative_entries += 1; + } else { + self.header.resource_entries += 1; + } + } + + self.header.questions = self.questions.len() as u16; + + self.header.write(buffer)?; + + for question in &self.questions { + question.write(buffer)?; + } + + for rec in self + .answers + .iter() + .chain(self.authorities.iter()) + .chain(self.resources.iter()) + .take(record_count) + { + rec.write(buffer)?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::dns::buffer::{PacketBuffer, VectorPacketBuffer}; + + #[test] + fn test_packet() { + let mut packet = DnsPacket::new(); + packet.header.id = 1337; + packet.header.response = true; + + packet + .questions + .push(DnsQuestion::new("google.com".to_string(), QueryType::NS)); + //packet.answers.push(DnsRecord::A("ns1.google.com".to_string(), "127.0.0.1".parse::().unwrap(), 3600)); + packet.answers.push(DnsRecord::NS { + domain: "google.com".to_string(), + host: "ns1.google.com".to_string(), + ttl: TransientTtl(3600), + }); + packet.answers.push(DnsRecord::NS { + domain: "google.com".to_string(), + host: "ns2.google.com".to_string(), + ttl: TransientTtl(3600), + }); + packet.answers.push(DnsRecord::NS { + domain: "google.com".to_string(), + host: "ns3.google.com".to_string(), + ttl: TransientTtl(3600), + }); + packet.answers.push(DnsRecord::NS { + domain: "google.com".to_string(), + host: "ns4.google.com".to_string(), + ttl: TransientTtl(3600), + }); + + let mut buffer = VectorPacketBuffer::new(); + packet.write(&mut buffer, 0xFFFF).unwrap(); + + buffer.seek(0).unwrap(); + + let parsed_packet = DnsPacket::from_buffer(&mut buffer).unwrap(); + + assert_eq!(packet.questions[0], parsed_packet.questions[0]); + assert_eq!(packet.answers[0], parsed_packet.answers[0]); + assert_eq!(packet.answers[1], parsed_packet.answers[1]); + assert_eq!(packet.answers[2], parsed_packet.answers[2]); + assert_eq!(packet.answers[3], parsed_packet.answers[3]); + } +} diff --git a/src/dns/resolve.rs b/src/dns/resolve.rs new file mode 100644 index 0000000..6bcfbc3 --- /dev/null +++ b/src/dns/resolve.rs @@ -0,0 +1,569 @@ +//! resolver implementations implementing different strategies for answering +//! incoming queries + +use std::sync::Arc; +use std::vec::Vec; + +use derive_more::{Display, Error, From}; + +use crate::dns::context::ServerContext; +use crate::dns::protocol::{DnsPacket, QueryType, ResultCode}; + +#[derive(Debug, Display, From, Error)] +pub enum ResolveError { + Client(crate::dns::client::ClientError), + Cache(crate::dns::cache::CacheError), + Io(std::io::Error), + NoServerFound, +} + +type Result = std::result::Result; + +pub trait DnsResolver { + fn get_context(&self) -> Arc; + + fn resolve(&mut self, qname: &str, qtype: QueryType, recursive: bool) -> Result { + if let QueryType::UNKNOWN(_) = qtype { + let mut packet = DnsPacket::new(); + packet.header.rescode = ResultCode::NOTIMP; + return Ok(packet); + } + + let context = self.get_context(); + + if let Some(qr) = context.authority.query(qname, qtype) { + return Ok(qr); + } + + if !recursive || !context.allow_recursive { + let mut packet = DnsPacket::new(); + packet.header.rescode = ResultCode::REFUSED; + return Ok(packet); + } + + if let Some(qr) = context.cache.lookup(qname, qtype) { + return Ok(qr); + } + + if qtype == QueryType::A || qtype == QueryType::AAAA { + if let Some(qr) = context.cache.lookup(qname, QueryType::CNAME) { + return Ok(qr); + } + } + + self.perform(qname, qtype) + } + + fn perform(&mut self, qname: &str, qtype: QueryType) -> Result; +} + +/// A Forwarding DNS Resolver +/// +/// This resolver uses an external DNS server to service a query +pub struct ForwardingDnsResolver { + context: Arc, + server: (String, u16), +} + +impl ForwardingDnsResolver { + pub fn new(context: Arc, server: (String, u16)) -> ForwardingDnsResolver { + ForwardingDnsResolver { + context: context, + server: server, + } + } +} + +impl DnsResolver for ForwardingDnsResolver { + fn get_context(&self) -> Arc { + self.context.clone() + } + + fn perform(&mut self, qname: &str, qtype: QueryType) -> Result { + let &(ref host, port) = &self.server; + let result = self + .context + .client + .send_query(qname, qtype, (host.as_str(), port), true)?; + + self.context.cache.store(&result.answers)?; + + Ok(result) + } +} + +/// A Recursive DNS resolver +/// +/// This resolver can answer any request using the root servers of the internet +pub struct RecursiveDnsResolver { + context: Arc, +} + +impl RecursiveDnsResolver { + pub fn new(context: Arc) -> RecursiveDnsResolver { + RecursiveDnsResolver { context: context } + } +} + +impl DnsResolver for RecursiveDnsResolver { + fn get_context(&self) -> Arc { + self.context.clone() + } + + fn perform(&mut self, qname: &str, qtype: QueryType) -> Result { + // Find the closest name server by splitting the label and progessively + // moving towards the root servers. I.e. check "google.com", then "com", + // and finally "". + let mut tentative_ns = None; + + let labels = qname.split('.').collect::>(); + for lbl_idx in 0..labels.len() + 1 { + let domain = labels[lbl_idx..].join("."); + + match self + .context + .cache + .lookup(&domain, QueryType::NS) + .and_then(|qr| qr.get_unresolved_ns(&domain)) + .and_then(|ns| self.context.cache.lookup(&ns, QueryType::A)) + .and_then(|qr| qr.get_random_a()) + { + Some(addr) => { + tentative_ns = Some(addr); + break; + } + None => continue, + } + } + + let mut ns = tentative_ns.ok_or_else(|| ResolveError::NoServerFound)?; + + // Start querying name servers + loop { + println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns); + + let ns_copy = ns.clone(); + + let server = (ns_copy.as_str(), 53); + let response = self + .context + .client + .send_query(qname, qtype.clone(), server, false)?; + + // If we've got an actual answer, we're done! + if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { + let _ = self.context.cache.store(&response.answers); + let _ = self.context.cache.store(&response.authorities); + let _ = self.context.cache.store(&response.resources); + return Ok(response.clone()); + } + + if response.header.rescode == ResultCode::NXDOMAIN { + if let Some(ttl) = response.get_ttl_from_soa() { + let _ = self.context.cache.store_nxdomain(qname, qtype, ttl); + } + return Ok(response.clone()); + } + + // Otherwise, try to find a new nameserver based on NS and a + // corresponding A record in the additional section + if let Some(new_ns) = response.get_resolved_ns(qname) { + // If there is such a record, we can retry the loop with that NS + ns = new_ns.clone(); + let _ = self.context.cache.store(&response.answers); + let _ = self.context.cache.store(&response.authorities); + let _ = self.context.cache.store(&response.resources); + + continue; + } + + // If not, we'll have to resolve the ip of a NS record + let new_ns_name = match response.get_unresolved_ns(qname) { + Some(x) => x, + None => return Ok(response.clone()), + }; + + // Recursively resolve the NS + let recursive_response = self.resolve(&new_ns_name, QueryType::A, true)?; + + // Pick a random IP and restart + if let Some(new_ns) = recursive_response.get_random_a() { + ns = new_ns.clone(); + } else { + return Ok(response.clone()); + } + } + } +} + +#[cfg(test)] +mod tests { + + use std::sync::Arc; + + use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode, TransientTtl}; + + use super::*; + + use crate::dns::context::tests::create_test_context; + use crate::dns::context::ResolveStrategy; + + #[test] + fn test_forwarding_resolver() { + let mut context = create_test_context(Box::new(|qname, _, _, _| { + let mut packet = DnsPacket::new(); + + if qname == "google.com" { + packet.answers.push(DnsRecord::A { + domain: "google.com".to_string(), + addr: "127.0.0.1".parse().unwrap(), + ttl: TransientTtl(3600), + }); + } else { + packet.header.rescode = ResultCode::NXDOMAIN; + } + + Ok(packet) + })); + + match Arc::get_mut(&mut context) { + Some(mut ctx) => { + ctx.resolve_strategy = ResolveStrategy::Forward { + host: "127.0.0.1".to_string(), + port: 53, + }; + } + None => panic!(), + } + + let mut resolver = context.create_resolver(context.clone()); + + // First verify that we get a match back + { + let res = match resolver.resolve("google.com", QueryType::A, true) { + Ok(x) => x, + Err(_) => panic!(), + }; + + assert_eq!(1, res.answers.len()); + + match res.answers[0] { + DnsRecord::A { ref domain, .. } => { + assert_eq!("google.com", domain); + } + _ => panic!(), + } + }; + + // Do the same lookup again, and verify that it's present in the cache + // and that the counter has been updated + { + let res = match resolver.resolve("google.com", QueryType::A, true) { + Ok(x) => x, + Err(_) => panic!(), + }; + + assert_eq!(1, res.answers.len()); + + let list = match context.cache.list() { + Ok(x) => x, + Err(_) => panic!(), + }; + + assert_eq!(1, list.len()); + + assert_eq!("google.com", list[0].domain); + assert_eq!(1, list[0].record_types.len()); + assert_eq!(1, list[0].hits); + }; + + // Do a failed lookup + { + let res = match resolver.resolve("yahoo.com", QueryType::A, true) { + Ok(x) => x, + Err(_) => panic!(), + }; + + assert_eq!(0, res.answers.len()); + assert_eq!(ResultCode::NXDOMAIN, res.header.rescode); + }; + } + + #[test] + fn test_recursive_resolver_with_no_nameserver() { + let context = create_test_context(Box::new(|_, _, _, _| { + let mut packet = DnsPacket::new(); + packet.header.rescode = ResultCode::NXDOMAIN; + Ok(packet) + })); + + let mut resolver = context.create_resolver(context.clone()); + + // Expect failure when no name servers are available + if let Ok(_) = resolver.resolve("google.com", QueryType::A, true) { + panic!(); + } + } + + #[test] + fn test_recursive_resolver_with_missing_a_record() { + let context = create_test_context(Box::new(|_, _, _, _| { + let mut packet = DnsPacket::new(); + packet.header.rescode = ResultCode::NXDOMAIN; + Ok(packet) + })); + + let mut resolver = context.create_resolver(context.clone()); + + // Expect failure when no name servers are available + if let Ok(_) = resolver.resolve("google.com", QueryType::A, true) { + panic!(); + } + + // Insert name server, but no corresponding A record + let mut nameservers = Vec::new(); + nameservers.push(DnsRecord::NS { + domain: "".to_string(), + host: "a.myroot.net".to_string(), + ttl: TransientTtl(3600), + }); + + let _ = context.cache.store(&nameservers); + + if let Ok(_) = resolver.resolve("google.com", QueryType::A, true) { + panic!(); + } + } + + #[test] + fn test_recursive_resolver_match_order() { + let context = create_test_context(Box::new(|_, _, (server, _), _| { + let mut packet = DnsPacket::new(); + + if server == "127.0.0.1" { + packet.header.id = 1; + + packet.answers.push(DnsRecord::A { + domain: "a.google.com".to_string(), + addr: "127.0.0.1".parse().unwrap(), + ttl: TransientTtl(3600), + }); + + return Ok(packet); + } else if server == "127.0.0.2" { + packet.header.id = 2; + + packet.answers.push(DnsRecord::A { + domain: "b.google.com".to_string(), + addr: "127.0.0.1".parse().unwrap(), + ttl: TransientTtl(3600), + }); + + return Ok(packet); + } else if server == "127.0.0.3" { + packet.header.id = 3; + + packet.answers.push(DnsRecord::A { + domain: "c.google.com".to_string(), + addr: "127.0.0.1".parse().unwrap(), + ttl: TransientTtl(3600), + }); + + return Ok(packet); + } + + packet.header.id = 999; + packet.header.rescode = ResultCode::NXDOMAIN; + Ok(packet) + })); + + let mut resolver = context.create_resolver(context.clone()); + + // Expect failure when no name servers are available + if let Ok(_) = resolver.resolve("google.com", QueryType::A, true) { + panic!(); + } + + // Insert root servers + { + let mut nameservers = Vec::new(); + nameservers.push(DnsRecord::NS { + domain: "".to_string(), + host: "a.myroot.net".to_string(), + ttl: TransientTtl(3600), + }); + nameservers.push(DnsRecord::A { + domain: "a.myroot.net".to_string(), + addr: "127.0.0.1".parse().unwrap(), + ttl: TransientTtl(3600), + }); + + let _ = context.cache.store(&nameservers); + } + + match resolver.resolve("google.com", QueryType::A, true) { + Ok(packet) => { + assert_eq!(1, packet.header.id); + } + Err(_) => panic!(), + } + + // Insert TLD servers + { + let mut nameservers = Vec::new(); + nameservers.push(DnsRecord::NS { + domain: "com".to_string(), + host: "a.mytld.net".to_string(), + ttl: TransientTtl(3600), + }); + nameservers.push(DnsRecord::A { + domain: "a.mytld.net".to_string(), + addr: "127.0.0.2".parse().unwrap(), + ttl: TransientTtl(3600), + }); + + let _ = context.cache.store(&nameservers); + } + + match resolver.resolve("google.com", QueryType::A, true) { + Ok(packet) => { + assert_eq!(2, packet.header.id); + } + Err(_) => panic!(), + } + + // Insert authoritative servers + { + let mut nameservers = Vec::new(); + nameservers.push(DnsRecord::NS { + domain: "google.com".to_string(), + host: "ns1.google.com".to_string(), + ttl: TransientTtl(3600), + }); + nameservers.push(DnsRecord::A { + domain: "ns1.google.com".to_string(), + addr: "127.0.0.3".parse().unwrap(), + ttl: TransientTtl(3600), + }); + + let _ = context.cache.store(&nameservers); + } + + match resolver.resolve("google.com", QueryType::A, true) { + Ok(packet) => { + assert_eq!(3, packet.header.id); + } + Err(_) => panic!(), + } + } + + #[test] + fn test_recursive_resolver_successfully() { + let context = create_test_context(Box::new(|qname, _, _, _| { + let mut packet = DnsPacket::new(); + + if qname == "google.com" { + packet.answers.push(DnsRecord::A { + domain: "google.com".to_string(), + addr: "127.0.0.1".parse().unwrap(), + ttl: TransientTtl(3600), + }); + } else { + packet.header.rescode = ResultCode::NXDOMAIN; + + packet.authorities.push(DnsRecord::SOA { + domain: "google.com".to_string(), + r_name: "google.com".to_string(), + m_name: "google.com".to_string(), + serial: 0, + refresh: 3600, + retry: 3600, + expire: 3600, + minimum: 3600, + ttl: TransientTtl(3600), + }); + } + + Ok(packet) + })); + + let mut resolver = context.create_resolver(context.clone()); + + // Insert name servers + let mut nameservers = Vec::new(); + nameservers.push(DnsRecord::NS { + domain: "google.com".to_string(), + host: "ns1.google.com".to_string(), + ttl: TransientTtl(3600), + }); + nameservers.push(DnsRecord::A { + domain: "ns1.google.com".to_string(), + addr: "127.0.0.1".parse().unwrap(), + ttl: TransientTtl(3600), + }); + + let _ = context.cache.store(&nameservers); + + // Check that we can successfully resolve + { + let res = match resolver.resolve("google.com", QueryType::A, true) { + Ok(x) => x, + Err(_) => panic!(), + }; + + assert_eq!(1, res.answers.len()); + + match res.answers[0] { + DnsRecord::A { ref domain, .. } => { + assert_eq!("google.com", domain); + } + _ => panic!(), + } + }; + + // And that we won't find anything for a domain that isn't present + { + let res = match resolver.resolve("foobar.google.com", QueryType::A, true) { + Ok(x) => x, + Err(_) => panic!(), + }; + + assert_eq!(ResultCode::NXDOMAIN, res.header.rescode); + assert_eq!(0, res.answers.len()); + }; + + // Perform another successful query, that should hit the cache + { + let res = match resolver.resolve("google.com", QueryType::A, true) { + Ok(x) => x, + Err(_) => panic!(), + }; + + assert_eq!(1, res.answers.len()); + }; + + // Now check that the cache is used, and that the statistics is correct + { + let list = match context.cache.list() { + Ok(x) => x, + Err(_) => panic!(), + }; + + assert_eq!(3, list.len()); + + // Check statistics for google entry + assert_eq!("google.com", list[1].domain); + + // Should have a NS record and an A record for a total of 2 record types + assert_eq!(2, list[1].record_types.len()); + + // Should have been hit two times for NS google.com and once for + // A google.com + assert_eq!(3, list[1].hits); + + assert_eq!("ns1.google.com", list[2].domain); + assert_eq!(1, list[2].record_types.len()); + assert_eq!(2, list[2].hits); + }; + } +} diff --git a/src/dns/server.rs b/src/dns/server.rs new file mode 100644 index 0000000..f7db44d --- /dev/null +++ b/src/dns/server.rs @@ -0,0 +1,608 @@ +//! UDP and TCP server implementations for DNS + +use std::collections::VecDeque; +use std::io::Write; +use std::net::SocketAddr; +use std::net::{Shutdown, TcpListener, TcpStream, UdpSocket}; +use std::sync::atomic::Ordering; +use std::sync::mpsc::{channel, Sender}; +use std::sync::{Arc, Condvar, Mutex}; +use std::thread::Builder; + +use derive_more::{Display, Error, From}; +use rand::random; + +use crate::dns::buffer::{BytePacketBuffer, PacketBuffer, StreamPacketBuffer, VectorPacketBuffer}; +use crate::dns::context::ServerContext; +use crate::dns::netutil::{read_packet_length, write_packet_length}; +use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode}; +use crate::dns::resolve::DnsResolver; + +#[derive(Debug, Display, From, Error)] +pub enum ServerError { + Io(std::io::Error), +} + +type Result = std::result::Result; + +macro_rules! return_or_report { + ( $x:expr, $message:expr ) => { + match $x { + Ok(res) => res, + Err(_) => { + println!($message); + return; + } + } + }; +} + +macro_rules! ignore_or_report { + ( $x:expr, $message:expr ) => { + match $x { + Ok(_) => {} + Err(_) => { + println!($message); + return; + } + }; + }; +} + +/// Common trait for DNS servers +pub trait DnsServer { + /// Initialize the server and start listenening + /// + /// This method should _NOT_ block. Rather, servers are expected to spawn a new + /// thread to handle requests and return immediately. + fn run_server(self) -> Result<()>; +} + +/// Utility function for resolving domains referenced in for example CNAME or SRV +/// records. This usually spares the client from having to perform additional +/// lookups. +fn resolve_cnames( + lookup_list: &[DnsRecord], + results: &mut Vec, + resolver: &mut Box, + depth: u16, +) { + if depth > 10 { + return; + } + + for ref rec in lookup_list { + match **rec { + DnsRecord::CNAME { ref host, .. } | DnsRecord::SRV { ref host, .. } => { + if let Ok(result2) = resolver.resolve(host, QueryType::A, true) { + let new_unmatched = result2.get_unresolved_cnames(); + results.push(result2); + + resolve_cnames(&new_unmatched, results, resolver, depth + 1); + } + } + _ => {} + } + } +} + +/// Perform the actual work for a query +/// +/// Incoming requests are validated to make sure they are well formed and adhere +/// to the server configuration. If so, the request will be passed on to the +/// active resolver and a query will be performed. It will also resolve some +/// possible references within the query, such as CNAME hosts. +/// +/// This function will always return a valid packet, even if the request could not +/// be performed, since we still want to send something back to the client. +pub fn execute_query(context: Arc, request: &DnsPacket) -> DnsPacket { + let mut packet = DnsPacket::new(); + packet.header.id = request.header.id; + packet.header.recursion_available = context.allow_recursive; + packet.header.response = true; + + if request.header.recursion_desired && !context.allow_recursive { + packet.header.rescode = ResultCode::REFUSED; + } else if request.questions.is_empty() { + packet.header.rescode = ResultCode::FORMERR; + } else { + let mut results = Vec::new(); + + let question = &request.questions[0]; + packet.questions.push(question.clone()); + + let mut resolver = context.create_resolver(context.clone()); + let rescode = match resolver.resolve( + &question.name, + question.qtype, + request.header.recursion_desired, + ) { + Ok(result) => { + let rescode = result.header.rescode; + + let unmatched = result.get_unresolved_cnames(); + results.push(result); + + resolve_cnames(&unmatched, &mut results, &mut resolver, 0); + + rescode + } + Err(err) => { + println!( + "Failed to resolve {:?} {}: {:?}", + question.qtype, question.name, err + ); + ResultCode::SERVFAIL + } + }; + + packet.header.rescode = rescode; + + for result in results { + for rec in result.answers { + packet.answers.push(rec); + } + for rec in result.authorities { + packet.authorities.push(rec); + } + for rec in result.resources { + packet.resources.push(rec); + } + } + } + + packet +} + +/// The UDP server +/// +/// Accepts DNS queries through UDP, and uses the `ServerContext` to determine +/// how to service the request. Packets are read on a single thread, after which +/// a new thread is spawned to service the request asynchronously. +pub struct DnsUdpServer { + context: Arc, + request_queue: Arc>>, + request_cond: Arc, + thread_count: usize, +} + +impl DnsUdpServer { + pub fn new(context: Arc, thread_count: usize) -> DnsUdpServer { + DnsUdpServer { + context: context, + request_queue: Arc::new(Mutex::new(VecDeque::new())), + request_cond: Arc::new(Condvar::new()), + thread_count: thread_count, + } + } +} + +impl DnsServer for DnsUdpServer { + /// Launch the server + /// + /// This method takes ownership of the server, preventing the method from + /// being called multiple times. + fn run_server(self) -> Result<()> { + // Bind the socket + let socket = UdpSocket::bind(("0.0.0.0", self.context.dns_port))?; + + // Spawn threads for handling requests + for thread_id in 0..self.thread_count { + let socket_clone = match socket.try_clone() { + Ok(x) => x, + Err(e) => { + println!("Failed to clone socket when starting UDP server: {:?}", e); + continue; + } + }; + + let context = self.context.clone(); + let request_cond = self.request_cond.clone(); + let request_queue = self.request_queue.clone(); + + let name = "DnsUdpServer-request-".to_string() + &thread_id.to_string(); + let _ = Builder::new().name(name).spawn(move || { + loop { + // Acquire lock, and wait on the condition until data is + // available. Then proceed with popping an entry of the queue. + let (src, request) = match request_queue + .lock() + .ok() + .and_then(|x| request_cond.wait(x).ok()) + .and_then(|mut x| x.pop_front()) + { + Some(x) => x, + None => { + println!("Not expected to happen!"); + continue; + } + }; + + let mut size_limit = 512; + + // Check for EDNS + if request.resources.len() == 1 { + if let DnsRecord::OPT { packet_len, .. } = request.resources[0] { + size_limit = packet_len as usize; + } + } + + // Create a response buffer, and ask the context for an appropriate + // resolver + let mut res_buffer = VectorPacketBuffer::new(); + + let mut packet = execute_query(context.clone(), &request); + let _ = packet.write(&mut res_buffer, size_limit); + + // Fire off the response + let len = res_buffer.pos(); + let data = return_or_report!( + res_buffer.get_range(0, len), + "Failed to get buffer data" + ); + ignore_or_report!( + socket_clone.send_to(data, src), + "Failed to send response packet" + ); + } + })?; + } + + // Start servicing requests + let _ = Builder::new() + .name("DnsUdpServer-incoming".into()) + .spawn(move || { + loop { + let _ = self + .context + .statistics + .udp_query_count + .fetch_add(1, Ordering::Release); + + // Read a query packet + let mut req_buffer = BytePacketBuffer::new(); + let (_, src) = match socket.recv_from(&mut req_buffer.buf) { + Ok(x) => x, + Err(e) => { + println!("Failed to read from UDP socket: {:?}", e); + continue; + } + }; + + // Parse it + let request = match DnsPacket::from_buffer(&mut req_buffer) { + Ok(x) => x, + Err(e) => { + println!("Failed to parse UDP query packet: {:?}", e); + continue; + } + }; + + // Acquire lock, add request to queue, and notify waiting threads + // using the condition. + match self.request_queue.lock() { + Ok(mut queue) => { + queue.push_back((src, request)); + self.request_cond.notify_one(); + } + Err(e) => { + println!("Failed to send UDP request for processing: {}", e); + } + } + } + })?; + + Ok(()) + } +} + +/// TCP DNS server +pub struct DnsTcpServer { + context: Arc, + senders: Vec>, + thread_count: usize, +} + +impl DnsTcpServer { + pub fn new(context: Arc, thread_count: usize) -> DnsTcpServer { + DnsTcpServer { + context: context, + senders: Vec::new(), + thread_count: thread_count, + } + } +} + +impl DnsServer for DnsTcpServer { + fn run_server(mut self) -> Result<()> { + let socket = TcpListener::bind(("0.0.0.0", self.context.dns_port))?; + + // Spawn threads for handling requests, and create the channels + for thread_id in 0..self.thread_count { + let (tx, rx) = channel(); + self.senders.push(tx); + + let context = self.context.clone(); + + let name = "DnsTcpServer-request-".to_string() + &thread_id.to_string(); + let _ = Builder::new().name(name).spawn(move || { + loop { + let mut stream = match rx.recv() { + Ok(x) => x, + Err(_) => continue, + }; + + let _ = context + .statistics + .tcp_query_count + .fetch_add(1, Ordering::Release); + + // When DNS packets are sent over TCP, they're prefixed with a two byte + // length. We don't really need to know the length in advance, so we + // just move past it and continue reading as usual + ignore_or_report!( + read_packet_length(&mut stream), + "Failed to read query packet length" + ); + + let request = { + let mut stream_buffer = StreamPacketBuffer::new(&mut stream); + return_or_report!( + DnsPacket::from_buffer(&mut stream_buffer), + "Failed to read query packet" + ) + }; + + let mut res_buffer = VectorPacketBuffer::new(); + + let mut packet = execute_query(context.clone(), &request); + ignore_or_report!( + packet.write(&mut res_buffer, 0xFFFF), + "Failed to write packet to buffer" + ); + + // As is the case for incoming queries, we need to send a 2 byte length + // value before handing of the actual packet. + let len = res_buffer.pos(); + ignore_or_report!( + write_packet_length(&mut stream, len), + "Failed to write packet size" + ); + + // Now we can go ahead and write the actual packet + let data = return_or_report!( + res_buffer.get_range(0, len), + "Failed to get packet data" + ); + + ignore_or_report!(stream.write(data), "Failed to write response packet"); + + ignore_or_report!(stream.shutdown(Shutdown::Both), "Failed to shutdown socket"); + } + })?; + } + + let _ = Builder::new() + .name("DnsTcpServer-incoming".into()) + .spawn(move || { + for wrap_stream in socket.incoming() { + let stream = match wrap_stream { + Ok(stream) => stream, + Err(err) => { + println!("Failed to accept TCP connection: {:?}", err); + continue; + } + }; + + // Hand it off to a worker thread + let thread_no = random::() % self.thread_count; + match self.senders[thread_no].send(stream) { + Ok(_) => {} + Err(e) => { + println!( + "Failed to send TCP request for processing on thread {}: {}", + thread_no, e + ); + } + } + } + })?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + + use std::net::Ipv4Addr; + use std::sync::Arc; + + use crate::dns::protocol::{ + DnsPacket, DnsQuestion, DnsRecord, QueryType, ResultCode, TransientTtl, + }; + + use super::*; + + use crate::dns::context::tests::create_test_context; + use crate::dns::context::ResolveStrategy; + + fn build_query(qname: &str, qtype: QueryType) -> DnsPacket { + let mut query_packet = DnsPacket::new(); + query_packet.header.recursion_desired = true; + + query_packet + .questions + .push(DnsQuestion::new(qname.into(), qtype)); + + query_packet + } + + #[test] + fn test_execute_query() { + // Construct a context to execute some queries successfully + let mut context = create_test_context(Box::new(|qname, qtype, _, _| { + let mut packet = DnsPacket::new(); + + if qname == "google.com" { + packet.answers.push(DnsRecord::A { + domain: "google.com".to_string(), + addr: "127.0.0.1".parse::().unwrap(), + ttl: TransientTtl(3600), + }); + } else if qname == "www.facebook.com" && qtype == QueryType::CNAME { + packet.answers.push(DnsRecord::CNAME { + domain: "www.facebook.com".to_string(), + host: "cdn.facebook.com".to_string(), + ttl: TransientTtl(3600), + }); + packet.answers.push(DnsRecord::A { + domain: "cdn.facebook.com".to_string(), + addr: "127.0.0.1".parse::().unwrap(), + ttl: TransientTtl(3600), + }); + } else if qname == "www.microsoft.com" && qtype == QueryType::CNAME { + packet.answers.push(DnsRecord::CNAME { + domain: "www.microsoft.com".to_string(), + host: "cdn.microsoft.com".to_string(), + ttl: TransientTtl(3600), + }); + } else if qname == "cdn.microsoft.com" && qtype == QueryType::A { + packet.answers.push(DnsRecord::A { + domain: "cdn.microsoft.com".to_string(), + addr: "127.0.0.1".parse::().unwrap(), + ttl: TransientTtl(3600), + }); + } else { + packet.header.rescode = ResultCode::NXDOMAIN; + } + + Ok(packet) + })); + + match Arc::get_mut(&mut context) { + Some(mut ctx) => { + ctx.resolve_strategy = ResolveStrategy::Forward { + host: "127.0.0.1".to_string(), + port: 53, + }; + } + None => panic!(), + } + + // A successful resolve + { + let res = execute_query(context.clone(), &build_query("google.com", QueryType::A)); + assert_eq!(1, res.answers.len()); + + match res.answers[0] { + DnsRecord::A { ref domain, .. } => { + assert_eq!("google.com", domain); + } + _ => panic!(), + } + }; + + // A successful resolve, that also resolves a CNAME without recursive lookup + { + let res = execute_query( + context.clone(), + &build_query("www.facebook.com", QueryType::CNAME), + ); + assert_eq!(2, res.answers.len()); + + match res.answers[0] { + DnsRecord::CNAME { ref domain, .. } => { + assert_eq!("www.facebook.com", domain); + } + _ => panic!(), + } + + match res.answers[1] { + DnsRecord::A { ref domain, .. } => { + assert_eq!("cdn.facebook.com", domain); + } + _ => panic!(), + } + }; + + // A successful resolve, that also resolves a CNAME through recursive lookup + { + let res = execute_query( + context.clone(), + &build_query("www.microsoft.com", QueryType::CNAME), + ); + assert_eq!(2, res.answers.len()); + + match res.answers[0] { + DnsRecord::CNAME { ref domain, .. } => { + assert_eq!("www.microsoft.com", domain); + } + _ => panic!(), + } + + match res.answers[1] { + DnsRecord::A { ref domain, .. } => { + assert_eq!("cdn.microsoft.com", domain); + } + _ => panic!(), + } + }; + + // An unsuccessful resolve, but without any error + { + let res = execute_query(context.clone(), &build_query("yahoo.com", QueryType::A)); + assert_eq!(ResultCode::NXDOMAIN, res.header.rescode); + assert_eq!(0, res.answers.len()); + }; + + // Disable recursive resolves to generate a failure + match Arc::get_mut(&mut context) { + Some(mut ctx) => { + ctx.allow_recursive = false; + } + None => panic!(), + } + + // This should generate an error code, since recursive resolves are + // no longer allowed + { + let res = execute_query(context.clone(), &build_query("yahoo.com", QueryType::A)); + assert_eq!(ResultCode::REFUSED, res.header.rescode); + assert_eq!(0, res.answers.len()); + }; + + // Send a query without a question, which should fail with an error code + { + let query_packet = DnsPacket::new(); + let res = execute_query(context.clone(), &query_packet); + assert_eq!(ResultCode::FORMERR, res.header.rescode); + assert_eq!(0, res.answers.len()); + }; + + // Now construct a context where the dns client will return a failure + let mut context2 = create_test_context(Box::new(|_, _, _, _| { + Err(crate::dns::client::ClientError::Io(std::io::Error::new( + std::io::ErrorKind::NotFound, + "Fail", + ))) + })); + + match Arc::get_mut(&mut context2) { + Some(mut ctx) => { + ctx.resolve_strategy = ResolveStrategy::Forward { + host: "127.0.0.1".to_string(), + port: 53, + }; + } + None => panic!(), + } + + // We expect this to set the server failure rescode + { + let res = execute_query(context2.clone(), &build_query("yahoo.com", QueryType::A)); + assert_eq!(ResultCode::SERVFAIL, res.header.rescode); + assert_eq!(0, res.answers.len()); + }; + } +} diff --git a/src/lib.rs b/src/lib.rs index 1dcbd3a..9f55fd1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,4 +17,5 @@ pub mod miner; pub mod context; pub mod event; pub mod p2p; +pub mod dns; From d135204af78abfa9d9dce59ccebea16e9ac10da6 Mon Sep 17 00:00:00 2001 From: Revertron Date: Fri, 19 Feb 2021 16:41:43 +0100 Subject: [PATCH 2/2] Implemented DNS on blockchain. Beautified a lot of code, fixed some things. --- Cargo.toml | 6 +- alfis.cfg | 9 ++- src/blockchain/blockchain.rs | 27 ++++++- src/blockchain/filter.rs | 54 ++++++++++++++ src/blockchain/mod.rs | 1 + src/blockchain/transaction.rs | 6 ++ src/context.rs | 47 +----------- src/dns/cache.rs | 60 +++------------- src/dns/client.rs | 70 ++++-------------- src/dns/context.rs | 8 ++- src/dns/filter.rs | 16 +++++ src/dns/mod.rs | 1 + src/dns/protocol.rs | 53 +++++++------- src/dns/resolve.rs | 23 +++--- src/dns/server.rs | 88 ++++++----------------- src/lib.rs | 5 +- src/main.rs | 131 ++++++++++++++++++++++++++++------ src/miner.rs | 1 - src/p2p/network.rs | 5 +- src/settings.rs | 60 ++++++++++++++++ src/simplebus.rs | 1 + src/webview/bulma.css | 8 +++ src/webview/index.html | 74 ++++++++++++++++--- src/webview/scripts.js | 80 ++++++++++++++++++++- 24 files changed, 539 insertions(+), 295 deletions(-) create mode 100644 src/blockchain/filter.rs create mode 100644 src/dns/filter.rs create mode 100644 src/settings.rs diff --git a/Cargo.toml b/Cargo.toml index cac146d..fb580eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,10 +4,12 @@ version = "0.1.0" authors = ["Revertron "] edition = "2018" build = "build.rs" -#![windows_subsystem = "windows"] +homepage = "https://alfis.name" +repository = "https://github.com/Revertron/Alfis" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +getopts = "0.2.21" rust-crypto = "^0.2" num_cpus = "1.13.0" byteorder = "1.3.2" @@ -37,4 +39,4 @@ serde_derive = "1.0.27" [package.metadata.winres] ProductName="ALFIS" -FileDescription="Alternative Free Identity System for independent DNS and more." \ No newline at end of file +FileDescription="Alternative Free Identity System" \ No newline at end of file diff --git a/alfis.cfg b/alfis.cfg index 5a2d805..7aa6f63 100644 --- a/alfis.cfg +++ b/alfis.cfg @@ -10,5 +10,12 @@ "127.0.0.1:10000", "127.0.0.1:10001", "127.0.0.1:10002" - ] + ], + "dns": { + "port": 53, + "forwarders": [ + "1.1.1.1", + "8.8.8.8" + ] + } } \ No newline at end of file diff --git a/src/blockchain/blockchain.rs b/src/blockchain/blockchain.rs index 8737677..99cb996 100644 --- a/src/blockchain/blockchain.rs +++ b/src/blockchain/blockchain.rs @@ -1,6 +1,7 @@ use sqlite::{Connection, State, Statement}; -use crate::{Block, Bytes, Keystore, Transaction, Settings}; +use crate::{Block, Bytes, Keystore, Transaction}; +use crate::settings::Settings; const DB_NAME: &str = "blockchain.db"; @@ -175,6 +176,30 @@ impl Blockchain { true } + pub fn get_domain_info(&self, domain: &str) -> Option { + if domain.is_empty() { + return None; + } + let identity_hash = Transaction::hash_identity(domain); + + let mut statement = self.db.prepare("SELECT * FROM transactions WHERE identity = ? ORDER BY id DESC LIMIT 1;").unwrap(); + statement.bind(1, identity_hash.as_bytes()).expect("Error in bind"); + while let State::Row = statement.next().unwrap() { + let identity = Bytes::from_bytes(statement.read::>(1).unwrap().as_slice()); + let confirmation = Bytes::from_bytes(statement.read::>(2).unwrap().as_slice()); + let method = statement.read::(3).unwrap(); + let data = statement.read::(4).unwrap(); + let pub_key = Bytes::from_bytes(statement.read::>(5).unwrap().as_slice()); + let signature = Bytes::from_bytes(statement.read::>(6).unwrap().as_slice()); + let transaction = Transaction { identity, confirmation, method, data, pub_key, signature }; + println!("Got transaction: {:?}", &transaction); + if transaction.check_for(domain) { + return Some(transaction.data); + } + } + None + } + pub fn last_block(&self) -> Option { self.last_block.clone() } diff --git a/src/blockchain/filter.rs b/src/blockchain/filter.rs new file mode 100644 index 0000000..9be95ad --- /dev/null +++ b/src/blockchain/filter.rs @@ -0,0 +1,54 @@ +use crate::Context; +use std::sync::{Mutex, Arc}; +use crate::dns::filter::DnsFilter; +use crate::dns::protocol::{DnsPacket, QueryType, DnsRecord, DnsQuestion}; + +pub struct BlockchainFilter { + context: Arc> +} + +impl BlockchainFilter { + pub fn new(context: Arc>) -> Self { + BlockchainFilter { context } + } +} + +impl DnsFilter for BlockchainFilter { + fn lookup(&self, qname: &str, qtype: QueryType) -> Option { + let data = self.context.lock().unwrap().blockchain.get_domain_info(qname); + match data { + None => { println!("Not found info for domain {}", &qname); } + Some(data) => { + let records: Vec = match serde_json::from_str(&data) { + Err(_) => { return None; } + Ok(records) => { records } + }; + let mut answers: Vec = Vec::new(); + for mut record in records { + if record.get_querytype() == qtype { + match &mut record { + // TODO make it for all types of records + DnsRecord::A { domain, .. } | DnsRecord::AAAA { domain, .. } if domain == "@" => { + *domain = String::from(qname); + } + _ => () + } + + answers.push(record); + } + } + if !answers.is_empty() { + // Create DnsPacket + let mut packet = DnsPacket::new(); + packet.questions.push(DnsQuestion::new(String::from(qname), qtype)); + for answer in answers { + packet.answers.push(answer); + } + return Some(packet); + } + } + } + + None + } +} \ No newline at end of file diff --git a/src/blockchain/mod.rs b/src/blockchain/mod.rs index 545461f..f330957 100644 --- a/src/blockchain/mod.rs +++ b/src/blockchain/mod.rs @@ -1,6 +1,7 @@ pub mod transaction; pub mod block; pub mod blockchain; +pub mod filter; pub use transaction::Transaction; pub use block::Block; diff --git a/src/blockchain/transaction.rs b/src/blockchain/transaction.rs index 1d12e43..e588e96 100644 --- a/src/blockchain/transaction.rs +++ b/src/blockchain/transaction.rs @@ -67,6 +67,12 @@ impl Transaction { digest.result(&mut buf); Bytes::from_bytes(&buf) } + + pub fn check_for(&self, domain: &str) -> bool { + let hash = Self::hash_identity(&domain); + let confirmation = Self::hash_with_key(&domain, &self.pub_key); + self.identity.eq(&hash) && self.confirmation.eq(&confirmation) + } } impl fmt::Debug for Transaction { diff --git a/src/context.rs b/src/context.rs index 035cf01..2c168d9 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,9 +1,6 @@ -use crate::{Keystore, Blockchain, Bus, Bytes}; +use crate::{Blockchain, Bus, Keystore}; use crate::event::Event; -use serde::{Serialize, Deserialize}; -use std::fs::File; -use std::io::Read; -use std::sync::MutexGuard; +use crate::settings::Settings; pub struct Context { pub settings: Settings, @@ -43,44 +40,4 @@ impl Context { pub fn get_blockchain(&self) -> &Blockchain { &self.blockchain } -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Settings { - pub origin: String, - pub version: u32, - pub key_file: String, - pub listen: String, - pub public: bool, - pub peers: Vec -} - -impl Settings { - pub fn new>(settings: S) -> serde_json::Result { - serde_json::from_str(&settings.into()) - } - - pub fn load(file_name: &str) -> Option { - match File::open(file_name) { - Ok(mut file) => { - let mut text = String::new(); - file.read_to_string(&mut text).unwrap(); - let loaded = serde_json::from_str(&text); - return if loaded.is_ok() { - Some(loaded.unwrap()) - } else { - None - } - }, - Err(..) => None - } - } - - pub fn get_origin(&self) -> Bytes { - if self.origin.eq("") { - return Bytes::zero32(); - } - let origin = crate::from_hex(&self.origin).expect("Wrong origin in settings"); - Bytes::from_bytes(origin.as_slice()) - } } \ No newline at end of file diff --git a/src/dns/cache.rs b/src/dns/cache.rs index 64f1332..1b2dc81 100644 --- a/src/dns/cache.rs +++ b/src/dns/cache.rs @@ -39,25 +39,15 @@ impl PartialEq for RecordEntry { } impl Hash for RecordEntry { - fn hash(&self, state: &mut H) - where - H: Hasher, - { + fn hash(&self, state: &mut H) where H: Hasher { self.record.hash(state); } } #[derive(Clone, Debug, Serialize, Deserialize)] pub enum RecordSet { - NoRecords { - qtype: QueryType, - ttl: u32, - timestamp: DateTime, - }, - Records { - qtype: QueryType, - records: HashSet, - }, + NoRecords { qtype: QueryType, ttl: u32, timestamp: DateTime }, + Records { qtype: QueryType, records: HashSet }, } #[derive(Clone, Debug)] @@ -70,22 +60,13 @@ pub struct DomainEntry { impl DomainEntry { pub fn new(domain: String) -> DomainEntry { - DomainEntry { - domain: domain, - record_types: HashMap::new(), - hits: 0, - updates: 0, - } + DomainEntry { domain, record_types: HashMap::new(), hits: 0, updates: 0 } } pub fn store_nxdomain(&mut self, qtype: QueryType, ttl: u32) { self.updates += 1; - let new_set = RecordSet::NoRecords { - qtype: qtype, - ttl: ttl, - timestamp: Local::now(), - }; + let new_set = RecordSet::NoRecords { qtype, ttl, timestamp: Local::now() }; self.record_types.insert(qtype, new_set); } @@ -93,15 +74,9 @@ impl DomainEntry { pub fn store_record(&mut self, rec: &DnsRecord) { self.updates += 1; - let entry = RecordEntry { - record: rec.clone(), - timestamp: Local::now(), - }; + let entry = RecordEntry { record: rec.clone(), timestamp: Local::now() }; - if let Some(&mut RecordSet::Records { - ref mut records, .. - }) = self.record_types.get_mut(&rec.get_querytype()) - { + if let Some(&mut RecordSet::Records { ref mut records, .. }) = self.record_types.get_mut(&rec.get_querytype()) { if records.contains(&entry) { records.remove(&entry); } @@ -113,10 +88,7 @@ impl DomainEntry { let mut records = HashSet::new(); records.insert(entry); - let new_set = RecordSet::Records { - qtype: rec.get_querytype(), - records: records, - }; + let new_set = RecordSet::Records { qtype: rec.get_querytype(), records }; self.record_types.insert(rec.get_querytype(), new_set); } @@ -191,9 +163,7 @@ pub struct Cache { impl Cache { pub fn new() -> Cache { - Cache { - domain_entries: BTreeMap::new(), - } + Cache { domain_entries: BTreeMap::new() } } fn get_cache_state(&mut self, qname: &str, qtype: QueryType) -> CacheState { @@ -203,13 +173,7 @@ impl Cache { } } - fn fill_queryresult( - &mut self, - qname: &str, - qtype: QueryType, - result_vec: &mut Vec, - increment_stats: bool, - ) { + fn fill_queryresult(&mut self,qname: &str, qtype: QueryType, result_vec: &mut Vec, increment_stats: bool) { if let Some(domain_entry) = self.domain_entries.get_mut(qname).and_then(Arc::get_mut) { if increment_stats { domain_entry.hits += 1 @@ -275,9 +239,7 @@ pub struct SynchronizedCache { impl SynchronizedCache { pub fn new() -> SynchronizedCache { - SynchronizedCache { - cache: RwLock::new(Cache::new()), - } + SynchronizedCache { cache: RwLock::new(Cache::new()) } } pub fn list(&self) -> Result>> { diff --git a/src/dns/client.rs b/src/dns/client.rs index 7603743..ff034d6 100644 --- a/src/dns/client.rs +++ b/src/dns/client.rs @@ -32,13 +32,7 @@ pub trait DnsClient { fn get_failed_count(&self) -> usize; fn run(&self) -> Result<()>; - fn send_query( - &self, - qname: &str, - qtype: QueryType, - server: (&str, u16), - recursive: bool, - ) -> Result; + fn send_query(&self, qname: &str, qtype: QueryType, server: (&str, u16), recursive: bool) -> Result; } /// The UDP client @@ -72,6 +66,7 @@ struct PendingQuery { } unsafe impl Send for DnsNetworkClient {} + unsafe impl Sync for DnsNetworkClient {} impl DnsNetworkClient { @@ -89,13 +84,7 @@ impl DnsNetworkClient { /// /// This is much simpler than using UDP, since the kernel will take care of /// packet ordering, connection state, timeouts etc. - pub fn send_tcp_query( - &self, - qname: &str, - qtype: QueryType, - server: (&str, u16), - recursive: bool, - ) -> Result { + pub fn send_tcp_query(&self, qname: &str, qtype: QueryType, server: (&str, u16), recursive: bool) -> Result { let _ = self.total_sent.fetch_add(1, Ordering::Release); // Prepare request @@ -135,14 +124,8 @@ impl DnsNetworkClient { /// The query is sent from the callee thread, but responses are read on a /// worker thread, and returned to this thread through a channel. Thus this /// method is thread safe, and can be used from any number of threads in - /// parallell. - pub fn send_udp_query( - &self, - qname: &str, - qtype: QueryType, - server: (&str, u16), - recursive: bool, - ) -> Result { + /// parallel. + pub fn send_udp_query(&self, qname: &str, qtype: QueryType, server: (&str, u16), recursive: bool) -> Result { let _ = self.total_sent.fetch_add(1, Ordering::Release); // Prepare request @@ -156,30 +139,20 @@ impl DnsNetworkClient { packet.header.questions = 1; packet.header.recursion_desired = recursive; - packet - .questions - .push(DnsQuestion::new(qname.to_string(), qtype)); + packet.questions.push(DnsQuestion::new(qname.to_string(), qtype)); // Create a return channel, and add a `PendingQuery` to the list of lookups // in progress let (tx, rx) = channel(); { - let mut pending_queries = self - .pending_queries - .lock() - .map_err(|_| ClientError::PoisonedLock)?; - pending_queries.push(PendingQuery { - seq: packet.header.id, - timestamp: Local::now(), - tx: tx, - }); + let mut pending_queries = self.pending_queries.lock().map_err(|_| ClientError::PoisonedLock)?; + pending_queries.push(PendingQuery { seq: packet.header.id, timestamp: Local::now(), tx }); } // Send query let mut req_buffer = BytePacketBuffer::new(); packet.write(&mut req_buffer, 512)?; - self.socket - .send_to(&req_buffer.buf[0..req_buffer.pos], server)?; + self.socket.send_to(&req_buffer.buf[0..req_buffer.pos], server)?; // Wait for response match rx.recv() { @@ -231,10 +204,7 @@ impl DnsClient for DnsNetworkClient { let packet = match DnsPacket::from_buffer(&mut res_buffer) { Ok(packet) => packet, Err(err) => { - println!( - "DnsNetworkClient failed to parse packet with error: {}", - err - ); + println!("DnsNetworkClient failed to parse packet with error: {:?}", err); continue; } }; @@ -298,13 +268,7 @@ impl DnsClient for DnsNetworkClient { Ok(()) } - fn send_query( - &self, - qname: &str, - qtype: QueryType, - server: (&str, u16), - recursive: bool, - ) -> Result { + fn send_query(&self,qname: &str, qtype: QueryType, server: (&str, u16), recursive: bool) -> Result { let packet = self.send_udp_query(qname, qtype, server, recursive)?; if !packet.header.truncated_message { return Ok(packet); @@ -317,7 +281,6 @@ impl DnsClient for DnsNetworkClient { #[cfg(test)] pub mod tests { - use super::*; use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType}; @@ -329,11 +292,12 @@ pub mod tests { impl<'a> DnsStubClient { pub fn new(callback: Box) -> DnsStubClient { - DnsStubClient { callback: callback } + DnsStubClient { callback } } } unsafe impl Send for DnsStubClient {} + unsafe impl Sync for DnsStubClient {} impl DnsClient for DnsStubClient { @@ -349,13 +313,7 @@ pub mod tests { Ok(()) } - fn send_query( - &self, - qname: &str, - qtype: QueryType, - server: (&str, u16), - recursive: bool, - ) -> Result { + fn send_query(&self,qname: &str, qtype: QueryType, server: (&str, u16), recursive: bool) -> Result { (self.callback)(qname, qtype, server, recursive) } } diff --git a/src/dns/context.rs b/src/dns/context.rs index 0490c55..9e74a3d 100644 --- a/src/dns/context.rs +++ b/src/dns/context.rs @@ -10,6 +10,7 @@ use crate::dns::authority::Authority; use crate::dns::cache::SynchronizedCache; use crate::dns::client::{DnsClient, DnsNetworkClient}; use crate::dns::resolve::{DnsResolver, ForwardingDnsResolver, RecursiveDnsResolver}; +use crate::dns::filter::DnsFilter; #[derive(Debug, Display, From, Error)] pub enum ContextError { @@ -43,6 +44,7 @@ pub enum ResolveStrategy { pub struct ServerContext { pub authority: Authority, pub cache: SynchronizedCache, + pub filters: Vec>, pub client: Box, pub dns_port: u16, pub api_port: u16, @@ -66,6 +68,7 @@ impl ServerContext { ServerContext { authority: Authority::new(), cache: SynchronizedCache::new(), + filters: Vec::new(), client: Box::new(DnsNetworkClient::new(34255)), dns_port: 53, api_port: 5380, @@ -73,7 +76,7 @@ impl ServerContext { allow_recursive: true, enable_udp: true, enable_tcp: true, - enable_api: true, + enable_api: false, statistics: ServerStatistics { tcp_query_count: AtomicUsize::new(0), udp_query_count: AtomicUsize::new(0), @@ -122,6 +125,7 @@ pub mod tests { Arc::new(ServerContext { authority: Authority::new(), cache: SynchronizedCache::new(), + filters: Vec::new(), client: Box::new(DnsStubClient::new(callback)), dns_port: 53, api_port: 5380, @@ -129,7 +133,7 @@ pub mod tests { allow_recursive: true, enable_udp: true, enable_tcp: true, - enable_api: true, + enable_api: false, statistics: ServerStatistics { tcp_query_count: AtomicUsize::new(0), udp_query_count: AtomicUsize::new(0), diff --git a/src/dns/filter.rs b/src/dns/filter.rs new file mode 100644 index 0000000..26120f9 --- /dev/null +++ b/src/dns/filter.rs @@ -0,0 +1,16 @@ +use crate::dns::protocol::{QueryType, DnsPacket}; + +pub trait DnsFilter { + fn lookup(&self, qname: &str, qtype: QueryType) -> Option; +} + +pub struct DummyFilter { + +} + +#[allow(unused_variables)] +impl DnsFilter for DummyFilter { + fn lookup(&self, qname: &str, qtype: QueryType) -> Option { + None + } +} \ No newline at end of file diff --git a/src/dns/mod.rs b/src/dns/mod.rs index 1783419..e4eb8f3 100644 --- a/src/dns/mod.rs +++ b/src/dns/mod.rs @@ -22,5 +22,6 @@ pub mod context; pub mod protocol; pub mod resolve; pub mod server; +pub mod filter; mod netutil; diff --git a/src/dns/protocol.rs b/src/dns/protocol.rs index b22c665..4d38f47 100644 --- a/src/dns/protocol.rs +++ b/src/dns/protocol.rs @@ -188,8 +188,8 @@ impl DnsRecord { ); Ok(DnsRecord::A { - domain: domain, - addr: addr, + domain, + addr, ttl: TransientTtl(ttl), }) } @@ -210,8 +210,8 @@ impl DnsRecord { ); Ok(DnsRecord::AAAA { - domain: domain, - addr: addr, + domain, + addr, ttl: TransientTtl(ttl), }) } @@ -220,7 +220,7 @@ impl DnsRecord { buffer.read_qname(&mut ns)?; Ok(DnsRecord::NS { - domain: domain, + domain, host: ns, ttl: TransientTtl(ttl), }) @@ -230,7 +230,7 @@ impl DnsRecord { buffer.read_qname(&mut cname)?; Ok(DnsRecord::CNAME { - domain: domain, + domain, host: cname, ttl: TransientTtl(ttl), }) @@ -244,10 +244,10 @@ impl DnsRecord { buffer.read_qname(&mut srv)?; Ok(DnsRecord::SRV { - domain: domain, - priority: priority, - weight: weight, - port: port, + domain, + priority, + weight, + port, host: srv, ttl: TransientTtl(ttl), }) @@ -258,8 +258,8 @@ impl DnsRecord { buffer.read_qname(&mut mx)?; Ok(DnsRecord::MX { - domain: domain, - priority: priority, + domain, + priority, host: mx, ttl: TransientTtl(ttl), }) @@ -278,14 +278,14 @@ impl DnsRecord { let minimum = buffer.read_u32()?; Ok(DnsRecord::SOA { - domain: domain, - m_name: m_name, - r_name: r_name, - serial: serial, - refresh: refresh, - retry: retry, - expire: expire, - minimum: minimum, + domain, + m_name, + r_name, + serial, + refresh, + retry, + expire, + minimum, ttl: TransientTtl(ttl), }) } @@ -300,7 +300,7 @@ impl DnsRecord { buffer.step(data_len as usize)?; Ok(DnsRecord::TXT { - domain: domain, + domain, data: txt, ttl: TransientTtl(ttl), }) @@ -317,16 +317,16 @@ impl DnsRecord { Ok(DnsRecord::OPT { packet_len: class, flags: ttl, - data: data, + data, }) } QueryType::UNKNOWN(_) => { buffer.step(data_len as usize)?; Ok(DnsRecord::UNKNOWN { - domain: domain, + domain, qtype: qtype_num, - data_len: data_len, + data_len, ttl: TransientTtl(ttl), }) } @@ -755,10 +755,7 @@ pub struct DnsQuestion { impl DnsQuestion { pub fn new(name: String, qtype: QueryType) -> DnsQuestion { - DnsQuestion { - name: name, - qtype: qtype, - } + DnsQuestion { name, qtype } } pub fn binary_len(&self) -> usize { diff --git a/src/dns/resolve.rs b/src/dns/resolve.rs index 6bcfbc3..16e5b71 100644 --- a/src/dns/resolve.rs +++ b/src/dns/resolve.rs @@ -51,6 +51,12 @@ pub trait DnsResolver { } } + for filter in self.get_context().filters.iter() { + if let Some(packet) = filter.lookup(qname, qtype) { + return Ok(packet); + } + } + self.perform(qname, qtype) } @@ -67,10 +73,7 @@ pub struct ForwardingDnsResolver { impl ForwardingDnsResolver { pub fn new(context: Arc, server: (String, u16)) -> ForwardingDnsResolver { - ForwardingDnsResolver { - context: context, - server: server, - } + ForwardingDnsResolver { context, server } } } @@ -81,10 +84,12 @@ impl DnsResolver for ForwardingDnsResolver { fn perform(&mut self, qname: &str, qtype: QueryType) -> Result { let &(ref host, port) = &self.server; - let result = self - .context - .client - .send_query(qname, qtype, (host.as_str(), port), true)?; + let result = match self.context.cache.lookup(qname, qtype) { + None => { + self.context.client.send_query(qname, qtype, (host.as_str(), port), true)? + } + Some(packet) => packet + }; self.context.cache.store(&result.answers)?; @@ -101,7 +106,7 @@ pub struct RecursiveDnsResolver { impl RecursiveDnsResolver { pub fn new(context: Arc) -> RecursiveDnsResolver { - RecursiveDnsResolver { context: context } + RecursiveDnsResolver { context } } } diff --git a/src/dns/server.rs b/src/dns/server.rs index f7db44d..903fff1 100644 --- a/src/dns/server.rs +++ b/src/dns/server.rs @@ -59,8 +59,7 @@ pub trait DnsServer { } /// Utility function for resolving domains referenced in for example CNAME or SRV -/// records. This usually spares the client from having to perform additional -/// lookups. +/// records. This usually spares the client from having to perform additional lookups. fn resolve_cnames( lookup_list: &[DnsRecord], results: &mut Vec, @@ -112,11 +111,7 @@ pub fn execute_query(context: Arc, request: &DnsPacket) -> DnsPac packet.questions.push(question.clone()); let mut resolver = context.create_resolver(context.clone()); - let rescode = match resolver.resolve( - &question.name, - question.qtype, - request.header.recursion_desired, - ) { + let rescode = match resolver.resolve(&question.name, question.qtype, request.header.recursion_desired) { Ok(result) => { let rescode = result.header.rescode; @@ -128,10 +123,7 @@ pub fn execute_query(context: Arc, request: &DnsPacket) -> DnsPac rescode } Err(err) => { - println!( - "Failed to resolve {:?} {}: {:?}", - question.qtype, question.name, err - ); + println!("Failed to resolve {:?} {}: {:?}", question.qtype, question.name, err); ResultCode::SERVFAIL } }; @@ -169,10 +161,10 @@ pub struct DnsUdpServer { impl DnsUdpServer { pub fn new(context: Arc, thread_count: usize) -> DnsUdpServer { DnsUdpServer { - context: context, + context, request_queue: Arc::new(Mutex::new(VecDeque::new())), request_cond: Arc::new(Condvar::new()), - thread_count: thread_count, + thread_count, } } } @@ -180,11 +172,10 @@ impl DnsUdpServer { impl DnsServer for DnsUdpServer { /// Launch the server /// - /// This method takes ownership of the server, preventing the method from - /// being called multiple times. + /// This method takes ownership of the server, preventing the method from being called multiple times. fn run_server(self) -> Result<()> { // Bind the socket - let socket = UdpSocket::bind(("0.0.0.0", self.context.dns_port))?; + let socket = UdpSocket::bind(("[::]", self.context.dns_port))?; // Spawn threads for handling requests for thread_id in 0..self.thread_count { @@ -227,8 +218,7 @@ impl DnsServer for DnsUdpServer { } } - // Create a response buffer, and ask the context for an appropriate - // resolver + // Create a response buffer, and ask the context for an appropriate resolver let mut res_buffer = VectorPacketBuffer::new(); let mut packet = execute_query(context.clone(), &request); @@ -236,14 +226,8 @@ impl DnsServer for DnsUdpServer { // Fire off the response let len = res_buffer.pos(); - let data = return_or_report!( - res_buffer.get_range(0, len), - "Failed to get buffer data" - ); - ignore_or_report!( - socket_clone.send_to(data, src), - "Failed to send response packet" - ); + let data = return_or_report!(res_buffer.get_range(0, len), "Failed to get buffer data"); + ignore_or_report!(socket_clone.send_to(data, src), "Failed to send response packet"); } })?; } @@ -253,11 +237,7 @@ impl DnsServer for DnsUdpServer { .name("DnsUdpServer-incoming".into()) .spawn(move || { loop { - let _ = self - .context - .statistics - .udp_query_count - .fetch_add(1, Ordering::Release); + let _ = self.context.statistics.udp_query_count.fetch_add(1, Ordering::Release); // Read a query packet let mut req_buffer = BytePacketBuffer::new(); @@ -278,8 +258,7 @@ impl DnsServer for DnsUdpServer { } }; - // Acquire lock, add request to queue, and notify waiting threads - // using the condition. + // Acquire lock, add request to queue, and notify waiting threads using the condition. match self.request_queue.lock() { Ok(mut queue) => { queue.push_back((src, request)); @@ -305,17 +284,13 @@ pub struct DnsTcpServer { impl DnsTcpServer { pub fn new(context: Arc, thread_count: usize) -> DnsTcpServer { - DnsTcpServer { - context: context, - senders: Vec::new(), - thread_count: thread_count, - } + DnsTcpServer { context, senders: Vec::new(), thread_count } } } impl DnsServer for DnsTcpServer { fn run_server(mut self) -> Result<()> { - let socket = TcpListener::bind(("0.0.0.0", self.context.dns_port))?; + let socket = TcpListener::bind(("[::]", self.context.dns_port))?; // Spawn threads for handling requests, and create the channels for thread_id in 0..self.thread_count { @@ -332,48 +307,30 @@ impl DnsServer for DnsTcpServer { Err(_) => continue, }; - let _ = context - .statistics - .tcp_query_count - .fetch_add(1, Ordering::Release); + let _ = context.statistics.tcp_query_count.fetch_add(1, Ordering::Release); // When DNS packets are sent over TCP, they're prefixed with a two byte // length. We don't really need to know the length in advance, so we // just move past it and continue reading as usual - ignore_or_report!( - read_packet_length(&mut stream), - "Failed to read query packet length" - ); + ignore_or_report!(read_packet_length(&mut stream), "Failed to read query packet length"); let request = { let mut stream_buffer = StreamPacketBuffer::new(&mut stream); - return_or_report!( - DnsPacket::from_buffer(&mut stream_buffer), - "Failed to read query packet" - ) + return_or_report!(DnsPacket::from_buffer(&mut stream_buffer), "Failed to read query packet") }; let mut res_buffer = VectorPacketBuffer::new(); let mut packet = execute_query(context.clone(), &request); - ignore_or_report!( - packet.write(&mut res_buffer, 0xFFFF), - "Failed to write packet to buffer" - ); + ignore_or_report!(packet.write(&mut res_buffer, 0xFFFF), "Failed to write packet to buffer"); // As is the case for incoming queries, we need to send a 2 byte length // value before handing of the actual packet. let len = res_buffer.pos(); - ignore_or_report!( - write_packet_length(&mut stream, len), - "Failed to write packet size" - ); + ignore_or_report!(write_packet_length(&mut stream, len), "Failed to write packet size"); // Now we can go ahead and write the actual packet - let data = return_or_report!( - res_buffer.get_range(0, len), - "Failed to get packet data" - ); + let data = return_or_report!(res_buffer.get_range(0, len), "Failed to get packet data"); ignore_or_report!(stream.write(data), "Failed to write response packet"); @@ -399,10 +356,7 @@ impl DnsServer for DnsTcpServer { match self.senders[thread_no].send(stream) { Ok(_) => {} Err(e) => { - println!( - "Failed to send TCP request for processing on thread {}: {}", - thread_no, e - ); + println!("Failed to send TCP request for processing on thread {}: {}", thread_no, e); } } } diff --git a/src/lib.rs b/src/lib.rs index 9f55fd1..22d0a56 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,13 +3,13 @@ pub use blockchain::transaction::Transaction; pub use crate::blockchain::Blockchain; pub use crate::context::Context; -pub use crate::context::Settings; +pub use settings::Settings; pub use crate::keys::Bytes; pub use crate::keys::Keystore; pub use crate::simplebus::*; pub use crate::utils::*; -mod blockchain; +pub mod blockchain; pub mod utils; pub mod simplebus; pub mod keys; @@ -18,4 +18,5 @@ pub mod context; pub mod event; pub mod p2p; pub mod dns; +pub mod settings; diff --git a/src/main.rs b/src/main.rs index 7fd8605..9c43e5b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,22 +2,31 @@ extern crate web_view; extern crate tinyfiledialogs as tfd; +use std::env; use std::sync::{Arc, Mutex}; -use std::sync::atomic::{AtomicBool, Ordering, AtomicUsize}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::thread; +use std::time::Duration; use rand::RngCore; -use serde::{Deserialize}; +use serde::Deserialize; use web_view::*; +use getopts::Options; -use alfis::{Blockchain, Bytes, Context, Keystore, Settings, Transaction, Block}; +use alfis::{Blockchain, Bytes, Context, Keystore, Transaction}; use alfis::event::Event; use alfis::miner::Miner; use alfis::p2p::Network; +use alfis::settings::Settings; +use alfis::dns::context::{ServerContext, ResolveStrategy}; +use alfis::dns::server::{DnsServer, DnsUdpServer, DnsTcpServer}; +use alfis::dns::protocol::DnsRecord; +use alfis::blockchain::filter::BlockchainFilter; extern crate serde; extern crate serde_json; +#[allow(dead_code)] const ONE_YEAR: u16 = 365; const GENESIS_ZONE: &str = "ygg"; const GENESIS_ZONE_DIFFICULTY: u16 = 20; @@ -26,6 +35,27 @@ const SETTINGS_FILENAME: &str = "alfis.cfg"; fn main() { println!("ALFIS 0.1.0"); + let args: Vec = env::args().collect(); + let program = args[0].clone(); + + let mut opts = Options::new(); + opts.optflag("h","help", "Print this help menu"); + opts.optflag("n","nogui","Run without graphic user interface"); + opts.optopt("c","config","Path to config file", ""); + + let opt_matches = match opts.parse(&args[1..]) { + Ok(m) => m, + Err(f) => panic!(f.to_string()), + }; + + if opt_matches.opt_present("h") { + let brief = format!("Usage: {} [options]", program); + print!("{}", opts.usage(&brief)); + return; + } + + let no_gui = opt_matches.opt_present("n"); + let settings = Settings::load(SETTINGS_FILENAME).expect("Error loading settings"); let keystore: Keystore = match Keystore::from_file(&settings.key_file, "") { None => { @@ -39,7 +69,9 @@ fn main() { None => { println!("No blocks found in DB"); } Some(block) => { println!("Loaded DB with origin {:?}", &block.hash); } } + let settings_copy = settings.clone(); let context: Arc> = Arc::new(Mutex::new(Context::new(settings, keystore, blockchain))); + start_dns_server(&context, &settings_copy); let mut miner_obj = Miner::new(context.clone()); miner_obj.start_mining_thread(); @@ -49,7 +81,32 @@ fn main() { network.start().expect("Error starting network component"); create_genesis_if_needed(&context, &miner); - run_interface(context.clone(), miner.clone()); + if no_gui { + let sleep = Duration::from_millis(1000); + loop { + thread::sleep(sleep); + } + } else { + run_interface(context.clone(), miner.clone()); + } +} + +fn start_dns_server(context: &Arc>, settings: &Settings) { + let server_context = create_server_context(context.clone(), &settings); + + if server_context.enable_udp { + let udp_server = DnsUdpServer::new(server_context.clone(), 20); + if let Err(e) = udp_server.run_server() { + println!("Failed to bind UDP listener: {:?}", e); + } + } + + if server_context.enable_tcp { + let tcp_server = DnsTcpServer::new(server_context.clone(), 20); + if let Err(e) = tcp_server.run_server() { + println!("Failed to bind TCP listener: {:?}", e); + } + } } fn create_genesis_if_needed(context: &Arc>, miner: &Arc>) { @@ -85,7 +142,6 @@ fn run_interface(context: Arc>, miner: Arc>) { Loaded => { web_view.eval("showMiningIndicator(false);").expect("Error evaluating!"); let handle = web_view.handle(); - let context_copy = context.clone(); let mut c = context.lock().unwrap(); c.bus.register(move |_uuid, e| { println!("Got event from bus {:?}", &e); @@ -103,8 +159,7 @@ fn run_interface(context: Arc>, miner: Arc>) { if !eval.is_empty() { println!("Evaluating {}", &eval); handle.dispatch(move |web_view| { - web_view.eval(&eval.replace("\\", "\\\\")).expect("Error evaluating!"); - return WVResult::Ok(()); + web_view.eval(&eval.replace("\\", "\\\\")) }).expect("Error dispatching!"); } true @@ -154,19 +209,22 @@ fn run_interface(context: Arc>, miner: Arc>) { let available = c.get_blockchain().is_domain_available(&name, &c.get_keystore()); web_view.eval(&format!("domainAvailable({})", available)).expect("Error evaluating!"); } - CreateDomain { name, records, tags } => { - let keystore = { - let guard = context.lock().unwrap(); - guard.get_keystore() - }; - create_domain(miner.clone(), name, records, &keystore); + CreateDomain { name, records, .. } => { + println!("Got records: {}", records); + if serde_json::from_str::>(&records).is_ok() { + let keystore = { + let guard = context.lock().unwrap(); + guard.get_keystore() + }; + create_domain(miner.clone(), name, records, &keystore); + } else { + println!("Error in DNS records for domain!"); + web_view.eval(&format!("showWarning('{}');", "Something wrong with your records! Please, correct the error and try again.")); + } } - ChangeDomain { name, records, tags } => { - let keystore = { context.lock().unwrap().get_keystore() }; - // TODO - } - RenewDomain { name, days } => {} - TransferDomain { name, owner } => {} + ChangeDomain { .. } => {} + RenewDomain { .. } => {} + TransferDomain { .. } => {} StopMining => { context.lock().unwrap().bus.post(Event::ActionStopMining); } @@ -192,7 +250,7 @@ fn create_domain>(miner: Arc>, name: S, data: S, ke println!("Generating domain {}", name); //let rec_vector: Vec = records.into().trim().split("\n").map(|s| s.trim()).map(String::from).collect(); //let tags_vector: Vec = tags.into().trim().split(",").map(|s| s.trim()).map(String::from).collect(); - let transaction = { create_transaction(keystore, name, "domain".into(), data.into()) }; + let transaction = create_transaction(keystore, name, "domain".into(), data.into()); let mut miner_guard = miner.lock().unwrap(); miner_guard.add_transaction(transaction); } @@ -259,6 +317,23 @@ fn generate_key(difficulty: usize, mining: Arc) -> Option } } +fn create_server_context(context: Arc>, settings: &Settings) -> Arc { + let mut server_context = ServerContext::new(); + server_context.allow_recursive = true; + server_context.dns_port = settings.dns.port; + server_context.resolve_strategy = match settings.dns.forwarders.is_empty() { + true => { ResolveStrategy::Recursive } + false => { ResolveStrategy::Forward { host: settings.dns.forwarders[0].clone(), port: 53 }} // TODO refactor to use more resolvers + }; + server_context.filters.push(Box::new(BlockchainFilter::new(context))); + match server_context.initialize() { + Ok(_) => {} + Err(e) => { panic!("Server failed to initialize: {:?}", e); } + } + + Arc::new(server_context) +} + #[derive(Deserialize)] #[serde(tag = "cmd", rename_all = "camelCase")] pub enum Cmd { @@ -281,3 +356,19 @@ fn inline_style(s: &str) -> String { fn inline_script(s: &str) -> String { format!(r#""#, s) } + +#[cfg(test)] +mod tests { + use alfis::dns::protocol::{DnsRecord, TransientTtl}; + + #[test] + fn record_to_string() { + let record = DnsRecord::A { + domain: "google.com".to_string(), + addr: "127.0.0.1".parse().unwrap(), + ttl: TransientTtl(300) + }; + println!("Record is {:?}", &record); + println!("Record in JSON is {}", serde_json::to_string(&record).unwrap()); + } +} \ No newline at end of file diff --git a/src/miner.rs b/src/miner.rs index 264a718..cf86700 100644 --- a/src/miner.rs +++ b/src/miner.rs @@ -10,7 +10,6 @@ use num_cpus; use crate::{Block, Bytes, Context, hash_is_good, Transaction}; use crate::event::Event; -use std::ops::DerefMut; pub struct Miner { context: Arc>, diff --git a/src/p2p/network.rs b/src/p2p/network.rs index 453822d..b30663f 100644 --- a/src/p2p/network.rs +++ b/src/p2p/network.rs @@ -13,7 +13,6 @@ use mio::net::{TcpListener, TcpStream}; use crate::{Context, Block, p2p::Message, p2p::State, p2p::Peer, p2p::Peers}; use std::net::{SocketAddr, IpAddr, SocketAddrV4, Shutdown}; -use std::ops::DerefMut; const SERVER: Token = Token(0); const POLL_TIMEOUT: Option = Some(Duration::from_millis(3000)); @@ -95,8 +94,8 @@ impl Network { None => {} Some(mut peer) => { let stream = peer.get_stream(); - poll.registry().deregister(stream); - stream.shutdown(Shutdown::Both); + let _ = poll.registry().deregister(stream); + let _ = stream.shutdown(Shutdown::Both); println!("Peer connection {:?} has shut down", &peer.get_addr()); } } diff --git a/src/settings.rs b/src/settings.rs new file mode 100644 index 0000000..ec1d9c0 --- /dev/null +++ b/src/settings.rs @@ -0,0 +1,60 @@ +use std::fs::File; +use std::io::Read; + +use serde::{Deserialize, Serialize}; + +use crate::Bytes; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Settings { + pub origin: String, + pub version: u32, + pub key_file: String, + pub listen: String, + pub public: bool, + pub peers: Vec, + #[serde(default)] + pub dns: Dns +} + +impl Settings { + pub fn new>(settings: S) -> serde_json::Result { + serde_json::from_str(&settings.into()) + } + + pub fn load(file_name: &str) -> Option { + match File::open(file_name) { + Ok(mut file) => { + let mut text = String::new(); + file.read_to_string(&mut text).unwrap(); + let loaded = serde_json::from_str(&text); + return if loaded.is_ok() { + Some(loaded.unwrap()) + } else { + None + } + }, + Err(..) => None + } + } + + pub fn get_origin(&self) -> Bytes { + if self.origin.eq("") { + return Bytes::zero32(); + } + let origin = crate::from_hex(&self.origin).expect("Wrong origin in settings"); + Bytes::from_bytes(origin.as_slice()) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Dns { + pub port: u16, + pub forwarders: Vec +} + +impl Default for Dns { + fn default() -> Self { + Dns { port: 53, forwarders: Vec::new() } + } +} \ No newline at end of file diff --git a/src/simplebus.rs b/src/simplebus.rs index c0ed4dd..a8f281d 100644 --- a/src/simplebus.rs +++ b/src/simplebus.rs @@ -27,6 +27,7 @@ impl Bus { } } +#[cfg(test)] mod tests { use std::sync::{Arc, Mutex}; use std::thread; diff --git a/src/webview/bulma.css b/src/webview/bulma.css index 994a87b..c262663 100644 --- a/src/webview/bulma.css +++ b/src/webview/bulma.css @@ -10838,4 +10838,12 @@ label.panel-block:hover { html { overflow: hidden; +} + +.notification { + position: absolute; + z-index: 100; + width: 50%; + top: 10pt; + right: 10pt; } \ No newline at end of file diff --git a/src/webview/index.html b/src/webview/index.html index ad2a62c..f7b00d9 100644 --- a/src/webview/index.html +++ b/src/webview/index.html @@ -28,8 +28,66 @@ + + + +
-
+
-
- -
- -
-
-
@@ -108,7 +159,14 @@
+
+ +
+
+
+ +
diff --git a/src/webview/scripts.js b/src/webview/scripts.js index 35b551a..c3695f2 100644 --- a/src/webview/scripts.js +++ b/src/webview/scripts.js @@ -1,3 +1,66 @@ +var recordsBuffer = []; + +function addRecord(record) { + recordsBuffer.push(record); + refresh_records_list(); +} + +function delRecord(index) { + recordsBuffer.splice(index, 1); + refresh_records_list(); +} + +function refresh_records_list() { + var buf = ""; + if (recordsBuffer.length > 0) { + buf = "\n"; + } + function getInput(text) { + return ''; + } + + function makeRecord(value, index, array) { + buf += "
\n"; + buf += "
" + getInput(value.domain) + "
\n"; + buf += "
" + getInput(value.type) + "
\n"; + buf += "
" + getInput(value.ttl) + "
\n"; + buf += "
" + getInput(value.addr) + "
\n"; + buf += "
\n
\n"; + buf += "
"; + } + + recordsBuffer.forEach(makeRecord); + document.getElementById("domain_records").innerHTML = buf; +} + +function showNewRecordDialog() { + button_positive = document.getElementById("new_record_positive_button"); + button_positive.onclick = function() { + addRecord(get_record_from_dialog()); // It will refresh list + dialog = document.getElementById("new_record_dialog"); + dialog.className = "modal"; + }; + + button_negative = document.getElementById("new_record_negative_button"); + button_negative.onclick = function() { + dialog = document.getElementById("new_record_dialog"); + dialog.className = "modal"; + refresh_records_list(); + } + + dialog = document.getElementById("new_record_dialog"); + dialog.className = "modal is-active"; +} + +function get_record_from_dialog() { + record_name = document.getElementById("record_name").value; + record_type = document.getElementById("record_type").value; + record_ttl = parseInt(document.getElementById("record_ttl").value); + record_data = document.getElementById("record_data").value; + return { type: record_type, domain: record_name, ttl: record_ttl, addr: record_data } +} + function onLoad() { external.invoke(JSON.stringify({cmd: 'loaded'})); } @@ -37,7 +100,8 @@ function saveKey() { function createDomain() { new_domain = document.getElementById("new_domain").value; - new_dom_records = document.getElementById("new_domain_records").value; + //new_dom_records = document.getElementById("new_domain_records").value; + new_dom_records = JSON.stringify(recordsBuffer); new_dom_tags = document.getElementById("new_domain_tags").value; external.invoke(JSON.stringify({cmd: 'createDomain', name: new_domain, records: new_dom_records, tags: new_dom_tags})); } @@ -102,6 +166,20 @@ function showModalDialog(text, callback) { dialog.className = "modal is-active"; } +function showWarning(text) { + warning = document.getElementById("notification_warning"); + message = document.getElementById("warning_text"); + message.innerHTML = text; + + warning.className = "notification is-warning"; + button = document.getElementById("close"); + button.onclick = function() { + message.value = ""; + warning.className = "notification is-warning is-hidden"; + } + setTimeout(button.onclick, 5000); +} + function showMiningIndicator(visible) { indicator = document.getElementById("mining_indicator"); if (visible) {