1use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, VecDeque};
13use std::time::{Duration, Instant};
14use thiserror::Error;
15
16#[cfg(feature = "async")]
17use tokio::sync::oneshot;
18
19#[derive(Error, Debug, Clone, PartialEq)]
21pub enum BatchingError {
22 #[error("Request queue is full")]
23 QueueFull,
24
25 #[error("Request timeout after {0:?}")]
26 Timeout(Duration),
27
28 #[error("Invalid batch size: {0}")]
29 InvalidBatchSize(usize),
30
31 #[error("Request cancelled")]
32 Cancelled,
33
34 #[error("Incompatible request shapes")]
35 IncompatibleShapes,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
40pub enum Priority {
41 Low = 0,
43 Normal = 1,
45 High = 2,
47 Critical = 3,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct RequestMetadata {
54 pub id: String,
56 pub priority: Priority,
58 #[serde(skip, default = "Instant::now")]
60 pub arrival_time: Instant,
61 pub max_latency: Option<Duration>,
63 pub input_shapes: Vec<Vec<usize>>,
65}
66
67pub struct BatchRequest<T> {
69 pub metadata: RequestMetadata,
71 pub inputs: T,
73 #[cfg(feature = "async")]
75 pub response_tx: Option<oneshot::Sender<Result<T, BatchingError>>>,
76}
77
78impl<T> BatchRequest<T> {
79 pub fn new(id: String, inputs: T, input_shapes: Vec<Vec<usize>>) -> Self {
81 Self {
82 metadata: RequestMetadata {
83 id,
84 priority: Priority::Normal,
85 arrival_time: Instant::now(),
86 max_latency: None,
87 input_shapes,
88 },
89 inputs,
90 #[cfg(feature = "async")]
91 response_tx: None,
92 }
93 }
94
95 pub fn with_priority(mut self, priority: Priority) -> Self {
97 self.metadata.priority = priority;
98 self
99 }
100
101 pub fn with_max_latency(mut self, max_latency: Duration) -> Self {
103 self.metadata.max_latency = Some(max_latency);
104 self
105 }
106
107 pub fn is_timed_out(&self) -> bool {
109 if let Some(max_latency) = self.metadata.max_latency {
110 self.metadata.arrival_time.elapsed() > max_latency
111 } else {
112 false
113 }
114 }
115
116 pub fn age(&self) -> Duration {
118 self.metadata.arrival_time.elapsed()
119 }
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct DynamicBatchConfig {
125 pub max_batch_size: usize,
127 pub min_batch_size: usize,
129 pub max_wait_time: Duration,
131 pub max_queue_depth: usize,
133 pub adaptive_sizing: bool,
135 pub target_latency: Option<Duration>,
137 pub enable_deduplication: bool,
139 pub enable_splitting: bool,
141}
142
143impl Default for DynamicBatchConfig {
144 fn default() -> Self {
145 Self {
146 max_batch_size: 32,
147 min_batch_size: 1,
148 max_wait_time: Duration::from_millis(10),
149 max_queue_depth: 1000,
150 adaptive_sizing: true,
151 target_latency: Some(Duration::from_millis(50)),
152 enable_deduplication: false,
153 enable_splitting: true,
154 }
155 }
156}
157
158impl DynamicBatchConfig {
159 pub fn throughput_optimized() -> Self {
161 Self {
162 max_batch_size: 128,
163 min_batch_size: 8,
164 max_wait_time: Duration::from_millis(50),
165 ..Default::default()
166 }
167 }
168
169 pub fn latency_optimized() -> Self {
171 Self {
172 max_batch_size: 16,
173 min_batch_size: 1,
174 max_wait_time: Duration::from_millis(1),
175 target_latency: Some(Duration::from_millis(10)),
176 ..Default::default()
177 }
178 }
179
180 pub fn interactive() -> Self {
182 Self {
183 max_batch_size: 8,
184 min_batch_size: 1,
185 max_wait_time: Duration::from_millis(5),
186 adaptive_sizing: true,
187 target_latency: Some(Duration::from_millis(20)),
188 ..Default::default()
189 }
190 }
191}
192
193#[derive(Debug, Clone, Default, Serialize, Deserialize)]
195pub struct BatchingStats {
196 pub total_requests: usize,
198 pub total_batches: usize,
200 pub avg_batch_size: f64,
202 pub avg_wait_time: Duration,
204 pub avg_latency: Duration,
206 pub num_timeouts: usize,
208 pub num_overflows: usize,
210 pub current_queue_depth: usize,
212}
213
214impl BatchingStats {
215 pub fn update_batch(&mut self, batch_size: usize, wait_time: Duration, latency: Duration) {
217 self.total_batches += 1;
218 self.total_requests += batch_size;
219
220 let n = self.total_batches as f64;
222 self.avg_batch_size = (self.avg_batch_size * (n - 1.0) + batch_size as f64) / n;
223
224 self.avg_wait_time = Duration::from_secs_f64(
225 (self.avg_wait_time.as_secs_f64() * (n - 1.0) + wait_time.as_secs_f64()) / n,
226 );
227
228 self.avg_latency = Duration::from_secs_f64(
229 (self.avg_latency.as_secs_f64() * (n - 1.0) + latency.as_secs_f64()) / n,
230 );
231 }
232
233 pub fn record_timeout(&mut self) {
235 self.num_timeouts += 1;
236 }
237
238 pub fn record_overflow(&mut self) {
240 self.num_overflows += 1;
241 }
242
243 pub fn throughput(&self) -> f64 {
245 if self.avg_latency.as_secs_f64() > 0.0 {
246 self.avg_batch_size / self.avg_latency.as_secs_f64()
247 } else {
248 0.0
249 }
250 }
251
252 pub fn efficiency(&self, max_batch_size: usize) -> f64 {
254 if max_batch_size > 0 {
255 self.avg_batch_size / max_batch_size as f64
256 } else {
257 0.0
258 }
259 }
260}
261
262pub struct RequestQueue<T> {
264 queues: HashMap<Priority, VecDeque<BatchRequest<T>>>,
265 config: DynamicBatchConfig,
266}
267
268impl<T> RequestQueue<T> {
269 pub fn new(config: DynamicBatchConfig) -> Self {
271 let mut queues = HashMap::new();
272 queues.insert(Priority::Low, VecDeque::new());
273 queues.insert(Priority::Normal, VecDeque::new());
274 queues.insert(Priority::High, VecDeque::new());
275 queues.insert(Priority::Critical, VecDeque::new());
276
277 Self { queues, config }
278 }
279
280 pub fn enqueue(&mut self, request: BatchRequest<T>) -> Result<(), BatchingError> {
282 let total_depth: usize = self.queues.values().map(|q| q.len()).sum();
283 if total_depth >= self.config.max_queue_depth {
284 return Err(BatchingError::QueueFull);
285 }
286
287 let priority = request.metadata.priority;
288 self.queues
289 .get_mut(&priority)
290 .expect("priority queue always initialized")
291 .push_back(request);
292 Ok(())
293 }
294
295 pub fn dequeue_batch(&mut self, max_size: usize) -> Vec<BatchRequest<T>> {
297 let mut batch = Vec::new();
298 let priorities = [
299 Priority::Critical,
300 Priority::High,
301 Priority::Normal,
302 Priority::Low,
303 ];
304
305 for &priority in &priorities {
306 if batch.len() >= max_size {
307 break;
308 }
309
310 let queue = self
311 .queues
312 .get_mut(&priority)
313 .expect("priority queue always initialized");
314 while let Some(request) = queue.pop_front() {
315 if request.is_timed_out() {
317 continue;
318 }
319
320 batch.push(request);
321
322 if batch.len() >= max_size {
323 break;
324 }
325 }
326 }
327
328 batch
329 }
330
331 pub fn depth(&self) -> usize {
333 self.queues.values().map(|q| q.len()).sum()
334 }
335
336 pub fn oldest_age(&self) -> Option<Duration> {
338 let priorities = [
339 Priority::Critical,
340 Priority::High,
341 Priority::Normal,
342 Priority::Low,
343 ];
344
345 for &priority in &priorities {
346 if let Some(request) = self
347 .queues
348 .get(&priority)
349 .expect("priority queue always initialized")
350 .front()
351 {
352 return Some(request.age());
353 }
354 }
355 None
356 }
357
358 pub fn should_form_batch(&self) -> bool {
360 if let Some(age) = self.oldest_age() {
362 if age >= self.config.max_wait_time {
363 return true;
364 }
365 }
366
367 let depth = self.depth();
369 if depth >= self.config.min_batch_size {
370 return true;
371 }
372
373 if !self
375 .queues
376 .get(&Priority::Critical)
377 .expect("Critical priority queue always initialized")
378 .is_empty()
379 {
380 return true;
381 }
382
383 false
384 }
385}
386
387pub struct AdaptiveBatcher {
389 config: DynamicBatchConfig,
390 current_batch_size: usize,
391 latency_history: VecDeque<Duration>,
392 throughput_history: VecDeque<f64>,
393}
394
395impl AdaptiveBatcher {
396 pub fn new(config: DynamicBatchConfig) -> Self {
398 Self {
399 current_batch_size: config.max_batch_size / 2,
400 config,
401 latency_history: VecDeque::with_capacity(100),
402 throughput_history: VecDeque::with_capacity(100),
403 }
404 }
405
406 pub fn current_batch_size(&self) -> usize {
408 self.current_batch_size
409 }
410
411 pub fn update(&mut self, _batch_size: usize, latency: Duration, throughput: f64) {
413 self.latency_history.push_back(latency);
414 self.throughput_history.push_back(throughput);
415
416 while self.latency_history.len() > 100 {
418 self.latency_history.pop_front();
419 }
420 while self.throughput_history.len() > 100 {
421 self.throughput_history.pop_front();
422 }
423
424 if !self.config.adaptive_sizing {
425 return;
426 }
427
428 let target_latency = match self.config.target_latency {
429 Some(t) => t,
430 None => return,
431 };
432
433 if latency < target_latency * 8 / 10 {
436 self.current_batch_size = (self.current_batch_size + 1).min(self.config.max_batch_size);
438 } else if latency > target_latency {
439 self.current_batch_size =
441 (self.current_batch_size.saturating_sub(1)).max(self.config.min_batch_size);
442 }
443 }
444
445 pub fn avg_latency(&self) -> Option<Duration> {
447 if self.latency_history.is_empty() {
448 return None;
449 }
450
451 let sum: Duration = self.latency_history.iter().sum();
452 Some(sum / self.latency_history.len() as u32)
453 }
454
455 pub fn avg_throughput(&self) -> Option<f64> {
457 if self.throughput_history.is_empty() {
458 return None;
459 }
460
461 Some(self.throughput_history.iter().sum::<f64>() / self.throughput_history.len() as f64)
462 }
463}
464
465pub struct DynamicBatcher<T> {
467 queue: RequestQueue<T>,
468 stats: BatchingStats,
469 adaptive: AdaptiveBatcher,
470}
471
472impl<T> DynamicBatcher<T> {
473 pub fn new(config: DynamicBatchConfig) -> Self {
475 let adaptive = AdaptiveBatcher::new(config.clone());
476 let queue = RequestQueue::new(config.clone());
477
478 Self {
479 queue,
480 stats: BatchingStats::default(),
481 adaptive,
482 }
483 }
484
485 pub fn submit(&mut self, request: BatchRequest<T>) -> Result<(), BatchingError> {
487 self.queue.enqueue(request)?;
488 self.stats.current_queue_depth = self.queue.depth();
489 Ok(())
490 }
491
492 pub fn try_form_batch(&mut self) -> Option<Vec<BatchRequest<T>>> {
494 if !self.queue.should_form_batch() {
495 return None;
496 }
497
498 let batch_size = self.adaptive.current_batch_size();
499 let batch = self.queue.dequeue_batch(batch_size);
500
501 if batch.is_empty() {
502 return None;
503 }
504
505 self.stats.current_queue_depth = self.queue.depth();
506 Some(batch)
507 }
508
509 pub fn stats(&self) -> &BatchingStats {
511 &self.stats
512 }
513
514 pub fn record_batch(&mut self, batch_size: usize, wait_time: Duration, latency: Duration) {
516 self.stats.update_batch(batch_size, wait_time, latency);
517
518 let throughput = batch_size as f64 / latency.as_secs_f64();
519 self.adaptive.update(batch_size, latency, throughput);
520 }
521
522 pub fn queue_depth(&self) -> usize {
524 self.queue.depth()
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531
532 #[test]
533 fn test_priority_ordering() {
534 assert!(Priority::Critical > Priority::High);
535 assert!(Priority::High > Priority::Normal);
536 assert!(Priority::Normal > Priority::Low);
537 }
538
539 #[test]
540 fn test_request_timeout() {
541 let request = BatchRequest::new("test".to_string(), vec![1.0, 2.0], vec![vec![2]])
542 .with_max_latency(Duration::from_millis(1));
543
544 std::thread::sleep(Duration::from_millis(2));
545 assert!(request.is_timed_out());
546 }
547
548 #[test]
549 fn test_queue_enqueue_dequeue() {
550 let config = DynamicBatchConfig::default();
551 let mut queue: RequestQueue<Vec<f64>> = RequestQueue::new(config);
552
553 let req1 = BatchRequest::new("1".to_string(), vec![1.0], vec![vec![1]]);
554 let req2 = BatchRequest::new("2".to_string(), vec![2.0], vec![vec![1]])
555 .with_priority(Priority::High);
556
557 queue.enqueue(req1).expect("unwrap");
558 queue.enqueue(req2).expect("unwrap");
559
560 assert_eq!(queue.depth(), 2);
561
562 let batch = queue.dequeue_batch(10);
563 assert_eq!(batch.len(), 2);
564 assert_eq!(batch[0].metadata.id, "2");
566 }
567
568 #[test]
569 fn test_queue_overflow() {
570 let config = DynamicBatchConfig {
571 max_queue_depth: 2,
572 ..Default::default()
573 };
574 let mut queue: RequestQueue<Vec<f64>> = RequestQueue::new(config);
575
576 queue
577 .enqueue(BatchRequest::new("1".to_string(), vec![1.0], vec![vec![1]]))
578 .expect("unwrap");
579 queue
580 .enqueue(BatchRequest::new("2".to_string(), vec![2.0], vec![vec![1]]))
581 .expect("unwrap");
582
583 let result = queue.enqueue(BatchRequest::new("3".to_string(), vec![3.0], vec![vec![1]]));
584 assert!(matches!(result, Err(BatchingError::QueueFull)));
585 }
586
587 #[test]
588 fn test_batching_stats() {
589 let mut stats = BatchingStats::default();
590
591 stats.update_batch(4, Duration::from_millis(5), Duration::from_millis(10));
592 stats.update_batch(8, Duration::from_millis(6), Duration::from_millis(12));
593
594 assert_eq!(stats.total_requests, 12);
595 assert_eq!(stats.total_batches, 2);
596 assert_eq!(stats.avg_batch_size, 6.0);
597 }
598
599 #[test]
600 fn test_adaptive_batcher() {
601 let config = DynamicBatchConfig {
602 adaptive_sizing: true,
603 target_latency: Some(Duration::from_millis(50)),
604 min_batch_size: 1,
605 max_batch_size: 32,
606 ..Default::default()
607 };
608
609 let mut batcher = AdaptiveBatcher::new(config);
610 let initial_size = batcher.current_batch_size();
611
612 batcher.update(8, Duration::from_millis(20), 400.0);
614 assert!(batcher.current_batch_size() >= initial_size);
615
616 for _ in 0..10 {
618 batcher.update(8, Duration::from_millis(100), 80.0);
619 }
620 assert!(batcher.current_batch_size() < initial_size);
621 }
622
623 #[test]
624 fn test_dynamic_batcher() {
625 let config = DynamicBatchConfig::latency_optimized();
626 let mut batcher: DynamicBatcher<Vec<f64>> = DynamicBatcher::new(config);
627
628 for i in 0..5 {
630 let request = BatchRequest::new(format!("req_{}", i), vec![i as f64], vec![vec![1]]);
631 batcher.submit(request).expect("unwrap");
632 }
633
634 assert_eq!(batcher.queue_depth(), 5);
635
636 let batch = batcher.try_form_batch();
638 assert!(batch.is_some());
639
640 let batch = batch.expect("unwrap");
641 assert!(!batch.is_empty());
642 }
643
644 #[test]
645 fn test_config_presets() {
646 let throughput = DynamicBatchConfig::throughput_optimized();
647 assert!(throughput.max_batch_size > DynamicBatchConfig::default().max_batch_size);
648
649 let latency = DynamicBatchConfig::latency_optimized();
650 assert!(latency.max_wait_time < DynamicBatchConfig::default().max_wait_time);
651
652 let interactive = DynamicBatchConfig::interactive();
653 assert!(interactive.max_batch_size < throughput.max_batch_size);
654 }
655
656 #[test]
657 fn test_stats_efficiency() {
658 let mut stats = BatchingStats::default();
659 stats.update_batch(16, Duration::from_millis(5), Duration::from_millis(10));
660
661 assert_eq!(stats.efficiency(32), 0.5);
662 assert_eq!(stats.efficiency(16), 1.0);
663 }
664}