1use std::collections::HashSet;
2use std::net::SocketAddr;
3use std::time::Duration;
4
5use chrono::Utc;
6use native_tls::TlsConnector;
7use once_cell::sync::Lazy;
8use regex::Regex;
9use reqwest::{Client, Url};
10use tokio::net::TcpStream;
11use tracing::{debug, instrument};
12
13use super::types::{CertificateInfo, DnsResolution, DomainExpiration, StatusResponse};
14use crate::dns::{DnsResolver, RecordData, RecordType};
15use crate::error::{Result, SeerError};
16use crate::lookup::SmartLookup;
17use crate::validation::{describe_reserved_ip, normalize_domain};
18
19const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
22const MAX_REDIRECTS: usize = 5;
23
24static TITLE_REGEX: Lazy<Regex> = Lazy::new(|| {
26 Regex::new(r"(?i)<title[^>]*>([^<]+)</title>").expect("Invalid regex for HTML title extraction")
27});
28
29#[derive(Debug, Clone)]
31pub struct StatusClient {
32 timeout: Duration,
33 dns_resolver: DnsResolver,
35 smart_lookup: SmartLookup,
37}
38
39impl Default for StatusClient {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl StatusClient {
46 pub fn new() -> Self {
48 Self {
49 timeout: DEFAULT_TIMEOUT,
50 dns_resolver: DnsResolver::new(),
51 smart_lookup: SmartLookup::new(),
52 }
53 }
54
55 pub fn with_timeout(mut self, timeout: Duration) -> Self {
57 self.timeout = timeout;
58 self
59 }
60
61 #[instrument(skip(self), fields(domain = %domain))]
63 pub async fn check(&self, domain: &str) -> Result<StatusResponse> {
64 let domain = normalize_domain(domain)?;
66 debug!("Checking status for domain: {}", domain);
67
68 let mut response = StatusResponse::new(domain.clone());
69
70 let (http_result, cert_result, expiry_result, dns_result) = tokio::join!(
73 self.fetch_http_info(&domain),
74 self.fetch_certificate_info(&domain),
75 self.fetch_domain_expiration(&domain),
76 self.fetch_dns_resolution(&domain)
77 );
78
79 match http_result {
81 Ok((status, status_text, title)) => {
82 response.http_status = Some(status);
83 response.http_status_text = Some(status_text);
84 response.title = title;
85 }
86 Err(e) => response.errors.push(super::types::StatusError {
87 check: "http".to_string(),
88 message: e.to_string(),
89 }),
90 }
91
92 match cert_result {
94 Ok(cert_info) => response.certificate = Some(cert_info),
95 Err(e) => response.errors.push(super::types::StatusError {
96 check: "ssl".to_string(),
97 message: e.to_string(),
98 }),
99 }
100
101 match expiry_result {
103 Ok(expiry_info) => response.domain_expiration = expiry_info,
104 Err(e) => response.errors.push(super::types::StatusError {
105 check: "expiration".to_string(),
106 message: e.to_string(),
107 }),
108 }
109
110 match dns_result {
112 Ok(dns_info) => response.dns_resolution = Some(dns_info),
113 Err(e) => response.errors.push(super::types::StatusError {
114 check: "dns".to_string(),
115 message: e.to_string(),
116 }),
117 }
118
119 Ok(response)
120 }
121
122 async fn fetch_http_info(&self, domain: &str) -> Result<(u16, String, Option<String>)> {
132 let mut url = Url::parse(&format!("https://{}", domain))
133 .map_err(|e| SeerError::HttpError(format!("invalid URL: {}", e)))?;
134 let mut visited = HashSet::new();
135
136 for _ in 0..=MAX_REDIRECTS {
137 let validated_addrs = validate_url_target(&url).await?;
138
139 if !visited.insert(url.clone()) {
140 return Err(SeerError::HttpError("redirect loop detected".to_string()));
141 }
142
143 let host = url
147 .host_str()
148 .ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
149 let client = Client::builder()
150 .redirect(reqwest::redirect::Policy::none())
151 .user_agent(concat!("Seer/", env!("CARGO_PKG_VERSION")))
152 .resolve_to_addrs(host, &validated_addrs)
153 .build()
154 .map_err(|e| SeerError::HttpError(format!("failed to build HTTP client: {}", e)))?;
155
156 let response = client
157 .get(url.clone())
158 .timeout(self.timeout)
159 .send()
160 .await
161 .map_err(|e| SeerError::HttpError(e.to_string()))?;
162
163 if response.status().is_redirection() {
164 let location = response.headers().get(reqwest::header::LOCATION);
165 let location = location.and_then(|v| v.to_str().ok()).ok_or_else(|| {
166 SeerError::HttpError("redirect missing location header".to_string())
167 })?;
168 let next_url = url
169 .join(location)
170 .or_else(|_| Url::parse(location))
171 .map_err(|e| SeerError::HttpError(format!("invalid redirect URL: {}", e)))?;
172 url = next_url;
173 continue;
174 }
175
176 let status = response.status();
177 let status_code = status.as_u16();
178 let status_text = status.canonical_reason().unwrap_or("Unknown").to_string();
179
180 let title = if status.is_success() {
182 let content_type = response
183 .headers()
184 .get("content-type")
185 .and_then(|v| v.to_str().ok())
186 .unwrap_or("");
187
188 if content_type.contains("text/html") {
189 const MAX_TITLE_BODY: usize = 64 * 1024;
192 let bytes = response
193 .bytes()
194 .await
195 .map_err(|e| SeerError::HttpError(e.to_string()))?;
196 let body = String::from_utf8_lossy(&bytes[..bytes.len().min(MAX_TITLE_BODY)]);
197 extract_title(&body)
198 } else {
199 None
200 }
201 } else {
202 None
203 };
204
205 return Ok((status_code, status_text, title));
206 }
207
208 Err(SeerError::HttpError("too many redirects".to_string()))
209 }
210
211 async fn fetch_certificate_info(&self, domain: &str) -> Result<CertificateInfo> {
218 let addr = format!("{}:443", domain);
220 let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
221 .await
222 .map_err(|e| SeerError::CertificateError(format!("DNS lookup failed: {}", e)))?
223 .collect();
224
225 if socket_addrs.is_empty() {
226 return Err(SeerError::CertificateError(format!(
227 "DNS lookup returned no addresses for {}",
228 domain
229 )));
230 }
231
232 for socket_addr in &socket_addrs {
233 if let Some(reason) = describe_reserved_ip(&socket_addr.ip()) {
234 return Err(SeerError::CertificateError(format!(
235 "cannot connect to {}: {} — {}",
236 domain,
237 socket_addr.ip(),
238 reason
239 )));
240 }
241 }
242
243 let connector = TlsConnector::builder()
244 .danger_accept_invalid_certs(true) .build()
246 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
247
248 let connector = tokio_native_tls::TlsConnector::from(connector);
249
250 let stream =
253 tokio::time::timeout(self.timeout, TcpStream::connect(socket_addrs.as_slice()))
254 .await
255 .map_err(|_| SeerError::Timeout(format!("connection to {} timed out", domain)))?
256 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
257
258 let tls_stream = tokio::time::timeout(self.timeout, connector.connect(domain, stream))
260 .await
261 .map_err(|_| SeerError::Timeout(format!("TLS handshake with {} timed out", domain)))?
262 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
263
264 let cert = tls_stream
266 .get_ref()
267 .peer_certificate()
268 .map_err(|e| SeerError::CertificateError(e.to_string()))?
269 .ok_or_else(|| SeerError::CertificateError("no certificate found".to_string()))?;
270
271 let der = cert
273 .to_der()
274 .map_err(|e| SeerError::CertificateError(e.to_string()))?;
275
276 parse_certificate_der(&der, domain)
277 }
278
279 async fn fetch_domain_expiration(&self, domain: &str) -> Result<Option<DomainExpiration>> {
281 match self.smart_lookup.lookup(domain).await {
282 Ok(result) => {
283 let (expiration_date, registrar) = result.expiration_info();
284
285 if let Some(exp_date) = expiration_date {
286 let days_until_expiry = (exp_date - Utc::now()).num_days();
287 Ok(Some(DomainExpiration {
288 expiration_date: exp_date,
289 days_until_expiry,
290 registrar,
291 }))
292 } else {
293 Ok(None)
294 }
295 }
296 Err(_) => Ok(None), }
298 }
299
300 async fn fetch_dns_resolution(&self, domain: &str) -> Result<DnsResolution> {
302 let resolver = &self.dns_resolver;
303
304 let (a_result, aaaa_result, cname_result, ns_result) = tokio::join!(
306 resolver.resolve(domain, RecordType::A, None),
307 resolver.resolve(domain, RecordType::AAAA, None),
308 resolver.resolve(domain, RecordType::CNAME, None),
309 resolver.resolve(domain, RecordType::NS, None)
310 );
311
312 let a_records: Vec<String> = a_result
314 .unwrap_or_default()
315 .into_iter()
316 .filter_map(|r| {
317 if let RecordData::A { address } = r.data {
318 Some(address)
319 } else {
320 None
321 }
322 })
323 .collect();
324
325 let aaaa_records: Vec<String> = aaaa_result
327 .unwrap_or_default()
328 .into_iter()
329 .filter_map(|r| {
330 if let RecordData::AAAA { address } = r.data {
331 Some(address)
332 } else {
333 None
334 }
335 })
336 .collect();
337
338 let cname_target: Option<String> =
340 cname_result.unwrap_or_default().into_iter().find_map(|r| {
341 if let RecordData::CNAME { target } = r.data {
342 Some(target.trim_end_matches('.').to_string())
343 } else {
344 None
345 }
346 });
347
348 let nameservers: Vec<String> = ns_result
350 .unwrap_or_default()
351 .into_iter()
352 .filter_map(|r| {
353 if let RecordData::NS { nameserver } = r.data {
354 Some(nameserver.trim_end_matches('.').to_string())
355 } else {
356 None
357 }
358 })
359 .collect();
360
361 let resolves = !a_records.is_empty() || !aaaa_records.is_empty() || cname_target.is_some();
363
364 Ok(DnsResolution {
365 a_records,
366 aaaa_records,
367 cname_target,
368 nameservers,
369 resolves,
370 })
371 }
372}
373
374fn extract_title(html: &str) -> Option<String> {
378 TITLE_REGEX
379 .captures(html)
380 .and_then(|caps| caps.get(1))
381 .map(|m| m.as_str().trim().to_string())
382 .filter(|s| !s.is_empty())
383}
384
385async fn validate_url_target(url: &Url) -> Result<Vec<SocketAddr>> {
391 let scheme = url.scheme();
392 if scheme != "https" && scheme != "http" {
393 return Err(SeerError::HttpError(format!(
394 "unsupported URL scheme: {}",
395 scheme
396 )));
397 }
398
399 if !url.username().is_empty() || url.password().is_some() {
400 return Err(SeerError::HttpError(
401 "URL credentials are not allowed".to_string(),
402 ));
403 }
404
405 let host = url
406 .host_str()
407 .ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
408 let port = url.port_or_known_default().unwrap_or(443);
409
410 if port != 80 && port != 443 {
412 return Err(SeerError::HttpError(format!(
413 "non-standard port {} is not allowed in redirects",
414 port
415 )));
416 }
417
418 if let Ok(ip) = host.parse::<std::net::IpAddr>() {
419 if let Some(reason) = describe_reserved_ip(&ip) {
420 return Err(SeerError::HttpError(format!(
421 "cannot connect to {}: {} — {}",
422 host, ip, reason
423 )));
424 }
425 return Ok(vec![SocketAddr::new(ip, port)]);
426 }
427
428 let addr = format!("{}:{}", host, port);
429 let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
430 .await
431 .map_err(|e| SeerError::HttpError(format!("DNS lookup failed: {}", e)))?
432 .collect();
433
434 if socket_addrs.is_empty() {
435 return Err(SeerError::HttpError(format!(
436 "DNS lookup returned no addresses for {}",
437 host
438 )));
439 }
440
441 for socket_addr in &socket_addrs {
442 if let Some(reason) = describe_reserved_ip(&socket_addr.ip()) {
443 return Err(SeerError::HttpError(format!(
444 "cannot connect to {}: {} — {}",
445 host,
446 socket_addr.ip(),
447 reason
448 )));
449 }
450 }
451
452 Ok(socket_addrs)
453}
454
455fn parse_certificate_der(der: &[u8], _domain: &str) -> Result<CertificateInfo> {
457 use x509_parser::prelude::*;
458
459 let (_, cert) = X509Certificate::from_der(der)
460 .map_err(|e| SeerError::CertificateError(format!("failed to parse certificate: {}", e)))?;
461
462 let issuer =
464 extract_name_from_x509(cert.issuer()).unwrap_or_else(|| "Unknown Issuer".to_string());
465
466 let subject =
468 extract_name_from_x509(cert.subject()).unwrap_or_else(|| "Unknown Subject".to_string());
469
470 let valid_from = asn1_time_to_chrono(cert.validity().not_before)?;
472 let valid_until = asn1_time_to_chrono(cert.validity().not_after)?;
473
474 let now = Utc::now();
475 let days_until_expiry = (valid_until - now).num_days();
476 let is_valid = now >= valid_from && now <= valid_until;
477
478 Ok(CertificateInfo {
479 issuer,
480 subject,
481 valid_from,
482 valid_until,
483 days_until_expiry,
484 is_valid,
485 })
486}
487
488fn extract_name_from_x509(name: &x509_parser::prelude::X509Name) -> Option<String> {
490 use x509_parser::prelude::*;
491
492 for rdn in name.iter() {
494 for attr in rdn.iter() {
495 if attr.attr_type() == &oid_registry::OID_X509_COMMON_NAME {
496 if let Some(s) = extract_attr_string(attr.attr_value()) {
497 return Some(s);
498 }
499 }
500 }
501 }
502
503 for rdn in name.iter() {
505 for attr in rdn.iter() {
506 if attr.attr_type() == &oid_registry::OID_X509_ORGANIZATION_NAME {
507 if let Some(s) = extract_attr_string(attr.attr_value()) {
508 return Some(s);
509 }
510 }
511 }
512 }
513
514 None
515}
516
517fn extract_attr_string(value: &x509_parser::der_parser::asn1_rs::Any) -> Option<String> {
519 if let Ok(s) = value.as_str() {
521 return Some(s.to_string());
522 }
523
524 if let Ok(utf8) = value.as_utf8string() {
526 return Some(utf8.string().to_string());
527 }
528
529 if let Ok(s) = std::str::from_utf8(value.data) {
531 return Some(s.to_string());
532 }
533
534 None
535}
536
537fn asn1_time_to_chrono(time: x509_parser::time::ASN1Time) -> Result<chrono::DateTime<Utc>> {
539 let timestamp = time.timestamp();
540 chrono::DateTime::from_timestamp(timestamp, 0)
541 .ok_or_else(|| SeerError::CertificateError("invalid certificate timestamp".to_string()))
542}