Skip to main content

seer_core/status/
client.rs

1use std::collections::HashSet;
2use std::time::Duration;
3
4use chrono::Utc;
5use native_tls::TlsConnector;
6use once_cell::sync::Lazy;
7use regex::Regex;
8use reqwest::{Client, Url};
9use tokio::net::TcpStream;
10use tracing::{debug, instrument};
11
12use super::types::{CertificateInfo, DnsResolution, DomainExpiration, StatusResponse};
13use crate::dns::{DnsResolver, RecordData, RecordType};
14use crate::error::{Result, SeerError};
15use crate::lookup::SmartLookup;
16use crate::validation::{is_private_or_reserved_ip, normalize_domain};
17
18/// Default timeout for HTTP and TLS operations (10 seconds).
19/// Balances responsiveness with allowing slow servers to respond.
20const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
21const MAX_REDIRECTS: usize = 5;
22
23/// Pre-compiled regex for extracting HTML title.
24static TITLE_REGEX: Lazy<Regex> = Lazy::new(|| {
25    Regex::new(r"(?i)<title[^>]*>([^<]+)</title>").expect("Invalid regex for HTML title extraction")
26});
27
28/// Shared HTTP client for status checks. Reusing a single Client enables
29/// connection pooling and avoids per-request TLS handshake overhead.
30static STATUS_HTTP_CLIENT: Lazy<Client> = Lazy::new(|| {
31    Client::builder()
32        .redirect(reqwest::redirect::Policy::none())
33        .user_agent(concat!("Seer/", env!("CARGO_PKG_VERSION")))
34        .pool_max_idle_per_host(10)
35        .build()
36        .expect("Failed to build status HTTP client - invalid configuration")
37});
38
39/// Client for checking domain status (HTTP, SSL, expiration)
40#[derive(Debug, Clone)]
41pub struct StatusClient {
42    timeout: Duration,
43    /// Cached DNS resolver reused across status checks.
44    dns_resolver: DnsResolver,
45}
46
47impl Default for StatusClient {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl StatusClient {
54    /// Creates a new StatusClient with default settings.
55    pub fn new() -> Self {
56        Self {
57            timeout: DEFAULT_TIMEOUT,
58            dns_resolver: DnsResolver::new(),
59        }
60    }
61
62    /// Sets the timeout for HTTP and TLS operations.
63    pub fn with_timeout(mut self, timeout: Duration) -> Self {
64        self.timeout = timeout;
65        self
66    }
67
68    /// Checks the status of a domain (HTTP, SSL, expiration, DNS).
69    #[instrument(skip(self), fields(domain = %domain))]
70    pub async fn check(&self, domain: &str) -> Result<StatusResponse> {
71        // Normalize domain format (doesn't require DNS resolution)
72        let domain = normalize_domain(domain)?;
73        debug!("Checking status for domain: {}", domain);
74
75        let mut response = StatusResponse::new(domain.clone());
76
77        // Fetch HTTP status, SSL cert info, domain expiration, and DNS resolution concurrently
78        // HTTP and SSL checks include SSRF protection internally
79        let (http_result, cert_result, expiry_result, dns_result) = tokio::join!(
80            self.fetch_http_info(&domain),
81            self.fetch_certificate_info(&domain),
82            self.fetch_domain_expiration(&domain),
83            self.fetch_dns_resolution(&domain)
84        );
85
86        // Apply HTTP info
87        if let Ok((status, status_text, title)) = http_result {
88            response.http_status = Some(status);
89            response.http_status_text = Some(status_text);
90            response.title = title;
91        }
92
93        // Apply certificate info
94        if let Ok(cert_info) = cert_result {
95            response.certificate = Some(cert_info);
96        }
97
98        // Apply domain expiration info
99        if let Ok(expiry_info) = expiry_result {
100            response.domain_expiration = expiry_info;
101        }
102
103        // Apply DNS resolution info
104        if let Ok(dns_info) = dns_result {
105            response.dns_resolution = Some(dns_info);
106        }
107
108        Ok(response)
109    }
110
111    /// Fetches the HTTP status code and page title.
112    async fn fetch_http_info(&self, domain: &str) -> Result<(u16, String, Option<String>)> {
113        let mut url = Url::parse(&format!("https://{}", domain))
114            .map_err(|e| SeerError::HttpError(format!("invalid URL: {}", e)))?;
115        let mut visited = HashSet::new();
116
117        for _ in 0..=MAX_REDIRECTS {
118            validate_url_target(&url).await?;
119
120            if !visited.insert(url.clone()) {
121                return Err(SeerError::HttpError("redirect loop detected".to_string()));
122            }
123
124            let response = STATUS_HTTP_CLIENT
125                .get(url.clone())
126                .timeout(self.timeout)
127                .send()
128                .await
129                .map_err(|e| SeerError::HttpError(e.to_string()))?;
130
131            if response.status().is_redirection() {
132                let location = response.headers().get(reqwest::header::LOCATION);
133                let location = location.and_then(|v| v.to_str().ok()).ok_or_else(|| {
134                    SeerError::HttpError("redirect missing location header".to_string())
135                })?;
136                let next_url = url
137                    .join(location)
138                    .or_else(|_| Url::parse(location))
139                    .map_err(|e| SeerError::HttpError(format!("invalid redirect URL: {}", e)))?;
140                url = next_url;
141                continue;
142            }
143
144            let status = response.status();
145            let status_code = status.as_u16();
146            let status_text = status.canonical_reason().unwrap_or("Unknown").to_string();
147
148            // Only try to get title for successful HTML responses
149            let title = if status.is_success() {
150                let content_type = response
151                    .headers()
152                    .get("content-type")
153                    .and_then(|v| v.to_str().ok())
154                    .unwrap_or("");
155
156                if content_type.contains("text/html") {
157                    let body = response
158                        .text()
159                        .await
160                        .map_err(|e| SeerError::HttpError(e.to_string()))?;
161                    extract_title(&body)
162                } else {
163                    None
164                }
165            } else {
166                None
167            };
168
169            return Ok((status_code, status_text, title));
170        }
171
172        Err(SeerError::HttpError("too many redirects".to_string()))
173    }
174
175    /// Fetches SSL certificate information using native-tls.
176    async fn fetch_certificate_info(&self, domain: &str) -> Result<CertificateInfo> {
177        // SSRF protection: resolve domain and check IPs before connecting
178        let addr = format!("{}:443", domain);
179        let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
180            .await
181            .map_err(|e| SeerError::CertificateError(format!("DNS lookup failed: {}", e)))?
182            .collect();
183
184        for socket_addr in &socket_addrs {
185            if is_private_or_reserved_ip(&socket_addr.ip()) {
186                return Err(SeerError::CertificateError(format!(
187                    "domain resolves to private/reserved IP: {}",
188                    socket_addr.ip()
189                )));
190            }
191        }
192
193        let connector = TlsConnector::builder()
194            .danger_accept_invalid_certs(true) // We want to see the cert even if invalid
195            .build()
196            .map_err(|e| SeerError::CertificateError(e.to_string()))?;
197
198        let connector = tokio_native_tls::TlsConnector::from(connector);
199
200        let stream = tokio::time::timeout(self.timeout, TcpStream::connect(&addr))
201            .await
202            .map_err(|_| SeerError::Timeout(format!("connection to {} timed out", domain)))?
203            .map_err(|e| SeerError::CertificateError(e.to_string()))?;
204
205        let tls_stream = tokio::time::timeout(self.timeout, connector.connect(domain, stream))
206            .await
207            .map_err(|_| SeerError::Timeout(format!("TLS handshake with {} timed out", domain)))?
208            .map_err(|e| SeerError::CertificateError(e.to_string()))?;
209
210        // Get the peer certificate
211        let cert = tls_stream
212            .get_ref()
213            .peer_certificate()
214            .map_err(|e| SeerError::CertificateError(e.to_string()))?
215            .ok_or_else(|| SeerError::CertificateError("no certificate found".to_string()))?;
216
217        // Parse certificate info
218        let der = cert
219            .to_der()
220            .map_err(|e| SeerError::CertificateError(e.to_string()))?;
221
222        parse_certificate_der(&der, domain)
223    }
224
225    /// Fetches domain expiration info using WHOIS/RDAP.
226    async fn fetch_domain_expiration(&self, domain: &str) -> Result<Option<DomainExpiration>> {
227        let lookup = SmartLookup::new();
228
229        match lookup.lookup(domain).await {
230            Ok(result) => {
231                let (expiration_date, registrar) = result.expiration_info();
232
233                if let Some(exp_date) = expiration_date {
234                    let days_until_expiry = (exp_date - Utc::now()).num_days();
235                    Ok(Some(DomainExpiration {
236                        expiration_date: exp_date,
237                        days_until_expiry,
238                        registrar,
239                    }))
240                } else {
241                    Ok(None)
242                }
243            }
244            Err(_) => Ok(None), // Don't fail the whole status check if WHOIS fails
245        }
246    }
247
248    /// Fetches DNS root record resolution (A, AAAA, CNAME, NS).
249    async fn fetch_dns_resolution(&self, domain: &str) -> Result<DnsResolution> {
250        let resolver = &self.dns_resolver;
251
252        // Query all record types concurrently
253        let (a_result, aaaa_result, cname_result, ns_result) = tokio::join!(
254            resolver.resolve(domain, RecordType::A, None),
255            resolver.resolve(domain, RecordType::AAAA, None),
256            resolver.resolve(domain, RecordType::CNAME, None),
257            resolver.resolve(domain, RecordType::NS, None)
258        );
259
260        // Extract A records
261        let a_records: Vec<String> = a_result
262            .unwrap_or_default()
263            .into_iter()
264            .filter_map(|r| {
265                if let RecordData::A { address } = r.data {
266                    Some(address)
267                } else {
268                    None
269                }
270            })
271            .collect();
272
273        // Extract AAAA records
274        let aaaa_records: Vec<String> = aaaa_result
275            .unwrap_or_default()
276            .into_iter()
277            .filter_map(|r| {
278                if let RecordData::AAAA { address } = r.data {
279                    Some(address)
280                } else {
281                    None
282                }
283            })
284            .collect();
285
286        // Extract CNAME target (trim trailing dot)
287        let cname_target: Option<String> =
288            cname_result.unwrap_or_default().into_iter().find_map(|r| {
289                if let RecordData::CNAME { target } = r.data {
290                    Some(target.trim_end_matches('.').to_string())
291                } else {
292                    None
293                }
294            });
295
296        // Extract NS records (trim trailing dots)
297        let nameservers: Vec<String> = ns_result
298            .unwrap_or_default()
299            .into_iter()
300            .filter_map(|r| {
301                if let RecordData::NS { nameserver } = r.data {
302                    Some(nameserver.trim_end_matches('.').to_string())
303                } else {
304                    None
305                }
306            })
307            .collect();
308
309        // Domain resolves if it has A/AAAA records or a CNAME
310        let resolves = !a_records.is_empty() || !aaaa_records.is_empty() || cname_target.is_some();
311
312        Ok(DnsResolution {
313            a_records,
314            aaaa_records,
315            cname_target,
316            nameservers,
317            resolves,
318        })
319    }
320}
321
322// Domain normalization and validation is now handled by the validation module
323
324/// Extracts the title from HTML content.
325fn extract_title(html: &str) -> Option<String> {
326    TITLE_REGEX
327        .captures(html)
328        .and_then(|caps| caps.get(1))
329        .map(|m| m.as_str().trim().to_string())
330        .filter(|s| !s.is_empty())
331}
332
333async fn validate_url_target(url: &Url) -> Result<()> {
334    let scheme = url.scheme();
335    if scheme != "https" && scheme != "http" {
336        return Err(SeerError::HttpError(format!(
337            "unsupported URL scheme: {}",
338            scheme
339        )));
340    }
341
342    if !url.username().is_empty() || url.password().is_some() {
343        return Err(SeerError::HttpError(
344            "URL credentials are not allowed".to_string(),
345        ));
346    }
347
348    let host = url
349        .host_str()
350        .ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
351    let port = url.port_or_known_default().unwrap_or(443);
352
353    if let Ok(ip) = host.parse::<std::net::IpAddr>() {
354        if is_private_or_reserved_ip(&ip) {
355            return Err(SeerError::HttpError(format!(
356                "URL resolves to private/reserved IP: {}",
357                ip
358            )));
359        }
360        return Ok(());
361    }
362
363    let addr = format!("{}:{}", host, port);
364    let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
365        .await
366        .map_err(|e| SeerError::HttpError(format!("DNS lookup failed: {}", e)))?
367        .collect();
368
369    for socket_addr in &socket_addrs {
370        if is_private_or_reserved_ip(&socket_addr.ip()) {
371            return Err(SeerError::HttpError(format!(
372                "URL resolves to private/reserved IP: {}",
373                socket_addr.ip()
374            )));
375        }
376    }
377
378    Ok(())
379}
380
381/// Parses certificate information from DER-encoded certificate using x509-parser.
382fn parse_certificate_der(der: &[u8], _domain: &str) -> Result<CertificateInfo> {
383    use x509_parser::prelude::*;
384
385    let (_, cert) = X509Certificate::from_der(der)
386        .map_err(|e| SeerError::CertificateError(format!("failed to parse certificate: {}", e)))?;
387
388    // Extract issuer - prefer CN, fall back to O (Organization)
389    let issuer =
390        extract_name_from_x509(cert.issuer()).unwrap_or_else(|| "Unknown Issuer".to_string());
391
392    // Extract subject - prefer CN, fall back to O (Organization)
393    let subject =
394        extract_name_from_x509(cert.subject()).unwrap_or_else(|| "Unknown Subject".to_string());
395
396    // Extract validity dates
397    let valid_from = asn1_time_to_chrono(cert.validity().not_before)?;
398    let valid_until = asn1_time_to_chrono(cert.validity().not_after)?;
399
400    let now = Utc::now();
401    let days_until_expiry = (valid_until - now).num_days();
402    let is_valid = now >= valid_from && now <= valid_until;
403
404    Ok(CertificateInfo {
405        issuer,
406        subject,
407        valid_from,
408        valid_until,
409        days_until_expiry,
410        is_valid,
411    })
412}
413
414/// Extracts the Common Name or Organization from an X.509 name.
415fn extract_name_from_x509(name: &x509_parser::prelude::X509Name) -> Option<String> {
416    use x509_parser::prelude::*;
417
418    // Try Common Name first (OID 2.5.4.3)
419    for rdn in name.iter() {
420        for attr in rdn.iter() {
421            if attr.attr_type() == &oid_registry::OID_X509_COMMON_NAME {
422                if let Some(s) = extract_attr_string(attr.attr_value()) {
423                    return Some(s);
424                }
425            }
426        }
427    }
428
429    // Fall back to Organization (OID 2.5.4.10)
430    for rdn in name.iter() {
431        for attr in rdn.iter() {
432            if attr.attr_type() == &oid_registry::OID_X509_ORGANIZATION_NAME {
433                if let Some(s) = extract_attr_string(attr.attr_value()) {
434                    return Some(s);
435                }
436            }
437        }
438    }
439
440    None
441}
442
443/// Extracts a string from an ASN.1 attribute value, handling different encodings.
444fn extract_attr_string(value: &x509_parser::der_parser::asn1_rs::Any) -> Option<String> {
445    // Try as_str() first (handles PrintableString, IA5String, etc.)
446    if let Ok(s) = value.as_str() {
447        return Some(s.to_string());
448    }
449
450    // Try UTF8String explicitly
451    if let Ok(utf8) = value.as_utf8string() {
452        return Some(utf8.string().to_string());
453    }
454
455    // Try raw bytes as UTF-8
456    if let Ok(s) = std::str::from_utf8(value.data) {
457        return Some(s.to_string());
458    }
459
460    None
461}
462
463/// Converts an x509-parser ASN1Time to a chrono DateTime.
464fn asn1_time_to_chrono(time: x509_parser::time::ASN1Time) -> Result<chrono::DateTime<Utc>> {
465    let timestamp = time.timestamp();
466    chrono::DateTime::from_timestamp(timestamp, 0)
467        .ok_or_else(|| SeerError::CertificateError("invalid certificate timestamp".to_string()))
468}