Skip to main content

tensorlogic_infer/
dynamic_batching.rs

1//! Dynamic batching for inference serving.
2//!
3//! This module provides dynamic batching capabilities for efficient inference serving:
4//! - Automatic request batching with configurable timeouts
5//! - Priority-based request queuing
6//! - Adaptive batch sizing based on load
7//! - Request deduplication
8//! - Batch splitting for heterogeneous requests
9//! - Latency and throughput optimization
10
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, VecDeque};
13use std::time::{Duration, Instant};
14use thiserror::Error;
15
16#[cfg(feature = "async")]
17use tokio::sync::oneshot;
18
19/// Dynamic batching errors.
20#[derive(Error, Debug, Clone, PartialEq)]
21pub enum BatchingError {
22    #[error("Request queue is full")]
23    QueueFull,
24
25    #[error("Request timeout after {0:?}")]
26    Timeout(Duration),
27
28    #[error("Invalid batch size: {0}")]
29    InvalidBatchSize(usize),
30
31    #[error("Request cancelled")]
32    Cancelled,
33
34    #[error("Incompatible request shapes")]
35    IncompatibleShapes,
36}
37
38/// Priority level for requests.
39#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
40pub enum Priority {
41    /// Low priority (batch-friendly, can wait)
42    Low = 0,
43    /// Normal priority
44    Normal = 1,
45    /// High priority (minimize latency)
46    High = 2,
47    /// Critical priority (process immediately)
48    Critical = 3,
49}
50
51/// Request metadata for batching decisions.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct RequestMetadata {
54    /// Request ID
55    pub id: String,
56    /// Priority level
57    pub priority: Priority,
58    /// Arrival timestamp
59    #[serde(skip, default = "Instant::now")]
60    pub arrival_time: Instant,
61    /// Maximum tolerable latency
62    pub max_latency: Option<Duration>,
63    /// Input shapes (for compatibility checking)
64    pub input_shapes: Vec<Vec<usize>>,
65}
66
67/// A request to be batched.
68pub struct BatchRequest<T> {
69    /// Request metadata
70    pub metadata: RequestMetadata,
71    /// Input data
72    pub inputs: T,
73    /// Response channel (for async execution)
74    #[cfg(feature = "async")]
75    pub response_tx: Option<oneshot::Sender<Result<T, BatchingError>>>,
76}
77
78impl<T> BatchRequest<T> {
79    /// Create a new batch request.
80    pub fn new(id: String, inputs: T, input_shapes: Vec<Vec<usize>>) -> Self {
81        Self {
82            metadata: RequestMetadata {
83                id,
84                priority: Priority::Normal,
85                arrival_time: Instant::now(),
86                max_latency: None,
87                input_shapes,
88            },
89            inputs,
90            #[cfg(feature = "async")]
91            response_tx: None,
92        }
93    }
94
95    /// Set priority.
96    pub fn with_priority(mut self, priority: Priority) -> Self {
97        self.metadata.priority = priority;
98        self
99    }
100
101    /// Set maximum latency.
102    pub fn with_max_latency(mut self, max_latency: Duration) -> Self {
103        self.metadata.max_latency = Some(max_latency);
104        self
105    }
106
107    /// Check if request has timed out.
108    pub fn is_timed_out(&self) -> bool {
109        if let Some(max_latency) = self.metadata.max_latency {
110            self.metadata.arrival_time.elapsed() > max_latency
111        } else {
112            false
113        }
114    }
115
116    /// Get age of the request.
117    pub fn age(&self) -> Duration {
118        self.metadata.arrival_time.elapsed()
119    }
120}
121
122/// Configuration for dynamic batching.
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct DynamicBatchConfig {
125    /// Maximum batch size
126    pub max_batch_size: usize,
127    /// Minimum batch size (for efficiency)
128    pub min_batch_size: usize,
129    /// Maximum wait time before forming a batch
130    pub max_wait_time: Duration,
131    /// Maximum queue depth
132    pub max_queue_depth: usize,
133    /// Enable adaptive batch sizing
134    pub adaptive_sizing: bool,
135    /// Target latency for adaptive sizing
136    pub target_latency: Option<Duration>,
137    /// Enable request deduplication
138    pub enable_deduplication: bool,
139    /// Enable batch splitting for heterogeneous requests
140    pub enable_splitting: bool,
141}
142
143impl Default for DynamicBatchConfig {
144    fn default() -> Self {
145        Self {
146            max_batch_size: 32,
147            min_batch_size: 1,
148            max_wait_time: Duration::from_millis(10),
149            max_queue_depth: 1000,
150            adaptive_sizing: true,
151            target_latency: Some(Duration::from_millis(50)),
152            enable_deduplication: false,
153            enable_splitting: true,
154        }
155    }
156}
157
158impl DynamicBatchConfig {
159    /// Create configuration optimized for throughput.
160    pub fn throughput_optimized() -> Self {
161        Self {
162            max_batch_size: 128,
163            min_batch_size: 8,
164            max_wait_time: Duration::from_millis(50),
165            ..Default::default()
166        }
167    }
168
169    /// Create configuration optimized for latency.
170    pub fn latency_optimized() -> Self {
171        Self {
172            max_batch_size: 16,
173            min_batch_size: 1,
174            max_wait_time: Duration::from_millis(1),
175            target_latency: Some(Duration::from_millis(10)),
176            ..Default::default()
177        }
178    }
179
180    /// Create configuration for interactive workloads.
181    pub fn interactive() -> Self {
182        Self {
183            max_batch_size: 8,
184            min_batch_size: 1,
185            max_wait_time: Duration::from_millis(5),
186            adaptive_sizing: true,
187            target_latency: Some(Duration::from_millis(20)),
188            ..Default::default()
189        }
190    }
191}
192
193/// Statistics for dynamic batching.
194#[derive(Debug, Clone, Default, Serialize, Deserialize)]
195pub struct BatchingStats {
196    /// Total requests processed
197    pub total_requests: usize,
198    /// Total batches formed
199    pub total_batches: usize,
200    /// Average batch size
201    pub avg_batch_size: f64,
202    /// Average wait time
203    pub avg_wait_time: Duration,
204    /// Average latency
205    pub avg_latency: Duration,
206    /// Number of timeouts
207    pub num_timeouts: usize,
208    /// Number of queue overflows
209    pub num_overflows: usize,
210    /// Current queue depth
211    pub current_queue_depth: usize,
212}
213
214impl BatchingStats {
215    /// Update statistics with a new batch.
216    pub fn update_batch(&mut self, batch_size: usize, wait_time: Duration, latency: Duration) {
217        self.total_batches += 1;
218        self.total_requests += batch_size;
219
220        // Update averages using incremental formula
221        let n = self.total_batches as f64;
222        self.avg_batch_size = (self.avg_batch_size * (n - 1.0) + batch_size as f64) / n;
223
224        self.avg_wait_time = Duration::from_secs_f64(
225            (self.avg_wait_time.as_secs_f64() * (n - 1.0) + wait_time.as_secs_f64()) / n,
226        );
227
228        self.avg_latency = Duration::from_secs_f64(
229            (self.avg_latency.as_secs_f64() * (n - 1.0) + latency.as_secs_f64()) / n,
230        );
231    }
232
233    /// Record a timeout.
234    pub fn record_timeout(&mut self) {
235        self.num_timeouts += 1;
236    }
237
238    /// Record a queue overflow.
239    pub fn record_overflow(&mut self) {
240        self.num_overflows += 1;
241    }
242
243    /// Get throughput (requests per second).
244    pub fn throughput(&self) -> f64 {
245        if self.avg_latency.as_secs_f64() > 0.0 {
246            self.avg_batch_size / self.avg_latency.as_secs_f64()
247        } else {
248            0.0
249        }
250    }
251
252    /// Get batching efficiency (ratio of actual to max batch size).
253    pub fn efficiency(&self, max_batch_size: usize) -> f64 {
254        if max_batch_size > 0 {
255            self.avg_batch_size / max_batch_size as f64
256        } else {
257            0.0
258        }
259    }
260}
261
262/// Request queue with priority support.
263pub struct RequestQueue<T> {
264    queues: HashMap<Priority, VecDeque<BatchRequest<T>>>,
265    config: DynamicBatchConfig,
266}
267
268impl<T> RequestQueue<T> {
269    /// Create a new request queue.
270    pub fn new(config: DynamicBatchConfig) -> Self {
271        let mut queues = HashMap::new();
272        queues.insert(Priority::Low, VecDeque::new());
273        queues.insert(Priority::Normal, VecDeque::new());
274        queues.insert(Priority::High, VecDeque::new());
275        queues.insert(Priority::Critical, VecDeque::new());
276
277        Self { queues, config }
278    }
279
280    /// Enqueue a request.
281    pub fn enqueue(&mut self, request: BatchRequest<T>) -> Result<(), BatchingError> {
282        let total_depth: usize = self.queues.values().map(|q| q.len()).sum();
283        if total_depth >= self.config.max_queue_depth {
284            return Err(BatchingError::QueueFull);
285        }
286
287        let priority = request.metadata.priority;
288        self.queues
289            .get_mut(&priority)
290            .expect("priority queue always initialized")
291            .push_back(request);
292        Ok(())
293    }
294
295    /// Dequeue requests to form a batch.
296    pub fn dequeue_batch(&mut self, max_size: usize) -> Vec<BatchRequest<T>> {
297        let mut batch = Vec::new();
298        let priorities = [
299            Priority::Critical,
300            Priority::High,
301            Priority::Normal,
302            Priority::Low,
303        ];
304
305        for &priority in &priorities {
306            if batch.len() >= max_size {
307                break;
308            }
309
310            let queue = self
311                .queues
312                .get_mut(&priority)
313                .expect("priority queue always initialized");
314            while let Some(request) = queue.pop_front() {
315                // Skip timed-out requests
316                if request.is_timed_out() {
317                    continue;
318                }
319
320                batch.push(request);
321
322                if batch.len() >= max_size {
323                    break;
324                }
325            }
326        }
327
328        batch
329    }
330
331    /// Get total queue depth.
332    pub fn depth(&self) -> usize {
333        self.queues.values().map(|q| q.len()).sum()
334    }
335
336    /// Get oldest request age.
337    pub fn oldest_age(&self) -> Option<Duration> {
338        let priorities = [
339            Priority::Critical,
340            Priority::High,
341            Priority::Normal,
342            Priority::Low,
343        ];
344
345        for &priority in &priorities {
346            if let Some(request) = self
347                .queues
348                .get(&priority)
349                .expect("priority queue always initialized")
350                .front()
351            {
352                return Some(request.age());
353            }
354        }
355        None
356    }
357
358    /// Check if batch formation criteria are met.
359    pub fn should_form_batch(&self) -> bool {
360        // Form batch if max wait time exceeded
361        if let Some(age) = self.oldest_age() {
362            if age >= self.config.max_wait_time {
363                return true;
364            }
365        }
366
367        // Form batch if min batch size reached
368        let depth = self.depth();
369        if depth >= self.config.min_batch_size {
370            return true;
371        }
372
373        // Form batch immediately for critical requests
374        if !self
375            .queues
376            .get(&Priority::Critical)
377            .expect("Critical priority queue always initialized")
378            .is_empty()
379        {
380            return true;
381        }
382
383        false
384    }
385}
386
387/// Adaptive batch size controller.
388pub struct AdaptiveBatcher {
389    config: DynamicBatchConfig,
390    current_batch_size: usize,
391    latency_history: VecDeque<Duration>,
392    throughput_history: VecDeque<f64>,
393}
394
395impl AdaptiveBatcher {
396    /// Create a new adaptive batcher.
397    pub fn new(config: DynamicBatchConfig) -> Self {
398        Self {
399            current_batch_size: config.max_batch_size / 2,
400            config,
401            latency_history: VecDeque::with_capacity(100),
402            throughput_history: VecDeque::with_capacity(100),
403        }
404    }
405
406    /// Get current recommended batch size.
407    pub fn current_batch_size(&self) -> usize {
408        self.current_batch_size
409    }
410
411    /// Update batch size based on observed latency.
412    pub fn update(&mut self, _batch_size: usize, latency: Duration, throughput: f64) {
413        self.latency_history.push_back(latency);
414        self.throughput_history.push_back(throughput);
415
416        // Keep only recent history
417        while self.latency_history.len() > 100 {
418            self.latency_history.pop_front();
419        }
420        while self.throughput_history.len() > 100 {
421            self.throughput_history.pop_front();
422        }
423
424        if !self.config.adaptive_sizing {
425            return;
426        }
427
428        let target_latency = match self.config.target_latency {
429            Some(t) => t,
430            None => return,
431        };
432
433        // Simple adaptive strategy: increase batch size if under target,
434        // decrease if over target
435        if latency < target_latency * 8 / 10 {
436            // Under target - can increase batch size
437            self.current_batch_size = (self.current_batch_size + 1).min(self.config.max_batch_size);
438        } else if latency > target_latency {
439            // Over target - decrease batch size
440            self.current_batch_size =
441                (self.current_batch_size.saturating_sub(1)).max(self.config.min_batch_size);
442        }
443    }
444
445    /// Get average recent latency.
446    pub fn avg_latency(&self) -> Option<Duration> {
447        if self.latency_history.is_empty() {
448            return None;
449        }
450
451        let sum: Duration = self.latency_history.iter().sum();
452        Some(sum / self.latency_history.len() as u32)
453    }
454
455    /// Get average recent throughput.
456    pub fn avg_throughput(&self) -> Option<f64> {
457        if self.throughput_history.is_empty() {
458            return None;
459        }
460
461        Some(self.throughput_history.iter().sum::<f64>() / self.throughput_history.len() as f64)
462    }
463}
464
465/// Dynamic batcher for inference requests.
466pub struct DynamicBatcher<T> {
467    queue: RequestQueue<T>,
468    stats: BatchingStats,
469    adaptive: AdaptiveBatcher,
470}
471
472impl<T> DynamicBatcher<T> {
473    /// Create a new dynamic batcher.
474    pub fn new(config: DynamicBatchConfig) -> Self {
475        let adaptive = AdaptiveBatcher::new(config.clone());
476        let queue = RequestQueue::new(config.clone());
477
478        Self {
479            queue,
480            stats: BatchingStats::default(),
481            adaptive,
482        }
483    }
484
485    /// Submit a request for batching.
486    pub fn submit(&mut self, request: BatchRequest<T>) -> Result<(), BatchingError> {
487        self.queue.enqueue(request)?;
488        self.stats.current_queue_depth = self.queue.depth();
489        Ok(())
490    }
491
492    /// Try to form a batch if criteria are met.
493    pub fn try_form_batch(&mut self) -> Option<Vec<BatchRequest<T>>> {
494        if !self.queue.should_form_batch() {
495            return None;
496        }
497
498        let batch_size = self.adaptive.current_batch_size();
499        let batch = self.queue.dequeue_batch(batch_size);
500
501        if batch.is_empty() {
502            return None;
503        }
504
505        self.stats.current_queue_depth = self.queue.depth();
506        Some(batch)
507    }
508
509    /// Get statistics.
510    pub fn stats(&self) -> &BatchingStats {
511        &self.stats
512    }
513
514    /// Record batch execution results.
515    pub fn record_batch(&mut self, batch_size: usize, wait_time: Duration, latency: Duration) {
516        self.stats.update_batch(batch_size, wait_time, latency);
517
518        let throughput = batch_size as f64 / latency.as_secs_f64();
519        self.adaptive.update(batch_size, latency, throughput);
520    }
521
522    /// Get current queue depth.
523    pub fn queue_depth(&self) -> usize {
524        self.queue.depth()
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531
532    #[test]
533    fn test_priority_ordering() {
534        assert!(Priority::Critical > Priority::High);
535        assert!(Priority::High > Priority::Normal);
536        assert!(Priority::Normal > Priority::Low);
537    }
538
539    #[test]
540    fn test_request_timeout() {
541        let request = BatchRequest::new("test".to_string(), vec![1.0, 2.0], vec![vec![2]])
542            .with_max_latency(Duration::from_millis(1));
543
544        std::thread::sleep(Duration::from_millis(2));
545        assert!(request.is_timed_out());
546    }
547
548    #[test]
549    fn test_queue_enqueue_dequeue() {
550        let config = DynamicBatchConfig::default();
551        let mut queue: RequestQueue<Vec<f64>> = RequestQueue::new(config);
552
553        let req1 = BatchRequest::new("1".to_string(), vec![1.0], vec![vec![1]]);
554        let req2 = BatchRequest::new("2".to_string(), vec![2.0], vec![vec![1]])
555            .with_priority(Priority::High);
556
557        queue.enqueue(req1).expect("unwrap");
558        queue.enqueue(req2).expect("unwrap");
559
560        assert_eq!(queue.depth(), 2);
561
562        let batch = queue.dequeue_batch(10);
563        assert_eq!(batch.len(), 2);
564        // High priority should be first
565        assert_eq!(batch[0].metadata.id, "2");
566    }
567
568    #[test]
569    fn test_queue_overflow() {
570        let config = DynamicBatchConfig {
571            max_queue_depth: 2,
572            ..Default::default()
573        };
574        let mut queue: RequestQueue<Vec<f64>> = RequestQueue::new(config);
575
576        queue
577            .enqueue(BatchRequest::new("1".to_string(), vec![1.0], vec![vec![1]]))
578            .expect("unwrap");
579        queue
580            .enqueue(BatchRequest::new("2".to_string(), vec![2.0], vec![vec![1]]))
581            .expect("unwrap");
582
583        let result = queue.enqueue(BatchRequest::new("3".to_string(), vec![3.0], vec![vec![1]]));
584        assert!(matches!(result, Err(BatchingError::QueueFull)));
585    }
586
587    #[test]
588    fn test_batching_stats() {
589        let mut stats = BatchingStats::default();
590
591        stats.update_batch(4, Duration::from_millis(5), Duration::from_millis(10));
592        stats.update_batch(8, Duration::from_millis(6), Duration::from_millis(12));
593
594        assert_eq!(stats.total_requests, 12);
595        assert_eq!(stats.total_batches, 2);
596        assert_eq!(stats.avg_batch_size, 6.0);
597    }
598
599    #[test]
600    fn test_adaptive_batcher() {
601        let config = DynamicBatchConfig {
602            adaptive_sizing: true,
603            target_latency: Some(Duration::from_millis(50)),
604            min_batch_size: 1,
605            max_batch_size: 32,
606            ..Default::default()
607        };
608
609        let mut batcher = AdaptiveBatcher::new(config);
610        let initial_size = batcher.current_batch_size();
611
612        // Simulate low latency - should increase batch size
613        batcher.update(8, Duration::from_millis(20), 400.0);
614        assert!(batcher.current_batch_size() >= initial_size);
615
616        // Simulate high latency - should decrease batch size
617        for _ in 0..10 {
618            batcher.update(8, Duration::from_millis(100), 80.0);
619        }
620        assert!(batcher.current_batch_size() < initial_size);
621    }
622
623    #[test]
624    fn test_dynamic_batcher() {
625        let config = DynamicBatchConfig::latency_optimized();
626        let mut batcher: DynamicBatcher<Vec<f64>> = DynamicBatcher::new(config);
627
628        // Submit requests
629        for i in 0..5 {
630            let request = BatchRequest::new(format!("req_{}", i), vec![i as f64], vec![vec![1]]);
631            batcher.submit(request).expect("unwrap");
632        }
633
634        assert_eq!(batcher.queue_depth(), 5);
635
636        // Form batch
637        let batch = batcher.try_form_batch();
638        assert!(batch.is_some());
639
640        let batch = batch.expect("unwrap");
641        assert!(!batch.is_empty());
642    }
643
644    #[test]
645    fn test_config_presets() {
646        let throughput = DynamicBatchConfig::throughput_optimized();
647        assert!(throughput.max_batch_size > DynamicBatchConfig::default().max_batch_size);
648
649        let latency = DynamicBatchConfig::latency_optimized();
650        assert!(latency.max_wait_time < DynamicBatchConfig::default().max_wait_time);
651
652        let interactive = DynamicBatchConfig::interactive();
653        assert!(interactive.max_batch_size < throughput.max_batch_size);
654    }
655
656    #[test]
657    fn test_stats_efficiency() {
658        let mut stats = BatchingStats::default();
659        stats.update_batch(16, Duration::from_millis(5), Duration::from_millis(10));
660
661        assert_eq!(stats.efficiency(32), 0.5);
662        assert_eq!(stats.efficiency(16), 1.0);
663    }
664}