Skip to main content

tensorlogic_infer/
dynamic_batching.rs

1//! Dynamic batching for inference serving.
2//!
3//! This module provides dynamic batching capabilities for efficient inference serving:
4//! - Automatic request batching with configurable timeouts
5//! - Priority-based request queuing
6//! - Adaptive batch sizing based on load
7//! - Request deduplication
8//! - Batch splitting for heterogeneous requests
9//! - Latency and throughput optimization
10
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, VecDeque};
13use std::time::{Duration, Instant};
14use thiserror::Error;
15
16#[cfg(feature = "async")]
17use tokio::sync::oneshot;
18
19/// Dynamic batching errors.
20#[derive(Error, Debug, Clone, PartialEq)]
21pub enum BatchingError {
22    #[error("Request queue is full")]
23    QueueFull,
24
25    #[error("Request timeout after {0:?}")]
26    Timeout(Duration),
27
28    #[error("Invalid batch size: {0}")]
29    InvalidBatchSize(usize),
30
31    #[error("Request cancelled")]
32    Cancelled,
33
34    #[error("Incompatible request shapes")]
35    IncompatibleShapes,
36}
37
38/// Priority level for requests.
39#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
40pub enum Priority {
41    /// Low priority (batch-friendly, can wait)
42    Low = 0,
43    /// Normal priority
44    Normal = 1,
45    /// High priority (minimize latency)
46    High = 2,
47    /// Critical priority (process immediately)
48    Critical = 3,
49}
50
51/// Request metadata for batching decisions.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct RequestMetadata {
54    /// Request ID
55    pub id: String,
56    /// Priority level
57    pub priority: Priority,
58    /// Arrival timestamp
59    #[serde(skip, default = "Instant::now")]
60    pub arrival_time: Instant,
61    /// Maximum tolerable latency
62    pub max_latency: Option<Duration>,
63    /// Input shapes (for compatibility checking)
64    pub input_shapes: Vec<Vec<usize>>,
65}
66
67/// A request to be batched.
68pub struct BatchRequest<T> {
69    /// Request metadata
70    pub metadata: RequestMetadata,
71    /// Input data
72    pub inputs: T,
73    /// Response channel (for async execution)
74    #[cfg(feature = "async")]
75    pub response_tx: Option<oneshot::Sender<Result<T, BatchingError>>>,
76}
77
78impl<T> BatchRequest<T> {
79    /// Create a new batch request.
80    pub fn new(id: String, inputs: T, input_shapes: Vec<Vec<usize>>) -> Self {
81        Self {
82            metadata: RequestMetadata {
83                id,
84                priority: Priority::Normal,
85                arrival_time: Instant::now(),
86                max_latency: None,
87                input_shapes,
88            },
89            inputs,
90            #[cfg(feature = "async")]
91            response_tx: None,
92        }
93    }
94
95    /// Set priority.
96    pub fn with_priority(mut self, priority: Priority) -> Self {
97        self.metadata.priority = priority;
98        self
99    }
100
101    /// Set maximum latency.
102    pub fn with_max_latency(mut self, max_latency: Duration) -> Self {
103        self.metadata.max_latency = Some(max_latency);
104        self
105    }
106
107    /// Check if request has timed out.
108    pub fn is_timed_out(&self) -> bool {
109        if let Some(max_latency) = self.metadata.max_latency {
110            self.metadata.arrival_time.elapsed() > max_latency
111        } else {
112            false
113        }
114    }
115
116    /// Get age of the request.
117    pub fn age(&self) -> Duration {
118        self.metadata.arrival_time.elapsed()
119    }
120}
121
122/// Configuration for dynamic batching.
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct DynamicBatchConfig {
125    /// Maximum batch size
126    pub max_batch_size: usize,
127    /// Minimum batch size (for efficiency)
128    pub min_batch_size: usize,
129    /// Maximum wait time before forming a batch
130    pub max_wait_time: Duration,
131    /// Maximum queue depth
132    pub max_queue_depth: usize,
133    /// Enable adaptive batch sizing
134    pub adaptive_sizing: bool,
135    /// Target latency for adaptive sizing
136    pub target_latency: Option<Duration>,
137    /// Enable request deduplication
138    pub enable_deduplication: bool,
139    /// Enable batch splitting for heterogeneous requests
140    pub enable_splitting: bool,
141}
142
143impl Default for DynamicBatchConfig {
144    fn default() -> Self {
145        Self {
146            max_batch_size: 32,
147            min_batch_size: 1,
148            max_wait_time: Duration::from_millis(10),
149            max_queue_depth: 1000,
150            adaptive_sizing: true,
151            target_latency: Some(Duration::from_millis(50)),
152            enable_deduplication: false,
153            enable_splitting: true,
154        }
155    }
156}
157
158impl DynamicBatchConfig {
159    /// Create configuration optimized for throughput.
160    pub fn throughput_optimized() -> Self {
161        Self {
162            max_batch_size: 128,
163            min_batch_size: 8,
164            max_wait_time: Duration::from_millis(50),
165            ..Default::default()
166        }
167    }
168
169    /// Create configuration optimized for latency.
170    pub fn latency_optimized() -> Self {
171        Self {
172            max_batch_size: 16,
173            min_batch_size: 1,
174            max_wait_time: Duration::from_millis(1),
175            target_latency: Some(Duration::from_millis(10)),
176            ..Default::default()
177        }
178    }
179
180    /// Create configuration for interactive workloads.
181    pub fn interactive() -> Self {
182        Self {
183            max_batch_size: 8,
184            min_batch_size: 1,
185            max_wait_time: Duration::from_millis(5),
186            adaptive_sizing: true,
187            target_latency: Some(Duration::from_millis(20)),
188            ..Default::default()
189        }
190    }
191}
192
193/// Statistics for dynamic batching.
194#[derive(Debug, Clone, Default, Serialize, Deserialize)]
195pub struct BatchingStats {
196    /// Total requests processed
197    pub total_requests: usize,
198    /// Total batches formed
199    pub total_batches: usize,
200    /// Average batch size
201    pub avg_batch_size: f64,
202    /// Average wait time
203    pub avg_wait_time: Duration,
204    /// Average latency
205    pub avg_latency: Duration,
206    /// Number of timeouts
207    pub num_timeouts: usize,
208    /// Number of queue overflows
209    pub num_overflows: usize,
210    /// Current queue depth
211    pub current_queue_depth: usize,
212}
213
214impl BatchingStats {
215    /// Update statistics with a new batch.
216    pub fn update_batch(&mut self, batch_size: usize, wait_time: Duration, latency: Duration) {
217        self.total_batches += 1;
218        self.total_requests += batch_size;
219
220        // Update averages using incremental formula
221        let n = self.total_batches as f64;
222        self.avg_batch_size = (self.avg_batch_size * (n - 1.0) + batch_size as f64) / n;
223
224        self.avg_wait_time = Duration::from_secs_f64(
225            (self.avg_wait_time.as_secs_f64() * (n - 1.0) + wait_time.as_secs_f64()) / n,
226        );
227
228        self.avg_latency = Duration::from_secs_f64(
229            (self.avg_latency.as_secs_f64() * (n - 1.0) + latency.as_secs_f64()) / n,
230        );
231    }
232
233    /// Record a timeout.
234    pub fn record_timeout(&mut self) {
235        self.num_timeouts += 1;
236    }
237
238    /// Record a queue overflow.
239    pub fn record_overflow(&mut self) {
240        self.num_overflows += 1;
241    }
242
243    /// Get throughput (requests per second).
244    pub fn throughput(&self) -> f64 {
245        if self.avg_latency.as_secs_f64() > 0.0 {
246            self.avg_batch_size / self.avg_latency.as_secs_f64()
247        } else {
248            0.0
249        }
250    }
251
252    /// Get batching efficiency (ratio of actual to max batch size).
253    pub fn efficiency(&self, max_batch_size: usize) -> f64 {
254        if max_batch_size > 0 {
255            self.avg_batch_size / max_batch_size as f64
256        } else {
257            0.0
258        }
259    }
260}
261
262/// Request queue with priority support.
263pub struct RequestQueue<T> {
264    queues: HashMap<Priority, VecDeque<BatchRequest<T>>>,
265    config: DynamicBatchConfig,
266}
267
268impl<T> RequestQueue<T> {
269    /// Create a new request queue.
270    pub fn new(config: DynamicBatchConfig) -> Self {
271        let mut queues = HashMap::new();
272        queues.insert(Priority::Low, VecDeque::new());
273        queues.insert(Priority::Normal, VecDeque::new());
274        queues.insert(Priority::High, VecDeque::new());
275        queues.insert(Priority::Critical, VecDeque::new());
276
277        Self { queues, config }
278    }
279
280    /// Enqueue a request.
281    pub fn enqueue(&mut self, request: BatchRequest<T>) -> Result<(), BatchingError> {
282        let total_depth: usize = self.queues.values().map(|q| q.len()).sum();
283        if total_depth >= self.config.max_queue_depth {
284            return Err(BatchingError::QueueFull);
285        }
286
287        let priority = request.metadata.priority;
288        self.queues.get_mut(&priority).unwrap().push_back(request);
289        Ok(())
290    }
291
292    /// Dequeue requests to form a batch.
293    pub fn dequeue_batch(&mut self, max_size: usize) -> Vec<BatchRequest<T>> {
294        let mut batch = Vec::new();
295        let priorities = [
296            Priority::Critical,
297            Priority::High,
298            Priority::Normal,
299            Priority::Low,
300        ];
301
302        for &priority in &priorities {
303            if batch.len() >= max_size {
304                break;
305            }
306
307            let queue = self.queues.get_mut(&priority).unwrap();
308            while let Some(request) = queue.pop_front() {
309                // Skip timed-out requests
310                if request.is_timed_out() {
311                    continue;
312                }
313
314                batch.push(request);
315
316                if batch.len() >= max_size {
317                    break;
318                }
319            }
320        }
321
322        batch
323    }
324
325    /// Get total queue depth.
326    pub fn depth(&self) -> usize {
327        self.queues.values().map(|q| q.len()).sum()
328    }
329
330    /// Get oldest request age.
331    pub fn oldest_age(&self) -> Option<Duration> {
332        let priorities = [
333            Priority::Critical,
334            Priority::High,
335            Priority::Normal,
336            Priority::Low,
337        ];
338
339        for &priority in &priorities {
340            if let Some(request) = self.queues.get(&priority).unwrap().front() {
341                return Some(request.age());
342            }
343        }
344        None
345    }
346
347    /// Check if batch formation criteria are met.
348    pub fn should_form_batch(&self) -> bool {
349        // Form batch if max wait time exceeded
350        if let Some(age) = self.oldest_age() {
351            if age >= self.config.max_wait_time {
352                return true;
353            }
354        }
355
356        // Form batch if min batch size reached
357        let depth = self.depth();
358        if depth >= self.config.min_batch_size {
359            return true;
360        }
361
362        // Form batch immediately for critical requests
363        if !self.queues.get(&Priority::Critical).unwrap().is_empty() {
364            return true;
365        }
366
367        false
368    }
369}
370
371/// Adaptive batch size controller.
372pub struct AdaptiveBatcher {
373    config: DynamicBatchConfig,
374    current_batch_size: usize,
375    latency_history: VecDeque<Duration>,
376    throughput_history: VecDeque<f64>,
377}
378
379impl AdaptiveBatcher {
380    /// Create a new adaptive batcher.
381    pub fn new(config: DynamicBatchConfig) -> Self {
382        Self {
383            current_batch_size: config.max_batch_size / 2,
384            config,
385            latency_history: VecDeque::with_capacity(100),
386            throughput_history: VecDeque::with_capacity(100),
387        }
388    }
389
390    /// Get current recommended batch size.
391    pub fn current_batch_size(&self) -> usize {
392        self.current_batch_size
393    }
394
395    /// Update batch size based on observed latency.
396    pub fn update(&mut self, _batch_size: usize, latency: Duration, throughput: f64) {
397        self.latency_history.push_back(latency);
398        self.throughput_history.push_back(throughput);
399
400        // Keep only recent history
401        while self.latency_history.len() > 100 {
402            self.latency_history.pop_front();
403        }
404        while self.throughput_history.len() > 100 {
405            self.throughput_history.pop_front();
406        }
407
408        if !self.config.adaptive_sizing {
409            return;
410        }
411
412        let target_latency = match self.config.target_latency {
413            Some(t) => t,
414            None => return,
415        };
416
417        // Simple adaptive strategy: increase batch size if under target,
418        // decrease if over target
419        if latency < target_latency * 8 / 10 {
420            // Under target - can increase batch size
421            self.current_batch_size = (self.current_batch_size + 1).min(self.config.max_batch_size);
422        } else if latency > target_latency {
423            // Over target - decrease batch size
424            self.current_batch_size =
425                (self.current_batch_size.saturating_sub(1)).max(self.config.min_batch_size);
426        }
427    }
428
429    /// Get average recent latency.
430    pub fn avg_latency(&self) -> Option<Duration> {
431        if self.latency_history.is_empty() {
432            return None;
433        }
434
435        let sum: Duration = self.latency_history.iter().sum();
436        Some(sum / self.latency_history.len() as u32)
437    }
438
439    /// Get average recent throughput.
440    pub fn avg_throughput(&self) -> Option<f64> {
441        if self.throughput_history.is_empty() {
442            return None;
443        }
444
445        Some(self.throughput_history.iter().sum::<f64>() / self.throughput_history.len() as f64)
446    }
447}
448
449/// Dynamic batcher for inference requests.
450pub struct DynamicBatcher<T> {
451    queue: RequestQueue<T>,
452    stats: BatchingStats,
453    adaptive: AdaptiveBatcher,
454}
455
456impl<T> DynamicBatcher<T> {
457    /// Create a new dynamic batcher.
458    pub fn new(config: DynamicBatchConfig) -> Self {
459        let adaptive = AdaptiveBatcher::new(config.clone());
460        let queue = RequestQueue::new(config.clone());
461
462        Self {
463            queue,
464            stats: BatchingStats::default(),
465            adaptive,
466        }
467    }
468
469    /// Submit a request for batching.
470    pub fn submit(&mut self, request: BatchRequest<T>) -> Result<(), BatchingError> {
471        self.queue.enqueue(request)?;
472        self.stats.current_queue_depth = self.queue.depth();
473        Ok(())
474    }
475
476    /// Try to form a batch if criteria are met.
477    pub fn try_form_batch(&mut self) -> Option<Vec<BatchRequest<T>>> {
478        if !self.queue.should_form_batch() {
479            return None;
480        }
481
482        let batch_size = self.adaptive.current_batch_size();
483        let batch = self.queue.dequeue_batch(batch_size);
484
485        if batch.is_empty() {
486            return None;
487        }
488
489        self.stats.current_queue_depth = self.queue.depth();
490        Some(batch)
491    }
492
493    /// Get statistics.
494    pub fn stats(&self) -> &BatchingStats {
495        &self.stats
496    }
497
498    /// Record batch execution results.
499    pub fn record_batch(&mut self, batch_size: usize, wait_time: Duration, latency: Duration) {
500        self.stats.update_batch(batch_size, wait_time, latency);
501
502        let throughput = batch_size as f64 / latency.as_secs_f64();
503        self.adaptive.update(batch_size, latency, throughput);
504    }
505
506    /// Get current queue depth.
507    pub fn queue_depth(&self) -> usize {
508        self.queue.depth()
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515
516    #[test]
517    fn test_priority_ordering() {
518        assert!(Priority::Critical > Priority::High);
519        assert!(Priority::High > Priority::Normal);
520        assert!(Priority::Normal > Priority::Low);
521    }
522
523    #[test]
524    fn test_request_timeout() {
525        let request = BatchRequest::new("test".to_string(), vec![1.0, 2.0], vec![vec![2]])
526            .with_max_latency(Duration::from_millis(1));
527
528        std::thread::sleep(Duration::from_millis(2));
529        assert!(request.is_timed_out());
530    }
531
532    #[test]
533    fn test_queue_enqueue_dequeue() {
534        let config = DynamicBatchConfig::default();
535        let mut queue: RequestQueue<Vec<f64>> = RequestQueue::new(config);
536
537        let req1 = BatchRequest::new("1".to_string(), vec![1.0], vec![vec![1]]);
538        let req2 = BatchRequest::new("2".to_string(), vec![2.0], vec![vec![1]])
539            .with_priority(Priority::High);
540
541        queue.enqueue(req1).unwrap();
542        queue.enqueue(req2).unwrap();
543
544        assert_eq!(queue.depth(), 2);
545
546        let batch = queue.dequeue_batch(10);
547        assert_eq!(batch.len(), 2);
548        // High priority should be first
549        assert_eq!(batch[0].metadata.id, "2");
550    }
551
552    #[test]
553    fn test_queue_overflow() {
554        let config = DynamicBatchConfig {
555            max_queue_depth: 2,
556            ..Default::default()
557        };
558        let mut queue: RequestQueue<Vec<f64>> = RequestQueue::new(config);
559
560        queue
561            .enqueue(BatchRequest::new("1".to_string(), vec![1.0], vec![vec![1]]))
562            .unwrap();
563        queue
564            .enqueue(BatchRequest::new("2".to_string(), vec![2.0], vec![vec![1]]))
565            .unwrap();
566
567        let result = queue.enqueue(BatchRequest::new("3".to_string(), vec![3.0], vec![vec![1]]));
568        assert!(matches!(result, Err(BatchingError::QueueFull)));
569    }
570
571    #[test]
572    fn test_batching_stats() {
573        let mut stats = BatchingStats::default();
574
575        stats.update_batch(4, Duration::from_millis(5), Duration::from_millis(10));
576        stats.update_batch(8, Duration::from_millis(6), Duration::from_millis(12));
577
578        assert_eq!(stats.total_requests, 12);
579        assert_eq!(stats.total_batches, 2);
580        assert_eq!(stats.avg_batch_size, 6.0);
581    }
582
583    #[test]
584    fn test_adaptive_batcher() {
585        let config = DynamicBatchConfig {
586            adaptive_sizing: true,
587            target_latency: Some(Duration::from_millis(50)),
588            min_batch_size: 1,
589            max_batch_size: 32,
590            ..Default::default()
591        };
592
593        let mut batcher = AdaptiveBatcher::new(config);
594        let initial_size = batcher.current_batch_size();
595
596        // Simulate low latency - should increase batch size
597        batcher.update(8, Duration::from_millis(20), 400.0);
598        assert!(batcher.current_batch_size() >= initial_size);
599
600        // Simulate high latency - should decrease batch size
601        for _ in 0..10 {
602            batcher.update(8, Duration::from_millis(100), 80.0);
603        }
604        assert!(batcher.current_batch_size() < initial_size);
605    }
606
607    #[test]
608    fn test_dynamic_batcher() {
609        let config = DynamicBatchConfig::latency_optimized();
610        let mut batcher: DynamicBatcher<Vec<f64>> = DynamicBatcher::new(config);
611
612        // Submit requests
613        for i in 0..5 {
614            let request = BatchRequest::new(format!("req_{}", i), vec![i as f64], vec![vec![1]]);
615            batcher.submit(request).unwrap();
616        }
617
618        assert_eq!(batcher.queue_depth(), 5);
619
620        // Form batch
621        let batch = batcher.try_form_batch();
622        assert!(batch.is_some());
623
624        let batch = batch.unwrap();
625        assert!(!batch.is_empty());
626    }
627
628    #[test]
629    fn test_config_presets() {
630        let throughput = DynamicBatchConfig::throughput_optimized();
631        assert!(throughput.max_batch_size > DynamicBatchConfig::default().max_batch_size);
632
633        let latency = DynamicBatchConfig::latency_optimized();
634        assert!(latency.max_wait_time < DynamicBatchConfig::default().max_wait_time);
635
636        let interactive = DynamicBatchConfig::interactive();
637        assert!(interactive.max_batch_size < throughput.max_batch_size);
638    }
639
640    #[test]
641    fn test_stats_efficiency() {
642        let mut stats = BatchingStats::default();
643        stats.update_batch(16, Duration::from_millis(5), Duration::from_millis(10));
644
645        assert_eq!(stats.efficiency(32), 0.5);
646        assert_eq!(stats.efficiency(16), 1.0);
647    }
648}