1use std::sync::atomic::{AtomicUsize, Ordering};
2use std::sync::Arc;
3use std::time::Duration;
4
5use futures::stream::{self, StreamExt};
6use serde::{Deserialize, Serialize};
7use tokio::sync::Semaphore;
8use tokio::time::sleep;
9use tracing::{debug, warn};
10
11use crate::dns::{DnsRecord, DnsResolver, PropagationChecker, PropagationResult, RecordType};
12use crate::error::Result;
13use crate::lookup::{LookupResult, SmartLookup};
14use crate::rdap::{RdapClient, RdapResponse};
15use crate::status::{StatusClient, StatusResponse};
16use crate::whois::{WhoisClient, WhoisResponse};
17
18pub type ProgressCallback = Box<dyn Fn(usize, usize, &str) + Send + Sync>;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(tag = "type", rename_all = "snake_case")]
23pub enum BulkOperation {
24 Whois {
25 domain: String,
26 },
27 Rdap {
28 domain: String,
29 },
30 Dns {
31 domain: String,
32 record_type: RecordType,
33 },
34 Propagation {
35 domain: String,
36 record_type: RecordType,
37 },
38 Lookup {
39 domain: String,
40 },
41 Status {
42 domain: String,
43 },
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48#[serde(untagged)]
49pub enum BulkResultData {
50 Whois(WhoisResponse),
51 Rdap(Box<RdapResponse>),
52 Dns(Vec<DnsRecord>),
53 Propagation(PropagationResult),
54 Lookup(LookupResult),
55 Status(StatusResponse),
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct BulkResult {
61 pub operation: BulkOperation,
62 pub success: bool,
63 pub data: Option<BulkResultData>,
64 pub error: Option<String>,
65 pub duration_ms: u64,
66}
67
68#[derive(Debug, Clone)]
70pub struct BulkExecutor {
71 concurrency: usize,
72 rate_limit_delay: Duration,
73 whois_client: WhoisClient,
74 rdap_client: RdapClient,
75 dns_resolver: DnsResolver,
76 propagation_checker: PropagationChecker,
77 smart_lookup: SmartLookup,
78 status_client: StatusClient,
79}
80
81impl Default for BulkExecutor {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl BulkExecutor {
88 pub fn new() -> Self {
89 Self {
90 concurrency: 10,
91 rate_limit_delay: Duration::from_millis(100),
92 whois_client: WhoisClient::new(),
93 rdap_client: RdapClient::new(),
94 dns_resolver: DnsResolver::new(),
95 propagation_checker: PropagationChecker::new(),
96 smart_lookup: SmartLookup::new(),
97 status_client: StatusClient::new(),
98 }
99 }
100
101 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
102 self.concurrency = concurrency.max(1);
103 self
104 }
105
106 pub fn with_rate_limit(mut self, delay: Duration) -> Self {
107 self.rate_limit_delay = delay;
108 self
109 }
110
111 pub async fn execute(
112 &self,
113 operations: Vec<BulkOperation>,
114 progress: Option<ProgressCallback>,
115 ) -> Vec<BulkResult> {
116 let total = operations.len();
117 let completed = Arc::new(AtomicUsize::new(0));
118 let semaphore = Arc::new(Semaphore::new(self.concurrency));
119
120 debug!(
121 total = total,
122 concurrency = self.concurrency,
123 "Starting bulk execution"
124 );
125
126 let results: Vec<BulkResult> = stream::iter(operations)
127 .map(|op| {
128 let semaphore = semaphore.clone();
129 let completed = completed.clone();
130 let progress = progress.as_ref();
131 let rate_limit_delay = self.rate_limit_delay;
132 let whois_client = &self.whois_client;
133 let rdap_client = &self.rdap_client;
134 let dns_resolver = &self.dns_resolver;
135 let propagation_checker = &self.propagation_checker;
136 let smart_lookup = &self.smart_lookup;
137 let status_client = &self.status_client;
138
139 async move {
140 let _permit = match semaphore.acquire().await {
141 Ok(permit) => permit,
142 Err(_) => {
143 return BulkResult {
144 operation: op,
145 success: false,
146 data: None,
147 error: Some("Operation cancelled".to_string()),
148 duration_ms: 0,
149 };
150 }
151 };
152
153 if !rate_limit_delay.is_zero() {
155 sleep(rate_limit_delay).await;
156 }
157
158 let start = std::time::Instant::now();
159 let result = execute_operation(
160 &op,
161 whois_client,
162 rdap_client,
163 dns_resolver,
164 propagation_checker,
165 smart_lookup,
166 status_client,
167 )
168 .await;
169 let duration_ms = start.elapsed().as_millis() as u64;
170
171 let count = completed.fetch_add(1, Ordering::Relaxed) + 1;
172
173 if let Some(progress) = progress {
174 let desc = match &op {
175 BulkOperation::Whois { domain }
176 | BulkOperation::Rdap { domain }
177 | BulkOperation::Dns { domain, .. }
178 | BulkOperation::Propagation { domain, .. }
179 | BulkOperation::Lookup { domain }
180 | BulkOperation::Status { domain } => domain.as_str(),
181 };
182 progress(count, total, desc);
183 }
184
185 match result {
186 Ok(data) => BulkResult {
187 operation: op,
188 success: true,
189 data: Some(data),
190 error: None,
191 duration_ms,
192 },
193 Err(e) => {
194 warn!(error = %e, "Bulk operation failed");
195 BulkResult {
196 operation: op,
197 success: false,
198 data: None,
199 error: Some(e.to_string()),
200 duration_ms,
201 }
202 }
203 }
204 }
205 })
206 .buffer_unordered(self.concurrency)
207 .collect()
208 .await;
209
210 results
211 }
212
213 pub async fn execute_whois(&self, domains: Vec<String>) -> Vec<BulkResult> {
214 let operations = domains
215 .into_iter()
216 .map(|domain| BulkOperation::Whois { domain })
217 .collect();
218 self.execute(operations, None).await
219 }
220
221 pub async fn execute_rdap(&self, domains: Vec<String>) -> Vec<BulkResult> {
222 let operations = domains
223 .into_iter()
224 .map(|domain| BulkOperation::Rdap { domain })
225 .collect();
226 self.execute(operations, None).await
227 }
228
229 pub async fn execute_dns(
230 &self,
231 domains: Vec<String>,
232 record_type: RecordType,
233 ) -> Vec<BulkResult> {
234 let operations = domains
235 .into_iter()
236 .map(|domain| BulkOperation::Dns {
237 domain,
238 record_type,
239 })
240 .collect();
241 self.execute(operations, None).await
242 }
243
244 pub async fn execute_propagation(
245 &self,
246 domains: Vec<String>,
247 record_type: RecordType,
248 ) -> Vec<BulkResult> {
249 let operations = domains
250 .into_iter()
251 .map(|domain| BulkOperation::Propagation {
252 domain,
253 record_type,
254 })
255 .collect();
256 self.execute(operations, None).await
257 }
258
259 pub async fn execute_lookup(&self, domains: Vec<String>) -> Vec<BulkResult> {
260 let operations = domains
261 .into_iter()
262 .map(|domain| BulkOperation::Lookup { domain })
263 .collect();
264 self.execute(operations, None).await
265 }
266
267 pub async fn execute_status(&self, domains: Vec<String>) -> Vec<BulkResult> {
268 let operations = domains
269 .into_iter()
270 .map(|domain| BulkOperation::Status { domain })
271 .collect();
272 self.execute(operations, None).await
273 }
274}
275
276async fn execute_operation(
277 op: &BulkOperation,
278 whois_client: &WhoisClient,
279 rdap_client: &RdapClient,
280 dns_resolver: &DnsResolver,
281 propagation_checker: &PropagationChecker,
282 smart_lookup: &SmartLookup,
283 status_client: &StatusClient,
284) -> Result<BulkResultData> {
285 match op {
286 BulkOperation::Whois { domain } => {
287 let result = whois_client.lookup(domain).await?;
288 Ok(BulkResultData::Whois(result))
289 }
290 BulkOperation::Rdap { domain } => {
291 let result = rdap_client.lookup_domain(domain).await?;
292 Ok(BulkResultData::Rdap(Box::new(result)))
293 }
294 BulkOperation::Dns {
295 domain,
296 record_type,
297 } => {
298 let result = dns_resolver.resolve(domain, *record_type, None).await?;
299 Ok(BulkResultData::Dns(result))
300 }
301 BulkOperation::Propagation {
302 domain,
303 record_type,
304 } => {
305 let result = propagation_checker.check(domain, *record_type).await?;
306 Ok(BulkResultData::Propagation(result))
307 }
308 BulkOperation::Lookup { domain } => {
309 let result = smart_lookup.lookup(domain).await?;
310 Ok(BulkResultData::Lookup(result))
311 }
312 BulkOperation::Status { domain } => {
313 let result = status_client.check(domain).await?;
314 Ok(BulkResultData::Status(result))
315 }
316 }
317}
318
319pub fn parse_domains_from_file(content: &str) -> Vec<String> {
320 content
321 .lines()
322 .map(|line| line.trim())
323 .filter(|line| !line.is_empty() && !line.starts_with('#'))
324 .map(|line| {
325 line.split(',').next().unwrap_or(line).trim().to_string()
327 })
328 .filter(|domain| domain.contains('.'))
329 .collect()
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn test_parse_domains_from_file() {
338 let content = r#"
339# This is a comment
340example.com
341google.com
342 whitespace.com
343invalid
344csv,format,example.org
345"#;
346
347 let domains = parse_domains_from_file(content);
348 assert_eq!(domains.len(), 3);
349 assert!(domains.contains(&"example.com".to_string()));
350 assert!(domains.contains(&"google.com".to_string()));
351 assert!(domains.contains(&"whitespace.com".to_string()));
352 }
354}