Skip to main content

tensorlogic_scirs_backend/
batch_executor.rs

1//! Batch execution support for parallel processing.
2
3use crate::{Scirs2Exec, Scirs2Tensor};
4use tensorlogic_infer::{BatchResult, ExecutorError, TlAutodiff, TlBatchExecutor};
5use tensorlogic_ir::EinsumGraph;
6
7#[cfg(feature = "parallel")]
8use scirs2_core::parallel_ops::*;
9
10impl TlBatchExecutor for Scirs2Exec {
11    type Tensor = Scirs2Tensor;
12    type Error = ExecutorError;
13
14    fn execute_batch(
15        &mut self,
16        graph: &EinsumGraph,
17        batch_inputs: Vec<Vec<Self::Tensor>>,
18    ) -> Result<BatchResult<Self::Tensor>, Self::Error> {
19        if batch_inputs.is_empty() {
20            return Err(ExecutorError::InvalidEinsumSpec(
21                "Empty batch provided".to_string(),
22            ));
23        }
24
25        let mut outputs = Vec::with_capacity(batch_inputs.len());
26
27        for input_batch in batch_inputs {
28            // Store tensors for this batch item
29            for (idx, tensor) in input_batch.iter().enumerate() {
30                if idx < graph.tensors.len() {
31                    self.add_tensor(graph.tensors[idx].clone(), tensor.clone());
32                }
33            }
34
35            let output = self.forward(graph)?;
36            outputs.push(output);
37        }
38
39        Ok(BatchResult::new(outputs))
40    }
41
42    fn execute_batch_parallel(
43        &mut self,
44        graph: &EinsumGraph,
45        batch_inputs: Vec<Vec<Self::Tensor>>,
46        num_threads: Option<usize>,
47    ) -> Result<BatchResult<Self::Tensor>, Self::Error> {
48        #[cfg(feature = "parallel")]
49        {
50            if batch_inputs.is_empty() {
51                return Err(ExecutorError::InvalidEinsumSpec(
52                    "Empty batch provided".to_string(),
53                ));
54            }
55
56            // Configure thread pool if requested
57            if let Some(threads) = num_threads {
58                ThreadPoolBuilder::new()
59                    .num_threads(threads)
60                    .build_global()
61                    .ok(); // Ignore if already initialized
62            }
63
64            // Execute batch items in parallel
65            let results: Result<Vec<_>, _> = batch_inputs
66                .par_iter()
67                .map(|input_batch| {
68                    let mut executor = self.clone();
69
70                    for (idx, tensor) in input_batch.iter().enumerate() {
71                        if idx < graph.tensors.len() {
72                            executor.add_tensor(graph.tensors[idx].clone(), tensor.clone());
73                        }
74                    }
75
76                    executor.forward(graph)
77                })
78                .collect();
79
80            let outputs = results?;
81            Ok(BatchResult::new(outputs))
82        }
83
84        #[cfg(not(feature = "parallel"))]
85        {
86            let _ = num_threads; // Avoid unused variable warning
87                                 // Fall back to sequential execution when parallel feature is disabled
88            self.execute_batch(graph, batch_inputs)
89        }
90    }
91
92    fn optimal_batch_size(&self) -> usize {
93        // For CPU execution, a moderate batch size balances parallelism and overhead
94        let num_cpus = std::thread::available_parallelism()
95            .map(|p| p.get())
96            .unwrap_or(4);
97
98        // Use 2x the number of CPUs as a heuristic
99        num_cpus * 2
100    }
101}
102
103/// Parallel batch executor using rayon for CPU parallelism
104pub struct ParallelBatchExecutor {
105    /// Base executor template
106    base: Scirs2Exec,
107}
108
109impl ParallelBatchExecutor {
110    /// Create a new parallel batch executor
111    pub fn new() -> Self {
112        ParallelBatchExecutor {
113            base: Scirs2Exec::new(),
114        }
115    }
116
117    /// Create parallel batch executor with memory pooling
118    pub fn with_memory_pool() -> Self {
119        ParallelBatchExecutor {
120            base: Scirs2Exec::with_memory_pool(),
121        }
122    }
123
124    /// Execute batch in parallel using rayon
125    pub fn execute_parallel(
126        &self,
127        graph: &EinsumGraph,
128        batch_inputs: Vec<Vec<Scirs2Tensor>>,
129    ) -> Result<BatchResult<Scirs2Tensor>, ExecutorError> {
130        if batch_inputs.is_empty() {
131            return Err(ExecutorError::InvalidEinsumSpec(
132                "Empty batch provided".to_string(),
133            ));
134        }
135
136        #[cfg(feature = "parallel")]
137        {
138            // Execute batch items in parallel using rayon
139            let results: Result<Vec<_>, _> = batch_inputs
140                .par_iter()
141                .map(|input_batch| {
142                    let mut executor = self.base.clone();
143
144                    for (idx, tensor) in input_batch.iter().enumerate() {
145                        if idx < graph.tensors.len() {
146                            executor.add_tensor(graph.tensors[idx].clone(), tensor.clone());
147                        }
148                    }
149
150                    executor.forward(graph)
151                })
152                .collect();
153
154            let outputs = results?;
155            Ok(BatchResult::new(outputs))
156        }
157
158        #[cfg(not(feature = "parallel"))]
159        {
160            // Fall back to sequential execution when parallel feature is disabled
161            let mut outputs = Vec::with_capacity(batch_inputs.len());
162
163            for input_batch in batch_inputs {
164                let mut executor = self.base.clone();
165
166                for (idx, tensor) in input_batch.iter().enumerate() {
167                    if idx < graph.tensors.len() {
168                        executor.add_tensor(graph.tensors[idx].clone(), tensor.clone());
169                    }
170                }
171
172                let output = executor.forward(graph)?;
173                outputs.push(output);
174            }
175
176            Ok(BatchResult::new(outputs))
177        }
178    }
179}
180
181impl Default for ParallelBatchExecutor {
182    fn default() -> Self {
183        Self::new()
184    }
185}
186
187impl Clone for Scirs2Exec {
188    fn clone(&self) -> Self {
189        Scirs2Exec {
190            tensors: self.tensors.clone(),
191            tape: self.tape.clone(),
192            pool: None, // Don't clone pool to avoid shared state issues
193        }
194    }
195}
196
197#[cfg(all(test, feature = "integration-tests"))]
198mod tests {
199    use super::*;
200    use scirs2_core::ndarray::ArrayD;
201    use tensorlogic_compiler::compile_to_einsum;
202    use tensorlogic_ir::{TLExpr, Term};
203
204    fn create_test_tensor(shape: &[usize], value: f64) -> ArrayD<f64> {
205        ArrayD::from_elem(shape.to_vec(), value)
206    }
207
208    #[test]
209    fn test_batch_executor_basic() {
210        let x = TLExpr::pred("x", vec![Term::var("i")]);
211        let y = TLExpr::pred("y", vec![Term::var("i")]);
212        let expr = TLExpr::add(x, y);
213        let graph = compile_to_einsum(&expr).unwrap();
214
215        let mut executor = Scirs2Exec::new();
216
217        // Create batch of 3 items
218        let batch_inputs = vec![
219            vec![create_test_tensor(&[5], 1.0), create_test_tensor(&[5], 2.0)],
220            vec![create_test_tensor(&[5], 3.0), create_test_tensor(&[5], 4.0)],
221            vec![create_test_tensor(&[5], 5.0), create_test_tensor(&[5], 6.0)],
222        ];
223
224        let result = executor.execute_batch(&graph, batch_inputs).unwrap();
225
226        assert_eq!(result.len(), 3);
227        assert!((result.outputs[0][0] - 3.0).abs() < 1e-6); // 1 + 2
228        assert!((result.outputs[1][0] - 7.0).abs() < 1e-6); // 3 + 4
229        assert!((result.outputs[2][0] - 11.0).abs() < 1e-6); // 5 + 6
230        assert_eq!(result.batch_size, 3);
231    }
232
233    #[test]
234    fn test_optimal_batch_size() {
235        let executor = Scirs2Exec::new();
236
237        let batch_size = executor.optimal_batch_size();
238        assert!(batch_size > 0);
239        assert!(batch_size <= 128); // Reasonable upper bound
240    }
241
242    #[test]
243    fn test_parallel_batch_executor() {
244        let x = TLExpr::pred("x", vec![Term::var("i")]);
245        let expr = TLExpr::mul(x.clone(), x);
246        let graph = compile_to_einsum(&expr).unwrap();
247
248        let executor = ParallelBatchExecutor::new();
249
250        let batch_inputs = vec![
251            vec![create_test_tensor(&[3], 2.0)],
252            vec![create_test_tensor(&[3], 3.0)],
253        ];
254
255        let result = executor.execute_parallel(&graph, batch_inputs).unwrap();
256
257        assert_eq!(result.len(), 2);
258        assert!((result.outputs[0][0] - 4.0).abs() < 1e-6); // 2 * 2
259        assert!((result.outputs[1][0] - 9.0).abs() < 1e-6); // 3 * 3
260    }
261
262    #[test]
263    fn test_empty_batch_error() {
264        let x = TLExpr::pred("x", vec![Term::var("i")]);
265        let graph = compile_to_einsum(&x).unwrap();
266
267        let mut executor = Scirs2Exec::new();
268        let batch_inputs: Vec<Vec<ArrayD<f64>>> = vec![];
269
270        let result = executor.execute_batch(&graph, batch_inputs);
271        assert!(result.is_err());
272    }
273
274    #[test]
275    fn test_batch_parallel_same_as_sequential() {
276        let x = TLExpr::pred("x", vec![Term::var("i")]);
277        let y = TLExpr::pred("y", vec![Term::var("i")]);
278        let expr = TLExpr::add(x, y);
279        let graph = compile_to_einsum(&expr).unwrap();
280
281        let batch_inputs = vec![
282            vec![create_test_tensor(&[3], 1.0), create_test_tensor(&[3], 2.0)],
283            vec![create_test_tensor(&[3], 3.0), create_test_tensor(&[3], 4.0)],
284        ];
285
286        let mut executor = Scirs2Exec::new();
287        let result_seq = executor
288            .execute_batch(&graph, batch_inputs.clone())
289            .unwrap();
290
291        let mut executor2 = Scirs2Exec::new();
292        let result_par = executor2
293            .execute_batch_parallel(&graph, batch_inputs, None)
294            .unwrap();
295
296        assert_eq!(result_seq.len(), result_par.len());
297        for (seq, par) in result_seq.outputs.iter().zip(result_par.outputs.iter()) {
298            assert_eq!(seq.shape(), par.shape());
299            for (s, p) in seq.iter().zip(par.iter()) {
300                assert!((s - p).abs() < 1e-10);
301            }
302        }
303    }
304}