ruvector_scipix/optimize/
batch.rs

1//! Dynamic batching for throughput optimization
2//!
3//! Provides intelligent batching to maximize GPU/CPU utilization while
4//! maintaining acceptable latency.
5
6use std::collections::VecDeque;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::{Mutex, oneshot};
10use tokio::time::sleep;
11
12/// Item in the batching queue
13pub struct BatchItem<T, R> {
14    pub data: T,
15    pub response: oneshot::Sender<BatchResult<R>>,
16    pub enqueued_at: Instant,
17}
18
19/// Result of batch processing
20pub type BatchResult<T> = std::result::Result<T, BatchError>;
21
22/// Batch processing errors
23#[derive(Debug, Clone)]
24pub enum BatchError {
25    Timeout,
26    ProcessingFailed(String),
27    QueueFull,
28}
29
30impl std::fmt::Display for BatchError {
31    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
32        match self {
33            BatchError::Timeout => write!(f, "Batch processing timeout"),
34            BatchError::ProcessingFailed(msg) => write!(f, "Processing failed: {}", msg),
35            BatchError::QueueFull => write!(f, "Queue is full"),
36        }
37    }
38}
39
40impl std::error::Error for BatchError {}
41
42/// Dynamic batcher configuration
43#[derive(Debug, Clone)]
44pub struct BatchConfig {
45    /// Maximum items in a batch
46    pub max_batch_size: usize,
47    /// Maximum time to wait before processing partial batch
48    pub max_wait_ms: u64,
49    /// Maximum queue size
50    pub max_queue_size: usize,
51    /// Minimum batch size to prefer
52    pub preferred_batch_size: usize,
53}
54
55impl Default for BatchConfig {
56    fn default() -> Self {
57        Self {
58            max_batch_size: 32,
59            max_wait_ms: 50,
60            max_queue_size: 1000,
61            preferred_batch_size: 16,
62        }
63    }
64}
65
66/// Dynamic batcher for throughput optimization
67pub struct DynamicBatcher<T, R> {
68    config: BatchConfig,
69    queue: Arc<Mutex<VecDeque<BatchItem<T, R>>>>,
70    processor: Arc<dyn Fn(Vec<T>) -> Vec<std::result::Result<R, String>> + Send + Sync>,
71    shutdown: Arc<Mutex<bool>>,
72}
73
74impl<T, R> DynamicBatcher<T, R>
75where
76    T: Send + 'static,
77    R: Send + 'static,
78{
79    /// Create new dynamic batcher
80    pub fn new<F>(config: BatchConfig, processor: F) -> Self
81    where
82        F: Fn(Vec<T>) -> Vec<std::result::Result<R, String>> + Send + Sync + 'static,
83    {
84        Self {
85            config,
86            queue: Arc::new(Mutex::new(VecDeque::new())),
87            processor: Arc::new(processor),
88            shutdown: Arc::new(Mutex::new(false)),
89        }
90    }
91
92    /// Add item to batch queue
93    pub async fn add(&self, item: T) -> BatchResult<R> {
94        let (tx, rx) = oneshot::channel();
95
96        let batch_item = BatchItem {
97            data: item,
98            response: tx,
99            enqueued_at: Instant::now(),
100        };
101
102        {
103            let mut queue = self.queue.lock().await;
104            if queue.len() >= self.config.max_queue_size {
105                return Err(BatchError::QueueFull);
106            }
107            queue.push_back(batch_item);
108        }
109
110        // Wait for response
111        rx.await.map_err(|_| BatchError::Timeout)?
112    }
113
114    /// Start batch processing loop
115    pub async fn run(&self) {
116        let mut last_process = Instant::now();
117
118        loop {
119            // Check if shutdown requested
120            {
121                let shutdown = self.shutdown.lock().await;
122                if *shutdown {
123                    break;
124                }
125            }
126
127            let should_process = {
128                let queue = self.queue.lock().await;
129                queue.len() >= self.config.max_batch_size
130                    || (queue.len() >= self.config.preferred_batch_size
131                        && last_process.elapsed().as_millis() >= self.config.max_wait_ms as u128)
132                    || (queue.len() > 0
133                        && last_process.elapsed().as_millis() >= self.config.max_wait_ms as u128)
134            };
135
136            if should_process {
137                self.process_batch().await;
138                last_process = Instant::now();
139            } else {
140                // Sleep briefly to avoid busy waiting
141                sleep(Duration::from_millis(1)).await;
142            }
143        }
144
145        // Process remaining items before shutdown
146        self.process_batch().await;
147    }
148
149    /// Process current batch
150    async fn process_batch(&self) {
151        let items = {
152            let mut queue = self.queue.lock().await;
153            let batch_size = self.config.max_batch_size.min(queue.len());
154            if batch_size == 0 {
155                return;
156            }
157            queue.drain(..batch_size).collect::<Vec<_>>()
158        };
159
160        if items.is_empty() {
161            return;
162        }
163
164        // Extract data and response channels
165        let (data, responses): (Vec<_>, Vec<_>) = items
166            .into_iter()
167            .map(|item| (item.data, item.response))
168            .unzip();
169
170        // Process batch
171        let results = (self.processor)(data);
172
173        // Send responses
174        for (response_tx, result) in responses.into_iter().zip(results.into_iter()) {
175            let batch_result = result.map_err(|e| BatchError::ProcessingFailed(e));
176            let _ = response_tx.send(batch_result);
177        }
178    }
179
180    /// Gracefully shutdown the batcher
181    pub async fn shutdown(&self) {
182        let mut shutdown = self.shutdown.lock().await;
183        *shutdown = true;
184    }
185
186    /// Get current queue size
187    pub async fn queue_size(&self) -> usize {
188        self.queue.lock().await.len()
189    }
190
191    /// Get current queue statistics
192    pub async fn stats(&self) -> BatchStats {
193        let queue = self.queue.lock().await;
194        let queue_size = queue.len();
195
196        let max_wait = queue
197            .front()
198            .map(|item| item.enqueued_at.elapsed())
199            .unwrap_or(Duration::from_secs(0));
200
201        BatchStats {
202            queue_size,
203            max_wait_time: max_wait,
204        }
205    }
206}
207
208/// Batch statistics
209#[derive(Debug, Clone)]
210pub struct BatchStats {
211    pub queue_size: usize,
212    pub max_wait_time: Duration,
213}
214
215/// Adaptive batcher that adjusts batch size based on latency
216pub struct AdaptiveBatcher<T, R> {
217    inner: DynamicBatcher<T, R>,
218    config: Arc<Mutex<BatchConfig>>,
219    latency_history: Arc<Mutex<VecDeque<Duration>>>,
220    target_latency: Duration,
221}
222
223impl<T, R> AdaptiveBatcher<T, R>
224where
225    T: Send + 'static,
226    R: Send + 'static,
227{
228    /// Create adaptive batcher with target latency
229    pub fn new<F>(
230        initial_config: BatchConfig,
231        target_latency: Duration,
232        processor: F,
233    ) -> Self
234    where
235        F: Fn(Vec<T>) -> Vec<Result<R, String>> + Send + Sync + 'static,
236    {
237        let config = Arc::new(Mutex::new(initial_config.clone()));
238        let inner = DynamicBatcher::new(initial_config, processor);
239
240        Self {
241            inner,
242            config,
243            latency_history: Arc::new(Mutex::new(VecDeque::with_capacity(100))),
244            target_latency,
245        }
246    }
247
248    /// Add item and adapt batch size
249    pub async fn add(&self, item: T) -> Result<R, BatchError> {
250        let start = Instant::now();
251        let result = self.inner.add(item).await;
252        let latency = start.elapsed();
253
254        // Record latency
255        {
256            let mut history = self.latency_history.lock().await;
257            history.push_back(latency);
258            if history.len() > 100 {
259                history.pop_front();
260            }
261        }
262
263        // Adapt batch size every 10 requests
264        {
265            let history = self.latency_history.lock().await;
266            if history.len() % 10 == 0 && history.len() >= 10 {
267                let avg_latency: Duration = history.iter().sum::<Duration>() / history.len() as u32;
268
269                let mut config = self.config.lock().await;
270                if avg_latency > self.target_latency {
271                    // Reduce batch size to lower latency
272                    config.max_batch_size = (config.max_batch_size * 9 / 10).max(1);
273                } else if avg_latency < self.target_latency / 2 {
274                    // Increase batch size for better throughput
275                    config.max_batch_size = (config.max_batch_size * 11 / 10).min(128);
276                }
277            }
278        }
279
280        result
281    }
282
283    /// Run the batcher
284    pub async fn run(&self) {
285        self.inner.run().await;
286    }
287
288    /// Get current configuration
289    pub async fn current_config(&self) -> BatchConfig {
290        self.config.lock().await.clone()
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[tokio::test]
299    async fn test_dynamic_batcher() {
300        let config = BatchConfig {
301            max_batch_size: 4,
302            max_wait_ms: 100,
303            max_queue_size: 100,
304            preferred_batch_size: 2,
305        };
306
307        let batcher = Arc::new(DynamicBatcher::new(config, |items: Vec<i32>| {
308            items.into_iter().map(|x| Ok(x * 2)).collect()
309        }));
310
311        // Start processing loop
312        let batcher_clone = batcher.clone();
313        tokio::spawn(async move {
314            batcher_clone.run().await;
315        });
316
317        // Add items
318        let mut handles = vec![];
319        for i in 0..8 {
320            let batcher = batcher.clone();
321            handles.push(tokio::spawn(async move {
322                batcher.add(i).await
323            }));
324        }
325
326        // Wait for results
327        for (i, handle) in handles.into_iter().enumerate() {
328            let result = handle.await.unwrap().unwrap();
329            assert_eq!(result, (i as i32) * 2);
330        }
331
332        batcher.shutdown().await;
333    }
334
335    #[tokio::test]
336    async fn test_batch_stats() {
337        let config = BatchConfig::default();
338        let batcher = DynamicBatcher::new(config, |items: Vec<i32>| {
339            items.into_iter().map(|x| Ok(x)).collect()
340        });
341
342        // Queue some items without processing
343        let _ = batcher.add(1);
344        let _ = batcher.add(2);
345        let _ = batcher.add(3);
346
347        let stats = batcher.stats().await;
348        assert_eq!(stats.queue_size, 3);
349    }
350
351    #[tokio::test]
352    async fn test_queue_full() {
353        let config = BatchConfig {
354            max_queue_size: 2,
355            ..Default::default()
356        };
357
358        let batcher = DynamicBatcher::new(config, |items: Vec<i32>| {
359            std::thread::sleep(Duration::from_secs(1)); // Slow processing
360            items.into_iter().map(|x| Ok(x)).collect()
361        });
362
363        // Fill queue
364        let _ = batcher.add(1);
365        let _ = batcher.add(2);
366
367        // This should fail - queue is full
368        let result = batcher.add(3).await;
369        assert!(matches!(result, Err(BatchError::QueueFull)));
370    }
371
372    #[tokio::test]
373    async fn test_adaptive_batcher() {
374        let config = BatchConfig {
375            max_batch_size: 8,
376            max_wait_ms: 50,
377            max_queue_size: 100,
378            preferred_batch_size: 4,
379        };
380
381        let batcher = Arc::new(AdaptiveBatcher::new(
382            config,
383            Duration::from_millis(100),
384            |items: Vec<i32>| items.into_iter().map(|x| Ok(x * 2)).collect(),
385        ));
386
387        let batcher_clone = batcher.clone();
388        tokio::spawn(async move {
389            batcher_clone.run().await;
390        });
391
392        // Process some requests
393        for i in 0..20 {
394            let result = batcher.add(i).await.unwrap();
395            assert_eq!(result, i * 2);
396        }
397
398        // Configuration should have adapted
399        let final_config = batcher.current_config().await;
400        assert!(final_config.max_batch_size > 0);
401    }
402}