1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use tokio::sync::watch;
8use tracing::{debug, instrument};
9
10use super::records::{DnsRecord, RecordType};
11use super::resolver::DnsResolver;
12use crate::error::{Result, SeerError};
13
14#[derive(Debug, Clone)]
16pub struct FollowConfig {
17 pub iterations: usize,
19 pub interval_secs: u64,
21 pub changes_only: bool,
23}
24
25impl Default for FollowConfig {
26 fn default() -> Self {
27 Self {
28 iterations: 10,
29 interval_secs: 60,
30 changes_only: false,
31 }
32 }
33}
34
35impl FollowConfig {
36 pub fn new(iterations: usize, interval_minutes: f64) -> Result<Self> {
44 if iterations == 0 {
45 return Err(SeerError::InvalidInput(
46 "iterations must be at least 1".into(),
47 ));
48 }
49 if !interval_minutes.is_finite() {
50 return Err(SeerError::InvalidInput(
51 "interval_minutes must be a finite number".into(),
52 ));
53 }
54 if interval_minutes < 0.0 {
55 return Err(SeerError::InvalidInput(
56 "interval_minutes must be non-negative".into(),
57 ));
58 }
59 if interval_minutes > 60.0 {
60 return Err(SeerError::InvalidInput(
61 "interval_minutes must be at most 60".into(),
62 ));
63 }
64 Ok(Self {
65 iterations,
66 interval_secs: (interval_minutes * 60.0) as u64,
67 changes_only: false,
68 })
69 }
70
71 pub fn with_changes_only(mut self, changes_only: bool) -> Self {
72 self.changes_only = changes_only;
73 self
74 }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct FollowIteration {
80 pub iteration: usize,
82 pub total_iterations: usize,
84 pub timestamp: DateTime<Utc>,
86 pub records: Vec<DnsRecord>,
88 pub changed: bool,
90 pub added: Vec<String>,
92 pub removed: Vec<String>,
94 pub error: Option<String>,
96}
97
98impl FollowIteration {
99 pub fn success(&self) -> bool {
100 self.error.is_none()
101 }
102
103 pub fn record_count(&self) -> usize {
104 self.records.len()
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct FollowResult {
111 pub domain: String,
113 pub record_type: RecordType,
115 pub nameserver: Option<String>,
117 pub iterations_requested: usize,
119 pub interval_secs: u64,
120 pub iterations: Vec<FollowIteration>,
122 pub interrupted: bool,
124 pub total_changes: usize,
126 pub started_at: DateTime<Utc>,
128 pub ended_at: DateTime<Utc>,
130}
131
132impl FollowResult {
133 pub fn completed_iterations(&self) -> usize {
134 self.iterations.len()
135 }
136
137 pub fn successful_iterations(&self) -> usize {
138 self.iterations.iter().filter(|i| i.success()).count()
139 }
140
141 pub fn failed_iterations(&self) -> usize {
142 self.iterations.iter().filter(|i| !i.success()).count()
143 }
144}
145
146pub type FollowProgressCallback = Arc<dyn Fn(&FollowIteration) + Send + Sync>;
148
149#[derive(Clone)]
151pub struct DnsFollower {
152 resolver: DnsResolver,
153}
154
155impl Default for DnsFollower {
156 fn default() -> Self {
157 Self::new()
158 }
159}
160
161impl DnsFollower {
162 pub fn new() -> Self {
163 Self {
164 resolver: DnsResolver::new(),
165 }
166 }
167
168 pub fn with_resolver(resolver: DnsResolver) -> Self {
169 Self { resolver }
170 }
171
172 #[instrument(skip(self, config, callback, cancel_rx))]
174 pub async fn follow(
175 &self,
176 domain: &str,
177 record_type: RecordType,
178 nameserver: Option<&str>,
179 config: FollowConfig,
180 callback: Option<FollowProgressCallback>,
181 cancel_rx: Option<watch::Receiver<bool>>,
182 ) -> Result<FollowResult> {
183 let domain = crate::validation::normalize_domain(domain)?;
184 let started_at = Utc::now();
185 let mut iterations: Vec<FollowIteration> = Vec::with_capacity(config.iterations);
186 let mut previous_values: HashSet<String> = HashSet::new();
187 let mut total_changes = 0;
188 let mut interrupted = false;
189
190 debug!(
191 domain = %domain,
192 record_type = %record_type,
193 iterations = config.iterations,
194 interval_secs = config.interval_secs,
195 "Starting DNS follow"
196 );
197
198 for i in 0..config.iterations {
199 if let Some(ref rx) = cancel_rx {
201 if *rx.borrow() {
202 debug!("Follow operation cancelled");
203 interrupted = true;
204 break;
205 }
206 }
207
208 let timestamp = Utc::now();
209 let iteration_num = i + 1;
210
211 let (records, error) = match self
213 .resolver
214 .resolve(&domain, record_type, nameserver)
215 .await
216 {
217 Ok(records) => (records, None),
218 Err(e) => (Vec::new(), Some(e.to_string())),
219 };
220
221 let current_values: HashSet<String> =
223 records.iter().map(|r| r.data.to_string()).collect();
224
225 let (changed, added, removed) = if i == 0 {
227 (false, Vec::new(), Vec::new())
229 } else {
230 let added: Vec<String> = current_values
231 .difference(&previous_values)
232 .cloned()
233 .collect();
234 let removed: Vec<String> = previous_values
235 .difference(¤t_values)
236 .cloned()
237 .collect();
238 let changed = !added.is_empty() || !removed.is_empty();
239 (changed, added, removed)
240 };
241
242 if changed {
243 total_changes += 1;
244 }
245
246 let iteration = FollowIteration {
247 iteration: iteration_num,
248 total_iterations: config.iterations,
249 timestamp,
250 records,
251 changed,
252 added,
253 removed,
254 error,
255 };
256
257 if let Some(ref cb) = callback {
259 if !config.changes_only || iteration_num == 1 || changed {
261 cb(&iteration);
262 }
263 }
264
265 iterations.push(iteration);
266 previous_values = current_values;
267
268 if i < config.iterations - 1 {
270 let sleep_duration = Duration::from_secs(config.interval_secs);
271
272 if let Some(ref rx) = cancel_rx {
274 let mut rx_clone = rx.clone();
275 tokio::select! {
276 _ = tokio::time::sleep(sleep_duration) => {}
277 _ = rx_clone.changed() => {
278 if *rx_clone.borrow() {
279 debug!("Follow operation cancelled during sleep");
280 interrupted = true;
281 break;
282 }
283 }
284 }
285 } else {
286 tokio::time::sleep(sleep_duration).await;
287 }
288 }
289 }
290
291 let ended_at = Utc::now();
292
293 Ok(FollowResult {
294 domain: domain.to_string(),
295 record_type,
296 nameserver: nameserver.map(|s| s.to_string()),
297 iterations_requested: config.iterations,
298 interval_secs: config.interval_secs,
299 iterations,
300 interrupted,
301 total_changes,
302 started_at,
303 ended_at,
304 })
305 }
306
307 #[instrument(skip(self, config), fields(domain = %domain, record_type = ?record_type))]
309 pub async fn follow_simple(
310 &self,
311 domain: &str,
312 record_type: RecordType,
313 nameserver: Option<&str>,
314 config: FollowConfig,
315 ) -> Result<FollowResult> {
316 self.follow(domain, record_type, nameserver, config, None, None)
317 .await
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[tokio::test]
326 async fn test_follow_config_default() {
327 let config = FollowConfig::default();
328 assert_eq!(config.iterations, 10);
329 assert_eq!(config.interval_secs, 60);
330 assert!(!config.changes_only);
331 }
332
333 #[tokio::test]
334 async fn test_follow_config_new() {
335 let config = FollowConfig::new(5, 0.5).unwrap();
336 assert_eq!(config.iterations, 5);
337 assert_eq!(config.interval_secs, 30);
338 }
339
340 #[tokio::test]
341 #[ignore = "live network; run with --ignored or SEER_LIVE_TESTS=1"]
342 async fn test_follow_single_iteration() {
343 let follower = DnsFollower::new();
344 let config = FollowConfig::new(1, 0.0).unwrap();
345
346 let result = follower
347 .follow_simple("example.com", RecordType::A, None, config)
348 .await;
349
350 assert!(result.is_ok());
351 let result = result.unwrap();
352 assert_eq!(result.completed_iterations(), 1);
353 assert!(!result.interrupted);
354 }
355
356 #[test]
357 fn follow_config_rejects_zero_iterations() {
358 assert!(FollowConfig::new(0, 1.0).is_err());
359 }
360
361 #[test]
362 fn follow_config_rejects_infinite_interval() {
363 assert!(FollowConfig::new(10, f64::INFINITY).is_err());
364 assert!(FollowConfig::new(10, f64::NEG_INFINITY).is_err());
365 }
366
367 #[test]
368 fn follow_config_rejects_nan_interval() {
369 assert!(FollowConfig::new(10, f64::NAN).is_err());
370 }
371
372 #[test]
373 fn follow_config_rejects_negative_interval() {
374 assert!(FollowConfig::new(10, -1.0).is_err());
375 }
376
377 #[test]
378 fn follow_config_rejects_interval_above_cap() {
379 assert!(FollowConfig::new(10, 60.1).is_err());
380 }
381
382 #[test]
383 fn follow_config_accepts_valid() {
384 assert!(FollowConfig::new(10, 1.5).is_ok());
385 assert!(FollowConfig::new(1, 0.0).is_ok());
386 assert!(FollowConfig::new(1, 60.0).is_ok());
387 }
388
389 #[tokio::test]
390 #[ignore = "live network; run with --ignored or SEER_LIVE_TESTS=1"]
391 async fn follow_honors_cancel() {
392 use tokio::sync::watch;
393
394 let (tx, rx) = watch::channel(false);
395 let config = FollowConfig::new(100, 0.5).unwrap();
397 let follower = DnsFollower::new();
398
399 let handle = tokio::spawn(async move {
400 follower
401 .follow("example.com", RecordType::A, None, config, None, Some(rx))
402 .await
403 });
404
405 tokio::time::sleep(Duration::from_millis(200)).await;
407 tx.send(true).unwrap();
408
409 let joined = tokio::time::timeout(Duration::from_secs(10), handle)
410 .await
411 .expect("follow should return promptly after cancel");
412 let result = joined.expect("join").expect("follow result");
413 assert!(result.interrupted, "follow should be interrupted");
414 assert!(
415 result.completed_iterations() < 100,
416 "should not complete all iterations"
417 );
418 }
419}