1use std::collections::HashMap;
9use std::fmt;
10
11use super::field::FieldError;
12
13const DNS_MAX_POINTER_HOPS: usize = 128;
19const DNS_MAX_LABEL_LEN: usize = 63;
21const DNS_MAX_NAME_LEN: usize = 253;
23const DNS_POINTER_FLAG: u8 = 0xC0;
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
30pub struct DnsName {
31 pub labels: Vec<String>,
33}
34
35impl DnsName {
36 pub fn new(labels: Vec<String>) -> Self {
38 Self { labels }
39 }
40
41 pub fn root() -> Self {
43 Self { labels: vec![] }
44 }
45
46 pub fn from_str_dotted(s: &str) -> Result<Self, FieldError> {
50 if s.is_empty() || s == "." {
51 return Ok(Self::root());
52 }
53 let s = s.strip_suffix('.').unwrap_or(s);
54 let labels: Vec<String> = s.split('.').map(|l| l.to_string()).collect();
55 for label in &labels {
57 if label.len() > DNS_MAX_LABEL_LEN {
58 return Err(FieldError::InvalidValue(format!(
59 "DNS label too long: {} bytes (max {})",
60 label.len(),
61 DNS_MAX_LABEL_LEN
62 )));
63 }
64 }
65 let total_len: usize = labels.iter().map(|l| l.len() + 1).sum::<usize>() + 1;
66 if total_len > DNS_MAX_NAME_LEN + 2 {
67 return Err(FieldError::InvalidValue(format!(
68 "DNS name too long: {} bytes (max {})",
69 total_len, DNS_MAX_NAME_LEN
70 )));
71 }
72 Ok(Self { labels })
73 }
74
75 pub fn is_root(&self) -> bool {
77 self.labels.is_empty()
78 }
79
80 pub fn to_fqdn(&self) -> String {
83 if self.labels.is_empty() {
84 return ".".to_string();
85 }
86 format!("{}.", self.labels.join("."))
87 }
88
89 pub fn encode(&self) -> Vec<u8> {
92 let mut out = Vec::new();
93 for label in &self.labels {
94 out.push(label.len() as u8);
95 out.extend_from_slice(label.as_bytes());
96 }
97 out.push(0); out
99 }
100
101 pub fn encode_compressed(
105 &self,
106 current_offset: usize,
107 compression_map: &mut HashMap<String, u16>,
108 ) -> Vec<u8> {
109 let mut out = Vec::new();
110 let mut offset = current_offset;
111
112 for i in 0..self.labels.len() {
113 let suffix = self.labels[i..].join(".");
115 if let Some(&ptr) = compression_map.get(&suffix) {
116 out.push(DNS_POINTER_FLAG | ((ptr >> 8) as u8));
118 out.push((ptr & 0xFF) as u8);
119 return out;
120 }
121 if offset < 0x3FFF {
123 compression_map.insert(suffix, offset as u16);
124 }
125 let label = &self.labels[i];
127 out.push(label.len() as u8);
128 out.extend_from_slice(label.as_bytes());
129 offset += 1 + label.len();
130 }
131 out.push(0); out
133 }
134
135 pub fn decode(packet: &[u8], offset: usize) -> Result<(Self, usize), FieldError> {
143 let mut labels = Vec::new();
144 let mut pos = offset;
145 let mut bytes_consumed = 0;
146 let mut followed_pointer = false;
147 let mut hops = 0;
148
149 loop {
150 if pos >= packet.len() {
151 return Err(FieldError::BufferTooShort {
152 offset: pos,
153 need: 1,
154 have: packet.len(),
155 });
156 }
157
158 let len_or_ptr = packet[pos];
159
160 if len_or_ptr == 0 {
161 if !followed_pointer {
163 bytes_consumed = pos - offset + 1;
164 }
165 break;
166 } else if len_or_ptr & DNS_POINTER_FLAG == DNS_POINTER_FLAG {
167 if pos + 1 >= packet.len() {
169 return Err(FieldError::BufferTooShort {
170 offset: pos,
171 need: 2,
172 have: packet.len(),
173 });
174 }
175 let ptr = (((len_or_ptr & 0x3F) as usize) << 8) | (packet[pos + 1] as usize);
176
177 if !followed_pointer {
178 bytes_consumed = pos - offset + 2;
179 followed_pointer = true;
180 }
181
182 hops += 1;
183 if hops > DNS_MAX_POINTER_HOPS {
184 return Err(FieldError::InvalidValue(
185 "DNS name compression loop detected".to_string(),
186 ));
187 }
188
189 if ptr >= packet.len() {
190 return Err(FieldError::InvalidValue(format!(
191 "DNS compression pointer {:#06x} out of bounds (packet len {})",
192 ptr,
193 packet.len()
194 )));
195 }
196
197 pos = ptr;
198 } else {
199 let label_len = len_or_ptr as usize;
201 if label_len > DNS_MAX_LABEL_LEN {
202 return Err(FieldError::InvalidValue(format!(
203 "DNS label too long: {} bytes (max {})",
204 label_len, DNS_MAX_LABEL_LEN
205 )));
206 }
207 if pos + 1 + label_len > packet.len() {
208 return Err(FieldError::BufferTooShort {
209 offset: pos + 1,
210 need: label_len,
211 have: packet.len() - pos - 1,
212 });
213 }
214 let label =
215 String::from_utf8_lossy(&packet[pos + 1..pos + 1 + label_len]).into_owned();
216 labels.push(label);
217 pos += 1 + label_len;
218 }
219 }
220
221 Ok((Self { labels }, bytes_consumed))
222 }
223
224 pub fn wire_len(&self) -> usize {
226 if self.labels.is_empty() {
227 return 1; }
229 self.labels.iter().map(|l| l.len() + 1).sum::<usize>() + 1
230 }
231}
232
233impl fmt::Display for DnsName {
234 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235 if self.labels.is_empty() {
236 write!(f, ".")
237 } else {
238 write!(f, "{}.", self.labels.join("."))
239 }
240 }
241}
242
243impl From<&str> for DnsName {
244 fn from(s: &str) -> Self {
245 DnsName::from_str_dotted(s).unwrap_or_default()
246 }
247}
248
249#[derive(Debug, Clone)]
257pub struct FlagValue {
258 pub value: u64,
260 pub names: &'static [&'static str],
262}
263
264impl FlagValue {
265 pub fn new(value: u64, names: &'static [&'static str]) -> Self {
266 Self { value, names }
267 }
268
269 pub fn has(&self, bit: usize) -> bool {
271 (self.value >> bit) & 1 != 0
272 }
273
274 pub fn has_named(&self, name: &str) -> Option<bool> {
276 self.names
277 .iter()
278 .position(|&n| n == name)
279 .map(|bit| self.has(bit))
280 }
281
282 pub fn set(&mut self, bit: usize) {
284 self.value |= 1u64 << bit;
285 }
286
287 pub fn clear(&mut self, bit: usize) {
289 self.value &= !(1u64 << bit);
290 }
291
292 pub fn set_flags(&self) -> Vec<&'static str> {
294 let mut flags = Vec::new();
295 for (i, &name) in self.names.iter().enumerate() {
296 if !name.is_empty() && self.has(i) {
297 flags.push(name);
298 }
299 }
300 flags
301 }
302}
303
304impl fmt::Display for FlagValue {
305 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
306 let flags = self.set_flags();
307 if flags.is_empty() {
308 write!(f, "0")
309 } else {
310 write!(f, "{}", flags.join("+"))
311 }
312 }
313}
314
315impl PartialEq for FlagValue {
316 fn eq(&self, other: &Self) -> bool {
317 self.value == other.value
318 }
319}
320
321impl Eq for FlagValue {}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
332 fn test_dns_name_from_str() {
333 let name = DnsName::from_str_dotted("www.example.com").unwrap();
334 assert_eq!(name.labels, vec!["www", "example", "com"]);
335 assert_eq!(name.to_fqdn(), "www.example.com.");
336 assert_eq!(name.to_string(), "www.example.com.");
337 }
338
339 #[test]
340 fn test_dns_name_from_str_trailing_dot() {
341 let name = DnsName::from_str_dotted("www.example.com.").unwrap();
342 assert_eq!(name.labels, vec!["www", "example", "com"]);
343 }
344
345 #[test]
346 fn test_dns_name_root() {
347 let name = DnsName::from_str_dotted(".").unwrap();
348 assert!(name.is_root());
349 assert_eq!(name.to_fqdn(), ".");
350 }
351
352 #[test]
353 fn test_dns_name_empty() {
354 let name = DnsName::from_str_dotted("").unwrap();
355 assert!(name.is_root());
356 }
357
358 #[test]
359 fn test_dns_name_encode() {
360 let name = DnsName::from_str_dotted("www.example.com").unwrap();
361 let encoded = name.encode();
362 assert_eq!(
363 encoded,
364 vec![
365 3, b'w', b'w', b'w', 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o',
366 b'm', 0
367 ]
368 );
369 }
370
371 #[test]
372 fn test_dns_name_encode_root() {
373 let name = DnsName::root();
374 assert_eq!(name.encode(), vec![0]);
375 assert_eq!(name.wire_len(), 1);
376 }
377
378 #[test]
379 fn test_dns_name_decode_simple() {
380 let data = vec![
381 3, b'w', b'w', b'w', 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm',
382 0,
383 ];
384 let (name, consumed) = DnsName::decode(&data, 0).unwrap();
385 assert_eq!(name.labels, vec!["www", "example", "com"]);
386 assert_eq!(consumed, 17);
387 }
388
389 #[test]
390 fn test_dns_name_decode_with_pointer() {
391 let mut data = vec![];
395 data.extend_from_slice(&[
397 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0,
398 ]);
399 data.extend_from_slice(&[3, b'w', b'w', b'w', 0xC0, 0x00]);
401
402 let (name, consumed) = DnsName::decode(&data, 13).unwrap();
403 assert_eq!(name.labels, vec!["www", "example", "com"]);
404 assert_eq!(consumed, 6); }
406
407 #[test]
408 fn test_dns_name_decode_pointer_loop() {
409 let data = vec![0xC0, 0x02, 0xC0, 0x00];
411 let result = DnsName::decode(&data, 0);
412 assert!(result.is_err());
413 assert!(result.unwrap_err().to_string().contains("loop detected"));
414 }
415
416 #[test]
417 fn test_dns_name_decode_pointer_out_of_bounds() {
418 let data = vec![0xC0, 0xFF]; let result = DnsName::decode(&data, 0);
420 assert!(result.is_err());
421 }
422
423 #[test]
424 fn test_dns_name_label_too_long() {
425 let long_label = "a".repeat(64);
426 let result = DnsName::from_str_dotted(&long_label);
427 assert!(result.is_err());
428 }
429
430 #[test]
431 fn test_dns_name_compression_roundtrip() {
432 let name1 = DnsName::from_str_dotted("www.example.com").unwrap();
433 let name2 = DnsName::from_str_dotted("mail.example.com").unwrap();
434
435 let mut compression_map = HashMap::new();
436 let mut packet = Vec::new();
437
438 let encoded1 = name1.encode_compressed(0, &mut compression_map);
440 packet.extend_from_slice(&encoded1);
441
442 let encoded2 = name2.encode_compressed(packet.len(), &mut compression_map);
444 packet.extend_from_slice(&encoded2);
445
446 let uncompressed2 = name2.encode();
449 assert!(encoded2.len() < uncompressed2.len());
450
451 let (decoded1, _) = DnsName::decode(&packet, 0).unwrap();
453 assert_eq!(decoded1, name1);
454
455 let (decoded2, _) = DnsName::decode(&packet, encoded1.len()).unwrap();
456 assert_eq!(decoded2, name2);
457 }
458
459 #[test]
460 fn test_dns_name_wire_len() {
461 let name = DnsName::from_str_dotted("www.example.com").unwrap();
462 assert_eq!(name.wire_len(), 17); }
464
465 #[test]
466 fn test_dns_name_decode_at_offset() {
467 let mut data = vec![0xAA, 0xBB]; data.extend_from_slice(&[4, b't', b'e', b's', b't', 0]);
470 let (name, consumed) = DnsName::decode(&data, 2).unwrap();
471 assert_eq!(name.labels, vec!["test"]);
472 assert_eq!(consumed, 6);
473 }
474
475 static TCP_FLAG_NAMES: &[&str] = &["FIN", "SYN", "RST", "PSH", "ACK", "URG", "ECE", "CWR"];
480
481 #[test]
482 fn test_flag_value_display() {
483 let flags = FlagValue::new(0b00010010, TCP_FLAG_NAMES); assert_eq!(flags.to_string(), "SYN+ACK");
485 }
486
487 #[test]
488 fn test_flag_value_empty() {
489 let flags = FlagValue::new(0, TCP_FLAG_NAMES);
490 assert_eq!(flags.to_string(), "0");
491 }
492
493 #[test]
494 fn test_flag_value_has() {
495 let flags = FlagValue::new(0b00000010, TCP_FLAG_NAMES); assert!(flags.has(1)); assert!(!flags.has(0)); assert!(!flags.has(4)); }
500
501 #[test]
502 fn test_flag_value_has_named() {
503 let flags = FlagValue::new(0b00010010, TCP_FLAG_NAMES);
504 assert_eq!(flags.has_named("SYN"), Some(true));
505 assert_eq!(flags.has_named("ACK"), Some(true));
506 assert_eq!(flags.has_named("FIN"), Some(false));
507 assert_eq!(flags.has_named("NONEXISTENT"), None);
508 }
509
510 #[test]
511 fn test_flag_value_set_clear() {
512 let mut flags = FlagValue::new(0, TCP_FLAG_NAMES);
513 flags.set(1); assert!(flags.has(1));
515 assert_eq!(flags.value, 2);
516
517 flags.set(4); assert_eq!(flags.to_string(), "SYN+ACK");
519
520 flags.clear(1); assert_eq!(flags.to_string(), "ACK");
522 }
523
524 #[test]
525 fn test_flag_value_set_flags() {
526 let flags = FlagValue::new(0b00010011, TCP_FLAG_NAMES); let set = flags.set_flags();
528 assert_eq!(set, vec!["FIN", "SYN", "ACK"]);
529 }
530}