1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use tokio::sync::watch;
8use tracing::{debug, instrument};
9
10use super::records::{DnsRecord, RecordType};
11use super::resolver::DnsResolver;
12use crate::error::{Result, SeerError};
13
14const MAX_FOLLOW_ITERATIONS: usize = 10_000;
18
19#[derive(Debug, Clone)]
21pub struct FollowConfig {
22 pub iterations: usize,
24 pub interval_secs: u64,
26 pub changes_only: bool,
28}
29
30impl Default for FollowConfig {
31 fn default() -> Self {
32 Self {
33 iterations: 10,
34 interval_secs: 60,
35 changes_only: false,
36 }
37 }
38}
39
40impl FollowConfig {
41 pub fn new(iterations: usize, interval_minutes: f64) -> Result<Self> {
49 if iterations == 0 {
50 return Err(SeerError::InvalidInput(
51 "iterations must be at least 1".into(),
52 ));
53 }
54 if iterations > MAX_FOLLOW_ITERATIONS {
55 return Err(SeerError::InvalidInput(format!(
56 "iterations must be at most {MAX_FOLLOW_ITERATIONS}"
57 )));
58 }
59 if !interval_minutes.is_finite() {
60 return Err(SeerError::InvalidInput(
61 "interval_minutes must be a finite number".into(),
62 ));
63 }
64 if interval_minutes < 0.0 {
65 return Err(SeerError::InvalidInput(
66 "interval_minutes must be non-negative".into(),
67 ));
68 }
69 if interval_minutes > 60.0 {
70 return Err(SeerError::InvalidInput(
71 "interval_minutes must be at most 60".into(),
72 ));
73 }
74 Ok(Self {
75 iterations,
76 interval_secs: (interval_minutes * 60.0) as u64,
77 changes_only: false,
78 })
79 }
80
81 pub fn with_changes_only(mut self, changes_only: bool) -> Self {
82 self.changes_only = changes_only;
83 self
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct FollowIteration {
90 pub iteration: usize,
92 pub total_iterations: usize,
94 pub timestamp: DateTime<Utc>,
96 pub records: Vec<DnsRecord>,
98 pub changed: bool,
100 pub added: Vec<String>,
102 pub removed: Vec<String>,
104 pub error: Option<String>,
106}
107
108impl FollowIteration {
109 pub fn success(&self) -> bool {
110 self.error.is_none()
111 }
112
113 pub fn record_count(&self) -> usize {
114 self.records.len()
115 }
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct FollowResult {
121 pub domain: String,
123 pub record_type: RecordType,
125 pub nameserver: Option<String>,
127 pub iterations_requested: usize,
129 pub interval_secs: u64,
130 pub iterations: Vec<FollowIteration>,
132 pub interrupted: bool,
134 pub total_changes: usize,
136 pub started_at: DateTime<Utc>,
138 pub ended_at: DateTime<Utc>,
140}
141
142impl FollowResult {
143 pub fn completed_iterations(&self) -> usize {
144 self.iterations.len()
145 }
146
147 pub fn successful_iterations(&self) -> usize {
148 self.iterations.iter().filter(|i| i.success()).count()
149 }
150
151 pub fn failed_iterations(&self) -> usize {
152 self.iterations.iter().filter(|i| !i.success()).count()
153 }
154}
155
156pub type FollowProgressCallback = Arc<dyn Fn(&FollowIteration) + Send + Sync>;
158
159#[derive(Clone)]
161pub struct DnsFollower {
162 resolver: DnsResolver,
163}
164
165impl Default for DnsFollower {
166 fn default() -> Self {
167 Self::new()
168 }
169}
170
171impl DnsFollower {
172 pub fn new() -> Self {
173 Self {
174 resolver: DnsResolver::new(),
175 }
176 }
177
178 pub fn with_resolver(resolver: DnsResolver) -> Self {
179 Self { resolver }
180 }
181
182 #[instrument(skip(self, config, callback, cancel_rx))]
184 pub async fn follow(
185 &self,
186 domain: &str,
187 record_type: RecordType,
188 nameserver: Option<&str>,
189 config: FollowConfig,
190 callback: Option<FollowProgressCallback>,
191 cancel_rx: Option<watch::Receiver<bool>>,
192 ) -> Result<FollowResult> {
193 let domain = crate::validation::normalize_domain(domain)?;
194 let started_at = Utc::now();
195 let mut iterations: Vec<FollowIteration> = Vec::with_capacity(config.iterations);
196 let mut previous_values: HashSet<String> = HashSet::new();
197 let mut total_changes = 0;
198 let mut interrupted = false;
199
200 debug!(
201 domain = %domain,
202 record_type = %record_type,
203 iterations = config.iterations,
204 interval_secs = config.interval_secs,
205 "Starting DNS follow"
206 );
207
208 for i in 0..config.iterations {
209 if let Some(ref rx) = cancel_rx {
211 if *rx.borrow() {
212 debug!("Follow operation cancelled");
213 interrupted = true;
214 break;
215 }
216 }
217
218 let timestamp = Utc::now();
219 let iteration_num = i + 1;
220
221 let (records, error) = match self
223 .resolver
224 .resolve(&domain, record_type, nameserver)
225 .await
226 {
227 Ok(records) => (records, None),
228 Err(e) => (Vec::new(), Some(e.to_string())),
229 };
230
231 let current_values: HashSet<String> =
233 records.iter().map(|r| r.data.to_string()).collect();
234
235 let (changed, added, removed) = if i == 0 {
237 (false, Vec::new(), Vec::new())
239 } else {
240 let added: Vec<String> = current_values
241 .difference(&previous_values)
242 .cloned()
243 .collect();
244 let removed: Vec<String> = previous_values
245 .difference(¤t_values)
246 .cloned()
247 .collect();
248 let changed = !added.is_empty() || !removed.is_empty();
249 (changed, added, removed)
250 };
251
252 if changed {
253 total_changes += 1;
254 }
255
256 let iteration = FollowIteration {
257 iteration: iteration_num,
258 total_iterations: config.iterations,
259 timestamp,
260 records,
261 changed,
262 added,
263 removed,
264 error,
265 };
266
267 if let Some(ref cb) = callback {
269 if !config.changes_only || iteration_num == 1 || changed {
271 cb(&iteration);
272 }
273 }
274
275 iterations.push(iteration);
276 previous_values = current_values;
277
278 if i < config.iterations - 1 {
280 let sleep_duration = Duration::from_secs(config.interval_secs);
281
282 if let Some(ref rx) = cancel_rx {
284 let mut rx_clone = rx.clone();
285 tokio::select! {
286 _ = tokio::time::sleep(sleep_duration) => {}
287 _ = rx_clone.changed() => {
288 if *rx_clone.borrow() {
289 debug!("Follow operation cancelled during sleep");
290 interrupted = true;
291 break;
292 }
293 }
294 }
295 } else {
296 tokio::time::sleep(sleep_duration).await;
297 }
298 }
299 }
300
301 let ended_at = Utc::now();
302
303 Ok(FollowResult {
304 domain: domain.to_string(),
305 record_type,
306 nameserver: nameserver.map(|s| s.to_string()),
307 iterations_requested: config.iterations,
308 interval_secs: config.interval_secs,
309 iterations,
310 interrupted,
311 total_changes,
312 started_at,
313 ended_at,
314 })
315 }
316
317 #[instrument(skip(self, config), fields(domain = %domain, record_type = ?record_type))]
319 pub async fn follow_simple(
320 &self,
321 domain: &str,
322 record_type: RecordType,
323 nameserver: Option<&str>,
324 config: FollowConfig,
325 ) -> Result<FollowResult> {
326 self.follow(domain, record_type, nameserver, config, None, None)
327 .await
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 #[tokio::test]
336 async fn test_follow_config_default() {
337 let config = FollowConfig::default();
338 assert_eq!(config.iterations, 10);
339 assert_eq!(config.interval_secs, 60);
340 assert!(!config.changes_only);
341 }
342
343 #[test]
344 fn follow_config_rejects_unbounded_iterations() {
345 assert!(FollowConfig::new(MAX_FOLLOW_ITERATIONS, 1.0).is_ok());
346 let err = FollowConfig::new(MAX_FOLLOW_ITERATIONS + 1, 1.0).unwrap_err();
347 assert!(matches!(err, SeerError::InvalidInput(_)));
348 assert!(FollowConfig::new(usize::MAX, 1.0).is_err());
349 }
350
351 #[tokio::test]
352 async fn test_follow_config_new() {
353 let config = FollowConfig::new(5, 0.5).unwrap();
354 assert_eq!(config.iterations, 5);
355 assert_eq!(config.interval_secs, 30);
356 }
357
358 #[tokio::test]
359 #[ignore = "live network; run with --ignored or SEER_LIVE_TESTS=1"]
360 async fn test_follow_single_iteration() {
361 let follower = DnsFollower::new();
362 let config = FollowConfig::new(1, 0.0).unwrap();
363
364 let result = follower
365 .follow_simple("example.com", RecordType::A, None, config)
366 .await;
367
368 assert!(result.is_ok());
369 let result = result.unwrap();
370 assert_eq!(result.completed_iterations(), 1);
371 assert!(!result.interrupted);
372 }
373
374 #[test]
375 fn follow_config_rejects_zero_iterations() {
376 assert!(FollowConfig::new(0, 1.0).is_err());
377 }
378
379 #[test]
380 fn follow_config_rejects_infinite_interval() {
381 assert!(FollowConfig::new(10, f64::INFINITY).is_err());
382 assert!(FollowConfig::new(10, f64::NEG_INFINITY).is_err());
383 }
384
385 #[test]
386 fn follow_config_rejects_nan_interval() {
387 assert!(FollowConfig::new(10, f64::NAN).is_err());
388 }
389
390 #[test]
391 fn follow_config_rejects_negative_interval() {
392 assert!(FollowConfig::new(10, -1.0).is_err());
393 }
394
395 #[test]
396 fn follow_config_rejects_interval_above_cap() {
397 assert!(FollowConfig::new(10, 60.1).is_err());
398 }
399
400 #[test]
401 fn follow_config_accepts_valid() {
402 assert!(FollowConfig::new(10, 1.5).is_ok());
403 assert!(FollowConfig::new(1, 0.0).is_ok());
404 assert!(FollowConfig::new(1, 60.0).is_ok());
405 }
406
407 #[tokio::test]
408 #[ignore = "live network; run with --ignored or SEER_LIVE_TESTS=1"]
409 async fn follow_honors_cancel() {
410 use tokio::sync::watch;
411
412 let (tx, rx) = watch::channel(false);
413 let config = FollowConfig::new(100, 0.5).unwrap();
415 let follower = DnsFollower::new();
416
417 let handle = tokio::spawn(async move {
418 follower
419 .follow("example.com", RecordType::A, None, config, None, Some(rx))
420 .await
421 });
422
423 tokio::time::sleep(Duration::from_millis(200)).await;
425 tx.send(true).unwrap();
426
427 let joined = tokio::time::timeout(Duration::from_secs(10), handle)
428 .await
429 .expect("follow should return promptly after cancel");
430 let result = joined.expect("join").expect("follow result");
431 assert!(result.interrupted, "follow should be interrupted");
432 assert!(
433 result.completed_iterations() < 100,
434 "should not complete all iterations"
435 );
436 }
437}