Skip to main content

seer_core/
watchlist.rs

1//! Domain watchlist for monitoring expiration and health.
2//!
3//! Loads a list of domains from `~/.seer/watchlist.toml` and checks their
4//! SSL certificates, domain expiration, and HTTP status.
5
6use std::path::PathBuf;
7
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10
11use crate::error::{Result, SeerError};
12use crate::status::StatusClient;
13
14/// Persistent list of domains to monitor.
15#[derive(Debug, Clone, Default, Serialize, Deserialize)]
16pub struct Watchlist {
17    #[serde(default)]
18    pub domains: Vec<String>,
19}
20
21/// Status result for a single watched domain.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct WatchResult {
24    pub domain: String,
25    pub ssl_days_remaining: Option<i64>,
26    pub domain_days_remaining: Option<i64>,
27    pub registrar: Option<String>,
28    pub http_status: Option<u16>,
29    pub issues: Vec<String>,
30}
31
32/// Aggregated report from checking all watched domains.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct WatchReport {
35    pub checked_at: DateTime<Utc>,
36    pub results: Vec<WatchResult>,
37    pub total: usize,
38    pub warnings: usize,
39    pub critical: usize,
40}
41
42impl Watchlist {
43    /// Returns the path to the watchlist file (`~/.seer/watchlist.toml`).
44    pub fn path() -> Option<PathBuf> {
45        dirs::home_dir().map(|h| h.join(".seer").join("watchlist.toml"))
46    }
47
48    /// Loads the watchlist from disk, returning an empty list on any failure.
49    ///
50    /// When the file exists but fails to parse, it is renamed to
51    /// `<path>.corrupt` (preserving the user's data for recovery/forensics)
52    /// and a warning is logged — previously the file was silently
53    /// overwritten on the next save, dropping the user's watchlist.
54    pub fn load() -> Self {
55        let Some(path) = Self::path() else {
56            return Self::default();
57        };
58        Self::load_from_path(&path)
59    }
60
61    /// Like [`Self::load`] but reads from an explicit path. Split out so
62    /// tests can exercise the corrupt-file handling without depending on
63    /// the real `~/.seer/watchlist.toml` location.
64    pub(crate) fn load_from_path(path: &std::path::Path) -> Self {
65        if !path.exists() {
66            return Self::default();
67        }
68        match std::fs::read_to_string(path) {
69            Ok(content) => match toml::from_str::<Watchlist>(&content) {
70                Ok(w) => w,
71                Err(e) => {
72                    let backup = path.with_extension("corrupt");
73                    if let Err(rename_err) = std::fs::rename(path, &backup) {
74                        tracing::error!(
75                            path = %path.display(),
76                            error = %rename_err,
77                            "failed to back up corrupt watchlist",
78                        );
79                    } else {
80                        tracing::warn!(
81                            path = %path.display(),
82                            backup = %backup.display(),
83                            error = %e,
84                            "watchlist file corrupt; moved to backup",
85                        );
86                    }
87                    Watchlist::default()
88                }
89            },
90            Err(_) => Self::default(),
91        }
92    }
93
94    /// Persists the watchlist to disk via write-and-rename so a crash mid-write
95    /// cannot leave the file truncated (the next `load()` would see corrupt
96    /// TOML and silently fall back to the default empty watchlist, losing
97    /// the user's domains). Mirrors `LookupHistory::save`.
98    ///
99    /// The temp filename is suffixed with the current PID so two concurrent
100    /// `seer` processes don't write to the same intermediate path and race
101    /// each other's `rename`s.
102    pub fn save(&self) -> Result<()> {
103        let path = Self::path()
104            .ok_or_else(|| SeerError::ConfigError("Cannot determine home directory".to_string()))?;
105        if let Some(parent) = path.parent() {
106            std::fs::create_dir_all(parent).map_err(|e| SeerError::ConfigError(e.to_string()))?;
107            #[cfg(unix)]
108            {
109                use std::os::unix::fs::PermissionsExt;
110                let _ = std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o700));
111            }
112        }
113        let content =
114            toml::to_string_pretty(self).map_err(|e| SeerError::ConfigError(e.to_string()))?;
115        let tmp_path = path.with_extension(format!("toml.{}.tmp", std::process::id()));
116        std::fs::write(&tmp_path, content).map_err(|e| SeerError::ConfigError(e.to_string()))?;
117        #[cfg(unix)]
118        {
119            use std::os::unix::fs::PermissionsExt;
120            let _ = std::fs::set_permissions(&tmp_path, std::fs::Permissions::from_mode(0o600));
121        }
122        std::fs::rename(&tmp_path, &path).map_err(|e| {
123            let _ = std::fs::remove_file(&tmp_path);
124            SeerError::ConfigError(e.to_string())
125        })?;
126        Ok(())
127    }
128
129    /// Adds a domain to the watchlist. Returns `Ok(true)` if the domain was newly added.
130    pub fn add(&mut self, domain: &str) -> Result<bool> {
131        let domain = crate::validation::normalize_domain(domain)?;
132        if self.domains.contains(&domain) {
133            return Ok(false);
134        }
135        self.domains.push(domain);
136        self.domains.sort();
137        Ok(true)
138    }
139
140    /// Removes a domain from the watchlist. Returns `true` if the domain was present.
141    pub fn remove(&mut self, domain: &str) -> bool {
142        let domain =
143            crate::validation::normalize_domain(domain).unwrap_or_else(|_| domain.to_lowercase());
144        let len_before = self.domains.len();
145        self.domains.retain(|d| d != &domain);
146        self.domains.len() < len_before
147    }
148}
149
150/// Checks all given domains concurrently and produces a [`WatchReport`].
151pub async fn check_watchlist(domains: &[String]) -> WatchReport {
152    use futures::stream::{self, StreamExt};
153
154    // Each per-domain future owns its `client` (via `Arc`) and `domain`
155    // (owned `String`) so the `buffer_unordered` futures are `Send + 'static`
156    // and the whole `check_watchlist` future can be used from `tokio::spawn`
157    // (e.g. the TUI). Borrowing `&client`/`&String` here makes the closure fail
158    // the higher-ranked `FnOnce` bound `tokio::spawn` requires.
159    let client = std::sync::Arc::new(StatusClient::new());
160
161    let results: Vec<WatchResult> = stream::iter(domains.iter().cloned())
162        .map(|domain| {
163            let client = client.clone();
164            async move {
165                let mut watch_result = WatchResult {
166                    domain: domain.clone(),
167                    ssl_days_remaining: None,
168                    domain_days_remaining: None,
169                    registrar: None,
170                    http_status: None,
171                    issues: vec![],
172                };
173
174                match client.check(&domain).await {
175                    Ok(status) => {
176                        watch_result.http_status = status.http_status;
177
178                        if let Some(ref cert) = status.certificate {
179                            watch_result.ssl_days_remaining = Some(cert.days_until_expiry);
180                            if cert.days_until_expiry < 30 {
181                                watch_result.issues.push(format!(
182                                    "SSL expires in {} days",
183                                    cert.days_until_expiry
184                                ));
185                            }
186                            if !cert.is_valid {
187                                watch_result
188                                    .issues
189                                    .push("SSL certificate invalid".to_string());
190                            }
191                        }
192
193                        if let Some(ref exp) = status.domain_expiration {
194                            watch_result.domain_days_remaining = Some(exp.days_until_expiry);
195                            watch_result.registrar = exp.registrar.clone();
196                            if exp.days_until_expiry < 90 {
197                                watch_result.issues.push(format!(
198                                    "Domain expires in {} days",
199                                    exp.days_until_expiry
200                                ));
201                            }
202                        }
203
204                        if let Some(status_code) = status.http_status {
205                            if !(200..300).contains(&status_code) {
206                                watch_result
207                                    .issues
208                                    .push(format!("HTTP status {}", status_code));
209                            }
210                        }
211                    }
212                    Err(e) => {
213                        watch_result.issues.push(format!("Check failed: {}", e));
214                    }
215                }
216
217                watch_result
218            }
219        })
220        .buffer_unordered(10)
221        .collect()
222        .await;
223
224    let total = results.len();
225    // A result counts as critical if ANY of: SSL expires within 30 days,
226    // domain registration expires within 30 days, or an issue string mentions
227    // "invalid" / "failed". Flat OR of three predicates, evaluated
228    // independently — the previous version nested the numeric checks inside
229    // an `any()` over issue strings, which made the critical count depend
230    // on iteration order over `issues` and could short-circuit incorrectly.
231    let critical = results
232        .iter()
233        .filter(|r| {
234            let bad_ssl = r.ssl_days_remaining.is_some_and(|d| d < 30);
235            let bad_domain = r.domain_days_remaining.is_some_and(|d| d < 30);
236            let bad_issue = r
237                .issues
238                .iter()
239                .any(|i| i.contains("invalid") || i.contains("failed"));
240            bad_ssl || bad_domain || bad_issue
241        })
242        .count();
243    let warnings = results.iter().filter(|r| !r.issues.is_empty()).count();
244
245    WatchReport {
246        checked_at: Utc::now(),
247        results,
248        total,
249        warnings,
250        critical,
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn test_watchlist_default() {
260        let wl = Watchlist::default();
261        assert!(wl.domains.is_empty());
262    }
263
264    #[test]
265    fn test_watchlist_add_remove() {
266        let mut wl = Watchlist::default();
267        assert!(wl.add("example.com").unwrap());
268        assert!(!wl.add("example.com").unwrap()); // duplicate
269        assert_eq!(wl.domains.len(), 1);
270
271        assert!(wl.add("test.org").unwrap());
272        assert_eq!(wl.domains.len(), 2);
273        // Should be sorted
274        assert_eq!(wl.domains[0], "example.com");
275        assert_eq!(wl.domains[1], "test.org");
276
277        assert!(wl.remove("example.com"));
278        assert!(!wl.remove("example.com")); // already removed
279        assert_eq!(wl.domains.len(), 1);
280    }
281
282    #[test]
283    fn test_watchlist_add_normalizes_case() {
284        let mut wl = Watchlist::default();
285        wl.add("EXAMPLE.COM").unwrap();
286        assert_eq!(wl.domains[0], "example.com");
287    }
288
289    #[test]
290    fn test_watchlist_serialization() {
291        let mut wl = Watchlist::default();
292        wl.add("a.com").unwrap();
293        wl.add("b.org").unwrap();
294        let toml_str = toml::to_string_pretty(&wl).unwrap();
295        assert!(toml_str.contains("a.com"));
296        assert!(toml_str.contains("b.org"));
297
298        let parsed: Watchlist = toml::from_str(&toml_str).unwrap();
299        assert_eq!(parsed.domains.len(), 2);
300    }
301
302    /// Creates a unique temporary file path for a load-from-disk test.
303    fn unique_temp_watchlist_path(tag: &str) -> PathBuf {
304        let mut dir = std::env::temp_dir();
305        dir.push(format!(
306            "seer-watchlist-test-{}-{}",
307            tag,
308            std::process::id()
309        ));
310        let _ = std::fs::create_dir_all(&dir);
311        dir.push("watchlist.toml");
312        dir
313    }
314
315    #[test]
316    fn load_from_path_returns_default_and_backs_up_corrupt_file() {
317        let path = unique_temp_watchlist_path("corrupt");
318        let backup = path.with_extension("corrupt");
319
320        let _ = std::fs::remove_file(&path);
321        let _ = std::fs::remove_file(&backup);
322
323        // TOML parsers reject stray garbage on the value side of `=`.
324        std::fs::write(&path, b"domains = not-an-array-\n").expect("seed corrupt watchlist file");
325
326        let loaded = Watchlist::load_from_path(&path);
327        assert!(
328            loaded.domains.is_empty(),
329            "corrupt watchlist must load as empty default"
330        );
331        assert!(
332            !path.exists(),
333            "original corrupt file should have been renamed away"
334        );
335        assert!(
336            backup.exists(),
337            "backup .corrupt file should exist at {}",
338            backup.display()
339        );
340
341        let _ = std::fs::remove_file(&backup);
342        if let Some(parent) = path.parent() {
343            let _ = std::fs::remove_dir_all(parent);
344        }
345    }
346
347    #[test]
348    fn load_from_path_returns_default_when_missing() {
349        let path = unique_temp_watchlist_path("missing");
350        let _ = std::fs::remove_file(&path);
351
352        let loaded = Watchlist::load_from_path(&path);
353        assert!(loaded.domains.is_empty());
354
355        if let Some(parent) = path.parent() {
356            let _ = std::fs::remove_dir_all(parent);
357        }
358    }
359
360    #[test]
361    fn test_watch_result_serialization() {
362        let result = WatchResult {
363            domain: "example.com".to_string(),
364            ssl_days_remaining: Some(45),
365            domain_days_remaining: Some(120),
366            registrar: Some("Test Registrar".to_string()),
367            http_status: Some(200),
368            issues: vec![],
369        };
370        let json = serde_json::to_string(&result).unwrap();
371        assert!(json.contains("example.com"));
372        assert!(json.contains("45"));
373    }
374}