seer_core/dns/
resolver.rs

1use std::net::{IpAddr, SocketAddr};
2use std::str::FromStr;
3use std::time::Duration;
4
5use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
6use hickory_resolver::proto::rr::rdata::CAA;
7use hickory_resolver::proto::rr::RecordType as HickoryRecordType;
8use hickory_resolver::TokioAsyncResolver;
9use tracing::{debug, instrument};
10
11use super::records::{DnsRecord, RecordData, RecordType};
12use crate::error::{Result, SeerError};
13use crate::validation::normalize_domain;
14
15const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
16
17#[derive(Debug, Clone)]
18pub struct DnsResolver {
19    timeout: Duration,
20}
21
22impl Default for DnsResolver {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl DnsResolver {
29    pub fn new() -> Self {
30        Self {
31            timeout: DEFAULT_TIMEOUT,
32        }
33    }
34
35    pub fn with_timeout(mut self, timeout: Duration) -> Self {
36        self.timeout = timeout;
37        self
38    }
39
40    fn create_resolver(&self, nameserver: Option<&str>) -> Result<TokioAsyncResolver> {
41        let mut opts = ResolverOpts::default();
42        opts.timeout = self.timeout;
43        opts.attempts = 2;
44        opts.use_hosts_file = false;
45
46        let config = if let Some(ns) = nameserver {
47            let ip: IpAddr = ns
48                .parse()
49                .map_err(|_| SeerError::DnsError(format!("Invalid nameserver IP: {}", ns)))?;
50
51            let socket_addr = SocketAddr::new(ip, 53);
52            let ns_config = NameServerConfig::new(socket_addr, Protocol::Udp);
53
54            let mut config = ResolverConfig::new();
55            config.add_name_server(ns_config);
56            config
57        } else {
58            ResolverConfig::google()
59        };
60
61        Ok(TokioAsyncResolver::tokio(config, opts))
62    }
63
64    #[instrument(skip(self), fields(domain = %domain, record_type = %record_type))]
65    pub async fn resolve(
66        &self,
67        domain: &str,
68        record_type: RecordType,
69        nameserver: Option<&str>,
70    ) -> Result<Vec<DnsRecord>> {
71        let resolver = self.create_resolver(nameserver)?;
72        let domain = normalize_domain(domain)?;
73
74        debug!(
75            nameserver = nameserver.unwrap_or("system"),
76            "Resolving DNS"
77        );
78
79        match record_type {
80            RecordType::A => self.resolve_a(&resolver, &domain).await,
81            RecordType::AAAA => self.resolve_aaaa(&resolver, &domain).await,
82            RecordType::CNAME => self.resolve_cname(&resolver, &domain).await,
83            RecordType::MX => self.resolve_mx(&resolver, &domain).await,
84            RecordType::NS => self.resolve_ns(&resolver, &domain).await,
85            RecordType::TXT => self.resolve_txt(&resolver, &domain).await,
86            RecordType::SOA => self.resolve_soa(&resolver, &domain).await,
87            RecordType::PTR => self.resolve_ptr(&resolver, &domain).await,
88            RecordType::SRV => Err(SeerError::DnsError(
89                "SRV records require service name format: _service._proto.name".to_string(),
90            )),
91            RecordType::CAA => self.resolve_caa(&resolver, &domain).await,
92            RecordType::DNSKEY => self.resolve_dnskey(&resolver, &domain).await,
93            RecordType::DS => self.resolve_ds(&resolver, &domain).await,
94            RecordType::ANY => self.resolve_any(&resolver, &domain).await,
95            _ => Err(SeerError::DnsError(format!(
96                "Record type {} not implemented",
97                record_type
98            ))),
99        }
100    }
101
102    pub async fn resolve_srv(
103        &self,
104        service: &str,
105        protocol: &str,
106        domain: &str,
107        nameserver: Option<&str>,
108    ) -> Result<Vec<DnsRecord>> {
109        // Validate service and protocol to prevent DNS query injection
110        if !is_valid_srv_label(service) {
111            return Err(SeerError::DnsError(format!(
112                "Invalid SRV service name: {}",
113                service
114            )));
115        }
116        if !is_valid_srv_label(protocol) {
117            return Err(SeerError::DnsError(format!(
118                "Invalid SRV protocol name: {}",
119                protocol
120            )));
121        }
122
123        let resolver = self.create_resolver(nameserver)?;
124        let query_name = format!("_{}._{}.{}", service, protocol, domain);
125
126        let response = resolver
127            .srv_lookup(&query_name)
128            .await
129            .map_err(|e| SeerError::DnsError(format!("SRV lookup failed: {}", e)))?;
130
131        let records = response
132            .iter()
133            .map(|srv| DnsRecord {
134                name: query_name.clone(),
135                record_type: RecordType::SRV,
136                ttl: response.as_lookup().record_iter().next().map(|r| r.ttl()).unwrap_or(0),
137                data: RecordData::SRV {
138                    priority: srv.priority(),
139                    weight: srv.weight(),
140                    port: srv.port(),
141                    target: srv.target().to_string(),
142                },
143            })
144            .collect();
145
146        Ok(records)
147    }
148
149    async fn resolve_a(&self, resolver: &TokioAsyncResolver, domain: &str) -> Result<Vec<DnsRecord>> {
150        let response = resolver
151            .ipv4_lookup(domain)
152            .await
153            .map_err(|e| SeerError::DnsError(format!("A lookup failed: {}", e)))?;
154
155        let ttl = response.as_lookup().record_iter().next().map(|r| r.ttl()).unwrap_or(0);
156
157        let records = response
158            .iter()
159            .map(|addr| DnsRecord {
160                name: domain.to_string(),
161                record_type: RecordType::A,
162                ttl,
163                data: RecordData::A {
164                    address: addr.to_string(),
165                },
166            })
167            .collect();
168
169        Ok(records)
170    }
171
172    async fn resolve_aaaa(
173        &self,
174        resolver: &TokioAsyncResolver,
175        domain: &str,
176    ) -> Result<Vec<DnsRecord>> {
177        let response = resolver
178            .ipv6_lookup(domain)
179            .await
180            .map_err(|e| SeerError::DnsError(format!("AAAA lookup failed: {}", e)))?;
181
182        let ttl = response.as_lookup().record_iter().next().map(|r| r.ttl()).unwrap_or(0);
183
184        let records = response
185            .iter()
186            .map(|addr| DnsRecord {
187                name: domain.to_string(),
188                record_type: RecordType::AAAA,
189                ttl,
190                data: RecordData::AAAA {
191                    address: addr.to_string(),
192                },
193            })
194            .collect();
195
196        Ok(records)
197    }
198
199    async fn resolve_cname(
200        &self,
201        resolver: &TokioAsyncResolver,
202        domain: &str,
203    ) -> Result<Vec<DnsRecord>> {
204        let response = resolver
205            .lookup(domain, HickoryRecordType::CNAME)
206            .await
207            .map_err(|e| SeerError::DnsError(format!("CNAME lookup failed: {}", e)))?;
208
209        let records = response
210            .record_iter()
211            .filter_map(|record| {
212                if let Some(rdata) = record.data() {
213                    if let Some(cname) = rdata.as_cname() {
214                        return Some(DnsRecord {
215                            name: domain.to_string(),
216                            record_type: RecordType::CNAME,
217                            ttl: record.ttl(),
218                            data: RecordData::CNAME {
219                                target: cname.0.to_string(),
220                            },
221                        });
222                    }
223                }
224                None
225            })
226            .collect();
227
228        Ok(records)
229    }
230
231    async fn resolve_mx(&self, resolver: &TokioAsyncResolver, domain: &str) -> Result<Vec<DnsRecord>> {
232        let response = resolver
233            .mx_lookup(domain)
234            .await
235            .map_err(|e| SeerError::DnsError(format!("MX lookup failed: {}", e)))?;
236
237        let ttl = response.as_lookup().record_iter().next().map(|r| r.ttl()).unwrap_or(0);
238
239        let mut records: Vec<DnsRecord> = response
240            .iter()
241            .map(|mx| DnsRecord {
242                name: domain.to_string(),
243                record_type: RecordType::MX,
244                ttl,
245                data: RecordData::MX {
246                    preference: mx.preference(),
247                    exchange: mx.exchange().to_string(),
248                },
249            })
250            .collect();
251
252        records.sort_by_key(|r| {
253            if let RecordData::MX { preference, .. } = &r.data {
254                *preference
255            } else {
256                0
257            }
258        });
259
260        Ok(records)
261    }
262
263    async fn resolve_ns(&self, resolver: &TokioAsyncResolver, domain: &str) -> Result<Vec<DnsRecord>> {
264        let response = resolver
265            .ns_lookup(domain)
266            .await
267            .map_err(|e| SeerError::DnsError(format!("NS lookup failed: {}", e)))?;
268
269        let ttl = response.as_lookup().record_iter().next().map(|r| r.ttl()).unwrap_or(0);
270
271        let records = response
272            .iter()
273            .map(|ns| DnsRecord {
274                name: domain.to_string(),
275                record_type: RecordType::NS,
276                ttl,
277                data: RecordData::NS {
278                    nameserver: ns.0.to_string(),
279                },
280            })
281            .collect();
282
283        Ok(records)
284    }
285
286    async fn resolve_txt(
287        &self,
288        resolver: &TokioAsyncResolver,
289        domain: &str,
290    ) -> Result<Vec<DnsRecord>> {
291        let response = resolver
292            .txt_lookup(domain)
293            .await
294            .map_err(|e| SeerError::DnsError(format!("TXT lookup failed: {}", e)))?;
295
296        let ttl = response.as_lookup().record_iter().next().map(|r| r.ttl()).unwrap_or(0);
297
298        let records = response
299            .iter()
300            .map(|txt| {
301                let text = txt
302                    .iter()
303                    .map(|data| String::from_utf8_lossy(data).to_string())
304                    .collect::<Vec<_>>()
305                    .join("");
306
307                DnsRecord {
308                    name: domain.to_string(),
309                    record_type: RecordType::TXT,
310                    ttl,
311                    data: RecordData::TXT { text },
312                }
313            })
314            .collect();
315
316        Ok(records)
317    }
318
319    async fn resolve_soa(
320        &self,
321        resolver: &TokioAsyncResolver,
322        domain: &str,
323    ) -> Result<Vec<DnsRecord>> {
324        let response = resolver
325            .soa_lookup(domain)
326            .await
327            .map_err(|e| SeerError::DnsError(format!("SOA lookup failed: {}", e)))?;
328
329        let ttl = response.as_lookup().record_iter().next().map(|r| r.ttl()).unwrap_or(0);
330
331        let records = response
332            .iter()
333            .map(|soa| DnsRecord {
334                name: domain.to_string(),
335                record_type: RecordType::SOA,
336                ttl,
337                data: RecordData::SOA {
338                    mname: soa.mname().to_string(),
339                    rname: soa.rname().to_string(),
340                    serial: soa.serial(),
341                    refresh: soa.refresh().try_into().unwrap_or(0),
342                    retry: soa.retry().try_into().unwrap_or(0),
343                    expire: soa.expire().try_into().unwrap_or(0),
344                    minimum: soa.minimum(),
345                },
346            })
347            .collect();
348
349        Ok(records)
350    }
351
352    async fn resolve_ptr(
353        &self,
354        resolver: &TokioAsyncResolver,
355        query: &str,
356    ) -> Result<Vec<DnsRecord>> {
357        // If it's an IP address, convert to reverse DNS format
358        let query = if let Ok(ip) = IpAddr::from_str(query) {
359            reverse_dns_name(&ip)
360        } else {
361            query.to_string()
362        };
363
364        let response = resolver
365            .lookup(&query, HickoryRecordType::PTR)
366            .await
367            .map_err(|e| SeerError::DnsError(format!("PTR lookup failed: {}", e)))?;
368
369        let records = response
370            .record_iter()
371            .filter_map(|record| {
372                if let Some(rdata) = record.data() {
373                    if let Some(ptr) = rdata.as_ptr() {
374                        return Some(DnsRecord {
375                            name: query.clone(),
376                            record_type: RecordType::PTR,
377                            ttl: record.ttl(),
378                            data: RecordData::PTR {
379                                target: ptr.0.to_string(),
380                            },
381                        });
382                    }
383                }
384                None
385            })
386            .collect();
387
388        Ok(records)
389    }
390
391    async fn resolve_caa(
392        &self,
393        resolver: &TokioAsyncResolver,
394        domain: &str,
395    ) -> Result<Vec<DnsRecord>> {
396        let response = resolver
397            .lookup(domain, HickoryRecordType::CAA)
398            .await
399            .map_err(|e| SeerError::DnsError(format!("CAA lookup failed: {}", e)))?;
400
401        let records = response
402            .record_iter()
403            .filter_map(|record| {
404                if let Some(rdata) = record.data() {
405                    if let Some(caa) = rdata.as_caa() {
406                        let (flags, tag, value) = parse_caa(caa);
407                        return Some(DnsRecord {
408                            name: domain.to_string(),
409                            record_type: RecordType::CAA,
410                            ttl: record.ttl(),
411                            data: RecordData::CAA { flags, tag, value },
412                        });
413                    }
414                }
415                None
416            })
417            .collect();
418
419        Ok(records)
420    }
421
422    async fn resolve_dnskey(
423        &self,
424        resolver: &TokioAsyncResolver,
425        domain: &str,
426    ) -> Result<Vec<DnsRecord>> {
427        use hickory_resolver::proto::rr::RData as HickoryRData;
428
429        let response = resolver
430            .lookup(domain, HickoryRecordType::DNSKEY)
431            .await
432            .map_err(|e| SeerError::DnsError(format!("DNSKEY lookup failed: {}", e)))?;
433
434        let records = response
435            .record_iter()
436            .filter_map(|record| {
437                if let Some(HickoryRData::DNSSEC(dnssec_rdata)) = record.data() {
438                    if let Some(dnskey) = dnssec_rdata.as_dnskey() {
439                            use base64::{engine::general_purpose::STANDARD, Engine};
440                            let public_key = STANDARD.encode(dnskey.public_key());
441                            return Some(DnsRecord {
442                                name: domain.to_string(),
443                                record_type: RecordType::DNSKEY,
444                                ttl: record.ttl(),
445                                data: RecordData::DNSKEY {
446                                    flags: dnskey.flags(),
447                                    protocol: 3, // Protocol is always 3 for DNSSEC (RFC 4034)
448                                    algorithm: u8::from(dnskey.algorithm()),
449                                    public_key,
450                                },
451                            });
452                        }
453                }
454                None
455            })
456            .collect();
457
458        Ok(records)
459    }
460
461    async fn resolve_ds(
462        &self,
463        resolver: &TokioAsyncResolver,
464        domain: &str,
465    ) -> Result<Vec<DnsRecord>> {
466        use hickory_resolver::proto::rr::RData as HickoryRData;
467
468        let response = resolver
469            .lookup(domain, HickoryRecordType::DS)
470            .await
471            .map_err(|e| SeerError::DnsError(format!("DS lookup failed: {}", e)))?;
472
473        let records = response
474            .record_iter()
475            .filter_map(|record| {
476                if let Some(HickoryRData::DNSSEC(dnssec_rdata)) = record.data() {
477                    if let Some(ds) = dnssec_rdata.as_ds() {
478                            let digest = ds
479                                .digest()
480                                .iter()
481                                .map(|b| format!("{:02X}", b))
482                                .collect::<String>();
483                            return Some(DnsRecord {
484                                name: domain.to_string(),
485                                record_type: RecordType::DS,
486                                ttl: record.ttl(),
487                                data: RecordData::DS {
488                                    key_tag: ds.key_tag(),
489                                    algorithm: u8::from(ds.algorithm()),
490                                    digest_type: u8::from(ds.digest_type()),
491                                    digest,
492                                },
493                            });
494                        }
495                }
496                None
497            })
498            .collect();
499
500        Ok(records)
501    }
502
503    async fn resolve_any(
504        &self,
505        resolver: &TokioAsyncResolver,
506        domain: &str,
507    ) -> Result<Vec<DnsRecord>> {
508        let mut all_records = Vec::new();
509
510        // Query common record types
511        let record_types = [
512            RecordType::A,
513            RecordType::AAAA,
514            RecordType::MX,
515            RecordType::NS,
516            RecordType::TXT,
517            RecordType::SOA,
518            RecordType::CAA,
519        ];
520
521        for record_type in record_types {
522            match self.resolve_type(resolver, domain, record_type).await {
523                Ok(records) => all_records.extend(records),
524                Err(_) => continue, // Skip record types that don't exist
525            }
526        }
527
528        Ok(all_records)
529    }
530
531    async fn resolve_type(
532        &self,
533        resolver: &TokioAsyncResolver,
534        domain: &str,
535        record_type: RecordType,
536    ) -> Result<Vec<DnsRecord>> {
537        match record_type {
538            RecordType::A => self.resolve_a(resolver, domain).await,
539            RecordType::AAAA => self.resolve_aaaa(resolver, domain).await,
540            RecordType::CNAME => self.resolve_cname(resolver, domain).await,
541            RecordType::MX => self.resolve_mx(resolver, domain).await,
542            RecordType::NS => self.resolve_ns(resolver, domain).await,
543            RecordType::TXT => self.resolve_txt(resolver, domain).await,
544            RecordType::SOA => self.resolve_soa(resolver, domain).await,
545            RecordType::CAA => self.resolve_caa(resolver, domain).await,
546            RecordType::DNSKEY => self.resolve_dnskey(resolver, domain).await,
547            RecordType::DS => self.resolve_ds(resolver, domain).await,
548            _ => Err(SeerError::DnsError("Unsupported record type".to_string())),
549        }
550    }
551}
552
553// Domain normalization is now handled by the shared validation module
554
555fn reverse_dns_name(ip: &IpAddr) -> String {
556    match ip {
557        IpAddr::V4(addr) => {
558            let octets = addr.octets();
559            format!(
560                "{}.{}.{}.{}.in-addr.arpa",
561                octets[3], octets[2], octets[1], octets[0]
562            )
563        }
564        IpAddr::V6(addr) => {
565            let segments = addr.segments();
566            let nibbles: Vec<String> = segments
567                .iter()
568                .flat_map(|s| {
569                    let hex = format!("{:04x}", s);
570                    hex.chars().map(|c| c.to_string()).collect::<Vec<_>>()
571                })
572                .rev()
573                .collect();
574            format!("{}.ip6.arpa", nibbles.join("."))
575        }
576    }
577}
578
579fn parse_caa(caa: &CAA) -> (u8, String, String) {
580    let flags = if caa.issuer_critical() { 128 } else { 0 };
581    let tag = caa.tag().as_str().to_string();
582    let value = caa.value().to_string();
583    (flags, tag, value)
584}
585
586/// Validates SRV service/protocol labels (alphanumeric and hyphens only, no dots)
587fn is_valid_srv_label(label: &str) -> bool {
588    !label.is_empty()
589        && label.len() <= 63
590        && label
591            .chars()
592            .all(|c| c.is_ascii_alphanumeric() || c == '-')
593        && !label.starts_with('-')
594        && !label.ends_with('-')
595}