tensorlogic_infer/
batch.rs1use tensorlogic_ir::EinsumGraph;
4
5#[derive(Debug, Clone)]
7pub struct BatchResult<T> {
8 pub outputs: Vec<T>,
9 pub batch_size: usize,
10}
11
12impl<T> BatchResult<T> {
13 pub fn new(outputs: Vec<T>) -> Self {
14 let batch_size = outputs.len();
15 BatchResult {
16 outputs,
17 batch_size,
18 }
19 }
20
21 pub fn len(&self) -> usize {
22 self.outputs.len()
23 }
24
25 pub fn is_empty(&self) -> bool {
26 self.outputs.is_empty()
27 }
28
29 pub fn iter(&self) -> impl Iterator<Item = &T> {
30 self.outputs.iter()
31 }
32}
33
34pub trait TlBatchExecutor {
36 type Tensor;
37 type Error;
38
39 fn execute_batch(
41 &mut self,
42 graph: &EinsumGraph,
43 batch_inputs: Vec<Vec<Self::Tensor>>,
44 ) -> Result<BatchResult<Self::Tensor>, Self::Error>;
45
46 fn execute_batch_parallel(
48 &mut self,
49 graph: &EinsumGraph,
50 batch_inputs: Vec<Vec<Self::Tensor>>,
51 num_threads: Option<usize>,
52 ) -> Result<BatchResult<Self::Tensor>, Self::Error>;
53
54 fn max_batch_size(&self) -> Option<usize> {
56 None }
58
59 fn optimal_batch_size(&self) -> usize {
61 32 }
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68
69 #[test]
70 fn test_batch_result() {
71 let outputs = vec![1, 2, 3, 4];
72 let result = BatchResult::new(outputs.clone());
73
74 assert_eq!(result.len(), 4);
75 assert_eq!(result.batch_size, 4);
76 assert!(!result.is_empty());
77
78 let collected: Vec<&i32> = result.iter().collect();
79 assert_eq!(collected.len(), 4);
80 }
81
82 #[test]
83 fn test_empty_batch_result() {
84 let result: BatchResult<i32> = BatchResult::new(vec![]);
85 assert_eq!(result.len(), 0);
86 assert!(result.is_empty());
87 }
88}