1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use tokio::sync::watch;
8use tracing::{debug, instrument};
9
10use super::records::{DnsRecord, RecordType};
11use super::resolver::DnsResolver;
12use crate::error::Result;
13
14#[derive(Debug, Clone)]
16pub struct FollowConfig {
17 pub iterations: usize,
19 pub interval_secs: u64,
21 pub changes_only: bool,
23}
24
25impl Default for FollowConfig {
26 fn default() -> Self {
27 Self {
28 iterations: 10,
29 interval_secs: 60,
30 changes_only: false,
31 }
32 }
33}
34
35impl FollowConfig {
36 pub fn new(iterations: usize, interval_minutes: f64) -> Self {
37 Self {
38 iterations,
39 interval_secs: (interval_minutes * 60.0) as u64,
40 changes_only: false,
41 }
42 }
43
44 pub fn with_changes_only(mut self, changes_only: bool) -> Self {
45 self.changes_only = changes_only;
46 self
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct FollowIteration {
53 pub iteration: usize,
55 pub total_iterations: usize,
57 pub timestamp: DateTime<Utc>,
59 pub records: Vec<DnsRecord>,
61 pub changed: bool,
63 pub added: Vec<String>,
65 pub removed: Vec<String>,
67 pub error: Option<String>,
69}
70
71impl FollowIteration {
72 pub fn success(&self) -> bool {
73 self.error.is_none()
74 }
75
76 pub fn record_count(&self) -> usize {
77 self.records.len()
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct FollowResult {
84 pub domain: String,
86 pub record_type: RecordType,
88 pub nameserver: Option<String>,
90 pub iterations_requested: usize,
92 pub interval_secs: u64,
93 pub iterations: Vec<FollowIteration>,
95 pub interrupted: bool,
97 pub total_changes: usize,
99 pub started_at: DateTime<Utc>,
101 pub ended_at: DateTime<Utc>,
103}
104
105impl FollowResult {
106 pub fn completed_iterations(&self) -> usize {
107 self.iterations.len()
108 }
109
110 pub fn successful_iterations(&self) -> usize {
111 self.iterations.iter().filter(|i| i.success()).count()
112 }
113
114 pub fn failed_iterations(&self) -> usize {
115 self.iterations.iter().filter(|i| !i.success()).count()
116 }
117}
118
119pub type FollowProgressCallback = Arc<dyn Fn(&FollowIteration) + Send + Sync>;
121
122#[derive(Clone)]
124pub struct DnsFollower {
125 resolver: DnsResolver,
126}
127
128impl Default for DnsFollower {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134impl DnsFollower {
135 pub fn new() -> Self {
136 Self {
137 resolver: DnsResolver::new(),
138 }
139 }
140
141 pub fn with_resolver(resolver: DnsResolver) -> Self {
142 Self { resolver }
143 }
144
145 #[instrument(skip(self, config, callback, cancel_rx))]
147 pub async fn follow(
148 &self,
149 domain: &str,
150 record_type: RecordType,
151 nameserver: Option<&str>,
152 config: FollowConfig,
153 callback: Option<FollowProgressCallback>,
154 cancel_rx: Option<watch::Receiver<bool>>,
155 ) -> Result<FollowResult> {
156 let started_at = Utc::now();
157 let mut iterations: Vec<FollowIteration> = Vec::with_capacity(config.iterations);
158 let mut previous_values: HashSet<String> = HashSet::new();
159 let mut total_changes = 0;
160 let mut interrupted = false;
161
162 debug!(
163 domain = %domain,
164 record_type = %record_type,
165 iterations = config.iterations,
166 interval_secs = config.interval_secs,
167 "Starting DNS follow"
168 );
169
170 for i in 0..config.iterations {
171 if let Some(ref rx) = cancel_rx {
173 if *rx.borrow() {
174 debug!("Follow operation cancelled");
175 interrupted = true;
176 break;
177 }
178 }
179
180 let timestamp = Utc::now();
181 let iteration_num = i + 1;
182
183 let (records, error) = match self
185 .resolver
186 .resolve(domain, record_type, nameserver)
187 .await
188 {
189 Ok(records) => (records, None),
190 Err(e) => (Vec::new(), Some(e.to_string())),
191 };
192
193 let current_values: HashSet<String> = records
195 .iter()
196 .map(|r| r.data.to_string())
197 .collect();
198
199 let (changed, added, removed) = if i == 0 {
201 (false, Vec::new(), Vec::new())
203 } else {
204 let added: Vec<String> = current_values
205 .difference(&previous_values)
206 .cloned()
207 .collect();
208 let removed: Vec<String> = previous_values
209 .difference(¤t_values)
210 .cloned()
211 .collect();
212 let changed = !added.is_empty() || !removed.is_empty();
213 (changed, added, removed)
214 };
215
216 if changed {
217 total_changes += 1;
218 }
219
220 let iteration = FollowIteration {
221 iteration: iteration_num,
222 total_iterations: config.iterations,
223 timestamp,
224 records,
225 changed,
226 added,
227 removed,
228 error,
229 };
230
231 if let Some(ref cb) = callback {
233 if !config.changes_only || iteration_num == 1 || changed {
235 cb(&iteration);
236 }
237 }
238
239 iterations.push(iteration);
240 previous_values = current_values;
241
242 if i < config.iterations - 1 {
244 let sleep_duration = Duration::from_secs(config.interval_secs);
245
246 if let Some(ref rx) = cancel_rx {
248 let mut rx_clone = rx.clone();
249 tokio::select! {
250 _ = tokio::time::sleep(sleep_duration) => {}
251 _ = rx_clone.changed() => {
252 if *rx_clone.borrow() {
253 debug!("Follow operation cancelled during sleep");
254 interrupted = true;
255 break;
256 }
257 }
258 }
259 } else {
260 tokio::time::sleep(sleep_duration).await;
261 }
262 }
263 }
264
265 let ended_at = Utc::now();
266
267 Ok(FollowResult {
268 domain: domain.to_string(),
269 record_type,
270 nameserver: nameserver.map(|s| s.to_string()),
271 iterations_requested: config.iterations,
272 interval_secs: config.interval_secs,
273 iterations,
274 interrupted,
275 total_changes,
276 started_at,
277 ended_at,
278 })
279 }
280
281 pub async fn follow_simple(
283 &self,
284 domain: &str,
285 record_type: RecordType,
286 nameserver: Option<&str>,
287 config: FollowConfig,
288 ) -> Result<FollowResult> {
289 self.follow(domain, record_type, nameserver, config, None, None)
290 .await
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[tokio::test]
299 async fn test_follow_config_default() {
300 let config = FollowConfig::default();
301 assert_eq!(config.iterations, 10);
302 assert_eq!(config.interval_secs, 60);
303 assert!(!config.changes_only);
304 }
305
306 #[tokio::test]
307 async fn test_follow_config_new() {
308 let config = FollowConfig::new(5, 0.5);
309 assert_eq!(config.iterations, 5);
310 assert_eq!(config.interval_secs, 30);
311 }
312
313 #[tokio::test]
314 async fn test_follow_single_iteration() {
315 let follower = DnsFollower::new();
316 let config = FollowConfig::new(1, 0.0);
317
318 let result = follower
319 .follow_simple("example.com", RecordType::A, None, config)
320 .await;
321
322 assert!(result.is_ok());
323 let result = result.unwrap();
324 assert_eq!(result.completed_iterations(), 1);
325 assert!(!result.interrupted);
326 }
327}