Skip to main content

seer_core/dns/
follow.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use tokio::sync::watch;
8use tracing::{debug, instrument};
9
10use super::records::{DnsRecord, RecordType};
11use super::resolver::DnsResolver;
12use crate::error::{Result, SeerError};
13
14/// Upper bound on follow iterations. A non-interactive caller (API / Python
15/// bindings) passing a huge count would otherwise schedule an effectively
16/// unbounded long-running loop; the interval is already capped at 60 minutes.
17const MAX_FOLLOW_ITERATIONS: usize = 10_000;
18
19/// Configuration for DNS follow operation
20#[derive(Debug, Clone)]
21pub struct FollowConfig {
22    /// Number of checks to perform
23    pub iterations: usize,
24    /// Interval between checks in seconds
25    pub interval_secs: u64,
26    /// Only output when records change
27    pub changes_only: bool,
28}
29
30impl Default for FollowConfig {
31    fn default() -> Self {
32        Self {
33            iterations: 10,
34            interval_secs: 60,
35            changes_only: false,
36        }
37    }
38}
39
40impl FollowConfig {
41    /// Construct a new `FollowConfig`.
42    ///
43    /// Validates:
44    /// - `iterations` must be >= 1
45    /// - `interval_minutes` must be finite (not NaN / infinity)
46    /// - `interval_minutes` must be non-negative
47    /// - `interval_minutes` must be at most 60
48    pub fn new(iterations: usize, interval_minutes: f64) -> Result<Self> {
49        if iterations == 0 {
50            return Err(SeerError::InvalidInput(
51                "iterations must be at least 1".into(),
52            ));
53        }
54        if iterations > MAX_FOLLOW_ITERATIONS {
55            return Err(SeerError::InvalidInput(format!(
56                "iterations must be at most {MAX_FOLLOW_ITERATIONS}"
57            )));
58        }
59        if !interval_minutes.is_finite() {
60            return Err(SeerError::InvalidInput(
61                "interval_minutes must be a finite number".into(),
62            ));
63        }
64        if interval_minutes < 0.0 {
65            return Err(SeerError::InvalidInput(
66                "interval_minutes must be non-negative".into(),
67            ));
68        }
69        if interval_minutes > 60.0 {
70            return Err(SeerError::InvalidInput(
71                "interval_minutes must be at most 60".into(),
72            ));
73        }
74        Ok(Self {
75            iterations,
76            interval_secs: (interval_minutes * 60.0) as u64,
77            changes_only: false,
78        })
79    }
80
81    pub fn with_changes_only(mut self, changes_only: bool) -> Self {
82        self.changes_only = changes_only;
83        self
84    }
85}
86
87/// Result of a single follow iteration
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct FollowIteration {
90    /// Iteration number (1-based)
91    pub iteration: usize,
92    /// Total number of iterations
93    pub total_iterations: usize,
94    /// Timestamp of the check
95    pub timestamp: DateTime<Utc>,
96    /// Records found (or empty if error/NXDOMAIN)
97    pub records: Vec<DnsRecord>,
98    /// Whether records changed from previous iteration
99    pub changed: bool,
100    /// Values added since previous iteration
101    pub added: Vec<String>,
102    /// Values removed since previous iteration
103    pub removed: Vec<String>,
104    /// Error message if the check failed
105    pub error: Option<String>,
106}
107
108impl FollowIteration {
109    pub fn success(&self) -> bool {
110        self.error.is_none()
111    }
112
113    pub fn record_count(&self) -> usize {
114        self.records.len()
115    }
116}
117
118/// Complete result of a follow operation
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct FollowResult {
121    /// Domain that was monitored
122    pub domain: String,
123    /// Record type that was monitored
124    pub record_type: RecordType,
125    /// Nameserver used (if custom)
126    pub nameserver: Option<String>,
127    /// Configuration used
128    pub iterations_requested: usize,
129    pub interval_secs: u64,
130    /// All iteration results
131    pub iterations: Vec<FollowIteration>,
132    /// Whether the operation was interrupted
133    pub interrupted: bool,
134    /// Total number of changes detected
135    pub total_changes: usize,
136    /// Start time
137    pub started_at: DateTime<Utc>,
138    /// End time
139    pub ended_at: DateTime<Utc>,
140}
141
142impl FollowResult {
143    pub fn completed_iterations(&self) -> usize {
144        self.iterations.len()
145    }
146
147    pub fn successful_iterations(&self) -> usize {
148        self.iterations.iter().filter(|i| i.success()).count()
149    }
150
151    pub fn failed_iterations(&self) -> usize {
152        self.iterations.iter().filter(|i| !i.success()).count()
153    }
154}
155
156/// Callback type for real-time progress updates
157pub type FollowProgressCallback = Arc<dyn Fn(&FollowIteration) + Send + Sync>;
158
159/// DNS Follower - monitors DNS records over time
160#[derive(Clone)]
161pub struct DnsFollower {
162    resolver: DnsResolver,
163}
164
165impl Default for DnsFollower {
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171impl DnsFollower {
172    pub fn new() -> Self {
173        Self {
174            resolver: DnsResolver::new(),
175        }
176    }
177
178    pub fn with_resolver(resolver: DnsResolver) -> Self {
179        Self { resolver }
180    }
181
182    /// Follow DNS records over time
183    #[instrument(skip(self, config, callback, cancel_rx))]
184    pub async fn follow(
185        &self,
186        domain: &str,
187        record_type: RecordType,
188        nameserver: Option<&str>,
189        config: FollowConfig,
190        callback: Option<FollowProgressCallback>,
191        cancel_rx: Option<watch::Receiver<bool>>,
192    ) -> Result<FollowResult> {
193        let domain = crate::validation::normalize_domain(domain)?;
194        let started_at = Utc::now();
195        let mut iterations: Vec<FollowIteration> = Vec::with_capacity(config.iterations);
196        let mut previous_values: HashSet<String> = HashSet::new();
197        let mut total_changes = 0;
198        let mut interrupted = false;
199
200        debug!(
201            domain = %domain,
202            record_type = %record_type,
203            iterations = config.iterations,
204            interval_secs = config.interval_secs,
205            "Starting DNS follow"
206        );
207
208        for i in 0..config.iterations {
209            // Check for cancellation
210            if let Some(ref rx) = cancel_rx {
211                if *rx.borrow() {
212                    debug!("Follow operation cancelled");
213                    interrupted = true;
214                    break;
215                }
216            }
217
218            let timestamp = Utc::now();
219            let iteration_num = i + 1;
220
221            // Perform DNS lookup
222            let (records, error) = match self
223                .resolver
224                .resolve(&domain, record_type, nameserver)
225                .await
226            {
227                Ok(records) => (records, None),
228                Err(e) => (Vec::new(), Some(e.to_string())),
229            };
230
231            // Extract record values for comparison
232            let current_values: HashSet<String> =
233                records.iter().map(|r| r.data.to_string()).collect();
234
235            // Compare with previous iteration
236            let (changed, added, removed) = if i == 0 {
237                // First iteration - no previous to compare
238                (false, Vec::new(), Vec::new())
239            } else {
240                let added: Vec<String> = current_values
241                    .difference(&previous_values)
242                    .cloned()
243                    .collect();
244                let removed: Vec<String> = previous_values
245                    .difference(&current_values)
246                    .cloned()
247                    .collect();
248                let changed = !added.is_empty() || !removed.is_empty();
249                (changed, added, removed)
250            };
251
252            if changed {
253                total_changes += 1;
254            }
255
256            let iteration = FollowIteration {
257                iteration: iteration_num,
258                total_iterations: config.iterations,
259                timestamp,
260                records,
261                changed,
262                added,
263                removed,
264                error,
265            };
266
267            // Call progress callback
268            if let Some(ref cb) = callback {
269                // Only call if not changes_only mode, or if this is first iteration or changed
270                if !config.changes_only || iteration_num == 1 || changed {
271                    cb(&iteration);
272                }
273            }
274
275            iterations.push(iteration);
276            previous_values = current_values;
277
278            // Sleep before next iteration (unless this is the last one)
279            if i < config.iterations - 1 {
280                let sleep_duration = Duration::from_secs(config.interval_secs);
281
282                // Use interruptible sleep
283                if let Some(ref rx) = cancel_rx {
284                    let mut rx_clone = rx.clone();
285                    tokio::select! {
286                        _ = tokio::time::sleep(sleep_duration) => {}
287                        _ = rx_clone.changed() => {
288                            if *rx_clone.borrow() {
289                                debug!("Follow operation cancelled during sleep");
290                                interrupted = true;
291                                break;
292                            }
293                        }
294                    }
295                } else {
296                    tokio::time::sleep(sleep_duration).await;
297                }
298            }
299        }
300
301        let ended_at = Utc::now();
302
303        Ok(FollowResult {
304            domain: domain.to_string(),
305            record_type,
306            nameserver: nameserver.map(|s| s.to_string()),
307            iterations_requested: config.iterations,
308            interval_secs: config.interval_secs,
309            iterations,
310            interrupted,
311            total_changes,
312            started_at,
313            ended_at,
314        })
315    }
316
317    /// Simple follow without callback or cancellation
318    #[instrument(skip(self, config), fields(domain = %domain, record_type = ?record_type))]
319    pub async fn follow_simple(
320        &self,
321        domain: &str,
322        record_type: RecordType,
323        nameserver: Option<&str>,
324        config: FollowConfig,
325    ) -> Result<FollowResult> {
326        self.follow(domain, record_type, nameserver, config, None, None)
327            .await
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    #[tokio::test]
336    async fn test_follow_config_default() {
337        let config = FollowConfig::default();
338        assert_eq!(config.iterations, 10);
339        assert_eq!(config.interval_secs, 60);
340        assert!(!config.changes_only);
341    }
342
343    #[test]
344    fn follow_config_rejects_unbounded_iterations() {
345        assert!(FollowConfig::new(MAX_FOLLOW_ITERATIONS, 1.0).is_ok());
346        let err = FollowConfig::new(MAX_FOLLOW_ITERATIONS + 1, 1.0).unwrap_err();
347        assert!(matches!(err, SeerError::InvalidInput(_)));
348        assert!(FollowConfig::new(usize::MAX, 1.0).is_err());
349    }
350
351    #[tokio::test]
352    async fn test_follow_config_new() {
353        let config = FollowConfig::new(5, 0.5).unwrap();
354        assert_eq!(config.iterations, 5);
355        assert_eq!(config.interval_secs, 30);
356    }
357
358    #[tokio::test]
359    #[ignore = "live network; run with --ignored or SEER_LIVE_TESTS=1"]
360    async fn test_follow_single_iteration() {
361        let follower = DnsFollower::new();
362        let config = FollowConfig::new(1, 0.0).unwrap();
363
364        let result = follower
365            .follow_simple("example.com", RecordType::A, None, config)
366            .await;
367
368        assert!(result.is_ok());
369        let result = result.unwrap();
370        assert_eq!(result.completed_iterations(), 1);
371        assert!(!result.interrupted);
372    }
373
374    #[test]
375    fn follow_config_rejects_zero_iterations() {
376        assert!(FollowConfig::new(0, 1.0).is_err());
377    }
378
379    #[test]
380    fn follow_config_rejects_infinite_interval() {
381        assert!(FollowConfig::new(10, f64::INFINITY).is_err());
382        assert!(FollowConfig::new(10, f64::NEG_INFINITY).is_err());
383    }
384
385    #[test]
386    fn follow_config_rejects_nan_interval() {
387        assert!(FollowConfig::new(10, f64::NAN).is_err());
388    }
389
390    #[test]
391    fn follow_config_rejects_negative_interval() {
392        assert!(FollowConfig::new(10, -1.0).is_err());
393    }
394
395    #[test]
396    fn follow_config_rejects_interval_above_cap() {
397        assert!(FollowConfig::new(10, 60.1).is_err());
398    }
399
400    #[test]
401    fn follow_config_accepts_valid() {
402        assert!(FollowConfig::new(10, 1.5).is_ok());
403        assert!(FollowConfig::new(1, 0.0).is_ok());
404        assert!(FollowConfig::new(1, 60.0).is_ok());
405    }
406
407    #[tokio::test]
408    #[ignore = "live network; run with --ignored or SEER_LIVE_TESTS=1"]
409    async fn follow_honors_cancel() {
410        use tokio::sync::watch;
411
412        let (tx, rx) = watch::channel(false);
413        // 100 iterations with 30s intervals would take ~50 minutes.
414        let config = FollowConfig::new(100, 0.5).unwrap();
415        let follower = DnsFollower::new();
416
417        let handle = tokio::spawn(async move {
418            follower
419                .follow("example.com", RecordType::A, None, config, None, Some(rx))
420                .await
421        });
422
423        // Give the follow a tick to start and get into its first sleep.
424        tokio::time::sleep(Duration::from_millis(200)).await;
425        tx.send(true).unwrap();
426
427        let joined = tokio::time::timeout(Duration::from_secs(10), handle)
428            .await
429            .expect("follow should return promptly after cancel");
430        let result = joined.expect("join").expect("follow result");
431        assert!(result.interrupted, "follow should be interrupted");
432        assert!(
433            result.completed_iterations() < 100,
434            "should not complete all iterations"
435        );
436    }
437}