1use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, VecDeque};
13use std::time::{Duration, Instant};
14use thiserror::Error;
15
16#[cfg(feature = "async")]
17use tokio::sync::oneshot;
18
19#[derive(Error, Debug, Clone, PartialEq)]
21pub enum BatchingError {
22 #[error("Request queue is full")]
23 QueueFull,
24
25 #[error("Request timeout after {0:?}")]
26 Timeout(Duration),
27
28 #[error("Invalid batch size: {0}")]
29 InvalidBatchSize(usize),
30
31 #[error("Request cancelled")]
32 Cancelled,
33
34 #[error("Incompatible request shapes")]
35 IncompatibleShapes,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
40pub enum Priority {
41 Low = 0,
43 Normal = 1,
45 High = 2,
47 Critical = 3,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct RequestMetadata {
54 pub id: String,
56 pub priority: Priority,
58 #[serde(skip, default = "Instant::now")]
60 pub arrival_time: Instant,
61 pub max_latency: Option<Duration>,
63 pub input_shapes: Vec<Vec<usize>>,
65}
66
67pub struct BatchRequest<T> {
69 pub metadata: RequestMetadata,
71 pub inputs: T,
73 #[cfg(feature = "async")]
75 pub response_tx: Option<oneshot::Sender<Result<T, BatchingError>>>,
76}
77
78impl<T> BatchRequest<T> {
79 pub fn new(id: String, inputs: T, input_shapes: Vec<Vec<usize>>) -> Self {
81 Self {
82 metadata: RequestMetadata {
83 id,
84 priority: Priority::Normal,
85 arrival_time: Instant::now(),
86 max_latency: None,
87 input_shapes,
88 },
89 inputs,
90 #[cfg(feature = "async")]
91 response_tx: None,
92 }
93 }
94
95 pub fn with_priority(mut self, priority: Priority) -> Self {
97 self.metadata.priority = priority;
98 self
99 }
100
101 pub fn with_max_latency(mut self, max_latency: Duration) -> Self {
103 self.metadata.max_latency = Some(max_latency);
104 self
105 }
106
107 pub fn is_timed_out(&self) -> bool {
109 if let Some(max_latency) = self.metadata.max_latency {
110 self.metadata.arrival_time.elapsed() > max_latency
111 } else {
112 false
113 }
114 }
115
116 pub fn age(&self) -> Duration {
118 self.metadata.arrival_time.elapsed()
119 }
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct DynamicBatchConfig {
125 pub max_batch_size: usize,
127 pub min_batch_size: usize,
129 pub max_wait_time: Duration,
131 pub max_queue_depth: usize,
133 pub adaptive_sizing: bool,
135 pub target_latency: Option<Duration>,
137 pub enable_deduplication: bool,
139 pub enable_splitting: bool,
141}
142
143impl Default for DynamicBatchConfig {
144 fn default() -> Self {
145 Self {
146 max_batch_size: 32,
147 min_batch_size: 1,
148 max_wait_time: Duration::from_millis(10),
149 max_queue_depth: 1000,
150 adaptive_sizing: true,
151 target_latency: Some(Duration::from_millis(50)),
152 enable_deduplication: false,
153 enable_splitting: true,
154 }
155 }
156}
157
158impl DynamicBatchConfig {
159 pub fn throughput_optimized() -> Self {
161 Self {
162 max_batch_size: 128,
163 min_batch_size: 8,
164 max_wait_time: Duration::from_millis(50),
165 ..Default::default()
166 }
167 }
168
169 pub fn latency_optimized() -> Self {
171 Self {
172 max_batch_size: 16,
173 min_batch_size: 1,
174 max_wait_time: Duration::from_millis(1),
175 target_latency: Some(Duration::from_millis(10)),
176 ..Default::default()
177 }
178 }
179
180 pub fn interactive() -> Self {
182 Self {
183 max_batch_size: 8,
184 min_batch_size: 1,
185 max_wait_time: Duration::from_millis(5),
186 adaptive_sizing: true,
187 target_latency: Some(Duration::from_millis(20)),
188 ..Default::default()
189 }
190 }
191}
192
193#[derive(Debug, Clone, Default, Serialize, Deserialize)]
195pub struct BatchingStats {
196 pub total_requests: usize,
198 pub total_batches: usize,
200 pub avg_batch_size: f64,
202 pub avg_wait_time: Duration,
204 pub avg_latency: Duration,
206 pub num_timeouts: usize,
208 pub num_overflows: usize,
210 pub current_queue_depth: usize,
212}
213
214impl BatchingStats {
215 pub fn update_batch(&mut self, batch_size: usize, wait_time: Duration, latency: Duration) {
217 self.total_batches += 1;
218 self.total_requests += batch_size;
219
220 let n = self.total_batches as f64;
222 self.avg_batch_size = (self.avg_batch_size * (n - 1.0) + batch_size as f64) / n;
223
224 self.avg_wait_time = Duration::from_secs_f64(
225 (self.avg_wait_time.as_secs_f64() * (n - 1.0) + wait_time.as_secs_f64()) / n,
226 );
227
228 self.avg_latency = Duration::from_secs_f64(
229 (self.avg_latency.as_secs_f64() * (n - 1.0) + latency.as_secs_f64()) / n,
230 );
231 }
232
233 pub fn record_timeout(&mut self) {
235 self.num_timeouts += 1;
236 }
237
238 pub fn record_overflow(&mut self) {
240 self.num_overflows += 1;
241 }
242
243 pub fn throughput(&self) -> f64 {
245 if self.avg_latency.as_secs_f64() > 0.0 {
246 self.avg_batch_size / self.avg_latency.as_secs_f64()
247 } else {
248 0.0
249 }
250 }
251
252 pub fn efficiency(&self, max_batch_size: usize) -> f64 {
254 if max_batch_size > 0 {
255 self.avg_batch_size / max_batch_size as f64
256 } else {
257 0.0
258 }
259 }
260}
261
262pub struct RequestQueue<T> {
264 queues: HashMap<Priority, VecDeque<BatchRequest<T>>>,
265 config: DynamicBatchConfig,
266}
267
268impl<T> RequestQueue<T> {
269 pub fn new(config: DynamicBatchConfig) -> Self {
271 let mut queues = HashMap::new();
272 queues.insert(Priority::Low, VecDeque::new());
273 queues.insert(Priority::Normal, VecDeque::new());
274 queues.insert(Priority::High, VecDeque::new());
275 queues.insert(Priority::Critical, VecDeque::new());
276
277 Self { queues, config }
278 }
279
280 pub fn enqueue(&mut self, request: BatchRequest<T>) -> Result<(), BatchingError> {
282 let total_depth: usize = self.queues.values().map(|q| q.len()).sum();
283 if total_depth >= self.config.max_queue_depth {
284 return Err(BatchingError::QueueFull);
285 }
286
287 let priority = request.metadata.priority;
288 self.queues.get_mut(&priority).unwrap().push_back(request);
289 Ok(())
290 }
291
292 pub fn dequeue_batch(&mut self, max_size: usize) -> Vec<BatchRequest<T>> {
294 let mut batch = Vec::new();
295 let priorities = [
296 Priority::Critical,
297 Priority::High,
298 Priority::Normal,
299 Priority::Low,
300 ];
301
302 for &priority in &priorities {
303 if batch.len() >= max_size {
304 break;
305 }
306
307 let queue = self.queues.get_mut(&priority).unwrap();
308 while let Some(request) = queue.pop_front() {
309 if request.is_timed_out() {
311 continue;
312 }
313
314 batch.push(request);
315
316 if batch.len() >= max_size {
317 break;
318 }
319 }
320 }
321
322 batch
323 }
324
325 pub fn depth(&self) -> usize {
327 self.queues.values().map(|q| q.len()).sum()
328 }
329
330 pub fn oldest_age(&self) -> Option<Duration> {
332 let priorities = [
333 Priority::Critical,
334 Priority::High,
335 Priority::Normal,
336 Priority::Low,
337 ];
338
339 for &priority in &priorities {
340 if let Some(request) = self.queues.get(&priority).unwrap().front() {
341 return Some(request.age());
342 }
343 }
344 None
345 }
346
347 pub fn should_form_batch(&self) -> bool {
349 if let Some(age) = self.oldest_age() {
351 if age >= self.config.max_wait_time {
352 return true;
353 }
354 }
355
356 let depth = self.depth();
358 if depth >= self.config.min_batch_size {
359 return true;
360 }
361
362 if !self.queues.get(&Priority::Critical).unwrap().is_empty() {
364 return true;
365 }
366
367 false
368 }
369}
370
371pub struct AdaptiveBatcher {
373 config: DynamicBatchConfig,
374 current_batch_size: usize,
375 latency_history: VecDeque<Duration>,
376 throughput_history: VecDeque<f64>,
377}
378
379impl AdaptiveBatcher {
380 pub fn new(config: DynamicBatchConfig) -> Self {
382 Self {
383 current_batch_size: config.max_batch_size / 2,
384 config,
385 latency_history: VecDeque::with_capacity(100),
386 throughput_history: VecDeque::with_capacity(100),
387 }
388 }
389
390 pub fn current_batch_size(&self) -> usize {
392 self.current_batch_size
393 }
394
395 pub fn update(&mut self, _batch_size: usize, latency: Duration, throughput: f64) {
397 self.latency_history.push_back(latency);
398 self.throughput_history.push_back(throughput);
399
400 while self.latency_history.len() > 100 {
402 self.latency_history.pop_front();
403 }
404 while self.throughput_history.len() > 100 {
405 self.throughput_history.pop_front();
406 }
407
408 if !self.config.adaptive_sizing {
409 return;
410 }
411
412 let target_latency = match self.config.target_latency {
413 Some(t) => t,
414 None => return,
415 };
416
417 if latency < target_latency * 8 / 10 {
420 self.current_batch_size = (self.current_batch_size + 1).min(self.config.max_batch_size);
422 } else if latency > target_latency {
423 self.current_batch_size =
425 (self.current_batch_size.saturating_sub(1)).max(self.config.min_batch_size);
426 }
427 }
428
429 pub fn avg_latency(&self) -> Option<Duration> {
431 if self.latency_history.is_empty() {
432 return None;
433 }
434
435 let sum: Duration = self.latency_history.iter().sum();
436 Some(sum / self.latency_history.len() as u32)
437 }
438
439 pub fn avg_throughput(&self) -> Option<f64> {
441 if self.throughput_history.is_empty() {
442 return None;
443 }
444
445 Some(self.throughput_history.iter().sum::<f64>() / self.throughput_history.len() as f64)
446 }
447}
448
449pub struct DynamicBatcher<T> {
451 queue: RequestQueue<T>,
452 stats: BatchingStats,
453 adaptive: AdaptiveBatcher,
454}
455
456impl<T> DynamicBatcher<T> {
457 pub fn new(config: DynamicBatchConfig) -> Self {
459 let adaptive = AdaptiveBatcher::new(config.clone());
460 let queue = RequestQueue::new(config.clone());
461
462 Self {
463 queue,
464 stats: BatchingStats::default(),
465 adaptive,
466 }
467 }
468
469 pub fn submit(&mut self, request: BatchRequest<T>) -> Result<(), BatchingError> {
471 self.queue.enqueue(request)?;
472 self.stats.current_queue_depth = self.queue.depth();
473 Ok(())
474 }
475
476 pub fn try_form_batch(&mut self) -> Option<Vec<BatchRequest<T>>> {
478 if !self.queue.should_form_batch() {
479 return None;
480 }
481
482 let batch_size = self.adaptive.current_batch_size();
483 let batch = self.queue.dequeue_batch(batch_size);
484
485 if batch.is_empty() {
486 return None;
487 }
488
489 self.stats.current_queue_depth = self.queue.depth();
490 Some(batch)
491 }
492
493 pub fn stats(&self) -> &BatchingStats {
495 &self.stats
496 }
497
498 pub fn record_batch(&mut self, batch_size: usize, wait_time: Duration, latency: Duration) {
500 self.stats.update_batch(batch_size, wait_time, latency);
501
502 let throughput = batch_size as f64 / latency.as_secs_f64();
503 self.adaptive.update(batch_size, latency, throughput);
504 }
505
506 pub fn queue_depth(&self) -> usize {
508 self.queue.depth()
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515
516 #[test]
517 fn test_priority_ordering() {
518 assert!(Priority::Critical > Priority::High);
519 assert!(Priority::High > Priority::Normal);
520 assert!(Priority::Normal > Priority::Low);
521 }
522
523 #[test]
524 fn test_request_timeout() {
525 let request = BatchRequest::new("test".to_string(), vec![1.0, 2.0], vec![vec![2]])
526 .with_max_latency(Duration::from_millis(1));
527
528 std::thread::sleep(Duration::from_millis(2));
529 assert!(request.is_timed_out());
530 }
531
532 #[test]
533 fn test_queue_enqueue_dequeue() {
534 let config = DynamicBatchConfig::default();
535 let mut queue: RequestQueue<Vec<f64>> = RequestQueue::new(config);
536
537 let req1 = BatchRequest::new("1".to_string(), vec![1.0], vec![vec![1]]);
538 let req2 = BatchRequest::new("2".to_string(), vec![2.0], vec![vec![1]])
539 .with_priority(Priority::High);
540
541 queue.enqueue(req1).unwrap();
542 queue.enqueue(req2).unwrap();
543
544 assert_eq!(queue.depth(), 2);
545
546 let batch = queue.dequeue_batch(10);
547 assert_eq!(batch.len(), 2);
548 assert_eq!(batch[0].metadata.id, "2");
550 }
551
552 #[test]
553 fn test_queue_overflow() {
554 let config = DynamicBatchConfig {
555 max_queue_depth: 2,
556 ..Default::default()
557 };
558 let mut queue: RequestQueue<Vec<f64>> = RequestQueue::new(config);
559
560 queue
561 .enqueue(BatchRequest::new("1".to_string(), vec![1.0], vec![vec![1]]))
562 .unwrap();
563 queue
564 .enqueue(BatchRequest::new("2".to_string(), vec![2.0], vec![vec![1]]))
565 .unwrap();
566
567 let result = queue.enqueue(BatchRequest::new("3".to_string(), vec![3.0], vec![vec![1]]));
568 assert!(matches!(result, Err(BatchingError::QueueFull)));
569 }
570
571 #[test]
572 fn test_batching_stats() {
573 let mut stats = BatchingStats::default();
574
575 stats.update_batch(4, Duration::from_millis(5), Duration::from_millis(10));
576 stats.update_batch(8, Duration::from_millis(6), Duration::from_millis(12));
577
578 assert_eq!(stats.total_requests, 12);
579 assert_eq!(stats.total_batches, 2);
580 assert_eq!(stats.avg_batch_size, 6.0);
581 }
582
583 #[test]
584 fn test_adaptive_batcher() {
585 let config = DynamicBatchConfig {
586 adaptive_sizing: true,
587 target_latency: Some(Duration::from_millis(50)),
588 min_batch_size: 1,
589 max_batch_size: 32,
590 ..Default::default()
591 };
592
593 let mut batcher = AdaptiveBatcher::new(config);
594 let initial_size = batcher.current_batch_size();
595
596 batcher.update(8, Duration::from_millis(20), 400.0);
598 assert!(batcher.current_batch_size() >= initial_size);
599
600 for _ in 0..10 {
602 batcher.update(8, Duration::from_millis(100), 80.0);
603 }
604 assert!(batcher.current_batch_size() < initial_size);
605 }
606
607 #[test]
608 fn test_dynamic_batcher() {
609 let config = DynamicBatchConfig::latency_optimized();
610 let mut batcher: DynamicBatcher<Vec<f64>> = DynamicBatcher::new(config);
611
612 for i in 0..5 {
614 let request = BatchRequest::new(format!("req_{}", i), vec![i as f64], vec![vec![1]]);
615 batcher.submit(request).unwrap();
616 }
617
618 assert_eq!(batcher.queue_depth(), 5);
619
620 let batch = batcher.try_form_batch();
622 assert!(batch.is_some());
623
624 let batch = batch.unwrap();
625 assert!(!batch.is_empty());
626 }
627
628 #[test]
629 fn test_config_presets() {
630 let throughput = DynamicBatchConfig::throughput_optimized();
631 assert!(throughput.max_batch_size > DynamicBatchConfig::default().max_batch_size);
632
633 let latency = DynamicBatchConfig::latency_optimized();
634 assert!(latency.max_wait_time < DynamicBatchConfig::default().max_wait_time);
635
636 let interactive = DynamicBatchConfig::interactive();
637 assert!(interactive.max_batch_size < throughput.max_batch_size);
638 }
639
640 #[test]
641 fn test_stats_efficiency() {
642 let mut stats = BatchingStats::default();
643 stats.update_batch(16, Duration::from_millis(5), Duration::from_millis(10));
644
645 assert_eq!(stats.efficiency(32), 0.5);
646 assert_eq!(stats.efficiency(16), 1.0);
647 }
648}