1use std::collections::HashMap;
9use std::fmt;
10
11use super::field::FieldError;
12
13const DNS_MAX_POINTER_HOPS: usize = 128;
19const DNS_MAX_LABEL_LEN: usize = 63;
21const DNS_MAX_NAME_LEN: usize = 253;
23const DNS_POINTER_FLAG: u8 = 0xC0;
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
30pub struct DnsName {
31 pub labels: Vec<String>,
33}
34
35impl DnsName {
36 #[must_use]
38 pub fn new(labels: Vec<String>) -> Self {
39 Self { labels }
40 }
41
42 #[must_use]
44 pub fn root() -> Self {
45 Self { labels: vec![] }
46 }
47
48 pub fn from_str_dotted(s: &str) -> Result<Self, FieldError> {
52 if s.is_empty() || s == "." {
53 return Ok(Self::root());
54 }
55 let s = s.strip_suffix('.').unwrap_or(s);
56 let labels: Vec<String> = s.split('.').map(std::string::ToString::to_string).collect();
57 for label in &labels {
59 if label.len() > DNS_MAX_LABEL_LEN {
60 return Err(FieldError::InvalidValue(format!(
61 "DNS label too long: {} bytes (max {})",
62 label.len(),
63 DNS_MAX_LABEL_LEN
64 )));
65 }
66 }
67 let total_len: usize = labels.iter().map(|l| l.len() + 1).sum::<usize>() + 1;
68 if total_len > DNS_MAX_NAME_LEN + 2 {
69 return Err(FieldError::InvalidValue(format!(
70 "DNS name too long: {total_len} bytes (max {DNS_MAX_NAME_LEN})"
71 )));
72 }
73 Ok(Self { labels })
74 }
75
76 #[must_use]
78 pub fn is_root(&self) -> bool {
79 self.labels.is_empty()
80 }
81
82 #[must_use]
85 pub fn to_fqdn(&self) -> String {
86 if self.labels.is_empty() {
87 return ".".to_string();
88 }
89 format!("{}.", self.labels.join("."))
90 }
91
92 #[must_use]
95 pub fn encode(&self) -> Vec<u8> {
96 let mut out = Vec::new();
97 for label in &self.labels {
98 out.push(label.len() as u8);
99 out.extend_from_slice(label.as_bytes());
100 }
101 out.push(0); out
103 }
104
105 pub fn encode_compressed(
109 &self,
110 current_offset: usize,
111 compression_map: &mut HashMap<String, u16>,
112 ) -> Vec<u8> {
113 let mut out = Vec::new();
114 let mut offset = current_offset;
115
116 for i in 0..self.labels.len() {
117 let suffix = self.labels[i..].join(".");
119 if let Some(&ptr) = compression_map.get(&suffix) {
120 out.push(DNS_POINTER_FLAG | ((ptr >> 8) as u8));
122 out.push((ptr & 0xFF) as u8);
123 return out;
124 }
125 if offset < 0x3FFF {
127 compression_map.insert(suffix, offset as u16);
128 }
129 let label = &self.labels[i];
131 out.push(label.len() as u8);
132 out.extend_from_slice(label.as_bytes());
133 offset += 1 + label.len();
134 }
135 out.push(0); out
137 }
138
139 pub fn decode(packet: &[u8], offset: usize) -> Result<(Self, usize), FieldError> {
147 let mut labels = Vec::new();
148 let mut pos = offset;
149 let mut bytes_consumed = 0;
150 let mut followed_pointer = false;
151 let mut hops = 0;
152
153 loop {
154 if pos >= packet.len() {
155 return Err(FieldError::BufferTooShort {
156 offset: pos,
157 need: 1,
158 have: packet.len(),
159 });
160 }
161
162 let len_or_ptr = packet[pos];
163
164 if len_or_ptr == 0 {
165 if !followed_pointer {
167 bytes_consumed = pos - offset + 1;
168 }
169 break;
170 } else if len_or_ptr & DNS_POINTER_FLAG == DNS_POINTER_FLAG {
171 if pos + 1 >= packet.len() {
173 return Err(FieldError::BufferTooShort {
174 offset: pos,
175 need: 2,
176 have: packet.len(),
177 });
178 }
179 let ptr = (((len_or_ptr & 0x3F) as usize) << 8) | (packet[pos + 1] as usize);
180
181 if !followed_pointer {
182 bytes_consumed = pos - offset + 2;
183 followed_pointer = true;
184 }
185
186 hops += 1;
187 if hops > DNS_MAX_POINTER_HOPS {
188 return Err(FieldError::InvalidValue(
189 "DNS name compression loop detected".to_string(),
190 ));
191 }
192
193 if ptr >= packet.len() {
194 return Err(FieldError::InvalidValue(format!(
195 "DNS compression pointer {:#06x} out of bounds (packet len {})",
196 ptr,
197 packet.len()
198 )));
199 }
200
201 pos = ptr;
202 } else {
203 let label_len = len_or_ptr as usize;
205 if label_len > DNS_MAX_LABEL_LEN {
206 return Err(FieldError::InvalidValue(format!(
207 "DNS label too long: {label_len} bytes (max {DNS_MAX_LABEL_LEN})"
208 )));
209 }
210 if pos + 1 + label_len > packet.len() {
211 return Err(FieldError::BufferTooShort {
212 offset: pos + 1,
213 need: label_len,
214 have: packet.len() - pos - 1,
215 });
216 }
217 let label =
218 String::from_utf8_lossy(&packet[pos + 1..pos + 1 + label_len]).into_owned();
219 labels.push(label);
220 pos += 1 + label_len;
221 }
222 }
223
224 Ok((Self { labels }, bytes_consumed))
225 }
226
227 #[must_use]
229 pub fn wire_len(&self) -> usize {
230 if self.labels.is_empty() {
231 return 1; }
233 self.labels.iter().map(|l| l.len() + 1).sum::<usize>() + 1
234 }
235}
236
237impl fmt::Display for DnsName {
238 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
239 if self.labels.is_empty() {
240 write!(f, ".")
241 } else {
242 write!(f, "{}.", self.labels.join("."))
243 }
244 }
245}
246
247impl From<&str> for DnsName {
248 fn from(s: &str) -> Self {
249 DnsName::from_str_dotted(s).unwrap_or_default()
250 }
251}
252
253#[derive(Debug, Clone)]
261pub struct FlagValue {
262 pub value: u64,
264 pub names: &'static [&'static str],
266}
267
268impl FlagValue {
269 #[must_use]
270 pub fn new(value: u64, names: &'static [&'static str]) -> Self {
271 Self { value, names }
272 }
273
274 #[must_use]
276 pub fn has(&self, bit: usize) -> bool {
277 (self.value >> bit) & 1 != 0
278 }
279
280 #[must_use]
282 pub fn has_named(&self, name: &str) -> Option<bool> {
283 self.names
284 .iter()
285 .position(|&n| n == name)
286 .map(|bit| self.has(bit))
287 }
288
289 pub fn set(&mut self, bit: usize) {
291 self.value |= 1u64 << bit;
292 }
293
294 pub fn clear(&mut self, bit: usize) {
296 self.value &= !(1u64 << bit);
297 }
298
299 #[must_use]
301 pub fn set_flags(&self) -> Vec<&'static str> {
302 let mut flags = Vec::new();
303 for (i, &name) in self.names.iter().enumerate() {
304 if !name.is_empty() && self.has(i) {
305 flags.push(name);
306 }
307 }
308 flags
309 }
310}
311
312impl fmt::Display for FlagValue {
313 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314 let flags = self.set_flags();
315 if flags.is_empty() {
316 write!(f, "0")
317 } else {
318 write!(f, "{}", flags.join("+"))
319 }
320 }
321}
322
323impl PartialEq for FlagValue {
324 fn eq(&self, other: &Self) -> bool {
325 self.value == other.value
326 }
327}
328
329impl Eq for FlagValue {}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 #[test]
340 fn test_dns_name_from_str() {
341 let name = DnsName::from_str_dotted("www.example.com").unwrap();
342 assert_eq!(name.labels, vec!["www", "example", "com"]);
343 assert_eq!(name.to_fqdn(), "www.example.com.");
344 assert_eq!(name.to_string(), "www.example.com.");
345 }
346
347 #[test]
348 fn test_dns_name_from_str_trailing_dot() {
349 let name = DnsName::from_str_dotted("www.example.com.").unwrap();
350 assert_eq!(name.labels, vec!["www", "example", "com"]);
351 }
352
353 #[test]
354 fn test_dns_name_root() {
355 let name = DnsName::from_str_dotted(".").unwrap();
356 assert!(name.is_root());
357 assert_eq!(name.to_fqdn(), ".");
358 }
359
360 #[test]
361 fn test_dns_name_empty() {
362 let name = DnsName::from_str_dotted("").unwrap();
363 assert!(name.is_root());
364 }
365
366 #[test]
367 fn test_dns_name_encode() {
368 let name = DnsName::from_str_dotted("www.example.com").unwrap();
369 let encoded = name.encode();
370 assert_eq!(
371 encoded,
372 vec![
373 3, b'w', b'w', b'w', 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o',
374 b'm', 0
375 ]
376 );
377 }
378
379 #[test]
380 fn test_dns_name_encode_root() {
381 let name = DnsName::root();
382 assert_eq!(name.encode(), vec![0]);
383 assert_eq!(name.wire_len(), 1);
384 }
385
386 #[test]
387 fn test_dns_name_decode_simple() {
388 let data = vec![
389 3, b'w', b'w', b'w', 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm',
390 0,
391 ];
392 let (name, consumed) = DnsName::decode(&data, 0).unwrap();
393 assert_eq!(name.labels, vec!["www", "example", "com"]);
394 assert_eq!(consumed, 17);
395 }
396
397 #[test]
398 fn test_dns_name_decode_with_pointer() {
399 let mut data = vec![];
403 data.extend_from_slice(&[
405 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0,
406 ]);
407 data.extend_from_slice(&[3, b'w', b'w', b'w', 0xC0, 0x00]);
409
410 let (name, consumed) = DnsName::decode(&data, 13).unwrap();
411 assert_eq!(name.labels, vec!["www", "example", "com"]);
412 assert_eq!(consumed, 6); }
414
415 #[test]
416 fn test_dns_name_decode_pointer_loop() {
417 let data = vec![0xC0, 0x02, 0xC0, 0x00];
419 let result = DnsName::decode(&data, 0);
420 assert!(result.is_err());
421 assert!(result.unwrap_err().to_string().contains("loop detected"));
422 }
423
424 #[test]
425 fn test_dns_name_decode_pointer_out_of_bounds() {
426 let data = vec![0xC0, 0xFF]; let result = DnsName::decode(&data, 0);
428 assert!(result.is_err());
429 }
430
431 #[test]
432 fn test_dns_name_label_too_long() {
433 let long_label = "a".repeat(64);
434 let result = DnsName::from_str_dotted(&long_label);
435 assert!(result.is_err());
436 }
437
438 #[test]
439 fn test_dns_name_compression_roundtrip() {
440 let name1 = DnsName::from_str_dotted("www.example.com").unwrap();
441 let name2 = DnsName::from_str_dotted("mail.example.com").unwrap();
442
443 let mut compression_map = HashMap::new();
444 let mut packet = Vec::new();
445
446 let encoded1 = name1.encode_compressed(0, &mut compression_map);
448 packet.extend_from_slice(&encoded1);
449
450 let encoded2 = name2.encode_compressed(packet.len(), &mut compression_map);
452 packet.extend_from_slice(&encoded2);
453
454 let uncompressed2 = name2.encode();
457 assert!(encoded2.len() < uncompressed2.len());
458
459 let (decoded1, _) = DnsName::decode(&packet, 0).unwrap();
461 assert_eq!(decoded1, name1);
462
463 let (decoded2, _) = DnsName::decode(&packet, encoded1.len()).unwrap();
464 assert_eq!(decoded2, name2);
465 }
466
467 #[test]
468 fn test_dns_name_wire_len() {
469 let name = DnsName::from_str_dotted("www.example.com").unwrap();
470 assert_eq!(name.wire_len(), 17); }
472
473 #[test]
474 fn test_dns_name_decode_at_offset() {
475 let mut data = vec![0xAA, 0xBB]; data.extend_from_slice(&[4, b't', b'e', b's', b't', 0]);
478 let (name, consumed) = DnsName::decode(&data, 2).unwrap();
479 assert_eq!(name.labels, vec!["test"]);
480 assert_eq!(consumed, 6);
481 }
482
483 static TCP_FLAG_NAMES: &[&str] = &["FIN", "SYN", "RST", "PSH", "ACK", "URG", "ECE", "CWR"];
488
489 #[test]
490 fn test_flag_value_display() {
491 let flags = FlagValue::new(0b00010010, TCP_FLAG_NAMES); assert_eq!(flags.to_string(), "SYN+ACK");
493 }
494
495 #[test]
496 fn test_flag_value_empty() {
497 let flags = FlagValue::new(0, TCP_FLAG_NAMES);
498 assert_eq!(flags.to_string(), "0");
499 }
500
501 #[test]
502 fn test_flag_value_has() {
503 let flags = FlagValue::new(0b00000010, TCP_FLAG_NAMES); assert!(flags.has(1)); assert!(!flags.has(0)); assert!(!flags.has(4)); }
508
509 #[test]
510 fn test_flag_value_has_named() {
511 let flags = FlagValue::new(0b00010010, TCP_FLAG_NAMES);
512 assert_eq!(flags.has_named("SYN"), Some(true));
513 assert_eq!(flags.has_named("ACK"), Some(true));
514 assert_eq!(flags.has_named("FIN"), Some(false));
515 assert_eq!(flags.has_named("NONEXISTENT"), None);
516 }
517
518 #[test]
519 fn test_flag_value_set_clear() {
520 let mut flags = FlagValue::new(0, TCP_FLAG_NAMES);
521 flags.set(1); assert!(flags.has(1));
523 assert_eq!(flags.value, 2);
524
525 flags.set(4); assert_eq!(flags.to_string(), "SYN+ACK");
527
528 flags.clear(1); assert_eq!(flags.to_string(), "ACK");
530 }
531
532 #[test]
533 fn test_flag_value_set_flags() {
534 let flags = FlagValue::new(0b00010011, TCP_FLAG_NAMES); let set = flags.set_flags();
536 assert_eq!(set, vec!["FIN", "SYN", "ACK"]);
537 }
538}