tact_client/
batch.rs

1//! Request batching for CDN downloads using HTTP/2 multiplexing
2
3use crate::{Error, Result};
4use futures_util::stream::{self, StreamExt};
5use reqwest::{Client, Response};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::{Mutex, mpsc};
10use tracing::{debug, info, warn};
11
12/// Default batch size for request batching
13const DEFAULT_BATCH_SIZE: usize = 20;
14
15/// Default batch timeout in milliseconds
16const DEFAULT_BATCH_TIMEOUT_MS: u64 = 100;
17
18/// Default maximum concurrent batches
19const DEFAULT_MAX_CONCURRENT_BATCHES: usize = 4;
20
21/// Configuration for request batching
22#[derive(Debug, Clone)]
23pub struct BatchConfig {
24    /// Maximum number of requests per batch
25    pub batch_size: usize,
26    /// Maximum time to wait for a batch to fill (milliseconds)
27    pub batch_timeout_ms: u64,
28    /// Maximum number of concurrent batches
29    pub max_concurrent_batches: usize,
30    /// Maximum time to wait for all requests in a batch to complete
31    pub batch_execution_timeout: Duration,
32}
33
34impl Default for BatchConfig {
35    fn default() -> Self {
36        Self {
37            batch_size: DEFAULT_BATCH_SIZE,
38            batch_timeout_ms: DEFAULT_BATCH_TIMEOUT_MS,
39            max_concurrent_batches: DEFAULT_MAX_CONCURRENT_BATCHES,
40            batch_execution_timeout: Duration::from_secs(60),
41        }
42    }
43}
44
45/// Statistics for batch operations
46#[derive(Debug, Clone)]
47pub struct BatchStats {
48    /// Total number of batches processed
49    pub batches_processed: u64,
50    /// Total number of requests processed
51    pub requests_processed: u64,
52    /// Average batch size
53    pub avg_batch_size: f64,
54    /// Total time spent in batch processing
55    pub total_batch_time: Duration,
56    /// Average time per batch
57    pub avg_batch_time: Duration,
58    /// HTTP/2 connections established
59    pub http2_connections: u64,
60}
61
62/// A single request in a batch
63#[derive(Debug, Clone)]
64pub struct BatchRequest {
65    /// Unique ID for this request
66    pub id: String,
67    /// Full URL to request
68    pub url: String,
69    /// Optional headers
70    pub headers: HashMap<String, String>,
71}
72
73/// Result of a batch request
74#[derive(Debug)]
75pub struct BatchResponse {
76    /// Request ID this response corresponds to
77    pub request_id: String,
78    /// The HTTP response (or error)
79    pub result: Result<Response>,
80    /// Time taken for this request
81    pub duration: Duration,
82}
83
84/// A batch of requests to be executed together
85#[allow(dead_code)]
86#[derive(Debug)]
87struct RequestBatch {
88    /// Requests in this batch
89    requests: Vec<BatchRequest>,
90    /// Channel to send responses back
91    response_tx: mpsc::UnboundedSender<BatchResponse>,
92    /// When this batch was created
93    created_at: Instant,
94}
95
96/// HTTP/2 request batcher for CDN downloads
97#[derive(Debug)]
98pub struct RequestBatcher {
99    /// HTTP client with HTTP/2 support
100    #[allow(dead_code)]
101    client: Client,
102    /// Configuration
103    #[allow(dead_code)]
104    config: BatchConfig,
105    /// Channel for incoming requests
106    request_tx: mpsc::UnboundedSender<(BatchRequest, mpsc::UnboundedSender<BatchResponse>)>,
107    /// Statistics
108    stats: Arc<Mutex<BatchStats>>,
109}
110
111impl RequestBatcher {
112    /// Create a new request batcher
113    pub fn new(client: Client, config: BatchConfig) -> Self {
114        let (request_tx, request_rx) = mpsc::unbounded_channel();
115        let stats = Arc::new(Mutex::new(BatchStats {
116            batches_processed: 0,
117            requests_processed: 0,
118            avg_batch_size: 0.0,
119            total_batch_time: Duration::ZERO,
120            avg_batch_time: Duration::ZERO,
121            http2_connections: 0,
122        }));
123
124        let batcher = Self {
125            client: client.clone(),
126            config: config.clone(),
127            request_tx,
128            stats: Arc::clone(&stats),
129        };
130
131        // Start the batch processing task
132        let batch_processor = BatchProcessor {
133            client,
134            config,
135            request_rx: Arc::new(Mutex::new(request_rx)),
136            stats,
137        };
138
139        tokio::spawn(batch_processor.run());
140
141        batcher
142    }
143
144    /// Submit a request to be batched
145    ///
146    /// Returns a receiver for the response. The request will be batched with others
147    /// and executed when the batch is full or the timeout expires.
148    pub async fn submit_request(
149        &self,
150        request: BatchRequest,
151    ) -> Result<mpsc::UnboundedReceiver<BatchResponse>> {
152        let (response_tx, response_rx) = mpsc::unbounded_channel();
153
154        self.request_tx
155            .send((request, response_tx))
156            .map_err(|_| Error::InvalidResponse)?;
157
158        Ok(response_rx)
159    }
160
161    /// Submit multiple requests and wait for all responses
162    pub async fn submit_requests_and_wait(
163        &self,
164        requests: Vec<BatchRequest>,
165    ) -> Vec<BatchResponse> {
166        let mut receivers = Vec::new();
167
168        // Submit all requests
169        for request in requests {
170            match self.submit_request(request).await {
171                Ok(rx) => receivers.push(rx),
172                Err(e) => {
173                    // Create error response for failed submission
174                    receivers.push({
175                        let (tx, rx) = mpsc::unbounded_channel();
176                        let _ = tx.send(BatchResponse {
177                            request_id: "unknown".to_string(),
178                            result: Err(e),
179                            duration: Duration::ZERO,
180                        });
181                        rx
182                    });
183                }
184            }
185        }
186
187        // Collect all responses
188        let mut responses = Vec::new();
189        for mut rx in receivers {
190            if let Some(response) = rx.recv().await {
191                responses.push(response);
192            } else {
193                responses.push(BatchResponse {
194                    request_id: "unknown".to_string(),
195                    result: Err(Error::InvalidResponse),
196                    duration: Duration::ZERO,
197                });
198            }
199        }
200
201        responses
202    }
203
204    /// Get current statistics
205    pub async fn get_stats(&self) -> BatchStats {
206        self.stats.lock().await.clone()
207    }
208
209    /// Create batch requests for CDN file downloads
210    pub fn create_cdn_requests(cdn_host: &str, path: &str, hashes: &[String]) -> Vec<BatchRequest> {
211        hashes
212            .iter()
213            .map(|hash| BatchRequest {
214                id: hash.clone(),
215                url: format!(
216                    "http://{}/{}/{}/{}/{}",
217                    cdn_host,
218                    path.trim_matches('/'),
219                    &hash[0..2],
220                    &hash[2..4],
221                    hash
222                ),
223                headers: HashMap::new(),
224            })
225            .collect()
226    }
227}
228
229/// Type alias for the request receiver channel
230type RequestReceiver =
231    Arc<Mutex<mpsc::UnboundedReceiver<(BatchRequest, mpsc::UnboundedSender<BatchResponse>)>>>;
232
233/// Internal batch processor
234struct BatchProcessor {
235    client: Client,
236    config: BatchConfig,
237    request_rx: RequestReceiver,
238    stats: Arc<Mutex<BatchStats>>,
239}
240
241impl BatchProcessor {
242    async fn run(self) {
243        let mut current_batch: Vec<(BatchRequest, mpsc::UnboundedSender<BatchResponse>)> =
244            Vec::new();
245        let mut batch_timer =
246            tokio::time::interval(Duration::from_millis(self.config.batch_timeout_ms));
247        let mut request_rx = self.request_rx.lock().await;
248
249        debug!(
250            "Starting batch processor with config: batch_size={}, timeout={}ms, max_concurrent={}",
251            self.config.batch_size,
252            self.config.batch_timeout_ms,
253            self.config.max_concurrent_batches
254        );
255
256        loop {
257            tokio::select! {
258                // New request received
259                maybe_request = request_rx.recv() => {
260                    match maybe_request {
261                        Some((request, response_tx)) => {
262                            current_batch.push((request, response_tx));
263
264                            // If batch is full, process it immediately
265                            if current_batch.len() >= self.config.batch_size {
266                                let batch = std::mem::take(&mut current_batch);
267                                self.process_batch(batch).await;
268                            }
269                        }
270                        None => {
271                            // Channel closed, process remaining batch and exit
272                            if !current_batch.is_empty() {
273                                let batch = std::mem::take(&mut current_batch);
274                                self.process_batch(batch).await;
275                            }
276                            break;
277                        }
278                    }
279                }
280
281                // Batch timeout expired
282                _ = batch_timer.tick() => {
283                    if !current_batch.is_empty() {
284                        let batch = std::mem::take(&mut current_batch);
285                        self.process_batch(batch).await;
286                    }
287                }
288            }
289        }
290
291        debug!("Batch processor shutting down");
292    }
293
294    async fn process_batch(
295        &self,
296        batch: Vec<(BatchRequest, mpsc::UnboundedSender<BatchResponse>)>,
297    ) {
298        if batch.is_empty() {
299            return;
300        }
301
302        let batch_start = Instant::now();
303        let batch_size = batch.len();
304
305        debug!("Processing batch of {} requests", batch_size);
306
307        // Group requests by host to maximize HTTP/2 connection reuse
308        let mut requests_by_host: HashMap<String, Vec<_>> = HashMap::new();
309
310        for (request, response_tx) in batch {
311            let host = self
312                .extract_host(&request.url)
313                .unwrap_or_else(|| "unknown".to_string());
314            requests_by_host
315                .entry(host)
316                .or_default()
317                .push((request, response_tx));
318        }
319
320        // Process each host group concurrently (up to max_concurrent_batches)
321        let host_groups: Vec<_> = requests_by_host.into_iter().collect();
322        let concurrent_limit = self.config.max_concurrent_batches.min(host_groups.len());
323
324        stream::iter(host_groups)
325            .map(|(host, requests)| async move {
326                self.process_host_batch(host, requests).await;
327            })
328            .buffer_unordered(concurrent_limit)
329            .collect::<Vec<_>>()
330            .await;
331
332        let batch_duration = batch_start.elapsed();
333
334        // Update statistics
335        let mut stats = self.stats.lock().await;
336        stats.batches_processed += 1;
337        stats.requests_processed += batch_size as u64;
338        stats.total_batch_time += batch_duration;
339        stats.avg_batch_size = stats.requests_processed as f64 / stats.batches_processed as f64;
340        stats.avg_batch_time = stats.total_batch_time / stats.batches_processed as u32;
341
342        info!(
343            "Processed batch: {} requests in {:?} (avg: {:.1} reqs/batch, {:?}/batch)",
344            batch_size, batch_duration, stats.avg_batch_size, stats.avg_batch_time
345        );
346    }
347
348    async fn process_host_batch(
349        &self,
350        host: String,
351        requests: Vec<(BatchRequest, mpsc::UnboundedSender<BatchResponse>)>,
352    ) {
353        debug!("Processing {} requests for host: {}", requests.len(), host);
354
355        // Check if server supports HTTP/2
356        let supports_http2 = self.check_http2_support(&host).await;
357        if supports_http2 {
358            let mut stats = self.stats.lock().await;
359            stats.http2_connections += 1;
360            debug!("HTTP/2 support confirmed for host: {}", host);
361        }
362
363        // Execute all requests for this host concurrently
364        // HTTP/2 multiplexing allows multiple requests on the same connection
365        let num_requests = requests.len();
366        let futures = requests
367            .into_iter()
368            .map(|(request, response_tx)| async move {
369                let start_time = Instant::now();
370                let request_id = request.id.clone();
371
372                let result = self.execute_request(request).await;
373                let duration = start_time.elapsed();
374
375                let response = BatchResponse {
376                    request_id,
377                    result,
378                    duration,
379                };
380
381                if response_tx.send(response).is_err() {
382                    warn!("Failed to send batch response - receiver dropped");
383                }
384            });
385
386        // Execute all requests concurrently
387        stream::iter(futures)
388            .buffer_unordered(num_requests) // Use HTTP/2 multiplexing
389            .collect::<Vec<_>>()
390            .await;
391    }
392
393    async fn execute_request(&self, request: BatchRequest) -> Result<Response> {
394        let mut req_builder = self.client.get(&request.url);
395
396        // Add custom headers
397        for (key, value) in &request.headers {
398            req_builder = req_builder.header(key, value);
399        }
400
401        // Execute with timeout
402        let response =
403            tokio::time::timeout(self.config.batch_execution_timeout, req_builder.send()).await;
404
405        match response {
406            Ok(Ok(response)) => {
407                if response.status().is_success() {
408                    Ok(response)
409                } else {
410                    Err(Error::Http(response.error_for_status().unwrap_err()))
411                }
412            }
413            Ok(Err(e)) => Err(Error::Http(e)),
414            Err(_) => Err(Error::InvalidResponse),
415        }
416    }
417
418    async fn check_http2_support(&self, host: &str) -> bool {
419        // Try a simple request to check HTTP version
420        // This is optimistic - we assume HTTP/2 support for HTTPS hosts
421        // and rely on reqwest's automatic protocol negotiation
422        host.starts_with("https://") ||
423        // For CDN hosts, we know most support HTTP/2
424        host.contains("akamai") || host.contains("cloudflare") || host.contains("blizzard")
425    }
426
427    fn extract_host(&self, url: &str) -> Option<String> {
428        if let Ok(parsed) = url::Url::parse(url) {
429            parsed.host_str().map(|s| s.to_string())
430        } else {
431            None
432        }
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439    use tokio::time::{Duration, sleep};
440
441    #[test]
442    fn test_batch_config_default() {
443        let config = BatchConfig::default();
444        assert_eq!(config.batch_size, DEFAULT_BATCH_SIZE);
445        assert_eq!(config.batch_timeout_ms, DEFAULT_BATCH_TIMEOUT_MS);
446        assert_eq!(
447            config.max_concurrent_batches,
448            DEFAULT_MAX_CONCURRENT_BATCHES
449        );
450    }
451
452    #[test]
453    fn test_create_cdn_requests() {
454        let hashes = vec!["abcd1234".to_string(), "efgh5678".to_string()];
455
456        let requests = RequestBatcher::create_cdn_requests("example.com", "data", &hashes);
457
458        assert_eq!(requests.len(), 2);
459        assert_eq!(requests[0].id, "abcd1234");
460        assert_eq!(requests[0].url, "http://example.com/data/ab/cd/abcd1234");
461        assert_eq!(requests[1].id, "efgh5678");
462        assert_eq!(requests[1].url, "http://example.com/data/ef/gh/efgh5678");
463    }
464
465    #[tokio::test]
466    async fn test_batch_stats_initialization() {
467        let client = reqwest::Client::new();
468        let config = BatchConfig::default();
469        let batcher = RequestBatcher::new(client, config);
470
471        let stats = batcher.get_stats().await;
472        assert_eq!(stats.batches_processed, 0);
473        assert_eq!(stats.requests_processed, 0);
474        assert_eq!(stats.avg_batch_size, 0.0);
475    }
476
477    #[tokio::test]
478    #[ignore = "Test depends on actual network requests for stats validation"]
479    async fn test_request_submission() {
480        let client = reqwest::Client::new();
481        let config = BatchConfig {
482            batch_timeout_ms: 50, // Short timeout for test
483            ..BatchConfig::default()
484        };
485        let batcher = RequestBatcher::new(client, config);
486
487        let request = BatchRequest {
488            id: "test123".to_string(),
489            url: "http://httpbin.org/status/200".to_string(), // Test endpoint
490            headers: HashMap::new(),
491        };
492
493        // This will fail in tests without network, but tests the API
494        let _receiver = batcher.submit_request(request).await;
495        // Just test that submission doesn't panic
496
497        // Wait a bit for batch processing
498        sleep(Duration::from_millis(100)).await;
499
500        // Stats should be updated (even if requests failed)
501        let stats = batcher.get_stats().await;
502        assert!(stats.batches_processed > 0 || stats.requests_processed > 0);
503    }
504}