1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use tokio::sync::watch;
8use tracing::{debug, instrument};
9
10use super::records::{DnsRecord, RecordType};
11use super::resolver::DnsResolver;
12use crate::error::{Result, SeerError};
13
14const MAX_FOLLOW_ITERATIONS: usize = 10_000;
18
19#[derive(Debug, Clone)]
21pub struct FollowConfig {
22 pub iterations: usize,
24 pub interval_secs: u64,
26 pub changes_only: bool,
28}
29
30impl Default for FollowConfig {
31 fn default() -> Self {
32 Self {
33 iterations: 10,
34 interval_secs: 60,
35 changes_only: false,
36 }
37 }
38}
39
40impl FollowConfig {
41 pub fn new(iterations: usize, interval_minutes: f64) -> Result<Self> {
49 if iterations == 0 {
50 return Err(SeerError::InvalidInput(
51 "iterations must be at least 1".into(),
52 ));
53 }
54 if iterations > MAX_FOLLOW_ITERATIONS {
55 return Err(SeerError::InvalidInput(format!(
56 "iterations must be at most {MAX_FOLLOW_ITERATIONS}"
57 )));
58 }
59 if !interval_minutes.is_finite() {
60 return Err(SeerError::InvalidInput(
61 "interval_minutes must be a finite number".into(),
62 ));
63 }
64 if interval_minutes < 0.0 {
65 return Err(SeerError::InvalidInput(
66 "interval_minutes must be non-negative".into(),
67 ));
68 }
69 if interval_minutes > 60.0 {
70 return Err(SeerError::InvalidInput(
71 "interval_minutes must be at most 60".into(),
72 ));
73 }
74 Ok(Self {
75 iterations,
76 interval_secs: (interval_minutes * 60.0) as u64,
77 changes_only: false,
78 })
79 }
80
81 pub fn with_changes_only(mut self, changes_only: bool) -> Self {
82 self.changes_only = changes_only;
83 self
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct FollowIteration {
90 pub iteration: usize,
92 pub total_iterations: usize,
94 pub timestamp: DateTime<Utc>,
96 pub records: Vec<DnsRecord>,
98 pub changed: bool,
100 pub added: Vec<String>,
102 pub removed: Vec<String>,
104 pub error: Option<String>,
106}
107
108impl FollowIteration {
109 pub fn success(&self) -> bool {
110 self.error.is_none()
111 }
112
113 pub fn record_count(&self) -> usize {
114 self.records.len()
115 }
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct FollowResult {
121 pub domain: String,
123 pub record_type: RecordType,
125 pub nameserver: Option<String>,
127 pub iterations_requested: usize,
129 pub interval_secs: u64,
130 pub iterations: Vec<FollowIteration>,
132 pub interrupted: bool,
134 pub total_changes: usize,
136 pub started_at: DateTime<Utc>,
138 pub ended_at: DateTime<Utc>,
140}
141
142impl FollowResult {
143 pub fn completed_iterations(&self) -> usize {
144 self.iterations.len()
145 }
146
147 pub fn successful_iterations(&self) -> usize {
148 self.iterations.iter().filter(|i| i.success()).count()
149 }
150
151 pub fn failed_iterations(&self) -> usize {
152 self.iterations.iter().filter(|i| !i.success()).count()
153 }
154}
155
156pub type FollowProgressCallback = Arc<dyn Fn(&FollowIteration) + Send + Sync>;
158
159#[derive(Clone)]
161pub struct DnsFollower {
162 resolver: DnsResolver,
163}
164
165impl Default for DnsFollower {
166 fn default() -> Self {
167 Self::new()
168 }
169}
170
171impl DnsFollower {
172 pub fn new() -> Self {
173 Self {
174 resolver: DnsResolver::new(),
175 }
176 }
177
178 pub fn with_resolver(resolver: DnsResolver) -> Self {
179 Self { resolver }
180 }
181
182 #[instrument(skip(self, config, callback, cancel_rx))]
184 pub async fn follow(
185 &self,
186 domain: &str,
187 record_type: RecordType,
188 nameserver: Option<&str>,
189 config: FollowConfig,
190 callback: Option<FollowProgressCallback>,
191 cancel_rx: Option<watch::Receiver<bool>>,
192 ) -> Result<FollowResult> {
193 let domain = crate::validation::normalize_domain(domain)?;
194 let started_at = Utc::now();
195 let mut iterations: Vec<FollowIteration> = Vec::with_capacity(config.iterations);
196 let mut previous_values: HashSet<String> = HashSet::new();
197 let mut total_changes = 0;
198 let mut interrupted = false;
199
200 debug!(
201 domain = %domain,
202 record_type = %record_type,
203 iterations = config.iterations,
204 interval_secs = config.interval_secs,
205 "Starting DNS follow"
206 );
207
208 for i in 0..config.iterations {
209 if let Some(ref rx) = cancel_rx {
211 if *rx.borrow() {
212 debug!("Follow operation cancelled");
213 interrupted = true;
214 break;
215 }
216 }
217
218 let timestamp = Utc::now();
219 let iteration_num = i + 1;
220
221 let (records, error) = match self
223 .resolver
224 .resolve(&domain, record_type, nameserver)
225 .await
226 {
227 Ok(records) => (records, None),
228 Err(e) => {
229 debug!(domain = %domain, error = %e, "DNS follow query failed");
230 (Vec::new(), Some(e.sanitized_message()))
232 }
233 };
234
235 let current_values: HashSet<String> =
237 records.iter().map(|r| r.data.to_string()).collect();
238
239 let (changed, added, removed) = if i == 0 {
241 (false, Vec::new(), Vec::new())
243 } else {
244 let added: Vec<String> = current_values
245 .difference(&previous_values)
246 .cloned()
247 .collect();
248 let removed: Vec<String> = previous_values
249 .difference(¤t_values)
250 .cloned()
251 .collect();
252 let changed = !added.is_empty() || !removed.is_empty();
253 (changed, added, removed)
254 };
255
256 if changed {
257 total_changes += 1;
258 }
259
260 let iteration = FollowIteration {
261 iteration: iteration_num,
262 total_iterations: config.iterations,
263 timestamp,
264 records,
265 changed,
266 added,
267 removed,
268 error,
269 };
270
271 if let Some(ref cb) = callback {
273 if !config.changes_only || iteration_num == 1 || changed {
275 cb(&iteration);
276 }
277 }
278
279 iterations.push(iteration);
280 previous_values = current_values;
281
282 if i < config.iterations - 1 {
284 let sleep_duration = Duration::from_secs(config.interval_secs);
285
286 if let Some(ref rx) = cancel_rx {
288 let mut rx_clone = rx.clone();
289 tokio::select! {
290 _ = tokio::time::sleep(sleep_duration) => {}
291 _ = rx_clone.changed() => {
292 if *rx_clone.borrow() {
293 debug!("Follow operation cancelled during sleep");
294 interrupted = true;
295 break;
296 }
297 }
298 }
299 } else {
300 tokio::time::sleep(sleep_duration).await;
301 }
302 }
303 }
304
305 let ended_at = Utc::now();
306
307 Ok(FollowResult {
308 domain: domain.to_string(),
309 record_type,
310 nameserver: nameserver.map(|s| s.to_string()),
311 iterations_requested: config.iterations,
312 interval_secs: config.interval_secs,
313 iterations,
314 interrupted,
315 total_changes,
316 started_at,
317 ended_at,
318 })
319 }
320
321 #[instrument(skip(self, config), fields(domain = %domain, record_type = ?record_type))]
323 pub async fn follow_simple(
324 &self,
325 domain: &str,
326 record_type: RecordType,
327 nameserver: Option<&str>,
328 config: FollowConfig,
329 ) -> Result<FollowResult> {
330 self.follow(domain, record_type, nameserver, config, None, None)
331 .await
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[tokio::test]
340 async fn test_follow_config_default() {
341 let config = FollowConfig::default();
342 assert_eq!(config.iterations, 10);
343 assert_eq!(config.interval_secs, 60);
344 assert!(!config.changes_only);
345 }
346
347 #[test]
348 fn follow_config_rejects_unbounded_iterations() {
349 assert!(FollowConfig::new(MAX_FOLLOW_ITERATIONS, 1.0).is_ok());
350 let err = FollowConfig::new(MAX_FOLLOW_ITERATIONS + 1, 1.0).unwrap_err();
351 assert!(matches!(err, SeerError::InvalidInput(_)));
352 assert!(FollowConfig::new(usize::MAX, 1.0).is_err());
353 }
354
355 #[tokio::test]
356 async fn test_follow_config_new() {
357 let config = FollowConfig::new(5, 0.5).unwrap();
358 assert_eq!(config.iterations, 5);
359 assert_eq!(config.interval_secs, 30);
360 }
361
362 #[tokio::test]
363 #[ignore = "live network; run with --ignored or SEER_LIVE_TESTS=1"]
364 async fn test_follow_single_iteration() {
365 let follower = DnsFollower::new();
366 let config = FollowConfig::new(1, 0.0).unwrap();
367
368 let result = follower
369 .follow_simple("example.com", RecordType::A, None, config)
370 .await;
371
372 assert!(result.is_ok());
373 let result = result.unwrap();
374 assert_eq!(result.completed_iterations(), 1);
375 assert!(!result.interrupted);
376 }
377
378 #[test]
379 fn follow_config_rejects_zero_iterations() {
380 assert!(FollowConfig::new(0, 1.0).is_err());
381 }
382
383 #[test]
384 fn follow_config_rejects_infinite_interval() {
385 assert!(FollowConfig::new(10, f64::INFINITY).is_err());
386 assert!(FollowConfig::new(10, f64::NEG_INFINITY).is_err());
387 }
388
389 #[test]
390 fn follow_config_rejects_nan_interval() {
391 assert!(FollowConfig::new(10, f64::NAN).is_err());
392 }
393
394 #[test]
395 fn follow_config_rejects_negative_interval() {
396 assert!(FollowConfig::new(10, -1.0).is_err());
397 }
398
399 #[test]
400 fn follow_config_rejects_interval_above_cap() {
401 assert!(FollowConfig::new(10, 60.1).is_err());
402 }
403
404 #[test]
405 fn follow_config_accepts_valid() {
406 assert!(FollowConfig::new(10, 1.5).is_ok());
407 assert!(FollowConfig::new(1, 0.0).is_ok());
408 assert!(FollowConfig::new(1, 60.0).is_ok());
409 }
410
411 #[tokio::test]
412 #[ignore = "live network; run with --ignored or SEER_LIVE_TESTS=1"]
413 async fn follow_honors_cancel() {
414 use tokio::sync::watch;
415
416 let (tx, rx) = watch::channel(false);
417 let config = FollowConfig::new(100, 0.5).unwrap();
419 let follower = DnsFollower::new();
420
421 let handle = tokio::spawn(async move {
422 follower
423 .follow("example.com", RecordType::A, None, config, None, Some(rx))
424 .await
425 });
426
427 tokio::time::sleep(Duration::from_millis(200)).await;
429 tx.send(true).unwrap();
430
431 let joined = tokio::time::timeout(Duration::from_secs(10), handle)
432 .await
433 .expect("follow should return promptly after cancel");
434 let result = joined.expect("join").expect("follow result");
435 assert!(result.interrupted, "follow should be interrupted");
436 assert!(
437 result.completed_iterations() < 100,
438 "should not complete all iterations"
439 );
440 }
441}