1use crate::{Error, Result};
4use futures_util::stream::{self, StreamExt};
5use reqwest::{Client, Response};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::{Mutex, mpsc};
10use tracing::{debug, info, warn};
11
12const DEFAULT_BATCH_SIZE: usize = 20;
14
15const DEFAULT_BATCH_TIMEOUT_MS: u64 = 100;
17
18const DEFAULT_MAX_CONCURRENT_BATCHES: usize = 4;
20
21#[derive(Debug, Clone)]
23pub struct BatchConfig {
24 pub batch_size: usize,
26 pub batch_timeout_ms: u64,
28 pub max_concurrent_batches: usize,
30 pub batch_execution_timeout: Duration,
32}
33
34impl Default for BatchConfig {
35 fn default() -> Self {
36 Self {
37 batch_size: DEFAULT_BATCH_SIZE,
38 batch_timeout_ms: DEFAULT_BATCH_TIMEOUT_MS,
39 max_concurrent_batches: DEFAULT_MAX_CONCURRENT_BATCHES,
40 batch_execution_timeout: Duration::from_secs(60),
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct BatchStats {
48 pub batches_processed: u64,
50 pub requests_processed: u64,
52 pub avg_batch_size: f64,
54 pub total_batch_time: Duration,
56 pub avg_batch_time: Duration,
58 pub http2_connections: u64,
60}
61
62#[derive(Debug, Clone)]
64pub struct BatchRequest {
65 pub id: String,
67 pub url: String,
69 pub headers: HashMap<String, String>,
71}
72
73#[derive(Debug)]
75pub struct BatchResponse {
76 pub request_id: String,
78 pub result: Result<Response>,
80 pub duration: Duration,
82}
83
84#[allow(dead_code)]
86#[derive(Debug)]
87struct RequestBatch {
88 requests: Vec<BatchRequest>,
90 response_tx: mpsc::UnboundedSender<BatchResponse>,
92 created_at: Instant,
94}
95
96#[derive(Debug)]
98pub struct RequestBatcher {
99 #[allow(dead_code)]
101 client: Client,
102 #[allow(dead_code)]
104 config: BatchConfig,
105 request_tx: mpsc::UnboundedSender<(BatchRequest, mpsc::UnboundedSender<BatchResponse>)>,
107 stats: Arc<Mutex<BatchStats>>,
109}
110
111impl RequestBatcher {
112 pub fn new(client: Client, config: BatchConfig) -> Self {
114 let (request_tx, request_rx) = mpsc::unbounded_channel();
115 let stats = Arc::new(Mutex::new(BatchStats {
116 batches_processed: 0,
117 requests_processed: 0,
118 avg_batch_size: 0.0,
119 total_batch_time: Duration::ZERO,
120 avg_batch_time: Duration::ZERO,
121 http2_connections: 0,
122 }));
123
124 let batcher = Self {
125 client: client.clone(),
126 config: config.clone(),
127 request_tx,
128 stats: Arc::clone(&stats),
129 };
130
131 let batch_processor = BatchProcessor {
133 client,
134 config,
135 request_rx: Arc::new(Mutex::new(request_rx)),
136 stats,
137 };
138
139 tokio::spawn(batch_processor.run());
140
141 batcher
142 }
143
144 pub async fn submit_request(
149 &self,
150 request: BatchRequest,
151 ) -> Result<mpsc::UnboundedReceiver<BatchResponse>> {
152 let (response_tx, response_rx) = mpsc::unbounded_channel();
153
154 self.request_tx
155 .send((request, response_tx))
156 .map_err(|_| Error::InvalidResponse)?;
157
158 Ok(response_rx)
159 }
160
161 pub async fn submit_requests_and_wait(
163 &self,
164 requests: Vec<BatchRequest>,
165 ) -> Vec<BatchResponse> {
166 let mut receivers = Vec::new();
167
168 for request in requests {
170 match self.submit_request(request).await {
171 Ok(rx) => receivers.push(rx),
172 Err(e) => {
173 receivers.push({
175 let (tx, rx) = mpsc::unbounded_channel();
176 let _ = tx.send(BatchResponse {
177 request_id: "unknown".to_string(),
178 result: Err(e),
179 duration: Duration::ZERO,
180 });
181 rx
182 });
183 }
184 }
185 }
186
187 let mut responses = Vec::new();
189 for mut rx in receivers {
190 if let Some(response) = rx.recv().await {
191 responses.push(response);
192 } else {
193 responses.push(BatchResponse {
194 request_id: "unknown".to_string(),
195 result: Err(Error::InvalidResponse),
196 duration: Duration::ZERO,
197 });
198 }
199 }
200
201 responses
202 }
203
204 pub async fn get_stats(&self) -> BatchStats {
206 self.stats.lock().await.clone()
207 }
208
209 pub fn create_cdn_requests(cdn_host: &str, path: &str, hashes: &[String]) -> Vec<BatchRequest> {
211 hashes
212 .iter()
213 .map(|hash| BatchRequest {
214 id: hash.clone(),
215 url: format!(
216 "http://{}/{}/{}/{}/{}",
217 cdn_host,
218 path.trim_matches('/'),
219 &hash[0..2],
220 &hash[2..4],
221 hash
222 ),
223 headers: HashMap::new(),
224 })
225 .collect()
226 }
227}
228
229type RequestReceiver =
231 Arc<Mutex<mpsc::UnboundedReceiver<(BatchRequest, mpsc::UnboundedSender<BatchResponse>)>>>;
232
233struct BatchProcessor {
235 client: Client,
236 config: BatchConfig,
237 request_rx: RequestReceiver,
238 stats: Arc<Mutex<BatchStats>>,
239}
240
241impl BatchProcessor {
242 async fn run(self) {
243 let mut current_batch: Vec<(BatchRequest, mpsc::UnboundedSender<BatchResponse>)> =
244 Vec::new();
245 let mut batch_timer =
246 tokio::time::interval(Duration::from_millis(self.config.batch_timeout_ms));
247 let mut request_rx = self.request_rx.lock().await;
248
249 debug!(
250 "Starting batch processor with config: batch_size={}, timeout={}ms, max_concurrent={}",
251 self.config.batch_size,
252 self.config.batch_timeout_ms,
253 self.config.max_concurrent_batches
254 );
255
256 loop {
257 tokio::select! {
258 maybe_request = request_rx.recv() => {
260 match maybe_request {
261 Some((request, response_tx)) => {
262 current_batch.push((request, response_tx));
263
264 if current_batch.len() >= self.config.batch_size {
266 let batch = std::mem::take(&mut current_batch);
267 self.process_batch(batch).await;
268 }
269 }
270 None => {
271 if !current_batch.is_empty() {
273 let batch = std::mem::take(&mut current_batch);
274 self.process_batch(batch).await;
275 }
276 break;
277 }
278 }
279 }
280
281 _ = batch_timer.tick() => {
283 if !current_batch.is_empty() {
284 let batch = std::mem::take(&mut current_batch);
285 self.process_batch(batch).await;
286 }
287 }
288 }
289 }
290
291 debug!("Batch processor shutting down");
292 }
293
294 async fn process_batch(
295 &self,
296 batch: Vec<(BatchRequest, mpsc::UnboundedSender<BatchResponse>)>,
297 ) {
298 if batch.is_empty() {
299 return;
300 }
301
302 let batch_start = Instant::now();
303 let batch_size = batch.len();
304
305 debug!("Processing batch of {} requests", batch_size);
306
307 let mut requests_by_host: HashMap<String, Vec<_>> = HashMap::new();
309
310 for (request, response_tx) in batch {
311 let host = self
312 .extract_host(&request.url)
313 .unwrap_or_else(|| "unknown".to_string());
314 requests_by_host
315 .entry(host)
316 .or_default()
317 .push((request, response_tx));
318 }
319
320 let host_groups: Vec<_> = requests_by_host.into_iter().collect();
322 let concurrent_limit = self.config.max_concurrent_batches.min(host_groups.len());
323
324 stream::iter(host_groups)
325 .map(|(host, requests)| async move {
326 self.process_host_batch(host, requests).await;
327 })
328 .buffer_unordered(concurrent_limit)
329 .collect::<Vec<_>>()
330 .await;
331
332 let batch_duration = batch_start.elapsed();
333
334 let mut stats = self.stats.lock().await;
336 stats.batches_processed += 1;
337 stats.requests_processed += batch_size as u64;
338 stats.total_batch_time += batch_duration;
339 stats.avg_batch_size = stats.requests_processed as f64 / stats.batches_processed as f64;
340 stats.avg_batch_time = stats.total_batch_time / stats.batches_processed as u32;
341
342 info!(
343 "Processed batch: {} requests in {:?} (avg: {:.1} reqs/batch, {:?}/batch)",
344 batch_size, batch_duration, stats.avg_batch_size, stats.avg_batch_time
345 );
346 }
347
348 async fn process_host_batch(
349 &self,
350 host: String,
351 requests: Vec<(BatchRequest, mpsc::UnboundedSender<BatchResponse>)>,
352 ) {
353 debug!("Processing {} requests for host: {}", requests.len(), host);
354
355 let supports_http2 = self.check_http2_support(&host).await;
357 if supports_http2 {
358 let mut stats = self.stats.lock().await;
359 stats.http2_connections += 1;
360 debug!("HTTP/2 support confirmed for host: {}", host);
361 }
362
363 let num_requests = requests.len();
366 let futures = requests
367 .into_iter()
368 .map(|(request, response_tx)| async move {
369 let start_time = Instant::now();
370 let request_id = request.id.clone();
371
372 let result = self.execute_request(request).await;
373 let duration = start_time.elapsed();
374
375 let response = BatchResponse {
376 request_id,
377 result,
378 duration,
379 };
380
381 if response_tx.send(response).is_err() {
382 warn!("Failed to send batch response - receiver dropped");
383 }
384 });
385
386 stream::iter(futures)
388 .buffer_unordered(num_requests) .collect::<Vec<_>>()
390 .await;
391 }
392
393 async fn execute_request(&self, request: BatchRequest) -> Result<Response> {
394 let mut req_builder = self.client.get(&request.url);
395
396 for (key, value) in &request.headers {
398 req_builder = req_builder.header(key, value);
399 }
400
401 let response =
403 tokio::time::timeout(self.config.batch_execution_timeout, req_builder.send()).await;
404
405 match response {
406 Ok(Ok(response)) => {
407 if response.status().is_success() {
408 Ok(response)
409 } else {
410 Err(Error::Http(response.error_for_status().unwrap_err()))
411 }
412 }
413 Ok(Err(e)) => Err(Error::Http(e)),
414 Err(_) => Err(Error::InvalidResponse),
415 }
416 }
417
418 async fn check_http2_support(&self, host: &str) -> bool {
419 host.starts_with("https://") ||
423 host.contains("akamai") || host.contains("cloudflare") || host.contains("blizzard")
425 }
426
427 fn extract_host(&self, url: &str) -> Option<String> {
428 if let Ok(parsed) = url::Url::parse(url) {
429 parsed.host_str().map(|s| s.to_string())
430 } else {
431 None
432 }
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use tokio::time::{Duration, sleep};
440
441 #[test]
442 fn test_batch_config_default() {
443 let config = BatchConfig::default();
444 assert_eq!(config.batch_size, DEFAULT_BATCH_SIZE);
445 assert_eq!(config.batch_timeout_ms, DEFAULT_BATCH_TIMEOUT_MS);
446 assert_eq!(
447 config.max_concurrent_batches,
448 DEFAULT_MAX_CONCURRENT_BATCHES
449 );
450 }
451
452 #[test]
453 fn test_create_cdn_requests() {
454 let hashes = vec!["abcd1234".to_string(), "efgh5678".to_string()];
455
456 let requests = RequestBatcher::create_cdn_requests("example.com", "data", &hashes);
457
458 assert_eq!(requests.len(), 2);
459 assert_eq!(requests[0].id, "abcd1234");
460 assert_eq!(requests[0].url, "http://example.com/data/ab/cd/abcd1234");
461 assert_eq!(requests[1].id, "efgh5678");
462 assert_eq!(requests[1].url, "http://example.com/data/ef/gh/efgh5678");
463 }
464
465 #[tokio::test]
466 async fn test_batch_stats_initialization() {
467 let client = reqwest::Client::new();
468 let config = BatchConfig::default();
469 let batcher = RequestBatcher::new(client, config);
470
471 let stats = batcher.get_stats().await;
472 assert_eq!(stats.batches_processed, 0);
473 assert_eq!(stats.requests_processed, 0);
474 assert_eq!(stats.avg_batch_size, 0.0);
475 }
476
477 #[tokio::test]
478 #[ignore = "Test depends on actual network requests for stats validation"]
479 async fn test_request_submission() {
480 let client = reqwest::Client::new();
481 let config = BatchConfig {
482 batch_timeout_ms: 50, ..BatchConfig::default()
484 };
485 let batcher = RequestBatcher::new(client, config);
486
487 let request = BatchRequest {
488 id: "test123".to_string(),
489 url: "http://httpbin.org/status/200".to_string(), headers: HashMap::new(),
491 };
492
493 let _receiver = batcher.submit_request(request).await;
495 sleep(Duration::from_millis(100)).await;
499
500 let stats = batcher.get_stats().await;
502 assert!(stats.batches_processed > 0 || stats.requests_processed > 0);
503 }
504}