1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use tokio::sync::watch;
8use tracing::{debug, instrument};
9
10use super::records::{DnsRecord, RecordType};
11use super::resolver::DnsResolver;
12use crate::error::{Result, SeerError};
13
14#[derive(Debug, Clone)]
16pub struct FollowConfig {
17 pub iterations: usize,
19 pub interval_secs: u64,
21 pub changes_only: bool,
23}
24
25impl Default for FollowConfig {
26 fn default() -> Self {
27 Self {
28 iterations: 10,
29 interval_secs: 60,
30 changes_only: false,
31 }
32 }
33}
34
35impl FollowConfig {
36 pub fn new(iterations: usize, interval_minutes: f64) -> Result<Self> {
44 if iterations == 0 {
45 return Err(SeerError::InvalidInput(
46 "iterations must be at least 1".into(),
47 ));
48 }
49 if !interval_minutes.is_finite() {
50 return Err(SeerError::InvalidInput(
51 "interval_minutes must be a finite number".into(),
52 ));
53 }
54 if interval_minutes < 0.0 {
55 return Err(SeerError::InvalidInput(
56 "interval_minutes must be non-negative".into(),
57 ));
58 }
59 if interval_minutes > 60.0 {
60 return Err(SeerError::InvalidInput(
61 "interval_minutes must be at most 60".into(),
62 ));
63 }
64 Ok(Self {
65 iterations,
66 interval_secs: (interval_minutes * 60.0) as u64,
67 changes_only: false,
68 })
69 }
70
71 pub fn with_changes_only(mut self, changes_only: bool) -> Self {
72 self.changes_only = changes_only;
73 self
74 }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct FollowIteration {
80 pub iteration: usize,
82 pub total_iterations: usize,
84 pub timestamp: DateTime<Utc>,
86 pub records: Vec<DnsRecord>,
88 pub changed: bool,
90 pub added: Vec<String>,
92 pub removed: Vec<String>,
94 pub error: Option<String>,
96}
97
98impl FollowIteration {
99 pub fn success(&self) -> bool {
100 self.error.is_none()
101 }
102
103 pub fn record_count(&self) -> usize {
104 self.records.len()
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct FollowResult {
111 pub domain: String,
113 pub record_type: RecordType,
115 pub nameserver: Option<String>,
117 pub iterations_requested: usize,
119 pub interval_secs: u64,
120 pub iterations: Vec<FollowIteration>,
122 pub interrupted: bool,
124 pub total_changes: usize,
126 pub started_at: DateTime<Utc>,
128 pub ended_at: DateTime<Utc>,
130}
131
132impl FollowResult {
133 pub fn completed_iterations(&self) -> usize {
134 self.iterations.len()
135 }
136
137 pub fn successful_iterations(&self) -> usize {
138 self.iterations.iter().filter(|i| i.success()).count()
139 }
140
141 pub fn failed_iterations(&self) -> usize {
142 self.iterations.iter().filter(|i| !i.success()).count()
143 }
144}
145
146pub type FollowProgressCallback = Arc<dyn Fn(&FollowIteration) + Send + Sync>;
148
149#[derive(Clone)]
151pub struct DnsFollower {
152 resolver: DnsResolver,
153}
154
155impl Default for DnsFollower {
156 fn default() -> Self {
157 Self::new()
158 }
159}
160
161impl DnsFollower {
162 pub fn new() -> Self {
163 Self {
164 resolver: DnsResolver::new(),
165 }
166 }
167
168 pub fn with_resolver(resolver: DnsResolver) -> Self {
169 Self { resolver }
170 }
171
172 #[instrument(skip(self, config, callback, cancel_rx))]
174 pub async fn follow(
175 &self,
176 domain: &str,
177 record_type: RecordType,
178 nameserver: Option<&str>,
179 config: FollowConfig,
180 callback: Option<FollowProgressCallback>,
181 cancel_rx: Option<watch::Receiver<bool>>,
182 ) -> Result<FollowResult> {
183 let domain = crate::validation::normalize_domain(domain)?;
184 let started_at = Utc::now();
185 let mut iterations: Vec<FollowIteration> = Vec::with_capacity(config.iterations);
186 let mut previous_values: HashSet<String> = HashSet::new();
187 let mut total_changes = 0;
188 let mut interrupted = false;
189
190 debug!(
191 domain = %domain,
192 record_type = %record_type,
193 iterations = config.iterations,
194 interval_secs = config.interval_secs,
195 "Starting DNS follow"
196 );
197
198 for i in 0..config.iterations {
199 if let Some(ref rx) = cancel_rx {
201 if *rx.borrow() {
202 debug!("Follow operation cancelled");
203 interrupted = true;
204 break;
205 }
206 }
207
208 let timestamp = Utc::now();
209 let iteration_num = i + 1;
210
211 let (records, error) = match self
213 .resolver
214 .resolve(&domain, record_type, nameserver)
215 .await
216 {
217 Ok(records) => (records, None),
218 Err(e) => (Vec::new(), Some(e.to_string())),
219 };
220
221 let current_values: HashSet<String> =
223 records.iter().map(|r| r.data.to_string()).collect();
224
225 let (changed, added, removed) = if i == 0 {
227 (false, Vec::new(), Vec::new())
229 } else {
230 let added: Vec<String> = current_values
231 .difference(&previous_values)
232 .cloned()
233 .collect();
234 let removed: Vec<String> = previous_values
235 .difference(¤t_values)
236 .cloned()
237 .collect();
238 let changed = !added.is_empty() || !removed.is_empty();
239 (changed, added, removed)
240 };
241
242 if changed {
243 total_changes += 1;
244 }
245
246 let iteration = FollowIteration {
247 iteration: iteration_num,
248 total_iterations: config.iterations,
249 timestamp,
250 records,
251 changed,
252 added,
253 removed,
254 error,
255 };
256
257 if let Some(ref cb) = callback {
259 if !config.changes_only || iteration_num == 1 || changed {
261 cb(&iteration);
262 }
263 }
264
265 iterations.push(iteration);
266 previous_values = current_values;
267
268 if i < config.iterations - 1 {
270 let sleep_duration = Duration::from_secs(config.interval_secs);
271
272 if let Some(ref rx) = cancel_rx {
274 let mut rx_clone = rx.clone();
275 tokio::select! {
276 _ = tokio::time::sleep(sleep_duration) => {}
277 _ = rx_clone.changed() => {
278 if *rx_clone.borrow() {
279 debug!("Follow operation cancelled during sleep");
280 interrupted = true;
281 break;
282 }
283 }
284 }
285 } else {
286 tokio::time::sleep(sleep_duration).await;
287 }
288 }
289 }
290
291 let ended_at = Utc::now();
292
293 Ok(FollowResult {
294 domain: domain.to_string(),
295 record_type,
296 nameserver: nameserver.map(|s| s.to_string()),
297 iterations_requested: config.iterations,
298 interval_secs: config.interval_secs,
299 iterations,
300 interrupted,
301 total_changes,
302 started_at,
303 ended_at,
304 })
305 }
306
307 #[instrument(skip(self, config), fields(domain = %domain, record_type = ?record_type))]
309 pub async fn follow_simple(
310 &self,
311 domain: &str,
312 record_type: RecordType,
313 nameserver: Option<&str>,
314 config: FollowConfig,
315 ) -> Result<FollowResult> {
316 self.follow(domain, record_type, nameserver, config, None, None)
317 .await
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[tokio::test]
326 async fn test_follow_config_default() {
327 let config = FollowConfig::default();
328 assert_eq!(config.iterations, 10);
329 assert_eq!(config.interval_secs, 60);
330 assert!(!config.changes_only);
331 }
332
333 #[tokio::test]
334 async fn test_follow_config_new() {
335 let config = FollowConfig::new(5, 0.5).unwrap();
336 assert_eq!(config.iterations, 5);
337 assert_eq!(config.interval_secs, 30);
338 }
339
340 #[tokio::test]
341 async fn test_follow_single_iteration() {
342 let follower = DnsFollower::new();
343 let config = FollowConfig::new(1, 0.0).unwrap();
344
345 let result = follower
346 .follow_simple("example.com", RecordType::A, None, config)
347 .await;
348
349 assert!(result.is_ok());
350 let result = result.unwrap();
351 assert_eq!(result.completed_iterations(), 1);
352 assert!(!result.interrupted);
353 }
354
355 #[test]
356 fn follow_config_rejects_zero_iterations() {
357 assert!(FollowConfig::new(0, 1.0).is_err());
358 }
359
360 #[test]
361 fn follow_config_rejects_infinite_interval() {
362 assert!(FollowConfig::new(10, f64::INFINITY).is_err());
363 assert!(FollowConfig::new(10, f64::NEG_INFINITY).is_err());
364 }
365
366 #[test]
367 fn follow_config_rejects_nan_interval() {
368 assert!(FollowConfig::new(10, f64::NAN).is_err());
369 }
370
371 #[test]
372 fn follow_config_rejects_negative_interval() {
373 assert!(FollowConfig::new(10, -1.0).is_err());
374 }
375
376 #[test]
377 fn follow_config_rejects_interval_above_cap() {
378 assert!(FollowConfig::new(10, 60.1).is_err());
379 }
380
381 #[test]
382 fn follow_config_accepts_valid() {
383 assert!(FollowConfig::new(10, 1.5).is_ok());
384 assert!(FollowConfig::new(1, 0.0).is_ok());
385 assert!(FollowConfig::new(1, 60.0).is_ok());
386 }
387
388 #[tokio::test]
389 async fn follow_honors_cancel() {
390 use tokio::sync::watch;
391
392 let (tx, rx) = watch::channel(false);
393 let config = FollowConfig::new(100, 0.5).unwrap();
395 let follower = DnsFollower::new();
396
397 let handle = tokio::spawn(async move {
398 follower
399 .follow("example.com", RecordType::A, None, config, None, Some(rx))
400 .await
401 });
402
403 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
405 tx.send(true).unwrap();
406
407 let joined = tokio::time::timeout(std::time::Duration::from_secs(10), handle)
408 .await
409 .expect("follow should return promptly after cancel");
410 let result = joined.expect("join").expect("follow result");
411 assert!(result.interrupted, "follow should be interrupted");
412 assert!(
413 result.completed_iterations() < 100,
414 "should not complete all iterations"
415 );
416 }
417}