Skip to main content

seer_core/dns/
resolver.rs

1use std::net::IpAddr;
2use std::str::FromStr;
3use std::time::Duration;
4
5use hickory_resolver::config::{NameServerConfig, ResolveHosts, ResolverConfig, GOOGLE};
6use hickory_resolver::net::runtime::TokioRuntimeProvider;
7use hickory_resolver::net::NetError;
8use hickory_resolver::proto::dnssec::PublicKey;
9use hickory_resolver::proto::rr::rdata::CAA;
10use hickory_resolver::proto::rr::{RData as HickoryRData, RecordType as HickoryRecordType};
11use hickory_resolver::TokioResolver;
12use tracing::{debug, instrument};
13
14use super::records::{DnsRecord, RecordData, RecordType};
15use crate::error::{Result, SeerError};
16use crate::validation::normalize_domain;
17
18/// Convert a DNS lookup result, treating "no records found" as an empty vec
19/// rather than an error. This is correct DNS behavior — the absence of a
20/// record type for a domain is a valid response (NODATA), not a failure.
21fn dns_lookup_or_empty<T>(
22    result: std::result::Result<T, NetError>,
23    record_type: &str,
24) -> Result<Option<T>> {
25    match result {
26        Ok(response) => Ok(Some(response)),
27        Err(e) if e.is_no_records_found() => Ok(None),
28        Err(e) => Err(SeerError::DnsError(format!(
29            "{} lookup failed: {}",
30            record_type, e
31        ))),
32    }
33}
34
35/// Default timeout for DNS queries (5 seconds).
36/// DNS is typically fast; longer timeouts indicate network issues or unreachable servers.
37const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
38
39/// Build a TokioResolver pre-configured with the given upstream config and
40/// our standard options (timeout, retries, no hosts-file consultation).
41///
42/// Build only fails when TLS configuration construction fails; we don't
43/// enable TLS features in seer-core so `expect` is safe here and is the
44/// cleanest expression of that invariant.
45fn build_resolver(config: ResolverConfig, timeout: Duration) -> TokioResolver {
46    let mut builder = TokioResolver::builder_with_config(config, TokioRuntimeProvider::default());
47    {
48        let opts = builder.options_mut();
49        opts.timeout = timeout;
50        opts.attempts = 2;
51        opts.use_hosts_file = ResolveHosts::Never;
52    }
53    builder
54        .build()
55        .expect("hickory resolver build is infallible without TLS features")
56}
57
58/// DNS resolver for querying various record types.
59///
60/// Uses Google DNS (8.8.8.8) by default, but supports custom nameservers.
61/// The default resolver is cached and reused across queries to avoid
62/// repeated initialization overhead.
63#[derive(Clone)]
64pub struct DnsResolver {
65    timeout: Duration,
66    /// Cached default resolver (Google DNS). Reused across all queries
67    /// that don't specify a custom nameserver.
68    default_resolver: TokioResolver,
69}
70
71impl std::fmt::Debug for DnsResolver {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.debug_struct("DnsResolver")
74            .field("timeout", &self.timeout)
75            .finish()
76    }
77}
78
79impl Default for DnsResolver {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl DnsResolver {
86    /// Creates a new DNS resolver with default settings.
87    pub fn new() -> Self {
88        Self {
89            timeout: DEFAULT_TIMEOUT,
90            default_resolver: build_resolver(ResolverConfig::udp_and_tcp(&GOOGLE), DEFAULT_TIMEOUT),
91        }
92    }
93
94    /// Sets the timeout for DNS queries.
95    ///
96    /// The default is 5 seconds, which is sufficient for most DNS queries.
97    pub fn with_timeout(mut self, timeout: Duration) -> Self {
98        self.timeout = timeout;
99        self.default_resolver = build_resolver(ResolverConfig::udp_and_tcp(&GOOGLE), timeout);
100        self
101    }
102
103    fn create_custom_resolver(&self, nameserver: &str) -> Result<TokioResolver> {
104        let ip: IpAddr = nameserver
105            .parse()
106            .map_err(|_| SeerError::DnsError(format!("invalid nameserver IP: {}", nameserver)))?;
107
108        // SSRF protection: reject private/reserved IPs for user-supplied nameservers
109        if let Some(reason) = crate::validation::describe_reserved_ip(&ip) {
110            return Err(SeerError::DnsError(format!(
111                "nameserver {} blocked: {}",
112                nameserver, reason
113            )));
114        }
115
116        // Build a config with a single UDP nameserver on the given IP.
117        // In hickory 0.26, NameServerConfig::udp(IpAddr) builds a
118        // ConnectionConfig with the default DNS port (53) for us, so we
119        // no longer need to construct a SocketAddr explicitly.
120        let ns_config = NameServerConfig::udp(ip);
121
122        let mut config = ResolverConfig::from_parts(None, vec![], vec![]);
123        config.add_name_server(ns_config);
124
125        Ok(build_resolver(config, self.timeout))
126    }
127
128    /// Resolves DNS records for a domain.
129    ///
130    /// # Arguments
131    /// * `domain` - The domain name to query
132    /// * `record_type` - The type of DNS record to look up (A, AAAA, MX, etc.)
133    /// * `nameserver` - Optional custom nameserver IP; uses Google DNS if None
134    #[instrument(skip(self), fields(domain = %domain, record_type = %record_type))]
135    pub async fn resolve(
136        &self,
137        domain: &str,
138        record_type: RecordType,
139        nameserver: Option<&str>,
140    ) -> Result<Vec<DnsRecord>> {
141        // Reuse the cached default resolver when no custom nameserver is specified
142        let custom_resolver;
143        let resolver = if let Some(ns) = nameserver {
144            custom_resolver = self.create_custom_resolver(ns)?;
145            &custom_resolver
146        } else {
147            &self.default_resolver
148        };
149        let domain = normalize_domain(domain)?;
150
151        debug!(nameserver = nameserver.unwrap_or("system"), "Resolving DNS");
152
153        match record_type {
154            RecordType::A => self.resolve_a(resolver, &domain).await,
155            RecordType::AAAA => self.resolve_aaaa(resolver, &domain).await,
156            RecordType::CNAME => self.resolve_cname(resolver, &domain).await,
157            RecordType::MX => self.resolve_mx(resolver, &domain).await,
158            RecordType::NS => self.resolve_ns(resolver, &domain).await,
159            RecordType::TXT => self.resolve_txt(resolver, &domain).await,
160            RecordType::SOA => self.resolve_soa(resolver, &domain).await,
161            RecordType::PTR => self.resolve_ptr(resolver, &domain).await,
162            RecordType::SRV => Err(SeerError::DnsError(
163                "SRV records require service name format: _service._proto.name".to_string(),
164            )),
165            RecordType::CAA => self.resolve_caa(resolver, &domain).await,
166            RecordType::DNSKEY => self.resolve_dnskey(resolver, &domain).await,
167            RecordType::DS => self.resolve_ds(resolver, &domain).await,
168            RecordType::ANY => self.resolve_any(resolver, &domain).await,
169            _ => Err(SeerError::DnsError(format!(
170                "Record type {} not implemented",
171                record_type
172            ))),
173        }
174    }
175
176    /// Resolves SRV records for a service.
177    ///
178    /// # Arguments
179    /// * `service` - The service name (e.g., "http", "ldap")
180    /// * `protocol` - The protocol (e.g., "tcp", "udp")
181    /// * `domain` - The domain name
182    /// * `nameserver` - Optional custom nameserver IP
183    #[instrument(skip(self), fields(domain = %domain, service = %service, protocol = %protocol))]
184    pub async fn resolve_srv(
185        &self,
186        service: &str,
187        protocol: &str,
188        domain: &str,
189        nameserver: Option<&str>,
190    ) -> Result<Vec<DnsRecord>> {
191        // Validate service and protocol to prevent DNS query injection
192        if !is_valid_srv_label(service) {
193            return Err(SeerError::DnsError(format!(
194                "invalid SRV service name: {}",
195                service
196            )));
197        }
198        if !is_valid_srv_label(protocol) {
199            return Err(SeerError::DnsError(format!(
200                "invalid SRV protocol name: {}",
201                protocol
202            )));
203        }
204
205        let custom_resolver;
206        let resolver = if let Some(ns) = nameserver {
207            custom_resolver = self.create_custom_resolver(ns)?;
208            &custom_resolver
209        } else {
210            &self.default_resolver
211        };
212        let query_name = format!("_{}._{}.{}", service, protocol, domain);
213
214        let Some(response) = dns_lookup_or_empty(
215            resolver.lookup(&query_name, HickoryRecordType::SRV).await,
216            "SRV",
217        )?
218        else {
219            return Ok(vec![]);
220        };
221
222        let records = response
223            .answers()
224            .iter()
225            .filter_map(|record| {
226                if let HickoryRData::SRV(srv) = &record.data {
227                    Some(DnsRecord {
228                        name: query_name.clone(),
229                        record_type: RecordType::SRV,
230                        ttl: record.ttl,
231                        data: RecordData::SRV {
232                            priority: srv.priority,
233                            weight: srv.weight,
234                            port: srv.port,
235                            target: srv.target.to_string(),
236                        },
237                    })
238                } else {
239                    None
240                }
241            })
242            .collect();
243
244        Ok(records)
245    }
246
247    async fn resolve_a(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
248        let Some(response) =
249            dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::A).await, "A")?
250        else {
251            return Ok(vec![]);
252        };
253
254        let records = response
255            .answers()
256            .iter()
257            .filter_map(|record| {
258                if let HickoryRData::A(addr) = &record.data {
259                    Some(DnsRecord {
260                        name: domain.to_string(),
261                        record_type: RecordType::A,
262                        ttl: record.ttl,
263                        data: RecordData::A {
264                            address: addr.0.to_string(),
265                        },
266                    })
267                } else {
268                    None
269                }
270            })
271            .collect();
272
273        Ok(records)
274    }
275
276    async fn resolve_aaaa(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
277        let Some(response) = dns_lookup_or_empty(
278            resolver.lookup(domain, HickoryRecordType::AAAA).await,
279            "AAAA",
280        )?
281        else {
282            return Ok(vec![]);
283        };
284
285        let records = response
286            .answers()
287            .iter()
288            .filter_map(|record| {
289                if let HickoryRData::AAAA(addr) = &record.data {
290                    Some(DnsRecord {
291                        name: domain.to_string(),
292                        record_type: RecordType::AAAA,
293                        ttl: record.ttl,
294                        data: RecordData::AAAA {
295                            address: addr.0.to_string(),
296                        },
297                    })
298                } else {
299                    None
300                }
301            })
302            .collect();
303
304        Ok(records)
305    }
306
307    async fn resolve_cname(
308        &self,
309        resolver: &TokioResolver,
310        domain: &str,
311    ) -> Result<Vec<DnsRecord>> {
312        let Some(response) = dns_lookup_or_empty(
313            resolver.lookup(domain, HickoryRecordType::CNAME).await,
314            "CNAME",
315        )?
316        else {
317            return Ok(vec![]);
318        };
319
320        let records = response
321            .answers()
322            .iter()
323            .filter_map(|record| {
324                if let HickoryRData::CNAME(cname) = &record.data {
325                    Some(DnsRecord {
326                        name: domain.to_string(),
327                        record_type: RecordType::CNAME,
328                        ttl: record.ttl,
329                        data: RecordData::CNAME {
330                            target: cname.0.to_string(),
331                        },
332                    })
333                } else {
334                    None
335                }
336            })
337            .collect();
338
339        Ok(records)
340    }
341
342    async fn resolve_mx(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
343        let Some(response) =
344            dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::MX).await, "MX")?
345        else {
346            return Ok(vec![]);
347        };
348
349        let mut records: Vec<DnsRecord> = response
350            .answers()
351            .iter()
352            .filter_map(|record| {
353                if let HickoryRData::MX(mx) = &record.data {
354                    Some(DnsRecord {
355                        name: domain.to_string(),
356                        record_type: RecordType::MX,
357                        ttl: record.ttl,
358                        data: RecordData::MX {
359                            preference: mx.preference,
360                            exchange: mx.exchange.to_string(),
361                        },
362                    })
363                } else {
364                    None
365                }
366            })
367            .collect();
368
369        records.sort_by_key(|r| {
370            if let RecordData::MX { preference, .. } = &r.data {
371                *preference
372            } else {
373                0
374            }
375        });
376
377        Ok(records)
378    }
379
380    async fn resolve_ns(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
381        let Some(response) =
382            dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::NS).await, "NS")?
383        else {
384            return Ok(vec![]);
385        };
386
387        let records = response
388            .answers()
389            .iter()
390            .filter_map(|record| {
391                if let HickoryRData::NS(ns) = &record.data {
392                    Some(DnsRecord {
393                        name: domain.to_string(),
394                        record_type: RecordType::NS,
395                        ttl: record.ttl,
396                        data: RecordData::NS {
397                            nameserver: ns.0.to_string(),
398                        },
399                    })
400                } else {
401                    None
402                }
403            })
404            .collect();
405
406        Ok(records)
407    }
408
409    async fn resolve_txt(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
410        let Some(response) =
411            dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::TXT).await, "TXT")?
412        else {
413            return Ok(vec![]);
414        };
415
416        let records = response
417            .answers()
418            .iter()
419            .filter_map(|record| {
420                if let HickoryRData::TXT(txt) = &record.data {
421                    let text = txt
422                        .txt_data
423                        .iter()
424                        .map(|data| String::from_utf8_lossy(data).to_string())
425                        .collect::<Vec<_>>()
426                        .join("");
427
428                    Some(DnsRecord {
429                        name: domain.to_string(),
430                        record_type: RecordType::TXT,
431                        ttl: record.ttl,
432                        data: RecordData::TXT { text },
433                    })
434                } else {
435                    None
436                }
437            })
438            .collect();
439
440        Ok(records)
441    }
442
443    async fn resolve_soa(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
444        let Some(response) =
445            dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::SOA).await, "SOA")?
446        else {
447            return Ok(vec![]);
448        };
449
450        let records = response
451            .answers()
452            .iter()
453            .filter_map(|record| {
454                if let HickoryRData::SOA(soa) = &record.data {
455                    Some(DnsRecord {
456                        name: domain.to_string(),
457                        record_type: RecordType::SOA,
458                        ttl: record.ttl,
459                        data: RecordData::SOA {
460                            mname: soa.mname.to_string(),
461                            rname: soa.rname.to_string(),
462                            serial: soa.serial,
463                            refresh: soa.refresh.try_into().unwrap_or(0),
464                            retry: soa.retry.try_into().unwrap_or(0),
465                            expire: soa.expire.try_into().unwrap_or(0),
466                            minimum: soa.minimum,
467                        },
468                    })
469                } else {
470                    None
471                }
472            })
473            .collect();
474
475        Ok(records)
476    }
477
478    async fn resolve_ptr(&self, resolver: &TokioResolver, query: &str) -> Result<Vec<DnsRecord>> {
479        // If it's an IP address, convert to reverse DNS format
480        let query = if let Ok(ip) = IpAddr::from_str(query) {
481            reverse_dns_name(&ip)
482        } else {
483            query.to_string()
484        };
485
486        let Some(response) =
487            dns_lookup_or_empty(resolver.lookup(&query, HickoryRecordType::PTR).await, "PTR")?
488        else {
489            return Ok(vec![]);
490        };
491
492        let records = response
493            .answers()
494            .iter()
495            .filter_map(|record| {
496                if let HickoryRData::PTR(ptr) = &record.data {
497                    Some(DnsRecord {
498                        name: query.clone(),
499                        record_type: RecordType::PTR,
500                        ttl: record.ttl,
501                        data: RecordData::PTR {
502                            target: ptr.0.to_string(),
503                        },
504                    })
505                } else {
506                    None
507                }
508            })
509            .collect();
510
511        Ok(records)
512    }
513
514    async fn resolve_caa(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
515        let Some(response) =
516            dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::CAA).await, "CAA")?
517        else {
518            return Ok(vec![]);
519        };
520
521        let records = response
522            .answers()
523            .iter()
524            .filter_map(|record| {
525                if let HickoryRData::CAA(caa) = &record.data {
526                    let (flags, tag, value) = parse_caa(caa);
527                    Some(DnsRecord {
528                        name: domain.to_string(),
529                        record_type: RecordType::CAA,
530                        ttl: record.ttl,
531                        data: RecordData::CAA { flags, tag, value },
532                    })
533                } else {
534                    None
535                }
536            })
537            .collect();
538
539        Ok(records)
540    }
541
542    async fn resolve_dnskey(
543        &self,
544        resolver: &TokioResolver,
545        domain: &str,
546    ) -> Result<Vec<DnsRecord>> {
547        use hickory_resolver::proto::dnssec::rdata::DNSSECRData;
548
549        let Some(response) = dns_lookup_or_empty(
550            resolver.lookup(domain, HickoryRecordType::DNSKEY).await,
551            "DNSKEY",
552        )?
553        else {
554            return Ok(vec![]);
555        };
556
557        let records = response
558            .answers()
559            .iter()
560            .filter_map(|record| {
561                if let HickoryRData::DNSSEC(DNSSECRData::DNSKEY(dnskey)) = &record.data {
562                    use base64::{engine::general_purpose::STANDARD, Engine};
563                    let public_key_buf = dnskey.public_key();
564                    let public_key = STANDARD.encode(public_key_buf.public_bytes());
565                    Some(DnsRecord {
566                        name: domain.to_string(),
567                        record_type: RecordType::DNSKEY,
568                        ttl: record.ttl,
569                        data: RecordData::DNSKEY {
570                            flags: dnskey.flags(),
571                            // Protocol is always 3 for DNSSEC (RFC 4034)
572                            protocol: 3,
573                            algorithm: u8::from(public_key_buf.algorithm()),
574                            public_key,
575                        },
576                    })
577                } else {
578                    None
579                }
580            })
581            .collect();
582
583        Ok(records)
584    }
585
586    async fn resolve_ds(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
587        use hickory_resolver::proto::dnssec::rdata::DNSSECRData;
588
589        let Some(response) =
590            dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::DS).await, "DS")?
591        else {
592            return Ok(vec![]);
593        };
594
595        let records = response
596            .answers()
597            .iter()
598            .filter_map(|record| {
599                if let HickoryRData::DNSSEC(DNSSECRData::DS(ds)) = &record.data {
600                    let digest = ds
601                        .digest()
602                        .iter()
603                        .map(|b| format!("{:02X}", b))
604                        .collect::<String>();
605                    Some(DnsRecord {
606                        name: domain.to_string(),
607                        record_type: RecordType::DS,
608                        ttl: record.ttl,
609                        data: RecordData::DS {
610                            key_tag: ds.key_tag(),
611                            algorithm: u8::from(ds.algorithm()),
612                            digest_type: u8::from(ds.digest_type()),
613                            digest,
614                        },
615                    })
616                } else {
617                    None
618                }
619            })
620            .collect();
621
622        Ok(records)
623    }
624
625    async fn resolve_any(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
626        let mut all_records = Vec::new();
627
628        // Query common record types
629        let record_types = [
630            RecordType::A,
631            RecordType::AAAA,
632            RecordType::MX,
633            RecordType::NS,
634            RecordType::TXT,
635            RecordType::SOA,
636            RecordType::CAA,
637        ];
638
639        for record_type in record_types {
640            match self.resolve_type(resolver, domain, record_type).await {
641                Ok(records) => all_records.extend(records),
642                Err(_) => continue, // Skip record types that don't exist
643            }
644        }
645
646        Ok(all_records)
647    }
648
649    async fn resolve_type(
650        &self,
651        resolver: &TokioResolver,
652        domain: &str,
653        record_type: RecordType,
654    ) -> Result<Vec<DnsRecord>> {
655        match record_type {
656            RecordType::A => self.resolve_a(resolver, domain).await,
657            RecordType::AAAA => self.resolve_aaaa(resolver, domain).await,
658            RecordType::CNAME => self.resolve_cname(resolver, domain).await,
659            RecordType::MX => self.resolve_mx(resolver, domain).await,
660            RecordType::NS => self.resolve_ns(resolver, domain).await,
661            RecordType::TXT => self.resolve_txt(resolver, domain).await,
662            RecordType::SOA => self.resolve_soa(resolver, domain).await,
663            RecordType::CAA => self.resolve_caa(resolver, domain).await,
664            RecordType::DNSKEY => self.resolve_dnskey(resolver, domain).await,
665            RecordType::DS => self.resolve_ds(resolver, domain).await,
666            _ => Err(SeerError::DnsError("unsupported record type".to_string())),
667        }
668    }
669}
670
671// Domain normalization is now handled by the shared validation module
672
673fn reverse_dns_name(ip: &IpAddr) -> String {
674    match ip {
675        IpAddr::V4(addr) => {
676            let octets = addr.octets();
677            format!(
678                "{}.{}.{}.{}.in-addr.arpa",
679                octets[3], octets[2], octets[1], octets[0]
680            )
681        }
682        IpAddr::V6(addr) => {
683            let segments = addr.segments();
684            // 32 hex nibbles + 31 dots + ".ip6.arpa" (9) = 72 chars
685            let mut result = String::with_capacity(72);
686            let mut first = true;
687            for segment in segments.iter().rev() {
688                for shift in [0, 4, 8, 12] {
689                    if !first {
690                        result.push('.');
691                    }
692                    first = false;
693                    let nibble = (segment >> shift) & 0xF;
694                    result
695                        .push(char::from_digit(nibble as u32, 16).expect("nibble is always 0-15"));
696                }
697            }
698            result.push_str(".ip6.arpa");
699            result
700        }
701    }
702}
703
704fn parse_caa(caa: &CAA) -> (u8, String, String) {
705    // hickory 0.26: CAA fields are public. `issuer_critical` and `tag` are
706    // plain fields; `value` is a `Vec<u8>` because RFC 8659 permits binary
707    // values for unknown property types. For seer's reporting purposes the
708    // common tags (issue/issuewild/iodef) are always UTF-8, so a lossy
709    // conversion preserves prior behavior without panicking on the rare
710    // binary case.
711    let flags = if caa.issuer_critical { 128 } else { 0 };
712    let tag = caa.tag.clone();
713    let value = String::from_utf8_lossy(&caa.value).to_string();
714    (flags, tag, value)
715}
716
717/// Validates SRV service/protocol labels (alphanumeric and hyphens only, no dots)
718fn is_valid_srv_label(label: &str) -> bool {
719    !label.is_empty()
720        && label.len() <= 63
721        && label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-')
722        && !label.starts_with('-')
723        && !label.ends_with('-')
724}
725
726#[cfg(test)]
727mod tests {
728    //! Unit tests for the pure helpers and public surface of the DNS
729    //! resolver. Tests that would exercise the hickory wire protocol
730    //! are covered by live-network tests marked `#[ignore]` in the
731    //! sibling modules (`dns/dnssec.rs`, `dns/follow.rs`). Deeper
732    //! coverage of `resolve_*` paths would require a hickory mock,
733    //! which is out of scope for this module.
734    //
735    // TODO: mock hickory resolver for full path coverage.
736
737    use super::*;
738    use std::net::{Ipv4Addr, Ipv6Addr};
739
740    // --- RecordType::from_str edge cases -----------------------------
741
742    #[test]
743    fn record_type_from_str_accepts_lowercase() {
744        assert_eq!(RecordType::from_str("a").unwrap(), RecordType::A);
745        assert_eq!(RecordType::from_str("mx").unwrap(), RecordType::MX);
746        assert_eq!(RecordType::from_str("cname").unwrap(), RecordType::CNAME);
747        assert_eq!(RecordType::from_str("dnskey").unwrap(), RecordType::DNSKEY);
748    }
749
750    #[test]
751    fn record_type_from_str_accepts_mixed_case() {
752        assert_eq!(RecordType::from_str("Mx").unwrap(), RecordType::MX);
753        assert_eq!(RecordType::from_str("cNaMe").unwrap(), RecordType::CNAME);
754    }
755
756    #[test]
757    fn record_type_from_str_rejects_whitespace_padded() {
758        // No trim is done inside from_str; leading/trailing whitespace
759        // must currently cause a parse error so callers don't pass
760        // malformed labels through.
761        assert!(RecordType::from_str(" A").is_err());
762        assert!(RecordType::from_str("A ").is_err());
763        assert!(RecordType::from_str("\tA\n").is_err());
764    }
765
766    #[test]
767    fn record_type_from_str_rejects_unknown() {
768        assert!(RecordType::from_str("NOTAREAL").is_err());
769        assert!(RecordType::from_str("A1").is_err());
770        assert!(RecordType::from_str("").is_err());
771    }
772
773    #[test]
774    fn record_type_from_str_accepts_star_as_any() {
775        assert_eq!(RecordType::from_str("*").unwrap(), RecordType::ANY);
776        assert_eq!(RecordType::from_str("ANY").unwrap(), RecordType::ANY);
777        assert_eq!(RecordType::from_str("any").unwrap(), RecordType::ANY);
778    }
779
780    // --- is_valid_srv_label ------------------------------------------
781
782    #[test]
783    fn srv_label_accepts_alphanumeric_and_hyphen() {
784        assert!(is_valid_srv_label("http"));
785        assert!(is_valid_srv_label("ldap-tls"));
786        assert!(is_valid_srv_label("a1"));
787        assert!(is_valid_srv_label("tcp"));
788    }
789
790    #[test]
791    fn srv_label_rejects_empty() {
792        assert!(!is_valid_srv_label(""));
793    }
794
795    #[test]
796    fn srv_label_rejects_leading_or_trailing_hyphen() {
797        assert!(!is_valid_srv_label("-http"));
798        assert!(!is_valid_srv_label("http-"));
799        assert!(!is_valid_srv_label("-"));
800    }
801
802    #[test]
803    fn srv_label_rejects_dots() {
804        // Dots would let an attacker construct `_service._tcp.evil.com.target`
805        // and pivot the query to a different domain.
806        assert!(!is_valid_srv_label("http.evil"));
807        assert!(!is_valid_srv_label("a.b"));
808    }
809
810    #[test]
811    fn srv_label_rejects_special_chars() {
812        assert!(!is_valid_srv_label("http evil"));
813        assert!(!is_valid_srv_label("http/evil"));
814        assert!(!is_valid_srv_label("http\0"));
815        assert!(!is_valid_srv_label("http\n"));
816    }
817
818    #[test]
819    fn srv_label_rejects_over_63_chars() {
820        let too_long = "a".repeat(64);
821        assert!(!is_valid_srv_label(&too_long));
822        let exactly_63 = "a".repeat(63);
823        assert!(is_valid_srv_label(&exactly_63));
824    }
825
826    // --- reverse_dns_name --------------------------------------------
827
828    #[test]
829    fn reverse_dns_name_formats_ipv4_correctly() {
830        let ip: IpAddr = Ipv4Addr::new(192, 0, 2, 1).into();
831        assert_eq!(reverse_dns_name(&ip), "1.2.0.192.in-addr.arpa");
832    }
833
834    #[test]
835    fn reverse_dns_name_formats_ipv6_correctly() {
836        // ::1 (loopback) → 32 nibbles of 0 followed by ...0.0.0.1 reversed.
837        let ip: IpAddr = Ipv6Addr::LOCALHOST.into();
838        let name = reverse_dns_name(&ip);
839        assert!(
840            name.ends_with(".ip6.arpa"),
841            "must end with .ip6.arpa; got: {}",
842            name
843        );
844        // The first nibble (most-reversed position) must be 1 (from ::1 low bit).
845        assert!(
846            name.starts_with("1."),
847            "expected '1.' prefix, got: {}",
848            name
849        );
850        // 32 nibbles + 31 dots + ".ip6.arpa" (9 chars) = 72.
851        assert_eq!(name.len(), 72);
852    }
853
854    // --- DnsResolver construction ------------------------------------
855
856    #[test]
857    fn resolver_new_has_default_timeout() {
858        let r = DnsResolver::new();
859        assert_eq!(r.timeout, DEFAULT_TIMEOUT);
860    }
861
862    #[test]
863    fn resolver_with_timeout_overrides_default() {
864        let custom = Duration::from_secs(42);
865        let r = DnsResolver::new().with_timeout(custom);
866        assert_eq!(r.timeout, custom);
867    }
868
869    #[test]
870    fn resolver_default_matches_new() {
871        let a = DnsResolver::default();
872        let b = DnsResolver::new();
873        assert_eq!(a.timeout, b.timeout);
874    }
875
876    // --- create_custom_resolver validation ---------------------------
877
878    #[test]
879    fn custom_resolver_rejects_invalid_ip() {
880        let r = DnsResolver::new();
881        let err = r.create_custom_resolver("not-an-ip").unwrap_err();
882        let msg = err.to_string().to_lowercase();
883        assert!(
884            msg.contains("invalid nameserver ip"),
885            "expected 'invalid nameserver ip' in error, got: {}",
886            msg
887        );
888    }
889
890    #[test]
891    fn custom_resolver_rejects_private_ipv4() {
892        // SSRF defense: private / reserved ranges must be blocked even
893        // when passed as a literal IP rather than a hostname.
894        let r = DnsResolver::new();
895        for reserved in ["127.0.0.1", "10.0.0.1", "192.168.1.1", "169.254.169.254"] {
896            let err = r.create_custom_resolver(reserved).unwrap_err();
897            let msg = err.to_string().to_lowercase();
898            assert!(
899                msg.contains("blocked") || msg.contains("reserved"),
900                "reserved IP {} must be rejected, got error: {}",
901                reserved,
902                msg
903            );
904        }
905    }
906
907    #[test]
908    fn custom_resolver_rejects_loopback_ipv6() {
909        let r = DnsResolver::new();
910        let err = r.create_custom_resolver("::1").unwrap_err();
911        let msg = err.to_string().to_lowercase();
912        assert!(
913            msg.contains("blocked") || msg.contains("reserved"),
914            "::1 must be rejected, got error: {}",
915            msg
916        );
917    }
918
919    #[test]
920    fn custom_resolver_accepts_public_ipv4() {
921        // A known public resolver IP must be acceptable.
922        let r = DnsResolver::new();
923        let result = r.create_custom_resolver("8.8.8.8");
924        assert!(
925            result.is_ok(),
926            "8.8.8.8 must be accepted as a public nameserver, got: {:?}",
927            result.err()
928        );
929    }
930
931    // --- SRV query validation (integration between helper + resolver) ----
932
933    #[tokio::test]
934    async fn resolve_srv_rejects_invalid_service_label() {
935        let r = DnsResolver::new();
936        // With_dot service name would construct a malformed DNS query.
937        let result = r.resolve_srv("http.evil", "tcp", "example.com", None).await;
938        assert!(result.is_err());
939        let msg = result.unwrap_err().to_string().to_lowercase();
940        assert!(
941            msg.contains("invalid srv service"),
942            "expected SRV service validation error, got: {}",
943            msg
944        );
945    }
946
947    #[tokio::test]
948    async fn resolve_srv_rejects_invalid_protocol_label() {
949        let r = DnsResolver::new();
950        let result = r.resolve_srv("http", "tcp.evil", "example.com", None).await;
951        assert!(result.is_err());
952        let msg = result.unwrap_err().to_string().to_lowercase();
953        assert!(
954            msg.contains("invalid srv protocol"),
955            "expected SRV protocol validation error, got: {}",
956            msg
957        );
958    }
959
960    // --- Normalization applied before resolution ---------------------
961
962    #[tokio::test]
963    async fn resolve_normalizes_uppercase_domain_input() {
964        // We can't hit the network in unit tests, but we can at least
965        // assert that normalization rejects clearly-invalid input
966        // before any network call is made. Domains with a leading `.`
967        // are rejected by the normalizer.
968        let r = DnsResolver::new();
969        let result = r.resolve(".bad.example", RecordType::A, None).await;
970        assert!(result.is_err(), "leading-dot domain must be rejected");
971    }
972
973    // --- SRV record -------------------------------------------------
974
975    #[tokio::test]
976    async fn resolve_rejects_srv_record_type_without_srv_helper() {
977        // Calling `resolve` with SRV should return the helpful error
978        // instructing the caller to use `resolve_srv` instead.
979        let r = DnsResolver::new();
980        let result = r.resolve("example.com", RecordType::SRV, None).await;
981        assert!(result.is_err());
982        let msg = result.unwrap_err().to_string();
983        assert!(
984            msg.contains("SRV records require service name format"),
985            "expected helpful SRV error, got: {}",
986            msg
987        );
988    }
989}