Skip to main content

seer_core/dns/
follow.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use tokio::sync::watch;
8use tracing::{debug, instrument};
9
10use super::records::{DnsRecord, RecordType};
11use super::resolver::DnsResolver;
12use crate::error::Result;
13
14/// Configuration for DNS follow operation
15#[derive(Debug, Clone)]
16pub struct FollowConfig {
17    /// Number of checks to perform
18    pub iterations: usize,
19    /// Interval between checks in seconds
20    pub interval_secs: u64,
21    /// Only output when records change
22    pub changes_only: bool,
23}
24
25impl Default for FollowConfig {
26    fn default() -> Self {
27        Self {
28            iterations: 10,
29            interval_secs: 60,
30            changes_only: false,
31        }
32    }
33}
34
35impl FollowConfig {
36    pub fn new(iterations: usize, interval_minutes: f64) -> Self {
37        Self {
38            iterations,
39            interval_secs: (interval_minutes * 60.0) as u64,
40            changes_only: false,
41        }
42    }
43
44    pub fn with_changes_only(mut self, changes_only: bool) -> Self {
45        self.changes_only = changes_only;
46        self
47    }
48}
49
50/// Result of a single follow iteration
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct FollowIteration {
53    /// Iteration number (1-based)
54    pub iteration: usize,
55    /// Total number of iterations
56    pub total_iterations: usize,
57    /// Timestamp of the check
58    pub timestamp: DateTime<Utc>,
59    /// Records found (or empty if error/NXDOMAIN)
60    pub records: Vec<DnsRecord>,
61    /// Whether records changed from previous iteration
62    pub changed: bool,
63    /// Values added since previous iteration
64    pub added: Vec<String>,
65    /// Values removed since previous iteration
66    pub removed: Vec<String>,
67    /// Error message if the check failed
68    pub error: Option<String>,
69}
70
71impl FollowIteration {
72    pub fn success(&self) -> bool {
73        self.error.is_none()
74    }
75
76    pub fn record_count(&self) -> usize {
77        self.records.len()
78    }
79}
80
81/// Complete result of a follow operation
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct FollowResult {
84    /// Domain that was monitored
85    pub domain: String,
86    /// Record type that was monitored
87    pub record_type: RecordType,
88    /// Nameserver used (if custom)
89    pub nameserver: Option<String>,
90    /// Configuration used
91    pub iterations_requested: usize,
92    pub interval_secs: u64,
93    /// All iteration results
94    pub iterations: Vec<FollowIteration>,
95    /// Whether the operation was interrupted
96    pub interrupted: bool,
97    /// Total number of changes detected
98    pub total_changes: usize,
99    /// Start time
100    pub started_at: DateTime<Utc>,
101    /// End time
102    pub ended_at: DateTime<Utc>,
103}
104
105impl FollowResult {
106    pub fn completed_iterations(&self) -> usize {
107        self.iterations.len()
108    }
109
110    pub fn successful_iterations(&self) -> usize {
111        self.iterations.iter().filter(|i| i.success()).count()
112    }
113
114    pub fn failed_iterations(&self) -> usize {
115        self.iterations.iter().filter(|i| !i.success()).count()
116    }
117}
118
119/// Callback type for real-time progress updates
120pub type FollowProgressCallback = Arc<dyn Fn(&FollowIteration) + Send + Sync>;
121
122/// DNS Follower - monitors DNS records over time
123#[derive(Clone)]
124pub struct DnsFollower {
125    resolver: DnsResolver,
126}
127
128impl Default for DnsFollower {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134impl DnsFollower {
135    pub fn new() -> Self {
136        Self {
137            resolver: DnsResolver::new(),
138        }
139    }
140
141    pub fn with_resolver(resolver: DnsResolver) -> Self {
142        Self { resolver }
143    }
144
145    /// Follow DNS records over time
146    #[instrument(skip(self, config, callback, cancel_rx))]
147    pub async fn follow(
148        &self,
149        domain: &str,
150        record_type: RecordType,
151        nameserver: Option<&str>,
152        config: FollowConfig,
153        callback: Option<FollowProgressCallback>,
154        cancel_rx: Option<watch::Receiver<bool>>,
155    ) -> Result<FollowResult> {
156        let started_at = Utc::now();
157        let mut iterations: Vec<FollowIteration> = Vec::with_capacity(config.iterations);
158        let mut previous_values: HashSet<String> = HashSet::new();
159        let mut total_changes = 0;
160        let mut interrupted = false;
161
162        debug!(
163            domain = %domain,
164            record_type = %record_type,
165            iterations = config.iterations,
166            interval_secs = config.interval_secs,
167            "Starting DNS follow"
168        );
169
170        for i in 0..config.iterations {
171            // Check for cancellation
172            if let Some(ref rx) = cancel_rx {
173                if *rx.borrow() {
174                    debug!("Follow operation cancelled");
175                    interrupted = true;
176                    break;
177                }
178            }
179
180            let timestamp = Utc::now();
181            let iteration_num = i + 1;
182
183            // Perform DNS lookup
184            let (records, error) =
185                match self.resolver.resolve(domain, record_type, nameserver).await {
186                    Ok(records) => (records, None),
187                    Err(e) => (Vec::new(), Some(e.to_string())),
188                };
189
190            // Extract record values for comparison
191            let current_values: HashSet<String> =
192                records.iter().map(|r| r.data.to_string()).collect();
193
194            // Compare with previous iteration
195            let (changed, added, removed) = if i == 0 {
196                // First iteration - no previous to compare
197                (false, Vec::new(), Vec::new())
198            } else {
199                let added: Vec<String> = current_values
200                    .difference(&previous_values)
201                    .cloned()
202                    .collect();
203                let removed: Vec<String> = previous_values
204                    .difference(&current_values)
205                    .cloned()
206                    .collect();
207                let changed = !added.is_empty() || !removed.is_empty();
208                (changed, added, removed)
209            };
210
211            if changed {
212                total_changes += 1;
213            }
214
215            let iteration = FollowIteration {
216                iteration: iteration_num,
217                total_iterations: config.iterations,
218                timestamp,
219                records,
220                changed,
221                added,
222                removed,
223                error,
224            };
225
226            // Call progress callback
227            if let Some(ref cb) = callback {
228                // Only call if not changes_only mode, or if this is first iteration or changed
229                if !config.changes_only || iteration_num == 1 || changed {
230                    cb(&iteration);
231                }
232            }
233
234            iterations.push(iteration);
235            previous_values = current_values;
236
237            // Sleep before next iteration (unless this is the last one)
238            if i < config.iterations - 1 {
239                let sleep_duration = Duration::from_secs(config.interval_secs);
240
241                // Use interruptible sleep
242                if let Some(ref rx) = cancel_rx {
243                    let mut rx_clone = rx.clone();
244                    tokio::select! {
245                        _ = tokio::time::sleep(sleep_duration) => {}
246                        _ = rx_clone.changed() => {
247                            if *rx_clone.borrow() {
248                                debug!("Follow operation cancelled during sleep");
249                                interrupted = true;
250                                break;
251                            }
252                        }
253                    }
254                } else {
255                    tokio::time::sleep(sleep_duration).await;
256                }
257            }
258        }
259
260        let ended_at = Utc::now();
261
262        Ok(FollowResult {
263            domain: domain.to_string(),
264            record_type,
265            nameserver: nameserver.map(|s| s.to_string()),
266            iterations_requested: config.iterations,
267            interval_secs: config.interval_secs,
268            iterations,
269            interrupted,
270            total_changes,
271            started_at,
272            ended_at,
273        })
274    }
275
276    /// Simple follow without callback or cancellation
277    pub async fn follow_simple(
278        &self,
279        domain: &str,
280        record_type: RecordType,
281        nameserver: Option<&str>,
282        config: FollowConfig,
283    ) -> Result<FollowResult> {
284        self.follow(domain, record_type, nameserver, config, None, None)
285            .await
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[tokio::test]
294    async fn test_follow_config_default() {
295        let config = FollowConfig::default();
296        assert_eq!(config.iterations, 10);
297        assert_eq!(config.interval_secs, 60);
298        assert!(!config.changes_only);
299    }
300
301    #[tokio::test]
302    async fn test_follow_config_new() {
303        let config = FollowConfig::new(5, 0.5);
304        assert_eq!(config.iterations, 5);
305        assert_eq!(config.interval_secs, 30);
306    }
307
308    #[tokio::test]
309    async fn test_follow_single_iteration() {
310        let follower = DnsFollower::new();
311        let config = FollowConfig::new(1, 0.0);
312
313        let result = follower
314            .follow_simple("example.com", RecordType::A, None, config)
315            .await;
316
317        assert!(result.is_ok());
318        let result = result.unwrap();
319        assert_eq!(result.completed_iterations(), 1);
320        assert!(!result.interrupted);
321    }
322}