1use std::net::{IpAddr, SocketAddr};
2use std::str::FromStr;
3use std::time::Duration;
4
5use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
6use hickory_resolver::proto::rr::rdata::CAA;
7use hickory_resolver::proto::rr::RecordType as HickoryRecordType;
8use hickory_resolver::TokioAsyncResolver;
9use tracing::{debug, instrument};
10
11use super::records::{DnsRecord, RecordData, RecordType};
12use crate::error::{Result, SeerError};
13use crate::validation::normalize_domain;
14
15const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
18
19#[derive(Clone)]
25pub struct DnsResolver {
26 timeout: Duration,
27 default_resolver: TokioAsyncResolver,
30}
31
32impl std::fmt::Debug for DnsResolver {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("DnsResolver")
35 .field("timeout", &self.timeout)
36 .finish()
37 }
38}
39
40impl Default for DnsResolver {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl DnsResolver {
47 pub fn new() -> Self {
49 let mut opts = ResolverOpts::default();
50 opts.timeout = DEFAULT_TIMEOUT;
51 opts.attempts = 2;
52 opts.use_hosts_file = false;
53
54 Self {
55 timeout: DEFAULT_TIMEOUT,
56 default_resolver: TokioAsyncResolver::tokio(ResolverConfig::google(), opts),
57 }
58 }
59
60 pub fn with_timeout(mut self, timeout: Duration) -> Self {
64 self.timeout = timeout;
65 let mut opts = ResolverOpts::default();
67 opts.timeout = timeout;
68 opts.attempts = 2;
69 opts.use_hosts_file = false;
70 self.default_resolver = TokioAsyncResolver::tokio(ResolverConfig::google(), opts);
71 self
72 }
73
74 fn create_custom_resolver(&self, nameserver: &str) -> Result<TokioAsyncResolver> {
75 let mut opts = ResolverOpts::default();
76 opts.timeout = self.timeout;
77 opts.attempts = 2;
78 opts.use_hosts_file = false;
79
80 let ip: IpAddr = nameserver
81 .parse()
82 .map_err(|_| SeerError::DnsError(format!("invalid nameserver IP: {}", nameserver)))?;
83
84 let socket_addr = SocketAddr::new(ip, 53);
85 let ns_config = NameServerConfig::new(socket_addr, Protocol::Udp);
86
87 let mut config = ResolverConfig::new();
88 config.add_name_server(ns_config);
89
90 Ok(TokioAsyncResolver::tokio(config, opts))
91 }
92
93 #[instrument(skip(self), fields(domain = %domain, record_type = %record_type))]
100 pub async fn resolve(
101 &self,
102 domain: &str,
103 record_type: RecordType,
104 nameserver: Option<&str>,
105 ) -> Result<Vec<DnsRecord>> {
106 let custom_resolver;
108 let resolver = if let Some(ns) = nameserver {
109 custom_resolver = self.create_custom_resolver(ns)?;
110 &custom_resolver
111 } else {
112 &self.default_resolver
113 };
114 let domain = normalize_domain(domain)?;
115
116 debug!(nameserver = nameserver.unwrap_or("system"), "Resolving DNS");
117
118 match record_type {
119 RecordType::A => self.resolve_a(resolver, &domain).await,
120 RecordType::AAAA => self.resolve_aaaa(resolver, &domain).await,
121 RecordType::CNAME => self.resolve_cname(resolver, &domain).await,
122 RecordType::MX => self.resolve_mx(resolver, &domain).await,
123 RecordType::NS => self.resolve_ns(resolver, &domain).await,
124 RecordType::TXT => self.resolve_txt(resolver, &domain).await,
125 RecordType::SOA => self.resolve_soa(resolver, &domain).await,
126 RecordType::PTR => self.resolve_ptr(resolver, &domain).await,
127 RecordType::SRV => Err(SeerError::DnsError(
128 "SRV records require service name format: _service._proto.name".to_string(),
129 )),
130 RecordType::CAA => self.resolve_caa(resolver, &domain).await,
131 RecordType::DNSKEY => self.resolve_dnskey(resolver, &domain).await,
132 RecordType::DS => self.resolve_ds(resolver, &domain).await,
133 RecordType::ANY => self.resolve_any(resolver, &domain).await,
134 _ => Err(SeerError::DnsError(format!(
135 "Record type {} not implemented",
136 record_type
137 ))),
138 }
139 }
140
141 pub async fn resolve_srv(
149 &self,
150 service: &str,
151 protocol: &str,
152 domain: &str,
153 nameserver: Option<&str>,
154 ) -> Result<Vec<DnsRecord>> {
155 if !is_valid_srv_label(service) {
157 return Err(SeerError::DnsError(format!(
158 "invalid SRV service name: {}",
159 service
160 )));
161 }
162 if !is_valid_srv_label(protocol) {
163 return Err(SeerError::DnsError(format!(
164 "invalid SRV protocol name: {}",
165 protocol
166 )));
167 }
168
169 let custom_resolver;
170 let resolver = if let Some(ns) = nameserver {
171 custom_resolver = self.create_custom_resolver(ns)?;
172 &custom_resolver
173 } else {
174 &self.default_resolver
175 };
176 let query_name = format!("_{}._{}.{}", service, protocol, domain);
177
178 let response = resolver
179 .srv_lookup(&query_name)
180 .await
181 .map_err(|e| SeerError::DnsError(format!("SRV lookup failed: {}", e)))?;
182
183 let records = response
184 .iter()
185 .map(|srv| DnsRecord {
186 name: query_name.clone(),
187 record_type: RecordType::SRV,
188 ttl: response
189 .as_lookup()
190 .record_iter()
191 .next()
192 .map(|r| r.ttl())
193 .unwrap_or(0),
194 data: RecordData::SRV {
195 priority: srv.priority(),
196 weight: srv.weight(),
197 port: srv.port(),
198 target: srv.target().to_string(),
199 },
200 })
201 .collect();
202
203 Ok(records)
204 }
205
206 async fn resolve_a(
207 &self,
208 resolver: &TokioAsyncResolver,
209 domain: &str,
210 ) -> Result<Vec<DnsRecord>> {
211 let response = resolver
212 .ipv4_lookup(domain)
213 .await
214 .map_err(|e| SeerError::DnsError(format!("A lookup failed: {}", e)))?;
215
216 let ttl = response
217 .as_lookup()
218 .record_iter()
219 .next()
220 .map(|r| r.ttl())
221 .unwrap_or(0);
222
223 let records = response
224 .iter()
225 .map(|addr| DnsRecord {
226 name: domain.to_string(),
227 record_type: RecordType::A,
228 ttl,
229 data: RecordData::A {
230 address: addr.to_string(),
231 },
232 })
233 .collect();
234
235 Ok(records)
236 }
237
238 async fn resolve_aaaa(
239 &self,
240 resolver: &TokioAsyncResolver,
241 domain: &str,
242 ) -> Result<Vec<DnsRecord>> {
243 let response = resolver
244 .ipv6_lookup(domain)
245 .await
246 .map_err(|e| SeerError::DnsError(format!("AAAA lookup failed: {}", e)))?;
247
248 let ttl = response
249 .as_lookup()
250 .record_iter()
251 .next()
252 .map(|r| r.ttl())
253 .unwrap_or(0);
254
255 let records = response
256 .iter()
257 .map(|addr| DnsRecord {
258 name: domain.to_string(),
259 record_type: RecordType::AAAA,
260 ttl,
261 data: RecordData::AAAA {
262 address: addr.to_string(),
263 },
264 })
265 .collect();
266
267 Ok(records)
268 }
269
270 async fn resolve_cname(
271 &self,
272 resolver: &TokioAsyncResolver,
273 domain: &str,
274 ) -> Result<Vec<DnsRecord>> {
275 let response = resolver
276 .lookup(domain, HickoryRecordType::CNAME)
277 .await
278 .map_err(|e| SeerError::DnsError(format!("CNAME lookup failed: {}", e)))?;
279
280 let records = response
281 .record_iter()
282 .filter_map(|record| {
283 if let Some(rdata) = record.data() {
284 if let Some(cname) = rdata.as_cname() {
285 return Some(DnsRecord {
286 name: domain.to_string(),
287 record_type: RecordType::CNAME,
288 ttl: record.ttl(),
289 data: RecordData::CNAME {
290 target: cname.0.to_string(),
291 },
292 });
293 }
294 }
295 None
296 })
297 .collect();
298
299 Ok(records)
300 }
301
302 async fn resolve_mx(
303 &self,
304 resolver: &TokioAsyncResolver,
305 domain: &str,
306 ) -> Result<Vec<DnsRecord>> {
307 let response = resolver
308 .mx_lookup(domain)
309 .await
310 .map_err(|e| SeerError::DnsError(format!("MX lookup failed: {}", e)))?;
311
312 let ttl = response
313 .as_lookup()
314 .record_iter()
315 .next()
316 .map(|r| r.ttl())
317 .unwrap_or(0);
318
319 let mut records: Vec<DnsRecord> = response
320 .iter()
321 .map(|mx| DnsRecord {
322 name: domain.to_string(),
323 record_type: RecordType::MX,
324 ttl,
325 data: RecordData::MX {
326 preference: mx.preference(),
327 exchange: mx.exchange().to_string(),
328 },
329 })
330 .collect();
331
332 records.sort_by_key(|r| {
333 if let RecordData::MX { preference, .. } = &r.data {
334 *preference
335 } else {
336 0
337 }
338 });
339
340 Ok(records)
341 }
342
343 async fn resolve_ns(
344 &self,
345 resolver: &TokioAsyncResolver,
346 domain: &str,
347 ) -> Result<Vec<DnsRecord>> {
348 let response = resolver
349 .ns_lookup(domain)
350 .await
351 .map_err(|e| SeerError::DnsError(format!("NS lookup failed: {}", e)))?;
352
353 let ttl = response
354 .as_lookup()
355 .record_iter()
356 .next()
357 .map(|r| r.ttl())
358 .unwrap_or(0);
359
360 let records = response
361 .iter()
362 .map(|ns| DnsRecord {
363 name: domain.to_string(),
364 record_type: RecordType::NS,
365 ttl,
366 data: RecordData::NS {
367 nameserver: ns.0.to_string(),
368 },
369 })
370 .collect();
371
372 Ok(records)
373 }
374
375 async fn resolve_txt(
376 &self,
377 resolver: &TokioAsyncResolver,
378 domain: &str,
379 ) -> Result<Vec<DnsRecord>> {
380 let response = resolver
381 .txt_lookup(domain)
382 .await
383 .map_err(|e| SeerError::DnsError(format!("TXT lookup failed: {}", e)))?;
384
385 let ttl = response
386 .as_lookup()
387 .record_iter()
388 .next()
389 .map(|r| r.ttl())
390 .unwrap_or(0);
391
392 let records = response
393 .iter()
394 .map(|txt| {
395 let text = txt
396 .iter()
397 .map(|data| String::from_utf8_lossy(data).to_string())
398 .collect::<Vec<_>>()
399 .join("");
400
401 DnsRecord {
402 name: domain.to_string(),
403 record_type: RecordType::TXT,
404 ttl,
405 data: RecordData::TXT { text },
406 }
407 })
408 .collect();
409
410 Ok(records)
411 }
412
413 async fn resolve_soa(
414 &self,
415 resolver: &TokioAsyncResolver,
416 domain: &str,
417 ) -> Result<Vec<DnsRecord>> {
418 let response = resolver
419 .soa_lookup(domain)
420 .await
421 .map_err(|e| SeerError::DnsError(format!("SOA lookup failed: {}", e)))?;
422
423 let ttl = response
424 .as_lookup()
425 .record_iter()
426 .next()
427 .map(|r| r.ttl())
428 .unwrap_or(0);
429
430 let records = response
431 .iter()
432 .map(|soa| DnsRecord {
433 name: domain.to_string(),
434 record_type: RecordType::SOA,
435 ttl,
436 data: RecordData::SOA {
437 mname: soa.mname().to_string(),
438 rname: soa.rname().to_string(),
439 serial: soa.serial(),
440 refresh: soa.refresh().try_into().unwrap_or(0),
441 retry: soa.retry().try_into().unwrap_or(0),
442 expire: soa.expire().try_into().unwrap_or(0),
443 minimum: soa.minimum(),
444 },
445 })
446 .collect();
447
448 Ok(records)
449 }
450
451 async fn resolve_ptr(
452 &self,
453 resolver: &TokioAsyncResolver,
454 query: &str,
455 ) -> Result<Vec<DnsRecord>> {
456 let query = if let Ok(ip) = IpAddr::from_str(query) {
458 reverse_dns_name(&ip)
459 } else {
460 query.to_string()
461 };
462
463 let response = resolver
464 .lookup(&query, HickoryRecordType::PTR)
465 .await
466 .map_err(|e| SeerError::DnsError(format!("PTR lookup failed: {}", e)))?;
467
468 let records = response
469 .record_iter()
470 .filter_map(|record| {
471 if let Some(rdata) = record.data() {
472 if let Some(ptr) = rdata.as_ptr() {
473 return Some(DnsRecord {
474 name: query.clone(),
475 record_type: RecordType::PTR,
476 ttl: record.ttl(),
477 data: RecordData::PTR {
478 target: ptr.0.to_string(),
479 },
480 });
481 }
482 }
483 None
484 })
485 .collect();
486
487 Ok(records)
488 }
489
490 async fn resolve_caa(
491 &self,
492 resolver: &TokioAsyncResolver,
493 domain: &str,
494 ) -> Result<Vec<DnsRecord>> {
495 let response = resolver
496 .lookup(domain, HickoryRecordType::CAA)
497 .await
498 .map_err(|e| SeerError::DnsError(format!("CAA lookup failed: {}", e)))?;
499
500 let records = response
501 .record_iter()
502 .filter_map(|record| {
503 if let Some(rdata) = record.data() {
504 if let Some(caa) = rdata.as_caa() {
505 let (flags, tag, value) = parse_caa(caa);
506 return Some(DnsRecord {
507 name: domain.to_string(),
508 record_type: RecordType::CAA,
509 ttl: record.ttl(),
510 data: RecordData::CAA { flags, tag, value },
511 });
512 }
513 }
514 None
515 })
516 .collect();
517
518 Ok(records)
519 }
520
521 async fn resolve_dnskey(
522 &self,
523 resolver: &TokioAsyncResolver,
524 domain: &str,
525 ) -> Result<Vec<DnsRecord>> {
526 use hickory_resolver::proto::rr::RData as HickoryRData;
527
528 let response = resolver
529 .lookup(domain, HickoryRecordType::DNSKEY)
530 .await
531 .map_err(|e| SeerError::DnsError(format!("DNSKEY lookup failed: {}", e)))?;
532
533 let records = response
534 .record_iter()
535 .filter_map(|record| {
536 if let Some(HickoryRData::DNSSEC(dnssec_rdata)) = record.data() {
537 if let Some(dnskey) = dnssec_rdata.as_dnskey() {
538 use base64::{engine::general_purpose::STANDARD, Engine};
539 let public_key = STANDARD.encode(dnskey.public_key());
540 return Some(DnsRecord {
541 name: domain.to_string(),
542 record_type: RecordType::DNSKEY,
543 ttl: record.ttl(),
544 data: RecordData::DNSKEY {
545 flags: dnskey.flags(),
546 protocol: 3, algorithm: u8::from(dnskey.algorithm()),
548 public_key,
549 },
550 });
551 }
552 }
553 None
554 })
555 .collect();
556
557 Ok(records)
558 }
559
560 async fn resolve_ds(
561 &self,
562 resolver: &TokioAsyncResolver,
563 domain: &str,
564 ) -> Result<Vec<DnsRecord>> {
565 use hickory_resolver::proto::rr::RData as HickoryRData;
566
567 let response = resolver
568 .lookup(domain, HickoryRecordType::DS)
569 .await
570 .map_err(|e| SeerError::DnsError(format!("DS lookup failed: {}", e)))?;
571
572 let records = response
573 .record_iter()
574 .filter_map(|record| {
575 if let Some(HickoryRData::DNSSEC(dnssec_rdata)) = record.data() {
576 if let Some(ds) = dnssec_rdata.as_ds() {
577 let digest = ds
578 .digest()
579 .iter()
580 .map(|b| format!("{:02X}", b))
581 .collect::<String>();
582 return Some(DnsRecord {
583 name: domain.to_string(),
584 record_type: RecordType::DS,
585 ttl: record.ttl(),
586 data: RecordData::DS {
587 key_tag: ds.key_tag(),
588 algorithm: u8::from(ds.algorithm()),
589 digest_type: u8::from(ds.digest_type()),
590 digest,
591 },
592 });
593 }
594 }
595 None
596 })
597 .collect();
598
599 Ok(records)
600 }
601
602 async fn resolve_any(
603 &self,
604 resolver: &TokioAsyncResolver,
605 domain: &str,
606 ) -> Result<Vec<DnsRecord>> {
607 let mut all_records = Vec::new();
608
609 let record_types = [
611 RecordType::A,
612 RecordType::AAAA,
613 RecordType::MX,
614 RecordType::NS,
615 RecordType::TXT,
616 RecordType::SOA,
617 RecordType::CAA,
618 ];
619
620 for record_type in record_types {
621 match self.resolve_type(resolver, domain, record_type).await {
622 Ok(records) => all_records.extend(records),
623 Err(_) => continue, }
625 }
626
627 Ok(all_records)
628 }
629
630 async fn resolve_type(
631 &self,
632 resolver: &TokioAsyncResolver,
633 domain: &str,
634 record_type: RecordType,
635 ) -> Result<Vec<DnsRecord>> {
636 match record_type {
637 RecordType::A => self.resolve_a(resolver, domain).await,
638 RecordType::AAAA => self.resolve_aaaa(resolver, domain).await,
639 RecordType::CNAME => self.resolve_cname(resolver, domain).await,
640 RecordType::MX => self.resolve_mx(resolver, domain).await,
641 RecordType::NS => self.resolve_ns(resolver, domain).await,
642 RecordType::TXT => self.resolve_txt(resolver, domain).await,
643 RecordType::SOA => self.resolve_soa(resolver, domain).await,
644 RecordType::CAA => self.resolve_caa(resolver, domain).await,
645 RecordType::DNSKEY => self.resolve_dnskey(resolver, domain).await,
646 RecordType::DS => self.resolve_ds(resolver, domain).await,
647 _ => Err(SeerError::DnsError("unsupported record type".to_string())),
648 }
649 }
650}
651
652fn reverse_dns_name(ip: &IpAddr) -> String {
655 match ip {
656 IpAddr::V4(addr) => {
657 let octets = addr.octets();
658 format!(
659 "{}.{}.{}.{}.in-addr.arpa",
660 octets[3], octets[2], octets[1], octets[0]
661 )
662 }
663 IpAddr::V6(addr) => {
664 let segments = addr.segments();
665 let mut result = String::with_capacity(72);
667 let mut first = true;
668 for segment in segments.iter().rev() {
669 for shift in [0, 4, 8, 12] {
670 if !first {
671 result.push('.');
672 }
673 first = false;
674 let nibble = (segment >> shift) & 0xF;
675 result.push(char::from_digit(nibble as u32, 16).unwrap());
676 }
677 }
678 result.push_str(".ip6.arpa");
679 result
680 }
681 }
682}
683
684fn parse_caa(caa: &CAA) -> (u8, String, String) {
685 let flags = if caa.issuer_critical() { 128 } else { 0 };
686 let tag = caa.tag().as_str().to_string();
687 let value = caa.value().to_string();
688 (flags, tag, value)
689}
690
691fn is_valid_srv_label(label: &str) -> bool {
693 !label.is_empty()
694 && label.len() <= 63
695 && label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-')
696 && !label.starts_with('-')
697 && !label.ends_with('-')
698}