Skip to main content

stackforge_core/layer/
field_ext.rs

1//! Extended field types for complex protocol fields.
2//!
3//! This module provides:
4//! - `DnsName`: DNS domain name with pointer compression support
5//! - `FlagValue`: Named bit flags with display support
6//! - Bit field reader utilities
7
8use std::collections::HashMap;
9use std::fmt;
10
11use super::field::FieldError;
12
13// ============================================================================
14// DNS Name (RFC 1035 Section 4.1.4)
15// ============================================================================
16
17/// Maximum number of pointer hops to prevent infinite loops.
18const DNS_MAX_POINTER_HOPS: usize = 128;
19/// Maximum label length per RFC 1035.
20const DNS_MAX_LABEL_LEN: usize = 63;
21/// Maximum total name length per RFC 1035.
22const DNS_MAX_NAME_LEN: usize = 253;
23/// Pointer flag: top 2 bits set indicates a compression pointer.
24const DNS_POINTER_FLAG: u8 = 0xC0;
25
26/// A DNS domain name consisting of labels.
27///
28/// Supports encoding/decoding with RFC 1035 compression pointers.
29#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
30pub struct DnsName {
31    /// The labels making up the domain name (e.g., ["www", "example", "com"]).
32    pub labels: Vec<String>,
33}
34
35impl DnsName {
36    /// Create a new `DnsName` from labels.
37    #[must_use]
38    pub fn new(labels: Vec<String>) -> Self {
39        Self { labels }
40    }
41
42    /// Create a root name (empty labels).
43    #[must_use]
44    pub fn root() -> Self {
45        Self { labels: vec![] }
46    }
47
48    /// Parse a DNS name from a dot-separated string.
49    /// "www.example.com" → labels: ["www", "example", "com"]
50    /// "www.example.com." → same (trailing dot is ignored)
51    pub fn from_str_dotted(s: &str) -> Result<Self, FieldError> {
52        if s.is_empty() || s == "." {
53            return Ok(Self::root());
54        }
55        let s = s.strip_suffix('.').unwrap_or(s);
56        let labels: Vec<String> = s.split('.').map(std::string::ToString::to_string).collect();
57        // Validate label lengths
58        for label in &labels {
59            if label.len() > DNS_MAX_LABEL_LEN {
60                return Err(FieldError::InvalidValue(format!(
61                    "DNS label too long: {} bytes (max {})",
62                    label.len(),
63                    DNS_MAX_LABEL_LEN
64                )));
65            }
66        }
67        let total_len: usize = labels.iter().map(|l| l.len() + 1).sum::<usize>() + 1;
68        if total_len > DNS_MAX_NAME_LEN + 2 {
69            return Err(FieldError::InvalidValue(format!(
70                "DNS name too long: {total_len} bytes (max {DNS_MAX_NAME_LEN})"
71            )));
72        }
73        Ok(Self { labels })
74    }
75
76    /// Check if this is the root name.
77    #[must_use]
78    pub fn is_root(&self) -> bool {
79        self.labels.is_empty()
80    }
81
82    /// Get the fully qualified domain name string.
83    /// Returns "www.example.com." with trailing dot.
84    #[must_use]
85    pub fn to_fqdn(&self) -> String {
86        if self.labels.is_empty() {
87            return ".".to_string();
88        }
89        format!("{}.", self.labels.join("."))
90    }
91
92    /// Encode to wire format without compression.
93    /// Each label is preceded by its length byte, terminated by a zero byte.
94    #[must_use]
95    pub fn encode(&self) -> Vec<u8> {
96        let mut out = Vec::new();
97        for label in &self.labels {
98            out.push(label.len() as u8);
99            out.extend_from_slice(label.as_bytes());
100        }
101        out.push(0); // root label
102        out
103    }
104
105    /// Encode to wire format with compression.
106    /// Uses `compression_map` to track previously written name positions.
107    /// `current_offset` is where this name will be written in the packet.
108    pub fn encode_compressed(
109        &self,
110        current_offset: usize,
111        compression_map: &mut HashMap<String, u16>,
112    ) -> Vec<u8> {
113        let mut out = Vec::new();
114        let mut offset = current_offset;
115
116        for i in 0..self.labels.len() {
117            // Check if the suffix from this label onwards was already written
118            let suffix = self.labels[i..].join(".");
119            if let Some(&ptr) = compression_map.get(&suffix) {
120                // Write a pointer to the previous occurrence
121                out.push(DNS_POINTER_FLAG | ((ptr >> 8) as u8));
122                out.push((ptr & 0xFF) as u8);
123                return out;
124            }
125            // Record this suffix position in the compression map
126            if offset < 0x3FFF {
127                compression_map.insert(suffix, offset as u16);
128            }
129            // Write the label
130            let label = &self.labels[i];
131            out.push(label.len() as u8);
132            out.extend_from_slice(label.as_bytes());
133            offset += 1 + label.len();
134        }
135        out.push(0); // root label
136        out
137    }
138
139    /// Decode a DNS name from wire format with pointer decompression.
140    ///
141    /// `packet` is the full packet buffer (needed for pointer resolution).
142    /// `offset` is the starting position of the name.
143    ///
144    /// Returns the decoded name and the number of bytes consumed from `offset`
145    /// (not counting bytes reached via pointers).
146    pub fn decode(packet: &[u8], offset: usize) -> Result<(Self, usize), FieldError> {
147        let mut labels = Vec::new();
148        let mut pos = offset;
149        let mut bytes_consumed = 0;
150        let mut followed_pointer = false;
151        let mut hops = 0;
152
153        loop {
154            if pos >= packet.len() {
155                return Err(FieldError::BufferTooShort {
156                    offset: pos,
157                    need: 1,
158                    have: packet.len(),
159                });
160            }
161
162            let len_or_ptr = packet[pos];
163
164            if len_or_ptr == 0 {
165                // Root label — end of name
166                if !followed_pointer {
167                    bytes_consumed = pos - offset + 1;
168                }
169                break;
170            } else if len_or_ptr & DNS_POINTER_FLAG == DNS_POINTER_FLAG {
171                // Compression pointer
172                if pos + 1 >= packet.len() {
173                    return Err(FieldError::BufferTooShort {
174                        offset: pos,
175                        need: 2,
176                        have: packet.len(),
177                    });
178                }
179                let ptr = (((len_or_ptr & 0x3F) as usize) << 8) | (packet[pos + 1] as usize);
180
181                if !followed_pointer {
182                    bytes_consumed = pos - offset + 2;
183                    followed_pointer = true;
184                }
185
186                hops += 1;
187                if hops > DNS_MAX_POINTER_HOPS {
188                    return Err(FieldError::InvalidValue(
189                        "DNS name compression loop detected".to_string(),
190                    ));
191                }
192
193                if ptr >= packet.len() {
194                    return Err(FieldError::InvalidValue(format!(
195                        "DNS compression pointer {:#06x} out of bounds (packet len {})",
196                        ptr,
197                        packet.len()
198                    )));
199                }
200
201                pos = ptr;
202            } else {
203                // Regular label
204                let label_len = len_or_ptr as usize;
205                if label_len > DNS_MAX_LABEL_LEN {
206                    return Err(FieldError::InvalidValue(format!(
207                        "DNS label too long: {label_len} bytes (max {DNS_MAX_LABEL_LEN})"
208                    )));
209                }
210                if pos + 1 + label_len > packet.len() {
211                    return Err(FieldError::BufferTooShort {
212                        offset: pos + 1,
213                        need: label_len,
214                        have: packet.len() - pos - 1,
215                    });
216                }
217                let label =
218                    String::from_utf8_lossy(&packet[pos + 1..pos + 1 + label_len]).into_owned();
219                labels.push(label);
220                pos += 1 + label_len;
221            }
222        }
223
224        Ok((Self { labels }, bytes_consumed))
225    }
226
227    /// Wire-format length of this name without compression.
228    #[must_use]
229    pub fn wire_len(&self) -> usize {
230        if self.labels.is_empty() {
231            return 1; // just the root label (0x00)
232        }
233        self.labels.iter().map(|l| l.len() + 1).sum::<usize>() + 1
234    }
235}
236
237impl fmt::Display for DnsName {
238    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
239        if self.labels.is_empty() {
240            write!(f, ".")
241        } else {
242            write!(f, "{}.", self.labels.join("."))
243        }
244    }
245}
246
247impl From<&str> for DnsName {
248    fn from(s: &str) -> Self {
249        DnsName::from_str_dotted(s).unwrap_or_default()
250    }
251}
252
253// ============================================================================
254// FlagValue - Named bit flags
255// ============================================================================
256
257/// Named bit flags for protocol fields.
258///
259/// Stores a flag value along with names for each bit position.
260#[derive(Debug, Clone)]
261pub struct FlagValue {
262    /// The raw flag bits.
263    pub value: u64,
264    /// Flag names indexed by bit position (LSB = index 0).
265    pub names: &'static [&'static str],
266}
267
268impl FlagValue {
269    #[must_use]
270    pub fn new(value: u64, names: &'static [&'static str]) -> Self {
271        Self { value, names }
272    }
273
274    /// Check if a specific flag bit is set.
275    #[must_use]
276    pub fn has(&self, bit: usize) -> bool {
277        (self.value >> bit) & 1 != 0
278    }
279
280    /// Check if a named flag is set. Returns None if name not found.
281    #[must_use]
282    pub fn has_named(&self, name: &str) -> Option<bool> {
283        self.names
284            .iter()
285            .position(|&n| n == name)
286            .map(|bit| self.has(bit))
287    }
288
289    /// Set a specific flag bit.
290    pub fn set(&mut self, bit: usize) {
291        self.value |= 1u64 << bit;
292    }
293
294    /// Clear a specific flag bit.
295    pub fn clear(&mut self, bit: usize) {
296        self.value &= !(1u64 << bit);
297    }
298
299    /// Get the list of set flag names.
300    #[must_use]
301    pub fn set_flags(&self) -> Vec<&'static str> {
302        let mut flags = Vec::new();
303        for (i, &name) in self.names.iter().enumerate() {
304            if !name.is_empty() && self.has(i) {
305                flags.push(name);
306            }
307        }
308        flags
309    }
310}
311
312impl fmt::Display for FlagValue {
313    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314        let flags = self.set_flags();
315        if flags.is_empty() {
316            write!(f, "0")
317        } else {
318            write!(f, "{}", flags.join("+"))
319        }
320    }
321}
322
323impl PartialEq for FlagValue {
324    fn eq(&self, other: &Self) -> bool {
325        self.value == other.value
326    }
327}
328
329impl Eq for FlagValue {}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    // ========================================================================
336    // DnsName tests
337    // ========================================================================
338
339    #[test]
340    fn test_dns_name_from_str() {
341        let name = DnsName::from_str_dotted("www.example.com").unwrap();
342        assert_eq!(name.labels, vec!["www", "example", "com"]);
343        assert_eq!(name.to_fqdn(), "www.example.com.");
344        assert_eq!(name.to_string(), "www.example.com.");
345    }
346
347    #[test]
348    fn test_dns_name_from_str_trailing_dot() {
349        let name = DnsName::from_str_dotted("www.example.com.").unwrap();
350        assert_eq!(name.labels, vec!["www", "example", "com"]);
351    }
352
353    #[test]
354    fn test_dns_name_root() {
355        let name = DnsName::from_str_dotted(".").unwrap();
356        assert!(name.is_root());
357        assert_eq!(name.to_fqdn(), ".");
358    }
359
360    #[test]
361    fn test_dns_name_empty() {
362        let name = DnsName::from_str_dotted("").unwrap();
363        assert!(name.is_root());
364    }
365
366    #[test]
367    fn test_dns_name_encode() {
368        let name = DnsName::from_str_dotted("www.example.com").unwrap();
369        let encoded = name.encode();
370        assert_eq!(
371            encoded,
372            vec![
373                3, b'w', b'w', b'w', 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o',
374                b'm', 0
375            ]
376        );
377    }
378
379    #[test]
380    fn test_dns_name_encode_root() {
381        let name = DnsName::root();
382        assert_eq!(name.encode(), vec![0]);
383        assert_eq!(name.wire_len(), 1);
384    }
385
386    #[test]
387    fn test_dns_name_decode_simple() {
388        let data = vec![
389            3, b'w', b'w', b'w', 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm',
390            0,
391        ];
392        let (name, consumed) = DnsName::decode(&data, 0).unwrap();
393        assert_eq!(name.labels, vec!["www", "example", "com"]);
394        assert_eq!(consumed, 17);
395    }
396
397    #[test]
398    fn test_dns_name_decode_with_pointer() {
399        // Build a packet with a pointer:
400        // offset 0: \x07example\x03com\x00  (13 bytes)
401        // offset 13: \x03www\xc0\x00        (www + pointer to offset 0)
402        let mut data = vec![];
403        // "example.com" at offset 0
404        data.extend_from_slice(&[
405            7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0,
406        ]);
407        // "www" + pointer to offset 0
408        data.extend_from_slice(&[3, b'w', b'w', b'w', 0xC0, 0x00]);
409
410        let (name, consumed) = DnsName::decode(&data, 13).unwrap();
411        assert_eq!(name.labels, vec!["www", "example", "com"]);
412        assert_eq!(consumed, 6); // 1+3 for "www" + 2 for pointer
413    }
414
415    #[test]
416    fn test_dns_name_decode_pointer_loop() {
417        // Create a loop: offset 0 points to offset 2, offset 2 points to offset 0
418        let data = vec![0xC0, 0x02, 0xC0, 0x00];
419        let result = DnsName::decode(&data, 0);
420        assert!(result.is_err());
421        assert!(result.unwrap_err().to_string().contains("loop detected"));
422    }
423
424    #[test]
425    fn test_dns_name_decode_pointer_out_of_bounds() {
426        let data = vec![0xC0, 0xFF]; // Pointer to offset 0x3FF, way beyond buffer
427        let result = DnsName::decode(&data, 0);
428        assert!(result.is_err());
429    }
430
431    #[test]
432    fn test_dns_name_label_too_long() {
433        let long_label = "a".repeat(64);
434        let result = DnsName::from_str_dotted(&long_label);
435        assert!(result.is_err());
436    }
437
438    #[test]
439    fn test_dns_name_compression_roundtrip() {
440        let name1 = DnsName::from_str_dotted("www.example.com").unwrap();
441        let name2 = DnsName::from_str_dotted("mail.example.com").unwrap();
442
443        let mut compression_map = HashMap::new();
444        let mut packet = Vec::new();
445
446        // Write first name at offset 0
447        let encoded1 = name1.encode_compressed(0, &mut compression_map);
448        packet.extend_from_slice(&encoded1);
449
450        // Write second name — should compress "example.com" part
451        let encoded2 = name2.encode_compressed(packet.len(), &mut compression_map);
452        packet.extend_from_slice(&encoded2);
453
454        // The second name should use a pointer for "example.com"
455        // It should be shorter than encoding without compression
456        let uncompressed2 = name2.encode();
457        assert!(encoded2.len() < uncompressed2.len());
458
459        // Decode both names and verify
460        let (decoded1, _) = DnsName::decode(&packet, 0).unwrap();
461        assert_eq!(decoded1, name1);
462
463        let (decoded2, _) = DnsName::decode(&packet, encoded1.len()).unwrap();
464        assert_eq!(decoded2, name2);
465    }
466
467    #[test]
468    fn test_dns_name_wire_len() {
469        let name = DnsName::from_str_dotted("www.example.com").unwrap();
470        assert_eq!(name.wire_len(), 17); // 1+3 + 1+7 + 1+3 + 1
471    }
472
473    #[test]
474    fn test_dns_name_decode_at_offset() {
475        // Simulating a DNS packet where the name starts at a non-zero offset
476        let mut data = vec![0xAA, 0xBB]; // Some header bytes
477        data.extend_from_slice(&[4, b't', b'e', b's', b't', 0]);
478        let (name, consumed) = DnsName::decode(&data, 2).unwrap();
479        assert_eq!(name.labels, vec!["test"]);
480        assert_eq!(consumed, 6);
481    }
482
483    // ========================================================================
484    // FlagValue tests
485    // ========================================================================
486
487    static TCP_FLAG_NAMES: &[&str] = &["FIN", "SYN", "RST", "PSH", "ACK", "URG", "ECE", "CWR"];
488
489    #[test]
490    fn test_flag_value_display() {
491        let flags = FlagValue::new(0b00010010, TCP_FLAG_NAMES); // SYN + ACK
492        assert_eq!(flags.to_string(), "SYN+ACK");
493    }
494
495    #[test]
496    fn test_flag_value_empty() {
497        let flags = FlagValue::new(0, TCP_FLAG_NAMES);
498        assert_eq!(flags.to_string(), "0");
499    }
500
501    #[test]
502    fn test_flag_value_has() {
503        let flags = FlagValue::new(0b00000010, TCP_FLAG_NAMES); // SYN
504        assert!(flags.has(1)); // SYN bit
505        assert!(!flags.has(0)); // FIN bit
506        assert!(!flags.has(4)); // ACK bit
507    }
508
509    #[test]
510    fn test_flag_value_has_named() {
511        let flags = FlagValue::new(0b00010010, TCP_FLAG_NAMES);
512        assert_eq!(flags.has_named("SYN"), Some(true));
513        assert_eq!(flags.has_named("ACK"), Some(true));
514        assert_eq!(flags.has_named("FIN"), Some(false));
515        assert_eq!(flags.has_named("NONEXISTENT"), None);
516    }
517
518    #[test]
519    fn test_flag_value_set_clear() {
520        let mut flags = FlagValue::new(0, TCP_FLAG_NAMES);
521        flags.set(1); // Set SYN
522        assert!(flags.has(1));
523        assert_eq!(flags.value, 2);
524
525        flags.set(4); // Set ACK
526        assert_eq!(flags.to_string(), "SYN+ACK");
527
528        flags.clear(1); // Clear SYN
529        assert_eq!(flags.to_string(), "ACK");
530    }
531
532    #[test]
533    fn test_flag_value_set_flags() {
534        let flags = FlagValue::new(0b00010011, TCP_FLAG_NAMES); // FIN+SYN+ACK
535        let set = flags.set_flags();
536        assert_eq!(set, vec!["FIN", "SYN", "ACK"]);
537    }
538}