sync_resolve/
message.rs

1//! Utilities for composing, decoding, and encoding messages.
2
3use std::borrow::Cow::{self, Borrowed, Owned};
4use std::cell::Cell;
5use std::default::Default;
6use std::fmt;
7use std::io::{Cursor, Read, Write};
8use std::mem::{transmute, zeroed};
9use std::slice::Iter;
10use std::str::from_utf8_unchecked;
11use std::vec::IntoIter;
12
13use rand::random;
14
15use crate::idna;
16use crate::record::{Class, Record, RecordType};
17
18/// Maximum size of a DNS message, in bytes.
19pub const MESSAGE_LIMIT: usize = 0xffff;
20
21/// Maximum length of a name segment (i.e. a `.`-separated identifier).
22pub const LABEL_LIMIT: usize = 63;
23
24/// Maximum total length of a name, in encoded format.
25pub const NAME_LIMIT: usize = 255;
26
27/// An error response code received in a response message.
28#[derive(Copy, Clone, Debug, Eq, PartialEq)]
29pub struct DnsError(pub RCode);
30
31impl fmt::Display for DnsError {
32    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
33        f.write_str(self.0.get_error())
34    }
35}
36
37/// Represents an error in decoding a DNS message.
38#[derive(Copy, Clone, Debug, Eq, PartialEq)]
39pub enum DecodeError {
40    /// Extraneous data encountered at the end of message
41    ExtraneousData,
42    /// Message end was encountered before expected
43    ShortMessage,
44    /// Unable to decode invalid data
45    InvalidMessage,
46    /// An invalid name was encountered
47    InvalidName,
48}
49
50impl fmt::Display for DecodeError {
51    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52        f.write_str(match *self {
53            DecodeError::ExtraneousData => "extraneous data",
54            DecodeError::ShortMessage => "short message",
55            DecodeError::InvalidMessage => "invalid message",
56            DecodeError::InvalidName => "invalid name",
57        })
58    }
59}
60
61/// Represents an error in encoding a DNS message.
62#[derive(Clone, Debug, Eq, PartialEq)]
63pub enum EncodeError {
64    /// A name or label was too long or contained invalid characters
65    InvalidName,
66    /// Message exceeded given buffer or `MESSAGE_LIMIT` bytes
67    TooLong,
68}
69
70impl fmt::Display for EncodeError {
71    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
72        match *self {
73            EncodeError::InvalidName => f.write_str("invalid name value"),
74            EncodeError::TooLong => f.write_str("message too long"),
75        }
76    }
77}
78
79/// Reads a single DNS message from a series of bytes.
80pub struct MsgReader<'a> {
81    data: Cursor<&'a [u8]>,
82}
83
84impl<'a> MsgReader<'a> {
85    /// Constructs a new message reader.
86    pub fn new(data: &[u8]) -> MsgReader {
87        MsgReader {
88            data: Cursor::new(data),
89        }
90    }
91
92    /// Constructs a new message reader, which will read from `data`,
93    /// beginning at `offset`.
94    pub fn with_offset(data: &[u8], offset: usize) -> MsgReader {
95        let mut cur = Cursor::new(data);
96        cur.set_position(offset as u64);
97        MsgReader { data: cur }
98    }
99
100    /// Returns the number of bytes remaining in the message.
101    pub fn remaining(&self) -> usize {
102        self.data.get_ref().len() - self.data.position() as usize
103    }
104
105    /// Reads a number of bytes equal to the length of the given buffer.
106    /// Returns `Err(ShortMessage)` if there are not enough bytes remaining.
107    pub fn read(&mut self, buf: &mut [u8]) -> Result<(), DecodeError> {
108        match self.data.read(buf) {
109            Ok(n) if n == buf.len() => Ok(()),
110            _ => Err(DecodeError::ShortMessage),
111        }
112    }
113
114    /// Reads a single byte from the message.
115    pub fn read_byte(&mut self) -> Result<u8, DecodeError> {
116        let mut buf = [0];
117        self.read(&mut buf)?;
118        Ok(buf[0])
119    }
120
121    /// Reads all remaining bytes.
122    pub fn read_to_end(&mut self) -> Result<Vec<u8>, DecodeError> {
123        let mut res = vec![0_u8; self.remaining()];
124        self.read(&mut res)?;
125        Ok(res)
126    }
127
128    /// Read a character-string.
129    ///
130    /// According to RFC 1035:
131    ///
132    /// > <character-string> is a single length octet followed
133    /// > by that number of characters. <character-string> is
134    /// > treated as binary information, and can be up to 256
135    /// > characters in length (including the length octet).
136    pub fn read_character_string(&mut self) -> Result<Vec<u8>, DecodeError> {
137        let length_octet = self.read_byte()? as usize;
138        let mut res = vec![0_u8; length_octet];
139        self.read(&mut res)?;
140        Ok(res)
141    }
142
143    /// Reads a big-endian unsigned 16 bit integer.
144    pub fn read_u16(&mut self) -> Result<u16, DecodeError> {
145        let mut buf = [0; 2];
146        self.read(&mut buf)?;
147        Ok(u16::from_be_bytes(buf))
148    }
149
150    /// Reads a big-endian unsigned 32 bit integer.
151    pub fn read_u32(&mut self) -> Result<u32, DecodeError> {
152        let mut buf = [0; 4];
153        self.read(&mut buf)?;
154        Ok(u32::from_be_bytes(buf))
155    }
156
157    /// Reads `n` bytes, which are inserted at the end of the given buffer.
158    pub fn read_into(&mut self, buf: &mut Vec<u8>, n: usize) -> Result<(), DecodeError> {
159        let len = buf.len();
160        buf.resize(len + n, 0);
161        self.read(&mut buf[len..])
162    }
163
164    /// Reads a name from the message.
165    pub fn read_name(&mut self) -> Result<String, DecodeError> {
166        // Start position, used to check against pointer references
167        let start_pos = self.data.position();
168        // Offset to return to if we've finished parsing a pointer reference
169        let mut restore = None;
170
171        let mut res = String::new();
172        let mut total_read = 0;
173
174        loop {
175            let len = self.read_byte()?;
176
177            if len == 0 {
178                if total_read + 1 > NAME_LIMIT {
179                    return Err(DecodeError::InvalidName);
180                }
181                break;
182            }
183
184            // If the length flag starts with "11", it will be followed by a
185            // pointer reference. For more information see RFC 1035 section
186            // 4.1.4 (Message compression).
187            // Prefix "00" means "no compression". Prefixes 0x01 and 0x10
188            // are reserved for future use.
189            let compressed = match len >> 6 {
190                0b11 => true,
191                0b00 => false,
192                _ => return Err(DecodeError::InvalidMessage),
193            };
194
195            if compressed {
196                // The beginning of a pointer reference. 14 bit denote the
197                // offset from the start of the message.
198                let hi = (len & 0b00111111) as u64;
199                let lo = self.read_byte()? as u64;
200                let offset = (hi << 8) | lo;
201
202                // To prevent an infinite loop, we require the pointer to
203                // point before the start of this name.
204                if offset >= start_pos {
205                    return Err(DecodeError::InvalidName);
206                }
207
208                if restore.is_none() {
209                    restore = Some(self.data.position());
210                }
211
212                self.data.set_position(offset);
213                continue;
214            }
215
216            if total_read + 1 + len as usize > NAME_LIMIT {
217                return Err(DecodeError::InvalidName);
218            }
219            total_read += 1 + len as usize;
220
221            self.read_segment(&mut res, len as usize)?;
222        }
223
224        if res.is_empty() {
225            res.push('.');
226        } else {
227            res.shrink_to_fit();
228        }
229
230        if let Some(pos) = restore {
231            self.data.set_position(pos);
232        }
233
234        Ok(res)
235    }
236
237    fn read_segment(&mut self, buf: &mut String, len: usize) -> Result<(), DecodeError> {
238        let mut bytes = [0; 64];
239
240        self.read(&mut bytes[..len])?;
241
242        let seg = &bytes[..len];
243
244        if !seg.is_ascii() {
245            return Err(DecodeError::InvalidName);
246        }
247
248        // We just verified this was ASCII, so it's safe.
249        let s = unsafe { from_utf8_unchecked(seg) };
250
251        if !is_valid_segment(s) {
252            return Err(DecodeError::InvalidName);
253        }
254
255        let label = match idna::to_unicode(s) {
256            Ok(s) => s,
257            Err(_) => return Err(DecodeError::InvalidName),
258        };
259
260        buf.push_str(&label);
261        buf.push('.');
262        Ok(())
263    }
264
265    fn consume(&mut self, n: u64) {
266        let p = self.data.position();
267        self.data.set_position(p + n);
268    }
269
270    /// Called at the end of message parsing. Returns `Err(ExtraneousData)`
271    /// if there are any unread bytes remaining.
272    fn finish(self) -> Result<(), DecodeError> {
273        if self.remaining() == 0 {
274            Ok(())
275        } else {
276            Err(DecodeError::ExtraneousData)
277        }
278    }
279
280    /// Reads a message header
281    fn read_header(&mut self) -> Result<FullHeader, DecodeError> {
282        let mut buf = [0; 12];
283
284        self.read(&mut buf)?;
285
286        let hdr: HeaderData = unsafe { transmute(buf) };
287
288        let id = u16::from_be(hdr.id);
289
290        // 1 bit: query or response flag
291        let qr = hdr.flags0 & 0b10000000;
292        // 4 bits: opcode
293        let op = hdr.flags0 & 0b01111000;
294        // 1 bit: authoritative answer flag
295        let aa = hdr.flags0 & 0b00000100;
296        // 1 bit: truncation flag
297        let tc = hdr.flags0 & 0b00000010;
298        // 1 bit: recursion desired flag
299        let rd = hdr.flags0 & 0b00000001;
300
301        // 1 bit: recursion available flag
302        let ra = hdr.flags1 & 0b10000000;
303        // 3 bits: reserved for future use
304        //     = hdr.flags1 & 0b01110000;
305        // 4 bits: response code
306        let rc = hdr.flags1 & 0b00001111;
307
308        let qd_count = u16::from_be(hdr.qd_count);
309        let an_count = u16::from_be(hdr.an_count);
310        let ns_count = u16::from_be(hdr.ns_count);
311        let ar_count = u16::from_be(hdr.ar_count);
312
313        Ok(FullHeader {
314            id,
315            qr: if qr == 0 { Qr::Query } else { Qr::Response },
316            op: OpCode::from_u8(op),
317            authoritative: aa != 0,
318            truncated: tc != 0,
319            recursion_desired: rd != 0,
320            recursion_available: ra != 0,
321            rcode: RCode::from_u8(rc),
322            qd_count,
323            an_count,
324            ns_count,
325            ar_count,
326        })
327    }
328
329    /// Reads a question item
330    fn read_question(&mut self) -> Result<Question, DecodeError> {
331        let name = self.read_name()?;
332
333        let mut buf = [0; 4];
334
335        self.read(&mut buf)?;
336
337        let msg: QuestionData = unsafe { transmute(buf) };
338
339        let q_type = u16::from_be(msg.q_type);
340        let q_class = u16::from_be(msg.q_class);
341
342        Ok(Question {
343            name,
344            q_type: RecordType::from_u16(q_type),
345            q_class: Class::from_u16(q_class),
346        })
347    }
348
349    /// Reads a resource record item
350    fn read_resource(&mut self) -> Result<Resource<'a>, DecodeError> {
351        let name = self.read_name()?;
352
353        let mut buf = [0; 10];
354
355        self.read(&mut buf)?;
356
357        let msg: ResourceData = unsafe { transmute(buf) };
358
359        let r_type = u16::from_be(msg.r_type);
360        let r_class = u16::from_be(msg.r_class);
361        let ttl = u32::from_be(msg.ttl);
362        let length = u16::from_be(msg.length);
363
364        let data = *self.data.get_ref();
365        let offset = self.data.position() as usize;
366
367        let r_data = &data[..offset + length as usize];
368        self.consume(length as u64);
369
370        Ok(Resource {
371            name,
372            r_type: RecordType::from_u16(r_type),
373            r_class: Class::from_u16(r_class),
374            ttl,
375            data: Borrowed(r_data),
376            offset,
377        })
378    }
379}
380
381/// Writes a single DNS message as a series of bytes.
382pub struct MsgWriter<'a> {
383    data: Cursor<&'a mut [u8]>,
384}
385
386impl<'a> MsgWriter<'a> {
387    /// Constructs a new message writer that will write into the given byte
388    /// slice.
389    pub fn new(data: &mut [u8]) -> MsgWriter {
390        MsgWriter {
391            data: Cursor::new(data),
392        }
393    }
394
395    /// Returns the number of bytes written so far.
396    pub fn written(&self) -> usize {
397        self.data.position() as usize
398    }
399
400    /// Returns a subslice of the wrapped byte slice that contains only the
401    /// bytes written.
402    pub fn into_bytes(self) -> &'a [u8] {
403        let n = self.written();
404        &self.data.into_inner()[..n]
405    }
406
407    /// Writes a series of bytes to the message. Returns `Err(TooLong)` if the
408    /// whole buffer cannot be written.
409    pub fn write(&mut self, data: &[u8]) -> Result<(), EncodeError> {
410        if self.written() + data.len() > MESSAGE_LIMIT {
411            // No matter the size of the buffer,
412            // we always want to stop at the hard-coded message limit.
413            Err(EncodeError::TooLong)
414        } else {
415            self.data.write_all(data).map_err(|_| EncodeError::TooLong)
416        }
417    }
418
419    /// Write a character string, as defined by RFC 1035.
420    pub fn write_character_string(&mut self, data: &[u8]) -> Result<(), EncodeError> {
421        let len = data.len();
422
423        if len > 255 {
424            Err(EncodeError::TooLong)
425        } else {
426            self.write_byte(len as u8)?;
427            self.write(data)
428        }
429    }
430
431    /// Writes a name to the message.
432    pub fn write_name(&mut self, name: &str) -> Result<(), EncodeError> {
433        if !is_valid_name(name) {
434            Err(EncodeError::InvalidName)
435        } else if name == "." {
436            self.write_byte(0)
437        } else {
438            let mut total_len = 0;
439
440            for seg in name.split('.') {
441                let seg = match idna::to_ascii(seg) {
442                    Ok(seg) => seg,
443                    Err(_) => return Err(EncodeError::InvalidName),
444                };
445
446                if !is_valid_segment(&seg) {
447                    return Err(EncodeError::InvalidName);
448                }
449
450                if seg.len() > LABEL_LIMIT {
451                    return Err(EncodeError::InvalidName);
452                }
453
454                // Add the size octet and the segment length
455                total_len += 1 + seg.len();
456
457                if total_len > NAME_LIMIT {
458                    return Err(EncodeError::InvalidName);
459                }
460
461                self.write_byte(seg.len() as u8)?;
462                self.write(seg.as_bytes())?;
463            }
464
465            if !name.ends_with('.') {
466                if total_len + 1 > NAME_LIMIT {
467                    return Err(EncodeError::InvalidName);
468                }
469                self.write_byte(0)?;
470            }
471
472            Ok(())
473        }
474    }
475
476    /// Writes a single byte to the message.
477    pub fn write_byte(&mut self, data: u8) -> Result<(), EncodeError> {
478        self.write(&[data])
479    }
480
481    /// Writes an unsigned 16 bit integer in big-endian format.
482    pub fn write_u16(&mut self, data: u16) -> Result<(), EncodeError> {
483        let data: [u8; 2] = data.to_be_bytes();
484        self.write(&data)
485    }
486
487    /// Writes an unsigned 32 bit integer in big-endian format.
488    pub fn write_u32(&mut self, data: u32) -> Result<(), EncodeError> {
489        let data: [u8; 4] = data.to_be_bytes();
490        self.write(&data)
491    }
492
493    /// Writes a message header
494    fn write_header(&mut self, header: &FullHeader) -> Result<(), EncodeError> {
495        let mut hdr: HeaderData = unsafe { zeroed() };
496
497        // 2 bytes: message ID
498        hdr.id = header.id.to_be();
499
500        // 1 bit: query or response flag
501        hdr.flags0 |= (header.qr as u8 & 1) << 7;
502        // 4 bits: opcode
503        hdr.flags0 |= (header.op.to_u8() & 0b1111) << 3;
504        // 1 bit: authoritative answer flag
505        hdr.flags0 |= (header.authoritative as u8) << 2;
506        // 1 bit: truncation flag
507        hdr.flags0 |= (header.truncated as u8) << 1;
508        // 1 bit: recursion desired flag
509        hdr.flags0 |= header.recursion_desired as u8;
510
511        // 1 bit: recursion available flag
512        hdr.flags1 |= (header.recursion_available as u8) << 7;
513        // 3 bits: reserved for future use
514        // .flags1 |= (0 as u8 & 0b111) << 4;
515        // 4 bits: response code
516        hdr.flags1 |= header.rcode.to_u8() & 0b1111;
517
518        hdr.qd_count = header.qd_count.to_be();
519        hdr.an_count = header.an_count.to_be();
520        hdr.ns_count = header.ns_count.to_be();
521        hdr.ar_count = header.ar_count.to_be();
522
523        let buf: [u8; 12] = unsafe { transmute(hdr) };
524
525        self.write(&buf)
526    }
527
528    /// Writes a question item
529    fn write_question(&mut self, question: &Question) -> Result<(), EncodeError> {
530        self.write_name(&question.name)?;
531
532        let mut qd: QuestionData = unsafe { zeroed() };
533
534        qd.q_type = question.q_type.to_u16().to_be();
535        qd.q_class = question.q_class.to_u16().to_be();
536
537        let buf: [u8; 4] = unsafe { transmute(qd) };
538
539        self.write(&buf)
540    }
541
542    /// Writes a resource record item
543    fn write_resource(&mut self, resource: &Resource) -> Result<(), EncodeError> {
544        self.write_name(&resource.name)?;
545
546        let mut rd: ResourceData = unsafe { zeroed() };
547
548        let rdata = resource.get_rdata();
549
550        rd.r_type = resource.r_type.to_u16().to_be();
551        rd.r_class = resource.r_class.to_u16().to_be();
552        rd.ttl = resource.ttl.to_be();
553        rd.length = to_u16(rdata.len())?.to_be();
554
555        let buf: [u8; 10] = unsafe { transmute(rd) };
556
557        self.write(&buf)?;
558        self.write(rdata)
559    }
560}
561
562/// Returns a sequential ID value from a thread-local random starting value.
563pub fn generate_id() -> u16 {
564    // It's not really necessary for these to be sequential, but it avoids the
565    // 1-in-65536 chance of producing the same random number twice in a row.
566    thread_local!(static ID: Cell<u16> = Cell::new(random()));
567    ID.with(|id| {
568        let value = id.get();
569        id.set(value.wrapping_add(1));
570        value
571    })
572}
573
574/// Returns whether the given string appears to be a valid hostname.
575/// The contents of the name (i.e. characters in labels) are not checked here;
576/// only the structure of the name is validated.
577fn is_valid_name(name: &str) -> bool {
578    let len = name.len();
579    len != 0 && (len == 1 || !name.starts_with('.')) && !name.contains("..")
580}
581
582/// Returns whether the given string constitutes a valid name segment.
583/// This check is not as strict as internet DNS servers will be. It only checks
584/// for basic sanity of input. If an invalid name is given, a DNS server will
585/// respond that it doesn't exist, anyway.
586fn is_valid_segment(s: &str) -> bool {
587    !(s.starts_with('-') || s.ends_with('-'))
588        && s.chars()
589            .all(|c| !(c == '.' || c.is_whitespace() || c.is_control()))
590}
591
592/// Represents a DNS message.
593#[derive(Clone, Debug, Default, Eq, PartialEq)]
594pub struct Message<'a> {
595    /// Describes the content of the remainder of the message.
596    pub header: Header,
597    /// Carries the question of query type messages.
598    pub question: Vec<Question>,
599    /// Resource records that answer the query
600    pub answer: Vec<Resource<'a>>,
601    /// Resource records that point to an authoritative name server
602    pub authority: Vec<Resource<'a>>,
603    /// Resource records that relate to the query, but are not strictly
604    /// answers for the question.
605    pub additional: Vec<Resource<'a>>,
606}
607
608impl<'a> Message<'a> {
609    /// Constructs a new `Message` with a random id value.
610    pub fn new() -> Message<'a> {
611        Message {
612            header: Header::new(),
613            ..Default::default()
614        }
615    }
616
617    /// Constructs a new `Message` with the given id value.
618    pub fn with_id(id: u16) -> Message<'a> {
619        Message {
620            header: Header::with_id(id),
621            ..Default::default()
622        }
623    }
624
625    /// Decodes a message from a series of bytes.
626    pub fn decode(data: &[u8]) -> Result<Message, DecodeError> {
627        let mut r = MsgReader::new(data);
628
629        let header = r.read_header()?;
630        let mut msg = Message {
631            header: header.to_header(),
632            // TODO: Cap these values to prevent abuse?
633            question: Vec::with_capacity(header.qd_count as usize),
634            answer: Vec::with_capacity(header.an_count as usize),
635            authority: Vec::with_capacity(header.ns_count as usize),
636            additional: Vec::with_capacity(header.ar_count as usize),
637        };
638
639        for _ in 0..header.qd_count {
640            msg.question.push(r.read_question()?);
641        }
642
643        for _ in 0..header.an_count {
644            msg.answer.push(r.read_resource()?);
645        }
646
647        for _ in 0..header.ns_count {
648            msg.authority.push(r.read_resource()?);
649        }
650
651        for _ in 0..header.ar_count {
652            msg.additional.push(r.read_resource()?);
653        }
654
655        r.finish()?;
656        Ok(msg)
657    }
658
659    /// Encodes a message to a series of bytes. On success, returns a subslice
660    /// of the given buffer containing only the encoded message bytes.
661    pub fn encode<'buf>(&self, buf: &'buf mut [u8]) -> Result<&'buf [u8], EncodeError> {
662        let mut w = MsgWriter::new(buf);
663        let hdr = &self.header;
664
665        let header = FullHeader {
666            id: hdr.id,
667            qr: hdr.qr,
668            op: hdr.op,
669            authoritative: hdr.authoritative,
670            truncated: hdr.truncated,
671            recursion_desired: hdr.recursion_desired,
672            recursion_available: hdr.recursion_available,
673            rcode: hdr.rcode,
674            qd_count: to_u16(self.question.len())?,
675            an_count: to_u16(self.answer.len())?,
676            ns_count: to_u16(self.authority.len())?,
677            ar_count: to_u16(self.additional.len())?,
678        };
679
680        w.write_header(&header)?;
681
682        for q in &self.question {
683            w.write_question(q)?;
684        }
685        for r in &self.answer {
686            w.write_resource(r)?;
687        }
688        for r in &self.authority {
689            w.write_resource(r)?;
690        }
691        for r in &self.additional {
692            w.write_resource(r)?;
693        }
694
695        Ok(w.into_bytes())
696    }
697
698    /// Returns a `DnsError` if the message response code is an error.
699    pub fn get_error(&self) -> Result<(), DnsError> {
700        if self.header.rcode == RCode::NoError {
701            Ok(())
702        } else {
703            Err(DnsError(self.header.rcode))
704        }
705    }
706
707    /// Returns an iterator over the records in this message.
708    pub fn records(&self) -> RecordIter {
709        RecordIter {
710            iters: [
711                self.answer.iter(),
712                self.authority.iter(),
713                self.additional.iter(),
714            ],
715        }
716    }
717
718    /// Consumes the message and returns an iterator over its records.
719    pub fn into_records(self) -> RecordIntoIter<'a> {
720        RecordIntoIter {
721            iters: [
722                self.answer.into_iter(),
723                self.authority.into_iter(),
724                self.additional.into_iter(),
725            ],
726        }
727    }
728}
729
730/// Yields `&Resource` items from a Message.
731pub struct RecordIter<'a> {
732    iters: [Iter<'a, Resource<'a>>; 3],
733}
734
735impl<'a> Iterator for RecordIter<'a> {
736    type Item = &'a Resource<'a>;
737
738    fn next(&mut self) -> Option<&'a Resource<'a>> {
739        self.iters[0]
740            .next()
741            .or_else(|| self.iters[1].next())
742            .or_else(|| self.iters[2].next())
743    }
744}
745
746/// Yields `Resource` items from a Message.
747pub struct RecordIntoIter<'a> {
748    iters: [IntoIter<Resource<'a>>; 3],
749}
750
751impl<'a> Iterator for RecordIntoIter<'a> {
752    type Item = Resource<'a>;
753
754    fn next(&mut self) -> Option<Resource<'a>> {
755        self.iters[0]
756            .next()
757            .or_else(|| self.iters[1].next())
758            .or_else(|| self.iters[2].next())
759    }
760}
761
762/// Represents a message header.
763#[derive(Copy, Clone, Debug, Eq, PartialEq)]
764pub struct Header {
765    /// Transaction ID; corresponding replies will have the same ID.
766    pub id: u16,
767    /// Query or response
768    pub qr: Qr,
769    /// Kind of query
770    pub op: OpCode,
771    /// In a response, indicates that the responding name server is an
772    /// authority for the domain name in question section.
773    pub authoritative: bool,
774    /// Indicates whether the message was truncated due to length greater than
775    /// that permitted on the transmission channel.
776    pub truncated: bool,
777    /// In a query, directs the name server to pursue the query recursively.
778    pub recursion_desired: bool,
779    /// In a response, indicates whether recursive queries are available on the
780    /// name server.
781    pub recursion_available: bool,
782    /// Response code
783    pub rcode: RCode,
784}
785
786impl Header {
787    /// Constructs a new `Header` with a random id value.
788    pub fn new() -> Header {
789        Header {
790            id: generate_id(),
791            ..Default::default()
792        }
793    }
794
795    /// Constructs a new `Header` with the given id value.
796    pub fn with_id(id: u16) -> Header {
797        Header {
798            id,
799            ..Default::default()
800        }
801    }
802}
803
804impl Default for Header {
805    fn default() -> Header {
806        Header {
807            id: 0,
808            qr: Qr::Query,
809            op: OpCode::Query,
810            authoritative: false,
811            truncated: false,
812            recursion_desired: false,
813            recursion_available: false,
814            rcode: RCode::NoError,
815        }
816    }
817}
818
819/// Contains all header data decoded from a message.
820#[derive(Copy, Clone, Debug, Eq, PartialEq)]
821struct FullHeader {
822    pub id: u16,
823    pub qr: Qr,
824    pub op: OpCode,
825    pub authoritative: bool,
826    pub truncated: bool,
827    pub recursion_desired: bool,
828    pub recursion_available: bool,
829    pub rcode: RCode,
830    pub qd_count: u16,
831    pub an_count: u16,
832    pub ns_count: u16,
833    pub ar_count: u16,
834}
835
836impl FullHeader {
837    fn to_header(self) -> Header {
838        Header {
839            id: self.id,
840            qr: self.qr,
841            op: self.op,
842            authoritative: self.authoritative,
843            truncated: self.truncated,
844            recursion_desired: self.recursion_desired,
845            recursion_available: self.recursion_available,
846            rcode: self.rcode,
847        }
848    }
849}
850
851impl Default for FullHeader {
852    fn default() -> FullHeader {
853        FullHeader {
854            id: 0,
855            qr: Qr::Query,
856            op: OpCode::Query,
857            authoritative: false,
858            truncated: false,
859            recursion_desired: false,
860            recursion_available: false,
861            rcode: RCode::NoError,
862            qd_count: 0,
863            an_count: 0,
864            ns_count: 0,
865            ar_count: 0,
866        }
867    }
868}
869
870/// Represents a question item.
871#[derive(Clone, Debug, Eq, PartialEq)]
872pub struct Question {
873    /// Query name
874    pub name: String,
875    /// Query type
876    pub q_type: RecordType,
877    /// Query class
878    pub q_class: Class,
879}
880
881impl Question {
882    /// Constructs a new `Question`.
883    pub fn new(name: String, q_type: RecordType, q_class: Class) -> Question {
884        Question {
885            name,
886            q_type,
887            q_class,
888        }
889    }
890}
891
892/// Represents a resource record item.
893#[derive(Clone, Debug, Eq, PartialEq)]
894pub struct Resource<'a> {
895    /// Resource name
896    pub name: String,
897    /// Resource type
898    pub r_type: RecordType,
899    /// Resource class
900    pub r_class: Class,
901    /// Time-to-live
902    pub ttl: u32,
903    /// Message data, up to and including resource record data
904    data: Cow<'a, [u8]>,
905    /// Beginning of rdata within `data`
906    offset: usize,
907}
908
909impl<'a> Resource<'a> {
910    /// Constructs a new `Resource`.
911    pub fn new(name: String, r_type: RecordType, r_class: Class, ttl: u32) -> Resource<'a> {
912        Resource {
913            name,
914            r_type,
915            r_class,
916            ttl,
917            data: Owned(Vec::new()),
918            offset: 0,
919        }
920    }
921
922    /// Returns resource data.
923    pub fn get_rdata(&self) -> &[u8] {
924        &self.data[self.offset..]
925    }
926
927    /// Decodes resource data into the given `Record` type.
928    pub fn read_rdata<R: Record>(&self) -> Result<R, DecodeError> {
929        let mut r = MsgReader::with_offset(&self.data, self.offset);
930        let res = Record::decode(&mut r)?;
931        r.finish()?;
932        Ok(res)
933    }
934
935    /// Encodes resource data from the given `Record` type.
936    pub fn write_rdata<R: Record>(&mut self, record: &R) -> Result<(), EncodeError> {
937        let mut buf = [0; MESSAGE_LIMIT];
938        let mut w = MsgWriter::new(&mut buf[..]);
939        record.encode(&mut w)?;
940        self.data = Owned(w.into_bytes().to_vec());
941        self.offset = 0;
942        Ok(())
943    }
944}
945
946/// Indicates a message is either a query or response.
947#[derive(Copy, Clone, Debug, Eq, PartialEq)]
948#[repr(u8)]
949pub enum Qr {
950    /// Query
951    Query = 0,
952    /// Response
953    Response = 1,
954}
955
956/// Represents the kind of message query.
957#[derive(Copy, Clone, Debug, Eq, PartialEq)]
958pub enum OpCode {
959    /// Query
960    Query,
961    /// Status
962    Status,
963    /// Notify
964    Notify,
965    /// Update
966    Update,
967    /// Unrecognized opcode
968    Other(u8),
969}
970
971impl OpCode {
972    /// Converts a `u8` to an `OpCode`.
973    pub fn from_u8(u: u8) -> OpCode {
974        match u {
975            0 => OpCode::Query,
976            2 => OpCode::Status,
977            4 => OpCode::Notify,
978            5 => OpCode::Update,
979            n => OpCode::Other(n),
980        }
981    }
982
983    /// Converts an `OpCode` to a `u8`.
984    pub fn to_u8(&self) -> u8 {
985        match *self {
986            OpCode::Query => 0,
987            OpCode::Status => 2,
988            OpCode::Notify => 4,
989            OpCode::Update => 5,
990            OpCode::Other(n) => n,
991        }
992    }
993}
994
995/// Represents the response code of a message
996#[derive(Copy, Clone, Debug, Eq, PartialEq)]
997pub enum RCode {
998    /// No error condition.
999    NoError,
1000    /// The server was unable to interpret the query.
1001    FormatError,
1002    /// The name server was unable to process the query due to a failure of
1003    /// the name server.
1004    ServerFailure,
1005    /// Name referenced in query does not exist.
1006    NameError,
1007    /// Requested query kind is not supported by name server.
1008    NotImplemented,
1009    /// The name server refuses to perform the specified operation for policy
1010    /// reasons.
1011    Refused,
1012    /// Unknown response code.
1013    Other(u8),
1014}
1015
1016impl RCode {
1017    /// Returns an error string for the response code.
1018    pub fn get_error(&self) -> &'static str {
1019        match *self {
1020            RCode::NoError => "no error",
1021            RCode::FormatError => "format error",
1022            RCode::ServerFailure => "server failure",
1023            RCode::NameError => "no such name",
1024            RCode::NotImplemented => "not implemented",
1025            RCode::Refused => "refused",
1026            RCode::Other(_) => "unknown response code",
1027        }
1028    }
1029
1030    /// Converts a `u8` to an `RCode`.
1031    pub fn from_u8(u: u8) -> RCode {
1032        match u {
1033            0 => RCode::NoError,
1034            1 => RCode::FormatError,
1035            2 => RCode::ServerFailure,
1036            3 => RCode::NameError,
1037            4 => RCode::NotImplemented,
1038            5 => RCode::Refused,
1039            n => RCode::Other(n),
1040        }
1041    }
1042
1043    /// Converts an `RCode` to a `u8`.
1044    pub fn to_u8(&self) -> u8 {
1045        match *self {
1046            RCode::NoError => 0,
1047            RCode::FormatError => 1,
1048            RCode::ServerFailure => 2,
1049            RCode::NameError => 3,
1050            RCode::NotImplemented => 4,
1051            RCode::Refused => 5,
1052            RCode::Other(n) => n,
1053        }
1054    }
1055}
1056
1057#[derive(Copy, Clone, Debug)]
1058#[repr(C, packed)]
1059struct HeaderData {
1060    id: u16,
1061    flags0: u8,
1062    flags1: u8,
1063    qd_count: u16,
1064    an_count: u16,
1065    ns_count: u16,
1066    ar_count: u16,
1067}
1068
1069#[derive(Copy, Clone, Debug)]
1070#[repr(C, packed)]
1071struct QuestionData {
1072    // name: String, -- dynamically sized
1073    q_type: u16,
1074    q_class: u16,
1075}
1076
1077#[derive(Copy, Clone, Debug)]
1078#[repr(C, packed)]
1079struct ResourceData {
1080    // name: String, -- dynamically sized
1081    r_type: u16,
1082    r_class: u16,
1083    ttl: u32,
1084    length: u16,
1085}
1086
1087fn to_u16(n: usize) -> Result<u16, EncodeError> {
1088    if n > u16::max_value() as usize {
1089        Err(EncodeError::TooLong)
1090    } else {
1091        Ok(n as u16)
1092    }
1093}
1094
1095#[cfg(test)]
1096mod test {
1097    use super::{
1098        is_valid_name, EncodeError, Header, Message, MsgReader, MsgWriter, OpCode, Qr, Question,
1099        RCode, MESSAGE_LIMIT,
1100    };
1101    use crate::record::{Class, RecordType};
1102
1103    #[test]
1104    fn test_idna_name() {
1105        let mut buf = [0; 64];
1106        let mut w = MsgWriter::new(&mut buf);
1107
1108        w.write_name("bücher.de.").unwrap();
1109        w.write_name("ουτοπία.δπθ.gr.").unwrap();
1110
1111        let bytes = w.into_bytes();
1112
1113        assert_eq!(
1114            bytes,
1115            &b"\
1116            \x0dxn--bcher-kva\x02de\x00\
1117            \x0exn--kxae4bafwg\x09xn--pxaix\x02gr\x00\
1118            "[..]
1119        );
1120
1121        let mut r = MsgReader::new(&bytes);
1122
1123        assert_eq!(r.read_name().as_ref().map(|s| &s[..]), Ok("bücher.de."));
1124        assert_eq!(
1125            r.read_name().as_ref().map(|s| &s[..]),
1126            Ok("ουτοπία.δπθ.gr.")
1127        );
1128    }
1129
1130    #[test]
1131    fn test_message() {
1132        let msg = Message {
1133            header: Header {
1134                id: 0xabcd,
1135                qr: Qr::Query,
1136                op: OpCode::Query,
1137                authoritative: false,
1138                truncated: false,
1139                recursion_desired: true,
1140                recursion_available: true,
1141                rcode: RCode::NoError,
1142            },
1143            question: vec![Question::new(
1144                "foo.bar.com.".to_owned(),
1145                RecordType::A,
1146                Class::Internet,
1147            )],
1148            answer: Vec::new(),
1149            authority: Vec::new(),
1150            additional: Vec::new(),
1151        };
1152
1153        let mut buf = [0; 64];
1154        let bytes = msg.encode(&mut buf).unwrap();
1155
1156        assert_eq!(
1157            bytes,
1158            &[
1159                0xab, 0xcd, 0b00000001, 0b10000000, 0, 1, 0, 0, 0, 0, 0, 0, 3, b'f', b'o', b'o', 3,
1160                b'b', b'a', b'r', 3, b'c', b'o', b'm', 0, 0, 1, 0, 1
1161            ][..]
1162        );
1163
1164        let msg2 = Message::decode(&bytes).unwrap();
1165
1166        assert_eq!(msg, msg2);
1167    }
1168
1169    #[test]
1170    fn test_primitives() {
1171        let mut buf = [0; 64];
1172        let mut w = MsgWriter::new(&mut buf);
1173
1174        w.write_byte(0x11).unwrap();
1175        w.write_u16(0x2233).unwrap();
1176        w.write_u32(0x44556677).unwrap();
1177        w.write_name("alpha.bravo.charlie").unwrap();
1178        w.write_name("delta.echo.foxtrot.").unwrap();
1179        w.write_name(".").unwrap();
1180
1181        assert_eq!(w.write_name(""), Err(EncodeError::InvalidName));
1182        assert_eq!(
1183            w.write_name("ohmyglobhowdidthisgethereiamnotgoodwithcomputerrrrrrrrrrrrrrrrrr.org"),
1184            Err(EncodeError::InvalidName)
1185        );
1186
1187        let bytes = w.into_bytes();
1188
1189        assert_eq!(
1190            bytes,
1191            &b"\
1192            \x11\
1193            \x22\x33\
1194            \x44\x55\x66\x77\
1195            \x05alpha\x05bravo\x07charlie\x00\
1196            \x05delta\x04echo\x07foxtrot\x00\
1197            \x00"[..]
1198        );
1199
1200        let mut r = MsgReader::new(&bytes);
1201
1202        assert_eq!(r.read_byte(), Ok(0x11));
1203        assert_eq!(r.read_u16(), Ok(0x2233));
1204        assert_eq!(r.read_u32(), Ok(0x44556677));
1205        assert_eq!(
1206            r.read_name().as_ref().map(|s| &s[..]),
1207            Ok("alpha.bravo.charlie.")
1208        );
1209        assert_eq!(
1210            r.read_name().as_ref().map(|s| &s[..]),
1211            Ok("delta.echo.foxtrot.")
1212        );
1213        assert_eq!(r.read_name().as_ref().map(|s| &s[..]), Ok("."));
1214    }
1215
1216    const LONGEST_NAME: &'static str = "aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaa\
1217         aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaa\
1218         aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaa\
1219         aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaa\
1220         aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaa\
1221         .com";
1222    const LONGEST_NAME_DOT: &'static str = "aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaa\
1223         aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaa\
1224         aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaa\
1225         aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaa\
1226         aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaa\
1227         .com.";
1228    const TOO_LONG_NAME: &'static str = "aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaa\
1229         aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaa\
1230         aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaa\
1231         aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaa\
1232         aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaa\
1233         .com";
1234    const TOO_LONG_NAME_DOT: &'static str = "aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaa\
1235         aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaa\
1236         aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaa\
1237         aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaa\
1238         aaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaa\
1239         .com.";
1240    const TOO_LONG_SEGMENT: &'static str = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\
1241         aaaaaaaaaaaaaa.com";
1242
1243    #[test]
1244    fn test_encode_name() {
1245        let mut buf = [0; MESSAGE_LIMIT];
1246        let mut w = MsgWriter::new(&mut buf);
1247
1248        w.write_name(LONGEST_NAME).unwrap();
1249        w.write_name(LONGEST_NAME_DOT).unwrap();
1250
1251        let bytes = w.into_bytes();
1252        let mut r = MsgReader::new(&bytes);
1253
1254        assert_eq!(r.read_name().as_ref().map(|s| &s[..]), Ok(LONGEST_NAME_DOT));
1255        assert_eq!(r.read_name().as_ref().map(|s| &s[..]), Ok(LONGEST_NAME_DOT));
1256
1257        let mut buf = [0; MESSAGE_LIMIT];
1258        let mut w = MsgWriter::new(&mut buf);
1259
1260        assert_eq!(w.write_name(TOO_LONG_NAME), Err(EncodeError::InvalidName));
1261        assert_eq!(
1262            w.write_name(TOO_LONG_NAME_DOT),
1263            Err(EncodeError::InvalidName)
1264        );
1265        assert_eq!(
1266            w.write_name(TOO_LONG_SEGMENT),
1267            Err(EncodeError::InvalidName)
1268        );
1269    }
1270
1271    #[test]
1272    fn test_valid_name() {
1273        assert!(is_valid_name("."));
1274        assert!(is_valid_name("foo.com."));
1275        assert!(is_valid_name("foo-123.com."));
1276        assert!(is_valid_name("FOO-BAR.COM"));
1277
1278        assert!(!is_valid_name(""));
1279        assert!(!is_valid_name(".foo.com"));
1280        assert!(!is_valid_name("foo..bar.com"));
1281    }
1282}