Skip to main content

seer_core/dns/
follow.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use tokio::sync::watch;
8use tracing::{debug, instrument};
9
10use super::records::{DnsRecord, RecordType};
11use super::resolver::DnsResolver;
12use crate::error::Result;
13
14/// Configuration for DNS follow operation
15#[derive(Debug, Clone)]
16pub struct FollowConfig {
17    /// Number of checks to perform
18    pub iterations: usize,
19    /// Interval between checks in seconds
20    pub interval_secs: u64,
21    /// Only output when records change
22    pub changes_only: bool,
23}
24
25impl Default for FollowConfig {
26    fn default() -> Self {
27        Self {
28            iterations: 10,
29            interval_secs: 60,
30            changes_only: false,
31        }
32    }
33}
34
35impl FollowConfig {
36    pub fn new(iterations: usize, interval_minutes: f64) -> Self {
37        Self {
38            iterations,
39            interval_secs: (interval_minutes * 60.0) as u64,
40            changes_only: false,
41        }
42    }
43
44    pub fn with_changes_only(mut self, changes_only: bool) -> Self {
45        self.changes_only = changes_only;
46        self
47    }
48}
49
50/// Result of a single follow iteration
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct FollowIteration {
53    /// Iteration number (1-based)
54    pub iteration: usize,
55    /// Total number of iterations
56    pub total_iterations: usize,
57    /// Timestamp of the check
58    pub timestamp: DateTime<Utc>,
59    /// Records found (or empty if error/NXDOMAIN)
60    pub records: Vec<DnsRecord>,
61    /// Whether records changed from previous iteration
62    pub changed: bool,
63    /// Values added since previous iteration
64    pub added: Vec<String>,
65    /// Values removed since previous iteration
66    pub removed: Vec<String>,
67    /// Error message if the check failed
68    pub error: Option<String>,
69}
70
71impl FollowIteration {
72    pub fn success(&self) -> bool {
73        self.error.is_none()
74    }
75
76    pub fn record_count(&self) -> usize {
77        self.records.len()
78    }
79}
80
81/// Complete result of a follow operation
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct FollowResult {
84    /// Domain that was monitored
85    pub domain: String,
86    /// Record type that was monitored
87    pub record_type: RecordType,
88    /// Nameserver used (if custom)
89    pub nameserver: Option<String>,
90    /// Configuration used
91    pub iterations_requested: usize,
92    pub interval_secs: u64,
93    /// All iteration results
94    pub iterations: Vec<FollowIteration>,
95    /// Whether the operation was interrupted
96    pub interrupted: bool,
97    /// Total number of changes detected
98    pub total_changes: usize,
99    /// Start time
100    pub started_at: DateTime<Utc>,
101    /// End time
102    pub ended_at: DateTime<Utc>,
103}
104
105impl FollowResult {
106    pub fn completed_iterations(&self) -> usize {
107        self.iterations.len()
108    }
109
110    pub fn successful_iterations(&self) -> usize {
111        self.iterations.iter().filter(|i| i.success()).count()
112    }
113
114    pub fn failed_iterations(&self) -> usize {
115        self.iterations.iter().filter(|i| !i.success()).count()
116    }
117}
118
119/// Callback type for real-time progress updates
120pub type FollowProgressCallback = Arc<dyn Fn(&FollowIteration) + Send + Sync>;
121
122/// DNS Follower - monitors DNS records over time
123#[derive(Clone)]
124pub struct DnsFollower {
125    resolver: DnsResolver,
126}
127
128impl Default for DnsFollower {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134impl DnsFollower {
135    pub fn new() -> Self {
136        Self {
137            resolver: DnsResolver::new(),
138        }
139    }
140
141    pub fn with_resolver(resolver: DnsResolver) -> Self {
142        Self { resolver }
143    }
144
145    /// Follow DNS records over time
146    #[instrument(skip(self, config, callback, cancel_rx))]
147    pub async fn follow(
148        &self,
149        domain: &str,
150        record_type: RecordType,
151        nameserver: Option<&str>,
152        config: FollowConfig,
153        callback: Option<FollowProgressCallback>,
154        cancel_rx: Option<watch::Receiver<bool>>,
155    ) -> Result<FollowResult> {
156        let domain = crate::validation::normalize_domain(domain)?;
157        let started_at = Utc::now();
158        let mut iterations: Vec<FollowIteration> = Vec::with_capacity(config.iterations);
159        let mut previous_values: HashSet<String> = HashSet::new();
160        let mut total_changes = 0;
161        let mut interrupted = false;
162
163        debug!(
164            domain = %domain,
165            record_type = %record_type,
166            iterations = config.iterations,
167            interval_secs = config.interval_secs,
168            "Starting DNS follow"
169        );
170
171        for i in 0..config.iterations {
172            // Check for cancellation
173            if let Some(ref rx) = cancel_rx {
174                if *rx.borrow() {
175                    debug!("Follow operation cancelled");
176                    interrupted = true;
177                    break;
178                }
179            }
180
181            let timestamp = Utc::now();
182            let iteration_num = i + 1;
183
184            // Perform DNS lookup
185            let (records, error) = match self
186                .resolver
187                .resolve(&domain, record_type, nameserver)
188                .await
189            {
190                Ok(records) => (records, None),
191                Err(e) => (Vec::new(), Some(e.to_string())),
192            };
193
194            // Extract record values for comparison
195            let current_values: HashSet<String> =
196                records.iter().map(|r| r.data.to_string()).collect();
197
198            // Compare with previous iteration
199            let (changed, added, removed) = if i == 0 {
200                // First iteration - no previous to compare
201                (false, Vec::new(), Vec::new())
202            } else {
203                let added: Vec<String> = current_values
204                    .difference(&previous_values)
205                    .cloned()
206                    .collect();
207                let removed: Vec<String> = previous_values
208                    .difference(&current_values)
209                    .cloned()
210                    .collect();
211                let changed = !added.is_empty() || !removed.is_empty();
212                (changed, added, removed)
213            };
214
215            if changed {
216                total_changes += 1;
217            }
218
219            let iteration = FollowIteration {
220                iteration: iteration_num,
221                total_iterations: config.iterations,
222                timestamp,
223                records,
224                changed,
225                added,
226                removed,
227                error,
228            };
229
230            // Call progress callback
231            if let Some(ref cb) = callback {
232                // Only call if not changes_only mode, or if this is first iteration or changed
233                if !config.changes_only || iteration_num == 1 || changed {
234                    cb(&iteration);
235                }
236            }
237
238            iterations.push(iteration);
239            previous_values = current_values;
240
241            // Sleep before next iteration (unless this is the last one)
242            if i < config.iterations - 1 {
243                let sleep_duration = Duration::from_secs(config.interval_secs);
244
245                // Use interruptible sleep
246                if let Some(ref rx) = cancel_rx {
247                    let mut rx_clone = rx.clone();
248                    tokio::select! {
249                        _ = tokio::time::sleep(sleep_duration) => {}
250                        _ = rx_clone.changed() => {
251                            if *rx_clone.borrow() {
252                                debug!("Follow operation cancelled during sleep");
253                                interrupted = true;
254                                break;
255                            }
256                        }
257                    }
258                } else {
259                    tokio::time::sleep(sleep_duration).await;
260                }
261            }
262        }
263
264        let ended_at = Utc::now();
265
266        Ok(FollowResult {
267            domain: domain.to_string(),
268            record_type,
269            nameserver: nameserver.map(|s| s.to_string()),
270            iterations_requested: config.iterations,
271            interval_secs: config.interval_secs,
272            iterations,
273            interrupted,
274            total_changes,
275            started_at,
276            ended_at,
277        })
278    }
279
280    /// Simple follow without callback or cancellation
281    #[instrument(skip(self, config), fields(domain = %domain, record_type = ?record_type))]
282    pub async fn follow_simple(
283        &self,
284        domain: &str,
285        record_type: RecordType,
286        nameserver: Option<&str>,
287        config: FollowConfig,
288    ) -> Result<FollowResult> {
289        self.follow(domain, record_type, nameserver, config, None, None)
290            .await
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[tokio::test]
299    async fn test_follow_config_default() {
300        let config = FollowConfig::default();
301        assert_eq!(config.iterations, 10);
302        assert_eq!(config.interval_secs, 60);
303        assert!(!config.changes_only);
304    }
305
306    #[tokio::test]
307    async fn test_follow_config_new() {
308        let config = FollowConfig::new(5, 0.5);
309        assert_eq!(config.iterations, 5);
310        assert_eq!(config.interval_secs, 30);
311    }
312
313    #[tokio::test]
314    async fn test_follow_single_iteration() {
315        let follower = DnsFollower::new();
316        let config = FollowConfig::new(1, 0.0);
317
318        let result = follower
319            .follow_simple("example.com", RecordType::A, None, config)
320            .await;
321
322        assert!(result.is_ok());
323        let result = result.unwrap();
324        assert_eq!(result.completed_iterations(), 1);
325        assert!(!result.interrupted);
326    }
327}