Skip to main content

seer_core/dns/
resolver.rs

1use std::net::{IpAddr, SocketAddr};
2use std::str::FromStr;
3use std::time::Duration;
4
5use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
6use hickory_resolver::error::ResolveErrorKind;
7use hickory_resolver::proto::rr::rdata::CAA;
8use hickory_resolver::proto::rr::RecordType as HickoryRecordType;
9use hickory_resolver::TokioAsyncResolver;
10use tracing::{debug, instrument};
11
12use super::records::{DnsRecord, RecordData, RecordType};
13use crate::error::{Result, SeerError};
14use crate::validation::normalize_domain;
15
16/// Convert a DNS lookup result, treating "no records found" as an empty vec
17/// rather than an error. This is correct DNS behavior — the absence of a
18/// record type for a domain is a valid response (NODATA), not a failure.
19fn dns_lookup_or_empty<T>(
20    result: std::result::Result<T, hickory_resolver::error::ResolveError>,
21    record_type: &str,
22) -> Result<Option<T>> {
23    match result {
24        Ok(response) => Ok(Some(response)),
25        Err(e) => match e.kind() {
26            ResolveErrorKind::NoRecordsFound { .. } => Ok(None),
27            _ => Err(SeerError::DnsError(format!(
28                "{} lookup failed: {}",
29                record_type, e
30            ))),
31        },
32    }
33}
34
35/// Default timeout for DNS queries (5 seconds).
36/// DNS is typically fast; longer timeouts indicate network issues or unreachable servers.
37const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
38
39/// DNS resolver for querying various record types.
40///
41/// Uses Google DNS (8.8.8.8) by default, but supports custom nameservers.
42/// The default resolver is cached and reused across queries to avoid
43/// repeated initialization overhead.
44#[derive(Clone)]
45pub struct DnsResolver {
46    timeout: Duration,
47    /// Cached default resolver (Google DNS). Reused across all queries
48    /// that don't specify a custom nameserver.
49    default_resolver: TokioAsyncResolver,
50}
51
52impl std::fmt::Debug for DnsResolver {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("DnsResolver")
55            .field("timeout", &self.timeout)
56            .finish()
57    }
58}
59
60impl Default for DnsResolver {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl DnsResolver {
67    /// Creates a new DNS resolver with default settings.
68    pub fn new() -> Self {
69        let mut opts = ResolverOpts::default();
70        opts.timeout = DEFAULT_TIMEOUT;
71        opts.attempts = 2;
72        opts.use_hosts_file = false;
73
74        Self {
75            timeout: DEFAULT_TIMEOUT,
76            default_resolver: TokioAsyncResolver::tokio(ResolverConfig::google(), opts),
77        }
78    }
79
80    /// Sets the timeout for DNS queries.
81    ///
82    /// The default is 5 seconds, which is sufficient for most DNS queries.
83    pub fn with_timeout(mut self, timeout: Duration) -> Self {
84        self.timeout = timeout;
85        // Recreate default resolver with new timeout
86        let mut opts = ResolverOpts::default();
87        opts.timeout = timeout;
88        opts.attempts = 2;
89        opts.use_hosts_file = false;
90        self.default_resolver = TokioAsyncResolver::tokio(ResolverConfig::google(), opts);
91        self
92    }
93
94    fn create_custom_resolver(&self, nameserver: &str) -> Result<TokioAsyncResolver> {
95        let mut opts = ResolverOpts::default();
96        opts.timeout = self.timeout;
97        opts.attempts = 2;
98        opts.use_hosts_file = false;
99
100        let ip: IpAddr = nameserver
101            .parse()
102            .map_err(|_| SeerError::DnsError(format!("invalid nameserver IP: {}", nameserver)))?;
103
104        // SSRF protection: reject private/reserved IPs for user-supplied nameservers
105        if let Some(reason) = crate::validation::describe_reserved_ip(&ip) {
106            return Err(SeerError::DnsError(format!(
107                "nameserver {} blocked: {}",
108                nameserver, reason
109            )));
110        }
111
112        let socket_addr = SocketAddr::new(ip, 53);
113        let ns_config = NameServerConfig::new(socket_addr, Protocol::Udp);
114
115        let mut config = ResolverConfig::new();
116        config.add_name_server(ns_config);
117
118        Ok(TokioAsyncResolver::tokio(config, opts))
119    }
120
121    /// Resolves DNS records for a domain.
122    ///
123    /// # Arguments
124    /// * `domain` - The domain name to query
125    /// * `record_type` - The type of DNS record to look up (A, AAAA, MX, etc.)
126    /// * `nameserver` - Optional custom nameserver IP; uses Google DNS if None
127    #[instrument(skip(self), fields(domain = %domain, record_type = %record_type))]
128    pub async fn resolve(
129        &self,
130        domain: &str,
131        record_type: RecordType,
132        nameserver: Option<&str>,
133    ) -> Result<Vec<DnsRecord>> {
134        // Reuse the cached default resolver when no custom nameserver is specified
135        let custom_resolver;
136        let resolver = if let Some(ns) = nameserver {
137            custom_resolver = self.create_custom_resolver(ns)?;
138            &custom_resolver
139        } else {
140            &self.default_resolver
141        };
142        let domain = normalize_domain(domain)?;
143
144        debug!(nameserver = nameserver.unwrap_or("system"), "Resolving DNS");
145
146        match record_type {
147            RecordType::A => self.resolve_a(resolver, &domain).await,
148            RecordType::AAAA => self.resolve_aaaa(resolver, &domain).await,
149            RecordType::CNAME => self.resolve_cname(resolver, &domain).await,
150            RecordType::MX => self.resolve_mx(resolver, &domain).await,
151            RecordType::NS => self.resolve_ns(resolver, &domain).await,
152            RecordType::TXT => self.resolve_txt(resolver, &domain).await,
153            RecordType::SOA => self.resolve_soa(resolver, &domain).await,
154            RecordType::PTR => self.resolve_ptr(resolver, &domain).await,
155            RecordType::SRV => Err(SeerError::DnsError(
156                "SRV records require service name format: _service._proto.name".to_string(),
157            )),
158            RecordType::CAA => self.resolve_caa(resolver, &domain).await,
159            RecordType::DNSKEY => self.resolve_dnskey(resolver, &domain).await,
160            RecordType::DS => self.resolve_ds(resolver, &domain).await,
161            RecordType::ANY => self.resolve_any(resolver, &domain).await,
162            _ => Err(SeerError::DnsError(format!(
163                "Record type {} not implemented",
164                record_type
165            ))),
166        }
167    }
168
169    /// Resolves SRV records for a service.
170    ///
171    /// # Arguments
172    /// * `service` - The service name (e.g., "http", "ldap")
173    /// * `protocol` - The protocol (e.g., "tcp", "udp")
174    /// * `domain` - The domain name
175    /// * `nameserver` - Optional custom nameserver IP
176    #[instrument(skip(self), fields(domain = %domain, service = %service, protocol = %protocol))]
177    pub async fn resolve_srv(
178        &self,
179        service: &str,
180        protocol: &str,
181        domain: &str,
182        nameserver: Option<&str>,
183    ) -> Result<Vec<DnsRecord>> {
184        // Validate service and protocol to prevent DNS query injection
185        if !is_valid_srv_label(service) {
186            return Err(SeerError::DnsError(format!(
187                "invalid SRV service name: {}",
188                service
189            )));
190        }
191        if !is_valid_srv_label(protocol) {
192            return Err(SeerError::DnsError(format!(
193                "invalid SRV protocol name: {}",
194                protocol
195            )));
196        }
197
198        let custom_resolver;
199        let resolver = if let Some(ns) = nameserver {
200            custom_resolver = self.create_custom_resolver(ns)?;
201            &custom_resolver
202        } else {
203            &self.default_resolver
204        };
205        let query_name = format!("_{}._{}.{}", service, protocol, domain);
206
207        let Some(response) = dns_lookup_or_empty(resolver.srv_lookup(&query_name).await, "SRV")?
208        else {
209            return Ok(vec![]);
210        };
211
212        let records = response
213            .iter()
214            .map(|srv| DnsRecord {
215                name: query_name.clone(),
216                record_type: RecordType::SRV,
217                ttl: response
218                    .as_lookup()
219                    .record_iter()
220                    .next()
221                    .map(|r| r.ttl())
222                    .unwrap_or(0),
223                data: RecordData::SRV {
224                    priority: srv.priority(),
225                    weight: srv.weight(),
226                    port: srv.port(),
227                    target: srv.target().to_string(),
228                },
229            })
230            .collect();
231
232        Ok(records)
233    }
234
235    async fn resolve_a(
236        &self,
237        resolver: &TokioAsyncResolver,
238        domain: &str,
239    ) -> Result<Vec<DnsRecord>> {
240        let Some(response) = dns_lookup_or_empty(resolver.ipv4_lookup(domain).await, "A")? else {
241            return Ok(vec![]);
242        };
243
244        let ttl = response
245            .as_lookup()
246            .record_iter()
247            .next()
248            .map(|r| r.ttl())
249            .unwrap_or(0);
250
251        let records = response
252            .iter()
253            .map(|addr| DnsRecord {
254                name: domain.to_string(),
255                record_type: RecordType::A,
256                ttl,
257                data: RecordData::A {
258                    address: addr.to_string(),
259                },
260            })
261            .collect();
262
263        Ok(records)
264    }
265
266    async fn resolve_aaaa(
267        &self,
268        resolver: &TokioAsyncResolver,
269        domain: &str,
270    ) -> Result<Vec<DnsRecord>> {
271        let Some(response) = dns_lookup_or_empty(resolver.ipv6_lookup(domain).await, "AAAA")?
272        else {
273            return Ok(vec![]);
274        };
275
276        let ttl = response
277            .as_lookup()
278            .record_iter()
279            .next()
280            .map(|r| r.ttl())
281            .unwrap_or(0);
282
283        let records = response
284            .iter()
285            .map(|addr| DnsRecord {
286                name: domain.to_string(),
287                record_type: RecordType::AAAA,
288                ttl,
289                data: RecordData::AAAA {
290                    address: addr.to_string(),
291                },
292            })
293            .collect();
294
295        Ok(records)
296    }
297
298    async fn resolve_cname(
299        &self,
300        resolver: &TokioAsyncResolver,
301        domain: &str,
302    ) -> Result<Vec<DnsRecord>> {
303        let Some(response) = dns_lookup_or_empty(
304            resolver.lookup(domain, HickoryRecordType::CNAME).await,
305            "CNAME",
306        )?
307        else {
308            return Ok(vec![]);
309        };
310
311        let records = response
312            .record_iter()
313            .filter_map(|record| {
314                if let Some(rdata) = record.data() {
315                    if let Some(cname) = rdata.as_cname() {
316                        return Some(DnsRecord {
317                            name: domain.to_string(),
318                            record_type: RecordType::CNAME,
319                            ttl: record.ttl(),
320                            data: RecordData::CNAME {
321                                target: cname.0.to_string(),
322                            },
323                        });
324                    }
325                }
326                None
327            })
328            .collect();
329
330        Ok(records)
331    }
332
333    async fn resolve_mx(
334        &self,
335        resolver: &TokioAsyncResolver,
336        domain: &str,
337    ) -> Result<Vec<DnsRecord>> {
338        let Some(response) = dns_lookup_or_empty(resolver.mx_lookup(domain).await, "MX")? else {
339            return Ok(vec![]);
340        };
341
342        let ttl = response
343            .as_lookup()
344            .record_iter()
345            .next()
346            .map(|r| r.ttl())
347            .unwrap_or(0);
348
349        let mut records: Vec<DnsRecord> = response
350            .iter()
351            .map(|mx| DnsRecord {
352                name: domain.to_string(),
353                record_type: RecordType::MX,
354                ttl,
355                data: RecordData::MX {
356                    preference: mx.preference(),
357                    exchange: mx.exchange().to_string(),
358                },
359            })
360            .collect();
361
362        records.sort_by_key(|r| {
363            if let RecordData::MX { preference, .. } = &r.data {
364                *preference
365            } else {
366                0
367            }
368        });
369
370        Ok(records)
371    }
372
373    async fn resolve_ns(
374        &self,
375        resolver: &TokioAsyncResolver,
376        domain: &str,
377    ) -> Result<Vec<DnsRecord>> {
378        let Some(response) = dns_lookup_or_empty(resolver.ns_lookup(domain).await, "NS")? else {
379            return Ok(vec![]);
380        };
381
382        let ttl = response
383            .as_lookup()
384            .record_iter()
385            .next()
386            .map(|r| r.ttl())
387            .unwrap_or(0);
388
389        let records = response
390            .iter()
391            .map(|ns| DnsRecord {
392                name: domain.to_string(),
393                record_type: RecordType::NS,
394                ttl,
395                data: RecordData::NS {
396                    nameserver: ns.0.to_string(),
397                },
398            })
399            .collect();
400
401        Ok(records)
402    }
403
404    async fn resolve_txt(
405        &self,
406        resolver: &TokioAsyncResolver,
407        domain: &str,
408    ) -> Result<Vec<DnsRecord>> {
409        let Some(response) = dns_lookup_or_empty(resolver.txt_lookup(domain).await, "TXT")? else {
410            return Ok(vec![]);
411        };
412
413        let ttl = response
414            .as_lookup()
415            .record_iter()
416            .next()
417            .map(|r| r.ttl())
418            .unwrap_or(0);
419
420        let records = response
421            .iter()
422            .map(|txt| {
423                let text = txt
424                    .iter()
425                    .map(|data| String::from_utf8_lossy(data).to_string())
426                    .collect::<Vec<_>>()
427                    .join("");
428
429                DnsRecord {
430                    name: domain.to_string(),
431                    record_type: RecordType::TXT,
432                    ttl,
433                    data: RecordData::TXT { text },
434                }
435            })
436            .collect();
437
438        Ok(records)
439    }
440
441    async fn resolve_soa(
442        &self,
443        resolver: &TokioAsyncResolver,
444        domain: &str,
445    ) -> Result<Vec<DnsRecord>> {
446        let Some(response) = dns_lookup_or_empty(resolver.soa_lookup(domain).await, "SOA")? else {
447            return Ok(vec![]);
448        };
449
450        let ttl = response
451            .as_lookup()
452            .record_iter()
453            .next()
454            .map(|r| r.ttl())
455            .unwrap_or(0);
456
457        let records = response
458            .iter()
459            .map(|soa| DnsRecord {
460                name: domain.to_string(),
461                record_type: RecordType::SOA,
462                ttl,
463                data: RecordData::SOA {
464                    mname: soa.mname().to_string(),
465                    rname: soa.rname().to_string(),
466                    serial: soa.serial(),
467                    refresh: soa.refresh().try_into().unwrap_or(0),
468                    retry: soa.retry().try_into().unwrap_or(0),
469                    expire: soa.expire().try_into().unwrap_or(0),
470                    minimum: soa.minimum(),
471                },
472            })
473            .collect();
474
475        Ok(records)
476    }
477
478    async fn resolve_ptr(
479        &self,
480        resolver: &TokioAsyncResolver,
481        query: &str,
482    ) -> Result<Vec<DnsRecord>> {
483        // If it's an IP address, convert to reverse DNS format
484        let query = if let Ok(ip) = IpAddr::from_str(query) {
485            reverse_dns_name(&ip)
486        } else {
487            query.to_string()
488        };
489
490        let Some(response) =
491            dns_lookup_or_empty(resolver.lookup(&query, HickoryRecordType::PTR).await, "PTR")?
492        else {
493            return Ok(vec![]);
494        };
495
496        let records = response
497            .record_iter()
498            .filter_map(|record| {
499                if let Some(rdata) = record.data() {
500                    if let Some(ptr) = rdata.as_ptr() {
501                        return Some(DnsRecord {
502                            name: query.clone(),
503                            record_type: RecordType::PTR,
504                            ttl: record.ttl(),
505                            data: RecordData::PTR {
506                                target: ptr.0.to_string(),
507                            },
508                        });
509                    }
510                }
511                None
512            })
513            .collect();
514
515        Ok(records)
516    }
517
518    async fn resolve_caa(
519        &self,
520        resolver: &TokioAsyncResolver,
521        domain: &str,
522    ) -> Result<Vec<DnsRecord>> {
523        let Some(response) =
524            dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::CAA).await, "CAA")?
525        else {
526            return Ok(vec![]);
527        };
528
529        let records = response
530            .record_iter()
531            .filter_map(|record| {
532                if let Some(rdata) = record.data() {
533                    if let Some(caa) = rdata.as_caa() {
534                        let (flags, tag, value) = parse_caa(caa);
535                        return Some(DnsRecord {
536                            name: domain.to_string(),
537                            record_type: RecordType::CAA,
538                            ttl: record.ttl(),
539                            data: RecordData::CAA { flags, tag, value },
540                        });
541                    }
542                }
543                None
544            })
545            .collect();
546
547        Ok(records)
548    }
549
550    async fn resolve_dnskey(
551        &self,
552        resolver: &TokioAsyncResolver,
553        domain: &str,
554    ) -> Result<Vec<DnsRecord>> {
555        use hickory_resolver::proto::rr::RData as HickoryRData;
556
557        let Some(response) = dns_lookup_or_empty(
558            resolver.lookup(domain, HickoryRecordType::DNSKEY).await,
559            "DNSKEY",
560        )?
561        else {
562            return Ok(vec![]);
563        };
564
565        let records = response
566            .record_iter()
567            .filter_map(|record| {
568                if let Some(HickoryRData::DNSSEC(dnssec_rdata)) = record.data() {
569                    if let Some(dnskey) = dnssec_rdata.as_dnskey() {
570                        use base64::{engine::general_purpose::STANDARD, Engine};
571                        let public_key = STANDARD.encode(dnskey.public_key());
572                        return Some(DnsRecord {
573                            name: domain.to_string(),
574                            record_type: RecordType::DNSKEY,
575                            ttl: record.ttl(),
576                            data: RecordData::DNSKEY {
577                                flags: dnskey.flags(),
578                                protocol: 3, // Protocol is always 3 for DNSSEC (RFC 4034)
579                                algorithm: u8::from(dnskey.algorithm()),
580                                public_key,
581                            },
582                        });
583                    }
584                }
585                None
586            })
587            .collect();
588
589        Ok(records)
590    }
591
592    async fn resolve_ds(
593        &self,
594        resolver: &TokioAsyncResolver,
595        domain: &str,
596    ) -> Result<Vec<DnsRecord>> {
597        use hickory_resolver::proto::rr::RData as HickoryRData;
598
599        let Some(response) =
600            dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::DS).await, "DS")?
601        else {
602            return Ok(vec![]);
603        };
604
605        let records = response
606            .record_iter()
607            .filter_map(|record| {
608                if let Some(HickoryRData::DNSSEC(dnssec_rdata)) = record.data() {
609                    if let Some(ds) = dnssec_rdata.as_ds() {
610                        let digest = ds
611                            .digest()
612                            .iter()
613                            .map(|b| format!("{:02X}", b))
614                            .collect::<String>();
615                        return Some(DnsRecord {
616                            name: domain.to_string(),
617                            record_type: RecordType::DS,
618                            ttl: record.ttl(),
619                            data: RecordData::DS {
620                                key_tag: ds.key_tag(),
621                                algorithm: u8::from(ds.algorithm()),
622                                digest_type: u8::from(ds.digest_type()),
623                                digest,
624                            },
625                        });
626                    }
627                }
628                None
629            })
630            .collect();
631
632        Ok(records)
633    }
634
635    async fn resolve_any(
636        &self,
637        resolver: &TokioAsyncResolver,
638        domain: &str,
639    ) -> Result<Vec<DnsRecord>> {
640        let mut all_records = Vec::new();
641
642        // Query common record types
643        let record_types = [
644            RecordType::A,
645            RecordType::AAAA,
646            RecordType::MX,
647            RecordType::NS,
648            RecordType::TXT,
649            RecordType::SOA,
650            RecordType::CAA,
651        ];
652
653        for record_type in record_types {
654            match self.resolve_type(resolver, domain, record_type).await {
655                Ok(records) => all_records.extend(records),
656                Err(_) => continue, // Skip record types that don't exist
657            }
658        }
659
660        Ok(all_records)
661    }
662
663    async fn resolve_type(
664        &self,
665        resolver: &TokioAsyncResolver,
666        domain: &str,
667        record_type: RecordType,
668    ) -> Result<Vec<DnsRecord>> {
669        match record_type {
670            RecordType::A => self.resolve_a(resolver, domain).await,
671            RecordType::AAAA => self.resolve_aaaa(resolver, domain).await,
672            RecordType::CNAME => self.resolve_cname(resolver, domain).await,
673            RecordType::MX => self.resolve_mx(resolver, domain).await,
674            RecordType::NS => self.resolve_ns(resolver, domain).await,
675            RecordType::TXT => self.resolve_txt(resolver, domain).await,
676            RecordType::SOA => self.resolve_soa(resolver, domain).await,
677            RecordType::CAA => self.resolve_caa(resolver, domain).await,
678            RecordType::DNSKEY => self.resolve_dnskey(resolver, domain).await,
679            RecordType::DS => self.resolve_ds(resolver, domain).await,
680            _ => Err(SeerError::DnsError("unsupported record type".to_string())),
681        }
682    }
683}
684
685// Domain normalization is now handled by the shared validation module
686
687fn reverse_dns_name(ip: &IpAddr) -> String {
688    match ip {
689        IpAddr::V4(addr) => {
690            let octets = addr.octets();
691            format!(
692                "{}.{}.{}.{}.in-addr.arpa",
693                octets[3], octets[2], octets[1], octets[0]
694            )
695        }
696        IpAddr::V6(addr) => {
697            let segments = addr.segments();
698            // 32 hex nibbles + 31 dots + ".ip6.arpa" (9) = 72 chars
699            let mut result = String::with_capacity(72);
700            let mut first = true;
701            for segment in segments.iter().rev() {
702                for shift in [0, 4, 8, 12] {
703                    if !first {
704                        result.push('.');
705                    }
706                    first = false;
707                    let nibble = (segment >> shift) & 0xF;
708                    result
709                        .push(char::from_digit(nibble as u32, 16).expect("nibble is always 0-15"));
710                }
711            }
712            result.push_str(".ip6.arpa");
713            result
714        }
715    }
716}
717
718fn parse_caa(caa: &CAA) -> (u8, String, String) {
719    let flags = if caa.issuer_critical() { 128 } else { 0 };
720    let tag = caa.tag().as_str().to_string();
721    let value = caa.value().to_string();
722    (flags, tag, value)
723}
724
725/// Validates SRV service/protocol labels (alphanumeric and hyphens only, no dots)
726fn is_valid_srv_label(label: &str) -> bool {
727    !label.is_empty()
728        && label.len() <= 63
729        && label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-')
730        && !label.starts_with('-')
731        && !label.ends_with('-')
732}