1use std::collections::HashSet;
2use std::net::SocketAddr;
3use std::time::Duration;
4
5use chrono::Utc;
6use native_tls::TlsConnector;
7use once_cell::sync::Lazy;
8use regex::Regex;
9use reqwest::{Client, Url};
10use tokio::net::TcpStream;
11use tracing::{debug, instrument};
12
13use super::types::{CertificateInfo, DnsResolution, DomainExpiration, StatusResponse};
14use crate::dns::{DnsResolver, RecordData, RecordType};
15use crate::error::{Result, SeerError};
16use crate::lookup::SmartLookup;
17use crate::validation::{describe_reserved_ip, normalize_domain};
18
19const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
22const MAX_REDIRECTS: usize = 5;
23
24static TITLE_REGEX: Lazy<Regex> = Lazy::new(|| {
26 Regex::new(r"(?i)<title[^>]*>([^<]+)</title>").expect("Invalid regex for HTML title extraction")
27});
28
29#[derive(Debug, Clone)]
31pub struct StatusClient {
32 timeout: Duration,
33 dns_resolver: DnsResolver,
35 smart_lookup: SmartLookup,
37}
38
39impl Default for StatusClient {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl StatusClient {
46 pub fn new() -> Self {
48 Self {
49 timeout: DEFAULT_TIMEOUT,
50 dns_resolver: DnsResolver::new(),
51 smart_lookup: SmartLookup::new(),
52 }
53 }
54
55 pub fn with_timeout(mut self, timeout: Duration) -> Self {
57 self.timeout = timeout;
58 self
59 }
60
61 #[instrument(skip(self), fields(domain = %domain))]
63 pub async fn check(&self, domain: &str) -> Result<StatusResponse> {
64 let domain = normalize_domain(domain)?;
66 debug!("Checking status for domain: {}", domain);
67
68 let mut response = StatusResponse::new(domain.clone());
69
70 let (http_result, cert_result, expiry_result, dns_result) = tokio::join!(
73 self.fetch_http_info(&domain),
74 self.fetch_certificate_info(&domain),
75 self.fetch_domain_expiration(&domain),
76 self.fetch_dns_resolution(&domain)
77 );
78
79 match http_result {
81 Ok((status, status_text, title)) => {
82 response.http_status = Some(status);
83 response.http_status_text = Some(status_text);
84 response.title = title;
85 }
86 Err(e) => response.errors.push(super::types::StatusError {
87 check: "http".to_string(),
88 message: e.to_string(),
89 }),
90 }
91
92 match cert_result {
94 Ok(cert_info) => response.certificate = Some(cert_info),
95 Err(e) => response.errors.push(super::types::StatusError {
96 check: "ssl".to_string(),
97 message: e.to_string(),
98 }),
99 }
100
101 match expiry_result {
103 Ok(expiry_info) => response.domain_expiration = expiry_info,
104 Err(e) => response.errors.push(super::types::StatusError {
105 check: "expiration".to_string(),
106 message: e.to_string(),
107 }),
108 }
109
110 match dns_result {
112 Ok(dns_info) => response.dns_resolution = Some(dns_info),
113 Err(e) => response.errors.push(super::types::StatusError {
114 check: "dns".to_string(),
115 message: e.to_string(),
116 }),
117 }
118
119 Ok(response)
120 }
121
122 async fn fetch_http_info(&self, domain: &str) -> Result<(u16, String, Option<String>)> {
132 let mut url = Url::parse(&format!("https://{}", domain))
133 .map_err(|e| SeerError::HttpError(format!("invalid URL: {}", e)))?;
134 let mut visited = HashSet::new();
135
136 for _ in 0..=MAX_REDIRECTS {
137 let validated_addrs = validate_url_target(&url).await?;
138
139 if !visited.insert(url.clone()) {
140 return Err(SeerError::HttpError("redirect loop detected".to_string()));
141 }
142
143 let host = url
147 .host_str()
148 .ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
149 let client = Client::builder()
150 .redirect(reqwest::redirect::Policy::none())
151 .user_agent(concat!("Seer/", env!("CARGO_PKG_VERSION")))
152 .resolve_to_addrs(host, &validated_addrs)
153 .build()
154 .map_err(|e| SeerError::HttpError(format!("failed to build HTTP client: {}", e)))?;
155
156 let response = client
157 .get(url.clone())
158 .timeout(self.timeout)
159 .send()
160 .await
161 .map_err(|e| SeerError::HttpError(e.to_string()))?;
162
163 if response.status().is_redirection() {
164 let location = response.headers().get(reqwest::header::LOCATION);
165 let location = location.and_then(|v| v.to_str().ok()).ok_or_else(|| {
166 SeerError::HttpError("redirect missing location header".to_string())
167 })?;
168 let next_url = url
169 .join(location)
170 .or_else(|_| Url::parse(location))
171 .map_err(|e| SeerError::HttpError(format!("invalid redirect URL: {}", e)))?;
172 url = next_url;
173 continue;
174 }
175
176 let status = response.status();
177 let status_code = status.as_u16();
178 let status_text = status.canonical_reason().unwrap_or("Unknown").to_string();
179
180 let title = if status.is_success() {
182 let content_type = response
183 .headers()
184 .get("content-type")
185 .and_then(|v| v.to_str().ok())
186 .unwrap_or("");
187
188 if content_type.contains("text/html") {
189 const MAX_TITLE_BODY: usize = 64 * 1024;
194 use futures::StreamExt;
195 let mut buf: Vec<u8> = Vec::with_capacity(8 * 1024);
196 let mut stream = response.bytes_stream();
197 while let Some(chunk) = stream.next().await {
198 let chunk = chunk
199 .map_err(|e| SeerError::HttpError(format!("body chunk: {}", e)))?;
200 let remaining = MAX_TITLE_BODY.saturating_sub(buf.len());
201 if remaining == 0 {
202 break;
203 }
204 let take = remaining.min(chunk.len());
205 buf.extend_from_slice(&chunk[..take]);
206 if buf.len() >= MAX_TITLE_BODY {
207 break;
208 }
209 }
210 let body = String::from_utf8_lossy(&buf);
211 extract_title(&body)
212 } else {
213 None
214 }
215 } else {
216 None
217 };
218
219 return Ok((status_code, status_text, title));
220 }
221
222 Err(SeerError::HttpError("too many redirects".to_string()))
223 }
224
225 async fn fetch_certificate_info(&self, domain: &str) -> Result<CertificateInfo> {
232 let addr = format!("{}:443", domain);
234 let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
235 .await
236 .map_err(|e| SeerError::CertificateError(format!("DNS lookup failed: {}", e)))?
237 .collect();
238
239 if socket_addrs.is_empty() {
240 return Err(SeerError::CertificateError(format!(
241 "DNS lookup returned no addresses for {}",
242 domain
243 )));
244 }
245
246 for socket_addr in &socket_addrs {
247 if let Some(reason) = describe_reserved_ip(&socket_addr.ip()) {
248 return Err(SeerError::CertificateError(format!(
249 "cannot connect to {}: {} — {}",
250 domain,
251 socket_addr.ip(),
252 reason
253 )));
254 }
255 }
256
257 let connector = TlsConnector::builder()
258 .danger_accept_invalid_certs(true) .build()
260 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
261
262 let connector = tokio_native_tls::TlsConnector::from(connector);
263
264 let stream =
267 tokio::time::timeout(self.timeout, TcpStream::connect(socket_addrs.as_slice()))
268 .await
269 .map_err(|_| SeerError::Timeout(format!("connection to {} timed out", domain)))?
270 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
271
272 let tls_stream = tokio::time::timeout(self.timeout, connector.connect(domain, stream))
274 .await
275 .map_err(|_| SeerError::Timeout(format!("TLS handshake with {} timed out", domain)))?
276 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
277
278 let cert = tls_stream
280 .get_ref()
281 .peer_certificate()
282 .map_err(|e| SeerError::CertificateError(e.to_string()))?
283 .ok_or_else(|| SeerError::CertificateError("no certificate found".to_string()))?;
284
285 let der = cert
287 .to_der()
288 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
289
290 parse_certificate_der(&der, domain)
291 }
292
293 async fn fetch_domain_expiration(&self, domain: &str) -> Result<Option<DomainExpiration>> {
295 match self.smart_lookup.lookup(domain).await {
296 Ok(result) => {
297 let (expiration_date, registrar) = result.expiration_info();
298
299 if let Some(exp_date) = expiration_date {
300 let days_until_expiry = (exp_date - Utc::now()).num_days();
301 Ok(Some(DomainExpiration {
302 expiration_date: exp_date,
303 days_until_expiry,
304 registrar,
305 }))
306 } else {
307 Ok(None)
308 }
309 }
310 Err(_) => Ok(None), }
312 }
313
314 async fn fetch_dns_resolution(&self, domain: &str) -> Result<DnsResolution> {
316 let resolver = &self.dns_resolver;
317
318 let (a_result, aaaa_result, cname_result, ns_result) = tokio::join!(
320 resolver.resolve(domain, RecordType::A, None),
321 resolver.resolve(domain, RecordType::AAAA, None),
322 resolver.resolve(domain, RecordType::CNAME, None),
323 resolver.resolve(domain, RecordType::NS, None)
324 );
325
326 let a_records: Vec<String> = a_result
328 .unwrap_or_default()
329 .into_iter()
330 .filter_map(|r| {
331 if let RecordData::A { address } = r.data {
332 Some(address)
333 } else {
334 None
335 }
336 })
337 .collect();
338
339 let aaaa_records: Vec<String> = aaaa_result
341 .unwrap_or_default()
342 .into_iter()
343 .filter_map(|r| {
344 if let RecordData::AAAA { address } = r.data {
345 Some(address)
346 } else {
347 None
348 }
349 })
350 .collect();
351
352 let cname_target: Option<String> =
354 cname_result.unwrap_or_default().into_iter().find_map(|r| {
355 if let RecordData::CNAME { target } = r.data {
356 Some(target.trim_end_matches('.').to_string())
357 } else {
358 None
359 }
360 });
361
362 let nameservers: Vec<String> = ns_result
364 .unwrap_or_default()
365 .into_iter()
366 .filter_map(|r| {
367 if let RecordData::NS { nameserver } = r.data {
368 Some(nameserver.trim_end_matches('.').to_string())
369 } else {
370 None
371 }
372 })
373 .collect();
374
375 let resolves = !a_records.is_empty() || !aaaa_records.is_empty() || cname_target.is_some();
377
378 Ok(DnsResolution {
379 a_records,
380 aaaa_records,
381 cname_target,
382 nameservers,
383 resolves,
384 })
385 }
386}
387
388fn extract_title(html: &str) -> Option<String> {
392 TITLE_REGEX
393 .captures(html)
394 .and_then(|caps| caps.get(1))
395 .map(|m| m.as_str().trim().to_string())
396 .filter(|s| !s.is_empty())
397}
398
399async fn validate_url_target(url: &Url) -> Result<Vec<SocketAddr>> {
405 let scheme = url.scheme();
406 if scheme != "https" && scheme != "http" {
407 return Err(SeerError::HttpError(format!(
408 "unsupported URL scheme: {}",
409 scheme
410 )));
411 }
412
413 if !url.username().is_empty() || url.password().is_some() {
414 return Err(SeerError::HttpError(
415 "URL credentials are not allowed".to_string(),
416 ));
417 }
418
419 let host = url
420 .host_str()
421 .ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
422 let port = url.port_or_known_default().unwrap_or(443);
423
424 if port != 80 && port != 443 {
426 return Err(SeerError::HttpError(format!(
427 "non-standard port {} is not allowed in redirects",
428 port
429 )));
430 }
431
432 if let Ok(ip) = host.parse::<std::net::IpAddr>() {
433 if let Some(reason) = describe_reserved_ip(&ip) {
434 return Err(SeerError::HttpError(format!(
435 "cannot connect to {}: {} — {}",
436 host, ip, reason
437 )));
438 }
439 return Ok(vec![SocketAddr::new(ip, port)]);
440 }
441
442 let addr = format!("{}:{}", host, port);
443 let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
444 .await
445 .map_err(|e| SeerError::HttpError(format!("DNS lookup failed: {}", e)))?
446 .collect();
447
448 if socket_addrs.is_empty() {
449 return Err(SeerError::HttpError(format!(
450 "DNS lookup returned no addresses for {}",
451 host
452 )));
453 }
454
455 for socket_addr in &socket_addrs {
456 if let Some(reason) = describe_reserved_ip(&socket_addr.ip()) {
457 return Err(SeerError::HttpError(format!(
458 "cannot connect to {}: {} — {}",
459 host,
460 socket_addr.ip(),
461 reason
462 )));
463 }
464 }
465
466 Ok(socket_addrs)
467}
468
469fn parse_certificate_der(der: &[u8], domain: &str) -> Result<CertificateInfo> {
471 use x509_parser::prelude::*;
472
473 let (_, cert) = X509Certificate::from_der(der)
474 .map_err(|e| SeerError::CertificateError(format!("failed to parse certificate: {}", e)))?;
475
476 let issuer =
478 extract_name_from_x509(cert.issuer()).unwrap_or_else(|| "Unknown Issuer".to_string());
479
480 let subject =
482 extract_name_from_x509(cert.subject()).unwrap_or_else(|| "Unknown Subject".to_string());
483
484 let valid_from = asn1_time_to_chrono(cert.validity().not_before)?;
486 let valid_until = asn1_time_to_chrono(cert.validity().not_after)?;
487
488 let now = Utc::now();
489 let days_until_expiry = (valid_until - now).num_days();
490 let is_valid = now >= valid_from && now <= valid_until;
491
492 let hostname_verified = cert_matches_hostname(&cert, domain);
497
498 Ok(CertificateInfo {
499 issuer,
500 subject,
501 valid_from,
502 valid_until,
503 days_until_expiry,
504 is_valid,
505 hostname_verified,
506 })
507}
508
509fn hostname_matches_pattern(host: &str, pattern: &str) -> bool {
515 let host = host.to_ascii_lowercase();
516 let pattern = pattern.to_ascii_lowercase();
517 if let Some(rest) = pattern.strip_prefix("*.") {
518 let Some(dot) = host.find('.') else {
520 return false;
521 };
522 let host_rest = &host[dot + 1..];
523 host_rest == rest
524 } else {
525 host == pattern
526 }
527}
528
529fn cert_matches_hostname(cert: &x509_parser::certificate::X509Certificate<'_>, host: &str) -> bool {
535 use x509_parser::prelude::*;
536
537 if let Ok(Some(san_ext)) = cert.tbs_certificate.subject_alternative_name() {
539 for name in &san_ext.value.general_names {
540 if let GeneralName::DNSName(n) = name {
541 if hostname_matches_pattern(host, n) {
542 return true;
543 }
544 }
545 }
546 }
547
548 for cn in cert.subject().iter_common_name() {
550 if let Ok(s) = cn.as_str() {
551 if hostname_matches_pattern(host, s) {
552 return true;
553 }
554 }
555 }
556
557 false
558}
559
560fn extract_name_from_x509(name: &x509_parser::prelude::X509Name) -> Option<String> {
562 use x509_parser::prelude::*;
563
564 for rdn in name.iter() {
566 for attr in rdn.iter() {
567 if attr.attr_type() == &oid_registry::OID_X509_COMMON_NAME {
568 if let Some(s) = extract_attr_string(attr.attr_value()) {
569 return Some(s);
570 }
571 }
572 }
573 }
574
575 for rdn in name.iter() {
577 for attr in rdn.iter() {
578 if attr.attr_type() == &oid_registry::OID_X509_ORGANIZATION_NAME {
579 if let Some(s) = extract_attr_string(attr.attr_value()) {
580 return Some(s);
581 }
582 }
583 }
584 }
585
586 None
587}
588
589fn extract_attr_string(value: &x509_parser::der_parser::asn1_rs::Any) -> Option<String> {
591 if let Ok(s) = value.as_str() {
593 return Some(s.to_string());
594 }
595
596 if let Ok(utf8) = value.as_utf8string() {
598 return Some(utf8.string().to_string());
599 }
600
601 if let Ok(s) = std::str::from_utf8(value.data) {
603 return Some(s.to_string());
604 }
605
606 None
607}
608
609fn asn1_time_to_chrono(time: x509_parser::time::ASN1Time) -> Result<chrono::DateTime<Utc>> {
611 let timestamp = time.timestamp();
612 chrono::DateTime::from_timestamp(timestamp, 0)
613 .ok_or_else(|| SeerError::CertificateError("invalid certificate timestamp".to_string()))
614}
615
616#[cfg(test)]
617mod tests {
618 use super::*;
619
620 #[test]
621 fn hostname_matches_pattern_exact() {
622 assert!(hostname_matches_pattern("example.com", "example.com"));
623 assert!(hostname_matches_pattern("EXAMPLE.COM", "example.com"));
624 assert!(hostname_matches_pattern("example.com", "EXAMPLE.COM"));
625 assert!(!hostname_matches_pattern("evil.com", "example.com"));
626 assert!(!hostname_matches_pattern("example.com", "evil.com"));
627 }
628
629 #[test]
630 fn hostname_matches_pattern_wildcard() {
631 assert!(hostname_matches_pattern("a.example.com", "*.example.com"));
632 assert!(hostname_matches_pattern("A.EXAMPLE.COM", "*.example.com"));
633 assert!(!hostname_matches_pattern("example.com", "*.example.com"));
635 assert!(!hostname_matches_pattern(
637 "a.b.example.com",
638 "*.example.com"
639 ));
640 assert!(!hostname_matches_pattern("b.other.com", "*.example.com"));
641 }
642
643 #[test]
644 fn hostname_matches_pattern_wildcard_requires_dot() {
645 assert!(!hostname_matches_pattern("localhost", "*.example.com"));
647 }
648}