ruvector_scipix/optimize/
batch.rs1use std::collections::VecDeque;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::{Mutex, oneshot};
10use tokio::time::sleep;
11
12pub struct BatchItem<T, R> {
14 pub data: T,
15 pub response: oneshot::Sender<BatchResult<R>>,
16 pub enqueued_at: Instant,
17}
18
19pub type BatchResult<T> = std::result::Result<T, BatchError>;
21
22#[derive(Debug, Clone)]
24pub enum BatchError {
25 Timeout,
26 ProcessingFailed(String),
27 QueueFull,
28}
29
30impl std::fmt::Display for BatchError {
31 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
32 match self {
33 BatchError::Timeout => write!(f, "Batch processing timeout"),
34 BatchError::ProcessingFailed(msg) => write!(f, "Processing failed: {}", msg),
35 BatchError::QueueFull => write!(f, "Queue is full"),
36 }
37 }
38}
39
40impl std::error::Error for BatchError {}
41
42#[derive(Debug, Clone)]
44pub struct BatchConfig {
45 pub max_batch_size: usize,
47 pub max_wait_ms: u64,
49 pub max_queue_size: usize,
51 pub preferred_batch_size: usize,
53}
54
55impl Default for BatchConfig {
56 fn default() -> Self {
57 Self {
58 max_batch_size: 32,
59 max_wait_ms: 50,
60 max_queue_size: 1000,
61 preferred_batch_size: 16,
62 }
63 }
64}
65
66pub struct DynamicBatcher<T, R> {
68 config: BatchConfig,
69 queue: Arc<Mutex<VecDeque<BatchItem<T, R>>>>,
70 processor: Arc<dyn Fn(Vec<T>) -> Vec<std::result::Result<R, String>> + Send + Sync>,
71 shutdown: Arc<Mutex<bool>>,
72}
73
74impl<T, R> DynamicBatcher<T, R>
75where
76 T: Send + 'static,
77 R: Send + 'static,
78{
79 pub fn new<F>(config: BatchConfig, processor: F) -> Self
81 where
82 F: Fn(Vec<T>) -> Vec<std::result::Result<R, String>> + Send + Sync + 'static,
83 {
84 Self {
85 config,
86 queue: Arc::new(Mutex::new(VecDeque::new())),
87 processor: Arc::new(processor),
88 shutdown: Arc::new(Mutex::new(false)),
89 }
90 }
91
92 pub async fn add(&self, item: T) -> BatchResult<R> {
94 let (tx, rx) = oneshot::channel();
95
96 let batch_item = BatchItem {
97 data: item,
98 response: tx,
99 enqueued_at: Instant::now(),
100 };
101
102 {
103 let mut queue = self.queue.lock().await;
104 if queue.len() >= self.config.max_queue_size {
105 return Err(BatchError::QueueFull);
106 }
107 queue.push_back(batch_item);
108 }
109
110 rx.await.map_err(|_| BatchError::Timeout)?
112 }
113
114 pub async fn run(&self) {
116 let mut last_process = Instant::now();
117
118 loop {
119 {
121 let shutdown = self.shutdown.lock().await;
122 if *shutdown {
123 break;
124 }
125 }
126
127 let should_process = {
128 let queue = self.queue.lock().await;
129 queue.len() >= self.config.max_batch_size
130 || (queue.len() >= self.config.preferred_batch_size
131 && last_process.elapsed().as_millis() >= self.config.max_wait_ms as u128)
132 || (queue.len() > 0
133 && last_process.elapsed().as_millis() >= self.config.max_wait_ms as u128)
134 };
135
136 if should_process {
137 self.process_batch().await;
138 last_process = Instant::now();
139 } else {
140 sleep(Duration::from_millis(1)).await;
142 }
143 }
144
145 self.process_batch().await;
147 }
148
149 async fn process_batch(&self) {
151 let items = {
152 let mut queue = self.queue.lock().await;
153 let batch_size = self.config.max_batch_size.min(queue.len());
154 if batch_size == 0 {
155 return;
156 }
157 queue.drain(..batch_size).collect::<Vec<_>>()
158 };
159
160 if items.is_empty() {
161 return;
162 }
163
164 let (data, responses): (Vec<_>, Vec<_>) = items
166 .into_iter()
167 .map(|item| (item.data, item.response))
168 .unzip();
169
170 let results = (self.processor)(data);
172
173 for (response_tx, result) in responses.into_iter().zip(results.into_iter()) {
175 let batch_result = result.map_err(|e| BatchError::ProcessingFailed(e));
176 let _ = response_tx.send(batch_result);
177 }
178 }
179
180 pub async fn shutdown(&self) {
182 let mut shutdown = self.shutdown.lock().await;
183 *shutdown = true;
184 }
185
186 pub async fn queue_size(&self) -> usize {
188 self.queue.lock().await.len()
189 }
190
191 pub async fn stats(&self) -> BatchStats {
193 let queue = self.queue.lock().await;
194 let queue_size = queue.len();
195
196 let max_wait = queue
197 .front()
198 .map(|item| item.enqueued_at.elapsed())
199 .unwrap_or(Duration::from_secs(0));
200
201 BatchStats {
202 queue_size,
203 max_wait_time: max_wait,
204 }
205 }
206}
207
208#[derive(Debug, Clone)]
210pub struct BatchStats {
211 pub queue_size: usize,
212 pub max_wait_time: Duration,
213}
214
215pub struct AdaptiveBatcher<T, R> {
217 inner: DynamicBatcher<T, R>,
218 config: Arc<Mutex<BatchConfig>>,
219 latency_history: Arc<Mutex<VecDeque<Duration>>>,
220 target_latency: Duration,
221}
222
223impl<T, R> AdaptiveBatcher<T, R>
224where
225 T: Send + 'static,
226 R: Send + 'static,
227{
228 pub fn new<F>(
230 initial_config: BatchConfig,
231 target_latency: Duration,
232 processor: F,
233 ) -> Self
234 where
235 F: Fn(Vec<T>) -> Vec<Result<R, String>> + Send + Sync + 'static,
236 {
237 let config = Arc::new(Mutex::new(initial_config.clone()));
238 let inner = DynamicBatcher::new(initial_config, processor);
239
240 Self {
241 inner,
242 config,
243 latency_history: Arc::new(Mutex::new(VecDeque::with_capacity(100))),
244 target_latency,
245 }
246 }
247
248 pub async fn add(&self, item: T) -> Result<R, BatchError> {
250 let start = Instant::now();
251 let result = self.inner.add(item).await;
252 let latency = start.elapsed();
253
254 {
256 let mut history = self.latency_history.lock().await;
257 history.push_back(latency);
258 if history.len() > 100 {
259 history.pop_front();
260 }
261 }
262
263 {
265 let history = self.latency_history.lock().await;
266 if history.len() % 10 == 0 && history.len() >= 10 {
267 let avg_latency: Duration = history.iter().sum::<Duration>() / history.len() as u32;
268
269 let mut config = self.config.lock().await;
270 if avg_latency > self.target_latency {
271 config.max_batch_size = (config.max_batch_size * 9 / 10).max(1);
273 } else if avg_latency < self.target_latency / 2 {
274 config.max_batch_size = (config.max_batch_size * 11 / 10).min(128);
276 }
277 }
278 }
279
280 result
281 }
282
283 pub async fn run(&self) {
285 self.inner.run().await;
286 }
287
288 pub async fn current_config(&self) -> BatchConfig {
290 self.config.lock().await.clone()
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[tokio::test]
299 async fn test_dynamic_batcher() {
300 let config = BatchConfig {
301 max_batch_size: 4,
302 max_wait_ms: 100,
303 max_queue_size: 100,
304 preferred_batch_size: 2,
305 };
306
307 let batcher = Arc::new(DynamicBatcher::new(config, |items: Vec<i32>| {
308 items.into_iter().map(|x| Ok(x * 2)).collect()
309 }));
310
311 let batcher_clone = batcher.clone();
313 tokio::spawn(async move {
314 batcher_clone.run().await;
315 });
316
317 let mut handles = vec![];
319 for i in 0..8 {
320 let batcher = batcher.clone();
321 handles.push(tokio::spawn(async move {
322 batcher.add(i).await
323 }));
324 }
325
326 for (i, handle) in handles.into_iter().enumerate() {
328 let result = handle.await.unwrap().unwrap();
329 assert_eq!(result, (i as i32) * 2);
330 }
331
332 batcher.shutdown().await;
333 }
334
335 #[tokio::test]
336 async fn test_batch_stats() {
337 let config = BatchConfig::default();
338 let batcher = DynamicBatcher::new(config, |items: Vec<i32>| {
339 items.into_iter().map(|x| Ok(x)).collect()
340 });
341
342 let _ = batcher.add(1);
344 let _ = batcher.add(2);
345 let _ = batcher.add(3);
346
347 let stats = batcher.stats().await;
348 assert_eq!(stats.queue_size, 3);
349 }
350
351 #[tokio::test]
352 async fn test_queue_full() {
353 let config = BatchConfig {
354 max_queue_size: 2,
355 ..Default::default()
356 };
357
358 let batcher = DynamicBatcher::new(config, |items: Vec<i32>| {
359 std::thread::sleep(Duration::from_secs(1)); items.into_iter().map(|x| Ok(x)).collect()
361 });
362
363 let _ = batcher.add(1);
365 let _ = batcher.add(2);
366
367 let result = batcher.add(3).await;
369 assert!(matches!(result, Err(BatchError::QueueFull)));
370 }
371
372 #[tokio::test]
373 async fn test_adaptive_batcher() {
374 let config = BatchConfig {
375 max_batch_size: 8,
376 max_wait_ms: 50,
377 max_queue_size: 100,
378 preferred_batch_size: 4,
379 };
380
381 let batcher = Arc::new(AdaptiveBatcher::new(
382 config,
383 Duration::from_millis(100),
384 |items: Vec<i32>| items.into_iter().map(|x| Ok(x * 2)).collect(),
385 ));
386
387 let batcher_clone = batcher.clone();
388 tokio::spawn(async move {
389 batcher_clone.run().await;
390 });
391
392 for i in 0..20 {
394 let result = batcher.add(i).await.unwrap();
395 assert_eq!(result, i * 2);
396 }
397
398 let final_config = batcher.current_config().await;
400 assert!(final_config.max_batch_size > 0);
401 }
402}