Skip to main content

seer_core/dns/
follow.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use tokio::sync::watch;
8use tracing::{debug, instrument};
9
10use super::records::{DnsRecord, RecordType};
11use super::resolver::DnsResolver;
12use crate::error::{Result, SeerError};
13
14/// Upper bound on follow iterations. A non-interactive caller (API / Python
15/// bindings) passing a huge count would otherwise schedule an effectively
16/// unbounded long-running loop; the interval is already capped at 60 minutes.
17const MAX_FOLLOW_ITERATIONS: usize = 10_000;
18
19/// Configuration for DNS follow operation
20#[derive(Debug, Clone)]
21pub struct FollowConfig {
22    /// Number of checks to perform
23    pub iterations: usize,
24    /// Interval between checks in seconds
25    pub interval_secs: u64,
26    /// Only output when records change
27    pub changes_only: bool,
28}
29
30impl Default for FollowConfig {
31    fn default() -> Self {
32        Self {
33            iterations: 10,
34            interval_secs: 60,
35            changes_only: false,
36        }
37    }
38}
39
40impl FollowConfig {
41    /// Construct a new `FollowConfig`.
42    ///
43    /// Validates:
44    /// - `iterations` must be >= 1
45    /// - `interval_minutes` must be finite (not NaN / infinity)
46    /// - `interval_minutes` must be non-negative
47    /// - `interval_minutes` must be at most 60
48    pub fn new(iterations: usize, interval_minutes: f64) -> Result<Self> {
49        if iterations == 0 {
50            return Err(SeerError::InvalidInput(
51                "iterations must be at least 1".into(),
52            ));
53        }
54        if iterations > MAX_FOLLOW_ITERATIONS {
55            return Err(SeerError::InvalidInput(format!(
56                "iterations must be at most {MAX_FOLLOW_ITERATIONS}"
57            )));
58        }
59        if !interval_minutes.is_finite() {
60            return Err(SeerError::InvalidInput(
61                "interval_minutes must be a finite number".into(),
62            ));
63        }
64        if interval_minutes < 0.0 {
65            return Err(SeerError::InvalidInput(
66                "interval_minutes must be non-negative".into(),
67            ));
68        }
69        if interval_minutes > 60.0 {
70            return Err(SeerError::InvalidInput(
71                "interval_minutes must be at most 60".into(),
72            ));
73        }
74        Ok(Self {
75            iterations,
76            interval_secs: (interval_minutes * 60.0) as u64,
77            changes_only: false,
78        })
79    }
80
81    pub fn with_changes_only(mut self, changes_only: bool) -> Self {
82        self.changes_only = changes_only;
83        self
84    }
85}
86
87/// Result of a single follow iteration
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct FollowIteration {
90    /// Iteration number (1-based)
91    pub iteration: usize,
92    /// Total number of iterations
93    pub total_iterations: usize,
94    /// Timestamp of the check
95    pub timestamp: DateTime<Utc>,
96    /// Records found (or empty if error/NXDOMAIN)
97    pub records: Vec<DnsRecord>,
98    /// Whether records changed from previous iteration
99    pub changed: bool,
100    /// Values added since previous iteration
101    pub added: Vec<String>,
102    /// Values removed since previous iteration
103    pub removed: Vec<String>,
104    /// Error message if the check failed
105    pub error: Option<String>,
106}
107
108impl FollowIteration {
109    pub fn success(&self) -> bool {
110        self.error.is_none()
111    }
112
113    pub fn record_count(&self) -> usize {
114        self.records.len()
115    }
116}
117
118/// Complete result of a follow operation
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct FollowResult {
121    /// Domain that was monitored
122    pub domain: String,
123    /// Record type that was monitored
124    pub record_type: RecordType,
125    /// Nameserver used (if custom)
126    pub nameserver: Option<String>,
127    /// Configuration used
128    pub iterations_requested: usize,
129    pub interval_secs: u64,
130    /// All iteration results
131    pub iterations: Vec<FollowIteration>,
132    /// Whether the operation was interrupted
133    pub interrupted: bool,
134    /// Total number of changes detected
135    pub total_changes: usize,
136    /// Start time
137    pub started_at: DateTime<Utc>,
138    /// End time
139    pub ended_at: DateTime<Utc>,
140}
141
142impl FollowResult {
143    pub fn completed_iterations(&self) -> usize {
144        self.iterations.len()
145    }
146
147    pub fn successful_iterations(&self) -> usize {
148        self.iterations.iter().filter(|i| i.success()).count()
149    }
150
151    pub fn failed_iterations(&self) -> usize {
152        self.iterations.iter().filter(|i| !i.success()).count()
153    }
154}
155
156/// Callback type for real-time progress updates
157pub type FollowProgressCallback = Arc<dyn Fn(&FollowIteration) + Send + Sync>;
158
159/// DNS Follower - monitors DNS records over time
160#[derive(Clone)]
161pub struct DnsFollower {
162    resolver: DnsResolver,
163}
164
165impl Default for DnsFollower {
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171impl DnsFollower {
172    pub fn new() -> Self {
173        Self {
174            resolver: DnsResolver::new(),
175        }
176    }
177
178    pub fn with_resolver(resolver: DnsResolver) -> Self {
179        Self { resolver }
180    }
181
182    /// Follow DNS records over time
183    #[instrument(skip(self, config, callback, cancel_rx))]
184    pub async fn follow(
185        &self,
186        domain: &str,
187        record_type: RecordType,
188        nameserver: Option<&str>,
189        config: FollowConfig,
190        callback: Option<FollowProgressCallback>,
191        cancel_rx: Option<watch::Receiver<bool>>,
192    ) -> Result<FollowResult> {
193        let domain = crate::validation::normalize_domain(domain)?;
194        let started_at = Utc::now();
195        let mut iterations: Vec<FollowIteration> = Vec::with_capacity(config.iterations);
196        let mut previous_values: HashSet<String> = HashSet::new();
197        let mut total_changes = 0;
198        let mut interrupted = false;
199
200        debug!(
201            domain = %domain,
202            record_type = %record_type,
203            iterations = config.iterations,
204            interval_secs = config.interval_secs,
205            "Starting DNS follow"
206        );
207
208        for i in 0..config.iterations {
209            // Check for cancellation
210            if let Some(ref rx) = cancel_rx {
211                if *rx.borrow() {
212                    debug!("Follow operation cancelled");
213                    interrupted = true;
214                    break;
215                }
216            }
217
218            let timestamp = Utc::now();
219            let iteration_num = i + 1;
220
221            // Perform DNS lookup
222            let (records, error) = match self
223                .resolver
224                .resolve(&domain, record_type, nameserver)
225                .await
226            {
227                Ok(records) => (records, None),
228                Err(e) => {
229                    debug!(domain = %domain, error = %e, "DNS follow query failed");
230                    // Sanitized for external return; full detail logged above.
231                    (Vec::new(), Some(e.sanitized_message()))
232                }
233            };
234
235            // Extract record values for comparison
236            let current_values: HashSet<String> =
237                records.iter().map(|r| r.data.to_string()).collect();
238
239            // Compare with previous iteration
240            let (changed, added, removed) = if i == 0 {
241                // First iteration - no previous to compare
242                (false, Vec::new(), Vec::new())
243            } else {
244                let added: Vec<String> = current_values
245                    .difference(&previous_values)
246                    .cloned()
247                    .collect();
248                let removed: Vec<String> = previous_values
249                    .difference(&current_values)
250                    .cloned()
251                    .collect();
252                let changed = !added.is_empty() || !removed.is_empty();
253                (changed, added, removed)
254            };
255
256            if changed {
257                total_changes += 1;
258            }
259
260            let iteration = FollowIteration {
261                iteration: iteration_num,
262                total_iterations: config.iterations,
263                timestamp,
264                records,
265                changed,
266                added,
267                removed,
268                error,
269            };
270
271            // Call progress callback
272            if let Some(ref cb) = callback {
273                // Only call if not changes_only mode, or if this is first iteration or changed
274                if !config.changes_only || iteration_num == 1 || changed {
275                    cb(&iteration);
276                }
277            }
278
279            iterations.push(iteration);
280            previous_values = current_values;
281
282            // Sleep before next iteration (unless this is the last one)
283            if i < config.iterations - 1 {
284                let sleep_duration = Duration::from_secs(config.interval_secs);
285
286                // Use interruptible sleep
287                if let Some(ref rx) = cancel_rx {
288                    let mut rx_clone = rx.clone();
289                    tokio::select! {
290                        _ = tokio::time::sleep(sleep_duration) => {}
291                        _ = rx_clone.changed() => {
292                            if *rx_clone.borrow() {
293                                debug!("Follow operation cancelled during sleep");
294                                interrupted = true;
295                                break;
296                            }
297                        }
298                    }
299                } else {
300                    tokio::time::sleep(sleep_duration).await;
301                }
302            }
303        }
304
305        let ended_at = Utc::now();
306
307        Ok(FollowResult {
308            domain: domain.to_string(),
309            record_type,
310            nameserver: nameserver.map(|s| s.to_string()),
311            iterations_requested: config.iterations,
312            interval_secs: config.interval_secs,
313            iterations,
314            interrupted,
315            total_changes,
316            started_at,
317            ended_at,
318        })
319    }
320
321    /// Simple follow without callback or cancellation
322    #[instrument(skip(self, config), fields(domain = %domain, record_type = ?record_type))]
323    pub async fn follow_simple(
324        &self,
325        domain: &str,
326        record_type: RecordType,
327        nameserver: Option<&str>,
328        config: FollowConfig,
329    ) -> Result<FollowResult> {
330        self.follow(domain, record_type, nameserver, config, None, None)
331            .await
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[tokio::test]
340    async fn test_follow_config_default() {
341        let config = FollowConfig::default();
342        assert_eq!(config.iterations, 10);
343        assert_eq!(config.interval_secs, 60);
344        assert!(!config.changes_only);
345    }
346
347    #[test]
348    fn follow_config_rejects_unbounded_iterations() {
349        assert!(FollowConfig::new(MAX_FOLLOW_ITERATIONS, 1.0).is_ok());
350        let err = FollowConfig::new(MAX_FOLLOW_ITERATIONS + 1, 1.0).unwrap_err();
351        assert!(matches!(err, SeerError::InvalidInput(_)));
352        assert!(FollowConfig::new(usize::MAX, 1.0).is_err());
353    }
354
355    #[tokio::test]
356    async fn test_follow_config_new() {
357        let config = FollowConfig::new(5, 0.5).unwrap();
358        assert_eq!(config.iterations, 5);
359        assert_eq!(config.interval_secs, 30);
360    }
361
362    #[tokio::test]
363    #[ignore = "live network; run with --ignored or SEER_LIVE_TESTS=1"]
364    async fn test_follow_single_iteration() {
365        let follower = DnsFollower::new();
366        let config = FollowConfig::new(1, 0.0).unwrap();
367
368        let result = follower
369            .follow_simple("example.com", RecordType::A, None, config)
370            .await;
371
372        assert!(result.is_ok());
373        let result = result.unwrap();
374        assert_eq!(result.completed_iterations(), 1);
375        assert!(!result.interrupted);
376    }
377
378    #[test]
379    fn follow_config_rejects_zero_iterations() {
380        assert!(FollowConfig::new(0, 1.0).is_err());
381    }
382
383    #[test]
384    fn follow_config_rejects_infinite_interval() {
385        assert!(FollowConfig::new(10, f64::INFINITY).is_err());
386        assert!(FollowConfig::new(10, f64::NEG_INFINITY).is_err());
387    }
388
389    #[test]
390    fn follow_config_rejects_nan_interval() {
391        assert!(FollowConfig::new(10, f64::NAN).is_err());
392    }
393
394    #[test]
395    fn follow_config_rejects_negative_interval() {
396        assert!(FollowConfig::new(10, -1.0).is_err());
397    }
398
399    #[test]
400    fn follow_config_rejects_interval_above_cap() {
401        assert!(FollowConfig::new(10, 60.1).is_err());
402    }
403
404    #[test]
405    fn follow_config_accepts_valid() {
406        assert!(FollowConfig::new(10, 1.5).is_ok());
407        assert!(FollowConfig::new(1, 0.0).is_ok());
408        assert!(FollowConfig::new(1, 60.0).is_ok());
409    }
410
411    #[tokio::test]
412    #[ignore = "live network; run with --ignored or SEER_LIVE_TESTS=1"]
413    async fn follow_honors_cancel() {
414        use tokio::sync::watch;
415
416        let (tx, rx) = watch::channel(false);
417        // 100 iterations with 30s intervals would take ~50 minutes.
418        let config = FollowConfig::new(100, 0.5).unwrap();
419        let follower = DnsFollower::new();
420
421        let handle = tokio::spawn(async move {
422            follower
423                .follow("example.com", RecordType::A, None, config, None, Some(rx))
424                .await
425        });
426
427        // Give the follow a tick to start and get into its first sleep.
428        tokio::time::sleep(Duration::from_millis(200)).await;
429        tx.send(true).unwrap();
430
431        let joined = tokio::time::timeout(Duration::from_secs(10), handle)
432            .await
433            .expect("follow should return promptly after cancel");
434        let result = joined.expect("join").expect("follow result");
435        assert!(result.interrupted, "follow should be interrupted");
436        assert!(
437            result.completed_iterations() < 100,
438            "should not complete all iterations"
439        );
440    }
441}