1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use tokio::sync::watch;
8use tracing::{debug, instrument};
9
10use super::records::{DnsRecord, RecordType};
11use super::resolver::DnsResolver;
12use crate::error::Result;
13
14#[derive(Debug, Clone)]
16pub struct FollowConfig {
17 pub iterations: usize,
19 pub interval_secs: u64,
21 pub changes_only: bool,
23}
24
25impl Default for FollowConfig {
26 fn default() -> Self {
27 Self {
28 iterations: 10,
29 interval_secs: 60,
30 changes_only: false,
31 }
32 }
33}
34
35impl FollowConfig {
36 pub fn new(iterations: usize, interval_minutes: f64) -> Self {
37 Self {
38 iterations,
39 interval_secs: (interval_minutes * 60.0) as u64,
40 changes_only: false,
41 }
42 }
43
44 pub fn with_changes_only(mut self, changes_only: bool) -> Self {
45 self.changes_only = changes_only;
46 self
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct FollowIteration {
53 pub iteration: usize,
55 pub total_iterations: usize,
57 pub timestamp: DateTime<Utc>,
59 pub records: Vec<DnsRecord>,
61 pub changed: bool,
63 pub added: Vec<String>,
65 pub removed: Vec<String>,
67 pub error: Option<String>,
69}
70
71impl FollowIteration {
72 pub fn success(&self) -> bool {
73 self.error.is_none()
74 }
75
76 pub fn record_count(&self) -> usize {
77 self.records.len()
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct FollowResult {
84 pub domain: String,
86 pub record_type: RecordType,
88 pub nameserver: Option<String>,
90 pub iterations_requested: usize,
92 pub interval_secs: u64,
93 pub iterations: Vec<FollowIteration>,
95 pub interrupted: bool,
97 pub total_changes: usize,
99 pub started_at: DateTime<Utc>,
101 pub ended_at: DateTime<Utc>,
103}
104
105impl FollowResult {
106 pub fn completed_iterations(&self) -> usize {
107 self.iterations.len()
108 }
109
110 pub fn successful_iterations(&self) -> usize {
111 self.iterations.iter().filter(|i| i.success()).count()
112 }
113
114 pub fn failed_iterations(&self) -> usize {
115 self.iterations.iter().filter(|i| !i.success()).count()
116 }
117}
118
119pub type FollowProgressCallback = Arc<dyn Fn(&FollowIteration) + Send + Sync>;
121
122#[derive(Clone)]
124pub struct DnsFollower {
125 resolver: DnsResolver,
126}
127
128impl Default for DnsFollower {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134impl DnsFollower {
135 pub fn new() -> Self {
136 Self {
137 resolver: DnsResolver::new(),
138 }
139 }
140
141 pub fn with_resolver(resolver: DnsResolver) -> Self {
142 Self { resolver }
143 }
144
145 #[instrument(skip(self, config, callback, cancel_rx))]
147 pub async fn follow(
148 &self,
149 domain: &str,
150 record_type: RecordType,
151 nameserver: Option<&str>,
152 config: FollowConfig,
153 callback: Option<FollowProgressCallback>,
154 cancel_rx: Option<watch::Receiver<bool>>,
155 ) -> Result<FollowResult> {
156 let domain = crate::validation::normalize_domain(domain)?;
157 let started_at = Utc::now();
158 let mut iterations: Vec<FollowIteration> = Vec::with_capacity(config.iterations);
159 let mut previous_values: HashSet<String> = HashSet::new();
160 let mut total_changes = 0;
161 let mut interrupted = false;
162
163 debug!(
164 domain = %domain,
165 record_type = %record_type,
166 iterations = config.iterations,
167 interval_secs = config.interval_secs,
168 "Starting DNS follow"
169 );
170
171 for i in 0..config.iterations {
172 if let Some(ref rx) = cancel_rx {
174 if *rx.borrow() {
175 debug!("Follow operation cancelled");
176 interrupted = true;
177 break;
178 }
179 }
180
181 let timestamp = Utc::now();
182 let iteration_num = i + 1;
183
184 let (records, error) = match self
186 .resolver
187 .resolve(&domain, record_type, nameserver)
188 .await
189 {
190 Ok(records) => (records, None),
191 Err(e) => (Vec::new(), Some(e.to_string())),
192 };
193
194 let current_values: HashSet<String> =
196 records.iter().map(|r| r.data.to_string()).collect();
197
198 let (changed, added, removed) = if i == 0 {
200 (false, Vec::new(), Vec::new())
202 } else {
203 let added: Vec<String> = current_values
204 .difference(&previous_values)
205 .cloned()
206 .collect();
207 let removed: Vec<String> = previous_values
208 .difference(¤t_values)
209 .cloned()
210 .collect();
211 let changed = !added.is_empty() || !removed.is_empty();
212 (changed, added, removed)
213 };
214
215 if changed {
216 total_changes += 1;
217 }
218
219 let iteration = FollowIteration {
220 iteration: iteration_num,
221 total_iterations: config.iterations,
222 timestamp,
223 records,
224 changed,
225 added,
226 removed,
227 error,
228 };
229
230 if let Some(ref cb) = callback {
232 if !config.changes_only || iteration_num == 1 || changed {
234 cb(&iteration);
235 }
236 }
237
238 iterations.push(iteration);
239 previous_values = current_values;
240
241 if i < config.iterations - 1 {
243 let sleep_duration = Duration::from_secs(config.interval_secs);
244
245 if let Some(ref rx) = cancel_rx {
247 let mut rx_clone = rx.clone();
248 tokio::select! {
249 _ = tokio::time::sleep(sleep_duration) => {}
250 _ = rx_clone.changed() => {
251 if *rx_clone.borrow() {
252 debug!("Follow operation cancelled during sleep");
253 interrupted = true;
254 break;
255 }
256 }
257 }
258 } else {
259 tokio::time::sleep(sleep_duration).await;
260 }
261 }
262 }
263
264 let ended_at = Utc::now();
265
266 Ok(FollowResult {
267 domain: domain.to_string(),
268 record_type,
269 nameserver: nameserver.map(|s| s.to_string()),
270 iterations_requested: config.iterations,
271 interval_secs: config.interval_secs,
272 iterations,
273 interrupted,
274 total_changes,
275 started_at,
276 ended_at,
277 })
278 }
279
280 #[instrument(skip(self, config), fields(domain = %domain, record_type = ?record_type))]
282 pub async fn follow_simple(
283 &self,
284 domain: &str,
285 record_type: RecordType,
286 nameserver: Option<&str>,
287 config: FollowConfig,
288 ) -> Result<FollowResult> {
289 self.follow(domain, record_type, nameserver, config, None, None)
290 .await
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[tokio::test]
299 async fn test_follow_config_default() {
300 let config = FollowConfig::default();
301 assert_eq!(config.iterations, 10);
302 assert_eq!(config.interval_secs, 60);
303 assert!(!config.changes_only);
304 }
305
306 #[tokio::test]
307 async fn test_follow_config_new() {
308 let config = FollowConfig::new(5, 0.5);
309 assert_eq!(config.iterations, 5);
310 assert_eq!(config.interval_secs, 30);
311 }
312
313 #[tokio::test]
314 async fn test_follow_single_iteration() {
315 let follower = DnsFollower::new();
316 let config = FollowConfig::new(1, 0.0);
317
318 let result = follower
319 .follow_simple("example.com", RecordType::A, None, config)
320 .await;
321
322 assert!(result.is_ok());
323 let result = result.unwrap();
324 assert_eq!(result.completed_iterations(), 1);
325 assert!(!result.interrupted);
326 }
327}