stackforge_core/layer/dns/
query.rs1use std::collections::HashMap;
14
15use super::types;
16use crate::layer::field::FieldError;
17use crate::layer::field_ext::DnsName;
18
19#[derive(Debug, Clone, PartialEq)]
21pub struct DnsQuestion {
22 pub qname: DnsName,
24 pub qtype: u16,
26 pub qclass: u16,
29}
30
31impl DnsQuestion {
32 pub fn new(qname: DnsName) -> Self {
34 Self {
35 qname,
36 qtype: types::rr_type::A,
37 qclass: types::dns_class::IN,
38 }
39 }
40
41 pub fn from_name(name: &str) -> Result<Self, FieldError> {
43 Ok(Self {
44 qname: DnsName::from_str_dotted(name)?,
45 qtype: types::rr_type::A,
46 qclass: types::dns_class::IN,
47 })
48 }
49
50 pub fn unicast_response(&self) -> bool {
52 self.qclass & 0x8000 != 0
53 }
54
55 pub fn actual_class(&self) -> u16 {
57 self.qclass & 0x7FFF
58 }
59
60 pub fn set_unicast_response(&mut self, unicast: bool) {
62 if unicast {
63 self.qclass |= 0x8000;
64 } else {
65 self.qclass &= 0x7FFF;
66 }
67 }
68
69 pub fn parse(packet: &[u8], offset: usize) -> Result<(Self, usize), FieldError> {
76 let (qname, name_len) = DnsName::decode(packet, offset)?;
77 let type_offset = offset + name_len;
78
79 if type_offset + 4 > packet.len() {
80 return Err(FieldError::BufferTooShort {
81 offset: type_offset,
82 need: 4,
83 have: packet.len() - type_offset,
84 });
85 }
86
87 let qtype = u16::from_be_bytes([packet[type_offset], packet[type_offset + 1]]);
88 let qclass = u16::from_be_bytes([packet[type_offset + 2], packet[type_offset + 3]]);
89
90 Ok((
91 Self {
92 qname,
93 qtype,
94 qclass,
95 },
96 name_len + 4,
97 ))
98 }
99
100 pub fn build(&self) -> Vec<u8> {
102 let mut out = self.qname.encode();
103 out.extend_from_slice(&self.qtype.to_be_bytes());
104 out.extend_from_slice(&self.qclass.to_be_bytes());
105 out
106 }
107
108 pub fn build_compressed(
110 &self,
111 current_offset: usize,
112 compression_map: &mut HashMap<String, u16>,
113 ) -> Vec<u8> {
114 let mut out = self
115 .qname
116 .encode_compressed(current_offset, compression_map);
117 out.extend_from_slice(&self.qtype.to_be_bytes());
118 out.extend_from_slice(&self.qclass.to_be_bytes());
119 out
120 }
121
122 pub fn summary(&self) -> String {
124 format!(
125 "{} {} {}",
126 self.qname,
127 types::dns_type_name(self.qtype),
128 types::dns_class_name(self.actual_class()),
129 )
130 }
131}
132
133impl Default for DnsQuestion {
134 fn default() -> Self {
135 Self {
136 qname: DnsName::from_str_dotted("www.example.com").unwrap_or_default(),
137 qtype: types::rr_type::A,
138 qclass: types::dns_class::IN,
139 }
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 #[test]
148 fn test_question_parse() {
149 let mut data = vec![];
151 data.extend_from_slice(&[7, b'e', b'x', b'a', b'm', b'p', b'l', b'e']);
152 data.extend_from_slice(&[3, b'c', b'o', b'm']);
153 data.push(0); data.extend_from_slice(&[0x00, 0x01]); data.extend_from_slice(&[0x00, 0x01]); let (q, consumed) = DnsQuestion::parse(&data, 0).unwrap();
158 assert_eq!(q.qname.labels, vec!["example", "com"]);
159 assert_eq!(q.qtype, 1);
160 assert_eq!(q.qclass, 1);
161 assert_eq!(consumed, data.len());
162 }
163
164 #[test]
165 fn test_question_build_roundtrip() {
166 let q = DnsQuestion {
167 qname: DnsName::from_str_dotted("www.example.com").unwrap(),
168 qtype: types::rr_type::AAAA,
169 qclass: types::dns_class::IN,
170 };
171 let built = q.build();
172 let (parsed, consumed) = DnsQuestion::parse(&built, 0).unwrap();
173 assert_eq!(parsed, q);
174 assert_eq!(consumed, built.len());
175 }
176
177 #[test]
178 fn test_question_with_pointer() {
179 let mut data = vec![];
181 data.extend_from_slice(&[
183 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0,
184 ]);
185 data.extend_from_slice(&[3, b'w', b'w', b'w', 0xC0, 0x00]);
187 data.extend_from_slice(&[0x00, 0x01]); data.extend_from_slice(&[0x00, 0x01]); let (q, consumed) = DnsQuestion::parse(&data, 13).unwrap();
191 assert_eq!(q.qname.labels, vec!["www", "example", "com"]);
192 assert_eq!(q.qtype, 1);
193 assert_eq!(consumed, 10); }
195
196 #[test]
197 fn test_question_mdns_unicast() {
198 let mut q = DnsQuestion::new(DnsName::from_str_dotted("test.local").unwrap());
199 assert!(!q.unicast_response());
200 assert_eq!(q.actual_class(), 1);
201
202 q.set_unicast_response(true);
203 assert!(q.unicast_response());
204 assert_eq!(q.actual_class(), 1);
205 assert_eq!(q.qclass, 0x8001);
206 }
207
208 #[test]
209 fn test_question_summary() {
210 let q = DnsQuestion {
211 qname: DnsName::from_str_dotted("example.com").unwrap(),
212 qtype: types::rr_type::MX,
213 qclass: types::dns_class::IN,
214 };
215 let summary = q.summary();
216 assert!(summary.contains("example.com"));
217 assert!(summary.contains("MX"));
218 }
219
220 #[test]
221 fn test_question_from_name() {
222 let q = DnsQuestion::from_name("google.com").unwrap();
223 assert_eq!(q.qname.labels, vec!["google", "com"]);
224 assert_eq!(q.qtype, types::rr_type::A);
225 }
226
227 #[test]
228 fn test_question_buffer_too_short() {
229 let data = vec![4, b't', b'e', b's', b't', 0];
231 let result = DnsQuestion::parse(&data, 0);
232 assert!(result.is_err());
233 }
234}