1#[cfg(feature = "parallel")]
26use scirs2_core::parallel_ops::*;
27
28#[cfg(feature = "parallel")]
29use std::sync::{Arc, Mutex};
30use tensorlogic_infer::{ElemOp, ExecutorError, ReduceOp, TlAutodiff, TlExecutor};
31#[cfg(not(feature = "parallel"))]
32use tensorlogic_ir::EinsumGraph;
33#[cfg(feature = "parallel")]
34use tensorlogic_ir::{EinsumGraph, OpType};
35
36use crate::autodiff::ForwardTape;
37#[cfg(feature = "parallel")]
38use crate::dependency_analyzer::DependencyAnalysis;
39#[cfg(feature = "parallel")]
40use crate::ops::{parse_elem_op, parse_reduce_op};
41use crate::Scirs2Tensor;
42
43#[derive(Debug, Clone)]
45pub struct ParallelConfig {
46 pub num_threads: Option<usize>,
48 pub min_parallel_ops: usize,
51 pub enable_pooling: bool,
53}
54
55impl Default for ParallelConfig {
56 fn default() -> Self {
57 Self {
58 num_threads: None, min_parallel_ops: 2,
60 enable_pooling: true,
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct ParallelStats {
68 pub num_levels: usize,
70 pub parallel_ops: usize,
72 pub sequential_ops: usize,
74 pub max_parallelism: usize,
76 pub estimated_speedup: f64,
78}
79
80pub struct ParallelScirs2Exec {
82 pub(crate) base: crate::executor::Scirs2Exec,
84 pub config: ParallelConfig,
86 pub stats: Option<ParallelStats>,
88}
89
90impl ParallelScirs2Exec {
91 pub fn new() -> Self {
93 Self {
94 base: crate::executor::Scirs2Exec::new(),
95 config: ParallelConfig::default(),
96 stats: None,
97 }
98 }
99
100 pub fn with_config(config: ParallelConfig) -> Self {
102 let base = if config.enable_pooling {
103 crate::executor::Scirs2Exec::with_memory_pool()
104 } else {
105 crate::executor::Scirs2Exec::new()
106 };
107
108 Self {
109 base,
110 config,
111 stats: None,
112 }
113 }
114
115 pub fn set_num_threads(&mut self, num_threads: usize) {
117 self.config.num_threads = Some(num_threads);
118 }
119
120 #[cfg(feature = "parallel")]
122 pub fn num_threads(&self) -> usize {
123 self.config.num_threads.unwrap_or_else(current_num_threads)
124 }
125
126 #[cfg(not(feature = "parallel"))]
127 pub fn num_threads(&self) -> usize {
128 self.config.num_threads.unwrap_or(1)
129 }
130
131 pub fn set_pooling(&mut self, enable: bool) {
133 self.config.enable_pooling = enable;
134 if enable {
135 self.base.enable_pooling();
136 } else {
137 self.base.disable_pooling();
138 }
139 }
140
141 pub fn pool_stats(&self) -> Option<crate::memory_pool::PoolStats> {
143 self.base.pool_stats()
144 }
145
146 pub fn execution_stats(&self) -> Option<&ParallelStats> {
148 self.stats.as_ref()
149 }
150
151 pub fn add_tensor(&mut self, name: impl Into<String>, tensor: Scirs2Tensor) {
153 self.base.add_tensor(name, tensor);
154 }
155
156 pub fn get_tensor(&self, name: &str) -> Option<&Scirs2Tensor> {
158 self.base.get_tensor(name)
159 }
160
161 #[cfg(feature = "parallel")]
163 fn execute_operation(
164 &self,
165 node: &tensorlogic_ir::EinsumNode,
166 input_tensors: &[Scirs2Tensor],
167 ) -> Result<Scirs2Tensor, ExecutorError> {
168 match &node.op {
170 OpType::Einsum { spec } => {
171 let views: Vec<_> = input_tensors.iter().map(|t| t.view()).collect();
175 let view_refs: Vec<_> = views.iter().collect();
176 scirs2_linalg::einsum(spec, &view_refs)
177 .map_err(|e| ExecutorError::InvalidEinsumSpec(format!("Einsum error: {}", e)))
178 }
179 OpType::ElemUnary { op } => {
180 if input_tensors.len() != 1 {
181 return Err(ExecutorError::InvalidEinsumSpec(format!(
182 "Unary operation requires 1 input, got {}",
183 input_tensors.len()
184 )));
185 }
186 let elem_op = parse_elem_op(op)?;
187 match elem_op {
188 ElemOp::Relu => Ok(input_tensors[0].mapv(|v| v.max(0.0))),
189 ElemOp::Sigmoid => Ok(input_tensors[0].mapv(|v| 1.0 / (1.0 + (-v).exp()))),
190 ElemOp::OneMinus => Ok(input_tensors[0].mapv(|v| 1.0 - v)),
191 _ => Err(ExecutorError::UnsupportedOperation(format!(
192 "Unary operation {:?} not supported",
193 elem_op
194 ))),
195 }
196 }
197 OpType::ElemBinary { op } => {
198 if input_tensors.len() != 2 {
199 return Err(ExecutorError::InvalidEinsumSpec(format!(
200 "Binary operation requires 2 inputs, got {}",
201 input_tensors.len()
202 )));
203 }
204 let elem_op = parse_elem_op(op)?;
205 let x = &input_tensors[0];
206 let y = &input_tensors[1];
207
208 let x_is_scalar = x.ndim() == 0;
210 let y_is_scalar = y.ndim() == 0;
211
212 let (x_broadcast, y_broadcast);
213 let (x_ref, y_ref) = if x_is_scalar && !y_is_scalar {
214 let scalar_value = x.iter().next().unwrap();
215 x_broadcast =
216 scirs2_core::ndarray::Array::from_elem(y.raw_dim(), *scalar_value);
217 (&x_broadcast.view(), &y.view())
218 } else if y_is_scalar && !x_is_scalar {
219 let scalar_value = y.iter().next().unwrap();
220 y_broadcast =
221 scirs2_core::ndarray::Array::from_elem(x.raw_dim(), *scalar_value);
222 (&x.view(), &y_broadcast.view())
223 } else if x.shape() != y.shape() {
224 return Err(ExecutorError::ShapeMismatch(format!(
225 "Shape mismatch: {:?} vs {:?}",
226 x.shape(),
227 y.shape()
228 )));
229 } else {
230 (&x.view(), &y.view())
231 };
232
233 let result = match elem_op {
234 ElemOp::Add => x_ref + y_ref,
235 ElemOp::Subtract => x_ref - y_ref,
236 ElemOp::Multiply => x_ref * y_ref,
237 ElemOp::Divide => x_ref / y_ref,
238 ElemOp::Min => scirs2_core::ndarray::Zip::from(x_ref)
239 .and(y_ref)
240 .map_collect(|&a, &b| a.min(b)),
241 ElemOp::Max => scirs2_core::ndarray::Zip::from(x_ref)
242 .and(y_ref)
243 .map_collect(|&a, &b| a.max(b)),
244 ElemOp::OrMax => scirs2_core::ndarray::Zip::from(x_ref)
245 .and(y_ref)
246 .map_collect(|&a, &b| a.max(b)),
247 ElemOp::OrProbSum => scirs2_core::ndarray::Zip::from(x_ref)
248 .and(y_ref)
249 .map_collect(|&a, &b| a + b - a * b),
250 ElemOp::Nand => scirs2_core::ndarray::Zip::from(x_ref)
251 .and(y_ref)
252 .map_collect(|&a, &b| 1.0 - (a * b)),
253 ElemOp::Nor => scirs2_core::ndarray::Zip::from(x_ref)
254 .and(y_ref)
255 .map_collect(|&a, &b| 1.0 - a.max(b)),
256 ElemOp::Xor => scirs2_core::ndarray::Zip::from(x_ref)
257 .and(y_ref)
258 .map_collect(|&a, &b| a + b - 2.0 * a * b),
259 _ => {
260 return Err(ExecutorError::UnsupportedOperation(format!(
261 "Binary operation {:?} not supported",
262 elem_op
263 )))
264 }
265 };
266 Ok(result)
267 }
268 OpType::Reduce { op, axes } => {
269 if input_tensors.len() != 1 {
270 return Err(ExecutorError::InvalidEinsumSpec(format!(
271 "Reduce operation requires 1 input, got {}",
272 input_tensors.len()
273 )));
274 }
275 let reduce_op = parse_reduce_op(op)?;
276 let x = &input_tensors[0];
277
278 use scirs2_core::ndarray::Axis;
279 let mut result = x.clone();
280 for &axis in axes.iter().rev() {
281 result = match reduce_op {
282 ReduceOp::Sum => result.sum_axis(Axis(axis)),
283 ReduceOp::Max => result.map_axis(Axis(axis), |view| {
284 view.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
285 }),
286 ReduceOp::Min => result.map_axis(Axis(axis), |view| {
287 view.iter().fold(f64::INFINITY, |a, &b| a.min(b))
288 }),
289 ReduceOp::Mean => {
290 let sum = result.sum_axis(Axis(axis));
291 let count = result.len_of(Axis(axis)) as f64;
292 sum / count
293 }
294 ReduceOp::Product => {
295 result.map_axis(Axis(axis), |view| view.iter().product())
296 }
297 };
298 }
299 Ok(result)
300 }
301 }
302 }
303}
304
305impl Default for ParallelScirs2Exec {
306 fn default() -> Self {
307 Self::new()
308 }
309}
310
311impl TlExecutor for ParallelScirs2Exec {
313 type Tensor = Scirs2Tensor;
314 type Error = ExecutorError;
315
316 fn einsum(&mut self, spec: &str, inputs: &[Self::Tensor]) -> Result<Self::Tensor, Self::Error> {
317 self.base.einsum(spec, inputs)
318 }
319
320 fn elem_op(&mut self, op: ElemOp, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
321 self.base.elem_op(op, x)
322 }
323
324 fn elem_op_binary(
325 &mut self,
326 op: ElemOp,
327 x: &Self::Tensor,
328 y: &Self::Tensor,
329 ) -> Result<Self::Tensor, Self::Error> {
330 self.base.elem_op_binary(op, x, y)
331 }
332
333 fn reduce(
334 &mut self,
335 op: ReduceOp,
336 x: &Self::Tensor,
337 axes: &[usize],
338 ) -> Result<Self::Tensor, Self::Error> {
339 self.base.reduce(op, x, axes)
340 }
341}
342
343#[cfg(feature = "parallel")]
344impl TlAutodiff for ParallelScirs2Exec {
345 type Tape = ForwardTape;
346
347 fn forward(&mut self, graph: &EinsumGraph) -> Result<Self::Tensor, Self::Error> {
348 if graph.is_empty() {
349 return Err(ExecutorError::InvalidEinsumSpec(
350 "Empty graph provided".to_string(),
351 ));
352 }
353
354 if graph.outputs.is_empty() {
355 return Err(ExecutorError::InvalidEinsumSpec(
356 "No output tensors specified".to_string(),
357 ));
358 }
359
360 let analysis = DependencyAnalysis::analyze(graph);
362
363 let computed_tensors: Arc<Mutex<Vec<Option<Scirs2Tensor>>>> =
365 Arc::new(Mutex::new(vec![None; graph.tensors.len()]));
366
367 let node_inputs: Arc<Mutex<Vec<Vec<Scirs2Tensor>>>> =
368 Arc::new(Mutex::new(Vec::with_capacity(graph.nodes.len())));
369
370 {
372 let mut storage = computed_tensors.lock().unwrap();
373 for (idx, tensor_name) in graph.tensors.iter().enumerate() {
374 if let Some(tensor) = self.base.tensors.get(tensor_name) {
375 storage[idx] = Some(tensor.clone());
376 } else {
377 let base_name = tensor_name.split('[').next().unwrap_or(tensor_name);
379 if let Some(tensor) = self.base.tensors.get(base_name) {
380 storage[idx] = Some(tensor.clone());
381 } else if tensor_name.starts_with("const_") || base_name.starts_with("const_") {
382 let const_name = if tensor_name.starts_with("const_") {
384 tensor_name
385 } else {
386 base_name
387 };
388 if let Some(value_str) = const_name.strip_prefix("const_") {
389 if let Ok(value) = value_str.parse::<f64>() {
390 use scirs2_core::ndarray::arr0;
391 storage[idx] = Some(arr0(value).into_dyn());
392 }
393 }
394 }
395 }
396 }
397 }
398
399 let mut parallel_ops = 0;
401 let mut sequential_ops = 0;
402
403 for level_ops in &analysis.execution_levels {
405 let should_parallelize = level_ops.len() >= self.config.min_parallel_ops;
406
407 if should_parallelize {
408 parallel_ops += level_ops.len();
410
411 let results: Vec<_> = level_ops
413 .par_iter()
414 .map(|&op_idx| {
415 let node = &graph.nodes[op_idx];
416
417 let inputs: Result<Vec<_>, _> = {
419 let storage = computed_tensors.lock().unwrap();
420 node.inputs
421 .iter()
422 .map(|&idx| {
423 storage
424 .get(idx)
425 .and_then(|t| t.as_ref())
426 .cloned()
427 .ok_or_else(|| {
428 ExecutorError::TensorNotFound(format!(
429 "Tensor at index {} not found",
430 idx
431 ))
432 })
433 })
434 .collect()
435 };
436
437 let input_tensors = inputs?;
438 let result = self.execute_operation(node, &input_tensors)?;
439
440 Ok((op_idx, node.outputs.clone(), input_tensors, result))
441 })
442 .collect::<Result<Vec<_>, ExecutorError>>()?;
443
444 {
446 let mut storage = computed_tensors.lock().unwrap();
447 let mut inputs_vec = node_inputs.lock().unwrap();
448
449 while inputs_vec.len()
451 <= results.iter().map(|(idx, _, _, _)| *idx).max().unwrap_or(0)
452 {
453 inputs_vec.push(Vec::new());
454 }
455
456 for (op_idx, outputs, inputs, tensor) in results {
457 if let Some(&output_idx) = outputs.first() {
459 storage[output_idx] = Some(tensor);
460 }
461
462 inputs_vec[op_idx] = inputs;
464 }
465 }
466 } else {
467 sequential_ops += level_ops.len();
469
470 let mut storage = computed_tensors.lock().unwrap();
471 let mut inputs_vec = node_inputs.lock().unwrap();
472
473 for &op_idx in level_ops {
474 let node = &graph.nodes[op_idx];
475
476 let inputs: Result<Vec<_>, _> = node
477 .inputs
478 .iter()
479 .map(|&idx| {
480 storage
481 .get(idx)
482 .and_then(|t| t.as_ref())
483 .cloned()
484 .ok_or_else(|| {
485 ExecutorError::TensorNotFound(format!(
486 "Tensor at index {} not found",
487 idx
488 ))
489 })
490 })
491 .collect();
492
493 let input_tensors = inputs?;
494 let result = self.execute_operation(node, &input_tensors)?;
495
496 if let Some(&output_idx) = node.outputs.first() {
498 storage[output_idx] = Some(result);
499 }
500
501 while inputs_vec.len() <= op_idx {
503 inputs_vec.push(Vec::new());
504 }
505 inputs_vec[op_idx] = input_tensors;
506 }
507 }
508 }
509
510 let final_tensors = Arc::try_unwrap(computed_tensors)
512 .unwrap()
513 .into_inner()
514 .unwrap();
515 let final_inputs = Arc::try_unwrap(node_inputs).unwrap().into_inner().unwrap();
516
517 self.base.tape = Some(ForwardTape {
518 tensors: final_tensors.clone(),
519 node_inputs: final_inputs,
520 });
521
522 self.stats = Some(ParallelStats {
524 num_levels: analysis.num_levels,
525 parallel_ops,
526 sequential_ops,
527 max_parallelism: analysis.max_parallelism,
528 estimated_speedup: analysis.estimated_speedup(),
529 });
530
531 let output_idx = graph.outputs[0];
533 final_tensors
534 .get(output_idx)
535 .and_then(|t| t.clone())
536 .ok_or_else(|| ExecutorError::TensorNotFound("Output tensor not computed".to_string()))
537 }
538
539 fn backward(
540 &mut self,
541 graph: &EinsumGraph,
542 loss_grad: &Self::Tensor,
543 ) -> Result<Self::Tape, Self::Error> {
544 self.base.backward(graph, loss_grad)
547 }
548}
549
550#[cfg(not(feature = "parallel"))]
552impl TlAutodiff for ParallelScirs2Exec {
553 type Tape = ForwardTape;
554
555 fn forward(&mut self, graph: &EinsumGraph) -> Result<Self::Tensor, Self::Error> {
556 self.base.forward(graph)
558 }
559
560 fn backward(
561 &mut self,
562 graph: &EinsumGraph,
563 loss_grad: &Self::Tensor,
564 ) -> Result<Self::Tape, Self::Error> {
565 self.base.backward(graph, loss_grad)
566 }
567}
568
569#[cfg(test)]
570#[cfg(feature = "parallel")]
571mod tests {
572 use super::*;
573 use scirs2_core::ndarray::array;
574 use tensorlogic_ir::EinsumNode;
575
576 fn create_parallel_test_graph() -> EinsumGraph {
577 let mut graph = EinsumGraph::new();
585
586 let a_idx = graph.add_tensor("a"); let b_idx = graph.add_tensor("b"); let c_idx = graph.add_tensor("c"); let d_idx = graph.add_tensor("d"); let e_idx = graph.add_tensor("e"); let f_idx = graph.add_tensor("f"); graph.add_input(a_idx).unwrap();
594 graph.add_input(b_idx).unwrap();
595
596 graph
598 .add_node(EinsumNode {
599 op: OpType::ElemUnary {
600 op: "relu".to_string(),
601 },
602 inputs: vec![a_idx],
603 outputs: vec![c_idx],
604 metadata: None,
605 })
606 .unwrap();
607
608 graph
610 .add_node(EinsumNode {
611 op: OpType::ElemUnary {
612 op: "sigmoid".to_string(),
613 },
614 inputs: vec![b_idx],
615 outputs: vec![d_idx],
616 metadata: None,
617 })
618 .unwrap();
619
620 graph
622 .add_node(EinsumNode {
623 op: OpType::ElemBinary {
624 op: "add".to_string(),
625 },
626 inputs: vec![c_idx, d_idx],
627 outputs: vec![e_idx],
628 metadata: None,
629 })
630 .unwrap();
631
632 graph
634 .add_node(EinsumNode {
635 op: OpType::ElemUnary {
636 op: "relu".to_string(),
637 },
638 inputs: vec![e_idx],
639 outputs: vec![f_idx],
640 metadata: None,
641 })
642 .unwrap();
643
644 graph.add_output(f_idx).unwrap();
645
646 graph
647 }
648
649 #[test]
650 fn test_parallel_executor_creation() {
651 let executor = ParallelScirs2Exec::new();
652 assert_eq!(executor.config.min_parallel_ops, 2);
653 assert!(executor.config.enable_pooling);
654 }
655
656 #[test]
657 fn test_set_num_threads() {
658 let mut executor = ParallelScirs2Exec::new();
659 executor.set_num_threads(4);
660 assert_eq!(executor.config.num_threads, Some(4));
661 }
662
663 #[test]
664 fn test_parallel_forward_pass() {
665 let graph = create_parallel_test_graph();
666 let mut executor = ParallelScirs2Exec::new();
667
668 executor.add_tensor("a", array![-1.0, 2.0, -3.0].into_dyn());
669 executor.add_tensor("b", array![0.0, 1.0, 2.0].into_dyn());
670
671 let result = executor.forward(&graph).unwrap();
672
673 assert_eq!(result.shape(), &[3]);
675
676 let stats = executor.execution_stats().unwrap();
678 assert_eq!(stats.num_levels, 3);
679 assert!(stats.parallel_ops >= 2); }
681
682 #[test]
683 fn test_parallel_vs_sequential_correctness() {
684 let graph = create_parallel_test_graph();
685
686 let mut parallel_exec = ParallelScirs2Exec::new();
688 parallel_exec.add_tensor("a", array![-1.0, 2.0, -3.0].into_dyn());
689 parallel_exec.add_tensor("b", array![0.0, 1.0, 2.0].into_dyn());
690 let parallel_result = parallel_exec.forward(&graph).unwrap();
691
692 let mut sequential_exec = crate::executor::Scirs2Exec::new();
694 sequential_exec.add_tensor("a", array![-1.0, 2.0, -3.0].into_dyn());
695 sequential_exec.add_tensor("b", array![0.0, 1.0, 2.0].into_dyn());
696 let sequential_result = sequential_exec.forward(&graph).unwrap();
697
698 assert_eq!(parallel_result.shape(), sequential_result.shape());
700
701 for (p, s) in parallel_result.iter().zip(sequential_result.iter()) {
702 assert!((p - s).abs() < 1e-10);
703 }
704 }
705
706 #[test]
707 fn test_parallel_stats() {
708 let graph = create_parallel_test_graph();
709 let mut executor = ParallelScirs2Exec::new();
710
711 executor.add_tensor("a", array![1.0, 2.0].into_dyn());
712 executor.add_tensor("b", array![3.0, 4.0].into_dyn());
713
714 executor.forward(&graph).unwrap();
715
716 let stats = executor.execution_stats().unwrap();
717 assert_eq!(stats.num_levels, 3);
718 assert!(stats.max_parallelism >= 2);
719 assert!(stats.estimated_speedup > 1.0);
720 }
721
722 #[test]
723 fn test_pooling_integration() {
724 let graph = create_parallel_test_graph();
725 let mut executor = ParallelScirs2Exec::new();
726 executor.set_pooling(true);
727
728 executor.add_tensor("a", array![1.0, 2.0].into_dyn());
729 executor.add_tensor("b", array![3.0, 4.0].into_dyn());
730
731 executor.forward(&graph).unwrap();
732
733 let _pool_stats = executor.pool_stats();
735 }
737
738 #[test]
739 fn test_min_parallel_ops_threshold() {
740 let mut graph = EinsumGraph::new();
742
743 let a_idx = graph.add_tensor("a");
744 let b_idx = graph.add_tensor("b");
745
746 graph.add_input(a_idx).unwrap();
747
748 graph
750 .add_node(EinsumNode {
751 op: OpType::ElemUnary {
752 op: "relu".to_string(),
753 },
754 inputs: vec![a_idx],
755 outputs: vec![b_idx],
756 metadata: None,
757 })
758 .unwrap();
759
760 graph.add_output(b_idx).unwrap();
761
762 let mut executor = ParallelScirs2Exec::new();
763 executor.add_tensor("a", array![1.0, 2.0, 3.0].into_dyn());
764
765 executor.forward(&graph).unwrap();
766
767 let stats = executor.execution_stats().unwrap();
768 assert_eq!(stats.sequential_ops, 1);
770 assert_eq!(stats.parallel_ops, 0);
771 }
772
773 #[test]
774 fn test_backward_pass_with_parallel() {
775 let graph = create_parallel_test_graph();
776 let mut executor = ParallelScirs2Exec::new();
777
778 executor.add_tensor("a", array![1.0, 2.0, 3.0].into_dyn());
779 executor.add_tensor("b", array![0.5, 1.0, 1.5].into_dyn());
780
781 executor.forward(&graph).unwrap();
782
783 let loss_grad = array![1.0, 1.0, 1.0].into_dyn();
785
786 let tape = executor.backward(&graph, &loss_grad).unwrap();
787
788 assert!(!tape.is_empty());
790 }
791}