Skip to main content

stackforge_core/layer/dns/
builder.rs

1//! DNS packet builder with fluent API.
2//!
3//! Provides a builder pattern for constructing DNS packets,
4//! similar to Scapy's `DNS()` constructor.
5
6use std::collections::HashMap;
7
8use super::header;
9use super::query::DnsQuestion;
10use super::rr::DnsResourceRecord;
11use super::types;
12
13/// Builder for constructing DNS packets.
14#[derive(Debug, Clone)]
15pub struct DnsBuilder {
16    /// Transaction ID.
17    pub id: u16,
18    /// Query/Response flag.
19    pub qr: bool,
20    /// Operation code.
21    pub opcode: u8,
22    /// Authoritative Answer flag.
23    pub aa: bool,
24    /// Truncation flag.
25    pub tc: bool,
26    /// Recursion Desired flag.
27    pub rd: bool,
28    /// Recursion Available flag.
29    pub ra: bool,
30    /// Reserved (Z) flag.
31    pub z: bool,
32    /// Authenticated Data flag (DNSSEC).
33    pub ad: bool,
34    /// Checking Disabled flag (DNSSEC).
35    pub cd: bool,
36    /// Response code.
37    pub rcode: u8,
38    /// Question section.
39    pub questions: Vec<DnsQuestion>,
40    /// Answer section.
41    pub answers: Vec<DnsResourceRecord>,
42    /// Authority section.
43    pub authorities: Vec<DnsResourceRecord>,
44    /// Additional section.
45    pub additionals: Vec<DnsResourceRecord>,
46    /// Whether to use DNS name compression when building.
47    pub compress: bool,
48}
49
50impl DnsBuilder {
51    /// Create a new DNS builder with default values (standard query with RD=1).
52    #[must_use]
53    pub fn new() -> Self {
54        Self {
55            id: 0,
56            qr: false,
57            opcode: types::opcode::QUERY,
58            aa: false,
59            tc: false,
60            rd: true,
61            ra: false,
62            z: false,
63            ad: false,
64            cd: false,
65            rcode: types::rcode::NOERROR,
66            questions: Vec::new(),
67            answers: Vec::new(),
68            authorities: Vec::new(),
69            additionals: Vec::new(),
70            compress: true,
71        }
72    }
73
74    /// Create a builder for a standard query.
75    #[must_use]
76    pub fn query(qname: &str, qtype: u16) -> Self {
77        let mut b = Self::new();
78        if let Ok(q) = DnsQuestion::from_name(qname) {
79            let mut q = q;
80            q.qtype = qtype;
81            b.questions.push(q);
82        }
83        b
84    }
85
86    /// Create a builder for a standard response.
87    #[must_use]
88    pub fn response() -> Self {
89        let mut b = Self::new();
90        b.qr = true;
91        b.ra = true;
92        b
93    }
94
95    // Fluent setters
96
97    #[must_use]
98    pub fn id(mut self, id: u16) -> Self {
99        self.id = id;
100        self
101    }
102
103    #[must_use]
104    pub fn qr(mut self, qr: bool) -> Self {
105        self.qr = qr;
106        self
107    }
108
109    #[must_use]
110    pub fn opcode(mut self, opcode: u8) -> Self {
111        self.opcode = opcode;
112        self
113    }
114
115    #[must_use]
116    pub fn aa(mut self, aa: bool) -> Self {
117        self.aa = aa;
118        self
119    }
120
121    #[must_use]
122    pub fn tc(mut self, tc: bool) -> Self {
123        self.tc = tc;
124        self
125    }
126
127    #[must_use]
128    pub fn rd(mut self, rd: bool) -> Self {
129        self.rd = rd;
130        self
131    }
132
133    #[must_use]
134    pub fn ra(mut self, ra: bool) -> Self {
135        self.ra = ra;
136        self
137    }
138
139    #[must_use]
140    pub fn z(mut self, z: bool) -> Self {
141        self.z = z;
142        self
143    }
144
145    #[must_use]
146    pub fn ad(mut self, ad: bool) -> Self {
147        self.ad = ad;
148        self
149    }
150
151    #[must_use]
152    pub fn cd(mut self, cd: bool) -> Self {
153        self.cd = cd;
154        self
155    }
156
157    #[must_use]
158    pub fn rcode(mut self, rcode: u8) -> Self {
159        self.rcode = rcode;
160        self
161    }
162
163    #[must_use]
164    pub fn compress(mut self, compress: bool) -> Self {
165        self.compress = compress;
166        self
167    }
168
169    /// Add a question to the question section.
170    #[must_use]
171    pub fn question(mut self, q: DnsQuestion) -> Self {
172        self.questions.push(q);
173        self
174    }
175
176    /// Add a resource record to the answer section.
177    #[must_use]
178    pub fn answer(mut self, rr: DnsResourceRecord) -> Self {
179        self.answers.push(rr);
180        self
181    }
182
183    /// Add a resource record to the authority section.
184    #[must_use]
185    pub fn authority(mut self, rr: DnsResourceRecord) -> Self {
186        self.authorities.push(rr);
187        self
188    }
189
190    /// Add a resource record to the additional section.
191    #[must_use]
192    pub fn additional(mut self, rr: DnsResourceRecord) -> Self {
193        self.additionals.push(rr);
194        self
195    }
196
197    /// Build the DNS packet bytes.
198    #[must_use]
199    pub fn build(&self) -> Vec<u8> {
200        if self.compress {
201            self.build_compressed()
202        } else {
203            self.build_uncompressed()
204        }
205    }
206
207    /// Build without name compression.
208    fn build_uncompressed(&self) -> Vec<u8> {
209        let mut out = Vec::with_capacity(512);
210
211        // Header (12 bytes)
212        out.extend_from_slice(&self.id.to_be_bytes());
213        let flags = header::build_flags(
214            self.qr,
215            self.opcode,
216            self.aa,
217            self.tc,
218            self.rd,
219            self.ra,
220            self.z,
221            self.ad,
222            self.cd,
223            self.rcode,
224        );
225        out.extend_from_slice(&flags.to_be_bytes());
226        out.extend_from_slice(&(self.questions.len() as u16).to_be_bytes());
227        out.extend_from_slice(&(self.answers.len() as u16).to_be_bytes());
228        out.extend_from_slice(&(self.authorities.len() as u16).to_be_bytes());
229        out.extend_from_slice(&(self.additionals.len() as u16).to_be_bytes());
230
231        // Questions
232        for q in &self.questions {
233            out.extend_from_slice(&q.build());
234        }
235
236        // Answers
237        for rr in &self.answers {
238            out.extend_from_slice(&rr.build());
239        }
240
241        // Authorities
242        for rr in &self.authorities {
243            out.extend_from_slice(&rr.build());
244        }
245
246        // Additionals
247        for rr in &self.additionals {
248            out.extend_from_slice(&rr.build());
249        }
250
251        out
252    }
253
254    /// Build with name compression.
255    fn build_compressed(&self) -> Vec<u8> {
256        let mut out = Vec::with_capacity(512);
257        let mut compression_map: HashMap<String, u16> = HashMap::new();
258
259        // Header (12 bytes)
260        out.extend_from_slice(&self.id.to_be_bytes());
261        let flags = header::build_flags(
262            self.qr,
263            self.opcode,
264            self.aa,
265            self.tc,
266            self.rd,
267            self.ra,
268            self.z,
269            self.ad,
270            self.cd,
271            self.rcode,
272        );
273        out.extend_from_slice(&flags.to_be_bytes());
274        out.extend_from_slice(&(self.questions.len() as u16).to_be_bytes());
275        out.extend_from_slice(&(self.answers.len() as u16).to_be_bytes());
276        out.extend_from_slice(&(self.authorities.len() as u16).to_be_bytes());
277        out.extend_from_slice(&(self.additionals.len() as u16).to_be_bytes());
278
279        // Questions
280        for q in &self.questions {
281            let encoded = q.build_compressed(out.len(), &mut compression_map);
282            out.extend_from_slice(&encoded);
283        }
284
285        // Answers
286        for rr in &self.answers {
287            let encoded = rr.build_compressed(out.len(), &mut compression_map);
288            out.extend_from_slice(&encoded);
289        }
290
291        // Authorities
292        for rr in &self.authorities {
293            let encoded = rr.build_compressed(out.len(), &mut compression_map);
294            out.extend_from_slice(&encoded);
295        }
296
297        // Additionals
298        for rr in &self.additionals {
299            let encoded = rr.build_compressed(out.len(), &mut compression_map);
300            out.extend_from_slice(&encoded);
301        }
302
303        out
304    }
305
306    /// Get the minimum header size.
307    #[must_use]
308    pub fn header_size(&self) -> usize {
309        header::DNS_HEADER_LEN
310    }
311}
312
313impl Default for DnsBuilder {
314    fn default() -> Self {
315        Self::new()
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::super::types::rr_type;
322    use super::*;
323    use crate::layer::field_ext::DnsName;
324
325    #[test]
326    fn test_builder_default_query() {
327        let b = DnsBuilder::new();
328        assert!(!b.qr);
329        assert!(b.rd);
330        assert_eq!(b.opcode, 0);
331        assert_eq!(b.rcode, 0);
332        assert!(b.questions.is_empty());
333    }
334
335    #[test]
336    fn test_builder_query_shortcut() {
337        let b = DnsBuilder::query("example.com", rr_type::A);
338        assert_eq!(b.questions.len(), 1);
339        assert_eq!(b.questions[0].qtype, rr_type::A);
340        assert_eq!(b.questions[0].qname.labels, vec!["example", "com"]);
341    }
342
343    #[test]
344    fn test_builder_fluent_api() {
345        let b = DnsBuilder::new()
346            .id(0x1234)
347            .qr(true)
348            .aa(true)
349            .rd(true)
350            .ra(true)
351            .rcode(0);
352        assert_eq!(b.id, 0x1234);
353        assert!(b.qr);
354        assert!(b.aa);
355    }
356
357    #[test]
358    fn test_builder_build_simple_query() {
359        let b = DnsBuilder::query("example.com", rr_type::A).id(0x1234);
360        let packet = b.build();
361
362        // Verify header
363        assert_eq!(packet.len() >= 12, true);
364        assert_eq!(u16::from_be_bytes([packet[0], packet[1]]), 0x1234); // ID
365        let qdcount = u16::from_be_bytes([packet[4], packet[5]]);
366        assert_eq!(qdcount, 1);
367    }
368
369    #[test]
370    fn test_builder_build_uncompressed() {
371        let b = DnsBuilder::query("example.com", rr_type::A)
372            .id(0x5678)
373            .compress(false);
374        let packet = b.build();
375        assert!(packet.len() >= 12 + 4 + 13); // header + type/class + "example.com" encoded
376    }
377
378    #[test]
379    fn test_builder_compression_reduces_size() {
380        let q1 = DnsQuestion::from_name("www.example.com").unwrap();
381        let q2 = DnsQuestion::from_name("mail.example.com").unwrap();
382        let b = DnsBuilder::new()
383            .question(q1.clone())
384            .question(q2.clone())
385            .compress(true);
386        let compressed = b.build();
387
388        let b2 = DnsBuilder::new().question(q1).question(q2).compress(false);
389        let uncompressed = b2.build();
390
391        assert!(compressed.len() < uncompressed.len());
392    }
393
394    #[test]
395    fn test_builder_response() {
396        let b = DnsBuilder::response();
397        assert!(b.qr);
398        assert!(b.ra);
399    }
400}