Skip to main content

synapse_pingora/crawler/
dns_resolver.rs

1//! Async DNS resolution for crawler verification.
2//!
3//! ## Security
4//! - Rate limiting via semaphore prevents resource exhaustion at scale
5//! - Timeout enforcement prevents slow DNS servers from blocking requests
6//! - IP round-trip verification prevents DNS rebinding attacks
7
8use std::net::IpAddr;
9use std::sync::Arc;
10use std::time::Duration;
11use thiserror::Error;
12use tokio::sync::Semaphore;
13use trust_dns_resolver::config::{ResolverConfig, ResolverOpts};
14use trust_dns_resolver::TokioAsyncResolver;
15
16/// DNS resolution errors.
17#[derive(Debug, Error, Clone)]
18pub enum DnsError {
19    #[error("DNS resolver creation failed: {0}")]
20    ResolverCreation(String),
21    #[error("DNS lookup failed: {0}")]
22    LookupFailed(String),
23    #[error("DNS timeout after {0}ms")]
24    Timeout(u64),
25    #[error("DNS rate limit exceeded, try again later")]
26    RateLimited,
27    #[error(
28        "DNS verification failed: IP not in forward lookup results (possible rebinding attack)"
29    )]
30    IpMismatch,
31}
32
33impl From<trust_dns_resolver::error::ResolveError> for DnsError {
34    fn from(e: trust_dns_resolver::error::ResolveError) -> Self {
35        DnsError::ResolverCreation(e.to_string())
36    }
37}
38
39/// Async DNS resolver for crawler verification with rate limiting.
40#[derive(Debug)]
41pub struct DnsResolver {
42    resolver: TokioAsyncResolver,
43    timeout: Duration,
44    /// Semaphore to limit concurrent DNS lookups
45    semaphore: Arc<Semaphore>,
46    /// Maximum concurrent lookups (for logging/metrics)
47    max_concurrent: usize,
48}
49
50impl DnsResolver {
51    /// Create a new DNS resolver with rate limiting.
52    ///
53    /// # Arguments
54    /// * `timeout_ms` - DNS lookup timeout in milliseconds
55    /// * `max_concurrent` - Maximum concurrent DNS lookups (semaphore permits)
56    pub async fn new(timeout_ms: u64, max_concurrent: usize) -> Result<Self, DnsError> {
57        let mut opts = ResolverOpts::default();
58        opts.timeout = Duration::from_millis(timeout_ms);
59        opts.attempts = 2;
60
61        let resolver = TokioAsyncResolver::tokio(ResolverConfig::default(), opts);
62
63        Ok(Self {
64            resolver,
65            timeout: Duration::from_millis(timeout_ms),
66            semaphore: Arc::new(Semaphore::new(max_concurrent)),
67            max_concurrent,
68        })
69    }
70
71    /// Get current number of available permits (for metrics)
72    pub fn available_permits(&self) -> usize {
73        self.semaphore.available_permits()
74    }
75
76    /// Get maximum concurrent lookups configured
77    pub fn max_concurrent(&self) -> usize {
78        self.max_concurrent
79    }
80
81    /// Acquire a permit for DNS lookup, with non-blocking try.
82    /// Returns None if rate limit is exceeded.
83    async fn acquire_permit(&self) -> Option<tokio::sync::SemaphorePermit<'_>> {
84        // Try to acquire without blocking - if we can't get a permit immediately,
85        // we're at capacity and should return a rate limit error
86        match self.semaphore.try_acquire() {
87            Ok(permit) => Some(permit),
88            Err(_) => {
89                tracing::warn!(
90                    "DNS rate limit reached: {}/{} permits in use",
91                    self.max_concurrent - self.semaphore.available_permits(),
92                    self.max_concurrent
93                );
94                None
95            }
96        }
97    }
98
99    /// Reverse DNS lookup: IP -> hostname.
100    ///
101    /// Rate-limited via semaphore to prevent resource exhaustion.
102    pub async fn reverse_lookup(&self, ip: IpAddr) -> Result<Option<String>, DnsError> {
103        let _permit = self.acquire_permit().await.ok_or(DnsError::RateLimited)?;
104
105        match tokio::time::timeout(self.timeout, self.resolver.reverse_lookup(ip)).await {
106            Ok(Ok(response)) => {
107                // Get first PTR record
108                if let Some(record) = response.iter().next() {
109                    Ok(Some(record.to_string().trim_end_matches('.').to_string()))
110                } else {
111                    Ok(None)
112                }
113            }
114            Ok(Err(e)) => {
115                // No PTR record is common, not an error
116                tracing::debug!("Reverse DNS lookup for {} failed: {}", ip, e);
117                Ok(None)
118            }
119            Err(_) => {
120                tracing::debug!(
121                    "Reverse DNS lookup for {} timed out after {}ms",
122                    ip,
123                    self.timeout.as_millis()
124                );
125                Err(DnsError::Timeout(self.timeout.as_millis() as u64))
126            }
127        }
128    }
129
130    /// Forward DNS lookup: hostname -> IPs.
131    ///
132    /// Rate-limited via semaphore to prevent resource exhaustion.
133    pub async fn forward_lookup(&self, hostname: &str) -> Result<Vec<IpAddr>, DnsError> {
134        let _permit = self.acquire_permit().await.ok_or(DnsError::RateLimited)?;
135
136        match tokio::time::timeout(self.timeout, self.resolver.lookup_ip(hostname)).await {
137            Ok(Ok(response)) => Ok(response.iter().collect()),
138            Ok(Err(e)) => {
139                tracing::debug!("Forward DNS lookup for {} failed: {}", hostname, e);
140                Err(DnsError::LookupFailed(e.to_string()))
141            }
142            Err(_) => {
143                tracing::debug!(
144                    "Forward DNS lookup for {} timed out after {}ms",
145                    hostname,
146                    self.timeout.as_millis()
147                );
148                Err(DnsError::Timeout(self.timeout.as_millis() as u64))
149            }
150        }
151    }
152
153    /// Verify IP via reverse+forward DNS lookup.
154    ///
155    /// Returns (verified, hostname) where verified is true only if:
156    /// 1. Reverse lookup (IP -> hostname) succeeds
157    /// 2. Forward lookup (hostname -> IPs) succeeds
158    /// 3. Original IP is contained in the forward lookup results
159    ///
160    /// This prevents DNS rebinding attacks where an attacker controls a domain
161    /// that initially resolves to a legitimate IP, then changes after verification.
162    ///
163    /// ## Security
164    /// The IP round-trip check is critical: we verify that the hostname
165    /// the IP claims to be actually resolves back to that IP.
166    pub async fn verify_ip(&self, ip: IpAddr) -> Result<(bool, Option<String>), DnsError> {
167        // Step 1: Reverse lookup IP -> hostname
168        let hostname = match self.reverse_lookup(ip).await? {
169            Some(h) => h,
170            None => return Ok((false, None)),
171        };
172
173        // Step 2: Forward lookup hostname -> IPs
174        let resolved_ips = match self.forward_lookup(&hostname).await {
175            Ok(ips) => ips,
176            Err(DnsError::RateLimited) => return Err(DnsError::RateLimited),
177            Err(_) => return Ok((false, Some(hostname))),
178        };
179
180        // Step 3: CRITICAL - Verify original IP is in the resolved IPs
181        // This prevents DNS rebinding attacks
182        let verified = resolved_ips.contains(&ip);
183
184        if !verified {
185            tracing::warn!(
186                ip = %ip,
187                hostname = %hostname,
188                resolved_ips = ?resolved_ips,
189                "DNS rebinding check failed: requesting IP not in forward lookup results"
190            );
191        }
192
193        Ok((verified, Some(hostname)))
194    }
195
196    /// Verify IP with explicit rebinding protection.
197    ///
198    /// Same as `verify_ip` but returns a specific error on IP mismatch
199    /// instead of just returning (false, hostname).
200    pub async fn verify_ip_strict(&self, ip: IpAddr) -> Result<String, DnsError> {
201        // Step 1: Reverse lookup IP -> hostname
202        let hostname = match self.reverse_lookup(ip).await? {
203            Some(h) => h,
204            None => return Err(DnsError::LookupFailed("No PTR record".to_string())),
205        };
206
207        // Step 2: Forward lookup hostname -> IPs
208        let resolved_ips = self.forward_lookup(&hostname).await?;
209
210        // Step 3: CRITICAL - Verify original IP is in the resolved IPs
211        if !resolved_ips.contains(&ip) {
212            tracing::warn!(
213                ip = %ip,
214                hostname = %hostname,
215                resolved_ips = ?resolved_ips,
216                "DNS rebinding attack detected: IP not in forward lookup results"
217            );
218            return Err(DnsError::IpMismatch);
219        }
220
221        Ok(hostname)
222    }
223}