Skip to main content

stackforge_core/layer/
field_ext.rs

1//! Extended field types for complex protocol fields.
2//!
3//! This module provides:
4//! - `DnsName`: DNS domain name with pointer compression support
5//! - `FlagValue`: Named bit flags with display support
6//! - Bit field reader utilities
7
8use std::collections::HashMap;
9use std::fmt;
10
11use super::field::FieldError;
12
13// ============================================================================
14// DNS Name (RFC 1035 Section 4.1.4)
15// ============================================================================
16
17/// Maximum number of pointer hops to prevent infinite loops.
18const DNS_MAX_POINTER_HOPS: usize = 128;
19/// Maximum label length per RFC 1035.
20const DNS_MAX_LABEL_LEN: usize = 63;
21/// Maximum total name length per RFC 1035.
22const DNS_MAX_NAME_LEN: usize = 253;
23/// Pointer flag: top 2 bits set indicates a compression pointer.
24const DNS_POINTER_FLAG: u8 = 0xC0;
25
26/// A DNS domain name consisting of labels.
27///
28/// Supports encoding/decoding with RFC 1035 compression pointers.
29#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
30pub struct DnsName {
31    /// The labels making up the domain name (e.g., ["www", "example", "com"]).
32    pub labels: Vec<String>,
33}
34
35impl DnsName {
36    /// Create a new DnsName from labels.
37    pub fn new(labels: Vec<String>) -> Self {
38        Self { labels }
39    }
40
41    /// Create a root name (empty labels).
42    pub fn root() -> Self {
43        Self { labels: vec![] }
44    }
45
46    /// Parse a DNS name from a dot-separated string.
47    /// "www.example.com" → labels: ["www", "example", "com"]
48    /// "www.example.com." → same (trailing dot is ignored)
49    pub fn from_str_dotted(s: &str) -> Result<Self, FieldError> {
50        if s.is_empty() || s == "." {
51            return Ok(Self::root());
52        }
53        let s = s.strip_suffix('.').unwrap_or(s);
54        let labels: Vec<String> = s.split('.').map(|l| l.to_string()).collect();
55        // Validate label lengths
56        for label in &labels {
57            if label.len() > DNS_MAX_LABEL_LEN {
58                return Err(FieldError::InvalidValue(format!(
59                    "DNS label too long: {} bytes (max {})",
60                    label.len(),
61                    DNS_MAX_LABEL_LEN
62                )));
63            }
64        }
65        let total_len: usize = labels.iter().map(|l| l.len() + 1).sum::<usize>() + 1;
66        if total_len > DNS_MAX_NAME_LEN + 2 {
67            return Err(FieldError::InvalidValue(format!(
68                "DNS name too long: {} bytes (max {})",
69                total_len, DNS_MAX_NAME_LEN
70            )));
71        }
72        Ok(Self { labels })
73    }
74
75    /// Check if this is the root name.
76    pub fn is_root(&self) -> bool {
77        self.labels.is_empty()
78    }
79
80    /// Get the fully qualified domain name string.
81    /// Returns "www.example.com." with trailing dot.
82    pub fn to_fqdn(&self) -> String {
83        if self.labels.is_empty() {
84            return ".".to_string();
85        }
86        format!("{}.", self.labels.join("."))
87    }
88
89    /// Encode to wire format without compression.
90    /// Each label is preceded by its length byte, terminated by a zero byte.
91    pub fn encode(&self) -> Vec<u8> {
92        let mut out = Vec::new();
93        for label in &self.labels {
94            out.push(label.len() as u8);
95            out.extend_from_slice(label.as_bytes());
96        }
97        out.push(0); // root label
98        out
99    }
100
101    /// Encode to wire format with compression.
102    /// Uses `compression_map` to track previously written name positions.
103    /// `current_offset` is where this name will be written in the packet.
104    pub fn encode_compressed(
105        &self,
106        current_offset: usize,
107        compression_map: &mut HashMap<String, u16>,
108    ) -> Vec<u8> {
109        let mut out = Vec::new();
110        let mut offset = current_offset;
111
112        for i in 0..self.labels.len() {
113            // Check if the suffix from this label onwards was already written
114            let suffix = self.labels[i..].join(".");
115            if let Some(&ptr) = compression_map.get(&suffix) {
116                // Write a pointer to the previous occurrence
117                out.push(DNS_POINTER_FLAG | ((ptr >> 8) as u8));
118                out.push((ptr & 0xFF) as u8);
119                return out;
120            }
121            // Record this suffix position in the compression map
122            if offset < 0x3FFF {
123                compression_map.insert(suffix, offset as u16);
124            }
125            // Write the label
126            let label = &self.labels[i];
127            out.push(label.len() as u8);
128            out.extend_from_slice(label.as_bytes());
129            offset += 1 + label.len();
130        }
131        out.push(0); // root label
132        out
133    }
134
135    /// Decode a DNS name from wire format with pointer decompression.
136    ///
137    /// `packet` is the full packet buffer (needed for pointer resolution).
138    /// `offset` is the starting position of the name.
139    ///
140    /// Returns the decoded name and the number of bytes consumed from `offset`
141    /// (not counting bytes reached via pointers).
142    pub fn decode(packet: &[u8], offset: usize) -> Result<(Self, usize), FieldError> {
143        let mut labels = Vec::new();
144        let mut pos = offset;
145        let mut bytes_consumed = 0;
146        let mut followed_pointer = false;
147        let mut hops = 0;
148
149        loop {
150            if pos >= packet.len() {
151                return Err(FieldError::BufferTooShort {
152                    offset: pos,
153                    need: 1,
154                    have: packet.len(),
155                });
156            }
157
158            let len_or_ptr = packet[pos];
159
160            if len_or_ptr == 0 {
161                // Root label — end of name
162                if !followed_pointer {
163                    bytes_consumed = pos - offset + 1;
164                }
165                break;
166            } else if len_or_ptr & DNS_POINTER_FLAG == DNS_POINTER_FLAG {
167                // Compression pointer
168                if pos + 1 >= packet.len() {
169                    return Err(FieldError::BufferTooShort {
170                        offset: pos,
171                        need: 2,
172                        have: packet.len(),
173                    });
174                }
175                let ptr = (((len_or_ptr & 0x3F) as usize) << 8) | (packet[pos + 1] as usize);
176
177                if !followed_pointer {
178                    bytes_consumed = pos - offset + 2;
179                    followed_pointer = true;
180                }
181
182                hops += 1;
183                if hops > DNS_MAX_POINTER_HOPS {
184                    return Err(FieldError::InvalidValue(
185                        "DNS name compression loop detected".to_string(),
186                    ));
187                }
188
189                if ptr >= packet.len() {
190                    return Err(FieldError::InvalidValue(format!(
191                        "DNS compression pointer {:#06x} out of bounds (packet len {})",
192                        ptr,
193                        packet.len()
194                    )));
195                }
196
197                pos = ptr;
198            } else {
199                // Regular label
200                let label_len = len_or_ptr as usize;
201                if label_len > DNS_MAX_LABEL_LEN {
202                    return Err(FieldError::InvalidValue(format!(
203                        "DNS label too long: {} bytes (max {})",
204                        label_len, DNS_MAX_LABEL_LEN
205                    )));
206                }
207                if pos + 1 + label_len > packet.len() {
208                    return Err(FieldError::BufferTooShort {
209                        offset: pos + 1,
210                        need: label_len,
211                        have: packet.len() - pos - 1,
212                    });
213                }
214                let label =
215                    String::from_utf8_lossy(&packet[pos + 1..pos + 1 + label_len]).into_owned();
216                labels.push(label);
217                pos += 1 + label_len;
218            }
219        }
220
221        Ok((Self { labels }, bytes_consumed))
222    }
223
224    /// Wire-format length of this name without compression.
225    pub fn wire_len(&self) -> usize {
226        if self.labels.is_empty() {
227            return 1; // just the root label (0x00)
228        }
229        self.labels.iter().map(|l| l.len() + 1).sum::<usize>() + 1
230    }
231}
232
233impl fmt::Display for DnsName {
234    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235        if self.labels.is_empty() {
236            write!(f, ".")
237        } else {
238            write!(f, "{}.", self.labels.join("."))
239        }
240    }
241}
242
243impl From<&str> for DnsName {
244    fn from(s: &str) -> Self {
245        DnsName::from_str_dotted(s).unwrap_or_default()
246    }
247}
248
249// ============================================================================
250// FlagValue - Named bit flags
251// ============================================================================
252
253/// Named bit flags for protocol fields.
254///
255/// Stores a flag value along with names for each bit position.
256#[derive(Debug, Clone)]
257pub struct FlagValue {
258    /// The raw flag bits.
259    pub value: u64,
260    /// Flag names indexed by bit position (LSB = index 0).
261    pub names: &'static [&'static str],
262}
263
264impl FlagValue {
265    pub fn new(value: u64, names: &'static [&'static str]) -> Self {
266        Self { value, names }
267    }
268
269    /// Check if a specific flag bit is set.
270    pub fn has(&self, bit: usize) -> bool {
271        (self.value >> bit) & 1 != 0
272    }
273
274    /// Check if a named flag is set. Returns None if name not found.
275    pub fn has_named(&self, name: &str) -> Option<bool> {
276        self.names
277            .iter()
278            .position(|&n| n == name)
279            .map(|bit| self.has(bit))
280    }
281
282    /// Set a specific flag bit.
283    pub fn set(&mut self, bit: usize) {
284        self.value |= 1u64 << bit;
285    }
286
287    /// Clear a specific flag bit.
288    pub fn clear(&mut self, bit: usize) {
289        self.value &= !(1u64 << bit);
290    }
291
292    /// Get the list of set flag names.
293    pub fn set_flags(&self) -> Vec<&'static str> {
294        let mut flags = Vec::new();
295        for (i, &name) in self.names.iter().enumerate() {
296            if !name.is_empty() && self.has(i) {
297                flags.push(name);
298            }
299        }
300        flags
301    }
302}
303
304impl fmt::Display for FlagValue {
305    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
306        let flags = self.set_flags();
307        if flags.is_empty() {
308            write!(f, "0")
309        } else {
310            write!(f, "{}", flags.join("+"))
311        }
312    }
313}
314
315impl PartialEq for FlagValue {
316    fn eq(&self, other: &Self) -> bool {
317        self.value == other.value
318    }
319}
320
321impl Eq for FlagValue {}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    // ========================================================================
328    // DnsName tests
329    // ========================================================================
330
331    #[test]
332    fn test_dns_name_from_str() {
333        let name = DnsName::from_str_dotted("www.example.com").unwrap();
334        assert_eq!(name.labels, vec!["www", "example", "com"]);
335        assert_eq!(name.to_fqdn(), "www.example.com.");
336        assert_eq!(name.to_string(), "www.example.com.");
337    }
338
339    #[test]
340    fn test_dns_name_from_str_trailing_dot() {
341        let name = DnsName::from_str_dotted("www.example.com.").unwrap();
342        assert_eq!(name.labels, vec!["www", "example", "com"]);
343    }
344
345    #[test]
346    fn test_dns_name_root() {
347        let name = DnsName::from_str_dotted(".").unwrap();
348        assert!(name.is_root());
349        assert_eq!(name.to_fqdn(), ".");
350    }
351
352    #[test]
353    fn test_dns_name_empty() {
354        let name = DnsName::from_str_dotted("").unwrap();
355        assert!(name.is_root());
356    }
357
358    #[test]
359    fn test_dns_name_encode() {
360        let name = DnsName::from_str_dotted("www.example.com").unwrap();
361        let encoded = name.encode();
362        assert_eq!(
363            encoded,
364            vec![
365                3, b'w', b'w', b'w', 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o',
366                b'm', 0
367            ]
368        );
369    }
370
371    #[test]
372    fn test_dns_name_encode_root() {
373        let name = DnsName::root();
374        assert_eq!(name.encode(), vec![0]);
375        assert_eq!(name.wire_len(), 1);
376    }
377
378    #[test]
379    fn test_dns_name_decode_simple() {
380        let data = vec![
381            3, b'w', b'w', b'w', 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm',
382            0,
383        ];
384        let (name, consumed) = DnsName::decode(&data, 0).unwrap();
385        assert_eq!(name.labels, vec!["www", "example", "com"]);
386        assert_eq!(consumed, 17);
387    }
388
389    #[test]
390    fn test_dns_name_decode_with_pointer() {
391        // Build a packet with a pointer:
392        // offset 0: \x07example\x03com\x00  (13 bytes)
393        // offset 13: \x03www\xc0\x00        (www + pointer to offset 0)
394        let mut data = vec![];
395        // "example.com" at offset 0
396        data.extend_from_slice(&[
397            7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0,
398        ]);
399        // "www" + pointer to offset 0
400        data.extend_from_slice(&[3, b'w', b'w', b'w', 0xC0, 0x00]);
401
402        let (name, consumed) = DnsName::decode(&data, 13).unwrap();
403        assert_eq!(name.labels, vec!["www", "example", "com"]);
404        assert_eq!(consumed, 6); // 1+3 for "www" + 2 for pointer
405    }
406
407    #[test]
408    fn test_dns_name_decode_pointer_loop() {
409        // Create a loop: offset 0 points to offset 2, offset 2 points to offset 0
410        let data = vec![0xC0, 0x02, 0xC0, 0x00];
411        let result = DnsName::decode(&data, 0);
412        assert!(result.is_err());
413        assert!(result.unwrap_err().to_string().contains("loop detected"));
414    }
415
416    #[test]
417    fn test_dns_name_decode_pointer_out_of_bounds() {
418        let data = vec![0xC0, 0xFF]; // Pointer to offset 0x3FF, way beyond buffer
419        let result = DnsName::decode(&data, 0);
420        assert!(result.is_err());
421    }
422
423    #[test]
424    fn test_dns_name_label_too_long() {
425        let long_label = "a".repeat(64);
426        let result = DnsName::from_str_dotted(&long_label);
427        assert!(result.is_err());
428    }
429
430    #[test]
431    fn test_dns_name_compression_roundtrip() {
432        let name1 = DnsName::from_str_dotted("www.example.com").unwrap();
433        let name2 = DnsName::from_str_dotted("mail.example.com").unwrap();
434
435        let mut compression_map = HashMap::new();
436        let mut packet = Vec::new();
437
438        // Write first name at offset 0
439        let encoded1 = name1.encode_compressed(0, &mut compression_map);
440        packet.extend_from_slice(&encoded1);
441
442        // Write second name — should compress "example.com" part
443        let encoded2 = name2.encode_compressed(packet.len(), &mut compression_map);
444        packet.extend_from_slice(&encoded2);
445
446        // The second name should use a pointer for "example.com"
447        // It should be shorter than encoding without compression
448        let uncompressed2 = name2.encode();
449        assert!(encoded2.len() < uncompressed2.len());
450
451        // Decode both names and verify
452        let (decoded1, _) = DnsName::decode(&packet, 0).unwrap();
453        assert_eq!(decoded1, name1);
454
455        let (decoded2, _) = DnsName::decode(&packet, encoded1.len()).unwrap();
456        assert_eq!(decoded2, name2);
457    }
458
459    #[test]
460    fn test_dns_name_wire_len() {
461        let name = DnsName::from_str_dotted("www.example.com").unwrap();
462        assert_eq!(name.wire_len(), 17); // 1+3 + 1+7 + 1+3 + 1
463    }
464
465    #[test]
466    fn test_dns_name_decode_at_offset() {
467        // Simulating a DNS packet where the name starts at a non-zero offset
468        let mut data = vec![0xAA, 0xBB]; // Some header bytes
469        data.extend_from_slice(&[4, b't', b'e', b's', b't', 0]);
470        let (name, consumed) = DnsName::decode(&data, 2).unwrap();
471        assert_eq!(name.labels, vec!["test"]);
472        assert_eq!(consumed, 6);
473    }
474
475    // ========================================================================
476    // FlagValue tests
477    // ========================================================================
478
479    static TCP_FLAG_NAMES: &[&str] = &["FIN", "SYN", "RST", "PSH", "ACK", "URG", "ECE", "CWR"];
480
481    #[test]
482    fn test_flag_value_display() {
483        let flags = FlagValue::new(0b00010010, TCP_FLAG_NAMES); // SYN + ACK
484        assert_eq!(flags.to_string(), "SYN+ACK");
485    }
486
487    #[test]
488    fn test_flag_value_empty() {
489        let flags = FlagValue::new(0, TCP_FLAG_NAMES);
490        assert_eq!(flags.to_string(), "0");
491    }
492
493    #[test]
494    fn test_flag_value_has() {
495        let flags = FlagValue::new(0b00000010, TCP_FLAG_NAMES); // SYN
496        assert!(flags.has(1)); // SYN bit
497        assert!(!flags.has(0)); // FIN bit
498        assert!(!flags.has(4)); // ACK bit
499    }
500
501    #[test]
502    fn test_flag_value_has_named() {
503        let flags = FlagValue::new(0b00010010, TCP_FLAG_NAMES);
504        assert_eq!(flags.has_named("SYN"), Some(true));
505        assert_eq!(flags.has_named("ACK"), Some(true));
506        assert_eq!(flags.has_named("FIN"), Some(false));
507        assert_eq!(flags.has_named("NONEXISTENT"), None);
508    }
509
510    #[test]
511    fn test_flag_value_set_clear() {
512        let mut flags = FlagValue::new(0, TCP_FLAG_NAMES);
513        flags.set(1); // Set SYN
514        assert!(flags.has(1));
515        assert_eq!(flags.value, 2);
516
517        flags.set(4); // Set ACK
518        assert_eq!(flags.to_string(), "SYN+ACK");
519
520        flags.clear(1); // Clear SYN
521        assert_eq!(flags.to_string(), "ACK");
522    }
523
524    #[test]
525    fn test_flag_value_set_flags() {
526        let flags = FlagValue::new(0b00010011, TCP_FLAG_NAMES); // FIN+SYN+ACK
527        let set = flags.set_flags();
528        assert_eq!(set, vec!["FIN", "SYN", "ACK"]);
529    }
530}