stackforge_core/layer/dns/
query.rs1use std::collections::HashMap;
14
15use super::types;
16use crate::layer::field::FieldError;
17use crate::layer::field_ext::DnsName;
18
19#[derive(Debug, Clone, PartialEq)]
21pub struct DnsQuestion {
22 pub qname: DnsName,
24 pub qtype: u16,
26 pub qclass: u16,
29}
30
31impl DnsQuestion {
32 #[must_use]
34 pub fn new(qname: DnsName) -> Self {
35 Self {
36 qname,
37 qtype: types::rr_type::A,
38 qclass: types::dns_class::IN,
39 }
40 }
41
42 pub fn from_name(name: &str) -> Result<Self, FieldError> {
44 Ok(Self {
45 qname: DnsName::from_str_dotted(name)?,
46 qtype: types::rr_type::A,
47 qclass: types::dns_class::IN,
48 })
49 }
50
51 #[must_use]
53 pub fn unicast_response(&self) -> bool {
54 self.qclass & 0x8000 != 0
55 }
56
57 #[must_use]
59 pub fn actual_class(&self) -> u16 {
60 self.qclass & 0x7FFF
61 }
62
63 pub fn set_unicast_response(&mut self, unicast: bool) {
65 if unicast {
66 self.qclass |= 0x8000;
67 } else {
68 self.qclass &= 0x7FFF;
69 }
70 }
71
72 pub fn parse(packet: &[u8], offset: usize) -> Result<(Self, usize), FieldError> {
79 let (qname, name_len) = DnsName::decode(packet, offset)?;
80 let type_offset = offset + name_len;
81
82 if type_offset + 4 > packet.len() {
83 return Err(FieldError::BufferTooShort {
84 offset: type_offset,
85 need: 4,
86 have: packet.len() - type_offset,
87 });
88 }
89
90 let qtype = u16::from_be_bytes([packet[type_offset], packet[type_offset + 1]]);
91 let qclass = u16::from_be_bytes([packet[type_offset + 2], packet[type_offset + 3]]);
92
93 Ok((
94 Self {
95 qname,
96 qtype,
97 qclass,
98 },
99 name_len + 4,
100 ))
101 }
102
103 #[must_use]
105 pub fn build(&self) -> Vec<u8> {
106 let mut out = self.qname.encode();
107 out.extend_from_slice(&self.qtype.to_be_bytes());
108 out.extend_from_slice(&self.qclass.to_be_bytes());
109 out
110 }
111
112 pub fn build_compressed(
114 &self,
115 current_offset: usize,
116 compression_map: &mut HashMap<String, u16>,
117 ) -> Vec<u8> {
118 let mut out = self
119 .qname
120 .encode_compressed(current_offset, compression_map);
121 out.extend_from_slice(&self.qtype.to_be_bytes());
122 out.extend_from_slice(&self.qclass.to_be_bytes());
123 out
124 }
125
126 #[must_use]
128 pub fn summary(&self) -> String {
129 format!(
130 "{} {} {}",
131 self.qname,
132 types::dns_type_name(self.qtype),
133 types::dns_class_name(self.actual_class()),
134 )
135 }
136}
137
138impl Default for DnsQuestion {
139 fn default() -> Self {
140 Self {
141 qname: DnsName::from_str_dotted("www.example.com").unwrap_or_default(),
142 qtype: types::rr_type::A,
143 qclass: types::dns_class::IN,
144 }
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[test]
153 fn test_question_parse() {
154 let mut data = vec![];
156 data.extend_from_slice(&[7, b'e', b'x', b'a', b'm', b'p', b'l', b'e']);
157 data.extend_from_slice(&[3, b'c', b'o', b'm']);
158 data.push(0); data.extend_from_slice(&[0x00, 0x01]); data.extend_from_slice(&[0x00, 0x01]); let (q, consumed) = DnsQuestion::parse(&data, 0).unwrap();
163 assert_eq!(q.qname.labels, vec!["example", "com"]);
164 assert_eq!(q.qtype, 1);
165 assert_eq!(q.qclass, 1);
166 assert_eq!(consumed, data.len());
167 }
168
169 #[test]
170 fn test_question_build_roundtrip() {
171 let q = DnsQuestion {
172 qname: DnsName::from_str_dotted("www.example.com").unwrap(),
173 qtype: types::rr_type::AAAA,
174 qclass: types::dns_class::IN,
175 };
176 let built = q.build();
177 let (parsed, consumed) = DnsQuestion::parse(&built, 0).unwrap();
178 assert_eq!(parsed, q);
179 assert_eq!(consumed, built.len());
180 }
181
182 #[test]
183 fn test_question_with_pointer() {
184 let mut data = vec![];
186 data.extend_from_slice(&[
188 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0,
189 ]);
190 data.extend_from_slice(&[3, b'w', b'w', b'w', 0xC0, 0x00]);
192 data.extend_from_slice(&[0x00, 0x01]); data.extend_from_slice(&[0x00, 0x01]); let (q, consumed) = DnsQuestion::parse(&data, 13).unwrap();
196 assert_eq!(q.qname.labels, vec!["www", "example", "com"]);
197 assert_eq!(q.qtype, 1);
198 assert_eq!(consumed, 10); }
200
201 #[test]
202 fn test_question_mdns_unicast() {
203 let mut q = DnsQuestion::new(DnsName::from_str_dotted("test.local").unwrap());
204 assert!(!q.unicast_response());
205 assert_eq!(q.actual_class(), 1);
206
207 q.set_unicast_response(true);
208 assert!(q.unicast_response());
209 assert_eq!(q.actual_class(), 1);
210 assert_eq!(q.qclass, 0x8001);
211 }
212
213 #[test]
214 fn test_question_summary() {
215 let q = DnsQuestion {
216 qname: DnsName::from_str_dotted("example.com").unwrap(),
217 qtype: types::rr_type::MX,
218 qclass: types::dns_class::IN,
219 };
220 let summary = q.summary();
221 assert!(summary.contains("example.com"));
222 assert!(summary.contains("MX"));
223 }
224
225 #[test]
226 fn test_question_from_name() {
227 let q = DnsQuestion::from_name("google.com").unwrap();
228 assert_eq!(q.qname.labels, vec!["google", "com"]);
229 assert_eq!(q.qtype, types::rr_type::A);
230 }
231
232 #[test]
233 fn test_question_buffer_too_short() {
234 let data = vec![4, b't', b'e', b's', b't', 0];
236 let result = DnsQuestion::parse(&data, 0);
237 assert!(result.is_err());
238 }
239}