1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use tokio::sync::watch;
8use tracing::{debug, instrument};
9
10use super::records::{DnsRecord, RecordType};
11use super::resolver::DnsResolver;
12use crate::error::{Result, SeerError};
13
14const MAX_FOLLOW_ITERATIONS: usize = 10_000;
18
19#[derive(Debug, Clone)]
21pub struct FollowConfig {
22 pub iterations: usize,
24 pub interval_secs: u64,
26 pub changes_only: bool,
28}
29
30impl Default for FollowConfig {
31 fn default() -> Self {
32 Self {
33 iterations: 10,
34 interval_secs: 60,
35 changes_only: false,
36 }
37 }
38}
39
40impl FollowConfig {
41 pub fn new(iterations: usize, interval_minutes: f64) -> Result<Self> {
49 if iterations == 0 {
50 return Err(SeerError::InvalidInput(
51 "iterations must be at least 1".into(),
52 ));
53 }
54 if iterations > MAX_FOLLOW_ITERATIONS {
55 return Err(SeerError::InvalidInput(format!(
56 "iterations must be at most {MAX_FOLLOW_ITERATIONS}"
57 )));
58 }
59 if !interval_minutes.is_finite() {
60 return Err(SeerError::InvalidInput(
61 "interval_minutes must be a finite number".into(),
62 ));
63 }
64 if interval_minutes < 0.0 {
65 return Err(SeerError::InvalidInput(
66 "interval_minutes must be non-negative".into(),
67 ));
68 }
69 if interval_minutes > 60.0 {
70 return Err(SeerError::InvalidInput(
71 "interval_minutes must be at most 60".into(),
72 ));
73 }
74 let mut interval_secs = (interval_minutes * 60.0) as u64;
80 if iterations > 1 {
81 interval_secs = interval_secs.max(1);
82 }
83 Ok(Self {
84 iterations,
85 interval_secs,
86 changes_only: false,
87 })
88 }
89
90 pub fn with_changes_only(mut self, changes_only: bool) -> Self {
91 self.changes_only = changes_only;
92 self
93 }
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct FollowIteration {
99 pub iteration: usize,
101 pub total_iterations: usize,
103 pub timestamp: DateTime<Utc>,
105 pub records: Vec<DnsRecord>,
107 pub changed: bool,
109 pub added: Vec<String>,
111 pub removed: Vec<String>,
113 pub error: Option<String>,
115}
116
117impl FollowIteration {
118 pub fn success(&self) -> bool {
119 self.error.is_none()
120 }
121
122 pub fn record_count(&self) -> usize {
123 self.records.len()
124 }
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct FollowResult {
130 pub domain: String,
132 pub record_type: RecordType,
134 pub nameserver: Option<String>,
136 pub iterations_requested: usize,
138 pub interval_secs: u64,
139 pub iterations: Vec<FollowIteration>,
141 pub interrupted: bool,
143 pub total_changes: usize,
145 pub started_at: DateTime<Utc>,
147 pub ended_at: DateTime<Utc>,
149}
150
151impl FollowResult {
152 pub fn completed_iterations(&self) -> usize {
153 self.iterations.len()
154 }
155
156 pub fn successful_iterations(&self) -> usize {
157 self.iterations.iter().filter(|i| i.success()).count()
158 }
159
160 pub fn failed_iterations(&self) -> usize {
161 self.iterations.iter().filter(|i| !i.success()).count()
162 }
163}
164
165pub type FollowProgressCallback = Arc<dyn Fn(&FollowIteration) + Send + Sync>;
167
168#[derive(Clone)]
170pub struct DnsFollower {
171 resolver: DnsResolver,
172}
173
174impl Default for DnsFollower {
175 fn default() -> Self {
176 Self::new()
177 }
178}
179
180impl DnsFollower {
181 pub fn new() -> Self {
182 Self {
183 resolver: DnsResolver::new(),
184 }
185 }
186
187 pub fn with_resolver(resolver: DnsResolver) -> Self {
188 Self { resolver }
189 }
190
191 #[instrument(skip(self, config, callback, cancel_rx))]
193 pub async fn follow(
194 &self,
195 domain: &str,
196 record_type: RecordType,
197 nameserver: Option<&str>,
198 config: FollowConfig,
199 callback: Option<FollowProgressCallback>,
200 cancel_rx: Option<watch::Receiver<bool>>,
201 ) -> Result<FollowResult> {
202 let domain = crate::validation::normalize_domain(domain)?;
203 let started_at = Utc::now();
204 let mut iterations: Vec<FollowIteration> = Vec::with_capacity(config.iterations);
205 let mut previous_values: HashSet<String> = HashSet::new();
206 let mut total_changes = 0;
207 let mut interrupted = false;
208
209 debug!(
210 domain = %domain,
211 record_type = %record_type,
212 iterations = config.iterations,
213 interval_secs = config.interval_secs,
214 "Starting DNS follow"
215 );
216
217 for i in 0..config.iterations {
218 if let Some(ref rx) = cancel_rx {
220 if *rx.borrow() {
221 debug!("Follow operation cancelled");
222 interrupted = true;
223 break;
224 }
225 }
226
227 let timestamp = Utc::now();
228 let iteration_num = i + 1;
229
230 let (records, error) = match self
232 .resolver
233 .resolve(&domain, record_type, nameserver)
234 .await
235 {
236 Ok(records) => (records, None),
237 Err(e) => {
238 debug!(domain = %domain, error = %e, "DNS follow query failed");
239 (Vec::new(), Some(e.sanitized_message()))
241 }
242 };
243
244 let current_values: HashSet<String> =
246 records.iter().map(|r| r.data.to_string()).collect();
247
248 let (changed, added, removed) = if i == 0 {
250 (false, Vec::new(), Vec::new())
252 } else {
253 let added: Vec<String> = current_values
254 .difference(&previous_values)
255 .cloned()
256 .collect();
257 let removed: Vec<String> = previous_values
258 .difference(¤t_values)
259 .cloned()
260 .collect();
261 let changed = !added.is_empty() || !removed.is_empty();
262 (changed, added, removed)
263 };
264
265 if changed {
266 total_changes += 1;
267 }
268
269 let iteration = FollowIteration {
270 iteration: iteration_num,
271 total_iterations: config.iterations,
272 timestamp,
273 records,
274 changed,
275 added,
276 removed,
277 error,
278 };
279
280 if let Some(ref cb) = callback {
282 if !config.changes_only || iteration_num == 1 || changed {
284 cb(&iteration);
285 }
286 }
287
288 iterations.push(iteration);
289 previous_values = current_values;
290
291 if i < config.iterations - 1 {
293 let sleep_duration = Duration::from_secs(config.interval_secs);
294
295 if let Some(ref rx) = cancel_rx {
297 let mut rx_clone = rx.clone();
298 tokio::select! {
299 _ = tokio::time::sleep(sleep_duration) => {}
300 _ = rx_clone.changed() => {
301 if *rx_clone.borrow() {
302 debug!("Follow operation cancelled during sleep");
303 interrupted = true;
304 break;
305 }
306 }
307 }
308 } else {
309 tokio::time::sleep(sleep_duration).await;
310 }
311 }
312 }
313
314 let ended_at = Utc::now();
315
316 Ok(FollowResult {
317 domain: domain.to_string(),
318 record_type,
319 nameserver: nameserver.map(|s| s.to_string()),
320 iterations_requested: config.iterations,
321 interval_secs: config.interval_secs,
322 iterations,
323 interrupted,
324 total_changes,
325 started_at,
326 ended_at,
327 })
328 }
329
330 #[instrument(skip(self, config), fields(domain = %domain, record_type = ?record_type))]
332 pub async fn follow_simple(
333 &self,
334 domain: &str,
335 record_type: RecordType,
336 nameserver: Option<&str>,
337 config: FollowConfig,
338 ) -> Result<FollowResult> {
339 self.follow(domain, record_type, nameserver, config, None, None)
340 .await
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[tokio::test]
349 async fn test_follow_config_default() {
350 let config = FollowConfig::default();
351 assert_eq!(config.iterations, 10);
352 assert_eq!(config.interval_secs, 60);
353 assert!(!config.changes_only);
354 }
355
356 #[test]
357 fn follow_config_rejects_unbounded_iterations() {
358 assert!(FollowConfig::new(MAX_FOLLOW_ITERATIONS, 1.0).is_ok());
359 let err = FollowConfig::new(MAX_FOLLOW_ITERATIONS + 1, 1.0).unwrap_err();
360 assert!(matches!(err, SeerError::InvalidInput(_)));
361 assert!(FollowConfig::new(usize::MAX, 1.0).is_err());
362 }
363
364 #[tokio::test]
365 async fn test_follow_config_new() {
366 let config = FollowConfig::new(5, 0.5).unwrap();
367 assert_eq!(config.iterations, 5);
368 assert_eq!(config.interval_secs, 30);
369 }
370
371 #[tokio::test]
372 #[ignore = "live network; run with --ignored or SEER_LIVE_TESTS=1"]
373 async fn test_follow_single_iteration() {
374 let follower = DnsFollower::new();
375 let config = FollowConfig::new(1, 0.0).unwrap();
376
377 let result = follower
378 .follow_simple("example.com", RecordType::A, None, config)
379 .await;
380
381 assert!(result.is_ok());
382 let result = result.unwrap();
383 assert_eq!(result.completed_iterations(), 1);
384 assert!(!result.interrupted);
385 }
386
387 #[test]
388 fn follow_config_rejects_zero_iterations() {
389 assert!(FollowConfig::new(0, 1.0).is_err());
390 }
391
392 #[test]
393 fn follow_config_rejects_infinite_interval() {
394 assert!(FollowConfig::new(10, f64::INFINITY).is_err());
395 assert!(FollowConfig::new(10, f64::NEG_INFINITY).is_err());
396 }
397
398 #[test]
399 fn follow_config_rejects_nan_interval() {
400 assert!(FollowConfig::new(10, f64::NAN).is_err());
401 }
402
403 #[test]
404 fn follow_config_rejects_negative_interval() {
405 assert!(FollowConfig::new(10, -1.0).is_err());
406 }
407
408 #[test]
409 fn follow_config_rejects_interval_above_cap() {
410 assert!(FollowConfig::new(10, 60.1).is_err());
411 }
412
413 #[test]
414 fn follow_config_accepts_valid() {
415 assert!(FollowConfig::new(10, 1.5).is_ok());
416 assert!(FollowConfig::new(1, 0.0).is_ok());
417 assert!(FollowConfig::new(1, 60.0).is_ok());
418 }
419
420 #[test]
421 fn follow_config_floors_subsecond_interval_for_multi_iteration() {
422 let config = FollowConfig::new(10_000, 0.001).unwrap();
426 assert!(
427 config.interval_secs >= 1,
428 "multi-iteration interval must be floored to >= 1s, got {}",
429 config.interval_secs
430 );
431 }
432
433 #[test]
434 fn follow_config_allows_zero_interval_for_single_iteration() {
435 let config = FollowConfig::new(1, 0.0).unwrap();
438 assert_eq!(config.interval_secs, 0);
439 }
440
441 #[tokio::test]
442 #[ignore = "live network; run with --ignored or SEER_LIVE_TESTS=1"]
443 async fn follow_honors_cancel() {
444 use tokio::sync::watch;
445
446 let (tx, rx) = watch::channel(false);
447 let config = FollowConfig::new(100, 0.5).unwrap();
449 let follower = DnsFollower::new();
450
451 let handle = tokio::spawn(async move {
452 follower
453 .follow("example.com", RecordType::A, None, config, None, Some(rx))
454 .await
455 });
456
457 tokio::time::sleep(Duration::from_millis(200)).await;
459 tx.send(true).unwrap();
460
461 let joined = tokio::time::timeout(Duration::from_secs(10), handle)
462 .await
463 .expect("follow should return promptly after cancel");
464 let result = joined.expect("join").expect("follow result");
465 assert!(result.interrupted, "follow should be interrupted");
466 assert!(
467 result.completed_iterations() < 100,
468 "should not complete all iterations"
469 );
470 }
471}