Skip to main content

seer_core/dns/
follow.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use tokio::sync::watch;
8use tracing::{debug, instrument};
9
10use super::records::{DnsRecord, RecordType};
11use super::resolver::DnsResolver;
12use crate::error::{Result, SeerError};
13
14/// Upper bound on follow iterations. A non-interactive caller (API / Python
15/// bindings) passing a huge count would otherwise schedule an effectively
16/// unbounded long-running loop; the interval is already capped at 60 minutes.
17const MAX_FOLLOW_ITERATIONS: usize = 10_000;
18
19/// Configuration for DNS follow operation
20#[derive(Debug, Clone)]
21pub struct FollowConfig {
22    /// Number of checks to perform
23    pub iterations: usize,
24    /// Interval between checks in seconds
25    pub interval_secs: u64,
26    /// Only output when records change
27    pub changes_only: bool,
28}
29
30impl Default for FollowConfig {
31    fn default() -> Self {
32        Self {
33            iterations: 10,
34            interval_secs: 60,
35            changes_only: false,
36        }
37    }
38}
39
40impl FollowConfig {
41    /// Construct a new `FollowConfig`.
42    ///
43    /// Validates:
44    /// - `iterations` must be >= 1
45    /// - `interval_minutes` must be finite (not NaN / infinity)
46    /// - `interval_minutes` must be non-negative
47    /// - `interval_minutes` must be at most 60
48    pub fn new(iterations: usize, interval_minutes: f64) -> Result<Self> {
49        if iterations == 0 {
50            return Err(SeerError::InvalidInput(
51                "iterations must be at least 1".into(),
52            ));
53        }
54        if iterations > MAX_FOLLOW_ITERATIONS {
55            return Err(SeerError::InvalidInput(format!(
56                "iterations must be at most {MAX_FOLLOW_ITERATIONS}"
57            )));
58        }
59        if !interval_minutes.is_finite() {
60            return Err(SeerError::InvalidInput(
61                "interval_minutes must be a finite number".into(),
62            ));
63        }
64        if interval_minutes < 0.0 {
65            return Err(SeerError::InvalidInput(
66                "interval_minutes must be non-negative".into(),
67            ));
68        }
69        if interval_minutes > 60.0 {
70            return Err(SeerError::InvalidInput(
71                "interval_minutes must be at most 60".into(),
72            ));
73        }
74        // A sub-second interval truncates to 0 seconds. For a multi-iteration
75        // follow that means back-to-back live DNS queries with no spacing — a
76        // self-inflicted query flood. Floor to 1s whenever more than one
77        // iteration will run; a single-shot follow (iterations == 1) does no
78        // looping and may keep a 0s interval.
79        let mut interval_secs = (interval_minutes * 60.0) as u64;
80        if iterations > 1 {
81            interval_secs = interval_secs.max(1);
82        }
83        Ok(Self {
84            iterations,
85            interval_secs,
86            changes_only: false,
87        })
88    }
89
90    pub fn with_changes_only(mut self, changes_only: bool) -> Self {
91        self.changes_only = changes_only;
92        self
93    }
94}
95
96/// Result of a single follow iteration
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct FollowIteration {
99    /// Iteration number (1-based)
100    pub iteration: usize,
101    /// Total number of iterations
102    pub total_iterations: usize,
103    /// Timestamp of the check
104    pub timestamp: DateTime<Utc>,
105    /// Records found (or empty if error/NXDOMAIN)
106    pub records: Vec<DnsRecord>,
107    /// Whether records changed from previous iteration
108    pub changed: bool,
109    /// Values added since previous iteration
110    pub added: Vec<String>,
111    /// Values removed since previous iteration
112    pub removed: Vec<String>,
113    /// Error message if the check failed
114    pub error: Option<String>,
115}
116
117impl FollowIteration {
118    pub fn success(&self) -> bool {
119        self.error.is_none()
120    }
121
122    pub fn record_count(&self) -> usize {
123        self.records.len()
124    }
125}
126
127/// Complete result of a follow operation
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct FollowResult {
130    /// Domain that was monitored
131    pub domain: String,
132    /// Record type that was monitored
133    pub record_type: RecordType,
134    /// Nameserver used (if custom)
135    pub nameserver: Option<String>,
136    /// Configuration used
137    pub iterations_requested: usize,
138    pub interval_secs: u64,
139    /// All iteration results
140    pub iterations: Vec<FollowIteration>,
141    /// Whether the operation was interrupted
142    pub interrupted: bool,
143    /// Total number of changes detected
144    pub total_changes: usize,
145    /// Start time
146    pub started_at: DateTime<Utc>,
147    /// End time
148    pub ended_at: DateTime<Utc>,
149}
150
151impl FollowResult {
152    pub fn completed_iterations(&self) -> usize {
153        self.iterations.len()
154    }
155
156    pub fn successful_iterations(&self) -> usize {
157        self.iterations.iter().filter(|i| i.success()).count()
158    }
159
160    pub fn failed_iterations(&self) -> usize {
161        self.iterations.iter().filter(|i| !i.success()).count()
162    }
163}
164
165/// Callback type for real-time progress updates
166pub type FollowProgressCallback = Arc<dyn Fn(&FollowIteration) + Send + Sync>;
167
168/// DNS Follower - monitors DNS records over time
169#[derive(Clone)]
170pub struct DnsFollower {
171    resolver: DnsResolver,
172}
173
174impl Default for DnsFollower {
175    fn default() -> Self {
176        Self::new()
177    }
178}
179
180impl DnsFollower {
181    pub fn new() -> Self {
182        Self {
183            resolver: DnsResolver::new(),
184        }
185    }
186
187    pub fn with_resolver(resolver: DnsResolver) -> Self {
188        Self { resolver }
189    }
190
191    /// Follow DNS records over time
192    #[instrument(skip(self, config, callback, cancel_rx))]
193    pub async fn follow(
194        &self,
195        domain: &str,
196        record_type: RecordType,
197        nameserver: Option<&str>,
198        config: FollowConfig,
199        callback: Option<FollowProgressCallback>,
200        cancel_rx: Option<watch::Receiver<bool>>,
201    ) -> Result<FollowResult> {
202        let domain = crate::validation::normalize_domain(domain)?;
203        let started_at = Utc::now();
204        let mut iterations: Vec<FollowIteration> = Vec::with_capacity(config.iterations);
205        let mut previous_values: HashSet<String> = HashSet::new();
206        let mut total_changes = 0;
207        let mut interrupted = false;
208
209        debug!(
210            domain = %domain,
211            record_type = %record_type,
212            iterations = config.iterations,
213            interval_secs = config.interval_secs,
214            "Starting DNS follow"
215        );
216
217        for i in 0..config.iterations {
218            // Check for cancellation
219            if let Some(ref rx) = cancel_rx {
220                if *rx.borrow() {
221                    debug!("Follow operation cancelled");
222                    interrupted = true;
223                    break;
224                }
225            }
226
227            let timestamp = Utc::now();
228            let iteration_num = i + 1;
229
230            // Perform DNS lookup
231            let (records, error) = match self
232                .resolver
233                .resolve(&domain, record_type, nameserver)
234                .await
235            {
236                Ok(records) => (records, None),
237                Err(e) => {
238                    debug!(domain = %domain, error = %e, "DNS follow query failed");
239                    // Sanitized for external return; full detail logged above.
240                    (Vec::new(), Some(e.sanitized_message()))
241                }
242            };
243
244            // Extract record values for comparison
245            let current_values: HashSet<String> =
246                records.iter().map(|r| r.data.to_string()).collect();
247
248            // Compare with previous iteration
249            let (changed, added, removed) = if i == 0 {
250                // First iteration - no previous to compare
251                (false, Vec::new(), Vec::new())
252            } else {
253                let added: Vec<String> = current_values
254                    .difference(&previous_values)
255                    .cloned()
256                    .collect();
257                let removed: Vec<String> = previous_values
258                    .difference(&current_values)
259                    .cloned()
260                    .collect();
261                let changed = !added.is_empty() || !removed.is_empty();
262                (changed, added, removed)
263            };
264
265            if changed {
266                total_changes += 1;
267            }
268
269            let iteration = FollowIteration {
270                iteration: iteration_num,
271                total_iterations: config.iterations,
272                timestamp,
273                records,
274                changed,
275                added,
276                removed,
277                error,
278            };
279
280            // Call progress callback
281            if let Some(ref cb) = callback {
282                // Only call if not changes_only mode, or if this is first iteration or changed
283                if !config.changes_only || iteration_num == 1 || changed {
284                    cb(&iteration);
285                }
286            }
287
288            iterations.push(iteration);
289            previous_values = current_values;
290
291            // Sleep before next iteration (unless this is the last one)
292            if i < config.iterations - 1 {
293                let sleep_duration = Duration::from_secs(config.interval_secs);
294
295                // Use interruptible sleep
296                if let Some(ref rx) = cancel_rx {
297                    let mut rx_clone = rx.clone();
298                    tokio::select! {
299                        _ = tokio::time::sleep(sleep_duration) => {}
300                        _ = rx_clone.changed() => {
301                            if *rx_clone.borrow() {
302                                debug!("Follow operation cancelled during sleep");
303                                interrupted = true;
304                                break;
305                            }
306                        }
307                    }
308                } else {
309                    tokio::time::sleep(sleep_duration).await;
310                }
311            }
312        }
313
314        let ended_at = Utc::now();
315
316        Ok(FollowResult {
317            domain: domain.to_string(),
318            record_type,
319            nameserver: nameserver.map(|s| s.to_string()),
320            iterations_requested: config.iterations,
321            interval_secs: config.interval_secs,
322            iterations,
323            interrupted,
324            total_changes,
325            started_at,
326            ended_at,
327        })
328    }
329
330    /// Simple follow without callback or cancellation
331    #[instrument(skip(self, config), fields(domain = %domain, record_type = ?record_type))]
332    pub async fn follow_simple(
333        &self,
334        domain: &str,
335        record_type: RecordType,
336        nameserver: Option<&str>,
337        config: FollowConfig,
338    ) -> Result<FollowResult> {
339        self.follow(domain, record_type, nameserver, config, None, None)
340            .await
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[tokio::test]
349    async fn test_follow_config_default() {
350        let config = FollowConfig::default();
351        assert_eq!(config.iterations, 10);
352        assert_eq!(config.interval_secs, 60);
353        assert!(!config.changes_only);
354    }
355
356    #[test]
357    fn follow_config_rejects_unbounded_iterations() {
358        assert!(FollowConfig::new(MAX_FOLLOW_ITERATIONS, 1.0).is_ok());
359        let err = FollowConfig::new(MAX_FOLLOW_ITERATIONS + 1, 1.0).unwrap_err();
360        assert!(matches!(err, SeerError::InvalidInput(_)));
361        assert!(FollowConfig::new(usize::MAX, 1.0).is_err());
362    }
363
364    #[tokio::test]
365    async fn test_follow_config_new() {
366        let config = FollowConfig::new(5, 0.5).unwrap();
367        assert_eq!(config.iterations, 5);
368        assert_eq!(config.interval_secs, 30);
369    }
370
371    #[tokio::test]
372    #[ignore = "live network; run with --ignored or SEER_LIVE_TESTS=1"]
373    async fn test_follow_single_iteration() {
374        let follower = DnsFollower::new();
375        let config = FollowConfig::new(1, 0.0).unwrap();
376
377        let result = follower
378            .follow_simple("example.com", RecordType::A, None, config)
379            .await;
380
381        assert!(result.is_ok());
382        let result = result.unwrap();
383        assert_eq!(result.completed_iterations(), 1);
384        assert!(!result.interrupted);
385    }
386
387    #[test]
388    fn follow_config_rejects_zero_iterations() {
389        assert!(FollowConfig::new(0, 1.0).is_err());
390    }
391
392    #[test]
393    fn follow_config_rejects_infinite_interval() {
394        assert!(FollowConfig::new(10, f64::INFINITY).is_err());
395        assert!(FollowConfig::new(10, f64::NEG_INFINITY).is_err());
396    }
397
398    #[test]
399    fn follow_config_rejects_nan_interval() {
400        assert!(FollowConfig::new(10, f64::NAN).is_err());
401    }
402
403    #[test]
404    fn follow_config_rejects_negative_interval() {
405        assert!(FollowConfig::new(10, -1.0).is_err());
406    }
407
408    #[test]
409    fn follow_config_rejects_interval_above_cap() {
410        assert!(FollowConfig::new(10, 60.1).is_err());
411    }
412
413    #[test]
414    fn follow_config_accepts_valid() {
415        assert!(FollowConfig::new(10, 1.5).is_ok());
416        assert!(FollowConfig::new(1, 0.0).is_ok());
417        assert!(FollowConfig::new(1, 60.0).is_ok());
418    }
419
420    #[test]
421    fn follow_config_floors_subsecond_interval_for_multi_iteration() {
422        // A sub-second interval truncates to 0s; with many iterations that is
423        // a back-to-back live-DNS query flood. Multi-iteration follows must
424        // be floored to at least 1s between queries.
425        let config = FollowConfig::new(10_000, 0.001).unwrap();
426        assert!(
427            config.interval_secs >= 1,
428            "multi-iteration interval must be floored to >= 1s, got {}",
429            config.interval_secs
430        );
431    }
432
433    #[test]
434    fn follow_config_allows_zero_interval_for_single_iteration() {
435        // A single-shot follow does no looping, so a 0s interval is harmless
436        // and must not be forced to 1s.
437        let config = FollowConfig::new(1, 0.0).unwrap();
438        assert_eq!(config.interval_secs, 0);
439    }
440
441    #[tokio::test]
442    #[ignore = "live network; run with --ignored or SEER_LIVE_TESTS=1"]
443    async fn follow_honors_cancel() {
444        use tokio::sync::watch;
445
446        let (tx, rx) = watch::channel(false);
447        // 100 iterations with 30s intervals would take ~50 minutes.
448        let config = FollowConfig::new(100, 0.5).unwrap();
449        let follower = DnsFollower::new();
450
451        let handle = tokio::spawn(async move {
452            follower
453                .follow("example.com", RecordType::A, None, config, None, Some(rx))
454                .await
455        });
456
457        // Give the follow a tick to start and get into its first sleep.
458        tokio::time::sleep(Duration::from_millis(200)).await;
459        tx.send(true).unwrap();
460
461        let joined = tokio::time::timeout(Duration::from_secs(10), handle)
462            .await
463            .expect("follow should return promptly after cancel");
464        let result = joined.expect("join").expect("follow result");
465        assert!(result.interrupted, "follow should be interrupted");
466        assert!(
467            result.completed_iterations() < 100,
468            "should not complete all iterations"
469        );
470    }
471}