use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
use data_encoding::{BASE32HEX, BASE64};
use rand::Rng;
use std::{collections::HashMap, cmp::max};
use std::io::{self, Cursor, Error, ErrorKind, Read, Seek, SeekFrom};
use std::fmt::{self, Display};
use strum_macros::EnumString;
#[derive(PartialEq, Copy, Clone, Debug)]
pub enum DnsOpcode {
QUERY,
IQUERY,
STATUS,
NOTIFY,
UPDATE,
DSO
}
#[derive(PartialEq, Copy, Clone, Debug)]
pub enum DnsRcode {
NOERROR,
FORMERR,
SERVFAIL,
NXDOMAIN,
NOTIMP,
REFUSED,
YXDOMAIN,
YXRRSET,
NXRRSET,
NOTAUTH,
NOTZONE,
DSOTYPENI,
BADVERSBADSIG,
BADKEY,
BADTIME,
BADMODE,
BADNAME,
BADALG,
BADTRUNC,
BADCOOKIE
}
#[derive(PartialEq, Copy, Clone, EnumString, Debug)]
pub enum DnsType {
A,
NS,
CNAME,
SOA,
PTR,
HINFO,
MX,
TXT,
RP,
AAAA,
SRV,
DNAME,
OPT,
DS,
SSHFP,
RRSIG,
NSEC,
DNSKEY,
NSEC3,
TLSA,
CAA
}
#[derive(PartialEq, Copy, Clone, Debug)]
pub enum DnsClass {
IN,
CH,
HS,
NONE,
ANY
}
pub struct DnsHeader {
msg_id: u16,
qr: bool,
opcode: DnsOpcode,
aa: bool,
tc: bool,
rd: bool,
ra: bool,
ad: bool,
cd: bool,
rcode: Option<DnsRcode>,
qdcount: u16,
ancount: u16,
nscount: u16,
arcount: u16
}
pub struct DnsQuestion {
qname: String,
qtype: DnsType,
qclass: DnsClass
}
#[derive(Clone)]
pub struct DnsRecord {
name: String,
atype: DnsType,
class: Option<DnsClass>,
ttl: Option<u32>,
payload_size: Option<u16>,
rcode: Option<DnsRcode>,
edns_version: Option<u8>,
flags: Option<HashMap<&'static str, bool>>,
rdata: Vec<u8>,
parsed_rdata: Vec<String>
}
pub struct DnsMessage {
pub header: DnsHeader,
questions: Vec<DnsQuestion>,
answers: Vec<DnsRecord>,
authoritative_answers: Vec<DnsRecord>,
additional_answers: Vec<DnsRecord>
}
impl DnsOpcode {
pub fn encode(&self) -> u8 {
match self {
DnsOpcode::QUERY => 0,
DnsOpcode::IQUERY => 1,
DnsOpcode::STATUS => 2,
DnsOpcode::NOTIFY => 4,
DnsOpcode::UPDATE => 5,
DnsOpcode::DSO => 6
}
}
pub fn parse(val: u8) -> io::Result<DnsOpcode> {
match val {
0 => Ok(DnsOpcode::QUERY),
1 => Ok(DnsOpcode::IQUERY),
2 => Ok(DnsOpcode::STATUS),
4 => Ok(DnsOpcode::NOTIFY),
5 => Ok(DnsOpcode::UPDATE),
6 => Ok(DnsOpcode::DSO),
x => Err(Error::new(
ErrorKind::InvalidInput,
format!("Invalid opcode: valid are 0 to 2 and 4 to 6, got {}", x)
))
}
}
}
impl Display for DnsOpcode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl DnsRcode {
pub fn encode(&self) -> u8 {
match self {
DnsRcode::NOERROR => 0,
DnsRcode::FORMERR => 1,
DnsRcode::SERVFAIL => 2,
DnsRcode::NXDOMAIN => 3,
DnsRcode::NOTIMP => 4,
DnsRcode::REFUSED => 5,
DnsRcode::YXDOMAIN => 6,
DnsRcode::YXRRSET => 7,
DnsRcode::NXRRSET => 8,
DnsRcode::NOTAUTH => 9,
DnsRcode::NOTZONE => 10,
DnsRcode::DSOTYPENI => 11,
DnsRcode::BADVERSBADSIG => 16 & 0b1111,
DnsRcode::BADKEY => 17 & 0b1111,
DnsRcode::BADTIME => 18 & 0b1111,
DnsRcode::BADMODE => 19 & 0b1111,
DnsRcode::BADNAME => 20 & 0b1111,
DnsRcode::BADALG => 21 & 0b1111,
DnsRcode::BADTRUNC => 22 & 0b1111,
DnsRcode::BADCOOKIE => 23 & 0b1111
}
}
pub fn parse(val: u16) -> io::Result<DnsRcode> {
match val {
0 => Ok(DnsRcode::NOERROR),
1 => Ok(DnsRcode::FORMERR),
2 => Ok(DnsRcode::SERVFAIL),
3 => Ok(DnsRcode::NXDOMAIN),
4 => Ok(DnsRcode::NOTIMP),
5 => Ok(DnsRcode::REFUSED),
6 => Ok(DnsRcode::YXDOMAIN),
7 => Ok(DnsRcode::YXRRSET),
8 => Ok(DnsRcode::NXRRSET),
9 => Ok(DnsRcode::NOTAUTH),
10 => Ok(DnsRcode::NOTZONE),
11 => Ok(DnsRcode::DSOTYPENI),
16 => Ok(DnsRcode::BADVERSBADSIG),
17 => Ok(DnsRcode::BADKEY),
18 => Ok(DnsRcode::BADTIME),
19 => Ok(DnsRcode::BADMODE),
20 => Ok(DnsRcode::BADNAME),
21 => Ok(DnsRcode::BADALG),
22 => Ok(DnsRcode::BADTRUNC),
23 => Ok(DnsRcode::BADCOOKIE),
x => Err(Error::new(
ErrorKind::InvalidInput,
format!("Invalid rcode: valid are values from 0 to 11 and 16 to 23, got {}", x)
))
}
}
}
impl Display for DnsRcode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl DnsType {
pub fn encode(&self) -> u16 {
match self {
DnsType::A => 1,
DnsType::NS => 2,
DnsType::CNAME => 5,
DnsType::SOA => 6,
DnsType::PTR => 12,
DnsType::HINFO => 13,
DnsType::MX => 15,
DnsType::TXT => 16,
DnsType::RP => 17,
DnsType::AAAA => 28,
DnsType::SRV => 33,
DnsType::DNAME => 39,
DnsType::OPT => 41,
DnsType::DS => 43,
DnsType::SSHFP => 44,
DnsType::RRSIG => 46,
DnsType::NSEC => 47,
DnsType::DNSKEY => 48,
DnsType::NSEC3 => 50,
DnsType::TLSA => 52,
DnsType::CAA => 257
}
}
pub fn parse(val: u16) -> io::Result<DnsType> {
Ok(match val {
1 => DnsType::A,
2 => DnsType::NS,
5 => DnsType::CNAME,
6 => DnsType::SOA,
12 => DnsType::PTR,
13 => DnsType::HINFO,
15 => DnsType::MX,
16 => DnsType::TXT,
17 => DnsType::RP,
28 => DnsType::AAAA,
33 => DnsType::SRV,
39 => DnsType::DNAME,
41 => DnsType::OPT,
43 => DnsType::DS,
44 => DnsType::SSHFP,
46 => DnsType::RRSIG,
47 => DnsType::NSEC,
48 => DnsType::DNSKEY,
50 => DnsType::NSEC3,
52 => DnsType::TLSA,
257 => DnsType::CAA,
x => return Err(Error::new(
ErrorKind::InvalidInput,
format!("Unknown or unimplemented DNS TYPE with number {}.", x)
))
})
}
pub fn rdata_schema(&self) -> &str {
match self {
DnsType::A => "ip4",
DnsType::NS
| DnsType::CNAME
| DnsType::DNAME
| DnsType::PTR
| DnsType::DNAME => "qname",
DnsType::SOA => "qname qname u32 u32 u32 u32 u32",
DnsType::HINFO => "string string",
DnsType::MX => "u16 qname",
DnsType::TXT => "text",
DnsType::RP => "qname qname",
DnsType::AAAA => "ip6",
DnsType::SRV => "u16 u16 u16 qname",
DnsType::OPT => "options",
DnsType::DS => "u16 u8 u8 hex",
DnsType::SSHFP => "u8 u8 hex",
DnsType::RRSIG => "qtype u8 u8 u32 u32 u32 u16 qname base64",
DnsType::NSEC => "qname types",
DnsType::DNSKEY => "u16 u8 u8 base64",
DnsType::NSEC3 => "u8 u8 u16 salt hash types",
DnsType::TLSA => "u8 u8 u8 hex",
DnsType::CAA => "u8 property"
}
}
}
impl Display for DnsType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl DnsClass {
pub fn encode(&self) -> u16 {
match self {
DnsClass::IN => 1,
DnsClass::CH => 3,
DnsClass::HS => 4,
DnsClass::NONE => 254,
DnsClass::ANY => 255
}
}
pub fn parse(val: u16) -> io::Result<DnsClass> {
match val {
1 => Ok(DnsClass::IN),
3 => Ok(DnsClass::CH),
4 => Ok(DnsClass::HS),
254 => Ok(DnsClass::NONE),
255 => Ok(DnsClass::ANY),
x => Err(Error::new(
ErrorKind::InvalidInput,
format!("Unknown DNS CLASS with number {}.", x)
))
}
}
}
impl Display for DnsClass {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl DnsHeader {
pub fn new_response_header(msg_id: u16, opcode: DnsOpcode, aa: bool, tc: bool, rd: bool,
ra: bool, ad: bool, cd: bool, rcode: DnsRcode, qdcount: u16,
ancount: u16, nscount: u16, arcount: u16) -> DnsHeader {
DnsHeader {
msg_id,
qr: true,
opcode,
aa,
tc,
rd,
ra,
ad,
cd,
rcode: Some(rcode),
qdcount,
ancount,
nscount,
arcount
}
}
pub fn new_query_header(msg_id: u16, opcode: DnsOpcode, tc: bool, rd: bool, ad: bool, cd: bool,
edns: bool, qdcount: u16) -> DnsHeader {
DnsHeader {
msg_id,
qr: false,
opcode,
aa: false,
tc,
rd,
ra: false,
ad,
cd,
rcode: None,
qdcount,
ancount: 0,
nscount: 0,
arcount: if edns {1} else {0}
}
}
pub fn encode(&self) -> io::Result<Vec<u8>> {
let mut res = Vec::new();
let qr = if self.qr {1u16} else {0u16};
let opcode = self.opcode.encode() as u16;
let aa = if self.aa {1} else {0};
let tc = if self.tc {1} else {0};
let rd = if self.rd {1} else {0};
let ra = if self.ra {1} else {0};
let ad = if self.ad {1} else {0};
let cd = if self.cd {1} else {0};
let rcode = match &self.rcode {
Some(val) => val.encode() as u16,
None => 0u16
};
let line_two = (qr << 15) + (opcode << 11) + (aa << 10) + (tc << 9) + (rd << 8)
+ (ra << 7) + (ad << 5) + (cd << 4) + rcode;
res.write_u16::<NetworkEndian>(self.msg_id)?;
res.write_u16::<NetworkEndian>(line_two)?;
res.write_u16::<NetworkEndian>(self.qdcount)?;
res.write_u16::<NetworkEndian>(self.ancount)?;
res.write_u16::<NetworkEndian>(self.nscount)?;
res.write_u16::<NetworkEndian>(self.arcount)?;
Ok(res)
}
pub fn parse(header: &mut Cursor<&[u8]>) -> io::Result<DnsHeader> {
let msg_id = header.read_u16::<NetworkEndian>()?;
let line_two = header.read_u16::<NetworkEndian>()?;
let qr = (line_two & (1 << 15)) >> 15;
let opcode = DnsOpcode::parse(((line_two & (0b1111 << 11)) >> 11) as u8)?;
let aa = (line_two & (1 << 10)) >> 10;
let tc = (line_two & (1 << 9)) >> 9;
let rd = (line_two & (1 << 8)) >> 8;
let ra = (line_two & (1 << 7)) >> 7;
let ad = (line_two & (1 << 5)) >> 5;
let cd = (line_two & (1 << 4)) >> 4;
let rcode = DnsRcode::parse(line_two & 0b1111)?;
Ok(DnsHeader {
msg_id,
qr: if qr != 0 {true} else {false},
opcode,
aa: if aa != 0 {true} else {false},
tc: if tc != 0 {true} else {false},
rd: if rd != 0 {true} else {false},
ra: if ra != 0 {true} else {false},
ad: if ad != 0 {true} else {false},
cd: if cd != 0 {true} else {false},
rcode: if qr != 0 {Some(rcode)} else {None},
qdcount: header.read_u16::<NetworkEndian>()?,
ancount: header.read_u16::<NetworkEndian>()?,
nscount: header.read_u16::<NetworkEndian>()?,
arcount: header.read_u16::<NetworkEndian>()?
})
}
pub fn info_str(&self) -> String {
let mut s = String::new();
if let Some(rcode) = self.rcode {
s.push_str(format!("id: {}, opcode: {}, rcode: {}, flags: ", self.msg_id, self.opcode,
rcode).as_str());
} else {
s.push_str(format!("id: {}, opcode: {}, flags: ", self.msg_id, self.opcode).as_str());
}
if self.aa { s.push_str("aa ") }
if self.tc { s.push_str("tc ") }
if self.rd { s.push_str("rd ") }
if self.ra { s.push_str("ra ") }
if self.ad { s.push_str("ad ") }
if self.cd { s.push_str("cd ") }
s.remove(s.len() - 1);
s
}
}
impl Display for DnsHeader {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut s = String::new();
if self.qr {
s.push_str("DNS Response (");
} else {
s.push_str("DNS Query (");
}
s.push_str(&self.info_str());
s.push(')');
write!(f, "{}", s)
}
}
impl DnsQuestion {
pub fn new(domain: &str, qtype: DnsType, qclass: DnsClass) -> DnsQuestion {
DnsQuestion {
qname: domain.to_string(),
qtype,
qclass
}
}
pub fn encode(&self) -> io::Result<Vec<u8>> {
let mut question = DnsMessage::encode_qname(self.qname.as_str())?;
question.write_u16::<NetworkEndian>(self.qtype.encode())?;
question.write_u16::<NetworkEndian>(self.qclass.encode())?;
Ok(question)
}
pub fn parse(msg: &mut Cursor<&[u8]>) -> io::Result<DnsQuestion> {
let qname = DnsMessage::parse_qname(msg)?;
let qtype = DnsType::parse(msg.read_u16::<NetworkEndian>()?)?;
let qclass = DnsClass::parse(msg.read_u16::<NetworkEndian>()?)?;
Ok(DnsQuestion {
qname,
qtype,
qclass
})
}
pub fn as_padded_string(&self, owner_len: usize) -> String {
let mut res = String::new();
let mut owner = self.qname.clone();
while owner.len() < owner_len {
owner.push(' ');
}
res.push_str(format!(
"{}\t \t{}\t{}", owner, self.qclass, self.qtype
).as_str());
res
}
}
impl Display for DnsQuestion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "DNS Question for '{}' (type: {}, class: {})", self.qname, self.qtype, self.qclass)
}
}
impl DnsRecord {
pub fn new_opt_record(payload_size: u16, do_flag: bool) -> DnsRecord {
let mut flags = HashMap::new();
flags.insert("do", do_flag);
DnsRecord {
name: "".to_string(),
atype: DnsType::OPT,
class: None,
ttl: None,
payload_size: Some(payload_size),
rcode: Some(DnsRcode::NOERROR),
edns_version: Some(0),
flags: Some(flags),
rdata: vec![],
parsed_rdata: Default::default()
}
}
pub fn encode(&self) -> io::Result<Vec<u8>> {
let mut record = DnsMessage::encode_qname(self.name.as_str())?;
record.write_u16::<NetworkEndian>(self.atype.encode())?;
if self.class.is_some() {
record.write_u16::<NetworkEndian>(self.class.as_ref().unwrap().encode())?;
record.write_u32::<NetworkEndian>(self.ttl.unwrap())?;
} else {
record.write_u16::<NetworkEndian>(self.payload_size.unwrap())?;
let rcode = (((self.rcode.unwrap().encode() as u16) & 0b111111110000) >> 4) as u8;
record.write_u8(rcode)?;
record.write_u8(self.edns_version.unwrap())?;
if self.flags.as_ref().unwrap()["do"] {
record.write_u16::<NetworkEndian>(1 << 15)?;
} else {
record.write_u16::<NetworkEndian>(0)?;
}
}
record.write_u16::<NetworkEndian>(self.rdata.len() as u16)?;
record.append(&mut self.rdata.clone());
Ok(record)
}
pub fn parse(msg: &mut Cursor<&[u8]>, rcode: DnsRcode) -> io::Result<DnsRecord> {
let name = DnsMessage::parse_qname(msg)?;
let t = msg.read_u16::<NetworkEndian>()?;
let atype = DnsType::parse(t)?;
if atype == DnsType::OPT {
return DnsRecord::parse_opt_record(msg, name, rcode);
}
let class = DnsClass::parse(msg.read_u16::<NetworkEndian>()?)?;
let ttl = msg.read_u32::<NetworkEndian>()?;
let rdlength = msg.read_u16::<NetworkEndian>()?;
let mut rdata = Vec::with_capacity(rdlength as usize);
rdata.resize(rdlength as usize, 0);
let pos_rdata_start = msg.position();
msg.read_exact(&mut rdata)?;
msg.set_position(pos_rdata_start);
let parsed_rdata = DnsRecord::parse_rdata(&atype, msg, rdlength)?;
Ok(DnsRecord {
name,
atype,
class: Some(class),
ttl: Some(ttl),
payload_size: None,
rcode: None,
edns_version: None,
flags: None,
rdata,
parsed_rdata
})
}
pub fn parse_rdata(atype: &DnsType, msg: &mut Cursor<&[u8]>, rdlength: u16) -> io::Result<Vec<String>> {
let mut res = Vec::new();
let schema = atype.rdata_schema();
let mut len_read = 0u16;
for token in schema.split(' ') {
let pos_before = msg.position();
res.push(match token {
"u8" => { msg.read_u8()?.to_string() },
"u16" => { msg.read_u16::<NetworkEndian>()?.to_string() },
"u32" => { msg.read_u32::<NetworkEndian>()?.to_string() },
"qname" => { DnsMessage::parse_qname(msg)? },
"string" => { DnsMessage::parse_string(msg)? },
"ip4" => { format!("{}.{}.{}.{}", msg.read_u8()?, msg.read_u8()?, msg.read_u8()?, msg.read_u8()?) },
"ip6" => {
let mut addr = String::new();
for _ in 0..8 {
addr.push_str(format!("{:x}:", msg.read_u16::<NetworkEndian>()?).as_str());
}
addr.remove(addr.len() - 1);
addr
},
"text" => {
let mut s = String::new();
let mut len = 0;
while len < rdlength - len_read {
let t = DnsMessage::parse_string(msg)?;
s.push_str(t.as_str());
len += (t.len() as u16) + 1;
}
s
},
"hex" => {
let mut hex = String::new();
for _i in 0..(rdlength - len_read) {
hex.push_str(format!("{:02x}", msg.read_u8()?).as_str());
}
hex
},
"qtype" => {
DnsType::parse(msg.read_u16::<NetworkEndian>()?)?.to_string()
},
"base64" => {
let mut data = Vec::with_capacity((rdlength - len_read) as usize);
data.resize((rdlength - len_read) as usize, 0);
msg.read_exact(&mut data)?;
BASE64.encode(&data)
},
"types" => {
let bitmap = DnsRecord::parse_nsec_type_bitmap(msg, len_read, rdlength)?;
DnsRecord::interpret_nsec_type_bitmap(bitmap)?
},
"salt" => {
let mut salt = String::new();
for _ in 0..msg.read_u8()? {
salt.push_str(format!("{:02x}", msg.read_u8()?).as_str());
}
salt
},
"hash" => {
let hash_len = msg.read_u8()? as usize;
let mut hash = Vec::with_capacity(hash_len);
hash.resize(hash_len, 0);
msg.read_exact(&mut hash)?;
BASE32HEX.encode(&hash)
},
"property" => {
let tag_len = msg.read_u8()?;
let mut property = String::new();
for _i in 0..tag_len {
property.push(msg.read_u8()? as char);
}
property.push(' ');
for _i in 0..(rdlength - (tag_len as u16) - 1 - len_read) {
property.push(msg.read_u8()? as char);
}
property
},
"options" => {
let mut len = 0;
let mut s = String::new();
while len < rdlength - len_read {
let option_code = msg.read_u16::<NetworkEndian>()?;
s.push_str(option_code.to_string().as_str());
s.push_str(": ");
let option_len = msg.read_u16::<NetworkEndian>()?;
for _ in 0..option_len {
s.push_str(format!("{:02x}", msg.read_u8()?).as_str());
}
s.push_str(", ");
len += option_len + 4;
}
if s.len() > 0 {
s.remove(s.len() - 1);
s.remove(s.len() - 1);
}
s
},
x => {
return Err(io::Error::new(
io::ErrorKind::Other, format!("unknown rdata schema: {}", x)
));
}
});
len_read += (msg.position() - pos_before) as u16;
}
Ok(res)
}
pub fn interpret_dnssec_algorithm(algorithm: u8) -> io::Result<String> {
Ok(match algorithm {
0 => "DELETE",
1 => "RSAMD5",
2 => "DH",
3 => "DSA",
5 => "RSASHA1",
6 => "DSA-NSEC3-SHA1 (DSA)",
7 => "RSASHA1-NSEC3-SHA1 (RSASHA1)",
8 => "RSASHA256",
10 => "RSASHA512",
12 => "ECC-GOST",
13 => "ECDSAP256SHA256",
14 => "ECDSAP384SHA384",
15 => "ED25519",
16 => "ED448",
252 => "INDIRECT",
253 => "PRIVATEDNS",
254 => "PRIVATEOID",
x => return Err(Error::new(
ErrorKind::Other,
format!("unknown DNSSEC algorithm: {}", x)
))
}.to_string())
}
pub fn interpret_nsec_type_bitmap(types: Vec<u16>) -> io::Result<String> {
let mut res = String::new();
for t in types {
let t = match DnsType::parse(t) {
Ok(r) => r.to_string(),
Err(_) => format!("TYPE{}", t)
};
res.push_str(t.as_str());
res.push(' ');
}
if res.len() > 0 {
res.remove(res.len() - 1);
}
Ok(res)
}
pub fn parse_nsec_type_bitmap(msg: &mut Cursor<&[u8]>, len_read: u16, rdlength: u16)
-> io::Result<Vec<u16>> {
let mut len_read = len_read;
let mut available_types = Vec::new();
while len_read < rdlength {
let window_number = msg.read_u8()?;
let bitmap_len = msg.read_u8()?;
for i in 0..bitmap_len {
let byte = msg.read_u8()?;
for j in 0..8 {
if (byte & (0b10000000 >> j)) != 0 {
let type_num = ((window_number as u16) << 8) + (i * 8 + j) as u16;
available_types.push(type_num);
}
}
}
len_read += (2 + bitmap_len) as u16;
}
Ok(available_types)
}
pub fn as_padded_string(&self, owner_len: usize) -> String {
let mut res = String::new();
let mut owner = self.name.clone();
while owner.len() < owner_len {
owner.push(' ');
}
res.push_str(format!(
"{}\t{}\t{}\t{}\t\t{}", owner, self.ttl.unwrap(), self.class.unwrap(), self.atype,
DnsRecord::format_parsed_rdata(&self.parsed_rdata)
).as_str());
res
}
fn parse_opt_record(msg: &mut Cursor<&[u8]>, name: String, rcode: DnsRcode)
-> io::Result<DnsRecord> {
if !name.eq("") {
return Err(Error::new(
ErrorKind::Other, "OPT record must have the root as name"
))
}
let payload_size = msg.read_u16::<NetworkEndian>()?;
let rcode_and_flags = msg.read_u32::<NetworkEndian>()?;
let ext_rcode = ((rcode_and_flags & (0b11111111 << 24)) >> 24) as u8;
let rcode = match ext_rcode {
0 => rcode,
x => DnsRcode::parse(((x as u16) << 4) + (rcode.encode() as u16))?
};
let version = msg.read_u8()?;
let mut flags = HashMap::new();
flags.insert(
"do",
if msg.read_u16::<NetworkEndian>()? & (1 << 15) != 0 {true} else {false}
);
let rdlength = msg.read_u16::<NetworkEndian>()?;
let mut rdata = Vec::with_capacity(rdlength as usize);
rdata.resize(rdlength as usize, 0);
let pos_rdata_start = msg.position();
msg.read_exact(&mut rdata)?;
msg.set_position(pos_rdata_start);
let parsed_rdata = DnsRecord::parse_rdata(&DnsType::OPT, msg, rdlength)?;
Ok(DnsRecord {
name,
atype: DnsType::OPT,
class: None,
ttl: None,
payload_size: Some(payload_size),
rcode: Some(rcode),
edns_version: Some(version),
flags: Some(flags),
rdata,
parsed_rdata
})
}
fn format_parsed_rdata(parsed_rdata: &Vec<String>) -> String {
let mut res = String::new();
for val in parsed_rdata {
res.push_str(format!("{} ", val).as_str());
}
if parsed_rdata.len() > 0 {
res.remove(res.len() - 1);
}
res
}
}
impl Display for DnsRecord {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.class.is_some() {
write!(
f, "DNS Record for '{}' (type: {}, class: {}, ttl: {}, rdata: {})",
&self.name, &self.atype, &self.class.as_ref().unwrap(), self.ttl.unwrap(),
DnsRecord::format_parsed_rdata(&self.parsed_rdata)
)
} else {
let mut s = format!(
"DNS OPT Record (EDNS version: {}, payload size: {}, flags: ",
&self.edns_version.unwrap(), self.payload_size.unwrap()
);
for (key, value) in self.flags.as_ref().unwrap() {
if *value { s.push_str(key); s.push_str(" "); }
}
s.remove(s.len() - 1);
s.push_str(format!(
", rdata: {})", DnsRecord::format_parsed_rdata(&self.parsed_rdata)
).as_str());
write!(f, "{}", s)
}
}
}
impl DnsMessage {
pub fn new_query(domain: &str, qtype: DnsType, opcode: DnsOpcode, rd: bool, ad: bool, cd: bool,
edns: bool, do_flag: bool, bufsize: u16) -> io::Result<DnsMessage> {
if do_flag && !edns {
return Err(Error::new(
ErrorKind::InvalidInput, "do flag set but not using EDNS(0)"
));
}
let msg_id = rand::thread_rng().gen_range(0, 1u32 << 16) as u16;
let mut additional_answers = Vec::new();
if edns {
additional_answers.push(DnsRecord::new_opt_record(bufsize, do_flag));
}
Ok(DnsMessage {
header: DnsHeader::new_query_header(msg_id, opcode, false, rd, ad, cd, edns, 1),
questions: vec![DnsQuestion::new(domain, qtype, DnsClass::IN)],
answers: Vec::new(),
authoritative_answers: Vec::new(),
additional_answers
})
}
pub fn new_response(msg_id: u16, opcode: DnsOpcode, aa: bool, tc: bool, rd: bool, ra: bool,
ad: bool, cd: bool, rcode: DnsRcode, questions: Vec<DnsQuestion>,
answers: Vec<DnsRecord>, authoritative_answers: Vec<DnsRecord>,
additional_answers: Vec<DnsRecord>) -> DnsMessage {
DnsMessage {
header: DnsHeader::new_response_header(
msg_id, opcode, aa, tc, rd, ra, ad, cd, rcode,
questions.len() as u16,
answers.len() as u16,
authoritative_answers.len() as u16,
additional_answers.len() as u16
),
questions,
answers,
authoritative_answers,
additional_answers
}
}
pub fn encode(&self) -> io::Result<Vec<u8>> {
let mut res = self.header.encode()?;
for question in &self.questions {
res.append(&mut question.encode()?);
}
for record in &self.answers {
res.append(&mut record.encode()?);
}
for record in &self.authoritative_answers {
res.append(&mut record.encode()?);
}
for record in &self.additional_answers {
res.append(&mut record.encode()?);
}
Ok(res)
}
pub fn parse(msg: &mut Cursor<&[u8]>) -> io::Result<DnsMessage> {
let mut header = DnsHeader::parse(msg)?;
if header.tc {
return Err(Error::new(
ErrorKind::Other, "Received truncated message. Support for truncated \
messages has not been implemented yet."
));
}
let qdcount = header.qdcount;
let ancount = header.ancount;
let nscount = header.nscount;
let arcount = header.arcount;
let questions = DnsMessage::parse_questions(msg, qdcount)?;
let mut answers = Vec::new();
let mut authoritative_answers = Vec::new();
let mut additional_answers = Vec::new();
if ancount > 0 {
answers = DnsMessage::parse_records(
msg, ancount, header.rcode.unwrap()
)?;
}
if nscount > 0 {
authoritative_answers = DnsMessage::parse_records(
msg, nscount, header.rcode.unwrap()
)?;
}
if arcount > 0 {
additional_answers = DnsMessage::parse_records(
msg, arcount, header.rcode.unwrap()
)?;
}
for answer in &additional_answers {
if answer.rcode.is_some() {
header.rcode = answer.rcode;
}
}
Ok(DnsMessage {
header,
questions,
answers,
authoritative_answers,
additional_answers
})
}
pub fn encode_qname(domain: &str) -> io::Result<Vec<u8>> {
if domain.bytes().len() > 255 {
return Err(
Error::new(
ErrorKind::InvalidInput,
format!("Invalid domain: expected up to 255 bytes, got {}",
domain.bytes().len())
)
);
}
if domain.eq("") {
return Ok(vec![0]);
}
let mut res = Vec::new();
for label in domain.split('.') {
if label.bytes().len() > 63 {
return Err(
Error::new(
ErrorKind::InvalidInput,
format!("Invalid label in domain: expected up to 63 bytes, got {}",
label.bytes().len())
)
);
}
res.write_u8(label.len() as u8)?;
label.bytes().for_each(|b| res.push(b));
}
res.write_u8(0)?;
Ok(res)
}
pub fn parse_string(msg: &mut Cursor<&[u8]>) -> io::Result<String> {
let length = msg.read_u8()?;
let mut res = String::new();
for _i in 0..length {
res.push(msg.read_u8()? as char);
}
Ok(res)
}
pub fn parse_qname(msg: &mut Cursor<&[u8]>) -> io::Result<String> {
let mut domain = String::new();
let mut c = msg.read_u8()?;
while c != 0 {
if (c & 0b11000000) != 0 {
c &= 0b00111111;
let offset = ((c as u16) << 8) + (msg.read_u8()? as u16);
let pos_after_pointer = msg.position() as i64;
msg.seek(SeekFrom::Start(offset as u64))?;
domain.push_str(DnsMessage::parse_qname(msg)?.as_str());
msg.seek(SeekFrom::Start(pos_after_pointer as u64))?;
return Ok(domain);
} else if (c & 0b01000000) != 0 || (c & 0b10000000) != 0 {
return Err(Error::new(
ErrorKind::Other,
"Unsupported label type (extended or invalid)"
));
}
for _i in 0..c {
domain.push(msg.read_u8()? as char);
}
domain.push('.');
c = msg.read_u8()?;
}
Ok(domain)
}
fn parse_questions(msg: &mut Cursor<&[u8]>, qdcount: u16) -> io::Result<Vec<DnsQuestion>> {
let mut questions = Vec::with_capacity(qdcount as usize);
for _i in 0..qdcount {
questions.push(DnsQuestion::parse(msg)?);
}
Ok(questions)
}
fn parse_records(msg: &mut Cursor<&[u8]>, ancount: u16, rcode: DnsRcode) -> io::Result<Vec<DnsRecord>> {
let mut answers = Vec::with_capacity(ancount as usize);
for _i in 0..ancount {
answers.push(DnsRecord::parse(msg, rcode)?);
}
Ok(answers)
}
}
impl Display for DnsMessage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut res = String::new();
let mut additional_answers = self.additional_answers.clone();
let mut opt = None;
let mut opt_index = 0;
let mut max_owner_len = 0;
for question in &self.questions {
max_owner_len = max(max_owner_len, question.qname.len());
}
for answer in &self.answers {
max_owner_len = max(max_owner_len, answer.name.len());
}
for answer in &self.authoritative_answers {
max_owner_len = max(max_owner_len, answer.name.len());
}
for (i, answer) in additional_answers.iter().enumerate() {
max_owner_len = max(max_owner_len, answer.name.len());
if answer.atype == DnsType::OPT {
opt = Some(answer);
opt_index = i;
}
}
if let Some(_) = opt {
additional_answers.remove(opt_index);
}
res.push_str(format!("Header:\n\t{}\n\n", self.header.info_str()).as_str());
res.push_str("Question Section:\n");
for question in &self.questions {
res.push('\t');
res.push_str(question.as_padded_string(max_owner_len).as_str());
res.push('\n');
}
res.push('\n');
if !self.answers.is_empty() {
res.push_str("Answer Section:\n");
for answer in &self.answers {
res.push('\t');
res.push_str(answer.as_padded_string(max_owner_len).as_str());
res.push('\n');
}
res.push('\n');
}
if !self.authoritative_answers.is_empty() {
res.push_str("Authoritative Section:\n");
for answer in &self.authoritative_answers {
res.push('\t');
res.push_str(answer.as_padded_string(max_owner_len).as_str());
res.push('\n');
}
res.push('\n');
}
if !additional_answers.is_empty() {
res.push_str("Additional Section:\n");
for answer in &additional_answers {
res.push('\t');
res.push_str(answer.as_padded_string(max_owner_len).as_str());
res.push('\n');
}
}
while res.chars().nth(res.len() - 1).unwrap() == '\n' {
res.remove(res.len() - 1);
}
write!(f, "{}", res)
}
}