1use std::collections::HashSet;
2use std::net::SocketAddr;
3use std::time::Duration;
4
5use chrono::Utc;
6use native_tls::TlsConnector;
7use once_cell::sync::Lazy;
8use regex::Regex;
9use reqwest::{Client, Url};
10use tokio::net::TcpStream;
11use tracing::{debug, instrument};
12
13use super::types::{CertificateInfo, DnsResolution, DomainExpiration, StatusResponse};
14use crate::dns::{DnsResolver, RecordData, RecordType};
15use crate::error::{Result, SeerError};
16use crate::lookup::SmartLookup;
17use crate::validation::{describe_reserved_ip, normalize_domain};
18
19const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
22const MAX_REDIRECTS: usize = 5;
23
24static TITLE_REGEX: Lazy<Regex> = Lazy::new(|| {
26 Regex::new(r"(?i)<title[^>]*>([^<]+)</title>").expect("Invalid regex for HTML title extraction")
27});
28
29#[derive(Debug, Clone)]
31pub struct StatusClient {
32 timeout: Duration,
33 dns_resolver: DnsResolver,
35 smart_lookup: SmartLookup,
37}
38
39impl Default for StatusClient {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl StatusClient {
46 pub fn new() -> Self {
48 Self {
49 timeout: DEFAULT_TIMEOUT,
50 dns_resolver: DnsResolver::new(),
51 smart_lookup: SmartLookup::new(),
52 }
53 }
54
55 pub fn with_timeout(mut self, timeout: Duration) -> Self {
57 self.timeout = timeout;
58 self
59 }
60
61 #[instrument(skip(self), fields(domain = %domain))]
63 pub async fn check(&self, domain: &str) -> Result<StatusResponse> {
64 let domain = normalize_domain(domain)?;
66 debug!("Checking status for domain: {}", domain);
67
68 let mut response = StatusResponse::new(domain.clone());
69
70 let (http_result, cert_result, expiry_result, dns_result) = tokio::join!(
73 self.fetch_http_info(&domain),
74 self.fetch_certificate_info(&domain),
75 self.fetch_domain_expiration(&domain),
76 self.fetch_dns_resolution(&domain)
77 );
78
79 match http_result {
81 Ok((status, status_text, title)) => {
82 response.http_status = Some(status);
83 response.http_status_text = Some(status_text);
84 response.title = title;
85 }
86 Err(e) => response.errors.push(super::types::StatusError {
87 check: "http".to_string(),
88 message: e.to_string(),
89 }),
90 }
91
92 match cert_result {
94 Ok(cert_info) => response.certificate = Some(cert_info),
95 Err(e) => response.errors.push(super::types::StatusError {
96 check: "ssl".to_string(),
97 message: e.to_string(),
98 }),
99 }
100
101 match expiry_result {
103 Ok(expiry_info) => response.domain_expiration = expiry_info,
104 Err(e) => response.errors.push(super::types::StatusError {
105 check: "expiration".to_string(),
106 message: e.to_string(),
107 }),
108 }
109
110 match dns_result {
112 Ok(dns_info) => response.dns_resolution = Some(dns_info),
113 Err(e) => response.errors.push(super::types::StatusError {
114 check: "dns".to_string(),
115 message: e.to_string(),
116 }),
117 }
118
119 Ok(response)
120 }
121
122 async fn fetch_http_info(&self, domain: &str) -> Result<(u16, String, Option<String>)> {
141 let mut url = Url::parse(&format!("https://{}", domain))
142 .map_err(|e| SeerError::HttpError(format!("invalid URL: {}", e)))?;
143 let mut visited = HashSet::new();
144
145 for _ in 0..=MAX_REDIRECTS {
146 let validated_addrs = validate_url_target(&url).await?;
147
148 if !visited.insert(url.clone()) {
149 return Err(SeerError::HttpError("redirect loop detected".to_string()));
150 }
151
152 let host = url
156 .host_str()
157 .ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
158 let client = Client::builder()
159 .redirect(reqwest::redirect::Policy::none())
160 .user_agent(concat!("Seer/", env!("CARGO_PKG_VERSION")))
161 .resolve_to_addrs(host, &validated_addrs)
162 .build()
163 .map_err(|e| SeerError::HttpError(format!("failed to build HTTP client: {}", e)))?;
164
165 let response = client
166 .get(url.clone())
167 .timeout(self.timeout)
168 .send()
169 .await
170 .map_err(|e| SeerError::HttpError(e.to_string()))?;
171
172 if response.status().is_redirection() {
173 let location = response.headers().get(reqwest::header::LOCATION);
174 let location = location.and_then(|v| v.to_str().ok()).ok_or_else(|| {
175 SeerError::HttpError("redirect missing location header".to_string())
176 })?;
177 let next_url = url
178 .join(location)
179 .or_else(|_| Url::parse(location))
180 .map_err(|e| SeerError::HttpError(format!("invalid redirect URL: {}", e)))?;
181 url = next_url;
182 continue;
183 }
184
185 let status = response.status();
186 let status_code = status.as_u16();
187 let status_text = status.canonical_reason().unwrap_or("Unknown").to_string();
188
189 let title = if status.is_success() {
191 let content_type = response
192 .headers()
193 .get("content-type")
194 .and_then(|v| v.to_str().ok())
195 .unwrap_or("");
196
197 if content_type.contains("text/html") {
198 const MAX_TITLE_BODY: usize = 64 * 1024;
203 use futures::StreamExt;
204 let mut buf: Vec<u8> = Vec::with_capacity(8 * 1024);
205 let mut stream = response.bytes_stream();
206 while let Some(chunk) = stream.next().await {
207 let chunk = chunk
208 .map_err(|e| SeerError::HttpError(format!("body chunk: {}", e)))?;
209 let remaining = MAX_TITLE_BODY.saturating_sub(buf.len());
210 if remaining == 0 {
211 break;
212 }
213 let take = remaining.min(chunk.len());
214 buf.extend_from_slice(&chunk[..take]);
215 if buf.len() >= MAX_TITLE_BODY {
216 break;
217 }
218 }
219 let body = String::from_utf8_lossy(&buf);
220 extract_title(&body)
221 } else {
222 None
223 }
224 } else {
225 None
226 };
227
228 return Ok((status_code, status_text, title));
229 }
230
231 Err(SeerError::HttpError("too many redirects".to_string()))
232 }
233
234 async fn fetch_certificate_info(&self, domain: &str) -> Result<CertificateInfo> {
241 let addr = format!("{}:443", domain);
243 let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
244 .await
245 .map_err(|e| SeerError::CertificateError(format!("DNS lookup failed: {}", e)))?
246 .collect();
247
248 if socket_addrs.is_empty() {
249 return Err(SeerError::CertificateError(format!(
250 "DNS lookup returned no addresses for {}",
251 domain
252 )));
253 }
254
255 for socket_addr in &socket_addrs {
256 if let Some(reason) = describe_reserved_ip(&socket_addr.ip()) {
257 return Err(SeerError::CertificateError(format!(
258 "cannot connect to {}: {} — {}",
259 domain,
260 socket_addr.ip(),
261 reason
262 )));
263 }
264 }
265
266 let connector = TlsConnector::builder()
267 .danger_accept_invalid_certs(true) .build()
269 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
270
271 let connector = tokio_native_tls::TlsConnector::from(connector);
272
273 let stream =
276 tokio::time::timeout(self.timeout, TcpStream::connect(socket_addrs.as_slice()))
277 .await
278 .map_err(|_| SeerError::Timeout(format!("connection to {} timed out", domain)))?
279 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
280
281 let tls_stream = tokio::time::timeout(self.timeout, connector.connect(domain, stream))
283 .await
284 .map_err(|_| SeerError::Timeout(format!("TLS handshake with {} timed out", domain)))?
285 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
286
287 let cert = tls_stream
289 .get_ref()
290 .peer_certificate()
291 .map_err(|e| SeerError::CertificateError(e.to_string()))?
292 .ok_or_else(|| SeerError::CertificateError("no certificate found".to_string()))?;
293
294 let der = cert
296 .to_der()
297 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
298
299 parse_certificate_der(&der, domain)
300 }
301
302 async fn fetch_domain_expiration(&self, domain: &str) -> Result<Option<DomainExpiration>> {
304 match self.smart_lookup.lookup(domain).await {
305 Ok(result) => {
306 let (expiration_date, registrar) = result.expiration_info();
307
308 if let Some(exp_date) = expiration_date {
309 let days_until_expiry = (exp_date - Utc::now()).num_days();
310 Ok(Some(DomainExpiration {
311 expiration_date: exp_date,
312 days_until_expiry,
313 registrar,
314 }))
315 } else {
316 Ok(None)
317 }
318 }
319 Err(_) => Ok(None), }
321 }
322
323 async fn fetch_dns_resolution(&self, domain: &str) -> Result<DnsResolution> {
325 let resolver = &self.dns_resolver;
326
327 let (a_result, aaaa_result, cname_result, ns_result) = tokio::join!(
329 resolver.resolve(domain, RecordType::A, None),
330 resolver.resolve(domain, RecordType::AAAA, None),
331 resolver.resolve(domain, RecordType::CNAME, None),
332 resolver.resolve(domain, RecordType::NS, None)
333 );
334
335 let a_records: Vec<String> = a_result
337 .unwrap_or_default()
338 .into_iter()
339 .filter_map(|r| {
340 if let RecordData::A { address } = r.data {
341 Some(address)
342 } else {
343 None
344 }
345 })
346 .collect();
347
348 let aaaa_records: Vec<String> = aaaa_result
350 .unwrap_or_default()
351 .into_iter()
352 .filter_map(|r| {
353 if let RecordData::AAAA { address } = r.data {
354 Some(address)
355 } else {
356 None
357 }
358 })
359 .collect();
360
361 let cname_target: Option<String> =
363 cname_result.unwrap_or_default().into_iter().find_map(|r| {
364 if let RecordData::CNAME { target } = r.data {
365 Some(target.trim_end_matches('.').to_string())
366 } else {
367 None
368 }
369 });
370
371 let nameservers: Vec<String> = ns_result
373 .unwrap_or_default()
374 .into_iter()
375 .filter_map(|r| {
376 if let RecordData::NS { nameserver } = r.data {
377 Some(nameserver.trim_end_matches('.').to_string())
378 } else {
379 None
380 }
381 })
382 .collect();
383
384 let resolves = !a_records.is_empty() || !aaaa_records.is_empty() || cname_target.is_some();
386
387 Ok(DnsResolution {
388 a_records,
389 aaaa_records,
390 cname_target,
391 nameservers,
392 resolves,
393 })
394 }
395}
396
397fn extract_title(html: &str) -> Option<String> {
401 TITLE_REGEX
402 .captures(html)
403 .and_then(|caps| caps.get(1))
404 .map(|m| m.as_str().trim().to_string())
405 .filter(|s| !s.is_empty())
406}
407
408async fn validate_url_target(url: &Url) -> Result<Vec<SocketAddr>> {
414 let scheme = url.scheme();
415 if scheme != "https" && scheme != "http" {
416 return Err(SeerError::HttpError(format!(
417 "unsupported URL scheme: {}",
418 scheme
419 )));
420 }
421
422 if !url.username().is_empty() || url.password().is_some() {
423 return Err(SeerError::HttpError(
424 "URL credentials are not allowed".to_string(),
425 ));
426 }
427
428 let host = url
429 .host_str()
430 .ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
431 let port = url.port_or_known_default().unwrap_or(443);
432
433 if port != 80 && port != 443 {
435 return Err(SeerError::HttpError(format!(
436 "non-standard port {} is not allowed in redirects",
437 port
438 )));
439 }
440
441 if let Ok(ip) = host.parse::<std::net::IpAddr>() {
442 if let Some(reason) = describe_reserved_ip(&ip) {
443 return Err(SeerError::HttpError(format!(
444 "cannot connect to {}: {} — {}",
445 host, ip, reason
446 )));
447 }
448 return Ok(vec![SocketAddr::new(ip, port)]);
449 }
450
451 let addr = format!("{}:{}", host, port);
452 let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
453 .await
454 .map_err(|e| SeerError::HttpError(format!("DNS lookup failed: {}", e)))?
455 .collect();
456
457 if socket_addrs.is_empty() {
458 return Err(SeerError::HttpError(format!(
459 "DNS lookup returned no addresses for {}",
460 host
461 )));
462 }
463
464 for socket_addr in &socket_addrs {
465 if let Some(reason) = describe_reserved_ip(&socket_addr.ip()) {
466 return Err(SeerError::HttpError(format!(
467 "cannot connect to {}: {} — {}",
468 host,
469 socket_addr.ip(),
470 reason
471 )));
472 }
473 }
474
475 Ok(socket_addrs)
476}
477
478fn parse_certificate_der(der: &[u8], domain: &str) -> Result<CertificateInfo> {
480 use x509_parser::prelude::*;
481
482 let (_, cert) = X509Certificate::from_der(der)
483 .map_err(|e| SeerError::CertificateError(format!("failed to parse certificate: {}", e)))?;
484
485 let issuer =
487 extract_name_from_x509(cert.issuer()).unwrap_or_else(|| "Unknown Issuer".to_string());
488
489 let subject =
491 extract_name_from_x509(cert.subject()).unwrap_or_else(|| "Unknown Subject".to_string());
492
493 let valid_from = asn1_time_to_chrono(cert.validity().not_before)?;
495 let valid_until = asn1_time_to_chrono(cert.validity().not_after)?;
496
497 let now = Utc::now();
498 let days_until_expiry = (valid_until - now).num_days();
499 let is_valid = now >= valid_from && now <= valid_until;
500
501 let hostname_verified = cert_matches_hostname(&cert, domain);
506
507 Ok(CertificateInfo {
508 issuer,
509 subject,
510 valid_from,
511 valid_until,
512 days_until_expiry,
513 is_valid,
514 hostname_verified,
515 })
516}
517
518fn hostname_matches_pattern(host: &str, pattern: &str) -> bool {
524 let host = host.to_ascii_lowercase();
525 let pattern = pattern.to_ascii_lowercase();
526 if let Some(rest) = pattern.strip_prefix("*.") {
527 let Some(dot) = host.find('.') else {
529 return false;
530 };
531 let host_rest = &host[dot + 1..];
532 host_rest == rest
533 } else {
534 host == pattern
535 }
536}
537
538fn cert_matches_hostname(cert: &x509_parser::certificate::X509Certificate<'_>, host: &str) -> bool {
544 use x509_parser::prelude::*;
545
546 if let Ok(Some(san_ext)) = cert.tbs_certificate.subject_alternative_name() {
548 for name in &san_ext.value.general_names {
549 if let GeneralName::DNSName(n) = name {
550 if hostname_matches_pattern(host, n) {
551 return true;
552 }
553 }
554 }
555 }
556
557 for cn in cert.subject().iter_common_name() {
559 if let Ok(s) = cn.as_str() {
560 if hostname_matches_pattern(host, s) {
561 return true;
562 }
563 }
564 }
565
566 false
567}
568
569fn extract_name_from_x509(name: &x509_parser::prelude::X509Name) -> Option<String> {
571 use x509_parser::prelude::*;
572
573 for rdn in name.iter() {
575 for attr in rdn.iter() {
576 if attr.attr_type() == &oid_registry::OID_X509_COMMON_NAME {
577 if let Some(s) = extract_attr_string(attr.attr_value()) {
578 return Some(s);
579 }
580 }
581 }
582 }
583
584 for rdn in name.iter() {
586 for attr in rdn.iter() {
587 if attr.attr_type() == &oid_registry::OID_X509_ORGANIZATION_NAME {
588 if let Some(s) = extract_attr_string(attr.attr_value()) {
589 return Some(s);
590 }
591 }
592 }
593 }
594
595 None
596}
597
598fn extract_attr_string(value: &x509_parser::der_parser::asn1_rs::Any) -> Option<String> {
600 if let Ok(s) = value.as_str() {
602 return Some(s.to_string());
603 }
604
605 if let Ok(utf8) = value.as_utf8string() {
607 return Some(utf8.string().to_string());
608 }
609
610 if let Ok(s) = std::str::from_utf8(value.data) {
612 return Some(s.to_string());
613 }
614
615 None
616}
617
618fn asn1_time_to_chrono(time: x509_parser::time::ASN1Time) -> Result<chrono::DateTime<Utc>> {
620 let timestamp = time.timestamp();
621 chrono::DateTime::from_timestamp(timestamp, 0)
622 .ok_or_else(|| SeerError::CertificateError("invalid certificate timestamp".to_string()))
623}
624
625#[cfg(test)]
626mod tests {
627 use super::*;
628
629 #[test]
630 fn hostname_matches_pattern_exact() {
631 assert!(hostname_matches_pattern("example.com", "example.com"));
632 assert!(hostname_matches_pattern("EXAMPLE.COM", "example.com"));
633 assert!(hostname_matches_pattern("example.com", "EXAMPLE.COM"));
634 assert!(!hostname_matches_pattern("evil.com", "example.com"));
635 assert!(!hostname_matches_pattern("example.com", "evil.com"));
636 }
637
638 #[test]
639 fn hostname_matches_pattern_wildcard() {
640 assert!(hostname_matches_pattern("a.example.com", "*.example.com"));
641 assert!(hostname_matches_pattern("A.EXAMPLE.COM", "*.example.com"));
642 assert!(!hostname_matches_pattern("example.com", "*.example.com"));
644 assert!(!hostname_matches_pattern(
646 "a.b.example.com",
647 "*.example.com"
648 ));
649 assert!(!hostname_matches_pattern("b.other.com", "*.example.com"));
650 }
651
652 #[test]
653 fn hostname_matches_pattern_wildcard_requires_dot() {
654 assert!(!hostname_matches_pattern("localhost", "*.example.com"));
656 }
657}