1use std::net::IpAddr;
12use std::str::FromStr;
13use std::time::Duration;
14
15use hickory_resolver::config::{NameServerConfig, ResolveHosts, ResolverConfig, GOOGLE};
16use hickory_resolver::net::runtime::TokioRuntimeProvider;
17use hickory_resolver::net::NetError;
18use hickory_resolver::proto::dnssec::PublicKey;
19use hickory_resolver::proto::rr::rdata::CAA;
20use hickory_resolver::proto::rr::{RData as HickoryRData, RecordType as HickoryRecordType};
21use hickory_resolver::TokioResolver;
22use tracing::{debug, instrument};
23
24use super::records::{DnsRecord, RecordData, RecordType};
25use crate::error::{Result, SeerError};
26use crate::validation::normalize_domain;
27
28fn dns_lookup_or_empty<T>(
32 result: std::result::Result<T, NetError>,
33 record_type: &str,
34) -> Result<Option<T>> {
35 match result {
36 Ok(response) => Ok(Some(response)),
37 Err(e) if e.is_no_records_found() => Ok(None),
38 Err(e) => Err(SeerError::DnsError(format!(
39 "{} lookup failed: {}",
40 record_type, e
41 ))),
42 }
43}
44
45const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
48
49fn build_resolver(config: ResolverConfig, timeout: Duration) -> TokioResolver {
56 let mut builder = TokioResolver::builder_with_config(config, TokioRuntimeProvider::default());
57 {
58 let opts = builder.options_mut();
59 opts.timeout = timeout;
60 opts.attempts = 2;
61 opts.use_hosts_file = ResolveHosts::Never;
62 }
63 builder
64 .build()
65 .expect("hickory resolver build is infallible without TLS features")
66}
67
68#[derive(Clone)]
74pub struct DnsResolver {
75 timeout: Duration,
76 default_resolver: TokioResolver,
79}
80
81impl std::fmt::Debug for DnsResolver {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("DnsResolver")
84 .field("timeout", &self.timeout)
85 .finish()
86 }
87}
88
89impl Default for DnsResolver {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl DnsResolver {
96 pub fn new() -> Self {
98 Self {
99 timeout: DEFAULT_TIMEOUT,
100 default_resolver: build_resolver(ResolverConfig::udp_and_tcp(&GOOGLE), DEFAULT_TIMEOUT),
101 }
102 }
103
104 pub fn with_timeout(mut self, timeout: Duration) -> Self {
108 self.timeout = timeout;
109 self.default_resolver = build_resolver(ResolverConfig::udp_and_tcp(&GOOGLE), timeout);
110 self
111 }
112
113 async fn create_custom_resolver(&self, nameserver: &str) -> Result<TokioResolver> {
114 let ips: Vec<IpAddr> = if let Ok(ip) = nameserver.parse::<IpAddr>() {
122 vec![ip]
123 } else {
124 let response = self
125 .default_resolver
126 .lookup_ip(nameserver)
127 .await
128 .map_err(|e| {
129 SeerError::DnsError(format!(
130 "failed to resolve nameserver hostname {}: {}",
131 nameserver, e
132 ))
133 })?;
134 let resolved: Vec<IpAddr> = response.iter().collect();
135 if resolved.is_empty() {
136 return Err(SeerError::DnsError(format!(
137 "nameserver {} did not resolve to any addresses",
138 nameserver
139 )));
140 }
141 resolved
142 };
143
144 for ip in &ips {
148 if let Some(reason) = crate::validation::describe_reserved_ip(ip) {
149 return Err(SeerError::DnsError(format!(
150 "nameserver {} blocked: {}",
151 nameserver, reason
152 )));
153 }
154 }
155
156 let mut config = ResolverConfig::from_parts(None, vec![], vec![]);
161 for ip in ips {
162 config.add_name_server(NameServerConfig::udp(ip));
163 }
164
165 Ok(build_resolver(config, self.timeout))
166 }
167
168 #[instrument(skip(self), fields(domain = %domain, record_type = %record_type))]
175 pub async fn resolve(
176 &self,
177 domain: &str,
178 record_type: RecordType,
179 nameserver: Option<&str>,
180 ) -> Result<Vec<DnsRecord>> {
181 let custom_resolver;
183 let resolver = if let Some(ns) = nameserver {
184 custom_resolver = self.create_custom_resolver(ns).await?;
185 &custom_resolver
186 } else {
187 &self.default_resolver
188 };
189 let domain = prepare_query(domain, record_type)?;
190
191 debug!(nameserver = nameserver.unwrap_or("system"), "Resolving DNS");
192
193 match record_type {
194 RecordType::A => self.resolve_a(resolver, &domain).await,
195 RecordType::AAAA => self.resolve_aaaa(resolver, &domain).await,
196 RecordType::CNAME => self.resolve_cname(resolver, &domain).await,
197 RecordType::MX => self.resolve_mx(resolver, &domain).await,
198 RecordType::NS => self.resolve_ns(resolver, &domain).await,
199 RecordType::TXT => self.resolve_txt(resolver, &domain).await,
200 RecordType::SOA => self.resolve_soa(resolver, &domain).await,
201 RecordType::PTR => self.resolve_ptr(resolver, &domain).await,
202 RecordType::SRV => match parse_srv_query(&domain) {
203 Some((service, protocol, name)) => {
205 self.resolve_srv_core(resolver, &service, &protocol, &name)
206 .await
207 }
208 None => Err(SeerError::InvalidInput(
211 "SRV records require service name format: _service._proto.name".to_string(),
212 )),
213 },
214 RecordType::CAA => self.resolve_caa(resolver, &domain).await,
215 RecordType::DNSKEY => self.resolve_dnskey(resolver, &domain).await,
216 RecordType::DS => self.resolve_ds(resolver, &domain).await,
217 RecordType::TLSA => self.resolve_tlsa(resolver, &domain).await,
218 RecordType::SSHFP => self.resolve_sshfp(resolver, &domain).await,
219 RecordType::NAPTR => self.resolve_naptr(resolver, &domain).await,
220 RecordType::ANY => self.resolve_any(resolver, &domain).await,
221 }
222 }
223
224 #[instrument(skip(self), fields(domain = %domain, service = %service, protocol = %protocol))]
232 pub async fn resolve_srv(
233 &self,
234 service: &str,
235 protocol: &str,
236 domain: &str,
237 nameserver: Option<&str>,
238 ) -> Result<Vec<DnsRecord>> {
239 let custom_resolver;
240 let resolver = if let Some(ns) = nameserver {
241 custom_resolver = self.create_custom_resolver(ns).await?;
242 &custom_resolver
243 } else {
244 &self.default_resolver
245 };
246 self.resolve_srv_core(resolver, service, protocol, domain)
247 .await
248 }
249
250 async fn resolve_srv_core(
257 &self,
258 resolver: &TokioResolver,
259 service: &str,
260 protocol: &str,
261 domain: &str,
262 ) -> Result<Vec<DnsRecord>> {
263 if !is_valid_srv_label(service) {
264 return Err(SeerError::InvalidInput(format!(
265 "invalid SRV service name: {}",
266 service
267 )));
268 }
269 if !is_valid_srv_label(protocol) {
270 return Err(SeerError::InvalidInput(format!(
271 "invalid SRV protocol name: {}",
272 protocol
273 )));
274 }
275
276 let query_name = format!("_{}._{}.{}", service, protocol, domain);
277
278 let Some(response) = dns_lookup_or_empty(
279 resolver.lookup(&query_name, HickoryRecordType::SRV).await,
280 "SRV",
281 )?
282 else {
283 return Ok(vec![]);
284 };
285
286 let records = response
287 .answers()
288 .iter()
289 .filter_map(|record| {
290 if let HickoryRData::SRV(srv) = &record.data {
291 Some(DnsRecord {
292 name: query_name.clone(),
293 record_type: RecordType::SRV,
294 ttl: record.ttl,
295 data: RecordData::SRV {
296 priority: srv.priority,
297 weight: srv.weight,
298 port: srv.port,
299 target: srv.target.to_string(),
300 },
301 })
302 } else {
303 None
304 }
305 })
306 .collect();
307
308 Ok(records)
309 }
310
311 async fn resolve_a(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
312 let Some(response) =
313 dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::A).await, "A")?
314 else {
315 return Ok(vec![]);
316 };
317
318 let records = response
319 .answers()
320 .iter()
321 .filter_map(|record| {
322 if let HickoryRData::A(addr) = &record.data {
323 Some(DnsRecord {
324 name: domain.to_string(),
325 record_type: RecordType::A,
326 ttl: record.ttl,
327 data: RecordData::A {
328 address: addr.0.to_string(),
329 },
330 })
331 } else {
332 None
333 }
334 })
335 .collect();
336
337 Ok(records)
338 }
339
340 async fn resolve_aaaa(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
341 let Some(response) = dns_lookup_or_empty(
342 resolver.lookup(domain, HickoryRecordType::AAAA).await,
343 "AAAA",
344 )?
345 else {
346 return Ok(vec![]);
347 };
348
349 let records = response
350 .answers()
351 .iter()
352 .filter_map(|record| {
353 if let HickoryRData::AAAA(addr) = &record.data {
354 Some(DnsRecord {
355 name: domain.to_string(),
356 record_type: RecordType::AAAA,
357 ttl: record.ttl,
358 data: RecordData::AAAA {
359 address: addr.0.to_string(),
360 },
361 })
362 } else {
363 None
364 }
365 })
366 .collect();
367
368 Ok(records)
369 }
370
371 async fn resolve_cname(
372 &self,
373 resolver: &TokioResolver,
374 domain: &str,
375 ) -> Result<Vec<DnsRecord>> {
376 let Some(response) = dns_lookup_or_empty(
377 resolver.lookup(domain, HickoryRecordType::CNAME).await,
378 "CNAME",
379 )?
380 else {
381 return Ok(vec![]);
382 };
383
384 let records = response
385 .answers()
386 .iter()
387 .filter_map(|record| {
388 if let HickoryRData::CNAME(cname) = &record.data {
389 Some(DnsRecord {
390 name: domain.to_string(),
391 record_type: RecordType::CNAME,
392 ttl: record.ttl,
393 data: RecordData::CNAME {
394 target: cname.0.to_string(),
395 },
396 })
397 } else {
398 None
399 }
400 })
401 .collect();
402
403 Ok(records)
404 }
405
406 async fn resolve_mx(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
407 let Some(response) =
408 dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::MX).await, "MX")?
409 else {
410 return Ok(vec![]);
411 };
412
413 let mut records: Vec<DnsRecord> = response
414 .answers()
415 .iter()
416 .filter_map(|record| {
417 if let HickoryRData::MX(mx) = &record.data {
418 Some(DnsRecord {
419 name: domain.to_string(),
420 record_type: RecordType::MX,
421 ttl: record.ttl,
422 data: RecordData::MX {
423 preference: mx.preference,
424 exchange: mx.exchange.to_string(),
425 },
426 })
427 } else {
428 None
429 }
430 })
431 .collect();
432
433 records.sort_by_key(|r| {
434 if let RecordData::MX { preference, .. } = &r.data {
435 *preference
436 } else {
437 0
438 }
439 });
440
441 Ok(records)
442 }
443
444 async fn resolve_ns(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
445 let Some(response) =
446 dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::NS).await, "NS")?
447 else {
448 return Ok(vec![]);
449 };
450
451 let records = response
452 .answers()
453 .iter()
454 .filter_map(|record| {
455 if let HickoryRData::NS(ns) = &record.data {
456 Some(DnsRecord {
457 name: domain.to_string(),
458 record_type: RecordType::NS,
459 ttl: record.ttl,
460 data: RecordData::NS {
461 nameserver: ns.0.to_string(),
462 },
463 })
464 } else {
465 None
466 }
467 })
468 .collect();
469
470 Ok(records)
471 }
472
473 async fn resolve_txt(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
474 let Some(response) =
475 dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::TXT).await, "TXT")?
476 else {
477 return Ok(vec![]);
478 };
479
480 let records = response
481 .answers()
482 .iter()
483 .filter_map(|record| {
484 if let HickoryRData::TXT(txt) = &record.data {
485 let text = txt
486 .txt_data
487 .iter()
488 .map(|data| String::from_utf8_lossy(data).to_string())
489 .collect::<Vec<_>>()
490 .join("");
491
492 Some(DnsRecord {
493 name: domain.to_string(),
494 record_type: RecordType::TXT,
495 ttl: record.ttl,
496 data: RecordData::TXT { text },
497 })
498 } else {
499 None
500 }
501 })
502 .collect();
503
504 Ok(records)
505 }
506
507 async fn resolve_soa(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
508 let Some(response) =
509 dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::SOA).await, "SOA")?
510 else {
511 return Ok(vec![]);
512 };
513
514 let records = response
515 .answers()
516 .iter()
517 .filter_map(|record| {
518 if let HickoryRData::SOA(soa) = &record.data {
519 Some(DnsRecord {
520 name: domain.to_string(),
521 record_type: RecordType::SOA,
522 ttl: record.ttl,
523 data: RecordData::SOA {
524 mname: soa.mname.to_string(),
525 rname: soa.rname.to_string(),
526 serial: soa.serial,
527 refresh: soa.refresh as u32,
534 retry: soa.retry as u32,
535 expire: soa.expire as u32,
536 minimum: soa.minimum,
537 },
538 })
539 } else {
540 None
541 }
542 })
543 .collect();
544
545 Ok(records)
546 }
547
548 async fn resolve_ptr(&self, resolver: &TokioResolver, query: &str) -> Result<Vec<DnsRecord>> {
549 let query = if let Ok(ip) = IpAddr::from_str(query) {
551 reverse_dns_name(&ip)
552 } else {
553 query.to_string()
554 };
555
556 let Some(response) =
557 dns_lookup_or_empty(resolver.lookup(&query, HickoryRecordType::PTR).await, "PTR")?
558 else {
559 return Ok(vec![]);
560 };
561
562 let records = response
563 .answers()
564 .iter()
565 .filter_map(|record| {
566 if let HickoryRData::PTR(ptr) = &record.data {
567 Some(DnsRecord {
568 name: query.clone(),
569 record_type: RecordType::PTR,
570 ttl: record.ttl,
571 data: RecordData::PTR {
572 target: ptr.0.to_string(),
573 },
574 })
575 } else {
576 None
577 }
578 })
579 .collect();
580
581 Ok(records)
582 }
583
584 async fn resolve_caa(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
585 let Some(response) =
586 dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::CAA).await, "CAA")?
587 else {
588 return Ok(vec![]);
589 };
590
591 let records = response
592 .answers()
593 .iter()
594 .filter_map(|record| {
595 if let HickoryRData::CAA(caa) = &record.data {
596 let (flags, tag, value) = parse_caa(caa);
597 Some(DnsRecord {
598 name: domain.to_string(),
599 record_type: RecordType::CAA,
600 ttl: record.ttl,
601 data: RecordData::CAA { flags, tag, value },
602 })
603 } else {
604 None
605 }
606 })
607 .collect();
608
609 Ok(records)
610 }
611
612 async fn resolve_dnskey(
613 &self,
614 resolver: &TokioResolver,
615 domain: &str,
616 ) -> Result<Vec<DnsRecord>> {
617 use hickory_resolver::proto::dnssec::rdata::DNSSECRData;
618
619 let Some(response) = dns_lookup_or_empty(
620 resolver.lookup(domain, HickoryRecordType::DNSKEY).await,
621 "DNSKEY",
622 )?
623 else {
624 return Ok(vec![]);
625 };
626
627 let records = response
628 .answers()
629 .iter()
630 .filter_map(|record| {
631 if let HickoryRData::DNSSEC(DNSSECRData::DNSKEY(dnskey)) = &record.data {
632 use base64::{engine::general_purpose::STANDARD, Engine};
633 let public_key_buf = dnskey.public_key();
634 let public_key = STANDARD.encode(public_key_buf.public_bytes());
635 Some(DnsRecord {
636 name: domain.to_string(),
637 record_type: RecordType::DNSKEY,
638 ttl: record.ttl,
639 data: RecordData::DNSKEY {
640 flags: dnskey.flags(),
641 protocol: 3,
643 algorithm: u8::from(public_key_buf.algorithm()),
644 public_key,
645 },
646 })
647 } else {
648 None
649 }
650 })
651 .collect();
652
653 Ok(records)
654 }
655
656 async fn resolve_ds(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
657 use hickory_resolver::proto::dnssec::rdata::DNSSECRData;
658
659 let Some(response) =
660 dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::DS).await, "DS")?
661 else {
662 return Ok(vec![]);
663 };
664
665 let records = response
666 .answers()
667 .iter()
668 .filter_map(|record| {
669 if let HickoryRData::DNSSEC(DNSSECRData::DS(ds)) = &record.data {
670 let digest = ds
671 .digest()
672 .iter()
673 .map(|b| format!("{:02X}", b))
674 .collect::<String>();
675 Some(DnsRecord {
676 name: domain.to_string(),
677 record_type: RecordType::DS,
678 ttl: record.ttl,
679 data: RecordData::DS {
680 key_tag: ds.key_tag(),
681 algorithm: u8::from(ds.algorithm()),
682 digest_type: u8::from(ds.digest_type()),
683 digest,
684 },
685 })
686 } else {
687 None
688 }
689 })
690 .collect();
691
692 Ok(records)
693 }
694
695 async fn resolve_tlsa(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
696 let Some(response) = dns_lookup_or_empty(
702 resolver.lookup(domain, HickoryRecordType::TLSA).await,
703 "TLSA",
704 )?
705 else {
706 return Ok(vec![]);
707 };
708
709 let records = response
710 .answers()
711 .iter()
712 .filter_map(|record| {
713 if let HickoryRData::TLSA(tlsa) = &record.data {
714 let cert_data = tlsa
715 .cert_data
716 .iter()
717 .map(|b| format!("{:02X}", b))
718 .collect::<String>();
719 Some(DnsRecord {
720 name: domain.to_string(),
721 record_type: RecordType::TLSA,
722 ttl: record.ttl,
723 data: RecordData::TLSA {
724 cert_usage: u8::from(tlsa.cert_usage),
725 selector: u8::from(tlsa.selector),
726 matching: u8::from(tlsa.matching),
727 cert_data,
728 },
729 })
730 } else {
731 None
732 }
733 })
734 .collect();
735
736 Ok(records)
737 }
738
739 async fn resolve_sshfp(
740 &self,
741 resolver: &TokioResolver,
742 domain: &str,
743 ) -> Result<Vec<DnsRecord>> {
744 let Some(response) = dns_lookup_or_empty(
745 resolver.lookup(domain, HickoryRecordType::SSHFP).await,
746 "SSHFP",
747 )?
748 else {
749 return Ok(vec![]);
750 };
751
752 let records = response
753 .answers()
754 .iter()
755 .filter_map(|record| {
756 if let HickoryRData::SSHFP(sshfp) = &record.data {
757 let fingerprint = sshfp
758 .fingerprint
759 .iter()
760 .map(|b| format!("{:02X}", b))
761 .collect::<String>();
762 Some(DnsRecord {
763 name: domain.to_string(),
764 record_type: RecordType::SSHFP,
765 ttl: record.ttl,
766 data: RecordData::SSHFP {
767 algorithm: u8::from(sshfp.algorithm),
768 fingerprint_type: u8::from(sshfp.fingerprint_type),
769 fingerprint,
770 },
771 })
772 } else {
773 None
774 }
775 })
776 .collect();
777
778 Ok(records)
779 }
780
781 async fn resolve_naptr(
782 &self,
783 resolver: &TokioResolver,
784 domain: &str,
785 ) -> Result<Vec<DnsRecord>> {
786 let Some(response) = dns_lookup_or_empty(
787 resolver.lookup(domain, HickoryRecordType::NAPTR).await,
788 "NAPTR",
789 )?
790 else {
791 return Ok(vec![]);
792 };
793
794 let records = response
795 .answers()
796 .iter()
797 .filter_map(|record| {
798 if let HickoryRData::NAPTR(naptr) = &record.data {
799 Some(DnsRecord {
800 name: domain.to_string(),
801 record_type: RecordType::NAPTR,
802 ttl: record.ttl,
803 data: RecordData::NAPTR {
807 order: naptr.order,
808 preference: naptr.preference,
809 flags: String::from_utf8_lossy(&naptr.flags).into_owned(),
810 services: String::from_utf8_lossy(&naptr.services).into_owned(),
811 regexp: String::from_utf8_lossy(&naptr.regexp).into_owned(),
812 replacement: naptr.replacement.to_string(),
813 },
814 })
815 } else {
816 None
817 }
818 })
819 .collect();
820
821 Ok(records)
822 }
823
824 async fn resolve_any(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
825 let mut all_records = Vec::new();
826
827 let record_types = [
829 RecordType::A,
830 RecordType::AAAA,
831 RecordType::MX,
832 RecordType::NS,
833 RecordType::TXT,
834 RecordType::SOA,
835 RecordType::CAA,
836 ];
837
838 let mut any_ok = false;
843 let mut last_err = None;
844 for record_type in record_types {
845 match self.resolve_type(resolver, domain, record_type).await {
846 Ok(records) => {
847 any_ok = true;
848 all_records.extend(records);
849 }
850 Err(e) => last_err = Some(e),
851 }
852 }
853
854 match last_err {
855 Some(e) if !any_ok => Err(e),
856 _ => Ok(all_records),
857 }
858 }
859
860 async fn resolve_type(
861 &self,
862 resolver: &TokioResolver,
863 domain: &str,
864 record_type: RecordType,
865 ) -> Result<Vec<DnsRecord>> {
866 match record_type {
867 RecordType::A => self.resolve_a(resolver, domain).await,
868 RecordType::AAAA => self.resolve_aaaa(resolver, domain).await,
869 RecordType::CNAME => self.resolve_cname(resolver, domain).await,
870 RecordType::MX => self.resolve_mx(resolver, domain).await,
871 RecordType::NS => self.resolve_ns(resolver, domain).await,
872 RecordType::TXT => self.resolve_txt(resolver, domain).await,
873 RecordType::SOA => self.resolve_soa(resolver, domain).await,
874 RecordType::CAA => self.resolve_caa(resolver, domain).await,
875 RecordType::DNSKEY => self.resolve_dnskey(resolver, domain).await,
876 RecordType::DS => self.resolve_ds(resolver, domain).await,
877 _ => Err(SeerError::DnsError("unsupported record type".to_string())),
878 }
879 }
880}
881
882#[derive(Debug, Clone, Copy, PartialEq, Eq)]
887pub enum DnsPresence {
888 Present,
890 Absent,
892 Unknown,
894}
895
896fn classify_ns_presence(result: &Result<Vec<DnsRecord>>) -> DnsPresence {
901 match result {
902 Ok(records) if records.is_empty() => DnsPresence::Absent,
903 Ok(_) => DnsPresence::Present,
904 Err(_) => DnsPresence::Unknown,
905 }
906}
907
908impl DnsResolver {
909 pub async fn presence(&self, domain: &str) -> DnsPresence {
917 classify_ns_presence(&self.resolve(domain, RecordType::NS, None).await)
918 }
919}
920
921fn prepare_query(domain: &str, record_type: RecordType) -> Result<String> {
934 if record_type == RecordType::PTR {
935 if let Ok(ip) = IpAddr::from_str(domain.trim()) {
936 return Ok(ip.to_string());
937 }
938 }
939 normalize_domain(domain)
940}
941
942fn parse_srv_query(name: &str) -> Option<(String, String, String)> {
947 let mut parts = name.splitn(3, '.');
948 let service = parts.next()?.strip_prefix('_')?;
949 let protocol = parts.next()?.strip_prefix('_')?;
950 let rest = parts.next()?;
951 if service.is_empty() || protocol.is_empty() || rest.is_empty() {
952 return None;
953 }
954 Some((service.to_string(), protocol.to_string(), rest.to_string()))
955}
956
957fn reverse_dns_name(ip: &IpAddr) -> String {
958 match ip {
959 IpAddr::V4(addr) => {
960 let octets = addr.octets();
961 format!(
962 "{}.{}.{}.{}.in-addr.arpa",
963 octets[3], octets[2], octets[1], octets[0]
964 )
965 }
966 IpAddr::V6(addr) => {
967 let segments = addr.segments();
968 let mut result = String::with_capacity(72);
970 let mut first = true;
971 for segment in segments.iter().rev() {
972 for shift in [0, 4, 8, 12] {
973 if !first {
974 result.push('.');
975 }
976 first = false;
977 let nibble = (segment >> shift) & 0xF;
978 result
979 .push(char::from_digit(nibble as u32, 16).expect("nibble is always 0-15"));
980 }
981 }
982 result.push_str(".ip6.arpa");
983 result
984 }
985 }
986}
987
988fn parse_caa(caa: &CAA) -> (u8, String, String) {
989 let flags = if caa.issuer_critical { 128 } else { 0 };
996 let tag = caa.tag.clone();
997 let value = String::from_utf8_lossy(&caa.value).to_string();
998 (flags, tag, value)
999}
1000
1001fn is_valid_srv_label(label: &str) -> bool {
1003 !label.is_empty()
1004 && label.len() <= 63
1005 && label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-')
1006 && !label.starts_with('-')
1007 && !label.ends_with('-')
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012 use super::*;
1022 use std::net::{Ipv4Addr, Ipv6Addr};
1023
1024 #[test]
1027 fn record_type_from_str_accepts_lowercase() {
1028 assert_eq!(RecordType::from_str("a").unwrap(), RecordType::A);
1029 assert_eq!(RecordType::from_str("mx").unwrap(), RecordType::MX);
1030 assert_eq!(RecordType::from_str("cname").unwrap(), RecordType::CNAME);
1031 assert_eq!(RecordType::from_str("dnskey").unwrap(), RecordType::DNSKEY);
1032 }
1033
1034 #[test]
1035 fn record_type_from_str_accepts_mixed_case() {
1036 assert_eq!(RecordType::from_str("Mx").unwrap(), RecordType::MX);
1037 assert_eq!(RecordType::from_str("cNaMe").unwrap(), RecordType::CNAME);
1038 }
1039
1040 #[test]
1041 fn record_type_from_str_rejects_whitespace_padded() {
1042 assert!(RecordType::from_str(" A").is_err());
1046 assert!(RecordType::from_str("A ").is_err());
1047 assert!(RecordType::from_str("\tA\n").is_err());
1048 }
1049
1050 #[test]
1051 fn record_type_from_str_rejects_unknown() {
1052 assert!(RecordType::from_str("NOTAREAL").is_err());
1053 assert!(RecordType::from_str("A1").is_err());
1054 assert!(RecordType::from_str("").is_err());
1055 }
1056
1057 #[test]
1058 fn record_type_from_str_accepts_star_as_any() {
1059 assert_eq!(RecordType::from_str("*").unwrap(), RecordType::ANY);
1060 assert_eq!(RecordType::from_str("ANY").unwrap(), RecordType::ANY);
1061 assert_eq!(RecordType::from_str("any").unwrap(), RecordType::ANY);
1062 }
1063
1064 #[test]
1067 fn srv_label_accepts_alphanumeric_and_hyphen() {
1068 assert!(is_valid_srv_label("http"));
1069 assert!(is_valid_srv_label("ldap-tls"));
1070 assert!(is_valid_srv_label("a1"));
1071 assert!(is_valid_srv_label("tcp"));
1072 }
1073
1074 #[test]
1075 fn srv_label_rejects_empty() {
1076 assert!(!is_valid_srv_label(""));
1077 }
1078
1079 #[test]
1080 fn srv_label_rejects_leading_or_trailing_hyphen() {
1081 assert!(!is_valid_srv_label("-http"));
1082 assert!(!is_valid_srv_label("http-"));
1083 assert!(!is_valid_srv_label("-"));
1084 }
1085
1086 #[test]
1087 fn srv_label_rejects_dots() {
1088 assert!(!is_valid_srv_label("http.evil"));
1091 assert!(!is_valid_srv_label("a.b"));
1092 }
1093
1094 #[test]
1095 fn srv_label_rejects_special_chars() {
1096 assert!(!is_valid_srv_label("http evil"));
1097 assert!(!is_valid_srv_label("http/evil"));
1098 assert!(!is_valid_srv_label("http\0"));
1099 assert!(!is_valid_srv_label("http\n"));
1100 }
1101
1102 #[test]
1103 fn srv_label_rejects_over_63_chars() {
1104 let too_long = "a".repeat(64);
1105 assert!(!is_valid_srv_label(&too_long));
1106 let exactly_63 = "a".repeat(63);
1107 assert!(is_valid_srv_label(&exactly_63));
1108 }
1109
1110 #[test]
1113 fn classify_ns_presence_absent_on_empty_ok() {
1114 let r: Result<Vec<DnsRecord>> = Ok(vec![]);
1116 assert_eq!(classify_ns_presence(&r), DnsPresence::Absent);
1117 }
1118
1119 #[test]
1120 fn classify_ns_presence_present_on_records() {
1121 let rec = DnsRecord {
1122 name: "example.test.".to_string(),
1123 record_type: RecordType::NS,
1124 ttl: 3600,
1125 data: RecordData::NS {
1126 nameserver: "ns1.example.net.".to_string(),
1127 },
1128 };
1129 let r: Result<Vec<DnsRecord>> = Ok(vec![rec]);
1130 assert_eq!(classify_ns_presence(&r), DnsPresence::Present);
1131 }
1132
1133 #[test]
1134 fn classify_ns_presence_unknown_on_error() {
1135 let r: Result<Vec<DnsRecord>> = Err(SeerError::DnsError("servfail".to_string()));
1136 assert_eq!(classify_ns_presence(&r), DnsPresence::Unknown);
1137 }
1138
1139 #[test]
1142 fn reverse_dns_name_formats_ipv4_correctly() {
1143 let ip: IpAddr = Ipv4Addr::new(192, 0, 2, 1).into();
1144 assert_eq!(reverse_dns_name(&ip), "1.2.0.192.in-addr.arpa");
1145 }
1146
1147 #[test]
1148 fn reverse_dns_name_formats_ipv6_correctly() {
1149 let ip: IpAddr = Ipv6Addr::LOCALHOST.into();
1151 let name = reverse_dns_name(&ip);
1152 assert!(
1153 name.ends_with(".ip6.arpa"),
1154 "must end with .ip6.arpa; got: {}",
1155 name
1156 );
1157 assert!(
1159 name.starts_with("1."),
1160 "expected '1.' prefix, got: {}",
1161 name
1162 );
1163 assert_eq!(name.len(), 72);
1165 }
1166
1167 #[test]
1170 fn resolver_new_has_default_timeout() {
1171 let r = DnsResolver::new();
1172 assert_eq!(r.timeout, DEFAULT_TIMEOUT);
1173 }
1174
1175 #[test]
1176 fn resolver_with_timeout_overrides_default() {
1177 let custom = Duration::from_secs(42);
1178 let r = DnsResolver::new().with_timeout(custom);
1179 assert_eq!(r.timeout, custom);
1180 }
1181
1182 #[test]
1183 fn resolver_default_matches_new() {
1184 let a = DnsResolver::default();
1185 let b = DnsResolver::new();
1186 assert_eq!(a.timeout, b.timeout);
1187 }
1188
1189 #[tokio::test]
1192 async fn custom_resolver_rejects_invalid_input() {
1193 let r = DnsResolver::new();
1198 let err = r.create_custom_resolver("..").await.unwrap_err();
1199 let msg = err.to_string().to_lowercase();
1200 assert!(
1201 msg.contains("dns resolution failed") || msg.contains("invalid"),
1202 "expected resolution failure, got: {}",
1203 msg
1204 );
1205 }
1206
1207 #[tokio::test]
1208 async fn custom_resolver_rejects_private_ipv4() {
1209 let r = DnsResolver::new();
1212 for reserved in ["127.0.0.1", "10.0.0.1", "192.168.1.1", "169.254.169.254"] {
1213 let err = r.create_custom_resolver(reserved).await.unwrap_err();
1214 let msg = err.to_string().to_lowercase();
1215 assert!(
1216 msg.contains("blocked") || msg.contains("reserved"),
1217 "reserved IP {} must be rejected, got error: {}",
1218 reserved,
1219 msg
1220 );
1221 }
1222 }
1223
1224 #[tokio::test]
1225 async fn custom_resolver_rejects_loopback_ipv6() {
1226 let r = DnsResolver::new();
1227 let err = r.create_custom_resolver("::1").await.unwrap_err();
1228 let msg = err.to_string().to_lowercase();
1229 assert!(
1230 msg.contains("blocked") || msg.contains("reserved"),
1231 "::1 must be rejected, got error: {}",
1232 msg
1233 );
1234 }
1235
1236 #[tokio::test]
1237 async fn custom_resolver_accepts_public_ipv4() {
1238 let r = DnsResolver::new();
1240 let result = r.create_custom_resolver("8.8.8.8").await;
1241 assert!(
1242 result.is_ok(),
1243 "8.8.8.8 must be accepted as a public nameserver, got: {:?}",
1244 result.err()
1245 );
1246 }
1247
1248 #[tokio::test]
1251 async fn resolve_srv_rejects_invalid_service_label() {
1252 let r = DnsResolver::new();
1253 let result = r.resolve_srv("http.evil", "tcp", "example.com", None).await;
1255 assert!(result.is_err());
1256 let msg = result.unwrap_err().to_string().to_lowercase();
1257 assert!(
1258 msg.contains("invalid srv service"),
1259 "expected SRV service validation error, got: {}",
1260 msg
1261 );
1262 }
1263
1264 #[tokio::test]
1265 async fn resolve_srv_rejects_invalid_protocol_label() {
1266 let r = DnsResolver::new();
1267 let result = r.resolve_srv("http", "tcp.evil", "example.com", None).await;
1268 assert!(result.is_err());
1269 let msg = result.unwrap_err().to_string().to_lowercase();
1270 assert!(
1271 msg.contains("invalid srv protocol"),
1272 "expected SRV protocol validation error, got: {}",
1273 msg
1274 );
1275 }
1276
1277 #[tokio::test]
1280 async fn resolve_normalizes_uppercase_domain_input() {
1281 let r = DnsResolver::new();
1286 let result = r.resolve(".bad.example", RecordType::A, None).await;
1287 assert!(result.is_err(), "leading-dot domain must be rejected");
1288 }
1289
1290 #[test]
1295 fn parse_srv_query_extracts_service_proto_and_name() {
1296 assert_eq!(
1297 parse_srv_query("_sip._tcp.example.com"),
1298 Some((
1299 "sip".to_string(),
1300 "tcp".to_string(),
1301 "example.com".to_string()
1302 ))
1303 );
1304 }
1305
1306 #[test]
1307 fn parse_srv_query_keeps_multilabel_domain() {
1308 assert_eq!(
1309 parse_srv_query("_sip._tcp.sip.voice.google.com"),
1310 Some((
1311 "sip".to_string(),
1312 "tcp".to_string(),
1313 "sip.voice.google.com".to_string()
1314 ))
1315 );
1316 }
1317
1318 #[test]
1319 fn parse_srv_query_rejects_bare_domain() {
1320 assert_eq!(parse_srv_query("example.com"), None);
1321 }
1322
1323 #[test]
1324 fn parse_srv_query_rejects_missing_proto_label() {
1325 assert_eq!(parse_srv_query("_sip.example.com"), None);
1327 }
1328
1329 #[tokio::test]
1330 async fn resolve_rejects_bare_domain_for_srv_as_input_error() {
1331 let r = DnsResolver::new();
1336 let err = r
1337 .resolve("example.com", RecordType::SRV, None)
1338 .await
1339 .expect_err("bare-domain SRV must error");
1340 assert!(
1341 matches!(err, SeerError::InvalidInput(_)),
1342 "bare-domain SRV should be an input error, got: {err:?}"
1343 );
1344 assert!(err.to_string().contains("_service._proto"));
1345 }
1346
1347 #[tokio::test]
1348 #[ignore = "live network"]
1349 async fn resolve_srv_via_dig_style_name_returns_records() {
1350 let r = DnsResolver::new();
1353 let records = r
1354 .resolve("_caldavs._tcp.google.com", RecordType::SRV, None)
1355 .await
1356 .expect("dig-style SRV lookup should succeed");
1357 assert!(!records.is_empty(), "expected SRV records");
1358 assert!(records.iter().all(|r| r.record_type == RecordType::SRV));
1359 }
1360
1361 #[tokio::test]
1362 #[ignore = "live network"]
1363 async fn resolve_naptr_returns_records() {
1364 let r = DnsResolver::new();
1366 let records = r
1367 .resolve("sip2sip.info", RecordType::NAPTR, None)
1368 .await
1369 .expect("NAPTR lookup should succeed");
1370 assert!(!records.is_empty(), "expected NAPTR records");
1371 assert!(records.iter().all(|r| r.record_type == RecordType::NAPTR));
1372 }
1373
1374 #[test]
1377 fn prepare_query_passes_ipv6_literal_through_for_ptr() {
1378 let out = prepare_query("2606:4700:4700::1111", RecordType::PTR).unwrap();
1384 assert_eq!(out, "2606:4700:4700::1111");
1385 }
1386
1387 #[test]
1388 fn prepare_query_passes_ipv6_loopback_through_for_ptr() {
1389 let out = prepare_query("::1", RecordType::PTR).unwrap();
1390 assert_eq!(out, "::1");
1391 }
1392
1393 #[test]
1394 fn prepare_query_passes_ipv4_literal_through_for_ptr() {
1395 let out = prepare_query("8.8.8.8", RecordType::PTR).unwrap();
1396 assert_eq!(out, "8.8.8.8");
1397 }
1398
1399 #[test]
1400 fn prepare_query_normalizes_non_ip_ptr_names() {
1401 let out = prepare_query("1.1.1.1.in-addr.arpa", RecordType::PTR).unwrap();
1403 assert_eq!(out, "1.1.1.1.in-addr.arpa");
1404 }
1405
1406 #[test]
1407 fn prepare_query_normalizes_domains_for_non_ptr() {
1408 let out = prepare_query("HTTPS://WWW.Example.com/path", RecordType::A).unwrap();
1409 assert_eq!(out, "example.com");
1410 }
1411}