Skip to main content

seer_core/status/
client.rs

1use std::collections::HashSet;
2use std::net::SocketAddr;
3use std::time::Duration;
4
5use chrono::Utc;
6use native_tls::TlsConnector;
7use once_cell::sync::Lazy;
8use regex::Regex;
9use reqwest::{Client, Url};
10use tokio::net::TcpStream;
11use tracing::{debug, instrument};
12
13use super::types::{CertificateInfo, DnsResolution, DomainExpiration, StatusResponse};
14use crate::dns::{DnsResolver, RecordData, RecordType};
15use crate::error::{Result, SeerError};
16use crate::lookup::SmartLookup;
17use crate::validation::{describe_reserved_ip, normalize_domain};
18
19/// Default timeout for HTTP and TLS operations (10 seconds).
20/// Balances responsiveness with allowing slow servers to respond.
21const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
22const MAX_REDIRECTS: usize = 5;
23
24/// Pre-compiled regex for extracting HTML title.
25static TITLE_REGEX: Lazy<Regex> = Lazy::new(|| {
26    Regex::new(r"(?i)<title[^>]*>([^<]+)</title>").expect("Invalid regex for HTML title extraction")
27});
28
29/// Client for checking domain status (HTTP, SSL, expiration)
30#[derive(Debug, Clone)]
31pub struct StatusClient {
32    timeout: Duration,
33    /// Cached DNS resolver reused across status checks.
34    dns_resolver: DnsResolver,
35    /// Reusable SmartLookup for domain expiration checks.
36    smart_lookup: SmartLookup,
37}
38
39impl Default for StatusClient {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl StatusClient {
46    /// Creates a new StatusClient with default settings.
47    pub fn new() -> Self {
48        Self {
49            timeout: DEFAULT_TIMEOUT,
50            dns_resolver: DnsResolver::new(),
51            smart_lookup: SmartLookup::new(),
52        }
53    }
54
55    /// Sets the timeout for HTTP and TLS operations.
56    pub fn with_timeout(mut self, timeout: Duration) -> Self {
57        self.timeout = timeout;
58        self
59    }
60
61    /// Checks the status of a domain (HTTP, SSL, expiration, DNS).
62    #[instrument(skip(self), fields(domain = %domain))]
63    pub async fn check(&self, domain: &str) -> Result<StatusResponse> {
64        // Normalize domain format (doesn't require DNS resolution)
65        let domain = normalize_domain(domain)?;
66        debug!("Checking status for domain: {}", domain);
67
68        let mut response = StatusResponse::new(domain.clone());
69
70        // Fetch HTTP status, SSL cert info, domain expiration, and DNS resolution concurrently
71        // HTTP and SSL checks include SSRF protection internally
72        let (http_result, cert_result, expiry_result, dns_result) = tokio::join!(
73            self.fetch_http_info(&domain),
74            self.fetch_certificate_info(&domain),
75            self.fetch_domain_expiration(&domain),
76            self.fetch_dns_resolution(&domain)
77        );
78
79        // Apply HTTP info
80        match http_result {
81            Ok((status, status_text, title)) => {
82                response.http_status = Some(status);
83                response.http_status_text = Some(status_text);
84                response.title = title;
85            }
86            Err(e) => response.errors.push(super::types::StatusError {
87                check: "http".to_string(),
88                message: e.to_string(),
89            }),
90        }
91
92        // Apply certificate info
93        match cert_result {
94            Ok(cert_info) => response.certificate = Some(cert_info),
95            Err(e) => response.errors.push(super::types::StatusError {
96                check: "ssl".to_string(),
97                message: e.to_string(),
98            }),
99        }
100
101        // Apply domain expiration info
102        match expiry_result {
103            Ok(expiry_info) => response.domain_expiration = expiry_info,
104            Err(e) => response.errors.push(super::types::StatusError {
105                check: "expiration".to_string(),
106                message: e.to_string(),
107            }),
108        }
109
110        // Apply DNS resolution info
111        match dns_result {
112            Ok(dns_info) => response.dns_resolution = Some(dns_info),
113            Err(e) => response.errors.push(super::types::StatusError {
114                check: "dns".to_string(),
115                message: e.to_string(),
116            }),
117        }
118
119        Ok(response)
120    }
121
122    /// Fetches the HTTP status code and page title.
123    ///
124    /// Redirects are followed manually with IP validation at each hop.
125    /// Resolved IPs are pinned on the HTTP client via `resolve_to_addrs` to
126    /// prevent DNS rebinding attacks (TOCTOU between validation and connect).
127    ///
128    /// # Security Note
129    /// Redirect targets are validated for SSRF but the HTTP response body (page title)
130    /// comes from an unauthenticated connection and should be treated as untrusted.
131    async fn fetch_http_info(&self, domain: &str) -> Result<(u16, String, Option<String>)> {
132        let mut url = Url::parse(&format!("https://{}", domain))
133            .map_err(|e| SeerError::HttpError(format!("invalid URL: {}", e)))?;
134        let mut visited = HashSet::new();
135
136        for _ in 0..=MAX_REDIRECTS {
137            let validated_addrs = validate_url_target(&url).await?;
138
139            if !visited.insert(url.clone()) {
140                return Err(SeerError::HttpError("redirect loop detected".to_string()));
141            }
142
143            // Build a per-hop client that pins the validated IPs so reqwest
144            // cannot re-resolve the hostname to a different (potentially
145            // private) address (DNS rebinding protection).
146            let host = url
147                .host_str()
148                .ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
149            let client = Client::builder()
150                .redirect(reqwest::redirect::Policy::none())
151                .user_agent(concat!("Seer/", env!("CARGO_PKG_VERSION")))
152                .resolve_to_addrs(host, &validated_addrs)
153                .build()
154                .map_err(|e| SeerError::HttpError(format!("failed to build HTTP client: {}", e)))?;
155
156            let response = client
157                .get(url.clone())
158                .timeout(self.timeout)
159                .send()
160                .await
161                .map_err(|e| SeerError::HttpError(e.to_string()))?;
162
163            if response.status().is_redirection() {
164                let location = response.headers().get(reqwest::header::LOCATION);
165                let location = location.and_then(|v| v.to_str().ok()).ok_or_else(|| {
166                    SeerError::HttpError("redirect missing location header".to_string())
167                })?;
168                let next_url = url
169                    .join(location)
170                    .or_else(|_| Url::parse(location))
171                    .map_err(|e| SeerError::HttpError(format!("invalid redirect URL: {}", e)))?;
172                url = next_url;
173                continue;
174            }
175
176            let status = response.status();
177            let status_code = status.as_u16();
178            let status_text = status.canonical_reason().unwrap_or("Unknown").to_string();
179
180            // Only try to get title for successful HTML responses
181            let title = if status.is_success() {
182                let content_type = response
183                    .headers()
184                    .get("content-type")
185                    .and_then(|v| v.to_str().ok())
186                    .unwrap_or("");
187
188                if content_type.contains("text/html") {
189                    // Read at most 64 KB for title extraction to prevent OOM
190                    // on arbitrarily large responses
191                    const MAX_TITLE_BODY: usize = 64 * 1024;
192                    let bytes = response
193                        .bytes()
194                        .await
195                        .map_err(|e| SeerError::HttpError(e.to_string()))?;
196                    let body = String::from_utf8_lossy(&bytes[..bytes.len().min(MAX_TITLE_BODY)]);
197                    extract_title(&body)
198                } else {
199                    None
200                }
201            } else {
202                None
203            };
204
205            return Ok((status_code, status_text, title));
206        }
207
208        Err(SeerError::HttpError("too many redirects".to_string()))
209    }
210
211    /// Fetches SSL certificate information using native-tls.
212    ///
213    /// # Security Note
214    /// This connection uses `danger_accept_invalid_certs(true)` to inspect certificates
215    /// even when invalid. Data retrieved (issuer, subject, dates) comes from an
216    /// unauthenticated TLS connection and may have been tampered with by a MITM.
217    async fn fetch_certificate_info(&self, domain: &str) -> Result<CertificateInfo> {
218        // SSRF protection: resolve domain and check IPs before connecting
219        let addr = format!("{}:443", domain);
220        let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
221            .await
222            .map_err(|e| SeerError::CertificateError(format!("DNS lookup failed: {}", e)))?
223            .collect();
224
225        if socket_addrs.is_empty() {
226            return Err(SeerError::CertificateError(format!(
227                "DNS lookup returned no addresses for {}",
228                domain
229            )));
230        }
231
232        for socket_addr in &socket_addrs {
233            if let Some(reason) = describe_reserved_ip(&socket_addr.ip()) {
234                return Err(SeerError::CertificateError(format!(
235                    "cannot connect to {}: {} — {}",
236                    domain,
237                    socket_addr.ip(),
238                    reason
239                )));
240            }
241        }
242
243        let connector = TlsConnector::builder()
244            .danger_accept_invalid_certs(true) // We want to see the cert even if invalid
245            .build()
246            .map_err(|e| SeerError::CertificateError(e.to_string()))?;
247
248        let connector = tokio_native_tls::TlsConnector::from(connector);
249
250        // Connect directly to the validated socket address to prevent DNS
251        // rebinding (TOCTOU) between validation and connect.
252        let stream =
253            tokio::time::timeout(self.timeout, TcpStream::connect(socket_addrs.as_slice()))
254                .await
255                .map_err(|_| SeerError::Timeout(format!("connection to {} timed out", domain)))?
256                .map_err(|e| SeerError::CertificateError(e.to_string()))?;
257
258        // Use the domain as SNI hostname for the TLS handshake.
259        let tls_stream = tokio::time::timeout(self.timeout, connector.connect(domain, stream))
260            .await
261            .map_err(|_| SeerError::Timeout(format!("TLS handshake with {} timed out", domain)))?
262            .map_err(|e| SeerError::CertificateError(e.to_string()))?;
263
264        // Get the peer certificate
265        let cert = tls_stream
266            .get_ref()
267            .peer_certificate()
268            .map_err(|e| SeerError::CertificateError(e.to_string()))?
269            .ok_or_else(|| SeerError::CertificateError("no certificate found".to_string()))?;
270
271        // Parse certificate info
272        let der = cert
273            .to_der()
274            .map_err(|e| SeerError::CertificateError(e.to_string()))?;
275
276        parse_certificate_der(&der, domain)
277    }
278
279    /// Fetches domain expiration info using WHOIS/RDAP.
280    async fn fetch_domain_expiration(&self, domain: &str) -> Result<Option<DomainExpiration>> {
281        match self.smart_lookup.lookup(domain).await {
282            Ok(result) => {
283                let (expiration_date, registrar) = result.expiration_info();
284
285                if let Some(exp_date) = expiration_date {
286                    let days_until_expiry = (exp_date - Utc::now()).num_days();
287                    Ok(Some(DomainExpiration {
288                        expiration_date: exp_date,
289                        days_until_expiry,
290                        registrar,
291                    }))
292                } else {
293                    Ok(None)
294                }
295            }
296            Err(_) => Ok(None), // Don't fail the whole status check if WHOIS fails
297        }
298    }
299
300    /// Fetches DNS root record resolution (A, AAAA, CNAME, NS).
301    async fn fetch_dns_resolution(&self, domain: &str) -> Result<DnsResolution> {
302        let resolver = &self.dns_resolver;
303
304        // Query all record types concurrently
305        let (a_result, aaaa_result, cname_result, ns_result) = tokio::join!(
306            resolver.resolve(domain, RecordType::A, None),
307            resolver.resolve(domain, RecordType::AAAA, None),
308            resolver.resolve(domain, RecordType::CNAME, None),
309            resolver.resolve(domain, RecordType::NS, None)
310        );
311
312        // Extract A records
313        let a_records: Vec<String> = a_result
314            .unwrap_or_default()
315            .into_iter()
316            .filter_map(|r| {
317                if let RecordData::A { address } = r.data {
318                    Some(address)
319                } else {
320                    None
321                }
322            })
323            .collect();
324
325        // Extract AAAA records
326        let aaaa_records: Vec<String> = aaaa_result
327            .unwrap_or_default()
328            .into_iter()
329            .filter_map(|r| {
330                if let RecordData::AAAA { address } = r.data {
331                    Some(address)
332                } else {
333                    None
334                }
335            })
336            .collect();
337
338        // Extract CNAME target (trim trailing dot)
339        let cname_target: Option<String> =
340            cname_result.unwrap_or_default().into_iter().find_map(|r| {
341                if let RecordData::CNAME { target } = r.data {
342                    Some(target.trim_end_matches('.').to_string())
343                } else {
344                    None
345                }
346            });
347
348        // Extract NS records (trim trailing dots)
349        let nameservers: Vec<String> = ns_result
350            .unwrap_or_default()
351            .into_iter()
352            .filter_map(|r| {
353                if let RecordData::NS { nameserver } = r.data {
354                    Some(nameserver.trim_end_matches('.').to_string())
355                } else {
356                    None
357                }
358            })
359            .collect();
360
361        // Domain resolves if it has A/AAAA records or a CNAME
362        let resolves = !a_records.is_empty() || !aaaa_records.is_empty() || cname_target.is_some();
363
364        Ok(DnsResolution {
365            a_records,
366            aaaa_records,
367            cname_target,
368            nameservers,
369            resolves,
370        })
371    }
372}
373
374// Domain normalization and validation is now handled by the validation module
375
376/// Extracts the title from HTML content.
377fn extract_title(html: &str) -> Option<String> {
378    TITLE_REGEX
379        .captures(html)
380        .and_then(|caps| caps.get(1))
381        .map(|m| m.as_str().trim().to_string())
382        .filter(|s| !s.is_empty())
383}
384
385/// Validates that a URL target is safe (no private/reserved IPs, no credentials,
386/// supported scheme) and returns the resolved socket addresses.
387///
388/// The caller should pin these addresses on the HTTP client to prevent DNS
389/// rebinding between validation and the actual connection.
390async fn validate_url_target(url: &Url) -> Result<Vec<SocketAddr>> {
391    let scheme = url.scheme();
392    if scheme != "https" && scheme != "http" {
393        return Err(SeerError::HttpError(format!(
394            "unsupported URL scheme: {}",
395            scheme
396        )));
397    }
398
399    if !url.username().is_empty() || url.password().is_some() {
400        return Err(SeerError::HttpError(
401            "URL credentials are not allowed".to_string(),
402        ));
403    }
404
405    let host = url
406        .host_str()
407        .ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
408    let port = url.port_or_known_default().unwrap_or(443);
409
410    // Only allow standard HTTP/HTTPS ports to prevent port scanning via redirects
411    if port != 80 && port != 443 {
412        return Err(SeerError::HttpError(format!(
413            "non-standard port {} is not allowed in redirects",
414            port
415        )));
416    }
417
418    if let Ok(ip) = host.parse::<std::net::IpAddr>() {
419        if let Some(reason) = describe_reserved_ip(&ip) {
420            return Err(SeerError::HttpError(format!(
421                "cannot connect to {}: {} — {}",
422                host, ip, reason
423            )));
424        }
425        return Ok(vec![SocketAddr::new(ip, port)]);
426    }
427
428    let addr = format!("{}:{}", host, port);
429    let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
430        .await
431        .map_err(|e| SeerError::HttpError(format!("DNS lookup failed: {}", e)))?
432        .collect();
433
434    if socket_addrs.is_empty() {
435        return Err(SeerError::HttpError(format!(
436            "DNS lookup returned no addresses for {}",
437            host
438        )));
439    }
440
441    for socket_addr in &socket_addrs {
442        if let Some(reason) = describe_reserved_ip(&socket_addr.ip()) {
443            return Err(SeerError::HttpError(format!(
444                "cannot connect to {}: {} — {}",
445                host,
446                socket_addr.ip(),
447                reason
448            )));
449        }
450    }
451
452    Ok(socket_addrs)
453}
454
455/// Parses certificate information from DER-encoded certificate using x509-parser.
456fn parse_certificate_der(der: &[u8], _domain: &str) -> Result<CertificateInfo> {
457    use x509_parser::prelude::*;
458
459    let (_, cert) = X509Certificate::from_der(der)
460        .map_err(|e| SeerError::CertificateError(format!("failed to parse certificate: {}", e)))?;
461
462    // Extract issuer - prefer CN, fall back to O (Organization)
463    let issuer =
464        extract_name_from_x509(cert.issuer()).unwrap_or_else(|| "Unknown Issuer".to_string());
465
466    // Extract subject - prefer CN, fall back to O (Organization)
467    let subject =
468        extract_name_from_x509(cert.subject()).unwrap_or_else(|| "Unknown Subject".to_string());
469
470    // Extract validity dates
471    let valid_from = asn1_time_to_chrono(cert.validity().not_before)?;
472    let valid_until = asn1_time_to_chrono(cert.validity().not_after)?;
473
474    let now = Utc::now();
475    let days_until_expiry = (valid_until - now).num_days();
476    let is_valid = now >= valid_from && now <= valid_until;
477
478    Ok(CertificateInfo {
479        issuer,
480        subject,
481        valid_from,
482        valid_until,
483        days_until_expiry,
484        is_valid,
485    })
486}
487
488/// Extracts the Common Name or Organization from an X.509 name.
489fn extract_name_from_x509(name: &x509_parser::prelude::X509Name) -> Option<String> {
490    use x509_parser::prelude::*;
491
492    // Try Common Name first (OID 2.5.4.3)
493    for rdn in name.iter() {
494        for attr in rdn.iter() {
495            if attr.attr_type() == &oid_registry::OID_X509_COMMON_NAME {
496                if let Some(s) = extract_attr_string(attr.attr_value()) {
497                    return Some(s);
498                }
499            }
500        }
501    }
502
503    // Fall back to Organization (OID 2.5.4.10)
504    for rdn in name.iter() {
505        for attr in rdn.iter() {
506            if attr.attr_type() == &oid_registry::OID_X509_ORGANIZATION_NAME {
507                if let Some(s) = extract_attr_string(attr.attr_value()) {
508                    return Some(s);
509                }
510            }
511        }
512    }
513
514    None
515}
516
517/// Extracts a string from an ASN.1 attribute value, handling different encodings.
518fn extract_attr_string(value: &x509_parser::der_parser::asn1_rs::Any) -> Option<String> {
519    // Try as_str() first (handles PrintableString, IA5String, etc.)
520    if let Ok(s) = value.as_str() {
521        return Some(s.to_string());
522    }
523
524    // Try UTF8String explicitly
525    if let Ok(utf8) = value.as_utf8string() {
526        return Some(utf8.string().to_string());
527    }
528
529    // Try raw bytes as UTF-8
530    if let Ok(s) = std::str::from_utf8(value.data) {
531        return Some(s.to_string());
532    }
533
534    None
535}
536
537/// Converts an x509-parser ASN1Time to a chrono DateTime.
538fn asn1_time_to_chrono(time: x509_parser::time::ASN1Time) -> Result<chrono::DateTime<Utc>> {
539    let timestamp = time.timestamp();
540    chrono::DateTime::from_timestamp(timestamp, 0)
541        .ok_or_else(|| SeerError::CertificateError("invalid certificate timestamp".to_string()))
542}