Skip to main content

seer_core/status/
client.rs

1use std::collections::HashSet;
2use std::net::SocketAddr;
3use std::time::Duration;
4
5use chrono::Utc;
6use native_tls::TlsConnector;
7use once_cell::sync::Lazy;
8use regex::Regex;
9use reqwest::{Client, Url};
10use tokio::net::TcpStream;
11use tracing::{debug, instrument};
12
13use super::types::{CertificateInfo, DnsResolution, DomainExpiration, StatusResponse};
14use crate::dns::{DnsResolver, RecordData, RecordType};
15use crate::error::{Result, SeerError};
16use crate::lookup::SmartLookup;
17use crate::validation::{describe_reserved_ip, normalize_domain};
18
19/// Default timeout for HTTP and TLS operations (10 seconds).
20/// Balances responsiveness with allowing slow servers to respond.
21const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
22const MAX_REDIRECTS: usize = 5;
23
24/// Pre-compiled regex for extracting HTML title.
25static TITLE_REGEX: Lazy<Regex> = Lazy::new(|| {
26    Regex::new(r"(?i)<title[^>]*>([^<]+)</title>").expect("Invalid regex for HTML title extraction")
27});
28
29/// Client for checking domain status (HTTP, SSL, expiration)
30#[derive(Debug, Clone)]
31pub struct StatusClient {
32    timeout: Duration,
33    /// Cached DNS resolver reused across status checks.
34    dns_resolver: DnsResolver,
35    /// Reusable SmartLookup for domain expiration checks.
36    smart_lookup: SmartLookup,
37}
38
39impl Default for StatusClient {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl StatusClient {
46    /// Creates a new StatusClient with default settings.
47    pub fn new() -> Self {
48        Self {
49            timeout: DEFAULT_TIMEOUT,
50            dns_resolver: DnsResolver::new(),
51            smart_lookup: SmartLookup::new(),
52        }
53    }
54
55    /// Sets the timeout for HTTP and TLS operations.
56    pub fn with_timeout(mut self, timeout: Duration) -> Self {
57        self.timeout = timeout;
58        self
59    }
60
61    /// Checks the status of a domain (HTTP, SSL, expiration, DNS).
62    #[instrument(skip(self), fields(domain = %domain))]
63    pub async fn check(&self, domain: &str) -> Result<StatusResponse> {
64        // Normalize domain format (doesn't require DNS resolution)
65        let domain = normalize_domain(domain)?;
66        debug!("Checking status for domain: {}", domain);
67
68        let mut response = StatusResponse::new(domain.clone());
69
70        // Fetch HTTP status, SSL cert info, domain expiration, and DNS resolution concurrently
71        // HTTP and SSL checks include SSRF protection internally
72        let (http_result, cert_result, expiry_result, dns_result) = tokio::join!(
73            self.fetch_http_info(&domain),
74            self.fetch_certificate_info(&domain),
75            self.fetch_domain_expiration(&domain),
76            self.fetch_dns_resolution(&domain)
77        );
78
79        // Apply HTTP info
80        match http_result {
81            Ok((status, status_text, title)) => {
82                response.http_status = Some(status);
83                response.http_status_text = Some(status_text);
84                response.title = title;
85            }
86            Err(e) => response.errors.push(super::types::StatusError {
87                check: "http".to_string(),
88                message: e.to_string(),
89            }),
90        }
91
92        // Apply certificate info
93        match cert_result {
94            Ok(cert_info) => response.certificate = Some(cert_info),
95            Err(e) => response.errors.push(super::types::StatusError {
96                check: "ssl".to_string(),
97                message: e.to_string(),
98            }),
99        }
100
101        // Apply domain expiration info
102        match expiry_result {
103            Ok(expiry_info) => response.domain_expiration = expiry_info,
104            Err(e) => response.errors.push(super::types::StatusError {
105                check: "expiration".to_string(),
106                message: e.to_string(),
107            }),
108        }
109
110        // Apply DNS resolution info
111        match dns_result {
112            Ok(dns_info) => response.dns_resolution = Some(dns_info),
113            Err(e) => response.errors.push(super::types::StatusError {
114                check: "dns".to_string(),
115                message: e.to_string(),
116            }),
117        }
118
119        Ok(response)
120    }
121
122    /// Fetches the HTTP status code and page title.
123    ///
124    /// Redirects are followed manually with IP validation at each hop.
125    /// Resolved IPs are pinned on the HTTP client via `resolve_to_addrs` to
126    /// prevent DNS rebinding attacks (TOCTOU between validation and connect).
127    ///
128    /// # Security Note
129    /// Redirect targets are validated for SSRF but the HTTP response body (page title)
130    /// comes from an unauthenticated connection and should be treated as untrusted.
131    async fn fetch_http_info(&self, domain: &str) -> Result<(u16, String, Option<String>)> {
132        let mut url = Url::parse(&format!("https://{}", domain))
133            .map_err(|e| SeerError::HttpError(format!("invalid URL: {}", e)))?;
134        let mut visited = HashSet::new();
135
136        for _ in 0..=MAX_REDIRECTS {
137            let validated_addrs = validate_url_target(&url).await?;
138
139            if !visited.insert(url.clone()) {
140                return Err(SeerError::HttpError("redirect loop detected".to_string()));
141            }
142
143            // Build a per-hop client that pins the validated IPs so reqwest
144            // cannot re-resolve the hostname to a different (potentially
145            // private) address (DNS rebinding protection).
146            let host = url
147                .host_str()
148                .ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
149            let client = Client::builder()
150                .redirect(reqwest::redirect::Policy::none())
151                .user_agent(concat!("Seer/", env!("CARGO_PKG_VERSION")))
152                .resolve_to_addrs(host, &validated_addrs)
153                .build()
154                .map_err(|e| SeerError::HttpError(format!("failed to build HTTP client: {}", e)))?;
155
156            let response = client
157                .get(url.clone())
158                .timeout(self.timeout)
159                .send()
160                .await
161                .map_err(|e| SeerError::HttpError(e.to_string()))?;
162
163            if response.status().is_redirection() {
164                let location = response.headers().get(reqwest::header::LOCATION);
165                let location = location.and_then(|v| v.to_str().ok()).ok_or_else(|| {
166                    SeerError::HttpError("redirect missing location header".to_string())
167                })?;
168                let next_url = url
169                    .join(location)
170                    .or_else(|_| Url::parse(location))
171                    .map_err(|e| SeerError::HttpError(format!("invalid redirect URL: {}", e)))?;
172                url = next_url;
173                continue;
174            }
175
176            let status = response.status();
177            let status_code = status.as_u16();
178            let status_text = status.canonical_reason().unwrap_or("Unknown").to_string();
179
180            // Only try to get title for successful HTML responses
181            let title = if status.is_success() {
182                let content_type = response
183                    .headers()
184                    .get("content-type")
185                    .and_then(|v| v.to_str().ok())
186                    .unwrap_or("");
187
188                if content_type.contains("text/html") {
189                    // Stream at most 64 KB for title extraction. Streaming
190                    // (rather than `response.bytes().await`) prevents a
191                    // malicious server from forcing us to buffer a huge
192                    // body before the cap is applied.
193                    const MAX_TITLE_BODY: usize = 64 * 1024;
194                    use futures::StreamExt;
195                    let mut buf: Vec<u8> = Vec::with_capacity(8 * 1024);
196                    let mut stream = response.bytes_stream();
197                    while let Some(chunk) = stream.next().await {
198                        let chunk = chunk
199                            .map_err(|e| SeerError::HttpError(format!("body chunk: {}", e)))?;
200                        let remaining = MAX_TITLE_BODY.saturating_sub(buf.len());
201                        if remaining == 0 {
202                            break;
203                        }
204                        let take = remaining.min(chunk.len());
205                        buf.extend_from_slice(&chunk[..take]);
206                        if buf.len() >= MAX_TITLE_BODY {
207                            break;
208                        }
209                    }
210                    let body = String::from_utf8_lossy(&buf);
211                    extract_title(&body)
212                } else {
213                    None
214                }
215            } else {
216                None
217            };
218
219            return Ok((status_code, status_text, title));
220        }
221
222        Err(SeerError::HttpError("too many redirects".to_string()))
223    }
224
225    /// Fetches SSL certificate information using native-tls.
226    ///
227    /// # Security Note
228    /// This connection uses `danger_accept_invalid_certs(true)` to inspect certificates
229    /// even when invalid. Data retrieved (issuer, subject, dates) comes from an
230    /// unauthenticated TLS connection and may have been tampered with by a MITM.
231    async fn fetch_certificate_info(&self, domain: &str) -> Result<CertificateInfo> {
232        // SSRF protection: resolve domain and check IPs before connecting
233        let addr = format!("{}:443", domain);
234        let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
235            .await
236            .map_err(|e| SeerError::CertificateError(format!("DNS lookup failed: {}", e)))?
237            .collect();
238
239        if socket_addrs.is_empty() {
240            return Err(SeerError::CertificateError(format!(
241                "DNS lookup returned no addresses for {}",
242                domain
243            )));
244        }
245
246        for socket_addr in &socket_addrs {
247            if let Some(reason) = describe_reserved_ip(&socket_addr.ip()) {
248                return Err(SeerError::CertificateError(format!(
249                    "cannot connect to {}: {} — {}",
250                    domain,
251                    socket_addr.ip(),
252                    reason
253                )));
254            }
255        }
256
257        let connector = TlsConnector::builder()
258            .danger_accept_invalid_certs(true) // We want to see the cert even if invalid
259            .build()
260            .map_err(|e| SeerError::CertificateError(e.to_string()))?;
261
262        let connector = tokio_native_tls::TlsConnector::from(connector);
263
264        // Connect directly to the validated socket address to prevent DNS
265        // rebinding (TOCTOU) between validation and connect.
266        let stream =
267            tokio::time::timeout(self.timeout, TcpStream::connect(socket_addrs.as_slice()))
268                .await
269                .map_err(|_| SeerError::Timeout(format!("connection to {} timed out", domain)))?
270                .map_err(|e| SeerError::CertificateError(e.to_string()))?;
271
272        // Use the domain as SNI hostname for the TLS handshake.
273        let tls_stream = tokio::time::timeout(self.timeout, connector.connect(domain, stream))
274            .await
275            .map_err(|_| SeerError::Timeout(format!("TLS handshake with {} timed out", domain)))?
276            .map_err(|e| SeerError::CertificateError(e.to_string()))?;
277
278        // Get the peer certificate
279        let cert = tls_stream
280            .get_ref()
281            .peer_certificate()
282            .map_err(|e| SeerError::CertificateError(e.to_string()))?
283            .ok_or_else(|| SeerError::CertificateError("no certificate found".to_string()))?;
284
285        // Parse certificate info
286        let der = cert
287            .to_der()
288            .map_err(|e| SeerError::CertificateError(e.to_string()))?;
289
290        parse_certificate_der(&der, domain)
291    }
292
293    /// Fetches domain expiration info using WHOIS/RDAP.
294    async fn fetch_domain_expiration(&self, domain: &str) -> Result<Option<DomainExpiration>> {
295        match self.smart_lookup.lookup(domain).await {
296            Ok(result) => {
297                let (expiration_date, registrar) = result.expiration_info();
298
299                if let Some(exp_date) = expiration_date {
300                    let days_until_expiry = (exp_date - Utc::now()).num_days();
301                    Ok(Some(DomainExpiration {
302                        expiration_date: exp_date,
303                        days_until_expiry,
304                        registrar,
305                    }))
306                } else {
307                    Ok(None)
308                }
309            }
310            Err(_) => Ok(None), // Don't fail the whole status check if WHOIS fails
311        }
312    }
313
314    /// Fetches DNS root record resolution (A, AAAA, CNAME, NS).
315    async fn fetch_dns_resolution(&self, domain: &str) -> Result<DnsResolution> {
316        let resolver = &self.dns_resolver;
317
318        // Query all record types concurrently
319        let (a_result, aaaa_result, cname_result, ns_result) = tokio::join!(
320            resolver.resolve(domain, RecordType::A, None),
321            resolver.resolve(domain, RecordType::AAAA, None),
322            resolver.resolve(domain, RecordType::CNAME, None),
323            resolver.resolve(domain, RecordType::NS, None)
324        );
325
326        // Extract A records
327        let a_records: Vec<String> = a_result
328            .unwrap_or_default()
329            .into_iter()
330            .filter_map(|r| {
331                if let RecordData::A { address } = r.data {
332                    Some(address)
333                } else {
334                    None
335                }
336            })
337            .collect();
338
339        // Extract AAAA records
340        let aaaa_records: Vec<String> = aaaa_result
341            .unwrap_or_default()
342            .into_iter()
343            .filter_map(|r| {
344                if let RecordData::AAAA { address } = r.data {
345                    Some(address)
346                } else {
347                    None
348                }
349            })
350            .collect();
351
352        // Extract CNAME target (trim trailing dot)
353        let cname_target: Option<String> =
354            cname_result.unwrap_or_default().into_iter().find_map(|r| {
355                if let RecordData::CNAME { target } = r.data {
356                    Some(target.trim_end_matches('.').to_string())
357                } else {
358                    None
359                }
360            });
361
362        // Extract NS records (trim trailing dots)
363        let nameservers: Vec<String> = ns_result
364            .unwrap_or_default()
365            .into_iter()
366            .filter_map(|r| {
367                if let RecordData::NS { nameserver } = r.data {
368                    Some(nameserver.trim_end_matches('.').to_string())
369                } else {
370                    None
371                }
372            })
373            .collect();
374
375        // Domain resolves if it has A/AAAA records or a CNAME
376        let resolves = !a_records.is_empty() || !aaaa_records.is_empty() || cname_target.is_some();
377
378        Ok(DnsResolution {
379            a_records,
380            aaaa_records,
381            cname_target,
382            nameservers,
383            resolves,
384        })
385    }
386}
387
388// Domain normalization and validation is now handled by the validation module
389
390/// Extracts the title from HTML content.
391fn extract_title(html: &str) -> Option<String> {
392    TITLE_REGEX
393        .captures(html)
394        .and_then(|caps| caps.get(1))
395        .map(|m| m.as_str().trim().to_string())
396        .filter(|s| !s.is_empty())
397}
398
399/// Validates that a URL target is safe (no private/reserved IPs, no credentials,
400/// supported scheme) and returns the resolved socket addresses.
401///
402/// The caller should pin these addresses on the HTTP client to prevent DNS
403/// rebinding between validation and the actual connection.
404async fn validate_url_target(url: &Url) -> Result<Vec<SocketAddr>> {
405    let scheme = url.scheme();
406    if scheme != "https" && scheme != "http" {
407        return Err(SeerError::HttpError(format!(
408            "unsupported URL scheme: {}",
409            scheme
410        )));
411    }
412
413    if !url.username().is_empty() || url.password().is_some() {
414        return Err(SeerError::HttpError(
415            "URL credentials are not allowed".to_string(),
416        ));
417    }
418
419    let host = url
420        .host_str()
421        .ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
422    let port = url.port_or_known_default().unwrap_or(443);
423
424    // Only allow standard HTTP/HTTPS ports to prevent port scanning via redirects
425    if port != 80 && port != 443 {
426        return Err(SeerError::HttpError(format!(
427            "non-standard port {} is not allowed in redirects",
428            port
429        )));
430    }
431
432    if let Ok(ip) = host.parse::<std::net::IpAddr>() {
433        if let Some(reason) = describe_reserved_ip(&ip) {
434            return Err(SeerError::HttpError(format!(
435                "cannot connect to {}: {} — {}",
436                host, ip, reason
437            )));
438        }
439        return Ok(vec![SocketAddr::new(ip, port)]);
440    }
441
442    let addr = format!("{}:{}", host, port);
443    let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
444        .await
445        .map_err(|e| SeerError::HttpError(format!("DNS lookup failed: {}", e)))?
446        .collect();
447
448    if socket_addrs.is_empty() {
449        return Err(SeerError::HttpError(format!(
450            "DNS lookup returned no addresses for {}",
451            host
452        )));
453    }
454
455    for socket_addr in &socket_addrs {
456        if let Some(reason) = describe_reserved_ip(&socket_addr.ip()) {
457            return Err(SeerError::HttpError(format!(
458                "cannot connect to {}: {} — {}",
459                host,
460                socket_addr.ip(),
461                reason
462            )));
463        }
464    }
465
466    Ok(socket_addrs)
467}
468
469/// Parses certificate information from DER-encoded certificate using x509-parser.
470fn parse_certificate_der(der: &[u8], domain: &str) -> Result<CertificateInfo> {
471    use x509_parser::prelude::*;
472
473    let (_, cert) = X509Certificate::from_der(der)
474        .map_err(|e| SeerError::CertificateError(format!("failed to parse certificate: {}", e)))?;
475
476    // Extract issuer - prefer CN, fall back to O (Organization)
477    let issuer =
478        extract_name_from_x509(cert.issuer()).unwrap_or_else(|| "Unknown Issuer".to_string());
479
480    // Extract subject - prefer CN, fall back to O (Organization)
481    let subject =
482        extract_name_from_x509(cert.subject()).unwrap_or_else(|| "Unknown Subject".to_string());
483
484    // Extract validity dates
485    let valid_from = asn1_time_to_chrono(cert.validity().not_before)?;
486    let valid_until = asn1_time_to_chrono(cert.validity().not_after)?;
487
488    let now = Utc::now();
489    let days_until_expiry = (valid_until - now).num_days();
490    let is_valid = now >= valid_from && now <= valid_until;
491
492    // Hostname verification is performed manually because the TLS connector
493    // was configured with danger_accept_invalid_certs(true) to allow cert
494    // inspection on mildly-broken sites. Without this check any cert — even
495    // one issued for an unrelated domain — would be accepted.
496    let hostname_verified = cert_matches_hostname(&cert, domain);
497
498    Ok(CertificateInfo {
499        issuer,
500        subject,
501        valid_from,
502        valid_until,
503        days_until_expiry,
504        is_valid,
505        hostname_verified,
506    })
507}
508
509/// Matches a hostname against a certificate name pattern.
510///
511/// Supports exact matches (case-insensitive) and single-label wildcards
512/// per RFC 6125 — `*.example.com` matches `a.example.com` but not
513/// `example.com` or `a.b.example.com`.
514fn hostname_matches_pattern(host: &str, pattern: &str) -> bool {
515    let host = host.to_ascii_lowercase();
516    let pattern = pattern.to_ascii_lowercase();
517    if let Some(rest) = pattern.strip_prefix("*.") {
518        // Wildcard: must match exactly one label, and must contain a dot
519        let Some(dot) = host.find('.') else {
520            return false;
521        };
522        let host_rest = &host[dot + 1..];
523        host_rest == rest
524    } else {
525        host == pattern
526    }
527}
528
529/// Checks whether a certificate's SAN dNSName entries (or CN as fallback)
530/// match the queried hostname.
531///
532/// Per RFC 6125, SAN dNSName is the authoritative source; CN is only checked
533/// as a legacy fallback.
534fn cert_matches_hostname(cert: &x509_parser::certificate::X509Certificate<'_>, host: &str) -> bool {
535    use x509_parser::prelude::*;
536
537    // SAN dNSName entries (preferred per RFC 6125)
538    if let Ok(Some(san_ext)) = cert.tbs_certificate.subject_alternative_name() {
539        for name in &san_ext.value.general_names {
540            if let GeneralName::DNSName(n) = name {
541                if hostname_matches_pattern(host, n) {
542                    return true;
543                }
544            }
545        }
546    }
547
548    // CN fallback (legacy)
549    for cn in cert.subject().iter_common_name() {
550        if let Ok(s) = cn.as_str() {
551            if hostname_matches_pattern(host, s) {
552                return true;
553            }
554        }
555    }
556
557    false
558}
559
560/// Extracts the Common Name or Organization from an X.509 name.
561fn extract_name_from_x509(name: &x509_parser::prelude::X509Name) -> Option<String> {
562    use x509_parser::prelude::*;
563
564    // Try Common Name first (OID 2.5.4.3)
565    for rdn in name.iter() {
566        for attr in rdn.iter() {
567            if attr.attr_type() == &oid_registry::OID_X509_COMMON_NAME {
568                if let Some(s) = extract_attr_string(attr.attr_value()) {
569                    return Some(s);
570                }
571            }
572        }
573    }
574
575    // Fall back to Organization (OID 2.5.4.10)
576    for rdn in name.iter() {
577        for attr in rdn.iter() {
578            if attr.attr_type() == &oid_registry::OID_X509_ORGANIZATION_NAME {
579                if let Some(s) = extract_attr_string(attr.attr_value()) {
580                    return Some(s);
581                }
582            }
583        }
584    }
585
586    None
587}
588
589/// Extracts a string from an ASN.1 attribute value, handling different encodings.
590fn extract_attr_string(value: &x509_parser::der_parser::asn1_rs::Any) -> Option<String> {
591    // Try as_str() first (handles PrintableString, IA5String, etc.)
592    if let Ok(s) = value.as_str() {
593        return Some(s.to_string());
594    }
595
596    // Try UTF8String explicitly
597    if let Ok(utf8) = value.as_utf8string() {
598        return Some(utf8.string().to_string());
599    }
600
601    // Try raw bytes as UTF-8
602    if let Ok(s) = std::str::from_utf8(value.data) {
603        return Some(s.to_string());
604    }
605
606    None
607}
608
609/// Converts an x509-parser ASN1Time to a chrono DateTime.
610fn asn1_time_to_chrono(time: x509_parser::time::ASN1Time) -> Result<chrono::DateTime<Utc>> {
611    let timestamp = time.timestamp();
612    chrono::DateTime::from_timestamp(timestamp, 0)
613        .ok_or_else(|| SeerError::CertificateError("invalid certificate timestamp".to_string()))
614}
615
616#[cfg(test)]
617mod tests {
618    use super::*;
619
620    #[test]
621    fn hostname_matches_pattern_exact() {
622        assert!(hostname_matches_pattern("example.com", "example.com"));
623        assert!(hostname_matches_pattern("EXAMPLE.COM", "example.com"));
624        assert!(hostname_matches_pattern("example.com", "EXAMPLE.COM"));
625        assert!(!hostname_matches_pattern("evil.com", "example.com"));
626        assert!(!hostname_matches_pattern("example.com", "evil.com"));
627    }
628
629    #[test]
630    fn hostname_matches_pattern_wildcard() {
631        assert!(hostname_matches_pattern("a.example.com", "*.example.com"));
632        assert!(hostname_matches_pattern("A.EXAMPLE.COM", "*.example.com"));
633        // Apex must not match wildcard (RFC 6125)
634        assert!(!hostname_matches_pattern("example.com", "*.example.com"));
635        // Wildcard only covers a single label
636        assert!(!hostname_matches_pattern(
637            "a.b.example.com",
638            "*.example.com"
639        ));
640        assert!(!hostname_matches_pattern("b.other.com", "*.example.com"));
641    }
642
643    #[test]
644    fn hostname_matches_pattern_wildcard_requires_dot() {
645        // A bare host with no dot cannot match a wildcard pattern
646        assert!(!hostname_matches_pattern("localhost", "*.example.com"));
647    }
648}