1use std::collections::HashSet;
10use std::net::SocketAddr;
11use std::time::Duration;
12
13use chrono::Utc;
14use native_tls::TlsConnector;
15use once_cell::sync::Lazy;
16use regex::Regex;
17use reqwest::{Client, Url};
18use tokio::net::TcpStream;
19use tracing::{debug, instrument};
20
21use super::types::{CertificateInfo, DnsResolution, DomainExpiration, StatusResponse};
22use crate::caa::{self, CaaPolicy};
23use crate::dns::{DnsResolver, RecordData, RecordType};
24use crate::error::{Result, SeerError};
25use crate::lookup::SmartLookup;
26use crate::validation::{describe_reserved_ip, normalize_domain};
27
28const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
31const MAX_REDIRECTS: usize = 5;
32
33static TITLE_REGEX: Lazy<Regex> = Lazy::new(|| {
35 Regex::new(r"(?i)<title[^>]*>([^<]+)</title>").expect("Invalid regex for HTML title extraction")
36});
37
38#[derive(Debug, Clone)]
40pub struct StatusClient {
41 timeout: Duration,
42 dns_resolver: DnsResolver,
44 smart_lookup: SmartLookup,
46}
47
48impl Default for StatusClient {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl StatusClient {
55 pub fn new() -> Self {
57 Self {
58 timeout: DEFAULT_TIMEOUT,
59 dns_resolver: DnsResolver::new(),
60 smart_lookup: SmartLookup::new(),
61 }
62 }
63
64 pub fn with_timeout(mut self, timeout: Duration) -> Self {
66 self.timeout = timeout;
67 self
68 }
69
70 #[instrument(skip(self), fields(domain = %domain))]
72 pub async fn check(&self, domain: &str) -> Result<StatusResponse> {
73 let domain = normalize_domain(domain)?;
75 debug!("Checking status for domain: {}", domain);
76
77 let mut response = StatusResponse::new(domain.clone());
78
79 let (http_result, cert_result, expiry_result, dns_result, caa_policy) = tokio::join!(
84 self.fetch_http_info(&domain),
85 self.fetch_certificate_info(&domain),
86 self.fetch_domain_expiration(&domain),
87 self.fetch_dns_resolution(&domain),
88 caa::lookup_caa(&self.dns_resolver, &domain),
89 );
90
91 match http_result {
93 Ok((status, status_text, title)) => {
94 response.http_status = Some(status);
95 response.http_status_text = Some(status_text);
96 response.title = title;
97 }
98 Err(e) => response.errors.push(super::types::StatusError {
99 check: "http".to_string(),
100 message: e.to_string(),
101 }),
102 }
103
104 let mut caa_policy: CaaPolicy = caa_policy;
107 match cert_result {
108 Ok(cert_info) => {
109 caa_policy.issuer_match =
110 Some(caa::classify_issuer(&cert_info.issuer, &caa_policy));
111 response.certificate = Some(cert_info);
112 }
113 Err(e) => response.errors.push(super::types::StatusError {
114 check: "ssl".to_string(),
115 message: e.to_string(),
116 }),
117 }
118 response.caa = Some(caa_policy);
119
120 match expiry_result {
122 Ok(expiry_info) => response.domain_expiration = expiry_info,
123 Err(e) => response.errors.push(super::types::StatusError {
124 check: "expiration".to_string(),
125 message: e.to_string(),
126 }),
127 }
128
129 match dns_result {
131 Ok(dns_info) => response.dns_resolution = Some(dns_info),
132 Err(e) => response.errors.push(super::types::StatusError {
133 check: "dns".to_string(),
134 message: e.to_string(),
135 }),
136 }
137
138 Ok(response)
139 }
140
141 async fn fetch_http_info(&self, domain: &str) -> Result<(u16, String, Option<String>)> {
160 let mut url = Url::parse(&format!("https://{}", domain))
161 .map_err(|e| SeerError::HttpError(format!("invalid URL: {}", e)))?;
162 let mut visited = HashSet::new();
163
164 for _ in 0..=MAX_REDIRECTS {
165 let validated_addrs = validate_url_target(&url).await?;
166
167 if !visited.insert(url.clone()) {
168 return Err(SeerError::HttpError("redirect loop detected".to_string()));
169 }
170
171 let host = url
175 .host_str()
176 .ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
177 let client = Client::builder()
178 .redirect(reqwest::redirect::Policy::none())
179 .user_agent(concat!("Seer/", env!("CARGO_PKG_VERSION")))
180 .resolve_to_addrs(host, &validated_addrs)
181 .build()
182 .map_err(|e| SeerError::HttpError(format!("failed to build HTTP client: {}", e)))?;
183
184 let response = client
185 .get(url.clone())
186 .timeout(self.timeout)
187 .send()
188 .await
189 .map_err(|e| SeerError::HttpError(e.to_string()))?;
190
191 if response.status().is_redirection() {
192 let location = response.headers().get(reqwest::header::LOCATION);
193 let location = location.and_then(|v| v.to_str().ok()).ok_or_else(|| {
194 SeerError::HttpError("redirect missing location header".to_string())
195 })?;
196 let next_url = url
197 .join(location)
198 .or_else(|_| Url::parse(location))
199 .map_err(|e| SeerError::HttpError(format!("invalid redirect URL: {}", e)))?;
200 url = next_url;
201 continue;
202 }
203
204 let status = response.status();
205 let status_code = status.as_u16();
206 let status_text = status.canonical_reason().unwrap_or("Unknown").to_string();
207
208 let title = if status.is_success() {
210 let content_type = response
211 .headers()
212 .get("content-type")
213 .and_then(|v| v.to_str().ok())
214 .unwrap_or("");
215
216 if content_type.contains("text/html") {
217 const MAX_TITLE_BODY: usize = 64 * 1024;
222 use futures::StreamExt;
223 let mut buf: Vec<u8> = Vec::with_capacity(8 * 1024);
224 let mut stream = response.bytes_stream();
225 while let Some(chunk) = stream.next().await {
226 let chunk = chunk
227 .map_err(|e| SeerError::HttpError(format!("body chunk: {}", e)))?;
228 let remaining = MAX_TITLE_BODY.saturating_sub(buf.len());
229 if remaining == 0 {
230 break;
231 }
232 let take = remaining.min(chunk.len());
233 buf.extend_from_slice(&chunk[..take]);
234 if buf.len() >= MAX_TITLE_BODY {
235 break;
236 }
237 }
238 let body = String::from_utf8_lossy(&buf);
239 extract_title(&body)
240 } else {
241 None
242 }
243 } else {
244 None
245 };
246
247 return Ok((status_code, status_text, title));
248 }
249
250 Err(SeerError::HttpError("too many redirects".to_string()))
251 }
252
253 async fn fetch_certificate_info(&self, domain: &str) -> Result<CertificateInfo> {
260 let socket_addrs = crate::net::resolve_public_host(domain, 443)
265 .await
266 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
267
268 let connector = TlsConnector::builder()
269 .danger_accept_invalid_certs(true) .build()
271 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
272
273 let connector = tokio_native_tls::TlsConnector::from(connector);
274
275 let stream =
278 tokio::time::timeout(self.timeout, TcpStream::connect(socket_addrs.as_slice()))
279 .await
280 .map_err(|_| SeerError::Timeout(format!("connection to {} timed out", domain)))?
281 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
282
283 let tls_stream = tokio::time::timeout(self.timeout, connector.connect(domain, stream))
285 .await
286 .map_err(|_| SeerError::Timeout(format!("TLS handshake with {} timed out", domain)))?
287 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
288
289 let cert = tls_stream
291 .get_ref()
292 .peer_certificate()
293 .map_err(|e| SeerError::CertificateError(e.to_string()))?
294 .ok_or_else(|| SeerError::CertificateError("no certificate found".to_string()))?;
295
296 let der = cert
298 .to_der()
299 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
300
301 parse_certificate_der(&der, domain)
302 }
303
304 async fn fetch_domain_expiration(&self, domain: &str) -> Result<Option<DomainExpiration>> {
306 match self.smart_lookup.lookup(domain).await {
307 Ok(result) => {
308 let (expiration_date, registrar) = result.expiration_info();
309
310 if let Some(exp_date) = expiration_date {
311 let days_until_expiry = (exp_date - Utc::now()).num_days();
312 Ok(Some(DomainExpiration {
313 expiration_date: exp_date,
314 days_until_expiry,
315 registrar,
316 }))
317 } else {
318 Ok(None)
319 }
320 }
321 Err(_) => Ok(None), }
323 }
324
325 async fn fetch_dns_resolution(&self, domain: &str) -> Result<DnsResolution> {
327 let resolver = &self.dns_resolver;
328
329 let (a_result, aaaa_result, cname_result, ns_result) = tokio::join!(
331 resolver.resolve(domain, RecordType::A, None),
332 resolver.resolve(domain, RecordType::AAAA, None),
333 resolver.resolve(domain, RecordType::CNAME, None),
334 resolver.resolve(domain, RecordType::NS, None)
335 );
336
337 let a_records: Vec<String> = a_result
339 .unwrap_or_default()
340 .into_iter()
341 .filter_map(|r| {
342 if let RecordData::A { address } = r.data {
343 Some(address)
344 } else {
345 None
346 }
347 })
348 .collect();
349
350 let aaaa_records: Vec<String> = aaaa_result
352 .unwrap_or_default()
353 .into_iter()
354 .filter_map(|r| {
355 if let RecordData::AAAA { address } = r.data {
356 Some(address)
357 } else {
358 None
359 }
360 })
361 .collect();
362
363 let cname_target: Option<String> =
365 cname_result.unwrap_or_default().into_iter().find_map(|r| {
366 if let RecordData::CNAME { target } = r.data {
367 Some(target.trim_end_matches('.').to_string())
368 } else {
369 None
370 }
371 });
372
373 let nameservers: Vec<String> = ns_result
375 .unwrap_or_default()
376 .into_iter()
377 .filter_map(|r| {
378 if let RecordData::NS { nameserver } = r.data {
379 Some(nameserver.trim_end_matches('.').to_string())
380 } else {
381 None
382 }
383 })
384 .collect();
385
386 let resolves = !a_records.is_empty() || !aaaa_records.is_empty() || cname_target.is_some();
388
389 Ok(DnsResolution {
390 a_records,
391 aaaa_records,
392 cname_target,
393 nameservers,
394 resolves,
395 })
396 }
397}
398
399fn extract_title(html: &str) -> Option<String> {
410 TITLE_REGEX
411 .captures(html)
412 .and_then(|caps| caps.get(1))
413 .map(|m| {
414 m.as_str()
420 .chars()
421 .filter(|c| !c.is_control())
422 .collect::<String>()
423 .trim()
424 .to_string()
425 })
426 .filter(|s| !s.is_empty())
427}
428
429async fn validate_url_target(url: &Url) -> Result<Vec<SocketAddr>> {
435 let scheme = url.scheme();
436 if scheme != "https" && scheme != "http" {
437 return Err(SeerError::HttpError(format!(
438 "unsupported URL scheme: {}",
439 scheme
440 )));
441 }
442
443 if !url.username().is_empty() || url.password().is_some() {
444 return Err(SeerError::HttpError(
445 "URL credentials are not allowed".to_string(),
446 ));
447 }
448
449 let host = url
450 .host_str()
451 .ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
452 let port = url.port_or_known_default().unwrap_or(443);
453
454 if port != 80 && port != 443 {
456 return Err(SeerError::HttpError(format!(
457 "non-standard port {} is not allowed in redirects",
458 port
459 )));
460 }
461
462 if let Ok(ip) = host.parse::<std::net::IpAddr>() {
463 if let Some(reason) = describe_reserved_ip(&ip) {
464 return Err(SeerError::HttpError(format!(
465 "cannot connect to {}: {} — {}",
466 host, ip, reason
467 )));
468 }
469 return Ok(vec![SocketAddr::new(ip, port)]);
470 }
471
472 let addr = format!("{}:{}", host, port);
473 let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
474 .await
475 .map_err(|e| SeerError::HttpError(format!("DNS lookup failed: {}", e)))?
476 .collect();
477
478 if socket_addrs.is_empty() {
479 return Err(SeerError::HttpError(format!(
480 "DNS lookup returned no addresses for {}",
481 host
482 )));
483 }
484
485 for socket_addr in &socket_addrs {
486 if let Some(reason) = describe_reserved_ip(&socket_addr.ip()) {
487 return Err(SeerError::HttpError(format!(
488 "cannot connect to {}: {} — {}",
489 host,
490 socket_addr.ip(),
491 reason
492 )));
493 }
494 }
495
496 Ok(socket_addrs)
497}
498
499fn parse_certificate_der(der: &[u8], domain: &str) -> Result<CertificateInfo> {
501 use x509_parser::prelude::*;
502
503 let (_, cert) = X509Certificate::from_der(der)
504 .map_err(|e| SeerError::CertificateError(format!("failed to parse certificate: {}", e)))?;
505
506 let issuer = format_issuer_name(cert.issuer()).unwrap_or_else(|| "Unknown Issuer".to_string());
511
512 let subject =
514 extract_name_from_x509(cert.subject()).unwrap_or_else(|| "Unknown Subject".to_string());
515
516 let valid_from = asn1_time_to_chrono(cert.validity().not_before)?;
518 let valid_until = asn1_time_to_chrono(cert.validity().not_after)?;
519
520 let now = Utc::now();
521 let days_until_expiry = (valid_until - now).num_days();
522 let is_valid = now >= valid_from && now <= valid_until;
523
524 let hostname_verified = cert_matches_hostname(&cert, domain);
529
530 Ok(CertificateInfo {
531 issuer,
532 subject,
533 valid_from,
534 valid_until,
535 days_until_expiry,
536 is_valid,
537 hostname_verified,
538 })
539}
540
541fn hostname_matches_pattern(host: &str, pattern: &str) -> bool {
547 let host = host.to_ascii_lowercase();
548 let pattern = pattern.to_ascii_lowercase();
549 if let Some(rest) = pattern.strip_prefix("*.") {
550 let Some(dot) = host.find('.') else {
552 return false;
553 };
554 let host_rest = &host[dot + 1..];
555 host_rest == rest
556 } else {
557 host == pattern
558 }
559}
560
561fn cert_matches_hostname(cert: &x509_parser::certificate::X509Certificate<'_>, host: &str) -> bool {
567 use x509_parser::prelude::*;
568
569 if let Ok(Some(san_ext)) = cert.tbs_certificate.subject_alternative_name() {
571 for name in &san_ext.value.general_names {
572 if let GeneralName::DNSName(n) = name {
573 if hostname_matches_pattern(host, n) {
574 return true;
575 }
576 }
577 }
578 }
579
580 for cn in cert.subject().iter_common_name() {
582 if let Ok(s) = cn.as_str() {
583 if hostname_matches_pattern(host, s) {
584 return true;
585 }
586 }
587 }
588
589 false
590}
591
592fn format_issuer_name(name: &x509_parser::prelude::X509Name) -> Option<String> {
596 use x509_parser::oid_registry;
597 let cn = extract_oid_value(name, &oid_registry::OID_X509_COMMON_NAME);
598 let org = extract_oid_value(name, &oid_registry::OID_X509_ORGANIZATION_NAME);
599 match (org, cn) {
600 (Some(o), Some(c)) if o != c => Some(format!("{} ({})", o, c)),
601 (Some(o), _) => Some(o),
602 (None, Some(c)) => Some(c),
603 (None, None) => None,
604 }
605}
606
607fn extract_oid_value(
609 name: &x509_parser::prelude::X509Name,
610 oid: &x509_parser::der_parser::oid::Oid<'static>,
611) -> Option<String> {
612 for rdn in name.iter() {
613 for attr in rdn.iter() {
614 if attr.attr_type() == oid {
615 if let Some(s) = extract_attr_string(attr.attr_value()) {
616 return Some(s);
617 }
618 }
619 }
620 }
621 None
622}
623
624fn extract_name_from_x509(name: &x509_parser::prelude::X509Name) -> Option<String> {
626 use x509_parser::prelude::*;
627
628 for rdn in name.iter() {
630 for attr in rdn.iter() {
631 if attr.attr_type() == &oid_registry::OID_X509_COMMON_NAME {
632 if let Some(s) = extract_attr_string(attr.attr_value()) {
633 return Some(s);
634 }
635 }
636 }
637 }
638
639 for rdn in name.iter() {
641 for attr in rdn.iter() {
642 if attr.attr_type() == &oid_registry::OID_X509_ORGANIZATION_NAME {
643 if let Some(s) = extract_attr_string(attr.attr_value()) {
644 return Some(s);
645 }
646 }
647 }
648 }
649
650 None
651}
652
653fn extract_attr_string(value: &x509_parser::der_parser::asn1_rs::Any) -> Option<String> {
655 if let Ok(s) = value.as_str() {
657 return Some(s.to_string());
658 }
659
660 if let Ok(utf8) = value.as_utf8string() {
662 return Some(utf8.string().to_string());
663 }
664
665 if let Ok(s) = std::str::from_utf8(value.data) {
667 return Some(s.to_string());
668 }
669
670 None
671}
672
673fn asn1_time_to_chrono(time: x509_parser::time::ASN1Time) -> Result<chrono::DateTime<Utc>> {
675 let timestamp = time.timestamp();
676 chrono::DateTime::from_timestamp(timestamp, 0)
677 .ok_or_else(|| SeerError::CertificateError("invalid certificate timestamp".to_string()))
678}
679
680#[cfg(test)]
681mod tests {
682 use super::*;
683
684 #[test]
685 fn hostname_matches_pattern_exact() {
686 assert!(hostname_matches_pattern("example.com", "example.com"));
687 assert!(hostname_matches_pattern("EXAMPLE.COM", "example.com"));
688 assert!(hostname_matches_pattern("example.com", "EXAMPLE.COM"));
689 assert!(!hostname_matches_pattern("evil.com", "example.com"));
690 assert!(!hostname_matches_pattern("example.com", "evil.com"));
691 }
692
693 #[test]
694 fn hostname_matches_pattern_wildcard() {
695 assert!(hostname_matches_pattern("a.example.com", "*.example.com"));
696 assert!(hostname_matches_pattern("A.EXAMPLE.COM", "*.example.com"));
697 assert!(!hostname_matches_pattern("example.com", "*.example.com"));
699 assert!(!hostname_matches_pattern(
701 "a.b.example.com",
702 "*.example.com"
703 ));
704 assert!(!hostname_matches_pattern("b.other.com", "*.example.com"));
705 }
706
707 #[test]
708 fn hostname_matches_pattern_wildcard_requires_dot() {
709 assert!(!hostname_matches_pattern("localhost", "*.example.com"));
711 }
712}