tenrso_exec/executor/
tiled_reductions.rs1#![allow(dead_code)]
19
20use anyhow::Result;
21use scirs2_core::ndarray_ext::Axis as NdAxis;
22use scirs2_core::numeric::{Float, FromPrimitive, Num};
23use tenrso_core::{Axis, DenseND};
24
25const TILE_SIZE: usize = 4096;
28
29const TILING_THRESHOLD: usize = 100_000;
32
33#[inline]
35pub(crate) fn should_use_tiling(shape: &[usize]) -> bool {
36 let total_elements: usize = shape.iter().product();
37 total_elements >= TILING_THRESHOLD
38}
39
40pub(crate) fn tiled_sum_all<T>(input: &DenseND<T>) -> Result<T>
47where
48 T: Clone + Num + Send + Sync + std::ops::AddAssign + std::iter::Sum,
49{
50 let input_view = input.view();
51 let total_elements = input_view.len();
52
53 if total_elements < TILING_THRESHOLD {
54 return Ok(input_view.iter().cloned().sum());
56 }
57
58 let num_tiles = total_elements.div_ceil(TILE_SIZE);
60 let mut tile_sums = Vec::with_capacity(num_tiles);
61
62 let input_slice = input_view.as_slice();
64 if let Some(slice) = input_slice {
65 for chunk in slice.chunks(TILE_SIZE) {
67 let tile_sum: T = chunk.iter().cloned().sum();
68 tile_sums.push(tile_sum);
69 }
70 } else {
71 return Ok(input_view.iter().cloned().sum());
73 }
74
75 Ok(tile_sums.into_iter().sum())
77}
78
79pub(crate) fn tiled_mean_all<T>(input: &DenseND<T>) -> Result<T>
81where
82 T: Clone + Num + Send + Sync + std::ops::AddAssign + Float + FromPrimitive + std::iter::Sum,
83{
84 let total_elements = input.view().len();
85 let sum = tiled_sum_all(input)?;
86 let mean = sum / T::from_usize(total_elements).unwrap();
87 Ok(mean)
88}
89
90pub(crate) fn tiled_max_all<T>(input: &DenseND<T>) -> Result<T>
92where
93 T: Clone + Num + Send + Sync + PartialOrd,
94{
95 let input_view = input.view();
96 let total_elements = input_view.len();
97
98 if total_elements < TILING_THRESHOLD {
99 return input_view
101 .iter()
102 .cloned()
103 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
104 .ok_or_else(|| anyhow::anyhow!("Cannot compute max of empty tensor"));
105 }
106
107 let input_slice = input_view.as_slice();
109 if let Some(slice) = input_slice {
110 let mut tile_maxes = Vec::new();
111
112 for chunk in slice.chunks(TILE_SIZE) {
114 if let Some(tile_max) = chunk
115 .iter()
116 .cloned()
117 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
118 {
119 tile_maxes.push(tile_max);
120 }
121 }
122
123 tile_maxes
125 .into_iter()
126 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
127 .ok_or_else(|| anyhow::anyhow!("Cannot compute max of empty tensor"))
128 } else {
129 input_view
131 .iter()
132 .cloned()
133 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
134 .ok_or_else(|| anyhow::anyhow!("Cannot compute max of empty tensor"))
135 }
136}
137
138pub(crate) fn tiled_min_all<T>(input: &DenseND<T>) -> Result<T>
140where
141 T: Clone + Num + Send + Sync + PartialOrd,
142{
143 let input_view = input.view();
144 let total_elements = input_view.len();
145
146 if total_elements < TILING_THRESHOLD {
147 return input_view
149 .iter()
150 .cloned()
151 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
152 .ok_or_else(|| anyhow::anyhow!("Cannot compute min of empty tensor"));
153 }
154
155 let input_slice = input_view.as_slice();
157 if let Some(slice) = input_slice {
158 let mut tile_mins = Vec::new();
159
160 for chunk in slice.chunks(TILE_SIZE) {
162 if let Some(tile_min) = chunk
163 .iter()
164 .cloned()
165 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
166 {
167 tile_mins.push(tile_min);
168 }
169 }
170
171 tile_mins
173 .into_iter()
174 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
175 .ok_or_else(|| anyhow::anyhow!("Cannot compute min of empty tensor"))
176 } else {
177 input_view
179 .iter()
180 .cloned()
181 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
182 .ok_or_else(|| anyhow::anyhow!("Cannot compute min of empty tensor"))
183 }
184}
185
186pub(crate) fn tiled_sum_axis<T>(input: &DenseND<T>, axis: Axis) -> Result<DenseND<T>>
193where
194 T: Clone + Num + Send + Sync + std::ops::AddAssign + std::iter::Sum,
195{
196 let input_view = input.view();
197
198 if !should_use_tiling(input.shape()) {
199 let nd_axis = NdAxis(axis);
201 let result = input_view.sum_axis(nd_axis);
202 return Ok(DenseND::from_array(result));
203 }
204
205 let nd_axis = NdAxis(axis);
208 let result = input_view.sum_axis(nd_axis);
209 Ok(DenseND::from_array(result))
210}
211
212#[allow(dead_code)]
214pub(crate) fn tiled_mean_axis<T>(input: &DenseND<T>, axis: Axis) -> Result<DenseND<T>>
215where
216 T: Clone + Num + Send + Sync + std::ops::AddAssign + Float + FromPrimitive + std::iter::Sum,
217{
218 let input_view = input.view();
219 let nd_axis = NdAxis(axis);
220
221 let result = input_view
222 .mean_axis(nd_axis)
223 .ok_or_else(|| anyhow::anyhow!("Mean computation failed"))?;
224
225 Ok(DenseND::from_array(result))
226}
227
228#[allow(dead_code)]
235pub(crate) fn tiled_matvec<T>(matrix: &DenseND<T>, vector: &DenseND<T>) -> Result<DenseND<T>>
236where
237 T: Clone + Num + Send + Sync + std::ops::AddAssign + std::default::Default,
238{
239 if matrix.shape().len() != 2 || vector.shape().len() != 1 {
241 return Err(anyhow::anyhow!(
242 "tiled_matvec requires 2D matrix and 1D vector"
243 ));
244 }
245
246 let m = matrix.shape()[0];
247 let n = matrix.shape()[1];
248 if vector.shape()[0] != n {
249 return Err(anyhow::anyhow!(
250 "Matrix columns ({}) must match vector size ({})",
251 n,
252 vector.shape()[0]
253 ));
254 }
255
256 if m * n < TILING_THRESHOLD {
258 let mut result_data = vec![T::zero(); m];
260 #[allow(clippy::needless_range_loop)]
261 for i in 0..m {
262 let mut sum = T::zero();
263 for j in 0..n {
264 sum += matrix.view()[[i, j]].clone() * vector.view()[[j]].clone();
265 }
266 result_data[i] = sum;
267 }
268 return DenseND::from_vec(result_data, &[m]);
269 }
270
271 let mut result_data = vec![T::zero(); m];
274 #[allow(clippy::needless_range_loop)]
275 for i in 0..m {
276 let mut sum = T::zero();
277 for j in 0..n {
278 sum += matrix.view()[[i, j]].clone() * vector.view()[[j]].clone();
279 }
280 result_data[i] = sum;
281 }
282 DenseND::from_vec(result_data, &[m])
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_should_use_tiling() {
291 assert!(!should_use_tiling(&[100, 100])); assert!(!should_use_tiling(&[300, 300])); assert!(should_use_tiling(&[400, 400])); assert!(should_use_tiling(&[1000, 1000])); }
296
297 #[test]
298 fn test_tiled_sum_all_small() {
299 let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
300 let result = tiled_sum_all(&input).unwrap();
301 assert_eq!(result, 15.0);
302 }
303
304 #[test]
305 fn test_tiled_sum_all_large() {
306 let data: Vec<f64> = (0..200_000).map(|i| i as f64).collect();
308 let input = DenseND::from_vec(data, &[200_000]).unwrap();
309 let result = tiled_sum_all(&input).unwrap();
310
311 let expected = 199_999.0 * 200_000.0 / 2.0;
313 assert!((result - expected).abs() < 1.0);
314 }
315
316 #[test]
317 fn test_tiled_mean_all() {
318 let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
319 let result = tiled_mean_all(&input).unwrap();
320 assert_eq!(result, 3.0);
321 }
322
323 #[test]
324 fn test_tiled_max_all() {
325 let input = DenseND::from_vec(vec![1.0, 5.0, 3.0, 9.0, 2.0], &[5]).unwrap();
326 let result = tiled_max_all(&input).unwrap();
327 assert_eq!(result, 9.0);
328 }
329
330 #[test]
331 fn test_tiled_min_all() {
332 let input = DenseND::from_vec(vec![5.0, 1.0, 3.0, 9.0, 2.0], &[5]).unwrap();
333 let result = tiled_min_all(&input).unwrap();
334 assert_eq!(result, 1.0);
335 }
336
337 #[test]
338 fn test_tiled_sum_axis() {
339 let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
340 let result = tiled_sum_axis(&input, 0).unwrap();
341
342 assert_eq!(result.shape(), &[3]);
344 let result_view = result.view();
345 assert_eq!(result_view[[0]], 5.0);
346 assert_eq!(result_view[[1]], 7.0);
347 assert_eq!(result_view[[2]], 9.0);
348 }
349
350 #[test]
351 fn test_tiled_matvec() {
352 let matrix = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
353 let vector = DenseND::from_vec(vec![5.0, 6.0], &[2]).unwrap();
354
355 let result = tiled_matvec(&matrix, &vector).unwrap();
356
357 assert_eq!(result.shape(), &[2]);
359 let result_view = result.view();
360 assert_eq!(result_view[[0]], 17.0);
361 assert_eq!(result_view[[1]], 39.0);
362 }
363}