Skip to main content

stackforge_core/layer/dns/
header.rs

1//! DNS header parsing (RFC 1035 Section 4.1.1).
2//!
3//! The DNS header is 12 bytes:
4//! ```text
5//! +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
6//! |                      ID                         |
7//! +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
8//! |QR|   Opcode  |AA|TC|RD|RA| Z|AD|CD|   RCODE    |
9//! +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
10//! |                    QDCOUNT                       |
11//! +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
12//! |                    ANCOUNT                       |
13//! +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
14//! |                    NSCOUNT                       |
15//! +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
16//! |                    ARCOUNT                       |
17//! +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
18//! ```
19
20use crate::layer::field::{Field, FieldError};
21
22/// DNS header size in bytes.
23pub const DNS_HEADER_LEN: usize = 12;
24
25// Byte offsets within the header
26const ID_OFFSET: usize = 0;
27const FLAGS_OFFSET: usize = 2;
28const QDCOUNT_OFFSET: usize = 4;
29const ANCOUNT_OFFSET: usize = 6;
30const NSCOUNT_OFFSET: usize = 8;
31const ARCOUNT_OFFSET: usize = 10;
32
33// Flag bit positions within the 16-bit flags field
34const FLAG_QR: u16 = 0x8000;
35const FLAG_OPCODE_MASK: u16 = 0x7800;
36const FLAG_OPCODE_SHIFT: u16 = 11;
37const FLAG_AA: u16 = 0x0400;
38const FLAG_TC: u16 = 0x0200;
39const FLAG_RD: u16 = 0x0100;
40const FLAG_RA: u16 = 0x0080;
41const FLAG_Z: u16 = 0x0040;
42const FLAG_AD: u16 = 0x0020;
43const FLAG_CD: u16 = 0x0010;
44const FLAG_RCODE_MASK: u16 = 0x000F;
45
46/// Read the DNS transaction ID.
47#[inline]
48pub fn read_id(buf: &[u8], base: usize) -> Result<u16, FieldError> {
49    u16::read(buf, base + ID_OFFSET)
50}
51
52/// Write the DNS transaction ID.
53#[inline]
54pub fn write_id(buf: &mut [u8], base: usize, id: u16) -> Result<(), FieldError> {
55    id.write(buf, base + ID_OFFSET)
56}
57
58/// Read the 16-bit flags field.
59#[inline]
60fn read_flags(buf: &[u8], base: usize) -> Result<u16, FieldError> {
61    u16::read(buf, base + FLAGS_OFFSET)
62}
63
64/// Write the 16-bit flags field.
65#[inline]
66fn write_flags(buf: &mut [u8], base: usize, flags: u16) -> Result<(), FieldError> {
67    flags.write(buf, base + FLAGS_OFFSET)
68}
69
70/// QR flag: 0 = query, 1 = response.
71#[inline]
72pub fn read_qr(buf: &[u8], base: usize) -> Result<bool, FieldError> {
73    Ok(read_flags(buf, base)? & FLAG_QR != 0)
74}
75
76pub fn write_qr(buf: &mut [u8], base: usize, qr: bool) -> Result<(), FieldError> {
77    let mut flags = read_flags(buf, base)?;
78    if qr {
79        flags |= FLAG_QR;
80    } else {
81        flags &= !FLAG_QR;
82    }
83    write_flags(buf, base, flags)
84}
85
86/// Opcode (4 bits).
87#[inline]
88pub fn read_opcode(buf: &[u8], base: usize) -> Result<u8, FieldError> {
89    let flags = read_flags(buf, base)?;
90    Ok(((flags & FLAG_OPCODE_MASK) >> FLAG_OPCODE_SHIFT) as u8)
91}
92
93pub fn write_opcode(buf: &mut [u8], base: usize, opcode: u8) -> Result<(), FieldError> {
94    let mut flags = read_flags(buf, base)?;
95    flags &= !FLAG_OPCODE_MASK;
96    flags |= ((opcode as u16) << FLAG_OPCODE_SHIFT) & FLAG_OPCODE_MASK;
97    write_flags(buf, base, flags)
98}
99
100/// AA flag: Authoritative Answer.
101#[inline]
102pub fn read_aa(buf: &[u8], base: usize) -> Result<bool, FieldError> {
103    Ok(read_flags(buf, base)? & FLAG_AA != 0)
104}
105
106pub fn write_aa(buf: &mut [u8], base: usize, aa: bool) -> Result<(), FieldError> {
107    let mut flags = read_flags(buf, base)?;
108    if aa {
109        flags |= FLAG_AA;
110    } else {
111        flags &= !FLAG_AA;
112    }
113    write_flags(buf, base, flags)
114}
115
116/// TC flag: Truncation.
117#[inline]
118pub fn read_tc(buf: &[u8], base: usize) -> Result<bool, FieldError> {
119    Ok(read_flags(buf, base)? & FLAG_TC != 0)
120}
121
122pub fn write_tc(buf: &mut [u8], base: usize, tc: bool) -> Result<(), FieldError> {
123    let mut flags = read_flags(buf, base)?;
124    if tc {
125        flags |= FLAG_TC;
126    } else {
127        flags &= !FLAG_TC;
128    }
129    write_flags(buf, base, flags)
130}
131
132/// RD flag: Recursion Desired.
133#[inline]
134pub fn read_rd(buf: &[u8], base: usize) -> Result<bool, FieldError> {
135    Ok(read_flags(buf, base)? & FLAG_RD != 0)
136}
137
138pub fn write_rd(buf: &mut [u8], base: usize, rd: bool) -> Result<(), FieldError> {
139    let mut flags = read_flags(buf, base)?;
140    if rd {
141        flags |= FLAG_RD;
142    } else {
143        flags &= !FLAG_RD;
144    }
145    write_flags(buf, base, flags)
146}
147
148/// RA flag: Recursion Available.
149#[inline]
150pub fn read_ra(buf: &[u8], base: usize) -> Result<bool, FieldError> {
151    Ok(read_flags(buf, base)? & FLAG_RA != 0)
152}
153
154pub fn write_ra(buf: &mut [u8], base: usize, ra: bool) -> Result<(), FieldError> {
155    let mut flags = read_flags(buf, base)?;
156    if ra {
157        flags |= FLAG_RA;
158    } else {
159        flags &= !FLAG_RA;
160    }
161    write_flags(buf, base, flags)
162}
163
164/// Z flag (reserved).
165#[inline]
166pub fn read_z(buf: &[u8], base: usize) -> Result<bool, FieldError> {
167    Ok(read_flags(buf, base)? & FLAG_Z != 0)
168}
169
170pub fn write_z(buf: &mut [u8], base: usize, z: bool) -> Result<(), FieldError> {
171    let mut flags = read_flags(buf, base)?;
172    if z {
173        flags |= FLAG_Z;
174    } else {
175        flags &= !FLAG_Z;
176    }
177    write_flags(buf, base, flags)
178}
179
180/// AD flag: Authenticated Data (DNSSEC).
181#[inline]
182pub fn read_ad(buf: &[u8], base: usize) -> Result<bool, FieldError> {
183    Ok(read_flags(buf, base)? & FLAG_AD != 0)
184}
185
186pub fn write_ad(buf: &mut [u8], base: usize, ad: bool) -> Result<(), FieldError> {
187    let mut flags = read_flags(buf, base)?;
188    if ad {
189        flags |= FLAG_AD;
190    } else {
191        flags &= !FLAG_AD;
192    }
193    write_flags(buf, base, flags)
194}
195
196/// CD flag: Checking Disabled (DNSSEC).
197#[inline]
198pub fn read_cd(buf: &[u8], base: usize) -> Result<bool, FieldError> {
199    Ok(read_flags(buf, base)? & FLAG_CD != 0)
200}
201
202pub fn write_cd(buf: &mut [u8], base: usize, cd: bool) -> Result<(), FieldError> {
203    let mut flags = read_flags(buf, base)?;
204    if cd {
205        flags |= FLAG_CD;
206    } else {
207        flags &= !FLAG_CD;
208    }
209    write_flags(buf, base, flags)
210}
211
212/// RCODE (4 bits): Response code.
213#[inline]
214pub fn read_rcode(buf: &[u8], base: usize) -> Result<u8, FieldError> {
215    let flags = read_flags(buf, base)?;
216    Ok((flags & FLAG_RCODE_MASK) as u8)
217}
218
219pub fn write_rcode(buf: &mut [u8], base: usize, rcode: u8) -> Result<(), FieldError> {
220    let mut flags = read_flags(buf, base)?;
221    flags &= !FLAG_RCODE_MASK;
222    flags |= (rcode as u16) & FLAG_RCODE_MASK;
223    write_flags(buf, base, flags)
224}
225
226/// Question count.
227#[inline]
228pub fn read_qdcount(buf: &[u8], base: usize) -> Result<u16, FieldError> {
229    u16::read(buf, base + QDCOUNT_OFFSET)
230}
231
232#[inline]
233pub fn write_qdcount(buf: &mut [u8], base: usize, count: u16) -> Result<(), FieldError> {
234    count.write(buf, base + QDCOUNT_OFFSET)
235}
236
237/// Answer count.
238#[inline]
239pub fn read_ancount(buf: &[u8], base: usize) -> Result<u16, FieldError> {
240    u16::read(buf, base + ANCOUNT_OFFSET)
241}
242
243#[inline]
244pub fn write_ancount(buf: &mut [u8], base: usize, count: u16) -> Result<(), FieldError> {
245    count.write(buf, base + ANCOUNT_OFFSET)
246}
247
248/// Authority count.
249#[inline]
250pub fn read_nscount(buf: &[u8], base: usize) -> Result<u16, FieldError> {
251    u16::read(buf, base + NSCOUNT_OFFSET)
252}
253
254#[inline]
255pub fn write_nscount(buf: &mut [u8], base: usize, count: u16) -> Result<(), FieldError> {
256    count.write(buf, base + NSCOUNT_OFFSET)
257}
258
259/// Additional count.
260#[inline]
261pub fn read_arcount(buf: &[u8], base: usize) -> Result<u16, FieldError> {
262    u16::read(buf, base + ARCOUNT_OFFSET)
263}
264
265#[inline]
266pub fn write_arcount(buf: &mut [u8], base: usize, count: u16) -> Result<(), FieldError> {
267    count.write(buf, base + ARCOUNT_OFFSET)
268}
269
270/// Build a raw 16-bit flags value from individual components.
271pub fn build_flags(
272    qr: bool,
273    opcode: u8,
274    aa: bool,
275    tc: bool,
276    rd: bool,
277    ra: bool,
278    z: bool,
279    ad: bool,
280    cd: bool,
281    rcode: u8,
282) -> u16 {
283    let mut flags: u16 = 0;
284    if qr {
285        flags |= FLAG_QR;
286    }
287    flags |= ((opcode as u16) << FLAG_OPCODE_SHIFT) & FLAG_OPCODE_MASK;
288    if aa {
289        flags |= FLAG_AA;
290    }
291    if tc {
292        flags |= FLAG_TC;
293    }
294    if rd {
295        flags |= FLAG_RD;
296    }
297    if ra {
298        flags |= FLAG_RA;
299    }
300    if z {
301        flags |= FLAG_Z;
302    }
303    if ad {
304        flags |= FLAG_AD;
305    }
306    if cd {
307        flags |= FLAG_CD;
308    }
309    flags |= (rcode as u16) & FLAG_RCODE_MASK;
310    flags
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    fn make_header(
318        id: u16,
319        qr: bool,
320        opcode: u8,
321        aa: bool,
322        tc: bool,
323        rd: bool,
324        ra: bool,
325        z: bool,
326        ad: bool,
327        cd: bool,
328        rcode: u8,
329        qdcount: u16,
330        ancount: u16,
331        nscount: u16,
332        arcount: u16,
333    ) -> Vec<u8> {
334        let flags = build_flags(qr, opcode, aa, tc, rd, ra, z, ad, cd, rcode);
335        let mut buf = vec![0u8; DNS_HEADER_LEN];
336        buf[0..2].copy_from_slice(&id.to_be_bytes());
337        buf[2..4].copy_from_slice(&flags.to_be_bytes());
338        buf[4..6].copy_from_slice(&qdcount.to_be_bytes());
339        buf[6..8].copy_from_slice(&ancount.to_be_bytes());
340        buf[8..10].copy_from_slice(&nscount.to_be_bytes());
341        buf[10..12].copy_from_slice(&arcount.to_be_bytes());
342        buf
343    }
344
345    #[test]
346    fn test_read_query_header() {
347        let buf = make_header(
348            0x1234, false, 0, false, false, true, false, false, false, false, 0, 1, 0, 0, 0,
349        );
350        assert_eq!(read_id(&buf, 0).unwrap(), 0x1234);
351        assert!(!read_qr(&buf, 0).unwrap());
352        assert_eq!(read_opcode(&buf, 0).unwrap(), 0);
353        assert!(!read_aa(&buf, 0).unwrap());
354        assert!(!read_tc(&buf, 0).unwrap());
355        assert!(read_rd(&buf, 0).unwrap());
356        assert!(!read_ra(&buf, 0).unwrap());
357        assert_eq!(read_rcode(&buf, 0).unwrap(), 0);
358        assert_eq!(read_qdcount(&buf, 0).unwrap(), 1);
359        assert_eq!(read_ancount(&buf, 0).unwrap(), 0);
360    }
361
362    #[test]
363    fn test_read_response_header() {
364        let buf = make_header(
365            0xABCD, true, 0, true, false, true, true, false, true, false, 0, 1, 2, 0, 1,
366        );
367        assert_eq!(read_id(&buf, 0).unwrap(), 0xABCD);
368        assert!(read_qr(&buf, 0).unwrap());
369        assert!(read_aa(&buf, 0).unwrap());
370        assert!(read_rd(&buf, 0).unwrap());
371        assert!(read_ra(&buf, 0).unwrap());
372        assert!(read_ad(&buf, 0).unwrap());
373        assert!(!read_cd(&buf, 0).unwrap());
374        assert_eq!(read_qdcount(&buf, 0).unwrap(), 1);
375        assert_eq!(read_ancount(&buf, 0).unwrap(), 2);
376        assert_eq!(read_arcount(&buf, 0).unwrap(), 1);
377    }
378
379    #[test]
380    fn test_write_flags() {
381        let mut buf = vec![0u8; DNS_HEADER_LEN];
382        write_id(&mut buf, 0, 0x5678).unwrap();
383        write_qr(&mut buf, 0, true).unwrap();
384        write_opcode(&mut buf, 0, 0).unwrap();
385        write_aa(&mut buf, 0, true).unwrap();
386        write_rd(&mut buf, 0, true).unwrap();
387        write_ra(&mut buf, 0, true).unwrap();
388        write_rcode(&mut buf, 0, 3).unwrap(); // NXDOMAIN
389        write_qdcount(&mut buf, 0, 1).unwrap();
390        write_ancount(&mut buf, 0, 0).unwrap();
391
392        assert_eq!(read_id(&buf, 0).unwrap(), 0x5678);
393        assert!(read_qr(&buf, 0).unwrap());
394        assert!(read_aa(&buf, 0).unwrap());
395        assert!(read_rd(&buf, 0).unwrap());
396        assert!(read_ra(&buf, 0).unwrap());
397        assert_eq!(read_rcode(&buf, 0).unwrap(), 3);
398    }
399
400    #[test]
401    fn test_build_flags() {
402        let flags = build_flags(true, 0, true, false, true, true, false, true, false, 3);
403        assert_eq!(flags & FLAG_QR, FLAG_QR);
404        assert_eq!(flags & FLAG_AA, FLAG_AA);
405        assert_eq!(flags & FLAG_RD, FLAG_RD);
406        assert_eq!(flags & FLAG_RA, FLAG_RA);
407        assert_eq!(flags & FLAG_AD, FLAG_AD);
408        assert_eq!(flags & FLAG_RCODE_MASK, 3);
409    }
410
411    #[test]
412    fn test_opcode_values() {
413        let mut buf = vec![0u8; DNS_HEADER_LEN];
414        for opcode in 0..=15 {
415            write_opcode(&mut buf, 0, opcode).unwrap();
416            assert_eq!(read_opcode(&buf, 0).unwrap(), opcode);
417        }
418    }
419
420    #[test]
421    fn test_rcode_values() {
422        let mut buf = vec![0u8; DNS_HEADER_LEN];
423        for rcode in 0..=15 {
424            write_rcode(&mut buf, 0, rcode).unwrap();
425            assert_eq!(read_rcode(&buf, 0).unwrap(), rcode);
426        }
427    }
428
429    #[test]
430    fn test_header_at_offset() {
431        // Simulate DNS header not at position 0 (e.g., after UDP)
432        let mut buf = vec![0xAA; 20 + DNS_HEADER_LEN];
433        let base = 20;
434        write_id(&mut buf, base, 0x9999).unwrap();
435        write_qdcount(&mut buf, base, 5).unwrap();
436        assert_eq!(read_id(&buf, base).unwrap(), 0x9999);
437        assert_eq!(read_qdcount(&buf, base).unwrap(), 5);
438    }
439}