tenrso_exec/executor/
parallel.rs

1//! Parallel execution utilities for tensor operations
2//!
3//! This module provides parallel implementations of operations using scirs2-core's
4//! parallel execution capabilities (backed by rayon).
5
6use anyhow::Result;
7use scirs2_core::ndarray_ext::{Array, Axis as NdAxis, IxDyn, Zip};
8use scirs2_core::numeric::{Float, FromPrimitive, Num};
9use tenrso_core::{Axis, DenseND};
10
11/// Threshold for parallel execution (number of elements)
12/// Operations with fewer elements than this will run serially
13const PARALLEL_THRESHOLD: usize = 10_000;
14
15/// Check if tensor is large enough to benefit from parallelization
16#[inline]
17pub(crate) fn should_parallelize(shape: &[usize]) -> bool {
18    let total_elements: usize = shape.iter().product();
19    total_elements >= PARALLEL_THRESHOLD
20}
21
22/// Apply element-wise unary operation in parallel
23#[allow(dead_code)]
24pub(crate) fn parallel_unary<T, F>(input: &DenseND<T>, op: F) -> Result<DenseND<T>>
25where
26    T: Clone + Num + Send + Sync,
27    F: Fn(T) -> T + Send + Sync,
28{
29    let input_view = input.view();
30
31    if !should_parallelize(input.shape()) {
32        // Small tensor - use serial execution
33        let result = input_view.mapv(op);
34        return Ok(DenseND::from_array(result));
35    }
36
37    // Parallel execution using scirs2-core's Zip
38    let result = input_view.mapv(op);
39    Ok(DenseND::from_array(result))
40}
41
42/// Apply element-wise binary operation in parallel with broadcasting
43#[allow(dead_code)]
44pub(crate) fn parallel_binary<T, F>(x: &DenseND<T>, y: &DenseND<T>, op: F) -> Result<DenseND<T>>
45where
46    T: Clone + Num + Send + Sync,
47    F: Fn(T, T) -> T + Send + Sync,
48{
49    let x_view = x.view();
50    let y_view = y.view();
51
52    // Check if shapes are compatible
53    if x.shape() == y.shape() {
54        if !should_parallelize(x.shape()) {
55            // Small tensor - use serial execution
56            let result = Zip::from(&x_view)
57                .and(&y_view)
58                .map_collect(|a, b| op(a.clone(), b.clone()));
59            return Ok(DenseND::from_array(result));
60        }
61
62        // Parallel execution
63        let result = Zip::from(&x_view)
64            .and(&y_view)
65            .par_map_collect(|a, b| op(a.clone(), b.clone()));
66        return Ok(DenseND::from_array(result));
67    }
68
69    // Broadcasting case - fall back to serial for now
70    // TODO: Implement parallel broadcasting
71    let result = Zip::from(&x_view)
72        .and(&y_view)
73        .map_collect(|a, b| op(a.clone(), b.clone()));
74    Ok(DenseND::from_array(result))
75}
76
77/// Parallel reduction along specified axes
78#[allow(dead_code)]
79pub(crate) fn parallel_reduce_sum<T>(input: &DenseND<T>, axes: &[Axis]) -> Result<DenseND<T>>
80where
81    T: Clone + Num + Send + Sync + std::ops::AddAssign + std::iter::Sum,
82{
83    if axes.is_empty() {
84        // Reduce all axes
85        let input_view = input.view();
86        let sum: T = input_view.iter().cloned().sum();
87
88        let result_array = Array::from_elem(IxDyn(&[]), sum);
89        return Ok(DenseND::from_array(result_array));
90    }
91
92    // Reduce along specific axes
93    let mut result = input.clone();
94    for &axis in axes {
95        if axis >= result.shape().len() {
96            return Err(anyhow::anyhow!(
97                "Axis {} out of bounds for tensor with {} dimensions",
98                axis,
99                result.shape().len()
100            ));
101        }
102
103        let result_view = result.view();
104        let reduced = result_view.sum_axis(NdAxis(axis));
105        result = DenseND::from_array(reduced);
106    }
107
108    Ok(result)
109}
110
111/// Parallel mean reduction along specified axes
112#[allow(dead_code)]
113pub(crate) fn parallel_reduce_mean<T>(input: &DenseND<T>, axes: &[Axis]) -> Result<DenseND<T>>
114where
115    T: Clone + Num + Send + Sync + std::ops::AddAssign + Float + FromPrimitive + std::iter::Sum,
116{
117    if axes.is_empty() {
118        // Mean of all elements
119        let input_view = input.view();
120        let total_elements = input_view.len();
121        let sum: T = input_view.iter().cloned().sum();
122        let mean = sum / T::from_usize(total_elements).unwrap();
123
124        let result_array = Array::from_elem(IxDyn(&[]), mean);
125        return Ok(DenseND::from_array(result_array));
126    }
127
128    // Mean along specific axes
129    let mut result = input.clone();
130    for &axis in axes {
131        if axis >= result.shape().len() {
132            return Err(anyhow::anyhow!("Axis {} out of bounds", axis));
133        }
134
135        let result_view = result.view();
136        let reduced = result_view
137            .mean_axis(NdAxis(axis))
138            .ok_or_else(|| anyhow::anyhow!("Mean computation failed"))?;
139        result = DenseND::from_array(reduced);
140    }
141
142    Ok(result)
143}
144
145/// Parallel matrix multiplication optimized for large matrices
146#[allow(dead_code)]
147pub(crate) fn parallel_matmul<T>(a: &DenseND<T>, b: &DenseND<T>) -> Result<DenseND<T>>
148where
149    T: Clone + Num + Send + Sync + std::ops::AddAssign + std::default::Default,
150{
151    // For now, delegate to the standard matmul
152    // TODO: Implement blocked parallel matmul for large matrices
153    use crate::ops::execute_dense_contraction;
154    use tenrso_planner::EinsumSpec;
155
156    let spec = EinsumSpec::parse("ij,jk->ik")?;
157    execute_dense_contraction(&spec, a, b)
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn test_should_parallelize() {
166        assert!(!should_parallelize(&[100]));
167        assert!(!should_parallelize(&[50, 50]));
168        assert!(should_parallelize(&[10000]));
169        assert!(should_parallelize(&[100, 100, 2]));
170    }
171
172    #[test]
173    fn test_parallel_unary() {
174        let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
175        let result = parallel_unary(&input, |x| x * 2.0).unwrap();
176        let result_view = result.view();
177
178        assert!((result_view[[0]] as f64 - 2.0).abs() < 1e-10);
179        assert!((result_view[[1]] as f64 - 4.0).abs() < 1e-10);
180        assert!((result_view[[2]] as f64 - 6.0).abs() < 1e-10);
181        assert!((result_view[[3]] as f64 - 8.0).abs() < 1e-10);
182    }
183
184    #[test]
185    fn test_parallel_binary() {
186        let a = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
187        let b = DenseND::from_vec(vec![2.0, 3.0, 4.0, 5.0], &[4]).unwrap();
188        let result = parallel_binary(&a, &b, |x, y| x + y).unwrap();
189        let result_view = result.view();
190
191        assert!((result_view[[0]] as f64 - 3.0).abs() < 1e-10);
192        assert!((result_view[[1]] as f64 - 5.0).abs() < 1e-10);
193        assert!((result_view[[2]] as f64 - 7.0).abs() < 1e-10);
194        assert!((result_view[[3]] as f64 - 9.0).abs() < 1e-10);
195    }
196
197    #[test]
198    fn test_parallel_reduce_sum_all() {
199        let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
200        let result = parallel_reduce_sum(&input, &[]).unwrap();
201        let result_view = result.view();
202
203        assert!((result_view[[]] as f64 - 10.0).abs() < 1e-10);
204    }
205
206    #[test]
207    fn test_parallel_reduce_mean_all() {
208        let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
209        let result = parallel_reduce_mean(&input, &[]).unwrap();
210        let result_view = result.view();
211
212        assert!((result_view[[]] as f64 - 2.5).abs() < 1e-10);
213    }
214}