1use std::collections::HashSet;
2use std::net::SocketAddr;
3use std::time::Duration;
4
5use chrono::Utc;
6use native_tls::TlsConnector;
7use once_cell::sync::Lazy;
8use regex::Regex;
9use reqwest::{Client, Url};
10use tokio::net::TcpStream;
11use tracing::{debug, instrument};
12
13use super::types::{CertificateInfo, DnsResolution, DomainExpiration, StatusResponse};
14use crate::caa::{self, CaaPolicy};
15use crate::dns::{DnsResolver, RecordData, RecordType};
16use crate::error::{Result, SeerError};
17use crate::lookup::SmartLookup;
18use crate::validation::{describe_reserved_ip, normalize_domain};
19
20const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
23const MAX_REDIRECTS: usize = 5;
24
25static TITLE_REGEX: Lazy<Regex> = Lazy::new(|| {
27 Regex::new(r"(?i)<title[^>]*>([^<]+)</title>").expect("Invalid regex for HTML title extraction")
28});
29
30#[derive(Debug, Clone)]
32pub struct StatusClient {
33 timeout: Duration,
34 dns_resolver: DnsResolver,
36 smart_lookup: SmartLookup,
38}
39
40impl Default for StatusClient {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl StatusClient {
47 pub fn new() -> Self {
49 Self {
50 timeout: DEFAULT_TIMEOUT,
51 dns_resolver: DnsResolver::new(),
52 smart_lookup: SmartLookup::new(),
53 }
54 }
55
56 pub fn with_timeout(mut self, timeout: Duration) -> Self {
58 self.timeout = timeout;
59 self
60 }
61
62 #[instrument(skip(self), fields(domain = %domain))]
64 pub async fn check(&self, domain: &str) -> Result<StatusResponse> {
65 let domain = normalize_domain(domain)?;
67 debug!("Checking status for domain: {}", domain);
68
69 let mut response = StatusResponse::new(domain.clone());
70
71 let (http_result, cert_result, expiry_result, dns_result, caa_policy) = tokio::join!(
76 self.fetch_http_info(&domain),
77 self.fetch_certificate_info(&domain),
78 self.fetch_domain_expiration(&domain),
79 self.fetch_dns_resolution(&domain),
80 caa::lookup_caa(&self.dns_resolver, &domain),
81 );
82
83 match http_result {
85 Ok((status, status_text, title)) => {
86 response.http_status = Some(status);
87 response.http_status_text = Some(status_text);
88 response.title = title;
89 }
90 Err(e) => response.errors.push(super::types::StatusError {
91 check: "http".to_string(),
92 message: e.to_string(),
93 }),
94 }
95
96 let mut caa_policy: CaaPolicy = caa_policy;
99 match cert_result {
100 Ok(cert_info) => {
101 caa_policy.issuer_match =
102 Some(caa::classify_issuer(&cert_info.issuer, &caa_policy));
103 response.certificate = Some(cert_info);
104 }
105 Err(e) => response.errors.push(super::types::StatusError {
106 check: "ssl".to_string(),
107 message: e.to_string(),
108 }),
109 }
110 response.caa = Some(caa_policy);
111
112 match expiry_result {
114 Ok(expiry_info) => response.domain_expiration = expiry_info,
115 Err(e) => response.errors.push(super::types::StatusError {
116 check: "expiration".to_string(),
117 message: e.to_string(),
118 }),
119 }
120
121 match dns_result {
123 Ok(dns_info) => response.dns_resolution = Some(dns_info),
124 Err(e) => response.errors.push(super::types::StatusError {
125 check: "dns".to_string(),
126 message: e.to_string(),
127 }),
128 }
129
130 Ok(response)
131 }
132
133 async fn fetch_http_info(&self, domain: &str) -> Result<(u16, String, Option<String>)> {
152 let mut url = Url::parse(&format!("https://{}", domain))
153 .map_err(|e| SeerError::HttpError(format!("invalid URL: {}", e)))?;
154 let mut visited = HashSet::new();
155
156 for _ in 0..=MAX_REDIRECTS {
157 let validated_addrs = validate_url_target(&url).await?;
158
159 if !visited.insert(url.clone()) {
160 return Err(SeerError::HttpError("redirect loop detected".to_string()));
161 }
162
163 let host = url
167 .host_str()
168 .ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
169 let client = Client::builder()
170 .redirect(reqwest::redirect::Policy::none())
171 .user_agent(concat!("Seer/", env!("CARGO_PKG_VERSION")))
172 .resolve_to_addrs(host, &validated_addrs)
173 .build()
174 .map_err(|e| SeerError::HttpError(format!("failed to build HTTP client: {}", e)))?;
175
176 let response = client
177 .get(url.clone())
178 .timeout(self.timeout)
179 .send()
180 .await
181 .map_err(|e| SeerError::HttpError(e.to_string()))?;
182
183 if response.status().is_redirection() {
184 let location = response.headers().get(reqwest::header::LOCATION);
185 let location = location.and_then(|v| v.to_str().ok()).ok_or_else(|| {
186 SeerError::HttpError("redirect missing location header".to_string())
187 })?;
188 let next_url = url
189 .join(location)
190 .or_else(|_| Url::parse(location))
191 .map_err(|e| SeerError::HttpError(format!("invalid redirect URL: {}", e)))?;
192 url = next_url;
193 continue;
194 }
195
196 let status = response.status();
197 let status_code = status.as_u16();
198 let status_text = status.canonical_reason().unwrap_or("Unknown").to_string();
199
200 let title = if status.is_success() {
202 let content_type = response
203 .headers()
204 .get("content-type")
205 .and_then(|v| v.to_str().ok())
206 .unwrap_or("");
207
208 if content_type.contains("text/html") {
209 const MAX_TITLE_BODY: usize = 64 * 1024;
214 use futures::StreamExt;
215 let mut buf: Vec<u8> = Vec::with_capacity(8 * 1024);
216 let mut stream = response.bytes_stream();
217 while let Some(chunk) = stream.next().await {
218 let chunk = chunk
219 .map_err(|e| SeerError::HttpError(format!("body chunk: {}", e)))?;
220 let remaining = MAX_TITLE_BODY.saturating_sub(buf.len());
221 if remaining == 0 {
222 break;
223 }
224 let take = remaining.min(chunk.len());
225 buf.extend_from_slice(&chunk[..take]);
226 if buf.len() >= MAX_TITLE_BODY {
227 break;
228 }
229 }
230 let body = String::from_utf8_lossy(&buf);
231 extract_title(&body)
232 } else {
233 None
234 }
235 } else {
236 None
237 };
238
239 return Ok((status_code, status_text, title));
240 }
241
242 Err(SeerError::HttpError("too many redirects".to_string()))
243 }
244
245 async fn fetch_certificate_info(&self, domain: &str) -> Result<CertificateInfo> {
252 let addr = format!("{}:443", domain);
254 let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
255 .await
256 .map_err(|e| SeerError::CertificateError(format!("DNS lookup failed: {}", e)))?
257 .collect();
258
259 if socket_addrs.is_empty() {
260 return Err(SeerError::CertificateError(format!(
261 "DNS lookup returned no addresses for {}",
262 domain
263 )));
264 }
265
266 for socket_addr in &socket_addrs {
267 if let Some(reason) = describe_reserved_ip(&socket_addr.ip()) {
268 return Err(SeerError::CertificateError(format!(
269 "cannot connect to {}: {} — {}",
270 domain,
271 socket_addr.ip(),
272 reason
273 )));
274 }
275 }
276
277 let connector = TlsConnector::builder()
278 .danger_accept_invalid_certs(true) .build()
280 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
281
282 let connector = tokio_native_tls::TlsConnector::from(connector);
283
284 let stream =
287 tokio::time::timeout(self.timeout, TcpStream::connect(socket_addrs.as_slice()))
288 .await
289 .map_err(|_| SeerError::Timeout(format!("connection to {} timed out", domain)))?
290 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
291
292 let tls_stream = tokio::time::timeout(self.timeout, connector.connect(domain, stream))
294 .await
295 .map_err(|_| SeerError::Timeout(format!("TLS handshake with {} timed out", domain)))?
296 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
297
298 let cert = tls_stream
300 .get_ref()
301 .peer_certificate()
302 .map_err(|e| SeerError::CertificateError(e.to_string()))?
303 .ok_or_else(|| SeerError::CertificateError("no certificate found".to_string()))?;
304
305 let der = cert
307 .to_der()
308 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
309
310 parse_certificate_der(&der, domain)
311 }
312
313 async fn fetch_domain_expiration(&self, domain: &str) -> Result<Option<DomainExpiration>> {
315 match self.smart_lookup.lookup(domain).await {
316 Ok(result) => {
317 let (expiration_date, registrar) = result.expiration_info();
318
319 if let Some(exp_date) = expiration_date {
320 let days_until_expiry = (exp_date - Utc::now()).num_days();
321 Ok(Some(DomainExpiration {
322 expiration_date: exp_date,
323 days_until_expiry,
324 registrar,
325 }))
326 } else {
327 Ok(None)
328 }
329 }
330 Err(_) => Ok(None), }
332 }
333
334 async fn fetch_dns_resolution(&self, domain: &str) -> Result<DnsResolution> {
336 let resolver = &self.dns_resolver;
337
338 let (a_result, aaaa_result, cname_result, ns_result) = tokio::join!(
340 resolver.resolve(domain, RecordType::A, None),
341 resolver.resolve(domain, RecordType::AAAA, None),
342 resolver.resolve(domain, RecordType::CNAME, None),
343 resolver.resolve(domain, RecordType::NS, None)
344 );
345
346 let a_records: Vec<String> = a_result
348 .unwrap_or_default()
349 .into_iter()
350 .filter_map(|r| {
351 if let RecordData::A { address } = r.data {
352 Some(address)
353 } else {
354 None
355 }
356 })
357 .collect();
358
359 let aaaa_records: Vec<String> = aaaa_result
361 .unwrap_or_default()
362 .into_iter()
363 .filter_map(|r| {
364 if let RecordData::AAAA { address } = r.data {
365 Some(address)
366 } else {
367 None
368 }
369 })
370 .collect();
371
372 let cname_target: Option<String> =
374 cname_result.unwrap_or_default().into_iter().find_map(|r| {
375 if let RecordData::CNAME { target } = r.data {
376 Some(target.trim_end_matches('.').to_string())
377 } else {
378 None
379 }
380 });
381
382 let nameservers: Vec<String> = ns_result
384 .unwrap_or_default()
385 .into_iter()
386 .filter_map(|r| {
387 if let RecordData::NS { nameserver } = r.data {
388 Some(nameserver.trim_end_matches('.').to_string())
389 } else {
390 None
391 }
392 })
393 .collect();
394
395 let resolves = !a_records.is_empty() || !aaaa_records.is_empty() || cname_target.is_some();
397
398 Ok(DnsResolution {
399 a_records,
400 aaaa_records,
401 cname_target,
402 nameservers,
403 resolves,
404 })
405 }
406}
407
408fn extract_title(html: &str) -> Option<String> {
412 TITLE_REGEX
413 .captures(html)
414 .and_then(|caps| caps.get(1))
415 .map(|m| m.as_str().trim().to_string())
416 .filter(|s| !s.is_empty())
417}
418
419async fn validate_url_target(url: &Url) -> Result<Vec<SocketAddr>> {
425 let scheme = url.scheme();
426 if scheme != "https" && scheme != "http" {
427 return Err(SeerError::HttpError(format!(
428 "unsupported URL scheme: {}",
429 scheme
430 )));
431 }
432
433 if !url.username().is_empty() || url.password().is_some() {
434 return Err(SeerError::HttpError(
435 "URL credentials are not allowed".to_string(),
436 ));
437 }
438
439 let host = url
440 .host_str()
441 .ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
442 let port = url.port_or_known_default().unwrap_or(443);
443
444 if port != 80 && port != 443 {
446 return Err(SeerError::HttpError(format!(
447 "non-standard port {} is not allowed in redirects",
448 port
449 )));
450 }
451
452 if let Ok(ip) = host.parse::<std::net::IpAddr>() {
453 if let Some(reason) = describe_reserved_ip(&ip) {
454 return Err(SeerError::HttpError(format!(
455 "cannot connect to {}: {} — {}",
456 host, ip, reason
457 )));
458 }
459 return Ok(vec![SocketAddr::new(ip, port)]);
460 }
461
462 let addr = format!("{}:{}", host, port);
463 let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
464 .await
465 .map_err(|e| SeerError::HttpError(format!("DNS lookup failed: {}", e)))?
466 .collect();
467
468 if socket_addrs.is_empty() {
469 return Err(SeerError::HttpError(format!(
470 "DNS lookup returned no addresses for {}",
471 host
472 )));
473 }
474
475 for socket_addr in &socket_addrs {
476 if let Some(reason) = describe_reserved_ip(&socket_addr.ip()) {
477 return Err(SeerError::HttpError(format!(
478 "cannot connect to {}: {} — {}",
479 host,
480 socket_addr.ip(),
481 reason
482 )));
483 }
484 }
485
486 Ok(socket_addrs)
487}
488
489fn parse_certificate_der(der: &[u8], domain: &str) -> Result<CertificateInfo> {
491 use x509_parser::prelude::*;
492
493 let (_, cert) = X509Certificate::from_der(der)
494 .map_err(|e| SeerError::CertificateError(format!("failed to parse certificate: {}", e)))?;
495
496 let issuer = format_issuer_name(cert.issuer()).unwrap_or_else(|| "Unknown Issuer".to_string());
501
502 let subject =
504 extract_name_from_x509(cert.subject()).unwrap_or_else(|| "Unknown Subject".to_string());
505
506 let valid_from = asn1_time_to_chrono(cert.validity().not_before)?;
508 let valid_until = asn1_time_to_chrono(cert.validity().not_after)?;
509
510 let now = Utc::now();
511 let days_until_expiry = (valid_until - now).num_days();
512 let is_valid = now >= valid_from && now <= valid_until;
513
514 let hostname_verified = cert_matches_hostname(&cert, domain);
519
520 Ok(CertificateInfo {
521 issuer,
522 subject,
523 valid_from,
524 valid_until,
525 days_until_expiry,
526 is_valid,
527 hostname_verified,
528 })
529}
530
531fn hostname_matches_pattern(host: &str, pattern: &str) -> bool {
537 let host = host.to_ascii_lowercase();
538 let pattern = pattern.to_ascii_lowercase();
539 if let Some(rest) = pattern.strip_prefix("*.") {
540 let Some(dot) = host.find('.') else {
542 return false;
543 };
544 let host_rest = &host[dot + 1..];
545 host_rest == rest
546 } else {
547 host == pattern
548 }
549}
550
551fn cert_matches_hostname(cert: &x509_parser::certificate::X509Certificate<'_>, host: &str) -> bool {
557 use x509_parser::prelude::*;
558
559 if let Ok(Some(san_ext)) = cert.tbs_certificate.subject_alternative_name() {
561 for name in &san_ext.value.general_names {
562 if let GeneralName::DNSName(n) = name {
563 if hostname_matches_pattern(host, n) {
564 return true;
565 }
566 }
567 }
568 }
569
570 for cn in cert.subject().iter_common_name() {
572 if let Ok(s) = cn.as_str() {
573 if hostname_matches_pattern(host, s) {
574 return true;
575 }
576 }
577 }
578
579 false
580}
581
582fn format_issuer_name(name: &x509_parser::prelude::X509Name) -> Option<String> {
586 use x509_parser::oid_registry;
587 let cn = extract_oid_value(name, &oid_registry::OID_X509_COMMON_NAME);
588 let org = extract_oid_value(name, &oid_registry::OID_X509_ORGANIZATION_NAME);
589 match (org, cn) {
590 (Some(o), Some(c)) if o != c => Some(format!("{} ({})", o, c)),
591 (Some(o), _) => Some(o),
592 (None, Some(c)) => Some(c),
593 (None, None) => None,
594 }
595}
596
597fn extract_oid_value(
599 name: &x509_parser::prelude::X509Name,
600 oid: &x509_parser::der_parser::oid::Oid<'static>,
601) -> Option<String> {
602 for rdn in name.iter() {
603 for attr in rdn.iter() {
604 if attr.attr_type() == oid {
605 if let Some(s) = extract_attr_string(attr.attr_value()) {
606 return Some(s);
607 }
608 }
609 }
610 }
611 None
612}
613
614fn extract_name_from_x509(name: &x509_parser::prelude::X509Name) -> Option<String> {
616 use x509_parser::prelude::*;
617
618 for rdn in name.iter() {
620 for attr in rdn.iter() {
621 if attr.attr_type() == &oid_registry::OID_X509_COMMON_NAME {
622 if let Some(s) = extract_attr_string(attr.attr_value()) {
623 return Some(s);
624 }
625 }
626 }
627 }
628
629 for rdn in name.iter() {
631 for attr in rdn.iter() {
632 if attr.attr_type() == &oid_registry::OID_X509_ORGANIZATION_NAME {
633 if let Some(s) = extract_attr_string(attr.attr_value()) {
634 return Some(s);
635 }
636 }
637 }
638 }
639
640 None
641}
642
643fn extract_attr_string(value: &x509_parser::der_parser::asn1_rs::Any) -> Option<String> {
645 if let Ok(s) = value.as_str() {
647 return Some(s.to_string());
648 }
649
650 if let Ok(utf8) = value.as_utf8string() {
652 return Some(utf8.string().to_string());
653 }
654
655 if let Ok(s) = std::str::from_utf8(value.data) {
657 return Some(s.to_string());
658 }
659
660 None
661}
662
663fn asn1_time_to_chrono(time: x509_parser::time::ASN1Time) -> Result<chrono::DateTime<Utc>> {
665 let timestamp = time.timestamp();
666 chrono::DateTime::from_timestamp(timestamp, 0)
667 .ok_or_else(|| SeerError::CertificateError("invalid certificate timestamp".to_string()))
668}
669
670#[cfg(test)]
671mod tests {
672 use super::*;
673
674 #[test]
675 fn hostname_matches_pattern_exact() {
676 assert!(hostname_matches_pattern("example.com", "example.com"));
677 assert!(hostname_matches_pattern("EXAMPLE.COM", "example.com"));
678 assert!(hostname_matches_pattern("example.com", "EXAMPLE.COM"));
679 assert!(!hostname_matches_pattern("evil.com", "example.com"));
680 assert!(!hostname_matches_pattern("example.com", "evil.com"));
681 }
682
683 #[test]
684 fn hostname_matches_pattern_wildcard() {
685 assert!(hostname_matches_pattern("a.example.com", "*.example.com"));
686 assert!(hostname_matches_pattern("A.EXAMPLE.COM", "*.example.com"));
687 assert!(!hostname_matches_pattern("example.com", "*.example.com"));
689 assert!(!hostname_matches_pattern(
691 "a.b.example.com",
692 "*.example.com"
693 ));
694 assert!(!hostname_matches_pattern("b.other.com", "*.example.com"));
695 }
696
697 #[test]
698 fn hostname_matches_pattern_wildcard_requires_dot() {
699 assert!(!hostname_matches_pattern("localhost", "*.example.com"));
701 }
702}