Skip to main content

seer_core/dns/
follow.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use tokio::sync::watch;
8use tracing::{debug, instrument};
9
10use super::records::{DnsRecord, RecordType};
11use super::resolver::DnsResolver;
12use crate::error::{Result, SeerError};
13
14/// Configuration for DNS follow operation
15#[derive(Debug, Clone)]
16pub struct FollowConfig {
17    /// Number of checks to perform
18    pub iterations: usize,
19    /// Interval between checks in seconds
20    pub interval_secs: u64,
21    /// Only output when records change
22    pub changes_only: bool,
23}
24
25impl Default for FollowConfig {
26    fn default() -> Self {
27        Self {
28            iterations: 10,
29            interval_secs: 60,
30            changes_only: false,
31        }
32    }
33}
34
35impl FollowConfig {
36    /// Construct a new `FollowConfig`.
37    ///
38    /// Validates:
39    /// - `iterations` must be >= 1
40    /// - `interval_minutes` must be finite (not NaN / infinity)
41    /// - `interval_minutes` must be non-negative
42    /// - `interval_minutes` must be at most 60
43    pub fn new(iterations: usize, interval_minutes: f64) -> Result<Self> {
44        if iterations == 0 {
45            return Err(SeerError::InvalidInput(
46                "iterations must be at least 1".into(),
47            ));
48        }
49        if !interval_minutes.is_finite() {
50            return Err(SeerError::InvalidInput(
51                "interval_minutes must be a finite number".into(),
52            ));
53        }
54        if interval_minutes < 0.0 {
55            return Err(SeerError::InvalidInput(
56                "interval_minutes must be non-negative".into(),
57            ));
58        }
59        if interval_minutes > 60.0 {
60            return Err(SeerError::InvalidInput(
61                "interval_minutes must be at most 60".into(),
62            ));
63        }
64        Ok(Self {
65            iterations,
66            interval_secs: (interval_minutes * 60.0) as u64,
67            changes_only: false,
68        })
69    }
70
71    pub fn with_changes_only(mut self, changes_only: bool) -> Self {
72        self.changes_only = changes_only;
73        self
74    }
75}
76
77/// Result of a single follow iteration
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct FollowIteration {
80    /// Iteration number (1-based)
81    pub iteration: usize,
82    /// Total number of iterations
83    pub total_iterations: usize,
84    /// Timestamp of the check
85    pub timestamp: DateTime<Utc>,
86    /// Records found (or empty if error/NXDOMAIN)
87    pub records: Vec<DnsRecord>,
88    /// Whether records changed from previous iteration
89    pub changed: bool,
90    /// Values added since previous iteration
91    pub added: Vec<String>,
92    /// Values removed since previous iteration
93    pub removed: Vec<String>,
94    /// Error message if the check failed
95    pub error: Option<String>,
96}
97
98impl FollowIteration {
99    pub fn success(&self) -> bool {
100        self.error.is_none()
101    }
102
103    pub fn record_count(&self) -> usize {
104        self.records.len()
105    }
106}
107
108/// Complete result of a follow operation
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct FollowResult {
111    /// Domain that was monitored
112    pub domain: String,
113    /// Record type that was monitored
114    pub record_type: RecordType,
115    /// Nameserver used (if custom)
116    pub nameserver: Option<String>,
117    /// Configuration used
118    pub iterations_requested: usize,
119    pub interval_secs: u64,
120    /// All iteration results
121    pub iterations: Vec<FollowIteration>,
122    /// Whether the operation was interrupted
123    pub interrupted: bool,
124    /// Total number of changes detected
125    pub total_changes: usize,
126    /// Start time
127    pub started_at: DateTime<Utc>,
128    /// End time
129    pub ended_at: DateTime<Utc>,
130}
131
132impl FollowResult {
133    pub fn completed_iterations(&self) -> usize {
134        self.iterations.len()
135    }
136
137    pub fn successful_iterations(&self) -> usize {
138        self.iterations.iter().filter(|i| i.success()).count()
139    }
140
141    pub fn failed_iterations(&self) -> usize {
142        self.iterations.iter().filter(|i| !i.success()).count()
143    }
144}
145
146/// Callback type for real-time progress updates
147pub type FollowProgressCallback = Arc<dyn Fn(&FollowIteration) + Send + Sync>;
148
149/// DNS Follower - monitors DNS records over time
150#[derive(Clone)]
151pub struct DnsFollower {
152    resolver: DnsResolver,
153}
154
155impl Default for DnsFollower {
156    fn default() -> Self {
157        Self::new()
158    }
159}
160
161impl DnsFollower {
162    pub fn new() -> Self {
163        Self {
164            resolver: DnsResolver::new(),
165        }
166    }
167
168    pub fn with_resolver(resolver: DnsResolver) -> Self {
169        Self { resolver }
170    }
171
172    /// Follow DNS records over time
173    #[instrument(skip(self, config, callback, cancel_rx))]
174    pub async fn follow(
175        &self,
176        domain: &str,
177        record_type: RecordType,
178        nameserver: Option<&str>,
179        config: FollowConfig,
180        callback: Option<FollowProgressCallback>,
181        cancel_rx: Option<watch::Receiver<bool>>,
182    ) -> Result<FollowResult> {
183        let domain = crate::validation::normalize_domain(domain)?;
184        let started_at = Utc::now();
185        let mut iterations: Vec<FollowIteration> = Vec::with_capacity(config.iterations);
186        let mut previous_values: HashSet<String> = HashSet::new();
187        let mut total_changes = 0;
188        let mut interrupted = false;
189
190        debug!(
191            domain = %domain,
192            record_type = %record_type,
193            iterations = config.iterations,
194            interval_secs = config.interval_secs,
195            "Starting DNS follow"
196        );
197
198        for i in 0..config.iterations {
199            // Check for cancellation
200            if let Some(ref rx) = cancel_rx {
201                if *rx.borrow() {
202                    debug!("Follow operation cancelled");
203                    interrupted = true;
204                    break;
205                }
206            }
207
208            let timestamp = Utc::now();
209            let iteration_num = i + 1;
210
211            // Perform DNS lookup
212            let (records, error) = match self
213                .resolver
214                .resolve(&domain, record_type, nameserver)
215                .await
216            {
217                Ok(records) => (records, None),
218                Err(e) => (Vec::new(), Some(e.to_string())),
219            };
220
221            // Extract record values for comparison
222            let current_values: HashSet<String> =
223                records.iter().map(|r| r.data.to_string()).collect();
224
225            // Compare with previous iteration
226            let (changed, added, removed) = if i == 0 {
227                // First iteration - no previous to compare
228                (false, Vec::new(), Vec::new())
229            } else {
230                let added: Vec<String> = current_values
231                    .difference(&previous_values)
232                    .cloned()
233                    .collect();
234                let removed: Vec<String> = previous_values
235                    .difference(&current_values)
236                    .cloned()
237                    .collect();
238                let changed = !added.is_empty() || !removed.is_empty();
239                (changed, added, removed)
240            };
241
242            if changed {
243                total_changes += 1;
244            }
245
246            let iteration = FollowIteration {
247                iteration: iteration_num,
248                total_iterations: config.iterations,
249                timestamp,
250                records,
251                changed,
252                added,
253                removed,
254                error,
255            };
256
257            // Call progress callback
258            if let Some(ref cb) = callback {
259                // Only call if not changes_only mode, or if this is first iteration or changed
260                if !config.changes_only || iteration_num == 1 || changed {
261                    cb(&iteration);
262                }
263            }
264
265            iterations.push(iteration);
266            previous_values = current_values;
267
268            // Sleep before next iteration (unless this is the last one)
269            if i < config.iterations - 1 {
270                let sleep_duration = Duration::from_secs(config.interval_secs);
271
272                // Use interruptible sleep
273                if let Some(ref rx) = cancel_rx {
274                    let mut rx_clone = rx.clone();
275                    tokio::select! {
276                        _ = tokio::time::sleep(sleep_duration) => {}
277                        _ = rx_clone.changed() => {
278                            if *rx_clone.borrow() {
279                                debug!("Follow operation cancelled during sleep");
280                                interrupted = true;
281                                break;
282                            }
283                        }
284                    }
285                } else {
286                    tokio::time::sleep(sleep_duration).await;
287                }
288            }
289        }
290
291        let ended_at = Utc::now();
292
293        Ok(FollowResult {
294            domain: domain.to_string(),
295            record_type,
296            nameserver: nameserver.map(|s| s.to_string()),
297            iterations_requested: config.iterations,
298            interval_secs: config.interval_secs,
299            iterations,
300            interrupted,
301            total_changes,
302            started_at,
303            ended_at,
304        })
305    }
306
307    /// Simple follow without callback or cancellation
308    #[instrument(skip(self, config), fields(domain = %domain, record_type = ?record_type))]
309    pub async fn follow_simple(
310        &self,
311        domain: &str,
312        record_type: RecordType,
313        nameserver: Option<&str>,
314        config: FollowConfig,
315    ) -> Result<FollowResult> {
316        self.follow(domain, record_type, nameserver, config, None, None)
317            .await
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[tokio::test]
326    async fn test_follow_config_default() {
327        let config = FollowConfig::default();
328        assert_eq!(config.iterations, 10);
329        assert_eq!(config.interval_secs, 60);
330        assert!(!config.changes_only);
331    }
332
333    #[tokio::test]
334    async fn test_follow_config_new() {
335        let config = FollowConfig::new(5, 0.5).unwrap();
336        assert_eq!(config.iterations, 5);
337        assert_eq!(config.interval_secs, 30);
338    }
339
340    #[tokio::test]
341    #[ignore = "live network; run with --ignored or SEER_LIVE_TESTS=1"]
342    async fn test_follow_single_iteration() {
343        let follower = DnsFollower::new();
344        let config = FollowConfig::new(1, 0.0).unwrap();
345
346        let result = follower
347            .follow_simple("example.com", RecordType::A, None, config)
348            .await;
349
350        assert!(result.is_ok());
351        let result = result.unwrap();
352        assert_eq!(result.completed_iterations(), 1);
353        assert!(!result.interrupted);
354    }
355
356    #[test]
357    fn follow_config_rejects_zero_iterations() {
358        assert!(FollowConfig::new(0, 1.0).is_err());
359    }
360
361    #[test]
362    fn follow_config_rejects_infinite_interval() {
363        assert!(FollowConfig::new(10, f64::INFINITY).is_err());
364        assert!(FollowConfig::new(10, f64::NEG_INFINITY).is_err());
365    }
366
367    #[test]
368    fn follow_config_rejects_nan_interval() {
369        assert!(FollowConfig::new(10, f64::NAN).is_err());
370    }
371
372    #[test]
373    fn follow_config_rejects_negative_interval() {
374        assert!(FollowConfig::new(10, -1.0).is_err());
375    }
376
377    #[test]
378    fn follow_config_rejects_interval_above_cap() {
379        assert!(FollowConfig::new(10, 60.1).is_err());
380    }
381
382    #[test]
383    fn follow_config_accepts_valid() {
384        assert!(FollowConfig::new(10, 1.5).is_ok());
385        assert!(FollowConfig::new(1, 0.0).is_ok());
386        assert!(FollowConfig::new(1, 60.0).is_ok());
387    }
388
389    #[tokio::test]
390    #[ignore = "live network; run with --ignored or SEER_LIVE_TESTS=1"]
391    async fn follow_honors_cancel() {
392        use tokio::sync::watch;
393
394        let (tx, rx) = watch::channel(false);
395        // 100 iterations with 30s intervals would take ~50 minutes.
396        let config = FollowConfig::new(100, 0.5).unwrap();
397        let follower = DnsFollower::new();
398
399        let handle = tokio::spawn(async move {
400            follower
401                .follow("example.com", RecordType::A, None, config, None, Some(rx))
402                .await
403        });
404
405        // Give the follow a tick to start and get into its first sleep.
406        tokio::time::sleep(Duration::from_millis(200)).await;
407        tx.send(true).unwrap();
408
409        let joined = tokio::time::timeout(Duration::from_secs(10), handle)
410            .await
411            .expect("follow should return promptly after cancel");
412        let result = joined.expect("join").expect("follow result");
413        assert!(result.interrupted, "follow should be interrupted");
414        assert!(
415            result.completed_iterations() < 100,
416            "should not complete all iterations"
417        );
418    }
419}