1use std::net::{IpAddr, SocketAddr};
2use std::str::FromStr;
3use std::time::Duration;
4
5use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
6use hickory_resolver::error::ResolveErrorKind;
7use hickory_resolver::proto::rr::rdata::CAA;
8use hickory_resolver::proto::rr::RecordType as HickoryRecordType;
9use hickory_resolver::TokioAsyncResolver;
10use tracing::{debug, instrument};
11
12use super::records::{DnsRecord, RecordData, RecordType};
13use crate::error::{Result, SeerError};
14use crate::validation::normalize_domain;
15
16fn dns_lookup_or_empty<T>(
20 result: std::result::Result<T, hickory_resolver::error::ResolveError>,
21 record_type: &str,
22) -> Result<Option<T>> {
23 match result {
24 Ok(response) => Ok(Some(response)),
25 Err(e) => match e.kind() {
26 ResolveErrorKind::NoRecordsFound { .. } => Ok(None),
27 _ => Err(SeerError::DnsError(format!(
28 "{} lookup failed: {}",
29 record_type, e
30 ))),
31 },
32 }
33}
34
35const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
38
39#[derive(Clone)]
45pub struct DnsResolver {
46 timeout: Duration,
47 default_resolver: TokioAsyncResolver,
50}
51
52impl std::fmt::Debug for DnsResolver {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("DnsResolver")
55 .field("timeout", &self.timeout)
56 .finish()
57 }
58}
59
60impl Default for DnsResolver {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl DnsResolver {
67 pub fn new() -> Self {
69 let mut opts = ResolverOpts::default();
70 opts.timeout = DEFAULT_TIMEOUT;
71 opts.attempts = 2;
72 opts.use_hosts_file = false;
73
74 Self {
75 timeout: DEFAULT_TIMEOUT,
76 default_resolver: TokioAsyncResolver::tokio(ResolverConfig::google(), opts),
77 }
78 }
79
80 pub fn with_timeout(mut self, timeout: Duration) -> Self {
84 self.timeout = timeout;
85 let mut opts = ResolverOpts::default();
87 opts.timeout = timeout;
88 opts.attempts = 2;
89 opts.use_hosts_file = false;
90 self.default_resolver = TokioAsyncResolver::tokio(ResolverConfig::google(), opts);
91 self
92 }
93
94 fn create_custom_resolver(&self, nameserver: &str) -> Result<TokioAsyncResolver> {
95 let mut opts = ResolverOpts::default();
96 opts.timeout = self.timeout;
97 opts.attempts = 2;
98 opts.use_hosts_file = false;
99
100 let ip: IpAddr = nameserver
101 .parse()
102 .map_err(|_| SeerError::DnsError(format!("invalid nameserver IP: {}", nameserver)))?;
103
104 if let Some(reason) = crate::validation::describe_reserved_ip(&ip) {
106 return Err(SeerError::DnsError(format!(
107 "nameserver {} blocked: {}",
108 nameserver, reason
109 )));
110 }
111
112 let socket_addr = SocketAddr::new(ip, 53);
113 let ns_config = NameServerConfig::new(socket_addr, Protocol::Udp);
114
115 let mut config = ResolverConfig::new();
116 config.add_name_server(ns_config);
117
118 Ok(TokioAsyncResolver::tokio(config, opts))
119 }
120
121 #[instrument(skip(self), fields(domain = %domain, record_type = %record_type))]
128 pub async fn resolve(
129 &self,
130 domain: &str,
131 record_type: RecordType,
132 nameserver: Option<&str>,
133 ) -> Result<Vec<DnsRecord>> {
134 let custom_resolver;
136 let resolver = if let Some(ns) = nameserver {
137 custom_resolver = self.create_custom_resolver(ns)?;
138 &custom_resolver
139 } else {
140 &self.default_resolver
141 };
142 let domain = normalize_domain(domain)?;
143
144 debug!(nameserver = nameserver.unwrap_or("system"), "Resolving DNS");
145
146 match record_type {
147 RecordType::A => self.resolve_a(resolver, &domain).await,
148 RecordType::AAAA => self.resolve_aaaa(resolver, &domain).await,
149 RecordType::CNAME => self.resolve_cname(resolver, &domain).await,
150 RecordType::MX => self.resolve_mx(resolver, &domain).await,
151 RecordType::NS => self.resolve_ns(resolver, &domain).await,
152 RecordType::TXT => self.resolve_txt(resolver, &domain).await,
153 RecordType::SOA => self.resolve_soa(resolver, &domain).await,
154 RecordType::PTR => self.resolve_ptr(resolver, &domain).await,
155 RecordType::SRV => Err(SeerError::DnsError(
156 "SRV records require service name format: _service._proto.name".to_string(),
157 )),
158 RecordType::CAA => self.resolve_caa(resolver, &domain).await,
159 RecordType::DNSKEY => self.resolve_dnskey(resolver, &domain).await,
160 RecordType::DS => self.resolve_ds(resolver, &domain).await,
161 RecordType::ANY => self.resolve_any(resolver, &domain).await,
162 _ => Err(SeerError::DnsError(format!(
163 "Record type {} not implemented",
164 record_type
165 ))),
166 }
167 }
168
169 #[instrument(skip(self), fields(domain = %domain, service = %service, protocol = %protocol))]
177 pub async fn resolve_srv(
178 &self,
179 service: &str,
180 protocol: &str,
181 domain: &str,
182 nameserver: Option<&str>,
183 ) -> Result<Vec<DnsRecord>> {
184 if !is_valid_srv_label(service) {
186 return Err(SeerError::DnsError(format!(
187 "invalid SRV service name: {}",
188 service
189 )));
190 }
191 if !is_valid_srv_label(protocol) {
192 return Err(SeerError::DnsError(format!(
193 "invalid SRV protocol name: {}",
194 protocol
195 )));
196 }
197
198 let custom_resolver;
199 let resolver = if let Some(ns) = nameserver {
200 custom_resolver = self.create_custom_resolver(ns)?;
201 &custom_resolver
202 } else {
203 &self.default_resolver
204 };
205 let query_name = format!("_{}._{}.{}", service, protocol, domain);
206
207 let Some(response) = dns_lookup_or_empty(resolver.srv_lookup(&query_name).await, "SRV")?
208 else {
209 return Ok(vec![]);
210 };
211
212 let records = response
213 .iter()
214 .map(|srv| DnsRecord {
215 name: query_name.clone(),
216 record_type: RecordType::SRV,
217 ttl: response
218 .as_lookup()
219 .record_iter()
220 .next()
221 .map(|r| r.ttl())
222 .unwrap_or(0),
223 data: RecordData::SRV {
224 priority: srv.priority(),
225 weight: srv.weight(),
226 port: srv.port(),
227 target: srv.target().to_string(),
228 },
229 })
230 .collect();
231
232 Ok(records)
233 }
234
235 async fn resolve_a(
236 &self,
237 resolver: &TokioAsyncResolver,
238 domain: &str,
239 ) -> Result<Vec<DnsRecord>> {
240 let Some(response) = dns_lookup_or_empty(resolver.ipv4_lookup(domain).await, "A")? else {
241 return Ok(vec![]);
242 };
243
244 let ttl = response
245 .as_lookup()
246 .record_iter()
247 .next()
248 .map(|r| r.ttl())
249 .unwrap_or(0);
250
251 let records = response
252 .iter()
253 .map(|addr| DnsRecord {
254 name: domain.to_string(),
255 record_type: RecordType::A,
256 ttl,
257 data: RecordData::A {
258 address: addr.to_string(),
259 },
260 })
261 .collect();
262
263 Ok(records)
264 }
265
266 async fn resolve_aaaa(
267 &self,
268 resolver: &TokioAsyncResolver,
269 domain: &str,
270 ) -> Result<Vec<DnsRecord>> {
271 let Some(response) = dns_lookup_or_empty(resolver.ipv6_lookup(domain).await, "AAAA")?
272 else {
273 return Ok(vec![]);
274 };
275
276 let ttl = response
277 .as_lookup()
278 .record_iter()
279 .next()
280 .map(|r| r.ttl())
281 .unwrap_or(0);
282
283 let records = response
284 .iter()
285 .map(|addr| DnsRecord {
286 name: domain.to_string(),
287 record_type: RecordType::AAAA,
288 ttl,
289 data: RecordData::AAAA {
290 address: addr.to_string(),
291 },
292 })
293 .collect();
294
295 Ok(records)
296 }
297
298 async fn resolve_cname(
299 &self,
300 resolver: &TokioAsyncResolver,
301 domain: &str,
302 ) -> Result<Vec<DnsRecord>> {
303 let Some(response) = dns_lookup_or_empty(
304 resolver.lookup(domain, HickoryRecordType::CNAME).await,
305 "CNAME",
306 )?
307 else {
308 return Ok(vec![]);
309 };
310
311 let records = response
312 .record_iter()
313 .filter_map(|record| {
314 if let Some(rdata) = record.data() {
315 if let Some(cname) = rdata.as_cname() {
316 return Some(DnsRecord {
317 name: domain.to_string(),
318 record_type: RecordType::CNAME,
319 ttl: record.ttl(),
320 data: RecordData::CNAME {
321 target: cname.0.to_string(),
322 },
323 });
324 }
325 }
326 None
327 })
328 .collect();
329
330 Ok(records)
331 }
332
333 async fn resolve_mx(
334 &self,
335 resolver: &TokioAsyncResolver,
336 domain: &str,
337 ) -> Result<Vec<DnsRecord>> {
338 let Some(response) = dns_lookup_or_empty(resolver.mx_lookup(domain).await, "MX")? else {
339 return Ok(vec![]);
340 };
341
342 let ttl = response
343 .as_lookup()
344 .record_iter()
345 .next()
346 .map(|r| r.ttl())
347 .unwrap_or(0);
348
349 let mut records: Vec<DnsRecord> = response
350 .iter()
351 .map(|mx| DnsRecord {
352 name: domain.to_string(),
353 record_type: RecordType::MX,
354 ttl,
355 data: RecordData::MX {
356 preference: mx.preference(),
357 exchange: mx.exchange().to_string(),
358 },
359 })
360 .collect();
361
362 records.sort_by_key(|r| {
363 if let RecordData::MX { preference, .. } = &r.data {
364 *preference
365 } else {
366 0
367 }
368 });
369
370 Ok(records)
371 }
372
373 async fn resolve_ns(
374 &self,
375 resolver: &TokioAsyncResolver,
376 domain: &str,
377 ) -> Result<Vec<DnsRecord>> {
378 let Some(response) = dns_lookup_or_empty(resolver.ns_lookup(domain).await, "NS")? else {
379 return Ok(vec![]);
380 };
381
382 let ttl = response
383 .as_lookup()
384 .record_iter()
385 .next()
386 .map(|r| r.ttl())
387 .unwrap_or(0);
388
389 let records = response
390 .iter()
391 .map(|ns| DnsRecord {
392 name: domain.to_string(),
393 record_type: RecordType::NS,
394 ttl,
395 data: RecordData::NS {
396 nameserver: ns.0.to_string(),
397 },
398 })
399 .collect();
400
401 Ok(records)
402 }
403
404 async fn resolve_txt(
405 &self,
406 resolver: &TokioAsyncResolver,
407 domain: &str,
408 ) -> Result<Vec<DnsRecord>> {
409 let Some(response) = dns_lookup_or_empty(resolver.txt_lookup(domain).await, "TXT")? else {
410 return Ok(vec![]);
411 };
412
413 let ttl = response
414 .as_lookup()
415 .record_iter()
416 .next()
417 .map(|r| r.ttl())
418 .unwrap_or(0);
419
420 let records = response
421 .iter()
422 .map(|txt| {
423 let text = txt
424 .iter()
425 .map(|data| String::from_utf8_lossy(data).to_string())
426 .collect::<Vec<_>>()
427 .join("");
428
429 DnsRecord {
430 name: domain.to_string(),
431 record_type: RecordType::TXT,
432 ttl,
433 data: RecordData::TXT { text },
434 }
435 })
436 .collect();
437
438 Ok(records)
439 }
440
441 async fn resolve_soa(
442 &self,
443 resolver: &TokioAsyncResolver,
444 domain: &str,
445 ) -> Result<Vec<DnsRecord>> {
446 let Some(response) = dns_lookup_or_empty(resolver.soa_lookup(domain).await, "SOA")? else {
447 return Ok(vec![]);
448 };
449
450 let ttl = response
451 .as_lookup()
452 .record_iter()
453 .next()
454 .map(|r| r.ttl())
455 .unwrap_or(0);
456
457 let records = response
458 .iter()
459 .map(|soa| DnsRecord {
460 name: domain.to_string(),
461 record_type: RecordType::SOA,
462 ttl,
463 data: RecordData::SOA {
464 mname: soa.mname().to_string(),
465 rname: soa.rname().to_string(),
466 serial: soa.serial(),
467 refresh: soa.refresh().try_into().unwrap_or(0),
468 retry: soa.retry().try_into().unwrap_or(0),
469 expire: soa.expire().try_into().unwrap_or(0),
470 minimum: soa.minimum(),
471 },
472 })
473 .collect();
474
475 Ok(records)
476 }
477
478 async fn resolve_ptr(
479 &self,
480 resolver: &TokioAsyncResolver,
481 query: &str,
482 ) -> Result<Vec<DnsRecord>> {
483 let query = if let Ok(ip) = IpAddr::from_str(query) {
485 reverse_dns_name(&ip)
486 } else {
487 query.to_string()
488 };
489
490 let Some(response) =
491 dns_lookup_or_empty(resolver.lookup(&query, HickoryRecordType::PTR).await, "PTR")?
492 else {
493 return Ok(vec![]);
494 };
495
496 let records = response
497 .record_iter()
498 .filter_map(|record| {
499 if let Some(rdata) = record.data() {
500 if let Some(ptr) = rdata.as_ptr() {
501 return Some(DnsRecord {
502 name: query.clone(),
503 record_type: RecordType::PTR,
504 ttl: record.ttl(),
505 data: RecordData::PTR {
506 target: ptr.0.to_string(),
507 },
508 });
509 }
510 }
511 None
512 })
513 .collect();
514
515 Ok(records)
516 }
517
518 async fn resolve_caa(
519 &self,
520 resolver: &TokioAsyncResolver,
521 domain: &str,
522 ) -> Result<Vec<DnsRecord>> {
523 let Some(response) =
524 dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::CAA).await, "CAA")?
525 else {
526 return Ok(vec![]);
527 };
528
529 let records = response
530 .record_iter()
531 .filter_map(|record| {
532 if let Some(rdata) = record.data() {
533 if let Some(caa) = rdata.as_caa() {
534 let (flags, tag, value) = parse_caa(caa);
535 return Some(DnsRecord {
536 name: domain.to_string(),
537 record_type: RecordType::CAA,
538 ttl: record.ttl(),
539 data: RecordData::CAA { flags, tag, value },
540 });
541 }
542 }
543 None
544 })
545 .collect();
546
547 Ok(records)
548 }
549
550 async fn resolve_dnskey(
551 &self,
552 resolver: &TokioAsyncResolver,
553 domain: &str,
554 ) -> Result<Vec<DnsRecord>> {
555 use hickory_resolver::proto::rr::RData as HickoryRData;
556
557 let Some(response) = dns_lookup_or_empty(
558 resolver.lookup(domain, HickoryRecordType::DNSKEY).await,
559 "DNSKEY",
560 )?
561 else {
562 return Ok(vec![]);
563 };
564
565 let records = response
566 .record_iter()
567 .filter_map(|record| {
568 if let Some(HickoryRData::DNSSEC(dnssec_rdata)) = record.data() {
569 if let Some(dnskey) = dnssec_rdata.as_dnskey() {
570 use base64::{engine::general_purpose::STANDARD, Engine};
571 let public_key = STANDARD.encode(dnskey.public_key());
572 return Some(DnsRecord {
573 name: domain.to_string(),
574 record_type: RecordType::DNSKEY,
575 ttl: record.ttl(),
576 data: RecordData::DNSKEY {
577 flags: dnskey.flags(),
578 protocol: 3, algorithm: u8::from(dnskey.algorithm()),
580 public_key,
581 },
582 });
583 }
584 }
585 None
586 })
587 .collect();
588
589 Ok(records)
590 }
591
592 async fn resolve_ds(
593 &self,
594 resolver: &TokioAsyncResolver,
595 domain: &str,
596 ) -> Result<Vec<DnsRecord>> {
597 use hickory_resolver::proto::rr::RData as HickoryRData;
598
599 let Some(response) =
600 dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::DS).await, "DS")?
601 else {
602 return Ok(vec![]);
603 };
604
605 let records = response
606 .record_iter()
607 .filter_map(|record| {
608 if let Some(HickoryRData::DNSSEC(dnssec_rdata)) = record.data() {
609 if let Some(ds) = dnssec_rdata.as_ds() {
610 let digest = ds
611 .digest()
612 .iter()
613 .map(|b| format!("{:02X}", b))
614 .collect::<String>();
615 return Some(DnsRecord {
616 name: domain.to_string(),
617 record_type: RecordType::DS,
618 ttl: record.ttl(),
619 data: RecordData::DS {
620 key_tag: ds.key_tag(),
621 algorithm: u8::from(ds.algorithm()),
622 digest_type: u8::from(ds.digest_type()),
623 digest,
624 },
625 });
626 }
627 }
628 None
629 })
630 .collect();
631
632 Ok(records)
633 }
634
635 async fn resolve_any(
636 &self,
637 resolver: &TokioAsyncResolver,
638 domain: &str,
639 ) -> Result<Vec<DnsRecord>> {
640 let mut all_records = Vec::new();
641
642 let record_types = [
644 RecordType::A,
645 RecordType::AAAA,
646 RecordType::MX,
647 RecordType::NS,
648 RecordType::TXT,
649 RecordType::SOA,
650 RecordType::CAA,
651 ];
652
653 for record_type in record_types {
654 match self.resolve_type(resolver, domain, record_type).await {
655 Ok(records) => all_records.extend(records),
656 Err(_) => continue, }
658 }
659
660 Ok(all_records)
661 }
662
663 async fn resolve_type(
664 &self,
665 resolver: &TokioAsyncResolver,
666 domain: &str,
667 record_type: RecordType,
668 ) -> Result<Vec<DnsRecord>> {
669 match record_type {
670 RecordType::A => self.resolve_a(resolver, domain).await,
671 RecordType::AAAA => self.resolve_aaaa(resolver, domain).await,
672 RecordType::CNAME => self.resolve_cname(resolver, domain).await,
673 RecordType::MX => self.resolve_mx(resolver, domain).await,
674 RecordType::NS => self.resolve_ns(resolver, domain).await,
675 RecordType::TXT => self.resolve_txt(resolver, domain).await,
676 RecordType::SOA => self.resolve_soa(resolver, domain).await,
677 RecordType::CAA => self.resolve_caa(resolver, domain).await,
678 RecordType::DNSKEY => self.resolve_dnskey(resolver, domain).await,
679 RecordType::DS => self.resolve_ds(resolver, domain).await,
680 _ => Err(SeerError::DnsError("unsupported record type".to_string())),
681 }
682 }
683}
684
685fn reverse_dns_name(ip: &IpAddr) -> String {
688 match ip {
689 IpAddr::V4(addr) => {
690 let octets = addr.octets();
691 format!(
692 "{}.{}.{}.{}.in-addr.arpa",
693 octets[3], octets[2], octets[1], octets[0]
694 )
695 }
696 IpAddr::V6(addr) => {
697 let segments = addr.segments();
698 let mut result = String::with_capacity(72);
700 let mut first = true;
701 for segment in segments.iter().rev() {
702 for shift in [0, 4, 8, 12] {
703 if !first {
704 result.push('.');
705 }
706 first = false;
707 let nibble = (segment >> shift) & 0xF;
708 result
709 .push(char::from_digit(nibble as u32, 16).expect("nibble is always 0-15"));
710 }
711 }
712 result.push_str(".ip6.arpa");
713 result
714 }
715 }
716}
717
718fn parse_caa(caa: &CAA) -> (u8, String, String) {
719 let flags = if caa.issuer_critical() { 128 } else { 0 };
720 let tag = caa.tag().as_str().to_string();
721 let value = caa.value().to_string();
722 (flags, tag, value)
723}
724
725fn is_valid_srv_label(label: &str) -> bool {
727 !label.is_empty()
728 && label.len() <= 63
729 && label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-')
730 && !label.starts_with('-')
731 && !label.ends_with('-')
732}