1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use tokio::sync::watch;
8use tracing::{debug, instrument};
9
10use super::records::{DnsRecord, RecordType};
11use super::resolver::DnsResolver;
12use crate::error::Result;
13
14#[derive(Debug, Clone)]
16pub struct FollowConfig {
17 pub iterations: usize,
19 pub interval_secs: u64,
21 pub changes_only: bool,
23}
24
25impl Default for FollowConfig {
26 fn default() -> Self {
27 Self {
28 iterations: 10,
29 interval_secs: 60,
30 changes_only: false,
31 }
32 }
33}
34
35impl FollowConfig {
36 pub fn new(iterations: usize, interval_minutes: f64) -> Self {
37 Self {
38 iterations,
39 interval_secs: (interval_minutes * 60.0) as u64,
40 changes_only: false,
41 }
42 }
43
44 pub fn with_changes_only(mut self, changes_only: bool) -> Self {
45 self.changes_only = changes_only;
46 self
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct FollowIteration {
53 pub iteration: usize,
55 pub total_iterations: usize,
57 pub timestamp: DateTime<Utc>,
59 pub records: Vec<DnsRecord>,
61 pub changed: bool,
63 pub added: Vec<String>,
65 pub removed: Vec<String>,
67 pub error: Option<String>,
69}
70
71impl FollowIteration {
72 pub fn success(&self) -> bool {
73 self.error.is_none()
74 }
75
76 pub fn record_count(&self) -> usize {
77 self.records.len()
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct FollowResult {
84 pub domain: String,
86 pub record_type: RecordType,
88 pub nameserver: Option<String>,
90 pub iterations_requested: usize,
92 pub interval_secs: u64,
93 pub iterations: Vec<FollowIteration>,
95 pub interrupted: bool,
97 pub total_changes: usize,
99 pub started_at: DateTime<Utc>,
101 pub ended_at: DateTime<Utc>,
103}
104
105impl FollowResult {
106 pub fn completed_iterations(&self) -> usize {
107 self.iterations.len()
108 }
109
110 pub fn successful_iterations(&self) -> usize {
111 self.iterations.iter().filter(|i| i.success()).count()
112 }
113
114 pub fn failed_iterations(&self) -> usize {
115 self.iterations.iter().filter(|i| !i.success()).count()
116 }
117}
118
119pub type FollowProgressCallback = Arc<dyn Fn(&FollowIteration) + Send + Sync>;
121
122#[derive(Clone)]
124pub struct DnsFollower {
125 resolver: DnsResolver,
126}
127
128impl Default for DnsFollower {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134impl DnsFollower {
135 pub fn new() -> Self {
136 Self {
137 resolver: DnsResolver::new(),
138 }
139 }
140
141 pub fn with_resolver(resolver: DnsResolver) -> Self {
142 Self { resolver }
143 }
144
145 #[instrument(skip(self, config, callback, cancel_rx))]
147 pub async fn follow(
148 &self,
149 domain: &str,
150 record_type: RecordType,
151 nameserver: Option<&str>,
152 config: FollowConfig,
153 callback: Option<FollowProgressCallback>,
154 cancel_rx: Option<watch::Receiver<bool>>,
155 ) -> Result<FollowResult> {
156 let started_at = Utc::now();
157 let mut iterations: Vec<FollowIteration> = Vec::with_capacity(config.iterations);
158 let mut previous_values: HashSet<String> = HashSet::new();
159 let mut total_changes = 0;
160 let mut interrupted = false;
161
162 debug!(
163 domain = %domain,
164 record_type = %record_type,
165 iterations = config.iterations,
166 interval_secs = config.interval_secs,
167 "Starting DNS follow"
168 );
169
170 for i in 0..config.iterations {
171 if let Some(ref rx) = cancel_rx {
173 if *rx.borrow() {
174 debug!("Follow operation cancelled");
175 interrupted = true;
176 break;
177 }
178 }
179
180 let timestamp = Utc::now();
181 let iteration_num = i + 1;
182
183 let (records, error) =
185 match self.resolver.resolve(domain, record_type, nameserver).await {
186 Ok(records) => (records, None),
187 Err(e) => (Vec::new(), Some(e.to_string())),
188 };
189
190 let current_values: HashSet<String> =
192 records.iter().map(|r| r.data.to_string()).collect();
193
194 let (changed, added, removed) = if i == 0 {
196 (false, Vec::new(), Vec::new())
198 } else {
199 let added: Vec<String> = current_values
200 .difference(&previous_values)
201 .cloned()
202 .collect();
203 let removed: Vec<String> = previous_values
204 .difference(¤t_values)
205 .cloned()
206 .collect();
207 let changed = !added.is_empty() || !removed.is_empty();
208 (changed, added, removed)
209 };
210
211 if changed {
212 total_changes += 1;
213 }
214
215 let iteration = FollowIteration {
216 iteration: iteration_num,
217 total_iterations: config.iterations,
218 timestamp,
219 records,
220 changed,
221 added,
222 removed,
223 error,
224 };
225
226 if let Some(ref cb) = callback {
228 if !config.changes_only || iteration_num == 1 || changed {
230 cb(&iteration);
231 }
232 }
233
234 iterations.push(iteration);
235 previous_values = current_values;
236
237 if i < config.iterations - 1 {
239 let sleep_duration = Duration::from_secs(config.interval_secs);
240
241 if let Some(ref rx) = cancel_rx {
243 let mut rx_clone = rx.clone();
244 tokio::select! {
245 _ = tokio::time::sleep(sleep_duration) => {}
246 _ = rx_clone.changed() => {
247 if *rx_clone.borrow() {
248 debug!("Follow operation cancelled during sleep");
249 interrupted = true;
250 break;
251 }
252 }
253 }
254 } else {
255 tokio::time::sleep(sleep_duration).await;
256 }
257 }
258 }
259
260 let ended_at = Utc::now();
261
262 Ok(FollowResult {
263 domain: domain.to_string(),
264 record_type,
265 nameserver: nameserver.map(|s| s.to_string()),
266 iterations_requested: config.iterations,
267 interval_secs: config.interval_secs,
268 iterations,
269 interrupted,
270 total_changes,
271 started_at,
272 ended_at,
273 })
274 }
275
276 pub async fn follow_simple(
278 &self,
279 domain: &str,
280 record_type: RecordType,
281 nameserver: Option<&str>,
282 config: FollowConfig,
283 ) -> Result<FollowResult> {
284 self.follow(domain, record_type, nameserver, config, None, None)
285 .await
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[tokio::test]
294 async fn test_follow_config_default() {
295 let config = FollowConfig::default();
296 assert_eq!(config.iterations, 10);
297 assert_eq!(config.interval_secs, 60);
298 assert!(!config.changes_only);
299 }
300
301 #[tokio::test]
302 async fn test_follow_config_new() {
303 let config = FollowConfig::new(5, 0.5);
304 assert_eq!(config.iterations, 5);
305 assert_eq!(config.interval_secs, 30);
306 }
307
308 #[tokio::test]
309 async fn test_follow_single_iteration() {
310 let follower = DnsFollower::new();
311 let config = FollowConfig::new(1, 0.0);
312
313 let result = follower
314 .follow_simple("example.com", RecordType::A, None, config)
315 .await;
316
317 assert!(result.is_ok());
318 let result = result.unwrap();
319 assert_eq!(result.completed_iterations(), 1);
320 assert!(!result.interrupted);
321 }
322}