tensorlogic_scirs_backend/
batch_executor.rs1use crate::{Scirs2Exec, Scirs2Tensor};
4use tensorlogic_infer::{BatchResult, ExecutorError, TlAutodiff, TlBatchExecutor};
5use tensorlogic_ir::EinsumGraph;
6
7#[cfg(feature = "parallel")]
8use scirs2_core::parallel_ops::*;
9
10impl TlBatchExecutor for Scirs2Exec {
11 type Tensor = Scirs2Tensor;
12 type Error = ExecutorError;
13
14 fn execute_batch(
15 &mut self,
16 graph: &EinsumGraph,
17 batch_inputs: Vec<Vec<Self::Tensor>>,
18 ) -> Result<BatchResult<Self::Tensor>, Self::Error> {
19 if batch_inputs.is_empty() {
20 return Err(ExecutorError::InvalidEinsumSpec(
21 "Empty batch provided".to_string(),
22 ));
23 }
24
25 let mut outputs = Vec::with_capacity(batch_inputs.len());
26
27 for input_batch in batch_inputs {
28 for (idx, tensor) in input_batch.iter().enumerate() {
30 if idx < graph.tensors.len() {
31 self.add_tensor(graph.tensors[idx].clone(), tensor.clone());
32 }
33 }
34
35 let output = self.forward(graph)?;
36 outputs.push(output);
37 }
38
39 Ok(BatchResult::new(outputs))
40 }
41
42 fn execute_batch_parallel(
43 &mut self,
44 graph: &EinsumGraph,
45 batch_inputs: Vec<Vec<Self::Tensor>>,
46 num_threads: Option<usize>,
47 ) -> Result<BatchResult<Self::Tensor>, Self::Error> {
48 #[cfg(feature = "parallel")]
49 {
50 if batch_inputs.is_empty() {
51 return Err(ExecutorError::InvalidEinsumSpec(
52 "Empty batch provided".to_string(),
53 ));
54 }
55
56 if let Some(threads) = num_threads {
58 ThreadPoolBuilder::new()
59 .num_threads(threads)
60 .build_global()
61 .ok(); }
63
64 let results: Result<Vec<_>, _> = batch_inputs
66 .par_iter()
67 .map(|input_batch| {
68 let mut executor = self.clone();
69
70 for (idx, tensor) in input_batch.iter().enumerate() {
71 if idx < graph.tensors.len() {
72 executor.add_tensor(graph.tensors[idx].clone(), tensor.clone());
73 }
74 }
75
76 executor.forward(graph)
77 })
78 .collect();
79
80 let outputs = results?;
81 Ok(BatchResult::new(outputs))
82 }
83
84 #[cfg(not(feature = "parallel"))]
85 {
86 let _ = num_threads; self.execute_batch(graph, batch_inputs)
89 }
90 }
91
92 fn optimal_batch_size(&self) -> usize {
93 let num_cpus = std::thread::available_parallelism()
95 .map(|p| p.get())
96 .unwrap_or(4);
97
98 num_cpus * 2
100 }
101}
102
103pub struct ParallelBatchExecutor {
105 base: Scirs2Exec,
107}
108
109impl ParallelBatchExecutor {
110 pub fn new() -> Self {
112 ParallelBatchExecutor {
113 base: Scirs2Exec::new(),
114 }
115 }
116
117 pub fn with_memory_pool() -> Self {
119 ParallelBatchExecutor {
120 base: Scirs2Exec::with_memory_pool(),
121 }
122 }
123
124 pub fn execute_parallel(
126 &self,
127 graph: &EinsumGraph,
128 batch_inputs: Vec<Vec<Scirs2Tensor>>,
129 ) -> Result<BatchResult<Scirs2Tensor>, ExecutorError> {
130 if batch_inputs.is_empty() {
131 return Err(ExecutorError::InvalidEinsumSpec(
132 "Empty batch provided".to_string(),
133 ));
134 }
135
136 #[cfg(feature = "parallel")]
137 {
138 let results: Result<Vec<_>, _> = batch_inputs
140 .par_iter()
141 .map(|input_batch| {
142 let mut executor = self.base.clone();
143
144 for (idx, tensor) in input_batch.iter().enumerate() {
145 if idx < graph.tensors.len() {
146 executor.add_tensor(graph.tensors[idx].clone(), tensor.clone());
147 }
148 }
149
150 executor.forward(graph)
151 })
152 .collect();
153
154 let outputs = results?;
155 Ok(BatchResult::new(outputs))
156 }
157
158 #[cfg(not(feature = "parallel"))]
159 {
160 let mut outputs = Vec::with_capacity(batch_inputs.len());
162
163 for input_batch in batch_inputs {
164 let mut executor = self.base.clone();
165
166 for (idx, tensor) in input_batch.iter().enumerate() {
167 if idx < graph.tensors.len() {
168 executor.add_tensor(graph.tensors[idx].clone(), tensor.clone());
169 }
170 }
171
172 let output = executor.forward(graph)?;
173 outputs.push(output);
174 }
175
176 Ok(BatchResult::new(outputs))
177 }
178 }
179}
180
181impl Default for ParallelBatchExecutor {
182 fn default() -> Self {
183 Self::new()
184 }
185}
186
187impl Clone for Scirs2Exec {
188 fn clone(&self) -> Self {
189 Scirs2Exec {
190 tensors: self.tensors.clone(),
191 tape: self.tape.clone(),
192 pool: None, }
194 }
195}
196
197#[cfg(all(test, feature = "integration-tests"))]
198mod tests {
199 use super::*;
200 use scirs2_core::ndarray::ArrayD;
201 use tensorlogic_compiler::compile_to_einsum;
202 use tensorlogic_ir::{TLExpr, Term};
203
204 fn create_test_tensor(shape: &[usize], value: f64) -> ArrayD<f64> {
205 ArrayD::from_elem(shape.to_vec(), value)
206 }
207
208 #[test]
209 fn test_batch_executor_basic() {
210 let x = TLExpr::pred("x", vec![Term::var("i")]);
211 let y = TLExpr::pred("y", vec![Term::var("i")]);
212 let expr = TLExpr::add(x, y);
213 let graph = compile_to_einsum(&expr).unwrap();
214
215 let mut executor = Scirs2Exec::new();
216
217 let batch_inputs = vec![
219 vec![create_test_tensor(&[5], 1.0), create_test_tensor(&[5], 2.0)],
220 vec![create_test_tensor(&[5], 3.0), create_test_tensor(&[5], 4.0)],
221 vec![create_test_tensor(&[5], 5.0), create_test_tensor(&[5], 6.0)],
222 ];
223
224 let result = executor.execute_batch(&graph, batch_inputs).unwrap();
225
226 assert_eq!(result.len(), 3);
227 assert!((result.outputs[0][0] - 3.0).abs() < 1e-6); assert!((result.outputs[1][0] - 7.0).abs() < 1e-6); assert!((result.outputs[2][0] - 11.0).abs() < 1e-6); assert_eq!(result.batch_size, 3);
231 }
232
233 #[test]
234 fn test_optimal_batch_size() {
235 let executor = Scirs2Exec::new();
236
237 let batch_size = executor.optimal_batch_size();
238 assert!(batch_size > 0);
239 assert!(batch_size <= 128); }
241
242 #[test]
243 fn test_parallel_batch_executor() {
244 let x = TLExpr::pred("x", vec![Term::var("i")]);
245 let expr = TLExpr::mul(x.clone(), x);
246 let graph = compile_to_einsum(&expr).unwrap();
247
248 let executor = ParallelBatchExecutor::new();
249
250 let batch_inputs = vec![
251 vec![create_test_tensor(&[3], 2.0)],
252 vec![create_test_tensor(&[3], 3.0)],
253 ];
254
255 let result = executor.execute_parallel(&graph, batch_inputs).unwrap();
256
257 assert_eq!(result.len(), 2);
258 assert!((result.outputs[0][0] - 4.0).abs() < 1e-6); assert!((result.outputs[1][0] - 9.0).abs() < 1e-6); }
261
262 #[test]
263 fn test_empty_batch_error() {
264 let x = TLExpr::pred("x", vec![Term::var("i")]);
265 let graph = compile_to_einsum(&x).unwrap();
266
267 let mut executor = Scirs2Exec::new();
268 let batch_inputs: Vec<Vec<ArrayD<f64>>> = vec![];
269
270 let result = executor.execute_batch(&graph, batch_inputs);
271 assert!(result.is_err());
272 }
273
274 #[test]
275 fn test_batch_parallel_same_as_sequential() {
276 let x = TLExpr::pred("x", vec![Term::var("i")]);
277 let y = TLExpr::pred("y", vec![Term::var("i")]);
278 let expr = TLExpr::add(x, y);
279 let graph = compile_to_einsum(&expr).unwrap();
280
281 let batch_inputs = vec![
282 vec![create_test_tensor(&[3], 1.0), create_test_tensor(&[3], 2.0)],
283 vec![create_test_tensor(&[3], 3.0), create_test_tensor(&[3], 4.0)],
284 ];
285
286 let mut executor = Scirs2Exec::new();
287 let result_seq = executor
288 .execute_batch(&graph, batch_inputs.clone())
289 .unwrap();
290
291 let mut executor2 = Scirs2Exec::new();
292 let result_par = executor2
293 .execute_batch_parallel(&graph, batch_inputs, None)
294 .unwrap();
295
296 assert_eq!(result_seq.len(), result_par.len());
297 for (seq, par) in result_seq.outputs.iter().zip(result_par.outputs.iter()) {
298 assert_eq!(seq.shape(), par.shape());
299 for (s, p) in seq.iter().zip(par.iter()) {
300 assert!((s - p).abs() < 1e-10);
301 }
302 }
303 }
304}