Skip to main content

stackforge_core/layer/dns/
header.rs

1//! DNS header parsing (RFC 1035 Section 4.1.1).
2//!
3//! The DNS header is 12 bytes:
4//! ```text
5//! +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
6//! |                      ID                         |
7//! +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
8//! |QR|   Opcode  |AA|TC|RD|RA| Z|AD|CD|   RCODE    |
9//! +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
10//! |                    QDCOUNT                       |
11//! +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
12//! |                    ANCOUNT                       |
13//! +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
14//! |                    NSCOUNT                       |
15//! +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
16//! |                    ARCOUNT                       |
17//! +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
18//! ```
19
20use crate::layer::field::{Field, FieldError};
21
22/// DNS header size in bytes.
23pub const DNS_HEADER_LEN: usize = 12;
24
25// Byte offsets within the header
26const ID_OFFSET: usize = 0;
27const FLAGS_OFFSET: usize = 2;
28const QDCOUNT_OFFSET: usize = 4;
29const ANCOUNT_OFFSET: usize = 6;
30const NSCOUNT_OFFSET: usize = 8;
31const ARCOUNT_OFFSET: usize = 10;
32
33// Flag bit positions within the 16-bit flags field
34const FLAG_QR: u16 = 0x8000;
35const FLAG_OPCODE_MASK: u16 = 0x7800;
36const FLAG_OPCODE_SHIFT: u16 = 11;
37const FLAG_AA: u16 = 0x0400;
38const FLAG_TC: u16 = 0x0200;
39const FLAG_RD: u16 = 0x0100;
40const FLAG_RA: u16 = 0x0080;
41const FLAG_Z: u16 = 0x0040;
42const FLAG_AD: u16 = 0x0020;
43const FLAG_CD: u16 = 0x0010;
44const FLAG_RCODE_MASK: u16 = 0x000F;
45
46/// Read the DNS transaction ID.
47#[inline]
48pub fn read_id(buf: &[u8], base: usize) -> Result<u16, FieldError> {
49    u16::read(buf, base + ID_OFFSET)
50}
51
52/// Write the DNS transaction ID.
53#[inline]
54pub fn write_id(buf: &mut [u8], base: usize, id: u16) -> Result<(), FieldError> {
55    id.write(buf, base + ID_OFFSET)
56}
57
58/// Read the 16-bit flags field.
59#[inline]
60fn read_flags(buf: &[u8], base: usize) -> Result<u16, FieldError> {
61    u16::read(buf, base + FLAGS_OFFSET)
62}
63
64/// Write the 16-bit flags field.
65#[inline]
66fn write_flags(buf: &mut [u8], base: usize, flags: u16) -> Result<(), FieldError> {
67    flags.write(buf, base + FLAGS_OFFSET)
68}
69
70/// QR flag: 0 = query, 1 = response.
71#[inline]
72pub fn read_qr(buf: &[u8], base: usize) -> Result<bool, FieldError> {
73    Ok(read_flags(buf, base)? & FLAG_QR != 0)
74}
75
76pub fn write_qr(buf: &mut [u8], base: usize, qr: bool) -> Result<(), FieldError> {
77    let mut flags = read_flags(buf, base)?;
78    if qr {
79        flags |= FLAG_QR;
80    } else {
81        flags &= !FLAG_QR;
82    }
83    write_flags(buf, base, flags)
84}
85
86/// Opcode (4 bits).
87#[inline]
88pub fn read_opcode(buf: &[u8], base: usize) -> Result<u8, FieldError> {
89    let flags = read_flags(buf, base)?;
90    Ok(((flags & FLAG_OPCODE_MASK) >> FLAG_OPCODE_SHIFT) as u8)
91}
92
93pub fn write_opcode(buf: &mut [u8], base: usize, opcode: u8) -> Result<(), FieldError> {
94    let mut flags = read_flags(buf, base)?;
95    flags &= !FLAG_OPCODE_MASK;
96    flags |= (u16::from(opcode) << FLAG_OPCODE_SHIFT) & FLAG_OPCODE_MASK;
97    write_flags(buf, base, flags)
98}
99
100/// AA flag: Authoritative Answer.
101#[inline]
102pub fn read_aa(buf: &[u8], base: usize) -> Result<bool, FieldError> {
103    Ok(read_flags(buf, base)? & FLAG_AA != 0)
104}
105
106pub fn write_aa(buf: &mut [u8], base: usize, aa: bool) -> Result<(), FieldError> {
107    let mut flags = read_flags(buf, base)?;
108    if aa {
109        flags |= FLAG_AA;
110    } else {
111        flags &= !FLAG_AA;
112    }
113    write_flags(buf, base, flags)
114}
115
116/// TC flag: Truncation.
117#[inline]
118pub fn read_tc(buf: &[u8], base: usize) -> Result<bool, FieldError> {
119    Ok(read_flags(buf, base)? & FLAG_TC != 0)
120}
121
122pub fn write_tc(buf: &mut [u8], base: usize, tc: bool) -> Result<(), FieldError> {
123    let mut flags = read_flags(buf, base)?;
124    if tc {
125        flags |= FLAG_TC;
126    } else {
127        flags &= !FLAG_TC;
128    }
129    write_flags(buf, base, flags)
130}
131
132/// RD flag: Recursion Desired.
133#[inline]
134pub fn read_rd(buf: &[u8], base: usize) -> Result<bool, FieldError> {
135    Ok(read_flags(buf, base)? & FLAG_RD != 0)
136}
137
138pub fn write_rd(buf: &mut [u8], base: usize, rd: bool) -> Result<(), FieldError> {
139    let mut flags = read_flags(buf, base)?;
140    if rd {
141        flags |= FLAG_RD;
142    } else {
143        flags &= !FLAG_RD;
144    }
145    write_flags(buf, base, flags)
146}
147
148/// RA flag: Recursion Available.
149#[inline]
150pub fn read_ra(buf: &[u8], base: usize) -> Result<bool, FieldError> {
151    Ok(read_flags(buf, base)? & FLAG_RA != 0)
152}
153
154pub fn write_ra(buf: &mut [u8], base: usize, ra: bool) -> Result<(), FieldError> {
155    let mut flags = read_flags(buf, base)?;
156    if ra {
157        flags |= FLAG_RA;
158    } else {
159        flags &= !FLAG_RA;
160    }
161    write_flags(buf, base, flags)
162}
163
164/// Z flag (reserved).
165#[inline]
166pub fn read_z(buf: &[u8], base: usize) -> Result<bool, FieldError> {
167    Ok(read_flags(buf, base)? & FLAG_Z != 0)
168}
169
170pub fn write_z(buf: &mut [u8], base: usize, z: bool) -> Result<(), FieldError> {
171    let mut flags = read_flags(buf, base)?;
172    if z {
173        flags |= FLAG_Z;
174    } else {
175        flags &= !FLAG_Z;
176    }
177    write_flags(buf, base, flags)
178}
179
180/// AD flag: Authenticated Data (DNSSEC).
181#[inline]
182pub fn read_ad(buf: &[u8], base: usize) -> Result<bool, FieldError> {
183    Ok(read_flags(buf, base)? & FLAG_AD != 0)
184}
185
186pub fn write_ad(buf: &mut [u8], base: usize, ad: bool) -> Result<(), FieldError> {
187    let mut flags = read_flags(buf, base)?;
188    if ad {
189        flags |= FLAG_AD;
190    } else {
191        flags &= !FLAG_AD;
192    }
193    write_flags(buf, base, flags)
194}
195
196/// CD flag: Checking Disabled (DNSSEC).
197#[inline]
198pub fn read_cd(buf: &[u8], base: usize) -> Result<bool, FieldError> {
199    Ok(read_flags(buf, base)? & FLAG_CD != 0)
200}
201
202pub fn write_cd(buf: &mut [u8], base: usize, cd: bool) -> Result<(), FieldError> {
203    let mut flags = read_flags(buf, base)?;
204    if cd {
205        flags |= FLAG_CD;
206    } else {
207        flags &= !FLAG_CD;
208    }
209    write_flags(buf, base, flags)
210}
211
212/// RCODE (4 bits): Response code.
213#[inline]
214pub fn read_rcode(buf: &[u8], base: usize) -> Result<u8, FieldError> {
215    let flags = read_flags(buf, base)?;
216    Ok((flags & FLAG_RCODE_MASK) as u8)
217}
218
219pub fn write_rcode(buf: &mut [u8], base: usize, rcode: u8) -> Result<(), FieldError> {
220    let mut flags = read_flags(buf, base)?;
221    flags &= !FLAG_RCODE_MASK;
222    flags |= u16::from(rcode) & FLAG_RCODE_MASK;
223    write_flags(buf, base, flags)
224}
225
226/// Question count.
227#[inline]
228pub fn read_qdcount(buf: &[u8], base: usize) -> Result<u16, FieldError> {
229    u16::read(buf, base + QDCOUNT_OFFSET)
230}
231
232#[inline]
233pub fn write_qdcount(buf: &mut [u8], base: usize, count: u16) -> Result<(), FieldError> {
234    count.write(buf, base + QDCOUNT_OFFSET)
235}
236
237/// Answer count.
238#[inline]
239pub fn read_ancount(buf: &[u8], base: usize) -> Result<u16, FieldError> {
240    u16::read(buf, base + ANCOUNT_OFFSET)
241}
242
243#[inline]
244pub fn write_ancount(buf: &mut [u8], base: usize, count: u16) -> Result<(), FieldError> {
245    count.write(buf, base + ANCOUNT_OFFSET)
246}
247
248/// Authority count.
249#[inline]
250pub fn read_nscount(buf: &[u8], base: usize) -> Result<u16, FieldError> {
251    u16::read(buf, base + NSCOUNT_OFFSET)
252}
253
254#[inline]
255pub fn write_nscount(buf: &mut [u8], base: usize, count: u16) -> Result<(), FieldError> {
256    count.write(buf, base + NSCOUNT_OFFSET)
257}
258
259/// Additional count.
260#[inline]
261pub fn read_arcount(buf: &[u8], base: usize) -> Result<u16, FieldError> {
262    u16::read(buf, base + ARCOUNT_OFFSET)
263}
264
265#[inline]
266pub fn write_arcount(buf: &mut [u8], base: usize, count: u16) -> Result<(), FieldError> {
267    count.write(buf, base + ARCOUNT_OFFSET)
268}
269
270/// Build a raw 16-bit flags value from individual components.
271#[must_use]
272pub fn build_flags(
273    qr: bool,
274    opcode: u8,
275    aa: bool,
276    tc: bool,
277    rd: bool,
278    ra: bool,
279    z: bool,
280    ad: bool,
281    cd: bool,
282    rcode: u8,
283) -> u16 {
284    let mut flags: u16 = 0;
285    if qr {
286        flags |= FLAG_QR;
287    }
288    flags |= (u16::from(opcode) << FLAG_OPCODE_SHIFT) & FLAG_OPCODE_MASK;
289    if aa {
290        flags |= FLAG_AA;
291    }
292    if tc {
293        flags |= FLAG_TC;
294    }
295    if rd {
296        flags |= FLAG_RD;
297    }
298    if ra {
299        flags |= FLAG_RA;
300    }
301    if z {
302        flags |= FLAG_Z;
303    }
304    if ad {
305        flags |= FLAG_AD;
306    }
307    if cd {
308        flags |= FLAG_CD;
309    }
310    flags |= u16::from(rcode) & FLAG_RCODE_MASK;
311    flags
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    fn make_header(
319        id: u16,
320        qr: bool,
321        opcode: u8,
322        aa: bool,
323        tc: bool,
324        rd: bool,
325        ra: bool,
326        z: bool,
327        ad: bool,
328        cd: bool,
329        rcode: u8,
330        qdcount: u16,
331        ancount: u16,
332        nscount: u16,
333        arcount: u16,
334    ) -> Vec<u8> {
335        let flags = build_flags(qr, opcode, aa, tc, rd, ra, z, ad, cd, rcode);
336        let mut buf = vec![0u8; DNS_HEADER_LEN];
337        buf[0..2].copy_from_slice(&id.to_be_bytes());
338        buf[2..4].copy_from_slice(&flags.to_be_bytes());
339        buf[4..6].copy_from_slice(&qdcount.to_be_bytes());
340        buf[6..8].copy_from_slice(&ancount.to_be_bytes());
341        buf[8..10].copy_from_slice(&nscount.to_be_bytes());
342        buf[10..12].copy_from_slice(&arcount.to_be_bytes());
343        buf
344    }
345
346    #[test]
347    fn test_read_query_header() {
348        let buf = make_header(
349            0x1234, false, 0, false, false, true, false, false, false, false, 0, 1, 0, 0, 0,
350        );
351        assert_eq!(read_id(&buf, 0).unwrap(), 0x1234);
352        assert!(!read_qr(&buf, 0).unwrap());
353        assert_eq!(read_opcode(&buf, 0).unwrap(), 0);
354        assert!(!read_aa(&buf, 0).unwrap());
355        assert!(!read_tc(&buf, 0).unwrap());
356        assert!(read_rd(&buf, 0).unwrap());
357        assert!(!read_ra(&buf, 0).unwrap());
358        assert_eq!(read_rcode(&buf, 0).unwrap(), 0);
359        assert_eq!(read_qdcount(&buf, 0).unwrap(), 1);
360        assert_eq!(read_ancount(&buf, 0).unwrap(), 0);
361    }
362
363    #[test]
364    fn test_read_response_header() {
365        let buf = make_header(
366            0xABCD, true, 0, true, false, true, true, false, true, false, 0, 1, 2, 0, 1,
367        );
368        assert_eq!(read_id(&buf, 0).unwrap(), 0xABCD);
369        assert!(read_qr(&buf, 0).unwrap());
370        assert!(read_aa(&buf, 0).unwrap());
371        assert!(read_rd(&buf, 0).unwrap());
372        assert!(read_ra(&buf, 0).unwrap());
373        assert!(read_ad(&buf, 0).unwrap());
374        assert!(!read_cd(&buf, 0).unwrap());
375        assert_eq!(read_qdcount(&buf, 0).unwrap(), 1);
376        assert_eq!(read_ancount(&buf, 0).unwrap(), 2);
377        assert_eq!(read_arcount(&buf, 0).unwrap(), 1);
378    }
379
380    #[test]
381    fn test_write_flags() {
382        let mut buf = vec![0u8; DNS_HEADER_LEN];
383        write_id(&mut buf, 0, 0x5678).unwrap();
384        write_qr(&mut buf, 0, true).unwrap();
385        write_opcode(&mut buf, 0, 0).unwrap();
386        write_aa(&mut buf, 0, true).unwrap();
387        write_rd(&mut buf, 0, true).unwrap();
388        write_ra(&mut buf, 0, true).unwrap();
389        write_rcode(&mut buf, 0, 3).unwrap(); // NXDOMAIN
390        write_qdcount(&mut buf, 0, 1).unwrap();
391        write_ancount(&mut buf, 0, 0).unwrap();
392
393        assert_eq!(read_id(&buf, 0).unwrap(), 0x5678);
394        assert!(read_qr(&buf, 0).unwrap());
395        assert!(read_aa(&buf, 0).unwrap());
396        assert!(read_rd(&buf, 0).unwrap());
397        assert!(read_ra(&buf, 0).unwrap());
398        assert_eq!(read_rcode(&buf, 0).unwrap(), 3);
399    }
400
401    #[test]
402    fn test_build_flags() {
403        let flags = build_flags(true, 0, true, false, true, true, false, true, false, 3);
404        assert_eq!(flags & FLAG_QR, FLAG_QR);
405        assert_eq!(flags & FLAG_AA, FLAG_AA);
406        assert_eq!(flags & FLAG_RD, FLAG_RD);
407        assert_eq!(flags & FLAG_RA, FLAG_RA);
408        assert_eq!(flags & FLAG_AD, FLAG_AD);
409        assert_eq!(flags & FLAG_RCODE_MASK, 3);
410    }
411
412    #[test]
413    fn test_opcode_values() {
414        let mut buf = vec![0u8; DNS_HEADER_LEN];
415        for opcode in 0..=15 {
416            write_opcode(&mut buf, 0, opcode).unwrap();
417            assert_eq!(read_opcode(&buf, 0).unwrap(), opcode);
418        }
419    }
420
421    #[test]
422    fn test_rcode_values() {
423        let mut buf = vec![0u8; DNS_HEADER_LEN];
424        for rcode in 0..=15 {
425            write_rcode(&mut buf, 0, rcode).unwrap();
426            assert_eq!(read_rcode(&buf, 0).unwrap(), rcode);
427        }
428    }
429
430    #[test]
431    fn test_header_at_offset() {
432        // Simulate DNS header not at position 0 (e.g., after UDP)
433        let mut buf = vec![0xAA; 20 + DNS_HEADER_LEN];
434        let base = 20;
435        write_id(&mut buf, base, 0x9999).unwrap();
436        write_qdcount(&mut buf, base, 5).unwrap();
437        assert_eq!(read_id(&buf, base).unwrap(), 0x9999);
438        assert_eq!(read_qdcount(&buf, base).unwrap(), 5);
439    }
440}