Skip to main content

seer_core/status/
client.rs

1use std::collections::HashSet;
2use std::net::SocketAddr;
3use std::time::Duration;
4
5use chrono::Utc;
6use native_tls::TlsConnector;
7use once_cell::sync::Lazy;
8use regex::Regex;
9use reqwest::{Client, Url};
10use tokio::net::TcpStream;
11use tracing::{debug, instrument};
12
13use super::types::{CertificateInfo, DnsResolution, DomainExpiration, StatusResponse};
14use crate::dns::{DnsResolver, RecordData, RecordType};
15use crate::error::{Result, SeerError};
16use crate::lookup::SmartLookup;
17use crate::validation::{describe_reserved_ip, normalize_domain};
18
19/// Default timeout for HTTP and TLS operations (10 seconds).
20/// Balances responsiveness with allowing slow servers to respond.
21const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
22const MAX_REDIRECTS: usize = 5;
23
24/// Pre-compiled regex for extracting HTML title.
25static TITLE_REGEX: Lazy<Regex> = Lazy::new(|| {
26    Regex::new(r"(?i)<title[^>]*>([^<]+)</title>").expect("Invalid regex for HTML title extraction")
27});
28
29/// Client for checking domain status (HTTP, SSL, expiration)
30#[derive(Debug, Clone)]
31pub struct StatusClient {
32    timeout: Duration,
33    /// Cached DNS resolver reused across status checks.
34    dns_resolver: DnsResolver,
35    /// Reusable SmartLookup for domain expiration checks.
36    smart_lookup: SmartLookup,
37}
38
39impl Default for StatusClient {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl StatusClient {
46    /// Creates a new StatusClient with default settings.
47    pub fn new() -> Self {
48        Self {
49            timeout: DEFAULT_TIMEOUT,
50            dns_resolver: DnsResolver::new(),
51            smart_lookup: SmartLookup::new(),
52        }
53    }
54
55    /// Sets the timeout for HTTP and TLS operations.
56    pub fn with_timeout(mut self, timeout: Duration) -> Self {
57        self.timeout = timeout;
58        self
59    }
60
61    /// Checks the status of a domain (HTTP, SSL, expiration, DNS).
62    #[instrument(skip(self), fields(domain = %domain))]
63    pub async fn check(&self, domain: &str) -> Result<StatusResponse> {
64        // Normalize domain format (doesn't require DNS resolution)
65        let domain = normalize_domain(domain)?;
66        debug!("Checking status for domain: {}", domain);
67
68        let mut response = StatusResponse::new(domain.clone());
69
70        // Fetch HTTP status, SSL cert info, domain expiration, and DNS resolution concurrently
71        // HTTP and SSL checks include SSRF protection internally
72        let (http_result, cert_result, expiry_result, dns_result) = tokio::join!(
73            self.fetch_http_info(&domain),
74            self.fetch_certificate_info(&domain),
75            self.fetch_domain_expiration(&domain),
76            self.fetch_dns_resolution(&domain)
77        );
78
79        // Apply HTTP info
80        match http_result {
81            Ok((status, status_text, title)) => {
82                response.http_status = Some(status);
83                response.http_status_text = Some(status_text);
84                response.title = title;
85            }
86            Err(e) => response.errors.push(super::types::StatusError {
87                check: "http".to_string(),
88                message: e.to_string(),
89            }),
90        }
91
92        // Apply certificate info
93        match cert_result {
94            Ok(cert_info) => response.certificate = Some(cert_info),
95            Err(e) => response.errors.push(super::types::StatusError {
96                check: "ssl".to_string(),
97                message: e.to_string(),
98            }),
99        }
100
101        // Apply domain expiration info
102        match expiry_result {
103            Ok(expiry_info) => response.domain_expiration = expiry_info,
104            Err(e) => response.errors.push(super::types::StatusError {
105                check: "expiration".to_string(),
106                message: e.to_string(),
107            }),
108        }
109
110        // Apply DNS resolution info
111        match dns_result {
112            Ok(dns_info) => response.dns_resolution = Some(dns_info),
113            Err(e) => response.errors.push(super::types::StatusError {
114                check: "dns".to_string(),
115                message: e.to_string(),
116            }),
117        }
118
119        Ok(response)
120    }
121
122    /// Fetches the HTTP status code and page title.
123    ///
124    /// Redirects are followed manually with IP validation at each hop.
125    /// Resolved IPs are pinned on the HTTP client via `resolve_to_addrs` to
126    /// prevent DNS rebinding attacks (TOCTOU between validation and connect).
127    ///
128    /// # Security Note
129    /// This path uses reqwest's default (validating) TLS configuration — a
130    /// bad certificate surfaces as a typed `SeerError::HttpError` and the
131    /// status check reports it as a failed "http" sub-check instead of
132    /// silently returning attacker-controlled body content as "successful".
133    /// The SSL inspection path in `ssl.rs` (and `fetch_certificate_info`
134    /// below) intentionally relaxes verification because inspecting an
135    /// invalid cert is the whole point of that code; this path MUST NOT.
136    ///
137    /// Redirect targets are validated for SSRF but the HTTP response body
138    /// (page title) comes from an unauthenticated connection and should be
139    /// treated as untrusted.
140    async fn fetch_http_info(&self, domain: &str) -> Result<(u16, String, Option<String>)> {
141        let mut url = Url::parse(&format!("https://{}", domain))
142            .map_err(|e| SeerError::HttpError(format!("invalid URL: {}", e)))?;
143        let mut visited = HashSet::new();
144
145        for _ in 0..=MAX_REDIRECTS {
146            let validated_addrs = validate_url_target(&url).await?;
147
148            if !visited.insert(url.clone()) {
149                return Err(SeerError::HttpError("redirect loop detected".to_string()));
150            }
151
152            // Build a per-hop client that pins the validated IPs so reqwest
153            // cannot re-resolve the hostname to a different (potentially
154            // private) address (DNS rebinding protection).
155            let host = url
156                .host_str()
157                .ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
158            let client = Client::builder()
159                .redirect(reqwest::redirect::Policy::none())
160                .user_agent(concat!("Seer/", env!("CARGO_PKG_VERSION")))
161                .resolve_to_addrs(host, &validated_addrs)
162                .build()
163                .map_err(|e| SeerError::HttpError(format!("failed to build HTTP client: {}", e)))?;
164
165            let response = client
166                .get(url.clone())
167                .timeout(self.timeout)
168                .send()
169                .await
170                .map_err(|e| SeerError::HttpError(e.to_string()))?;
171
172            if response.status().is_redirection() {
173                let location = response.headers().get(reqwest::header::LOCATION);
174                let location = location.and_then(|v| v.to_str().ok()).ok_or_else(|| {
175                    SeerError::HttpError("redirect missing location header".to_string())
176                })?;
177                let next_url = url
178                    .join(location)
179                    .or_else(|_| Url::parse(location))
180                    .map_err(|e| SeerError::HttpError(format!("invalid redirect URL: {}", e)))?;
181                url = next_url;
182                continue;
183            }
184
185            let status = response.status();
186            let status_code = status.as_u16();
187            let status_text = status.canonical_reason().unwrap_or("Unknown").to_string();
188
189            // Only try to get title for successful HTML responses
190            let title = if status.is_success() {
191                let content_type = response
192                    .headers()
193                    .get("content-type")
194                    .and_then(|v| v.to_str().ok())
195                    .unwrap_or("");
196
197                if content_type.contains("text/html") {
198                    // Stream at most 64 KB for title extraction. Streaming
199                    // (rather than `response.bytes().await`) prevents a
200                    // malicious server from forcing us to buffer a huge
201                    // body before the cap is applied.
202                    const MAX_TITLE_BODY: usize = 64 * 1024;
203                    use futures::StreamExt;
204                    let mut buf: Vec<u8> = Vec::with_capacity(8 * 1024);
205                    let mut stream = response.bytes_stream();
206                    while let Some(chunk) = stream.next().await {
207                        let chunk = chunk
208                            .map_err(|e| SeerError::HttpError(format!("body chunk: {}", e)))?;
209                        let remaining = MAX_TITLE_BODY.saturating_sub(buf.len());
210                        if remaining == 0 {
211                            break;
212                        }
213                        let take = remaining.min(chunk.len());
214                        buf.extend_from_slice(&chunk[..take]);
215                        if buf.len() >= MAX_TITLE_BODY {
216                            break;
217                        }
218                    }
219                    let body = String::from_utf8_lossy(&buf);
220                    extract_title(&body)
221                } else {
222                    None
223                }
224            } else {
225                None
226            };
227
228            return Ok((status_code, status_text, title));
229        }
230
231        Err(SeerError::HttpError("too many redirects".to_string()))
232    }
233
234    /// Fetches SSL certificate information using native-tls.
235    ///
236    /// # Security Note
237    /// This connection uses `danger_accept_invalid_certs(true)` to inspect certificates
238    /// even when invalid. Data retrieved (issuer, subject, dates) comes from an
239    /// unauthenticated TLS connection and may have been tampered with by a MITM.
240    async fn fetch_certificate_info(&self, domain: &str) -> Result<CertificateInfo> {
241        // SSRF protection: resolve domain and check IPs before connecting
242        let addr = format!("{}:443", domain);
243        let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
244            .await
245            .map_err(|e| SeerError::CertificateError(format!("DNS lookup failed: {}", e)))?
246            .collect();
247
248        if socket_addrs.is_empty() {
249            return Err(SeerError::CertificateError(format!(
250                "DNS lookup returned no addresses for {}",
251                domain
252            )));
253        }
254
255        for socket_addr in &socket_addrs {
256            if let Some(reason) = describe_reserved_ip(&socket_addr.ip()) {
257                return Err(SeerError::CertificateError(format!(
258                    "cannot connect to {}: {} — {}",
259                    domain,
260                    socket_addr.ip(),
261                    reason
262                )));
263            }
264        }
265
266        let connector = TlsConnector::builder()
267            .danger_accept_invalid_certs(true) // We want to see the cert even if invalid
268            .build()
269            .map_err(|e| SeerError::CertificateError(e.to_string()))?;
270
271        let connector = tokio_native_tls::TlsConnector::from(connector);
272
273        // Connect directly to the validated socket address to prevent DNS
274        // rebinding (TOCTOU) between validation and connect.
275        let stream =
276            tokio::time::timeout(self.timeout, TcpStream::connect(socket_addrs.as_slice()))
277                .await
278                .map_err(|_| SeerError::Timeout(format!("connection to {} timed out", domain)))?
279                .map_err(|e| SeerError::CertificateError(e.to_string()))?;
280
281        // Use the domain as SNI hostname for the TLS handshake.
282        let tls_stream = tokio::time::timeout(self.timeout, connector.connect(domain, stream))
283            .await
284            .map_err(|_| SeerError::Timeout(format!("TLS handshake with {} timed out", domain)))?
285            .map_err(|e| SeerError::CertificateError(e.to_string()))?;
286
287        // Get the peer certificate
288        let cert = tls_stream
289            .get_ref()
290            .peer_certificate()
291            .map_err(|e| SeerError::CertificateError(e.to_string()))?
292            .ok_or_else(|| SeerError::CertificateError("no certificate found".to_string()))?;
293
294        // Parse certificate info
295        let der = cert
296            .to_der()
297            .map_err(|e| SeerError::CertificateError(e.to_string()))?;
298
299        parse_certificate_der(&der, domain)
300    }
301
302    /// Fetches domain expiration info using WHOIS/RDAP.
303    async fn fetch_domain_expiration(&self, domain: &str) -> Result<Option<DomainExpiration>> {
304        match self.smart_lookup.lookup(domain).await {
305            Ok(result) => {
306                let (expiration_date, registrar) = result.expiration_info();
307
308                if let Some(exp_date) = expiration_date {
309                    let days_until_expiry = (exp_date - Utc::now()).num_days();
310                    Ok(Some(DomainExpiration {
311                        expiration_date: exp_date,
312                        days_until_expiry,
313                        registrar,
314                    }))
315                } else {
316                    Ok(None)
317                }
318            }
319            Err(_) => Ok(None), // Don't fail the whole status check if WHOIS fails
320        }
321    }
322
323    /// Fetches DNS root record resolution (A, AAAA, CNAME, NS).
324    async fn fetch_dns_resolution(&self, domain: &str) -> Result<DnsResolution> {
325        let resolver = &self.dns_resolver;
326
327        // Query all record types concurrently
328        let (a_result, aaaa_result, cname_result, ns_result) = tokio::join!(
329            resolver.resolve(domain, RecordType::A, None),
330            resolver.resolve(domain, RecordType::AAAA, None),
331            resolver.resolve(domain, RecordType::CNAME, None),
332            resolver.resolve(domain, RecordType::NS, None)
333        );
334
335        // Extract A records
336        let a_records: Vec<String> = a_result
337            .unwrap_or_default()
338            .into_iter()
339            .filter_map(|r| {
340                if let RecordData::A { address } = r.data {
341                    Some(address)
342                } else {
343                    None
344                }
345            })
346            .collect();
347
348        // Extract AAAA records
349        let aaaa_records: Vec<String> = aaaa_result
350            .unwrap_or_default()
351            .into_iter()
352            .filter_map(|r| {
353                if let RecordData::AAAA { address } = r.data {
354                    Some(address)
355                } else {
356                    None
357                }
358            })
359            .collect();
360
361        // Extract CNAME target (trim trailing dot)
362        let cname_target: Option<String> =
363            cname_result.unwrap_or_default().into_iter().find_map(|r| {
364                if let RecordData::CNAME { target } = r.data {
365                    Some(target.trim_end_matches('.').to_string())
366                } else {
367                    None
368                }
369            });
370
371        // Extract NS records (trim trailing dots)
372        let nameservers: Vec<String> = ns_result
373            .unwrap_or_default()
374            .into_iter()
375            .filter_map(|r| {
376                if let RecordData::NS { nameserver } = r.data {
377                    Some(nameserver.trim_end_matches('.').to_string())
378                } else {
379                    None
380                }
381            })
382            .collect();
383
384        // Domain resolves if it has A/AAAA records or a CNAME
385        let resolves = !a_records.is_empty() || !aaaa_records.is_empty() || cname_target.is_some();
386
387        Ok(DnsResolution {
388            a_records,
389            aaaa_records,
390            cname_target,
391            nameservers,
392            resolves,
393        })
394    }
395}
396
397// Domain normalization and validation is now handled by the validation module
398
399/// Extracts the title from HTML content.
400fn extract_title(html: &str) -> Option<String> {
401    TITLE_REGEX
402        .captures(html)
403        .and_then(|caps| caps.get(1))
404        .map(|m| m.as_str().trim().to_string())
405        .filter(|s| !s.is_empty())
406}
407
408/// Validates that a URL target is safe (no private/reserved IPs, no credentials,
409/// supported scheme) and returns the resolved socket addresses.
410///
411/// The caller should pin these addresses on the HTTP client to prevent DNS
412/// rebinding between validation and the actual connection.
413async fn validate_url_target(url: &Url) -> Result<Vec<SocketAddr>> {
414    let scheme = url.scheme();
415    if scheme != "https" && scheme != "http" {
416        return Err(SeerError::HttpError(format!(
417            "unsupported URL scheme: {}",
418            scheme
419        )));
420    }
421
422    if !url.username().is_empty() || url.password().is_some() {
423        return Err(SeerError::HttpError(
424            "URL credentials are not allowed".to_string(),
425        ));
426    }
427
428    let host = url
429        .host_str()
430        .ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
431    let port = url.port_or_known_default().unwrap_or(443);
432
433    // Only allow standard HTTP/HTTPS ports to prevent port scanning via redirects
434    if port != 80 && port != 443 {
435        return Err(SeerError::HttpError(format!(
436            "non-standard port {} is not allowed in redirects",
437            port
438        )));
439    }
440
441    if let Ok(ip) = host.parse::<std::net::IpAddr>() {
442        if let Some(reason) = describe_reserved_ip(&ip) {
443            return Err(SeerError::HttpError(format!(
444                "cannot connect to {}: {} — {}",
445                host, ip, reason
446            )));
447        }
448        return Ok(vec![SocketAddr::new(ip, port)]);
449    }
450
451    let addr = format!("{}:{}", host, port);
452    let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
453        .await
454        .map_err(|e| SeerError::HttpError(format!("DNS lookup failed: {}", e)))?
455        .collect();
456
457    if socket_addrs.is_empty() {
458        return Err(SeerError::HttpError(format!(
459            "DNS lookup returned no addresses for {}",
460            host
461        )));
462    }
463
464    for socket_addr in &socket_addrs {
465        if let Some(reason) = describe_reserved_ip(&socket_addr.ip()) {
466            return Err(SeerError::HttpError(format!(
467                "cannot connect to {}: {} — {}",
468                host,
469                socket_addr.ip(),
470                reason
471            )));
472        }
473    }
474
475    Ok(socket_addrs)
476}
477
478/// Parses certificate information from DER-encoded certificate using x509-parser.
479fn parse_certificate_der(der: &[u8], domain: &str) -> Result<CertificateInfo> {
480    use x509_parser::prelude::*;
481
482    let (_, cert) = X509Certificate::from_der(der)
483        .map_err(|e| SeerError::CertificateError(format!("failed to parse certificate: {}", e)))?;
484
485    // Extract issuer - prefer CN, fall back to O (Organization)
486    let issuer =
487        extract_name_from_x509(cert.issuer()).unwrap_or_else(|| "Unknown Issuer".to_string());
488
489    // Extract subject - prefer CN, fall back to O (Organization)
490    let subject =
491        extract_name_from_x509(cert.subject()).unwrap_or_else(|| "Unknown Subject".to_string());
492
493    // Extract validity dates
494    let valid_from = asn1_time_to_chrono(cert.validity().not_before)?;
495    let valid_until = asn1_time_to_chrono(cert.validity().not_after)?;
496
497    let now = Utc::now();
498    let days_until_expiry = (valid_until - now).num_days();
499    let is_valid = now >= valid_from && now <= valid_until;
500
501    // Hostname verification is performed manually because the TLS connector
502    // was configured with danger_accept_invalid_certs(true) to allow cert
503    // inspection on mildly-broken sites. Without this check any cert — even
504    // one issued for an unrelated domain — would be accepted.
505    let hostname_verified = cert_matches_hostname(&cert, domain);
506
507    Ok(CertificateInfo {
508        issuer,
509        subject,
510        valid_from,
511        valid_until,
512        days_until_expiry,
513        is_valid,
514        hostname_verified,
515    })
516}
517
518/// Matches a hostname against a certificate name pattern.
519///
520/// Supports exact matches (case-insensitive) and single-label wildcards
521/// per RFC 6125 — `*.example.com` matches `a.example.com` but not
522/// `example.com` or `a.b.example.com`.
523fn hostname_matches_pattern(host: &str, pattern: &str) -> bool {
524    let host = host.to_ascii_lowercase();
525    let pattern = pattern.to_ascii_lowercase();
526    if let Some(rest) = pattern.strip_prefix("*.") {
527        // Wildcard: must match exactly one label, and must contain a dot
528        let Some(dot) = host.find('.') else {
529            return false;
530        };
531        let host_rest = &host[dot + 1..];
532        host_rest == rest
533    } else {
534        host == pattern
535    }
536}
537
538/// Checks whether a certificate's SAN dNSName entries (or CN as fallback)
539/// match the queried hostname.
540///
541/// Per RFC 6125, SAN dNSName is the authoritative source; CN is only checked
542/// as a legacy fallback.
543fn cert_matches_hostname(cert: &x509_parser::certificate::X509Certificate<'_>, host: &str) -> bool {
544    use x509_parser::prelude::*;
545
546    // SAN dNSName entries (preferred per RFC 6125)
547    if let Ok(Some(san_ext)) = cert.tbs_certificate.subject_alternative_name() {
548        for name in &san_ext.value.general_names {
549            if let GeneralName::DNSName(n) = name {
550                if hostname_matches_pattern(host, n) {
551                    return true;
552                }
553            }
554        }
555    }
556
557    // CN fallback (legacy)
558    for cn in cert.subject().iter_common_name() {
559        if let Ok(s) = cn.as_str() {
560            if hostname_matches_pattern(host, s) {
561                return true;
562            }
563        }
564    }
565
566    false
567}
568
569/// Extracts the Common Name or Organization from an X.509 name.
570fn extract_name_from_x509(name: &x509_parser::prelude::X509Name) -> Option<String> {
571    use x509_parser::prelude::*;
572
573    // Try Common Name first (OID 2.5.4.3)
574    for rdn in name.iter() {
575        for attr in rdn.iter() {
576            if attr.attr_type() == &oid_registry::OID_X509_COMMON_NAME {
577                if let Some(s) = extract_attr_string(attr.attr_value()) {
578                    return Some(s);
579                }
580            }
581        }
582    }
583
584    // Fall back to Organization (OID 2.5.4.10)
585    for rdn in name.iter() {
586        for attr in rdn.iter() {
587            if attr.attr_type() == &oid_registry::OID_X509_ORGANIZATION_NAME {
588                if let Some(s) = extract_attr_string(attr.attr_value()) {
589                    return Some(s);
590                }
591            }
592        }
593    }
594
595    None
596}
597
598/// Extracts a string from an ASN.1 attribute value, handling different encodings.
599fn extract_attr_string(value: &x509_parser::der_parser::asn1_rs::Any) -> Option<String> {
600    // Try as_str() first (handles PrintableString, IA5String, etc.)
601    if let Ok(s) = value.as_str() {
602        return Some(s.to_string());
603    }
604
605    // Try UTF8String explicitly
606    if let Ok(utf8) = value.as_utf8string() {
607        return Some(utf8.string().to_string());
608    }
609
610    // Try raw bytes as UTF-8
611    if let Ok(s) = std::str::from_utf8(value.data) {
612        return Some(s.to_string());
613    }
614
615    None
616}
617
618/// Converts an x509-parser ASN1Time to a chrono DateTime.
619fn asn1_time_to_chrono(time: x509_parser::time::ASN1Time) -> Result<chrono::DateTime<Utc>> {
620    let timestamp = time.timestamp();
621    chrono::DateTime::from_timestamp(timestamp, 0)
622        .ok_or_else(|| SeerError::CertificateError("invalid certificate timestamp".to_string()))
623}
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628
629    #[test]
630    fn hostname_matches_pattern_exact() {
631        assert!(hostname_matches_pattern("example.com", "example.com"));
632        assert!(hostname_matches_pattern("EXAMPLE.COM", "example.com"));
633        assert!(hostname_matches_pattern("example.com", "EXAMPLE.COM"));
634        assert!(!hostname_matches_pattern("evil.com", "example.com"));
635        assert!(!hostname_matches_pattern("example.com", "evil.com"));
636    }
637
638    #[test]
639    fn hostname_matches_pattern_wildcard() {
640        assert!(hostname_matches_pattern("a.example.com", "*.example.com"));
641        assert!(hostname_matches_pattern("A.EXAMPLE.COM", "*.example.com"));
642        // Apex must not match wildcard (RFC 6125)
643        assert!(!hostname_matches_pattern("example.com", "*.example.com"));
644        // Wildcard only covers a single label
645        assert!(!hostname_matches_pattern(
646            "a.b.example.com",
647            "*.example.com"
648        ));
649        assert!(!hostname_matches_pattern("b.other.com", "*.example.com"));
650    }
651
652    #[test]
653    fn hostname_matches_pattern_wildcard_requires_dot() {
654        // A bare host with no dot cannot match a wildcard pattern
655        assert!(!hostname_matches_pattern("localhost", "*.example.com"));
656    }
657}