tenrso_exec/executor/
optimized_ops.rs

1//! Optimized operations integration layer
2//!
3//! This module provides a unified interface for using the various optimization
4//! modules (SIMD, tiled reductions, vectorized broadcasting) based on executor
5//! configuration and tensor characteristics.
6//!
7//! # Strategy
8//!
9//! 1. Check executor configuration flags
10//! 2. Check tensor characteristics (size, shape)
11//! 3. Select the best implementation:
12//!    - Optimized version if enabled and beneficial
13//!    - Standard version otherwise
14//!
15//! # Performance
16//!
17//! The integration layer adds minimal overhead (a few boolean checks)
18//! while providing significant speedups for applicable operations.
19
20use super::simd_ops::{self, SimdBinaryOp, SimdUnaryOp};
21use super::tiled_reductions;
22use super::types::{BinaryOp, CpuExecutor};
23use anyhow::Result;
24use scirs2_core::numeric::{Float, FromPrimitive, Num};
25use tenrso_core::{Axis, DenseND};
26
27/// Optimized unary element-wise operation
28///
29/// Automatically selects between SIMD and standard implementation
30/// based on executor configuration and tensor size.
31#[allow(dead_code)]
32pub(crate) fn optimized_unary<T>(
33    executor: &CpuExecutor,
34    input: &DenseND<T>,
35    op: UnaryOpType,
36) -> Result<DenseND<T>>
37where
38    T: Clone + Num + Float + FromPrimitive + Send + Sync,
39{
40    // Check if SIMD is enabled and tensor is large enough
41    if executor.enable_simd && simd_ops::should_use_simd(input.shape()) {
42        let simd_op = match op {
43            UnaryOpType::Neg => SimdUnaryOp::Neg,
44            UnaryOpType::Abs => SimdUnaryOp::Abs,
45            UnaryOpType::Exp => SimdUnaryOp::Exp,
46            UnaryOpType::Log => SimdUnaryOp::Log,
47            UnaryOpType::Sin => SimdUnaryOp::Sin,
48            UnaryOpType::Cos => SimdUnaryOp::Cos,
49            UnaryOpType::Sqrt => SimdUnaryOp::Sqrt,
50            UnaryOpType::Sqr => SimdUnaryOp::Sqr,
51            UnaryOpType::Recip => SimdUnaryOp::Recip,
52            UnaryOpType::Tanh => SimdUnaryOp::Tanh,
53            UnaryOpType::Sigmoid => SimdUnaryOp::Sigmoid,
54            UnaryOpType::ReLU => SimdUnaryOp::ReLU,
55            UnaryOpType::Gelu => SimdUnaryOp::Gelu,
56            UnaryOpType::Elu => SimdUnaryOp::Elu,
57            UnaryOpType::Selu => SimdUnaryOp::Selu,
58            UnaryOpType::Softplus => SimdUnaryOp::Softplus,
59            UnaryOpType::Sign => SimdUnaryOp::Sign,
60        };
61        return simd_ops::simd_unary(input, simd_op);
62    }
63
64    // Fall back to standard implementation
65    let result = match op {
66        UnaryOpType::Neg => input.view().mapv(|v| -v),
67        UnaryOpType::Abs => input.view().mapv(|v| v.abs()),
68        UnaryOpType::Exp => input.view().mapv(|v| v.exp()),
69        UnaryOpType::Log => input.view().mapv(|v| v.ln()),
70        UnaryOpType::Sin => input.view().mapv(|v| v.sin()),
71        UnaryOpType::Cos => input.view().mapv(|v| v.cos()),
72        UnaryOpType::Sqrt => input.view().mapv(|v| v.sqrt()),
73        UnaryOpType::Sqr => input.view().mapv(|v| v * v),
74        UnaryOpType::Recip => input.view().mapv(|v| v.recip()),
75        UnaryOpType::Tanh => input.view().mapv(|v| v.tanh()),
76        UnaryOpType::Sigmoid => input.view().mapv(|v| {
77            let one = T::one();
78            one / (one + (-v).exp())
79        }),
80        UnaryOpType::ReLU => input.view().mapv(|v| {
81            let zero = T::zero();
82            if v > zero {
83                v
84            } else {
85                zero
86            }
87        }),
88        UnaryOpType::Gelu => input.view().mapv(|v| {
89            let half = T::from_f64(0.5).unwrap_or_else(T::one);
90            let one = T::one();
91            let coeff = T::from_f64(0.7978845608028654).unwrap_or_else(T::one);
92            let cubic_coeff = T::from_f64(0.044715).unwrap_or_else(T::zero);
93            let x_cubed = v * v * v;
94            let inner = coeff * (v + cubic_coeff * x_cubed);
95            half * v * (one + inner.tanh())
96        }),
97        UnaryOpType::Elu => input.view().mapv(|v| {
98            let zero = T::zero();
99            let one = T::one();
100            if v > zero {
101                v
102            } else {
103                v.exp() - one
104            }
105        }),
106        UnaryOpType::Selu => input.view().mapv(|v| {
107            let zero = T::zero();
108            let one = T::one();
109            let scale = T::from_f64(1.050_700_987_355_480_5).unwrap_or_else(T::one);
110            let alpha = T::from_f64(1.673_263_242_354_377_2).unwrap_or_else(T::one);
111            if v > zero {
112                scale * v
113            } else {
114                scale * alpha * (v.exp() - one)
115            }
116        }),
117        UnaryOpType::Softplus => input.view().mapv(|v| {
118            let zero = T::zero();
119            let one = T::one();
120            let abs_v = v.abs();
121            let max_part = if v > zero { v } else { zero };
122            max_part + (one + (-abs_v).exp()).ln()
123        }),
124        UnaryOpType::Sign => input.view().mapv(|v| {
125            let zero = T::zero();
126            let one = T::one();
127            let neg_one = -one;
128            if v > zero {
129                one
130            } else if v < zero {
131                neg_one
132            } else {
133                zero
134            }
135        }),
136    };
137
138    Ok(DenseND::from_array(result))
139}
140
141/// Optimized binary element-wise operation
142///
143/// Automatically selects between SIMD, vectorized broadcasting,
144/// and standard implementation based on executor configuration.
145#[allow(dead_code)]
146pub(crate) fn optimized_binary<T>(
147    executor: &CpuExecutor,
148    x: &DenseND<T>,
149    y: &DenseND<T>,
150    op: BinaryOp,
151) -> Result<DenseND<T>>
152where
153    T: Clone + Num + Float + Send + Sync + std::ops::AddAssign,
154{
155    // Check if shapes match (no broadcasting needed)
156    if x.shape() == y.shape() {
157        // Try SIMD for same-shape operations
158        if executor.enable_simd && simd_ops::should_use_simd(x.shape()) {
159            let simd_op = match op {
160                BinaryOp::Add => SimdBinaryOp::Add,
161                BinaryOp::Sub => SimdBinaryOp::Sub,
162                BinaryOp::Mul => SimdBinaryOp::Mul,
163                BinaryOp::Div => SimdBinaryOp::Div,
164                BinaryOp::Pow => SimdBinaryOp::Pow,
165                BinaryOp::Maximum => SimdBinaryOp::Maximum,
166                BinaryOp::Minimum => SimdBinaryOp::Minimum,
167            };
168            return simd_ops::simd_binary(x, y, simd_op);
169        }
170    }
171
172    // For different shapes, standard ndarray operations handle broadcasting
173    use scirs2_core::ndarray_ext::Zip;
174    let result = match op {
175        BinaryOp::Add => &x.view() + &y.view(),
176        BinaryOp::Sub => &x.view() - &y.view(),
177        BinaryOp::Mul => &x.view() * &y.view(),
178        BinaryOp::Div => &x.view() / &y.view(),
179        BinaryOp::Pow => Zip::from(&x.view())
180            .and(&y.view())
181            .map_collect(|&x_val, &y_val| x_val.powf(y_val)),
182        BinaryOp::Maximum => Zip::from(&x.view())
183            .and(&y.view())
184            .map_collect(|&x_val, &y_val| if x_val > y_val { x_val } else { y_val }),
185        BinaryOp::Minimum => Zip::from(&x.view())
186            .and(&y.view())
187            .map_collect(|&x_val, &y_val| if x_val < y_val { x_val } else { y_val }),
188    };
189
190    Ok(DenseND::from_array(result))
191}
192
193/// Optimized sum reduction
194///
195/// Uses tiled reduction for large tensors when enabled.
196#[allow(dead_code)]
197pub(crate) fn optimized_sum<T>(
198    executor: &CpuExecutor,
199    input: &DenseND<T>,
200    axes: &[Axis],
201) -> Result<DenseND<T>>
202where
203    T: Clone + Num + Send + Sync + std::ops::AddAssign + std::iter::Sum,
204{
205    // For all-axes reduction (empty axes means reduce to scalar)
206    if axes.is_empty() {
207        if executor.enable_tiled_reductions && tiled_reductions::should_use_tiling(input.shape()) {
208            let sum_val = tiled_reductions::tiled_sum_all(input)?;
209            let result = scirs2_core::ndarray_ext::Array::from_elem(
210                scirs2_core::ndarray_ext::IxDyn(&[]),
211                sum_val,
212            );
213            return Ok(DenseND::from_array(result));
214        } else {
215            // Small tensor - use simple sum
216            let sum_val: T = input.view().iter().cloned().sum();
217            let result = scirs2_core::ndarray_ext::Array::from_elem(
218                scirs2_core::ndarray_ext::IxDyn(&[]),
219                sum_val,
220            );
221            return Ok(DenseND::from_array(result));
222        }
223    }
224
225    // For single-axis reduction, try tiled axis reduction
226    if axes.len() == 1 && executor.enable_tiled_reductions {
227        return tiled_reductions::tiled_sum_axis(input, axes[0]);
228    }
229
230    // Fall back to standard implementation for multi-axis reduction
231    let mut result = input.view().to_owned();
232    let mut sorted_axes = axes.to_vec();
233    sorted_axes.sort_unstable_by(|a, b| b.cmp(a));
234
235    for &axis_idx in &sorted_axes {
236        let axis = scirs2_core::ndarray_ext::Axis(axis_idx);
237        result = result.sum_axis(axis);
238    }
239
240    Ok(DenseND::from_array(result))
241}
242
243/// Optimized mean reduction
244///
245/// Uses tiled reduction for large tensors when enabled.
246#[allow(dead_code)]
247pub(crate) fn optimized_mean<T>(
248    executor: &CpuExecutor,
249    input: &DenseND<T>,
250    axes: &[Axis],
251) -> Result<DenseND<T>>
252where
253    T: Clone + Num + Send + Sync + std::ops::AddAssign + Float + FromPrimitive + std::iter::Sum,
254{
255    // For all-axes reduction (empty axes means reduce to scalar)
256    if axes.is_empty() {
257        if executor.enable_tiled_reductions && tiled_reductions::should_use_tiling(input.shape()) {
258            let mean_val = tiled_reductions::tiled_mean_all(input)?;
259            let result = scirs2_core::ndarray_ext::Array::from_elem(
260                scirs2_core::ndarray_ext::IxDyn(&[]),
261                mean_val,
262            );
263            return Ok(DenseND::from_array(result));
264        } else {
265            // Small tensor - use simple mean
266            let total_elements = input.view().len();
267            let sum: T = input.view().iter().cloned().sum();
268            let mean = sum / T::from_usize(total_elements).unwrap();
269            let result = scirs2_core::ndarray_ext::Array::from_elem(
270                scirs2_core::ndarray_ext::IxDyn(&[]),
271                mean,
272            );
273            return Ok(DenseND::from_array(result));
274        }
275    }
276
277    // Fall back to standard implementation
278    let mut result = input.view().to_owned();
279    let mut sorted_axes = axes.to_vec();
280    sorted_axes.sort_unstable_by(|a, b| b.cmp(a));
281
282    for &axis_idx in &sorted_axes {
283        let axis = scirs2_core::ndarray_ext::Axis(axis_idx);
284        result = result
285            .mean_axis(axis)
286            .ok_or_else(|| anyhow::anyhow!("Mean computation failed"))?;
287    }
288
289    Ok(DenseND::from_array(result))
290}
291
292/// Unary operation types for optimized dispatch
293#[derive(Clone, Copy, Debug)]
294#[allow(dead_code)]
295pub(crate) enum UnaryOpType {
296    Neg,
297    Abs,
298    Exp,
299    Log,
300    Sin,
301    Cos,
302    Sqrt,
303    Sqr,
304    Recip,
305    Tanh,
306    Sigmoid,
307    ReLU,
308    Gelu,
309    Elu,
310    Selu,
311    Softplus,
312    Sign,
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn test_optimized_unary_small_tensor() {
321        let executor = CpuExecutor::new();
322        let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
323
324        let result = optimized_unary(&executor, &input, UnaryOpType::Neg).unwrap();
325        let result_view = result.view();
326
327        assert_eq!(result_view[[0]], -1.0);
328        assert_eq!(result_view[[1]], -2.0);
329        assert_eq!(result_view[[2]], -3.0);
330        assert_eq!(result_view[[3]], -4.0);
331    }
332
333    #[test]
334    fn test_optimized_binary_same_shape() {
335        let executor = CpuExecutor::new();
336        let a = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
337        let b = DenseND::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[4]).unwrap();
338
339        let result = optimized_binary(&executor, &a, &b, BinaryOp::Add).unwrap();
340        let result_view = result.view();
341
342        assert_eq!(result_view[[0]], 6.0);
343        assert_eq!(result_view[[1]], 8.0);
344        assert_eq!(result_view[[2]], 10.0);
345        assert_eq!(result_view[[3]], 12.0);
346    }
347
348    #[test]
349    fn test_optimized_sum_all() {
350        let executor = CpuExecutor::new();
351        let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
352
353        let result = optimized_sum(&executor, &input, &[]).unwrap();
354        let result_view = result.view();
355
356        assert_eq!(result_view[[]], 15.0);
357    }
358
359    #[test]
360    fn test_optimized_mean_all() {
361        let executor = CpuExecutor::new();
362        let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
363
364        let result = optimized_mean(&executor, &input, &[]).unwrap();
365        let result_view = result.view();
366
367        assert_eq!(result_view[[]], 3.0);
368    }
369
370    #[test]
371    fn test_optimization_disabled() {
372        let executor = CpuExecutor::unoptimized();
373        let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
374
375        // Should still work, just without optimizations
376        let result = optimized_unary(&executor, &input, UnaryOpType::Exp).unwrap();
377        let result_view = result.view();
378
379        assert!((result_view[[0]] - std::f64::consts::E).abs() < 1e-10);
380    }
381}