1use std::sync::atomic::{AtomicUsize, Ordering};
2use std::sync::Arc;
3use std::time::Duration;
4
5use futures::stream::{self, StreamExt};
6use serde::{Deserialize, Serialize};
7use tokio::sync::Semaphore;
8use tokio::time::sleep;
9use tracing::{debug, warn};
10
11use crate::dns::{DnsRecord, DnsResolver, PropagationChecker, PropagationResult, RecordType};
12use crate::error::Result;
13use crate::lookup::{LookupResult, SmartLookup};
14use crate::rdap::{RdapClient, RdapResponse};
15use crate::status::{StatusClient, StatusResponse};
16use crate::whois::{WhoisClient, WhoisResponse};
17
18pub type ProgressCallback = Box<dyn Fn(usize, usize, &str) + Send + Sync>;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21#[serde(tag = "type", rename_all = "snake_case")]
22pub enum BulkOperation {
23 Whois { domain: String },
24 Rdap { domain: String },
25 Dns { domain: String, record_type: RecordType },
26 Propagation { domain: String, record_type: RecordType },
27 Lookup { domain: String },
28 Status { domain: String },
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32#[serde(untagged)]
33pub enum BulkResultData {
34 Whois(WhoisResponse),
35 Rdap(Box<RdapResponse>),
36 Dns(Vec<DnsRecord>),
37 Propagation(PropagationResult),
38 Lookup(LookupResult),
39 Status(StatusResponse),
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct BulkResult {
44 pub operation: BulkOperation,
45 pub success: bool,
46 pub data: Option<BulkResultData>,
47 pub error: Option<String>,
48 pub duration_ms: u64,
49}
50
51#[derive(Debug, Clone)]
52pub struct BulkExecutor {
53 concurrency: usize,
54 rate_limit_delay: Duration,
55 whois_client: WhoisClient,
56 rdap_client: RdapClient,
57 dns_resolver: DnsResolver,
58 propagation_checker: PropagationChecker,
59 smart_lookup: SmartLookup,
60 status_client: StatusClient,
61}
62
63impl Default for BulkExecutor {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl BulkExecutor {
70 pub fn new() -> Self {
71 Self {
72 concurrency: 10,
73 rate_limit_delay: Duration::from_millis(100),
74 whois_client: WhoisClient::new(),
75 rdap_client: RdapClient::new(),
76 dns_resolver: DnsResolver::new(),
77 propagation_checker: PropagationChecker::new(),
78 smart_lookup: SmartLookup::new(),
79 status_client: StatusClient::new(),
80 }
81 }
82
83 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
84 self.concurrency = concurrency.max(1);
85 self
86 }
87
88 pub fn with_rate_limit(mut self, delay: Duration) -> Self {
89 self.rate_limit_delay = delay;
90 self
91 }
92
93 pub async fn execute(
94 &self,
95 operations: Vec<BulkOperation>,
96 progress: Option<ProgressCallback>,
97 ) -> Vec<BulkResult> {
98 let total = operations.len();
99 let completed = Arc::new(AtomicUsize::new(0));
100 let semaphore = Arc::new(Semaphore::new(self.concurrency));
101
102 debug!(
103 total = total,
104 concurrency = self.concurrency,
105 "Starting bulk execution"
106 );
107
108 let results: Vec<BulkResult> = stream::iter(operations)
109 .map(|op| {
110 let semaphore = semaphore.clone();
111 let completed = completed.clone();
112 let progress = progress.as_ref();
113 let rate_limit_delay = self.rate_limit_delay;
114 let whois_client = &self.whois_client;
115 let rdap_client = &self.rdap_client;
116 let dns_resolver = &self.dns_resolver;
117 let propagation_checker = &self.propagation_checker;
118 let smart_lookup = &self.smart_lookup;
119 let status_client = &self.status_client;
120
121 async move {
122 let _permit = match semaphore.acquire().await {
123 Ok(permit) => permit,
124 Err(_) => {
125 return BulkResult {
126 operation: op,
127 success: false,
128 data: None,
129 error: Some("Operation cancelled".to_string()),
130 duration_ms: 0,
131 };
132 }
133 };
134
135 if !rate_limit_delay.is_zero() {
137 sleep(rate_limit_delay).await;
138 }
139
140 let start = std::time::Instant::now();
141 let result = execute_operation(
142 &op,
143 whois_client,
144 rdap_client,
145 dns_resolver,
146 propagation_checker,
147 smart_lookup,
148 status_client,
149 )
150 .await;
151 let duration_ms = start.elapsed().as_millis() as u64;
152
153 let count = completed.fetch_add(1, Ordering::Relaxed) + 1;
154
155 if let Some(progress) = progress {
156 let desc = match &op {
157 BulkOperation::Whois { domain } => domain.clone(),
158 BulkOperation::Rdap { domain } => domain.clone(),
159 BulkOperation::Dns { domain, .. } => domain.clone(),
160 BulkOperation::Propagation { domain, .. } => domain.clone(),
161 BulkOperation::Lookup { domain } => domain.clone(),
162 BulkOperation::Status { domain } => domain.clone(),
163 };
164 progress(count, total, &desc);
165 }
166
167 match result {
168 Ok(data) => BulkResult {
169 operation: op,
170 success: true,
171 data: Some(data),
172 error: None,
173 duration_ms,
174 },
175 Err(e) => {
176 warn!(error = %e, "Bulk operation failed");
177 BulkResult {
178 operation: op,
179 success: false,
180 data: None,
181 error: Some(e.to_string()),
182 duration_ms,
183 }
184 }
185 }
186 }
187 })
188 .buffer_unordered(self.concurrency)
189 .collect()
190 .await;
191
192 results
193 }
194
195 pub async fn execute_whois(&self, domains: Vec<String>) -> Vec<BulkResult> {
196 let operations = domains
197 .into_iter()
198 .map(|domain| BulkOperation::Whois { domain })
199 .collect();
200 self.execute(operations, None).await
201 }
202
203 pub async fn execute_rdap(&self, domains: Vec<String>) -> Vec<BulkResult> {
204 let operations = domains
205 .into_iter()
206 .map(|domain| BulkOperation::Rdap { domain })
207 .collect();
208 self.execute(operations, None).await
209 }
210
211 pub async fn execute_dns(
212 &self,
213 domains: Vec<String>,
214 record_type: RecordType,
215 ) -> Vec<BulkResult> {
216 let operations = domains
217 .into_iter()
218 .map(|domain| BulkOperation::Dns {
219 domain,
220 record_type,
221 })
222 .collect();
223 self.execute(operations, None).await
224 }
225
226 pub async fn execute_propagation(
227 &self,
228 domains: Vec<String>,
229 record_type: RecordType,
230 ) -> Vec<BulkResult> {
231 let operations = domains
232 .into_iter()
233 .map(|domain| BulkOperation::Propagation {
234 domain,
235 record_type,
236 })
237 .collect();
238 self.execute(operations, None).await
239 }
240
241 pub async fn execute_lookup(&self, domains: Vec<String>) -> Vec<BulkResult> {
242 let operations = domains
243 .into_iter()
244 .map(|domain| BulkOperation::Lookup { domain })
245 .collect();
246 self.execute(operations, None).await
247 }
248
249 pub async fn execute_status(&self, domains: Vec<String>) -> Vec<BulkResult> {
250 let operations = domains
251 .into_iter()
252 .map(|domain| BulkOperation::Status { domain })
253 .collect();
254 self.execute(operations, None).await
255 }
256}
257
258async fn execute_operation(
259 op: &BulkOperation,
260 whois_client: &WhoisClient,
261 rdap_client: &RdapClient,
262 dns_resolver: &DnsResolver,
263 propagation_checker: &PropagationChecker,
264 smart_lookup: &SmartLookup,
265 status_client: &StatusClient,
266) -> Result<BulkResultData> {
267 match op {
268 BulkOperation::Whois { domain } => {
269 let result = whois_client.lookup(domain).await?;
270 Ok(BulkResultData::Whois(result))
271 }
272 BulkOperation::Rdap { domain } => {
273 let result = rdap_client.lookup_domain(domain).await?;
274 Ok(BulkResultData::Rdap(Box::new(result)))
275 }
276 BulkOperation::Dns {
277 domain,
278 record_type,
279 } => {
280 let result = dns_resolver.resolve(domain, *record_type, None).await?;
281 Ok(BulkResultData::Dns(result))
282 }
283 BulkOperation::Propagation {
284 domain,
285 record_type,
286 } => {
287 let result = propagation_checker.check(domain, *record_type).await?;
288 Ok(BulkResultData::Propagation(result))
289 }
290 BulkOperation::Lookup { domain } => {
291 let result = smart_lookup.lookup(domain).await?;
292 Ok(BulkResultData::Lookup(result))
293 }
294 BulkOperation::Status { domain } => {
295 let result = status_client.check(domain).await?;
296 Ok(BulkResultData::Status(result))
297 }
298 }
299}
300
301pub fn parse_domains_from_file(content: &str) -> Vec<String> {
302 content
303 .lines()
304 .map(|line| line.trim())
305 .filter(|line| !line.is_empty() && !line.starts_with('#'))
306 .map(|line| {
307 line.split(',').next().unwrap_or(line).trim().to_string()
309 })
310 .filter(|domain| domain.contains('.'))
311 .collect()
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 #[test]
319 fn test_parse_domains_from_file() {
320 let content = r#"
321# This is a comment
322example.com
323google.com
324 whitespace.com
325invalid
326csv,format,example.org
327"#;
328
329 let domains = parse_domains_from_file(content);
330 assert_eq!(domains.len(), 3);
331 assert!(domains.contains(&"example.com".to_string()));
332 assert!(domains.contains(&"google.com".to_string()));
333 assert!(domains.contains(&"whitespace.com".to_string()));
334 }
336}