seer_core/dns/
follow.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use tokio::sync::watch;
8use tracing::{debug, instrument};
9
10use super::records::{DnsRecord, RecordType};
11use super::resolver::DnsResolver;
12use crate::error::Result;
13
14/// Configuration for DNS follow operation
15#[derive(Debug, Clone)]
16pub struct FollowConfig {
17    /// Number of checks to perform
18    pub iterations: usize,
19    /// Interval between checks in seconds
20    pub interval_secs: u64,
21    /// Only output when records change
22    pub changes_only: bool,
23}
24
25impl Default for FollowConfig {
26    fn default() -> Self {
27        Self {
28            iterations: 10,
29            interval_secs: 60,
30            changes_only: false,
31        }
32    }
33}
34
35impl FollowConfig {
36    pub fn new(iterations: usize, interval_minutes: f64) -> Self {
37        Self {
38            iterations,
39            interval_secs: (interval_minutes * 60.0) as u64,
40            changes_only: false,
41        }
42    }
43
44    pub fn with_changes_only(mut self, changes_only: bool) -> Self {
45        self.changes_only = changes_only;
46        self
47    }
48}
49
50/// Result of a single follow iteration
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct FollowIteration {
53    /// Iteration number (1-based)
54    pub iteration: usize,
55    /// Total number of iterations
56    pub total_iterations: usize,
57    /// Timestamp of the check
58    pub timestamp: DateTime<Utc>,
59    /// Records found (or empty if error/NXDOMAIN)
60    pub records: Vec<DnsRecord>,
61    /// Whether records changed from previous iteration
62    pub changed: bool,
63    /// Values added since previous iteration
64    pub added: Vec<String>,
65    /// Values removed since previous iteration
66    pub removed: Vec<String>,
67    /// Error message if the check failed
68    pub error: Option<String>,
69}
70
71impl FollowIteration {
72    pub fn success(&self) -> bool {
73        self.error.is_none()
74    }
75
76    pub fn record_count(&self) -> usize {
77        self.records.len()
78    }
79}
80
81/// Complete result of a follow operation
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct FollowResult {
84    /// Domain that was monitored
85    pub domain: String,
86    /// Record type that was monitored
87    pub record_type: RecordType,
88    /// Nameserver used (if custom)
89    pub nameserver: Option<String>,
90    /// Configuration used
91    pub iterations_requested: usize,
92    pub interval_secs: u64,
93    /// All iteration results
94    pub iterations: Vec<FollowIteration>,
95    /// Whether the operation was interrupted
96    pub interrupted: bool,
97    /// Total number of changes detected
98    pub total_changes: usize,
99    /// Start time
100    pub started_at: DateTime<Utc>,
101    /// End time
102    pub ended_at: DateTime<Utc>,
103}
104
105impl FollowResult {
106    pub fn completed_iterations(&self) -> usize {
107        self.iterations.len()
108    }
109
110    pub fn successful_iterations(&self) -> usize {
111        self.iterations.iter().filter(|i| i.success()).count()
112    }
113
114    pub fn failed_iterations(&self) -> usize {
115        self.iterations.iter().filter(|i| !i.success()).count()
116    }
117}
118
119/// Callback type for real-time progress updates
120pub type FollowProgressCallback = Arc<dyn Fn(&FollowIteration) + Send + Sync>;
121
122/// DNS Follower - monitors DNS records over time
123#[derive(Clone)]
124pub struct DnsFollower {
125    resolver: DnsResolver,
126}
127
128impl Default for DnsFollower {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134impl DnsFollower {
135    pub fn new() -> Self {
136        Self {
137            resolver: DnsResolver::new(),
138        }
139    }
140
141    pub fn with_resolver(resolver: DnsResolver) -> Self {
142        Self { resolver }
143    }
144
145    /// Follow DNS records over time
146    #[instrument(skip(self, config, callback, cancel_rx))]
147    pub async fn follow(
148        &self,
149        domain: &str,
150        record_type: RecordType,
151        nameserver: Option<&str>,
152        config: FollowConfig,
153        callback: Option<FollowProgressCallback>,
154        cancel_rx: Option<watch::Receiver<bool>>,
155    ) -> Result<FollowResult> {
156        let started_at = Utc::now();
157        let mut iterations: Vec<FollowIteration> = Vec::with_capacity(config.iterations);
158        let mut previous_values: HashSet<String> = HashSet::new();
159        let mut total_changes = 0;
160        let mut interrupted = false;
161
162        debug!(
163            domain = %domain,
164            record_type = %record_type,
165            iterations = config.iterations,
166            interval_secs = config.interval_secs,
167            "Starting DNS follow"
168        );
169
170        for i in 0..config.iterations {
171            // Check for cancellation
172            if let Some(ref rx) = cancel_rx {
173                if *rx.borrow() {
174                    debug!("Follow operation cancelled");
175                    interrupted = true;
176                    break;
177                }
178            }
179
180            let timestamp = Utc::now();
181            let iteration_num = i + 1;
182
183            // Perform DNS lookup
184            let (records, error) = match self
185                .resolver
186                .resolve(domain, record_type, nameserver)
187                .await
188            {
189                Ok(records) => (records, None),
190                Err(e) => (Vec::new(), Some(e.to_string())),
191            };
192
193            // Extract record values for comparison
194            let current_values: HashSet<String> = records
195                .iter()
196                .map(|r| r.data.to_string())
197                .collect();
198
199            // Compare with previous iteration
200            let (changed, added, removed) = if i == 0 {
201                // First iteration - no previous to compare
202                (false, Vec::new(), Vec::new())
203            } else {
204                let added: Vec<String> = current_values
205                    .difference(&previous_values)
206                    .cloned()
207                    .collect();
208                let removed: Vec<String> = previous_values
209                    .difference(&current_values)
210                    .cloned()
211                    .collect();
212                let changed = !added.is_empty() || !removed.is_empty();
213                (changed, added, removed)
214            };
215
216            if changed {
217                total_changes += 1;
218            }
219
220            let iteration = FollowIteration {
221                iteration: iteration_num,
222                total_iterations: config.iterations,
223                timestamp,
224                records,
225                changed,
226                added,
227                removed,
228                error,
229            };
230
231            // Call progress callback
232            if let Some(ref cb) = callback {
233                // Only call if not changes_only mode, or if this is first iteration or changed
234                if !config.changes_only || iteration_num == 1 || changed {
235                    cb(&iteration);
236                }
237            }
238
239            iterations.push(iteration);
240            previous_values = current_values;
241
242            // Sleep before next iteration (unless this is the last one)
243            if i < config.iterations - 1 {
244                let sleep_duration = Duration::from_secs(config.interval_secs);
245
246                // Use interruptible sleep
247                if let Some(ref rx) = cancel_rx {
248                    let mut rx_clone = rx.clone();
249                    tokio::select! {
250                        _ = tokio::time::sleep(sleep_duration) => {}
251                        _ = rx_clone.changed() => {
252                            if *rx_clone.borrow() {
253                                debug!("Follow operation cancelled during sleep");
254                                interrupted = true;
255                                break;
256                            }
257                        }
258                    }
259                } else {
260                    tokio::time::sleep(sleep_duration).await;
261                }
262            }
263        }
264
265        let ended_at = Utc::now();
266
267        Ok(FollowResult {
268            domain: domain.to_string(),
269            record_type,
270            nameserver: nameserver.map(|s| s.to_string()),
271            iterations_requested: config.iterations,
272            interval_secs: config.interval_secs,
273            iterations,
274            interrupted,
275            total_changes,
276            started_at,
277            ended_at,
278        })
279    }
280
281    /// Simple follow without callback or cancellation
282    pub async fn follow_simple(
283        &self,
284        domain: &str,
285        record_type: RecordType,
286        nameserver: Option<&str>,
287        config: FollowConfig,
288    ) -> Result<FollowResult> {
289        self.follow(domain, record_type, nameserver, config, None, None)
290            .await
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[tokio::test]
299    async fn test_follow_config_default() {
300        let config = FollowConfig::default();
301        assert_eq!(config.iterations, 10);
302        assert_eq!(config.interval_secs, 60);
303        assert!(!config.changes_only);
304    }
305
306    #[tokio::test]
307    async fn test_follow_config_new() {
308        let config = FollowConfig::new(5, 0.5);
309        assert_eq!(config.iterations, 5);
310        assert_eq!(config.interval_secs, 30);
311    }
312
313    #[tokio::test]
314    async fn test_follow_single_iteration() {
315        let follower = DnsFollower::new();
316        let config = FollowConfig::new(1, 0.0);
317
318        let result = follower
319            .follow_simple("example.com", RecordType::A, None, config)
320            .await;
321
322        assert!(result.is_ok());
323        let result = result.unwrap();
324        assert_eq!(result.completed_iterations(), 1);
325        assert!(!result.interrupted);
326    }
327}