Skip to main content

tensorlogic_scirs_backend/
parallel_executor.rs

1//! Parallel executor implementation using Rayon for multi-threaded execution.
2//!
3//! This module provides a parallel implementation of the TensorLogic executor
4//! that can execute independent operations concurrently using thread pools.
5//!
6//! ## Key Features
7//!
8//! - **Level-by-level execution**: Operations are grouped by execution level
9//! - **Rayon thread pools**: Configurable thread pool for parallel execution
10//! - **Automatic dependency handling**: Uses DependencyAnalysis for safe parallelization
11//! - **Performance monitoring**: Tracks parallel vs sequential execution times
12//!
13//! ## Example
14//!
15//! ```rust,ignore
16//! use tensorlogic_scirs_backend::ParallelScirs2Exec;
17//! use tensorlogic_infer::TlAutodiff;
18//!
19//! let mut executor = ParallelScirs2Exec::new();
20//! executor.set_num_threads(4); // Use 4 threads
21//!
22//! let result = executor.forward(&graph)?;
23//! ```
24
25#[cfg(feature = "parallel")]
26use scirs2_core::parallel_ops::*;
27
28#[cfg(feature = "parallel")]
29use std::sync::{Arc, Mutex};
30use tensorlogic_infer::{ElemOp, ExecutorError, ReduceOp, TlAutodiff, TlExecutor};
31#[cfg(not(feature = "parallel"))]
32use tensorlogic_ir::EinsumGraph;
33#[cfg(feature = "parallel")]
34use tensorlogic_ir::{EinsumGraph, OpType};
35
36use crate::autodiff::ForwardTape;
37#[cfg(feature = "parallel")]
38use crate::dependency_analyzer::DependencyAnalysis;
39#[cfg(feature = "parallel")]
40use crate::ops::{parse_elem_op, parse_reduce_op};
41use crate::Scirs2Tensor;
42
43/// Configuration for parallel execution.
44#[derive(Debug, Clone)]
45pub struct ParallelConfig {
46    /// Number of threads to use (None = use all available cores)
47    pub num_threads: Option<usize>,
48    /// Minimum number of operations per level to enable parallelization
49    /// (levels with fewer ops run sequentially to avoid overhead)
50    pub min_parallel_ops: usize,
51    /// Enable memory pooling for tensor reuse
52    pub enable_pooling: bool,
53}
54
55impl Default for ParallelConfig {
56    fn default() -> Self {
57        Self {
58            num_threads: None, // Use all available cores
59            min_parallel_ops: 2,
60            enable_pooling: true,
61        }
62    }
63}
64
65/// Statistics about parallel execution.
66#[derive(Debug, Clone)]
67pub struct ParallelStats {
68    /// Number of execution levels
69    pub num_levels: usize,
70    /// Number of operations executed in parallel
71    pub parallel_ops: usize,
72    /// Number of operations executed sequentially
73    pub sequential_ops: usize,
74    /// Maximum number of concurrent operations in any level
75    pub max_parallelism: usize,
76    /// Estimated speedup from parallelization
77    pub estimated_speedup: f64,
78}
79
80/// Parallel executor using Rayon for multi-threaded execution.
81pub struct ParallelScirs2Exec {
82    /// Base executor for sequential operations
83    pub(crate) base: crate::executor::Scirs2Exec,
84    /// Configuration for parallel execution
85    pub config: ParallelConfig,
86    /// Statistics from last execution
87    pub stats: Option<ParallelStats>,
88}
89
90impl ParallelScirs2Exec {
91    /// Create a new parallel executor with default configuration.
92    pub fn new() -> Self {
93        Self {
94            base: crate::executor::Scirs2Exec::new(),
95            config: ParallelConfig::default(),
96            stats: None,
97        }
98    }
99
100    /// Create a parallel executor with custom configuration.
101    pub fn with_config(config: ParallelConfig) -> Self {
102        let base = if config.enable_pooling {
103            crate::executor::Scirs2Exec::with_memory_pool()
104        } else {
105            crate::executor::Scirs2Exec::new()
106        };
107
108        Self {
109            base,
110            config,
111            stats: None,
112        }
113    }
114
115    /// Set the number of threads to use.
116    pub fn set_num_threads(&mut self, num_threads: usize) {
117        self.config.num_threads = Some(num_threads);
118    }
119
120    /// Get the number of threads configured (returns actual thread count).
121    #[cfg(feature = "parallel")]
122    pub fn num_threads(&self) -> usize {
123        self.config.num_threads.unwrap_or_else(current_num_threads)
124    }
125
126    #[cfg(not(feature = "parallel"))]
127    pub fn num_threads(&self) -> usize {
128        self.config.num_threads.unwrap_or(1)
129    }
130
131    /// Enable or disable memory pooling.
132    pub fn set_pooling(&mut self, enable: bool) {
133        self.config.enable_pooling = enable;
134        if enable {
135            self.base.enable_pooling();
136        } else {
137            self.base.disable_pooling();
138        }
139    }
140
141    /// Get pool statistics if pooling is enabled.
142    pub fn pool_stats(&self) -> Option<crate::memory_pool::PoolStats> {
143        self.base.pool_stats()
144    }
145
146    /// Get statistics from the last execution.
147    pub fn execution_stats(&self) -> Option<&ParallelStats> {
148        self.stats.as_ref()
149    }
150
151    /// Add a named tensor to the executor.
152    pub fn add_tensor(&mut self, name: impl Into<String>, tensor: Scirs2Tensor) {
153        self.base.add_tensor(name, tensor);
154    }
155
156    /// Get a tensor by name.
157    pub fn get_tensor(&self, name: &str) -> Option<&Scirs2Tensor> {
158        self.base.get_tensor(name)
159    }
160
161    /// Execute a single operation (helper function).
162    #[cfg(feature = "parallel")]
163    fn execute_operation(
164        &self,
165        node: &tensorlogic_ir::EinsumNode,
166        input_tensors: &[Scirs2Tensor],
167    ) -> Result<Scirs2Tensor, ExecutorError> {
168        // Dispatch based on operation type
169        match &node.op {
170            OpType::Einsum { spec } => {
171                // Need to use a mutable executor for einsum
172                // For now, we'll use the sequential path through self.base
173                // In a real parallel implementation, we'd need to handle this differently
174                let views: Vec<_> = input_tensors.iter().map(|t| t.view()).collect();
175                let view_refs: Vec<_> = views.iter().collect();
176                scirs2_linalg::einsum(spec, &view_refs)
177                    .map_err(|e| ExecutorError::InvalidEinsumSpec(format!("Einsum error: {}", e)))
178            }
179            OpType::ElemUnary { op } => {
180                if input_tensors.len() != 1 {
181                    return Err(ExecutorError::InvalidEinsumSpec(format!(
182                        "Unary operation requires 1 input, got {}",
183                        input_tensors.len()
184                    )));
185                }
186                let elem_op = parse_elem_op(op)?;
187                match elem_op {
188                    ElemOp::Relu => Ok(input_tensors[0].mapv(|v| v.max(0.0))),
189                    ElemOp::Sigmoid => Ok(input_tensors[0].mapv(|v| 1.0 / (1.0 + (-v).exp()))),
190                    ElemOp::OneMinus => Ok(input_tensors[0].mapv(|v| 1.0 - v)),
191                    _ => Err(ExecutorError::UnsupportedOperation(format!(
192                        "Unary operation {:?} not supported",
193                        elem_op
194                    ))),
195                }
196            }
197            OpType::ElemBinary { op } => {
198                if input_tensors.len() != 2 {
199                    return Err(ExecutorError::InvalidEinsumSpec(format!(
200                        "Binary operation requires 2 inputs, got {}",
201                        input_tensors.len()
202                    )));
203                }
204                let elem_op = parse_elem_op(op)?;
205                let x = &input_tensors[0];
206                let y = &input_tensors[1];
207
208                // Handle scalar broadcasting
209                let x_is_scalar = x.ndim() == 0;
210                let y_is_scalar = y.ndim() == 0;
211
212                let (x_broadcast, y_broadcast);
213                let (x_ref, y_ref) = if x_is_scalar && !y_is_scalar {
214                    let scalar_value = x.iter().next().unwrap();
215                    x_broadcast =
216                        scirs2_core::ndarray::Array::from_elem(y.raw_dim(), *scalar_value);
217                    (&x_broadcast.view(), &y.view())
218                } else if y_is_scalar && !x_is_scalar {
219                    let scalar_value = y.iter().next().unwrap();
220                    y_broadcast =
221                        scirs2_core::ndarray::Array::from_elem(x.raw_dim(), *scalar_value);
222                    (&x.view(), &y_broadcast.view())
223                } else if x.shape() != y.shape() {
224                    return Err(ExecutorError::ShapeMismatch(format!(
225                        "Shape mismatch: {:?} vs {:?}",
226                        x.shape(),
227                        y.shape()
228                    )));
229                } else {
230                    (&x.view(), &y.view())
231                };
232
233                let result = match elem_op {
234                    ElemOp::Add => x_ref + y_ref,
235                    ElemOp::Subtract => x_ref - y_ref,
236                    ElemOp::Multiply => x_ref * y_ref,
237                    ElemOp::Divide => x_ref / y_ref,
238                    ElemOp::Min => scirs2_core::ndarray::Zip::from(x_ref)
239                        .and(y_ref)
240                        .map_collect(|&a, &b| a.min(b)),
241                    ElemOp::Max => scirs2_core::ndarray::Zip::from(x_ref)
242                        .and(y_ref)
243                        .map_collect(|&a, &b| a.max(b)),
244                    ElemOp::OrMax => scirs2_core::ndarray::Zip::from(x_ref)
245                        .and(y_ref)
246                        .map_collect(|&a, &b| a.max(b)),
247                    ElemOp::OrProbSum => scirs2_core::ndarray::Zip::from(x_ref)
248                        .and(y_ref)
249                        .map_collect(|&a, &b| a + b - a * b),
250                    ElemOp::Nand => scirs2_core::ndarray::Zip::from(x_ref)
251                        .and(y_ref)
252                        .map_collect(|&a, &b| 1.0 - (a * b)),
253                    ElemOp::Nor => scirs2_core::ndarray::Zip::from(x_ref)
254                        .and(y_ref)
255                        .map_collect(|&a, &b| 1.0 - a.max(b)),
256                    ElemOp::Xor => scirs2_core::ndarray::Zip::from(x_ref)
257                        .and(y_ref)
258                        .map_collect(|&a, &b| a + b - 2.0 * a * b),
259                    _ => {
260                        return Err(ExecutorError::UnsupportedOperation(format!(
261                            "Binary operation {:?} not supported",
262                            elem_op
263                        )))
264                    }
265                };
266                Ok(result)
267            }
268            OpType::Reduce { op, axes } => {
269                if input_tensors.len() != 1 {
270                    return Err(ExecutorError::InvalidEinsumSpec(format!(
271                        "Reduce operation requires 1 input, got {}",
272                        input_tensors.len()
273                    )));
274                }
275                let reduce_op = parse_reduce_op(op)?;
276                let x = &input_tensors[0];
277
278                use scirs2_core::ndarray::Axis;
279                let mut result = x.clone();
280                for &axis in axes.iter().rev() {
281                    result = match reduce_op {
282                        ReduceOp::Sum => result.sum_axis(Axis(axis)),
283                        ReduceOp::Max => result.map_axis(Axis(axis), |view| {
284                            view.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
285                        }),
286                        ReduceOp::Min => result.map_axis(Axis(axis), |view| {
287                            view.iter().fold(f64::INFINITY, |a, &b| a.min(b))
288                        }),
289                        ReduceOp::Mean => {
290                            let sum = result.sum_axis(Axis(axis));
291                            let count = result.len_of(Axis(axis)) as f64;
292                            sum / count
293                        }
294                        ReduceOp::Product => {
295                            result.map_axis(Axis(axis), |view| view.iter().product())
296                        }
297                    };
298                }
299                Ok(result)
300            }
301        }
302    }
303}
304
305impl Default for ParallelScirs2Exec {
306    fn default() -> Self {
307        Self::new()
308    }
309}
310
311// Delegate basic TlExecutor methods to the base executor
312impl TlExecutor for ParallelScirs2Exec {
313    type Tensor = Scirs2Tensor;
314    type Error = ExecutorError;
315
316    fn einsum(&mut self, spec: &str, inputs: &[Self::Tensor]) -> Result<Self::Tensor, Self::Error> {
317        self.base.einsum(spec, inputs)
318    }
319
320    fn elem_op(&mut self, op: ElemOp, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
321        self.base.elem_op(op, x)
322    }
323
324    fn elem_op_binary(
325        &mut self,
326        op: ElemOp,
327        x: &Self::Tensor,
328        y: &Self::Tensor,
329    ) -> Result<Self::Tensor, Self::Error> {
330        self.base.elem_op_binary(op, x, y)
331    }
332
333    fn reduce(
334        &mut self,
335        op: ReduceOp,
336        x: &Self::Tensor,
337        axes: &[usize],
338    ) -> Result<Self::Tensor, Self::Error> {
339        self.base.reduce(op, x, axes)
340    }
341}
342
343#[cfg(feature = "parallel")]
344impl TlAutodiff for ParallelScirs2Exec {
345    type Tape = ForwardTape;
346
347    fn forward(&mut self, graph: &EinsumGraph) -> Result<Self::Tensor, Self::Error> {
348        if graph.is_empty() {
349            return Err(ExecutorError::InvalidEinsumSpec(
350                "Empty graph provided".to_string(),
351            ));
352        }
353
354        if graph.outputs.is_empty() {
355            return Err(ExecutorError::InvalidEinsumSpec(
356                "No output tensors specified".to_string(),
357            ));
358        }
359
360        // Analyze dependencies
361        let analysis = DependencyAnalysis::analyze(graph);
362
363        // Initialize tensor storage
364        let computed_tensors: Arc<Mutex<Vec<Option<Scirs2Tensor>>>> =
365            Arc::new(Mutex::new(vec![None; graph.tensors.len()]));
366
367        let node_inputs: Arc<Mutex<Vec<Vec<Scirs2Tensor>>>> =
368            Arc::new(Mutex::new(Vec::with_capacity(graph.nodes.len())));
369
370        // Initialize input tensors from our stored tensors
371        {
372            let mut storage = computed_tensors.lock().unwrap();
373            for (idx, tensor_name) in graph.tensors.iter().enumerate() {
374                if let Some(tensor) = self.base.tensors.get(tensor_name) {
375                    storage[idx] = Some(tensor.clone());
376                } else {
377                    // Handle tensors with axes notation (e.g., "age[a]" -> "age")
378                    let base_name = tensor_name.split('[').next().unwrap_or(tensor_name);
379                    if let Some(tensor) = self.base.tensors.get(base_name) {
380                        storage[idx] = Some(tensor.clone());
381                    } else if tensor_name.starts_with("const_") || base_name.starts_with("const_") {
382                        // Handle constant tensors
383                        let const_name = if tensor_name.starts_with("const_") {
384                            tensor_name
385                        } else {
386                            base_name
387                        };
388                        if let Some(value_str) = const_name.strip_prefix("const_") {
389                            if let Ok(value) = value_str.parse::<f64>() {
390                                use scirs2_core::ndarray::arr0;
391                                storage[idx] = Some(arr0(value).into_dyn());
392                            }
393                        }
394                    }
395                }
396            }
397        }
398
399        // Track statistics
400        let mut parallel_ops = 0;
401        let mut sequential_ops = 0;
402
403        // Execute operations level by level
404        for level_ops in &analysis.execution_levels {
405            let should_parallelize = level_ops.len() >= self.config.min_parallel_ops;
406
407            if should_parallelize {
408                // Parallel execution for this level
409                parallel_ops += level_ops.len();
410
411                // Execute all operations in this level in parallel
412                let results: Vec<_> = level_ops
413                    .par_iter()
414                    .map(|&op_idx| {
415                        let node = &graph.nodes[op_idx];
416
417                        // Read inputs from shared storage
418                        let inputs: Result<Vec<_>, _> = {
419                            let storage = computed_tensors.lock().unwrap();
420                            node.inputs
421                                .iter()
422                                .map(|&idx| {
423                                    storage
424                                        .get(idx)
425                                        .and_then(|t| t.as_ref())
426                                        .cloned()
427                                        .ok_or_else(|| {
428                                            ExecutorError::TensorNotFound(format!(
429                                                "Tensor at index {} not found",
430                                                idx
431                                            ))
432                                        })
433                                })
434                                .collect()
435                        };
436
437                        let input_tensors = inputs?;
438                        let result = self.execute_operation(node, &input_tensors)?;
439
440                        Ok((op_idx, node.outputs.clone(), input_tensors, result))
441                    })
442                    .collect::<Result<Vec<_>, ExecutorError>>()?;
443
444                // Store results
445                {
446                    let mut storage = computed_tensors.lock().unwrap();
447                    let mut inputs_vec = node_inputs.lock().unwrap();
448
449                    // Ensure node_inputs has enough capacity
450                    while inputs_vec.len()
451                        <= results.iter().map(|(idx, _, _, _)| *idx).max().unwrap_or(0)
452                    {
453                        inputs_vec.push(Vec::new());
454                    }
455
456                    for (op_idx, outputs, inputs, tensor) in results {
457                        // Store in tensor storage
458                        if let Some(&output_idx) = outputs.first() {
459                            storage[output_idx] = Some(tensor);
460                        }
461
462                        // Store inputs for backward pass
463                        inputs_vec[op_idx] = inputs;
464                    }
465                }
466            } else {
467                // Sequential execution for this level
468                sequential_ops += level_ops.len();
469
470                let mut storage = computed_tensors.lock().unwrap();
471                let mut inputs_vec = node_inputs.lock().unwrap();
472
473                for &op_idx in level_ops {
474                    let node = &graph.nodes[op_idx];
475
476                    let inputs: Result<Vec<_>, _> = node
477                        .inputs
478                        .iter()
479                        .map(|&idx| {
480                            storage
481                                .get(idx)
482                                .and_then(|t| t.as_ref())
483                                .cloned()
484                                .ok_or_else(|| {
485                                    ExecutorError::TensorNotFound(format!(
486                                        "Tensor at index {} not found",
487                                        idx
488                                    ))
489                                })
490                        })
491                        .collect();
492
493                    let input_tensors = inputs?;
494                    let result = self.execute_operation(node, &input_tensors)?;
495
496                    // Store result
497                    if let Some(&output_idx) = node.outputs.first() {
498                        storage[output_idx] = Some(result);
499                    }
500
501                    // Store inputs for backward pass
502                    while inputs_vec.len() <= op_idx {
503                        inputs_vec.push(Vec::new());
504                    }
505                    inputs_vec[op_idx] = input_tensors;
506                }
507            }
508        }
509
510        // Store tape for backward pass
511        let final_tensors = Arc::try_unwrap(computed_tensors)
512            .unwrap()
513            .into_inner()
514            .unwrap();
515        let final_inputs = Arc::try_unwrap(node_inputs).unwrap().into_inner().unwrap();
516
517        self.base.tape = Some(ForwardTape {
518            tensors: final_tensors.clone(),
519            node_inputs: final_inputs,
520        });
521
522        // Store statistics
523        self.stats = Some(ParallelStats {
524            num_levels: analysis.num_levels,
525            parallel_ops,
526            sequential_ops,
527            max_parallelism: analysis.max_parallelism,
528            estimated_speedup: analysis.estimated_speedup(),
529        });
530
531        // Return the output tensor
532        let output_idx = graph.outputs[0];
533        final_tensors
534            .get(output_idx)
535            .and_then(|t| t.clone())
536            .ok_or_else(|| ExecutorError::TensorNotFound("Output tensor not computed".to_string()))
537    }
538
539    fn backward(
540        &mut self,
541        graph: &EinsumGraph,
542        loss_grad: &Self::Tensor,
543    ) -> Result<Self::Tape, Self::Error> {
544        // Use the base executor's backward implementation
545        // (backward pass is typically more sequential due to dependency chains)
546        self.base.backward(graph, loss_grad)
547    }
548}
549
550// If parallel feature is not enabled, provide a non-parallel implementation
551#[cfg(not(feature = "parallel"))]
552impl TlAutodiff for ParallelScirs2Exec {
553    type Tape = ForwardTape;
554
555    fn forward(&mut self, graph: &EinsumGraph) -> Result<Self::Tensor, Self::Error> {
556        // Fall back to sequential execution
557        self.base.forward(graph)
558    }
559
560    fn backward(
561        &mut self,
562        graph: &EinsumGraph,
563        loss_grad: &Self::Tensor,
564    ) -> Result<Self::Tape, Self::Error> {
565        self.base.backward(graph, loss_grad)
566    }
567}
568
569#[cfg(test)]
570#[cfg(feature = "parallel")]
571mod tests {
572    use super::*;
573    use scirs2_core::ndarray::array;
574    use tensorlogic_ir::EinsumNode;
575
576    fn create_parallel_test_graph() -> EinsumGraph {
577        // Create a graph with parallelizable operations:
578        // Tensors: 0=a, 1=b, 2=c, 3=d, 4=e, 5=f
579        // Op0: c = relu(a)    (level 0, independent)
580        // Op1: d = sigmoid(b) (level 0, independent)
581        // Op2: e = c + d      (level 1, depends on Op0, Op1)
582        // Op3: f = relu(e)    (level 2, depends on Op2)
583
584        let mut graph = EinsumGraph::new();
585
586        let a_idx = graph.add_tensor("a"); // 0
587        let b_idx = graph.add_tensor("b"); // 1
588        let c_idx = graph.add_tensor("c"); // 2
589        let d_idx = graph.add_tensor("d"); // 3
590        let e_idx = graph.add_tensor("e"); // 4
591        let f_idx = graph.add_tensor("f"); // 5
592
593        graph.add_input(a_idx).unwrap();
594        graph.add_input(b_idx).unwrap();
595
596        // Op0: c = relu(a)
597        graph
598            .add_node(EinsumNode {
599                op: OpType::ElemUnary {
600                    op: "relu".to_string(),
601                },
602                inputs: vec![a_idx],
603                outputs: vec![c_idx],
604                metadata: None,
605            })
606            .unwrap();
607
608        // Op1: d = sigmoid(b)
609        graph
610            .add_node(EinsumNode {
611                op: OpType::ElemUnary {
612                    op: "sigmoid".to_string(),
613                },
614                inputs: vec![b_idx],
615                outputs: vec![d_idx],
616                metadata: None,
617            })
618            .unwrap();
619
620        // Op2: e = c + d
621        graph
622            .add_node(EinsumNode {
623                op: OpType::ElemBinary {
624                    op: "add".to_string(),
625                },
626                inputs: vec![c_idx, d_idx],
627                outputs: vec![e_idx],
628                metadata: None,
629            })
630            .unwrap();
631
632        // Op3: f = relu(e)
633        graph
634            .add_node(EinsumNode {
635                op: OpType::ElemUnary {
636                    op: "relu".to_string(),
637                },
638                inputs: vec![e_idx],
639                outputs: vec![f_idx],
640                metadata: None,
641            })
642            .unwrap();
643
644        graph.add_output(f_idx).unwrap();
645
646        graph
647    }
648
649    #[test]
650    fn test_parallel_executor_creation() {
651        let executor = ParallelScirs2Exec::new();
652        assert_eq!(executor.config.min_parallel_ops, 2);
653        assert!(executor.config.enable_pooling);
654    }
655
656    #[test]
657    fn test_set_num_threads() {
658        let mut executor = ParallelScirs2Exec::new();
659        executor.set_num_threads(4);
660        assert_eq!(executor.config.num_threads, Some(4));
661    }
662
663    #[test]
664    fn test_parallel_forward_pass() {
665        let graph = create_parallel_test_graph();
666        let mut executor = ParallelScirs2Exec::new();
667
668        executor.add_tensor("a", array![-1.0, 2.0, -3.0].into_dyn());
669        executor.add_tensor("b", array![0.0, 1.0, 2.0].into_dyn());
670
671        let result = executor.forward(&graph).unwrap();
672
673        // Verify output shape
674        assert_eq!(result.shape(), &[3]);
675
676        // Verify statistics
677        let stats = executor.execution_stats().unwrap();
678        assert_eq!(stats.num_levels, 3);
679        assert!(stats.parallel_ops >= 2); // Op0 and Op1 should run in parallel
680    }
681
682    #[test]
683    fn test_parallel_vs_sequential_correctness() {
684        let graph = create_parallel_test_graph();
685
686        // Execute with parallel executor
687        let mut parallel_exec = ParallelScirs2Exec::new();
688        parallel_exec.add_tensor("a", array![-1.0, 2.0, -3.0].into_dyn());
689        parallel_exec.add_tensor("b", array![0.0, 1.0, 2.0].into_dyn());
690        let parallel_result = parallel_exec.forward(&graph).unwrap();
691
692        // Execute with sequential executor
693        let mut sequential_exec = crate::executor::Scirs2Exec::new();
694        sequential_exec.add_tensor("a", array![-1.0, 2.0, -3.0].into_dyn());
695        sequential_exec.add_tensor("b", array![0.0, 1.0, 2.0].into_dyn());
696        let sequential_result = sequential_exec.forward(&graph).unwrap();
697
698        // Results should match
699        assert_eq!(parallel_result.shape(), sequential_result.shape());
700
701        for (p, s) in parallel_result.iter().zip(sequential_result.iter()) {
702            assert!((p - s).abs() < 1e-10);
703        }
704    }
705
706    #[test]
707    fn test_parallel_stats() {
708        let graph = create_parallel_test_graph();
709        let mut executor = ParallelScirs2Exec::new();
710
711        executor.add_tensor("a", array![1.0, 2.0].into_dyn());
712        executor.add_tensor("b", array![3.0, 4.0].into_dyn());
713
714        executor.forward(&graph).unwrap();
715
716        let stats = executor.execution_stats().unwrap();
717        assert_eq!(stats.num_levels, 3);
718        assert!(stats.max_parallelism >= 2);
719        assert!(stats.estimated_speedup > 1.0);
720    }
721
722    #[test]
723    fn test_pooling_integration() {
724        let graph = create_parallel_test_graph();
725        let mut executor = ParallelScirs2Exec::new();
726        executor.set_pooling(true);
727
728        executor.add_tensor("a", array![1.0, 2.0].into_dyn());
729        executor.add_tensor("b", array![3.0, 4.0].into_dyn());
730
731        executor.forward(&graph).unwrap();
732
733        // Pool should have some statistics (if pooling is used)
734        let _pool_stats = executor.pool_stats();
735        // Note: pool might not be used in forward pass, so this is optional
736    }
737
738    #[test]
739    fn test_min_parallel_ops_threshold() {
740        // Create a graph with only 1 independent operation
741        let mut graph = EinsumGraph::new();
742
743        let a_idx = graph.add_tensor("a");
744        let b_idx = graph.add_tensor("b");
745
746        graph.add_input(a_idx).unwrap();
747
748        // Single operation
749        graph
750            .add_node(EinsumNode {
751                op: OpType::ElemUnary {
752                    op: "relu".to_string(),
753                },
754                inputs: vec![a_idx],
755                outputs: vec![b_idx],
756                metadata: None,
757            })
758            .unwrap();
759
760        graph.add_output(b_idx).unwrap();
761
762        let mut executor = ParallelScirs2Exec::new();
763        executor.add_tensor("a", array![1.0, 2.0, 3.0].into_dyn());
764
765        executor.forward(&graph).unwrap();
766
767        let stats = executor.execution_stats().unwrap();
768        // Since there's only 1 op, it should run sequentially
769        assert_eq!(stats.sequential_ops, 1);
770        assert_eq!(stats.parallel_ops, 0);
771    }
772
773    #[test]
774    fn test_backward_pass_with_parallel() {
775        let graph = create_parallel_test_graph();
776        let mut executor = ParallelScirs2Exec::new();
777
778        executor.add_tensor("a", array![1.0, 2.0, 3.0].into_dyn());
779        executor.add_tensor("b", array![0.5, 1.0, 1.5].into_dyn());
780
781        executor.forward(&graph).unwrap();
782
783        // Backward pass
784        let loss_grad = array![1.0, 1.0, 1.0].into_dyn();
785
786        let tape = executor.backward(&graph, &loss_grad).unwrap();
787
788        // Should have gradients for inputs
789        assert!(!tape.is_empty());
790    }
791}