tenrso_exec/executor/
parallel.rs1use anyhow::Result;
7use scirs2_core::ndarray_ext::{Array, Axis as NdAxis, IxDyn, Zip};
8use scirs2_core::numeric::{Float, FromPrimitive, Num};
9use tenrso_core::{Axis, DenseND};
10
11const PARALLEL_THRESHOLD: usize = 10_000;
14
15#[inline]
17pub(crate) fn should_parallelize(shape: &[usize]) -> bool {
18 let total_elements: usize = shape.iter().product();
19 total_elements >= PARALLEL_THRESHOLD
20}
21
22#[allow(dead_code)]
24pub(crate) fn parallel_unary<T, F>(input: &DenseND<T>, op: F) -> Result<DenseND<T>>
25where
26 T: Clone + Num + Send + Sync,
27 F: Fn(T) -> T + Send + Sync,
28{
29 let input_view = input.view();
30
31 if !should_parallelize(input.shape()) {
32 let result = input_view.mapv(op);
34 return Ok(DenseND::from_array(result));
35 }
36
37 let result = input_view.mapv(op);
39 Ok(DenseND::from_array(result))
40}
41
42#[allow(dead_code)]
44pub(crate) fn parallel_binary<T, F>(x: &DenseND<T>, y: &DenseND<T>, op: F) -> Result<DenseND<T>>
45where
46 T: Clone + Num + Send + Sync,
47 F: Fn(T, T) -> T + Send + Sync,
48{
49 let x_view = x.view();
50 let y_view = y.view();
51
52 if x.shape() == y.shape() {
54 if !should_parallelize(x.shape()) {
55 let result = Zip::from(&x_view)
57 .and(&y_view)
58 .map_collect(|a, b| op(a.clone(), b.clone()));
59 return Ok(DenseND::from_array(result));
60 }
61
62 let result = Zip::from(&x_view)
64 .and(&y_view)
65 .par_map_collect(|a, b| op(a.clone(), b.clone()));
66 return Ok(DenseND::from_array(result));
67 }
68
69 let result = Zip::from(&x_view)
72 .and(&y_view)
73 .map_collect(|a, b| op(a.clone(), b.clone()));
74 Ok(DenseND::from_array(result))
75}
76
77#[allow(dead_code)]
79pub(crate) fn parallel_reduce_sum<T>(input: &DenseND<T>, axes: &[Axis]) -> Result<DenseND<T>>
80where
81 T: Clone + Num + Send + Sync + std::ops::AddAssign + std::iter::Sum,
82{
83 if axes.is_empty() {
84 let input_view = input.view();
86 let sum: T = input_view.iter().cloned().sum();
87
88 let result_array = Array::from_elem(IxDyn(&[]), sum);
89 return Ok(DenseND::from_array(result_array));
90 }
91
92 let mut result = input.clone();
94 for &axis in axes {
95 if axis >= result.shape().len() {
96 return Err(anyhow::anyhow!(
97 "Axis {} out of bounds for tensor with {} dimensions",
98 axis,
99 result.shape().len()
100 ));
101 }
102
103 let result_view = result.view();
104 let reduced = result_view.sum_axis(NdAxis(axis));
105 result = DenseND::from_array(reduced);
106 }
107
108 Ok(result)
109}
110
111#[allow(dead_code)]
113pub(crate) fn parallel_reduce_mean<T>(input: &DenseND<T>, axes: &[Axis]) -> Result<DenseND<T>>
114where
115 T: Clone + Num + Send + Sync + std::ops::AddAssign + Float + FromPrimitive + std::iter::Sum,
116{
117 if axes.is_empty() {
118 let input_view = input.view();
120 let total_elements = input_view.len();
121 let sum: T = input_view.iter().cloned().sum();
122 let mean = sum / T::from_usize(total_elements).unwrap();
123
124 let result_array = Array::from_elem(IxDyn(&[]), mean);
125 return Ok(DenseND::from_array(result_array));
126 }
127
128 let mut result = input.clone();
130 for &axis in axes {
131 if axis >= result.shape().len() {
132 return Err(anyhow::anyhow!("Axis {} out of bounds", axis));
133 }
134
135 let result_view = result.view();
136 let reduced = result_view
137 .mean_axis(NdAxis(axis))
138 .ok_or_else(|| anyhow::anyhow!("Mean computation failed"))?;
139 result = DenseND::from_array(reduced);
140 }
141
142 Ok(result)
143}
144
145#[allow(dead_code)]
147pub(crate) fn parallel_matmul<T>(a: &DenseND<T>, b: &DenseND<T>) -> Result<DenseND<T>>
148where
149 T: Clone + Num + Send + Sync + std::ops::AddAssign + std::default::Default,
150{
151 use crate::ops::execute_dense_contraction;
154 use tenrso_planner::EinsumSpec;
155
156 let spec = EinsumSpec::parse("ij,jk->ik")?;
157 execute_dense_contraction(&spec, a, b)
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[test]
165 fn test_should_parallelize() {
166 assert!(!should_parallelize(&[100]));
167 assert!(!should_parallelize(&[50, 50]));
168 assert!(should_parallelize(&[10000]));
169 assert!(should_parallelize(&[100, 100, 2]));
170 }
171
172 #[test]
173 fn test_parallel_unary() {
174 let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
175 let result = parallel_unary(&input, |x| x * 2.0).unwrap();
176 let result_view = result.view();
177
178 assert!((result_view[[0]] as f64 - 2.0).abs() < 1e-10);
179 assert!((result_view[[1]] as f64 - 4.0).abs() < 1e-10);
180 assert!((result_view[[2]] as f64 - 6.0).abs() < 1e-10);
181 assert!((result_view[[3]] as f64 - 8.0).abs() < 1e-10);
182 }
183
184 #[test]
185 fn test_parallel_binary() {
186 let a = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
187 let b = DenseND::from_vec(vec![2.0, 3.0, 4.0, 5.0], &[4]).unwrap();
188 let result = parallel_binary(&a, &b, |x, y| x + y).unwrap();
189 let result_view = result.view();
190
191 assert!((result_view[[0]] as f64 - 3.0).abs() < 1e-10);
192 assert!((result_view[[1]] as f64 - 5.0).abs() < 1e-10);
193 assert!((result_view[[2]] as f64 - 7.0).abs() < 1e-10);
194 assert!((result_view[[3]] as f64 - 9.0).abs() < 1e-10);
195 }
196
197 #[test]
198 fn test_parallel_reduce_sum_all() {
199 let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
200 let result = parallel_reduce_sum(&input, &[]).unwrap();
201 let result_view = result.view();
202
203 assert!((result_view[[]] as f64 - 10.0).abs() < 1e-10);
204 }
205
206 #[test]
207 fn test_parallel_reduce_mean_all() {
208 let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
209 let result = parallel_reduce_mean(&input, &[]).unwrap();
210 let result_view = result.view();
211
212 assert!((result_view[[]] as f64 - 2.5).abs() < 1e-10);
213 }
214}