Skip to main content

tensorlogic_infer/
batch.rs

1//! Batch execution support for processing multiple inputs efficiently.
2
3use tensorlogic_ir::EinsumGraph;
4
5/// Batch execution result containing outputs and optional metadata
6#[derive(Debug, Clone)]
7pub struct BatchResult<T> {
8    pub outputs: Vec<T>,
9    pub batch_size: usize,
10}
11
12impl<T> BatchResult<T> {
13    pub fn new(outputs: Vec<T>) -> Self {
14        let batch_size = outputs.len();
15        BatchResult {
16            outputs,
17            batch_size,
18        }
19    }
20
21    pub fn len(&self) -> usize {
22        self.outputs.len()
23    }
24
25    pub fn is_empty(&self) -> bool {
26        self.outputs.is_empty()
27    }
28
29    pub fn iter(&self) -> impl Iterator<Item = &T> {
30        self.outputs.iter()
31    }
32}
33
34/// Extension trait for batch execution
35pub trait TlBatchExecutor {
36    type Tensor;
37    type Error;
38
39    /// Execute a graph on a batch of inputs
40    fn execute_batch(
41        &mut self,
42        graph: &EinsumGraph,
43        batch_inputs: Vec<Vec<Self::Tensor>>,
44    ) -> Result<BatchResult<Self::Tensor>, Self::Error>;
45
46    /// Execute a graph on a batch of inputs with parallel processing
47    fn execute_batch_parallel(
48        &mut self,
49        graph: &EinsumGraph,
50        batch_inputs: Vec<Vec<Self::Tensor>>,
51        num_threads: Option<usize>,
52    ) -> Result<BatchResult<Self::Tensor>, Self::Error>;
53
54    /// Get maximum recommended batch size for this executor
55    fn max_batch_size(&self) -> Option<usize> {
56        None // Default: no limit
57    }
58
59    /// Get optimal batch size for this executor
60    fn optimal_batch_size(&self) -> usize {
61        32 // Default recommendation
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68
69    #[test]
70    fn test_batch_result() {
71        let outputs = vec![1, 2, 3, 4];
72        let result = BatchResult::new(outputs.clone());
73
74        assert_eq!(result.len(), 4);
75        assert_eq!(result.batch_size, 4);
76        assert!(!result.is_empty());
77
78        let collected: Vec<&i32> = result.iter().collect();
79        assert_eq!(collected.len(), 4);
80    }
81
82    #[test]
83    fn test_empty_batch_result() {
84        let result: BatchResult<i32> = BatchResult::new(vec![]);
85        assert_eq!(result.len(), 0);
86        assert!(result.is_empty());
87    }
88}