1use std::net::{IpAddr, SocketAddr};
2use std::str::FromStr;
3use std::time::Duration;
4
5use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
6use hickory_resolver::proto::rr::rdata::CAA;
7use hickory_resolver::proto::rr::RecordType as HickoryRecordType;
8use hickory_resolver::TokioAsyncResolver;
9use tracing::{debug, instrument};
10
11use super::records::{DnsRecord, RecordData, RecordType};
12use crate::error::{Result, SeerError};
13use crate::validation::normalize_domain;
14
15const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
16
17#[derive(Debug, Clone)]
18pub struct DnsResolver {
19 timeout: Duration,
20}
21
22impl Default for DnsResolver {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl DnsResolver {
29 pub fn new() -> Self {
30 Self {
31 timeout: DEFAULT_TIMEOUT,
32 }
33 }
34
35 pub fn with_timeout(mut self, timeout: Duration) -> Self {
36 self.timeout = timeout;
37 self
38 }
39
40 fn create_resolver(&self, nameserver: Option<&str>) -> Result<TokioAsyncResolver> {
41 let mut opts = ResolverOpts::default();
42 opts.timeout = self.timeout;
43 opts.attempts = 2;
44 opts.use_hosts_file = false;
45
46 let config = if let Some(ns) = nameserver {
47 let ip: IpAddr = ns
48 .parse()
49 .map_err(|_| SeerError::DnsError(format!("Invalid nameserver IP: {}", ns)))?;
50
51 let socket_addr = SocketAddr::new(ip, 53);
52 let ns_config = NameServerConfig::new(socket_addr, Protocol::Udp);
53
54 let mut config = ResolverConfig::new();
55 config.add_name_server(ns_config);
56 config
57 } else {
58 ResolverConfig::google()
59 };
60
61 Ok(TokioAsyncResolver::tokio(config, opts))
62 }
63
64 #[instrument(skip(self), fields(domain = %domain, record_type = %record_type))]
65 pub async fn resolve(
66 &self,
67 domain: &str,
68 record_type: RecordType,
69 nameserver: Option<&str>,
70 ) -> Result<Vec<DnsRecord>> {
71 let resolver = self.create_resolver(nameserver)?;
72 let domain = normalize_domain(domain)?;
73
74 debug!(
75 nameserver = nameserver.unwrap_or("system"),
76 "Resolving DNS"
77 );
78
79 match record_type {
80 RecordType::A => self.resolve_a(&resolver, &domain).await,
81 RecordType::AAAA => self.resolve_aaaa(&resolver, &domain).await,
82 RecordType::CNAME => self.resolve_cname(&resolver, &domain).await,
83 RecordType::MX => self.resolve_mx(&resolver, &domain).await,
84 RecordType::NS => self.resolve_ns(&resolver, &domain).await,
85 RecordType::TXT => self.resolve_txt(&resolver, &domain).await,
86 RecordType::SOA => self.resolve_soa(&resolver, &domain).await,
87 RecordType::PTR => self.resolve_ptr(&resolver, &domain).await,
88 RecordType::SRV => Err(SeerError::DnsError(
89 "SRV records require service name format: _service._proto.name".to_string(),
90 )),
91 RecordType::CAA => self.resolve_caa(&resolver, &domain).await,
92 RecordType::DNSKEY => self.resolve_dnskey(&resolver, &domain).await,
93 RecordType::DS => self.resolve_ds(&resolver, &domain).await,
94 RecordType::ANY => self.resolve_any(&resolver, &domain).await,
95 _ => Err(SeerError::DnsError(format!(
96 "Record type {} not implemented",
97 record_type
98 ))),
99 }
100 }
101
102 pub async fn resolve_srv(
103 &self,
104 service: &str,
105 protocol: &str,
106 domain: &str,
107 nameserver: Option<&str>,
108 ) -> Result<Vec<DnsRecord>> {
109 if !is_valid_srv_label(service) {
111 return Err(SeerError::DnsError(format!(
112 "Invalid SRV service name: {}",
113 service
114 )));
115 }
116 if !is_valid_srv_label(protocol) {
117 return Err(SeerError::DnsError(format!(
118 "Invalid SRV protocol name: {}",
119 protocol
120 )));
121 }
122
123 let resolver = self.create_resolver(nameserver)?;
124 let query_name = format!("_{}._{}.{}", service, protocol, domain);
125
126 let response = resolver
127 .srv_lookup(&query_name)
128 .await
129 .map_err(|e| SeerError::DnsError(format!("SRV lookup failed: {}", e)))?;
130
131 let records = response
132 .iter()
133 .map(|srv| DnsRecord {
134 name: query_name.clone(),
135 record_type: RecordType::SRV,
136 ttl: response.as_lookup().record_iter().next().map(|r| r.ttl()).unwrap_or(0),
137 data: RecordData::SRV {
138 priority: srv.priority(),
139 weight: srv.weight(),
140 port: srv.port(),
141 target: srv.target().to_string(),
142 },
143 })
144 .collect();
145
146 Ok(records)
147 }
148
149 async fn resolve_a(&self, resolver: &TokioAsyncResolver, domain: &str) -> Result<Vec<DnsRecord>> {
150 let response = resolver
151 .ipv4_lookup(domain)
152 .await
153 .map_err(|e| SeerError::DnsError(format!("A lookup failed: {}", e)))?;
154
155 let ttl = response.as_lookup().record_iter().next().map(|r| r.ttl()).unwrap_or(0);
156
157 let records = response
158 .iter()
159 .map(|addr| DnsRecord {
160 name: domain.to_string(),
161 record_type: RecordType::A,
162 ttl,
163 data: RecordData::A {
164 address: addr.to_string(),
165 },
166 })
167 .collect();
168
169 Ok(records)
170 }
171
172 async fn resolve_aaaa(
173 &self,
174 resolver: &TokioAsyncResolver,
175 domain: &str,
176 ) -> Result<Vec<DnsRecord>> {
177 let response = resolver
178 .ipv6_lookup(domain)
179 .await
180 .map_err(|e| SeerError::DnsError(format!("AAAA lookup failed: {}", e)))?;
181
182 let ttl = response.as_lookup().record_iter().next().map(|r| r.ttl()).unwrap_or(0);
183
184 let records = response
185 .iter()
186 .map(|addr| DnsRecord {
187 name: domain.to_string(),
188 record_type: RecordType::AAAA,
189 ttl,
190 data: RecordData::AAAA {
191 address: addr.to_string(),
192 },
193 })
194 .collect();
195
196 Ok(records)
197 }
198
199 async fn resolve_cname(
200 &self,
201 resolver: &TokioAsyncResolver,
202 domain: &str,
203 ) -> Result<Vec<DnsRecord>> {
204 let response = resolver
205 .lookup(domain, HickoryRecordType::CNAME)
206 .await
207 .map_err(|e| SeerError::DnsError(format!("CNAME lookup failed: {}", e)))?;
208
209 let records = response
210 .record_iter()
211 .filter_map(|record| {
212 if let Some(rdata) = record.data() {
213 if let Some(cname) = rdata.as_cname() {
214 return Some(DnsRecord {
215 name: domain.to_string(),
216 record_type: RecordType::CNAME,
217 ttl: record.ttl(),
218 data: RecordData::CNAME {
219 target: cname.0.to_string(),
220 },
221 });
222 }
223 }
224 None
225 })
226 .collect();
227
228 Ok(records)
229 }
230
231 async fn resolve_mx(&self, resolver: &TokioAsyncResolver, domain: &str) -> Result<Vec<DnsRecord>> {
232 let response = resolver
233 .mx_lookup(domain)
234 .await
235 .map_err(|e| SeerError::DnsError(format!("MX lookup failed: {}", e)))?;
236
237 let ttl = response.as_lookup().record_iter().next().map(|r| r.ttl()).unwrap_or(0);
238
239 let mut records: Vec<DnsRecord> = response
240 .iter()
241 .map(|mx| DnsRecord {
242 name: domain.to_string(),
243 record_type: RecordType::MX,
244 ttl,
245 data: RecordData::MX {
246 preference: mx.preference(),
247 exchange: mx.exchange().to_string(),
248 },
249 })
250 .collect();
251
252 records.sort_by_key(|r| {
253 if let RecordData::MX { preference, .. } = &r.data {
254 *preference
255 } else {
256 0
257 }
258 });
259
260 Ok(records)
261 }
262
263 async fn resolve_ns(&self, resolver: &TokioAsyncResolver, domain: &str) -> Result<Vec<DnsRecord>> {
264 let response = resolver
265 .ns_lookup(domain)
266 .await
267 .map_err(|e| SeerError::DnsError(format!("NS lookup failed: {}", e)))?;
268
269 let ttl = response.as_lookup().record_iter().next().map(|r| r.ttl()).unwrap_or(0);
270
271 let records = response
272 .iter()
273 .map(|ns| DnsRecord {
274 name: domain.to_string(),
275 record_type: RecordType::NS,
276 ttl,
277 data: RecordData::NS {
278 nameserver: ns.0.to_string(),
279 },
280 })
281 .collect();
282
283 Ok(records)
284 }
285
286 async fn resolve_txt(
287 &self,
288 resolver: &TokioAsyncResolver,
289 domain: &str,
290 ) -> Result<Vec<DnsRecord>> {
291 let response = resolver
292 .txt_lookup(domain)
293 .await
294 .map_err(|e| SeerError::DnsError(format!("TXT lookup failed: {}", e)))?;
295
296 let ttl = response.as_lookup().record_iter().next().map(|r| r.ttl()).unwrap_or(0);
297
298 let records = response
299 .iter()
300 .map(|txt| {
301 let text = txt
302 .iter()
303 .map(|data| String::from_utf8_lossy(data).to_string())
304 .collect::<Vec<_>>()
305 .join("");
306
307 DnsRecord {
308 name: domain.to_string(),
309 record_type: RecordType::TXT,
310 ttl,
311 data: RecordData::TXT { text },
312 }
313 })
314 .collect();
315
316 Ok(records)
317 }
318
319 async fn resolve_soa(
320 &self,
321 resolver: &TokioAsyncResolver,
322 domain: &str,
323 ) -> Result<Vec<DnsRecord>> {
324 let response = resolver
325 .soa_lookup(domain)
326 .await
327 .map_err(|e| SeerError::DnsError(format!("SOA lookup failed: {}", e)))?;
328
329 let ttl = response.as_lookup().record_iter().next().map(|r| r.ttl()).unwrap_or(0);
330
331 let records = response
332 .iter()
333 .map(|soa| DnsRecord {
334 name: domain.to_string(),
335 record_type: RecordType::SOA,
336 ttl,
337 data: RecordData::SOA {
338 mname: soa.mname().to_string(),
339 rname: soa.rname().to_string(),
340 serial: soa.serial(),
341 refresh: soa.refresh().try_into().unwrap_or(0),
342 retry: soa.retry().try_into().unwrap_or(0),
343 expire: soa.expire().try_into().unwrap_or(0),
344 minimum: soa.minimum(),
345 },
346 })
347 .collect();
348
349 Ok(records)
350 }
351
352 async fn resolve_ptr(
353 &self,
354 resolver: &TokioAsyncResolver,
355 query: &str,
356 ) -> Result<Vec<DnsRecord>> {
357 let query = if let Ok(ip) = IpAddr::from_str(query) {
359 reverse_dns_name(&ip)
360 } else {
361 query.to_string()
362 };
363
364 let response = resolver
365 .lookup(&query, HickoryRecordType::PTR)
366 .await
367 .map_err(|e| SeerError::DnsError(format!("PTR lookup failed: {}", e)))?;
368
369 let records = response
370 .record_iter()
371 .filter_map(|record| {
372 if let Some(rdata) = record.data() {
373 if let Some(ptr) = rdata.as_ptr() {
374 return Some(DnsRecord {
375 name: query.clone(),
376 record_type: RecordType::PTR,
377 ttl: record.ttl(),
378 data: RecordData::PTR {
379 target: ptr.0.to_string(),
380 },
381 });
382 }
383 }
384 None
385 })
386 .collect();
387
388 Ok(records)
389 }
390
391 async fn resolve_caa(
392 &self,
393 resolver: &TokioAsyncResolver,
394 domain: &str,
395 ) -> Result<Vec<DnsRecord>> {
396 let response = resolver
397 .lookup(domain, HickoryRecordType::CAA)
398 .await
399 .map_err(|e| SeerError::DnsError(format!("CAA lookup failed: {}", e)))?;
400
401 let records = response
402 .record_iter()
403 .filter_map(|record| {
404 if let Some(rdata) = record.data() {
405 if let Some(caa) = rdata.as_caa() {
406 let (flags, tag, value) = parse_caa(caa);
407 return Some(DnsRecord {
408 name: domain.to_string(),
409 record_type: RecordType::CAA,
410 ttl: record.ttl(),
411 data: RecordData::CAA { flags, tag, value },
412 });
413 }
414 }
415 None
416 })
417 .collect();
418
419 Ok(records)
420 }
421
422 async fn resolve_dnskey(
423 &self,
424 resolver: &TokioAsyncResolver,
425 domain: &str,
426 ) -> Result<Vec<DnsRecord>> {
427 use hickory_resolver::proto::rr::RData as HickoryRData;
428
429 let response = resolver
430 .lookup(domain, HickoryRecordType::DNSKEY)
431 .await
432 .map_err(|e| SeerError::DnsError(format!("DNSKEY lookup failed: {}", e)))?;
433
434 let records = response
435 .record_iter()
436 .filter_map(|record| {
437 if let Some(HickoryRData::DNSSEC(dnssec_rdata)) = record.data() {
438 if let Some(dnskey) = dnssec_rdata.as_dnskey() {
439 use base64::{engine::general_purpose::STANDARD, Engine};
440 let public_key = STANDARD.encode(dnskey.public_key());
441 return Some(DnsRecord {
442 name: domain.to_string(),
443 record_type: RecordType::DNSKEY,
444 ttl: record.ttl(),
445 data: RecordData::DNSKEY {
446 flags: dnskey.flags(),
447 protocol: 3, algorithm: u8::from(dnskey.algorithm()),
449 public_key,
450 },
451 });
452 }
453 }
454 None
455 })
456 .collect();
457
458 Ok(records)
459 }
460
461 async fn resolve_ds(
462 &self,
463 resolver: &TokioAsyncResolver,
464 domain: &str,
465 ) -> Result<Vec<DnsRecord>> {
466 use hickory_resolver::proto::rr::RData as HickoryRData;
467
468 let response = resolver
469 .lookup(domain, HickoryRecordType::DS)
470 .await
471 .map_err(|e| SeerError::DnsError(format!("DS lookup failed: {}", e)))?;
472
473 let records = response
474 .record_iter()
475 .filter_map(|record| {
476 if let Some(HickoryRData::DNSSEC(dnssec_rdata)) = record.data() {
477 if let Some(ds) = dnssec_rdata.as_ds() {
478 let digest = ds
479 .digest()
480 .iter()
481 .map(|b| format!("{:02X}", b))
482 .collect::<String>();
483 return Some(DnsRecord {
484 name: domain.to_string(),
485 record_type: RecordType::DS,
486 ttl: record.ttl(),
487 data: RecordData::DS {
488 key_tag: ds.key_tag(),
489 algorithm: u8::from(ds.algorithm()),
490 digest_type: u8::from(ds.digest_type()),
491 digest,
492 },
493 });
494 }
495 }
496 None
497 })
498 .collect();
499
500 Ok(records)
501 }
502
503 async fn resolve_any(
504 &self,
505 resolver: &TokioAsyncResolver,
506 domain: &str,
507 ) -> Result<Vec<DnsRecord>> {
508 let mut all_records = Vec::new();
509
510 let record_types = [
512 RecordType::A,
513 RecordType::AAAA,
514 RecordType::MX,
515 RecordType::NS,
516 RecordType::TXT,
517 RecordType::SOA,
518 RecordType::CAA,
519 ];
520
521 for record_type in record_types {
522 match self.resolve_type(resolver, domain, record_type).await {
523 Ok(records) => all_records.extend(records),
524 Err(_) => continue, }
526 }
527
528 Ok(all_records)
529 }
530
531 async fn resolve_type(
532 &self,
533 resolver: &TokioAsyncResolver,
534 domain: &str,
535 record_type: RecordType,
536 ) -> Result<Vec<DnsRecord>> {
537 match record_type {
538 RecordType::A => self.resolve_a(resolver, domain).await,
539 RecordType::AAAA => self.resolve_aaaa(resolver, domain).await,
540 RecordType::CNAME => self.resolve_cname(resolver, domain).await,
541 RecordType::MX => self.resolve_mx(resolver, domain).await,
542 RecordType::NS => self.resolve_ns(resolver, domain).await,
543 RecordType::TXT => self.resolve_txt(resolver, domain).await,
544 RecordType::SOA => self.resolve_soa(resolver, domain).await,
545 RecordType::CAA => self.resolve_caa(resolver, domain).await,
546 RecordType::DNSKEY => self.resolve_dnskey(resolver, domain).await,
547 RecordType::DS => self.resolve_ds(resolver, domain).await,
548 _ => Err(SeerError::DnsError("Unsupported record type".to_string())),
549 }
550 }
551}
552
553fn reverse_dns_name(ip: &IpAddr) -> String {
556 match ip {
557 IpAddr::V4(addr) => {
558 let octets = addr.octets();
559 format!(
560 "{}.{}.{}.{}.in-addr.arpa",
561 octets[3], octets[2], octets[1], octets[0]
562 )
563 }
564 IpAddr::V6(addr) => {
565 let segments = addr.segments();
566 let nibbles: Vec<String> = segments
567 .iter()
568 .flat_map(|s| {
569 let hex = format!("{:04x}", s);
570 hex.chars().map(|c| c.to_string()).collect::<Vec<_>>()
571 })
572 .rev()
573 .collect();
574 format!("{}.ip6.arpa", nibbles.join("."))
575 }
576 }
577}
578
579fn parse_caa(caa: &CAA) -> (u8, String, String) {
580 let flags = if caa.issuer_critical() { 128 } else { 0 };
581 let tag = caa.tag().as_str().to_string();
582 let value = caa.value().to_string();
583 (flags, tag, value)
584}
585
586fn is_valid_srv_label(label: &str) -> bool {
588 !label.is_empty()
589 && label.len() <= 63
590 && label
591 .chars()
592 .all(|c| c.is_ascii_alphanumeric() || c == '-')
593 && !label.starts_with('-')
594 && !label.ends_with('-')
595}