Skip to main content

seer_core/dns/
resolver.rs

1use std::net::{IpAddr, SocketAddr};
2use std::str::FromStr;
3use std::time::Duration;
4
5use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
6use hickory_resolver::proto::rr::rdata::CAA;
7use hickory_resolver::proto::rr::RecordType as HickoryRecordType;
8use hickory_resolver::TokioAsyncResolver;
9use tracing::{debug, instrument};
10
11use super::records::{DnsRecord, RecordData, RecordType};
12use crate::error::{Result, SeerError};
13use crate::validation::normalize_domain;
14
15/// Default timeout for DNS queries (5 seconds).
16/// DNS is typically fast; longer timeouts indicate network issues or unreachable servers.
17const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
18
19/// DNS resolver for querying various record types.
20///
21/// Uses Google DNS (8.8.8.8) by default, but supports custom nameservers.
22/// The default resolver is cached and reused across queries to avoid
23/// repeated initialization overhead.
24#[derive(Clone)]
25pub struct DnsResolver {
26    timeout: Duration,
27    /// Cached default resolver (Google DNS). Reused across all queries
28    /// that don't specify a custom nameserver.
29    default_resolver: TokioAsyncResolver,
30}
31
32impl std::fmt::Debug for DnsResolver {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("DnsResolver")
35            .field("timeout", &self.timeout)
36            .finish()
37    }
38}
39
40impl Default for DnsResolver {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl DnsResolver {
47    /// Creates a new DNS resolver with default settings.
48    pub fn new() -> Self {
49        let mut opts = ResolverOpts::default();
50        opts.timeout = DEFAULT_TIMEOUT;
51        opts.attempts = 2;
52        opts.use_hosts_file = false;
53
54        Self {
55            timeout: DEFAULT_TIMEOUT,
56            default_resolver: TokioAsyncResolver::tokio(ResolverConfig::google(), opts),
57        }
58    }
59
60    /// Sets the timeout for DNS queries.
61    ///
62    /// The default is 5 seconds, which is sufficient for most DNS queries.
63    pub fn with_timeout(mut self, timeout: Duration) -> Self {
64        self.timeout = timeout;
65        // Recreate default resolver with new timeout
66        let mut opts = ResolverOpts::default();
67        opts.timeout = timeout;
68        opts.attempts = 2;
69        opts.use_hosts_file = false;
70        self.default_resolver = TokioAsyncResolver::tokio(ResolverConfig::google(), opts);
71        self
72    }
73
74    fn create_custom_resolver(&self, nameserver: &str) -> Result<TokioAsyncResolver> {
75        let mut opts = ResolverOpts::default();
76        opts.timeout = self.timeout;
77        opts.attempts = 2;
78        opts.use_hosts_file = false;
79
80        let ip: IpAddr = nameserver
81            .parse()
82            .map_err(|_| SeerError::DnsError(format!("invalid nameserver IP: {}", nameserver)))?;
83
84        let socket_addr = SocketAddr::new(ip, 53);
85        let ns_config = NameServerConfig::new(socket_addr, Protocol::Udp);
86
87        let mut config = ResolverConfig::new();
88        config.add_name_server(ns_config);
89
90        Ok(TokioAsyncResolver::tokio(config, opts))
91    }
92
93    /// Resolves DNS records for a domain.
94    ///
95    /// # Arguments
96    /// * `domain` - The domain name to query
97    /// * `record_type` - The type of DNS record to look up (A, AAAA, MX, etc.)
98    /// * `nameserver` - Optional custom nameserver IP; uses Google DNS if None
99    #[instrument(skip(self), fields(domain = %domain, record_type = %record_type))]
100    pub async fn resolve(
101        &self,
102        domain: &str,
103        record_type: RecordType,
104        nameserver: Option<&str>,
105    ) -> Result<Vec<DnsRecord>> {
106        // Reuse the cached default resolver when no custom nameserver is specified
107        let custom_resolver;
108        let resolver = if let Some(ns) = nameserver {
109            custom_resolver = self.create_custom_resolver(ns)?;
110            &custom_resolver
111        } else {
112            &self.default_resolver
113        };
114        let domain = normalize_domain(domain)?;
115
116        debug!(nameserver = nameserver.unwrap_or("system"), "Resolving DNS");
117
118        match record_type {
119            RecordType::A => self.resolve_a(resolver, &domain).await,
120            RecordType::AAAA => self.resolve_aaaa(resolver, &domain).await,
121            RecordType::CNAME => self.resolve_cname(resolver, &domain).await,
122            RecordType::MX => self.resolve_mx(resolver, &domain).await,
123            RecordType::NS => self.resolve_ns(resolver, &domain).await,
124            RecordType::TXT => self.resolve_txt(resolver, &domain).await,
125            RecordType::SOA => self.resolve_soa(resolver, &domain).await,
126            RecordType::PTR => self.resolve_ptr(resolver, &domain).await,
127            RecordType::SRV => Err(SeerError::DnsError(
128                "SRV records require service name format: _service._proto.name".to_string(),
129            )),
130            RecordType::CAA => self.resolve_caa(resolver, &domain).await,
131            RecordType::DNSKEY => self.resolve_dnskey(resolver, &domain).await,
132            RecordType::DS => self.resolve_ds(resolver, &domain).await,
133            RecordType::ANY => self.resolve_any(resolver, &domain).await,
134            _ => Err(SeerError::DnsError(format!(
135                "Record type {} not implemented",
136                record_type
137            ))),
138        }
139    }
140
141    /// Resolves SRV records for a service.
142    ///
143    /// # Arguments
144    /// * `service` - The service name (e.g., "http", "ldap")
145    /// * `protocol` - The protocol (e.g., "tcp", "udp")
146    /// * `domain` - The domain name
147    /// * `nameserver` - Optional custom nameserver IP
148    pub async fn resolve_srv(
149        &self,
150        service: &str,
151        protocol: &str,
152        domain: &str,
153        nameserver: Option<&str>,
154    ) -> Result<Vec<DnsRecord>> {
155        // Validate service and protocol to prevent DNS query injection
156        if !is_valid_srv_label(service) {
157            return Err(SeerError::DnsError(format!(
158                "invalid SRV service name: {}",
159                service
160            )));
161        }
162        if !is_valid_srv_label(protocol) {
163            return Err(SeerError::DnsError(format!(
164                "invalid SRV protocol name: {}",
165                protocol
166            )));
167        }
168
169        let custom_resolver;
170        let resolver = if let Some(ns) = nameserver {
171            custom_resolver = self.create_custom_resolver(ns)?;
172            &custom_resolver
173        } else {
174            &self.default_resolver
175        };
176        let query_name = format!("_{}._{}.{}", service, protocol, domain);
177
178        let response = resolver
179            .srv_lookup(&query_name)
180            .await
181            .map_err(|e| SeerError::DnsError(format!("SRV lookup failed: {}", e)))?;
182
183        let records = response
184            .iter()
185            .map(|srv| DnsRecord {
186                name: query_name.clone(),
187                record_type: RecordType::SRV,
188                ttl: response
189                    .as_lookup()
190                    .record_iter()
191                    .next()
192                    .map(|r| r.ttl())
193                    .unwrap_or(0),
194                data: RecordData::SRV {
195                    priority: srv.priority(),
196                    weight: srv.weight(),
197                    port: srv.port(),
198                    target: srv.target().to_string(),
199                },
200            })
201            .collect();
202
203        Ok(records)
204    }
205
206    async fn resolve_a(
207        &self,
208        resolver: &TokioAsyncResolver,
209        domain: &str,
210    ) -> Result<Vec<DnsRecord>> {
211        let response = resolver
212            .ipv4_lookup(domain)
213            .await
214            .map_err(|e| SeerError::DnsError(format!("A lookup failed: {}", e)))?;
215
216        let ttl = response
217            .as_lookup()
218            .record_iter()
219            .next()
220            .map(|r| r.ttl())
221            .unwrap_or(0);
222
223        let records = response
224            .iter()
225            .map(|addr| DnsRecord {
226                name: domain.to_string(),
227                record_type: RecordType::A,
228                ttl,
229                data: RecordData::A {
230                    address: addr.to_string(),
231                },
232            })
233            .collect();
234
235        Ok(records)
236    }
237
238    async fn resolve_aaaa(
239        &self,
240        resolver: &TokioAsyncResolver,
241        domain: &str,
242    ) -> Result<Vec<DnsRecord>> {
243        let response = resolver
244            .ipv6_lookup(domain)
245            .await
246            .map_err(|e| SeerError::DnsError(format!("AAAA lookup failed: {}", e)))?;
247
248        let ttl = response
249            .as_lookup()
250            .record_iter()
251            .next()
252            .map(|r| r.ttl())
253            .unwrap_or(0);
254
255        let records = response
256            .iter()
257            .map(|addr| DnsRecord {
258                name: domain.to_string(),
259                record_type: RecordType::AAAA,
260                ttl,
261                data: RecordData::AAAA {
262                    address: addr.to_string(),
263                },
264            })
265            .collect();
266
267        Ok(records)
268    }
269
270    async fn resolve_cname(
271        &self,
272        resolver: &TokioAsyncResolver,
273        domain: &str,
274    ) -> Result<Vec<DnsRecord>> {
275        let response = resolver
276            .lookup(domain, HickoryRecordType::CNAME)
277            .await
278            .map_err(|e| SeerError::DnsError(format!("CNAME lookup failed: {}", e)))?;
279
280        let records = response
281            .record_iter()
282            .filter_map(|record| {
283                if let Some(rdata) = record.data() {
284                    if let Some(cname) = rdata.as_cname() {
285                        return Some(DnsRecord {
286                            name: domain.to_string(),
287                            record_type: RecordType::CNAME,
288                            ttl: record.ttl(),
289                            data: RecordData::CNAME {
290                                target: cname.0.to_string(),
291                            },
292                        });
293                    }
294                }
295                None
296            })
297            .collect();
298
299        Ok(records)
300    }
301
302    async fn resolve_mx(
303        &self,
304        resolver: &TokioAsyncResolver,
305        domain: &str,
306    ) -> Result<Vec<DnsRecord>> {
307        let response = resolver
308            .mx_lookup(domain)
309            .await
310            .map_err(|e| SeerError::DnsError(format!("MX lookup failed: {}", e)))?;
311
312        let ttl = response
313            .as_lookup()
314            .record_iter()
315            .next()
316            .map(|r| r.ttl())
317            .unwrap_or(0);
318
319        let mut records: Vec<DnsRecord> = response
320            .iter()
321            .map(|mx| DnsRecord {
322                name: domain.to_string(),
323                record_type: RecordType::MX,
324                ttl,
325                data: RecordData::MX {
326                    preference: mx.preference(),
327                    exchange: mx.exchange().to_string(),
328                },
329            })
330            .collect();
331
332        records.sort_by_key(|r| {
333            if let RecordData::MX { preference, .. } = &r.data {
334                *preference
335            } else {
336                0
337            }
338        });
339
340        Ok(records)
341    }
342
343    async fn resolve_ns(
344        &self,
345        resolver: &TokioAsyncResolver,
346        domain: &str,
347    ) -> Result<Vec<DnsRecord>> {
348        let response = resolver
349            .ns_lookup(domain)
350            .await
351            .map_err(|e| SeerError::DnsError(format!("NS lookup failed: {}", e)))?;
352
353        let ttl = response
354            .as_lookup()
355            .record_iter()
356            .next()
357            .map(|r| r.ttl())
358            .unwrap_or(0);
359
360        let records = response
361            .iter()
362            .map(|ns| DnsRecord {
363                name: domain.to_string(),
364                record_type: RecordType::NS,
365                ttl,
366                data: RecordData::NS {
367                    nameserver: ns.0.to_string(),
368                },
369            })
370            .collect();
371
372        Ok(records)
373    }
374
375    async fn resolve_txt(
376        &self,
377        resolver: &TokioAsyncResolver,
378        domain: &str,
379    ) -> Result<Vec<DnsRecord>> {
380        let response = resolver
381            .txt_lookup(domain)
382            .await
383            .map_err(|e| SeerError::DnsError(format!("TXT lookup failed: {}", e)))?;
384
385        let ttl = response
386            .as_lookup()
387            .record_iter()
388            .next()
389            .map(|r| r.ttl())
390            .unwrap_or(0);
391
392        let records = response
393            .iter()
394            .map(|txt| {
395                let text = txt
396                    .iter()
397                    .map(|data| String::from_utf8_lossy(data).to_string())
398                    .collect::<Vec<_>>()
399                    .join("");
400
401                DnsRecord {
402                    name: domain.to_string(),
403                    record_type: RecordType::TXT,
404                    ttl,
405                    data: RecordData::TXT { text },
406                }
407            })
408            .collect();
409
410        Ok(records)
411    }
412
413    async fn resolve_soa(
414        &self,
415        resolver: &TokioAsyncResolver,
416        domain: &str,
417    ) -> Result<Vec<DnsRecord>> {
418        let response = resolver
419            .soa_lookup(domain)
420            .await
421            .map_err(|e| SeerError::DnsError(format!("SOA lookup failed: {}", e)))?;
422
423        let ttl = response
424            .as_lookup()
425            .record_iter()
426            .next()
427            .map(|r| r.ttl())
428            .unwrap_or(0);
429
430        let records = response
431            .iter()
432            .map(|soa| DnsRecord {
433                name: domain.to_string(),
434                record_type: RecordType::SOA,
435                ttl,
436                data: RecordData::SOA {
437                    mname: soa.mname().to_string(),
438                    rname: soa.rname().to_string(),
439                    serial: soa.serial(),
440                    refresh: soa.refresh().try_into().unwrap_or(0),
441                    retry: soa.retry().try_into().unwrap_or(0),
442                    expire: soa.expire().try_into().unwrap_or(0),
443                    minimum: soa.minimum(),
444                },
445            })
446            .collect();
447
448        Ok(records)
449    }
450
451    async fn resolve_ptr(
452        &self,
453        resolver: &TokioAsyncResolver,
454        query: &str,
455    ) -> Result<Vec<DnsRecord>> {
456        // If it's an IP address, convert to reverse DNS format
457        let query = if let Ok(ip) = IpAddr::from_str(query) {
458            reverse_dns_name(&ip)
459        } else {
460            query.to_string()
461        };
462
463        let response = resolver
464            .lookup(&query, HickoryRecordType::PTR)
465            .await
466            .map_err(|e| SeerError::DnsError(format!("PTR lookup failed: {}", e)))?;
467
468        let records = response
469            .record_iter()
470            .filter_map(|record| {
471                if let Some(rdata) = record.data() {
472                    if let Some(ptr) = rdata.as_ptr() {
473                        return Some(DnsRecord {
474                            name: query.clone(),
475                            record_type: RecordType::PTR,
476                            ttl: record.ttl(),
477                            data: RecordData::PTR {
478                                target: ptr.0.to_string(),
479                            },
480                        });
481                    }
482                }
483                None
484            })
485            .collect();
486
487        Ok(records)
488    }
489
490    async fn resolve_caa(
491        &self,
492        resolver: &TokioAsyncResolver,
493        domain: &str,
494    ) -> Result<Vec<DnsRecord>> {
495        let response = resolver
496            .lookup(domain, HickoryRecordType::CAA)
497            .await
498            .map_err(|e| SeerError::DnsError(format!("CAA lookup failed: {}", e)))?;
499
500        let records = response
501            .record_iter()
502            .filter_map(|record| {
503                if let Some(rdata) = record.data() {
504                    if let Some(caa) = rdata.as_caa() {
505                        let (flags, tag, value) = parse_caa(caa);
506                        return Some(DnsRecord {
507                            name: domain.to_string(),
508                            record_type: RecordType::CAA,
509                            ttl: record.ttl(),
510                            data: RecordData::CAA { flags, tag, value },
511                        });
512                    }
513                }
514                None
515            })
516            .collect();
517
518        Ok(records)
519    }
520
521    async fn resolve_dnskey(
522        &self,
523        resolver: &TokioAsyncResolver,
524        domain: &str,
525    ) -> Result<Vec<DnsRecord>> {
526        use hickory_resolver::proto::rr::RData as HickoryRData;
527
528        let response = resolver
529            .lookup(domain, HickoryRecordType::DNSKEY)
530            .await
531            .map_err(|e| SeerError::DnsError(format!("DNSKEY lookup failed: {}", e)))?;
532
533        let records = response
534            .record_iter()
535            .filter_map(|record| {
536                if let Some(HickoryRData::DNSSEC(dnssec_rdata)) = record.data() {
537                    if let Some(dnskey) = dnssec_rdata.as_dnskey() {
538                        use base64::{engine::general_purpose::STANDARD, Engine};
539                        let public_key = STANDARD.encode(dnskey.public_key());
540                        return Some(DnsRecord {
541                            name: domain.to_string(),
542                            record_type: RecordType::DNSKEY,
543                            ttl: record.ttl(),
544                            data: RecordData::DNSKEY {
545                                flags: dnskey.flags(),
546                                protocol: 3, // Protocol is always 3 for DNSSEC (RFC 4034)
547                                algorithm: u8::from(dnskey.algorithm()),
548                                public_key,
549                            },
550                        });
551                    }
552                }
553                None
554            })
555            .collect();
556
557        Ok(records)
558    }
559
560    async fn resolve_ds(
561        &self,
562        resolver: &TokioAsyncResolver,
563        domain: &str,
564    ) -> Result<Vec<DnsRecord>> {
565        use hickory_resolver::proto::rr::RData as HickoryRData;
566
567        let response = resolver
568            .lookup(domain, HickoryRecordType::DS)
569            .await
570            .map_err(|e| SeerError::DnsError(format!("DS lookup failed: {}", e)))?;
571
572        let records = response
573            .record_iter()
574            .filter_map(|record| {
575                if let Some(HickoryRData::DNSSEC(dnssec_rdata)) = record.data() {
576                    if let Some(ds) = dnssec_rdata.as_ds() {
577                        let digest = ds
578                            .digest()
579                            .iter()
580                            .map(|b| format!("{:02X}", b))
581                            .collect::<String>();
582                        return Some(DnsRecord {
583                            name: domain.to_string(),
584                            record_type: RecordType::DS,
585                            ttl: record.ttl(),
586                            data: RecordData::DS {
587                                key_tag: ds.key_tag(),
588                                algorithm: u8::from(ds.algorithm()),
589                                digest_type: u8::from(ds.digest_type()),
590                                digest,
591                            },
592                        });
593                    }
594                }
595                None
596            })
597            .collect();
598
599        Ok(records)
600    }
601
602    async fn resolve_any(
603        &self,
604        resolver: &TokioAsyncResolver,
605        domain: &str,
606    ) -> Result<Vec<DnsRecord>> {
607        let mut all_records = Vec::new();
608
609        // Query common record types
610        let record_types = [
611            RecordType::A,
612            RecordType::AAAA,
613            RecordType::MX,
614            RecordType::NS,
615            RecordType::TXT,
616            RecordType::SOA,
617            RecordType::CAA,
618        ];
619
620        for record_type in record_types {
621            match self.resolve_type(resolver, domain, record_type).await {
622                Ok(records) => all_records.extend(records),
623                Err(_) => continue, // Skip record types that don't exist
624            }
625        }
626
627        Ok(all_records)
628    }
629
630    async fn resolve_type(
631        &self,
632        resolver: &TokioAsyncResolver,
633        domain: &str,
634        record_type: RecordType,
635    ) -> Result<Vec<DnsRecord>> {
636        match record_type {
637            RecordType::A => self.resolve_a(resolver, domain).await,
638            RecordType::AAAA => self.resolve_aaaa(resolver, domain).await,
639            RecordType::CNAME => self.resolve_cname(resolver, domain).await,
640            RecordType::MX => self.resolve_mx(resolver, domain).await,
641            RecordType::NS => self.resolve_ns(resolver, domain).await,
642            RecordType::TXT => self.resolve_txt(resolver, domain).await,
643            RecordType::SOA => self.resolve_soa(resolver, domain).await,
644            RecordType::CAA => self.resolve_caa(resolver, domain).await,
645            RecordType::DNSKEY => self.resolve_dnskey(resolver, domain).await,
646            RecordType::DS => self.resolve_ds(resolver, domain).await,
647            _ => Err(SeerError::DnsError("unsupported record type".to_string())),
648        }
649    }
650}
651
652// Domain normalization is now handled by the shared validation module
653
654fn reverse_dns_name(ip: &IpAddr) -> String {
655    match ip {
656        IpAddr::V4(addr) => {
657            let octets = addr.octets();
658            format!(
659                "{}.{}.{}.{}.in-addr.arpa",
660                octets[3], octets[2], octets[1], octets[0]
661            )
662        }
663        IpAddr::V6(addr) => {
664            let segments = addr.segments();
665            // 32 hex nibbles + 31 dots + ".ip6.arpa" (9) = 72 chars
666            let mut result = String::with_capacity(72);
667            let mut first = true;
668            for segment in segments.iter().rev() {
669                for shift in [0, 4, 8, 12] {
670                    if !first {
671                        result.push('.');
672                    }
673                    first = false;
674                    let nibble = (segment >> shift) & 0xF;
675                    result.push(char::from_digit(nibble as u32, 16).unwrap());
676                }
677            }
678            result.push_str(".ip6.arpa");
679            result
680        }
681    }
682}
683
684fn parse_caa(caa: &CAA) -> (u8, String, String) {
685    let flags = if caa.issuer_critical() { 128 } else { 0 };
686    let tag = caa.tag().as_str().to_string();
687    let value = caa.value().to_string();
688    (flags, tag, value)
689}
690
691/// Validates SRV service/protocol labels (alphanumeric and hyphens only, no dots)
692fn is_valid_srv_label(label: &str) -> bool {
693    !label.is_empty()
694        && label.len() <= 63
695        && label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-')
696        && !label.starts_with('-')
697        && !label.ends_with('-')
698}