Skip to main content

seer_core/dns/
resolver.rs

1use std::net::{IpAddr, SocketAddr};
2use std::str::FromStr;
3use std::time::Duration;
4
5use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
6use hickory_resolver::error::ResolveErrorKind;
7use hickory_resolver::proto::rr::rdata::CAA;
8use hickory_resolver::proto::rr::RecordType as HickoryRecordType;
9use hickory_resolver::TokioAsyncResolver;
10use tracing::{debug, instrument};
11
12use super::records::{DnsRecord, RecordData, RecordType};
13use crate::error::{Result, SeerError};
14use crate::validation::normalize_domain;
15
16/// Convert a DNS lookup result, treating "no records found" as an empty vec
17/// rather than an error. This is correct DNS behavior — the absence of a
18/// record type for a domain is a valid response (NODATA), not a failure.
19fn dns_lookup_or_empty<T>(
20    result: std::result::Result<T, hickory_resolver::error::ResolveError>,
21    record_type: &str,
22) -> Result<Option<T>> {
23    match result {
24        Ok(response) => Ok(Some(response)),
25        Err(e) => match e.kind() {
26            ResolveErrorKind::NoRecordsFound { .. } => Ok(None),
27            _ => Err(SeerError::DnsError(format!(
28                "{} lookup failed: {}",
29                record_type, e
30            ))),
31        },
32    }
33}
34
35/// Default timeout for DNS queries (5 seconds).
36/// DNS is typically fast; longer timeouts indicate network issues or unreachable servers.
37const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
38
39/// DNS resolver for querying various record types.
40///
41/// Uses Google DNS (8.8.8.8) by default, but supports custom nameservers.
42/// The default resolver is cached and reused across queries to avoid
43/// repeated initialization overhead.
44#[derive(Clone)]
45pub struct DnsResolver {
46    timeout: Duration,
47    /// Cached default resolver (Google DNS). Reused across all queries
48    /// that don't specify a custom nameserver.
49    default_resolver: TokioAsyncResolver,
50}
51
52impl std::fmt::Debug for DnsResolver {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("DnsResolver")
55            .field("timeout", &self.timeout)
56            .finish()
57    }
58}
59
60impl Default for DnsResolver {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl DnsResolver {
67    /// Creates a new DNS resolver with default settings.
68    pub fn new() -> Self {
69        let mut opts = ResolverOpts::default();
70        opts.timeout = DEFAULT_TIMEOUT;
71        opts.attempts = 2;
72        opts.use_hosts_file = false;
73
74        Self {
75            timeout: DEFAULT_TIMEOUT,
76            default_resolver: TokioAsyncResolver::tokio(ResolverConfig::google(), opts),
77        }
78    }
79
80    /// Sets the timeout for DNS queries.
81    ///
82    /// The default is 5 seconds, which is sufficient for most DNS queries.
83    pub fn with_timeout(mut self, timeout: Duration) -> Self {
84        self.timeout = timeout;
85        // Recreate default resolver with new timeout
86        let mut opts = ResolverOpts::default();
87        opts.timeout = timeout;
88        opts.attempts = 2;
89        opts.use_hosts_file = false;
90        self.default_resolver = TokioAsyncResolver::tokio(ResolverConfig::google(), opts);
91        self
92    }
93
94    fn create_custom_resolver(&self, nameserver: &str) -> Result<TokioAsyncResolver> {
95        let mut opts = ResolverOpts::default();
96        opts.timeout = self.timeout;
97        opts.attempts = 2;
98        opts.use_hosts_file = false;
99
100        let ip: IpAddr = nameserver
101            .parse()
102            .map_err(|_| SeerError::DnsError(format!("invalid nameserver IP: {}", nameserver)))?;
103
104        // SSRF protection: reject private/reserved IPs for user-supplied nameservers
105        if let Some(reason) = crate::validation::describe_reserved_ip(&ip) {
106            return Err(SeerError::DnsError(format!(
107                "nameserver {} blocked: {}",
108                nameserver, reason
109            )));
110        }
111
112        let socket_addr = SocketAddr::new(ip, 53);
113        let ns_config = NameServerConfig::new(socket_addr, Protocol::Udp);
114
115        let mut config = ResolverConfig::new();
116        config.add_name_server(ns_config);
117
118        Ok(TokioAsyncResolver::tokio(config, opts))
119    }
120
121    /// Resolves DNS records for a domain.
122    ///
123    /// # Arguments
124    /// * `domain` - The domain name to query
125    /// * `record_type` - The type of DNS record to look up (A, AAAA, MX, etc.)
126    /// * `nameserver` - Optional custom nameserver IP; uses Google DNS if None
127    #[instrument(skip(self), fields(domain = %domain, record_type = %record_type))]
128    pub async fn resolve(
129        &self,
130        domain: &str,
131        record_type: RecordType,
132        nameserver: Option<&str>,
133    ) -> Result<Vec<DnsRecord>> {
134        // Reuse the cached default resolver when no custom nameserver is specified
135        let custom_resolver;
136        let resolver = if let Some(ns) = nameserver {
137            custom_resolver = self.create_custom_resolver(ns)?;
138            &custom_resolver
139        } else {
140            &self.default_resolver
141        };
142        let domain = normalize_domain(domain)?;
143
144        debug!(nameserver = nameserver.unwrap_or("system"), "Resolving DNS");
145
146        match record_type {
147            RecordType::A => self.resolve_a(resolver, &domain).await,
148            RecordType::AAAA => self.resolve_aaaa(resolver, &domain).await,
149            RecordType::CNAME => self.resolve_cname(resolver, &domain).await,
150            RecordType::MX => self.resolve_mx(resolver, &domain).await,
151            RecordType::NS => self.resolve_ns(resolver, &domain).await,
152            RecordType::TXT => self.resolve_txt(resolver, &domain).await,
153            RecordType::SOA => self.resolve_soa(resolver, &domain).await,
154            RecordType::PTR => self.resolve_ptr(resolver, &domain).await,
155            RecordType::SRV => Err(SeerError::DnsError(
156                "SRV records require service name format: _service._proto.name".to_string(),
157            )),
158            RecordType::CAA => self.resolve_caa(resolver, &domain).await,
159            RecordType::DNSKEY => self.resolve_dnskey(resolver, &domain).await,
160            RecordType::DS => self.resolve_ds(resolver, &domain).await,
161            RecordType::ANY => self.resolve_any(resolver, &domain).await,
162            _ => Err(SeerError::DnsError(format!(
163                "Record type {} not implemented",
164                record_type
165            ))),
166        }
167    }
168
169    /// Resolves SRV records for a service.
170    ///
171    /// # Arguments
172    /// * `service` - The service name (e.g., "http", "ldap")
173    /// * `protocol` - The protocol (e.g., "tcp", "udp")
174    /// * `domain` - The domain name
175    /// * `nameserver` - Optional custom nameserver IP
176    #[instrument(skip(self), fields(domain = %domain, service = %service, protocol = %protocol))]
177    pub async fn resolve_srv(
178        &self,
179        service: &str,
180        protocol: &str,
181        domain: &str,
182        nameserver: Option<&str>,
183    ) -> Result<Vec<DnsRecord>> {
184        // Validate service and protocol to prevent DNS query injection
185        if !is_valid_srv_label(service) {
186            return Err(SeerError::DnsError(format!(
187                "invalid SRV service name: {}",
188                service
189            )));
190        }
191        if !is_valid_srv_label(protocol) {
192            return Err(SeerError::DnsError(format!(
193                "invalid SRV protocol name: {}",
194                protocol
195            )));
196        }
197
198        let custom_resolver;
199        let resolver = if let Some(ns) = nameserver {
200            custom_resolver = self.create_custom_resolver(ns)?;
201            &custom_resolver
202        } else {
203            &self.default_resolver
204        };
205        let query_name = format!("_{}._{}.{}", service, protocol, domain);
206
207        let Some(response) = dns_lookup_or_empty(resolver.srv_lookup(&query_name).await, "SRV")?
208        else {
209            return Ok(vec![]);
210        };
211
212        let records = response
213            .iter()
214            .map(|srv| DnsRecord {
215                name: query_name.clone(),
216                record_type: RecordType::SRV,
217                ttl: response
218                    .as_lookup()
219                    .record_iter()
220                    .next()
221                    .map(|r| r.ttl())
222                    .unwrap_or(0),
223                data: RecordData::SRV {
224                    priority: srv.priority(),
225                    weight: srv.weight(),
226                    port: srv.port(),
227                    target: srv.target().to_string(),
228                },
229            })
230            .collect();
231
232        Ok(records)
233    }
234
235    async fn resolve_a(
236        &self,
237        resolver: &TokioAsyncResolver,
238        domain: &str,
239    ) -> Result<Vec<DnsRecord>> {
240        let Some(response) = dns_lookup_or_empty(resolver.ipv4_lookup(domain).await, "A")? else {
241            return Ok(vec![]);
242        };
243
244        let ttl = response
245            .as_lookup()
246            .record_iter()
247            .next()
248            .map(|r| r.ttl())
249            .unwrap_or(0);
250
251        let records = response
252            .iter()
253            .map(|addr| DnsRecord {
254                name: domain.to_string(),
255                record_type: RecordType::A,
256                ttl,
257                data: RecordData::A {
258                    address: addr.to_string(),
259                },
260            })
261            .collect();
262
263        Ok(records)
264    }
265
266    async fn resolve_aaaa(
267        &self,
268        resolver: &TokioAsyncResolver,
269        domain: &str,
270    ) -> Result<Vec<DnsRecord>> {
271        let Some(response) = dns_lookup_or_empty(resolver.ipv6_lookup(domain).await, "AAAA")?
272        else {
273            return Ok(vec![]);
274        };
275
276        let ttl = response
277            .as_lookup()
278            .record_iter()
279            .next()
280            .map(|r| r.ttl())
281            .unwrap_or(0);
282
283        let records = response
284            .iter()
285            .map(|addr| DnsRecord {
286                name: domain.to_string(),
287                record_type: RecordType::AAAA,
288                ttl,
289                data: RecordData::AAAA {
290                    address: addr.to_string(),
291                },
292            })
293            .collect();
294
295        Ok(records)
296    }
297
298    async fn resolve_cname(
299        &self,
300        resolver: &TokioAsyncResolver,
301        domain: &str,
302    ) -> Result<Vec<DnsRecord>> {
303        let Some(response) = dns_lookup_or_empty(
304            resolver.lookup(domain, HickoryRecordType::CNAME).await,
305            "CNAME",
306        )?
307        else {
308            return Ok(vec![]);
309        };
310
311        let records = response
312            .record_iter()
313            .filter_map(|record| {
314                if let Some(rdata) = record.data() {
315                    if let Some(cname) = rdata.as_cname() {
316                        return Some(DnsRecord {
317                            name: domain.to_string(),
318                            record_type: RecordType::CNAME,
319                            ttl: record.ttl(),
320                            data: RecordData::CNAME {
321                                target: cname.0.to_string(),
322                            },
323                        });
324                    }
325                }
326                None
327            })
328            .collect();
329
330        Ok(records)
331    }
332
333    async fn resolve_mx(
334        &self,
335        resolver: &TokioAsyncResolver,
336        domain: &str,
337    ) -> Result<Vec<DnsRecord>> {
338        let Some(response) = dns_lookup_or_empty(resolver.mx_lookup(domain).await, "MX")? else {
339            return Ok(vec![]);
340        };
341
342        let ttl = response
343            .as_lookup()
344            .record_iter()
345            .next()
346            .map(|r| r.ttl())
347            .unwrap_or(0);
348
349        let mut records: Vec<DnsRecord> = response
350            .iter()
351            .map(|mx| DnsRecord {
352                name: domain.to_string(),
353                record_type: RecordType::MX,
354                ttl,
355                data: RecordData::MX {
356                    preference: mx.preference(),
357                    exchange: mx.exchange().to_string(),
358                },
359            })
360            .collect();
361
362        records.sort_by_key(|r| {
363            if let RecordData::MX { preference, .. } = &r.data {
364                *preference
365            } else {
366                0
367            }
368        });
369
370        Ok(records)
371    }
372
373    async fn resolve_ns(
374        &self,
375        resolver: &TokioAsyncResolver,
376        domain: &str,
377    ) -> Result<Vec<DnsRecord>> {
378        let Some(response) = dns_lookup_or_empty(resolver.ns_lookup(domain).await, "NS")? else {
379            return Ok(vec![]);
380        };
381
382        let ttl = response
383            .as_lookup()
384            .record_iter()
385            .next()
386            .map(|r| r.ttl())
387            .unwrap_or(0);
388
389        let records = response
390            .iter()
391            .map(|ns| DnsRecord {
392                name: domain.to_string(),
393                record_type: RecordType::NS,
394                ttl,
395                data: RecordData::NS {
396                    nameserver: ns.0.to_string(),
397                },
398            })
399            .collect();
400
401        Ok(records)
402    }
403
404    async fn resolve_txt(
405        &self,
406        resolver: &TokioAsyncResolver,
407        domain: &str,
408    ) -> Result<Vec<DnsRecord>> {
409        let Some(response) = dns_lookup_or_empty(resolver.txt_lookup(domain).await, "TXT")? else {
410            return Ok(vec![]);
411        };
412
413        let ttl = response
414            .as_lookup()
415            .record_iter()
416            .next()
417            .map(|r| r.ttl())
418            .unwrap_or(0);
419
420        let records = response
421            .iter()
422            .map(|txt| {
423                let text = txt
424                    .iter()
425                    .map(|data| String::from_utf8_lossy(data).to_string())
426                    .collect::<Vec<_>>()
427                    .join("");
428
429                DnsRecord {
430                    name: domain.to_string(),
431                    record_type: RecordType::TXT,
432                    ttl,
433                    data: RecordData::TXT { text },
434                }
435            })
436            .collect();
437
438        Ok(records)
439    }
440
441    async fn resolve_soa(
442        &self,
443        resolver: &TokioAsyncResolver,
444        domain: &str,
445    ) -> Result<Vec<DnsRecord>> {
446        let Some(response) = dns_lookup_or_empty(resolver.soa_lookup(domain).await, "SOA")? else {
447            return Ok(vec![]);
448        };
449
450        let ttl = response
451            .as_lookup()
452            .record_iter()
453            .next()
454            .map(|r| r.ttl())
455            .unwrap_or(0);
456
457        let records = response
458            .iter()
459            .map(|soa| DnsRecord {
460                name: domain.to_string(),
461                record_type: RecordType::SOA,
462                ttl,
463                data: RecordData::SOA {
464                    mname: soa.mname().to_string(),
465                    rname: soa.rname().to_string(),
466                    serial: soa.serial(),
467                    refresh: soa.refresh().try_into().unwrap_or(0),
468                    retry: soa.retry().try_into().unwrap_or(0),
469                    expire: soa.expire().try_into().unwrap_or(0),
470                    minimum: soa.minimum(),
471                },
472            })
473            .collect();
474
475        Ok(records)
476    }
477
478    async fn resolve_ptr(
479        &self,
480        resolver: &TokioAsyncResolver,
481        query: &str,
482    ) -> Result<Vec<DnsRecord>> {
483        // If it's an IP address, convert to reverse DNS format
484        let query = if let Ok(ip) = IpAddr::from_str(query) {
485            reverse_dns_name(&ip)
486        } else {
487            query.to_string()
488        };
489
490        let Some(response) =
491            dns_lookup_or_empty(resolver.lookup(&query, HickoryRecordType::PTR).await, "PTR")?
492        else {
493            return Ok(vec![]);
494        };
495
496        let records = response
497            .record_iter()
498            .filter_map(|record| {
499                if let Some(rdata) = record.data() {
500                    if let Some(ptr) = rdata.as_ptr() {
501                        return Some(DnsRecord {
502                            name: query.clone(),
503                            record_type: RecordType::PTR,
504                            ttl: record.ttl(),
505                            data: RecordData::PTR {
506                                target: ptr.0.to_string(),
507                            },
508                        });
509                    }
510                }
511                None
512            })
513            .collect();
514
515        Ok(records)
516    }
517
518    async fn resolve_caa(
519        &self,
520        resolver: &TokioAsyncResolver,
521        domain: &str,
522    ) -> Result<Vec<DnsRecord>> {
523        let Some(response) =
524            dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::CAA).await, "CAA")?
525        else {
526            return Ok(vec![]);
527        };
528
529        let records = response
530            .record_iter()
531            .filter_map(|record| {
532                if let Some(rdata) = record.data() {
533                    if let Some(caa) = rdata.as_caa() {
534                        let (flags, tag, value) = parse_caa(caa);
535                        return Some(DnsRecord {
536                            name: domain.to_string(),
537                            record_type: RecordType::CAA,
538                            ttl: record.ttl(),
539                            data: RecordData::CAA { flags, tag, value },
540                        });
541                    }
542                }
543                None
544            })
545            .collect();
546
547        Ok(records)
548    }
549
550    async fn resolve_dnskey(
551        &self,
552        resolver: &TokioAsyncResolver,
553        domain: &str,
554    ) -> Result<Vec<DnsRecord>> {
555        use hickory_resolver::proto::rr::RData as HickoryRData;
556
557        let Some(response) = dns_lookup_or_empty(
558            resolver.lookup(domain, HickoryRecordType::DNSKEY).await,
559            "DNSKEY",
560        )?
561        else {
562            return Ok(vec![]);
563        };
564
565        let records = response
566            .record_iter()
567            .filter_map(|record| {
568                if let Some(HickoryRData::DNSSEC(dnssec_rdata)) = record.data() {
569                    if let Some(dnskey) = dnssec_rdata.as_dnskey() {
570                        use base64::{engine::general_purpose::STANDARD, Engine};
571                        let public_key = STANDARD.encode(dnskey.public_key());
572                        return Some(DnsRecord {
573                            name: domain.to_string(),
574                            record_type: RecordType::DNSKEY,
575                            ttl: record.ttl(),
576                            data: RecordData::DNSKEY {
577                                flags: dnskey.flags(),
578                                protocol: 3, // Protocol is always 3 for DNSSEC (RFC 4034)
579                                algorithm: u8::from(dnskey.algorithm()),
580                                public_key,
581                            },
582                        });
583                    }
584                }
585                None
586            })
587            .collect();
588
589        Ok(records)
590    }
591
592    async fn resolve_ds(
593        &self,
594        resolver: &TokioAsyncResolver,
595        domain: &str,
596    ) -> Result<Vec<DnsRecord>> {
597        use hickory_resolver::proto::rr::RData as HickoryRData;
598
599        let Some(response) =
600            dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::DS).await, "DS")?
601        else {
602            return Ok(vec![]);
603        };
604
605        let records = response
606            .record_iter()
607            .filter_map(|record| {
608                if let Some(HickoryRData::DNSSEC(dnssec_rdata)) = record.data() {
609                    if let Some(ds) = dnssec_rdata.as_ds() {
610                        let digest = ds
611                            .digest()
612                            .iter()
613                            .map(|b| format!("{:02X}", b))
614                            .collect::<String>();
615                        return Some(DnsRecord {
616                            name: domain.to_string(),
617                            record_type: RecordType::DS,
618                            ttl: record.ttl(),
619                            data: RecordData::DS {
620                                key_tag: ds.key_tag(),
621                                algorithm: u8::from(ds.algorithm()),
622                                digest_type: u8::from(ds.digest_type()),
623                                digest,
624                            },
625                        });
626                    }
627                }
628                None
629            })
630            .collect();
631
632        Ok(records)
633    }
634
635    async fn resolve_any(
636        &self,
637        resolver: &TokioAsyncResolver,
638        domain: &str,
639    ) -> Result<Vec<DnsRecord>> {
640        let mut all_records = Vec::new();
641
642        // Query common record types
643        let record_types = [
644            RecordType::A,
645            RecordType::AAAA,
646            RecordType::MX,
647            RecordType::NS,
648            RecordType::TXT,
649            RecordType::SOA,
650            RecordType::CAA,
651        ];
652
653        for record_type in record_types {
654            match self.resolve_type(resolver, domain, record_type).await {
655                Ok(records) => all_records.extend(records),
656                Err(_) => continue, // Skip record types that don't exist
657            }
658        }
659
660        Ok(all_records)
661    }
662
663    async fn resolve_type(
664        &self,
665        resolver: &TokioAsyncResolver,
666        domain: &str,
667        record_type: RecordType,
668    ) -> Result<Vec<DnsRecord>> {
669        match record_type {
670            RecordType::A => self.resolve_a(resolver, domain).await,
671            RecordType::AAAA => self.resolve_aaaa(resolver, domain).await,
672            RecordType::CNAME => self.resolve_cname(resolver, domain).await,
673            RecordType::MX => self.resolve_mx(resolver, domain).await,
674            RecordType::NS => self.resolve_ns(resolver, domain).await,
675            RecordType::TXT => self.resolve_txt(resolver, domain).await,
676            RecordType::SOA => self.resolve_soa(resolver, domain).await,
677            RecordType::CAA => self.resolve_caa(resolver, domain).await,
678            RecordType::DNSKEY => self.resolve_dnskey(resolver, domain).await,
679            RecordType::DS => self.resolve_ds(resolver, domain).await,
680            _ => Err(SeerError::DnsError("unsupported record type".to_string())),
681        }
682    }
683}
684
685// Domain normalization is now handled by the shared validation module
686
687fn reverse_dns_name(ip: &IpAddr) -> String {
688    match ip {
689        IpAddr::V4(addr) => {
690            let octets = addr.octets();
691            format!(
692                "{}.{}.{}.{}.in-addr.arpa",
693                octets[3], octets[2], octets[1], octets[0]
694            )
695        }
696        IpAddr::V6(addr) => {
697            let segments = addr.segments();
698            // 32 hex nibbles + 31 dots + ".ip6.arpa" (9) = 72 chars
699            let mut result = String::with_capacity(72);
700            let mut first = true;
701            for segment in segments.iter().rev() {
702                for shift in [0, 4, 8, 12] {
703                    if !first {
704                        result.push('.');
705                    }
706                    first = false;
707                    let nibble = (segment >> shift) & 0xF;
708                    result
709                        .push(char::from_digit(nibble as u32, 16).expect("nibble is always 0-15"));
710                }
711            }
712            result.push_str(".ip6.arpa");
713            result
714        }
715    }
716}
717
718fn parse_caa(caa: &CAA) -> (u8, String, String) {
719    let flags = if caa.issuer_critical() { 128 } else { 0 };
720    let tag = caa.tag().as_str().to_string();
721    let value = caa.value().to_string();
722    (flags, tag, value)
723}
724
725/// Validates SRV service/protocol labels (alphanumeric and hyphens only, no dots)
726fn is_valid_srv_label(label: &str) -> bool {
727    !label.is_empty()
728        && label.len() <= 63
729        && label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-')
730        && !label.starts_with('-')
731        && !label.ends_with('-')
732}
733
734#[cfg(test)]
735mod tests {
736    //! Unit tests for the pure helpers and public surface of the DNS
737    //! resolver. Tests that would exercise the hickory wire protocol
738    //! are covered by live-network tests marked `#[ignore]` in the
739    //! sibling modules (`dns/dnssec.rs`, `dns/follow.rs`). Deeper
740    //! coverage of `resolve_*` paths would require a hickory mock,
741    //! which is out of scope for this module.
742    //
743    // TODO: mock hickory resolver for full path coverage.
744
745    use super::*;
746    use std::net::{Ipv4Addr, Ipv6Addr};
747
748    // --- RecordType::from_str edge cases -----------------------------
749
750    #[test]
751    fn record_type_from_str_accepts_lowercase() {
752        assert_eq!(RecordType::from_str("a").unwrap(), RecordType::A);
753        assert_eq!(RecordType::from_str("mx").unwrap(), RecordType::MX);
754        assert_eq!(RecordType::from_str("cname").unwrap(), RecordType::CNAME);
755        assert_eq!(RecordType::from_str("dnskey").unwrap(), RecordType::DNSKEY);
756    }
757
758    #[test]
759    fn record_type_from_str_accepts_mixed_case() {
760        assert_eq!(RecordType::from_str("Mx").unwrap(), RecordType::MX);
761        assert_eq!(RecordType::from_str("cNaMe").unwrap(), RecordType::CNAME);
762    }
763
764    #[test]
765    fn record_type_from_str_rejects_whitespace_padded() {
766        // No trim is done inside from_str; leading/trailing whitespace
767        // must currently cause a parse error so callers don't pass
768        // malformed labels through.
769        assert!(RecordType::from_str(" A").is_err());
770        assert!(RecordType::from_str("A ").is_err());
771        assert!(RecordType::from_str("\tA\n").is_err());
772    }
773
774    #[test]
775    fn record_type_from_str_rejects_unknown() {
776        assert!(RecordType::from_str("NOTAREAL").is_err());
777        assert!(RecordType::from_str("A1").is_err());
778        assert!(RecordType::from_str("").is_err());
779    }
780
781    #[test]
782    fn record_type_from_str_accepts_star_as_any() {
783        assert_eq!(RecordType::from_str("*").unwrap(), RecordType::ANY);
784        assert_eq!(RecordType::from_str("ANY").unwrap(), RecordType::ANY);
785        assert_eq!(RecordType::from_str("any").unwrap(), RecordType::ANY);
786    }
787
788    // --- is_valid_srv_label ------------------------------------------
789
790    #[test]
791    fn srv_label_accepts_alphanumeric_and_hyphen() {
792        assert!(is_valid_srv_label("http"));
793        assert!(is_valid_srv_label("ldap-tls"));
794        assert!(is_valid_srv_label("a1"));
795        assert!(is_valid_srv_label("tcp"));
796    }
797
798    #[test]
799    fn srv_label_rejects_empty() {
800        assert!(!is_valid_srv_label(""));
801    }
802
803    #[test]
804    fn srv_label_rejects_leading_or_trailing_hyphen() {
805        assert!(!is_valid_srv_label("-http"));
806        assert!(!is_valid_srv_label("http-"));
807        assert!(!is_valid_srv_label("-"));
808    }
809
810    #[test]
811    fn srv_label_rejects_dots() {
812        // Dots would let an attacker construct `_service._tcp.evil.com.target`
813        // and pivot the query to a different domain.
814        assert!(!is_valid_srv_label("http.evil"));
815        assert!(!is_valid_srv_label("a.b"));
816    }
817
818    #[test]
819    fn srv_label_rejects_special_chars() {
820        assert!(!is_valid_srv_label("http evil"));
821        assert!(!is_valid_srv_label("http/evil"));
822        assert!(!is_valid_srv_label("http\0"));
823        assert!(!is_valid_srv_label("http\n"));
824    }
825
826    #[test]
827    fn srv_label_rejects_over_63_chars() {
828        let too_long = "a".repeat(64);
829        assert!(!is_valid_srv_label(&too_long));
830        let exactly_63 = "a".repeat(63);
831        assert!(is_valid_srv_label(&exactly_63));
832    }
833
834    // --- reverse_dns_name --------------------------------------------
835
836    #[test]
837    fn reverse_dns_name_formats_ipv4_correctly() {
838        let ip: IpAddr = Ipv4Addr::new(192, 0, 2, 1).into();
839        assert_eq!(reverse_dns_name(&ip), "1.2.0.192.in-addr.arpa");
840    }
841
842    #[test]
843    fn reverse_dns_name_formats_ipv6_correctly() {
844        // ::1 (loopback) → 32 nibbles of 0 followed by ...0.0.0.1 reversed.
845        let ip: IpAddr = Ipv6Addr::LOCALHOST.into();
846        let name = reverse_dns_name(&ip);
847        assert!(
848            name.ends_with(".ip6.arpa"),
849            "must end with .ip6.arpa; got: {}",
850            name
851        );
852        // The first nibble (most-reversed position) must be 1 (from ::1 low bit).
853        assert!(
854            name.starts_with("1."),
855            "expected '1.' prefix, got: {}",
856            name
857        );
858        // 32 nibbles + 31 dots + ".ip6.arpa" (9 chars) = 72.
859        assert_eq!(name.len(), 72);
860    }
861
862    // --- DnsResolver construction ------------------------------------
863
864    #[test]
865    fn resolver_new_has_default_timeout() {
866        let r = DnsResolver::new();
867        assert_eq!(r.timeout, DEFAULT_TIMEOUT);
868    }
869
870    #[test]
871    fn resolver_with_timeout_overrides_default() {
872        let custom = Duration::from_secs(42);
873        let r = DnsResolver::new().with_timeout(custom);
874        assert_eq!(r.timeout, custom);
875    }
876
877    #[test]
878    fn resolver_default_matches_new() {
879        let a = DnsResolver::default();
880        let b = DnsResolver::new();
881        assert_eq!(a.timeout, b.timeout);
882    }
883
884    // --- create_custom_resolver validation ---------------------------
885
886    #[test]
887    fn custom_resolver_rejects_invalid_ip() {
888        let r = DnsResolver::new();
889        let err = r.create_custom_resolver("not-an-ip").unwrap_err();
890        let msg = err.to_string().to_lowercase();
891        assert!(
892            msg.contains("invalid nameserver ip"),
893            "expected 'invalid nameserver ip' in error, got: {}",
894            msg
895        );
896    }
897
898    #[test]
899    fn custom_resolver_rejects_private_ipv4() {
900        // SSRF defense: private / reserved ranges must be blocked even
901        // when passed as a literal IP rather than a hostname.
902        let r = DnsResolver::new();
903        for reserved in ["127.0.0.1", "10.0.0.1", "192.168.1.1", "169.254.169.254"] {
904            let err = r.create_custom_resolver(reserved).unwrap_err();
905            let msg = err.to_string().to_lowercase();
906            assert!(
907                msg.contains("blocked") || msg.contains("reserved"),
908                "reserved IP {} must be rejected, got error: {}",
909                reserved,
910                msg
911            );
912        }
913    }
914
915    #[test]
916    fn custom_resolver_rejects_loopback_ipv6() {
917        let r = DnsResolver::new();
918        let err = r.create_custom_resolver("::1").unwrap_err();
919        let msg = err.to_string().to_lowercase();
920        assert!(
921            msg.contains("blocked") || msg.contains("reserved"),
922            "::1 must be rejected, got error: {}",
923            msg
924        );
925    }
926
927    #[test]
928    fn custom_resolver_accepts_public_ipv4() {
929        // A known public resolver IP must be acceptable.
930        let r = DnsResolver::new();
931        let result = r.create_custom_resolver("8.8.8.8");
932        assert!(
933            result.is_ok(),
934            "8.8.8.8 must be accepted as a public nameserver, got: {:?}",
935            result.err()
936        );
937    }
938
939    // --- SRV query validation (integration between helper + resolver) ----
940
941    #[tokio::test]
942    async fn resolve_srv_rejects_invalid_service_label() {
943        let r = DnsResolver::new();
944        // With_dot service name would construct a malformed DNS query.
945        let result = r.resolve_srv("http.evil", "tcp", "example.com", None).await;
946        assert!(result.is_err());
947        let msg = result.unwrap_err().to_string().to_lowercase();
948        assert!(
949            msg.contains("invalid srv service"),
950            "expected SRV service validation error, got: {}",
951            msg
952        );
953    }
954
955    #[tokio::test]
956    async fn resolve_srv_rejects_invalid_protocol_label() {
957        let r = DnsResolver::new();
958        let result = r.resolve_srv("http", "tcp.evil", "example.com", None).await;
959        assert!(result.is_err());
960        let msg = result.unwrap_err().to_string().to_lowercase();
961        assert!(
962            msg.contains("invalid srv protocol"),
963            "expected SRV protocol validation error, got: {}",
964            msg
965        );
966    }
967
968    // --- Normalization applied before resolution ---------------------
969
970    #[tokio::test]
971    async fn resolve_normalizes_uppercase_domain_input() {
972        // We can't hit the network in unit tests, but we can at least
973        // assert that normalization rejects clearly-invalid input
974        // before any network call is made. Domains with a leading `.`
975        // are rejected by the normalizer.
976        let r = DnsResolver::new();
977        let result = r.resolve(".bad.example", RecordType::A, None).await;
978        assert!(result.is_err(), "leading-dot domain must be rejected");
979    }
980
981    // --- SRV record -------------------------------------------------
982
983    #[tokio::test]
984    async fn resolve_rejects_srv_record_type_without_srv_helper() {
985        // Calling `resolve` with SRV should return the helpful error
986        // instructing the caller to use `resolve_srv` instead.
987        let r = DnsResolver::new();
988        let result = r.resolve("example.com", RecordType::SRV, None).await;
989        assert!(result.is_err());
990        let msg = result.unwrap_err().to_string();
991        assert!(
992            msg.contains("SRV records require service name format"),
993            "expected helpful SRV error, got: {}",
994            msg
995        );
996    }
997}