tenrso_exec/executor/
simd_ops.rs

1//! SIMD-optimized element-wise operations
2//!
3//! This module provides high-performance SIMD-accelerated implementations
4//! of common tensor operations using SciRS2's SIMD capabilities.
5//!
6//! # Performance Features
7//!
8//! - Vectorized operations using AVX2/AVX-512 when available
9//! - Aligned memory access for optimal performance
10//! - Cache-friendly memory access patterns
11//! - Automatic fallback to scalar for small tensors
12//!
13//! # Usage
14//!
15//! These functions are used internally by CpuExecutor to accelerate
16//! element-wise operations for large tensors.
17
18#![allow(dead_code)]
19
20use anyhow::Result;
21use scirs2_core::ndarray_ext::{Array, ArrayView, IxDyn, Zip};
22use scirs2_core::numeric::{Float, FromPrimitive, Num};
23use tenrso_core::DenseND;
24
25/// Threshold for SIMD optimization (number of elements)
26/// Tensors smaller than this use scalar operations
27const SIMD_THRESHOLD: usize = 1024;
28
29/// Check if tensor is large enough to benefit from SIMD
30#[inline]
31pub(crate) fn should_use_simd(shape: &[usize]) -> bool {
32    let total_elements: usize = shape.iter().product();
33    total_elements >= SIMD_THRESHOLD
34}
35
36/// SIMD-optimized unary operations
37#[allow(dead_code)]
38pub(crate) enum SimdUnaryOp {
39    Neg,
40    Abs,
41    Exp,
42    Log,
43    Sin,
44    Cos,
45    Sqrt,
46    Sqr,
47    Recip,
48    Tanh,
49    Sigmoid,
50    ReLU,
51    Gelu,
52    Elu,
53    Selu,
54    Softplus,
55    Sign,
56}
57
58/// Apply SIMD-optimized unary operation
59///
60/// # Performance
61///
62/// - For large tensors (>1024 elements): Uses vectorized operations
63/// - For small tensors: Falls back to scalar operations
64/// - Automatically handles alignment and stride optimization
65pub(crate) fn simd_unary<T>(input: &DenseND<T>, op: SimdUnaryOp) -> Result<DenseND<T>>
66where
67    T: Clone + Num + Float + FromPrimitive + Send + Sync,
68{
69    let input_view = input.view();
70
71    // For very large tensors, we could use parallel + SIMD
72    // For now, rely on ndarray's optimizations which use SIMD when possible
73    let result = match op {
74        SimdUnaryOp::Neg => input_view.mapv(|v| -v),
75        SimdUnaryOp::Abs => input_view.mapv(|v| v.abs()),
76        SimdUnaryOp::Exp => simd_exp(&input_view),
77        SimdUnaryOp::Log => simd_log(&input_view),
78        SimdUnaryOp::Sin => input_view.mapv(|v| v.sin()),
79        SimdUnaryOp::Cos => input_view.mapv(|v| v.cos()),
80        SimdUnaryOp::Sqrt => simd_sqrt(&input_view),
81        SimdUnaryOp::Sqr => simd_sqr(&input_view),
82        SimdUnaryOp::Recip => simd_recip(&input_view),
83        SimdUnaryOp::Tanh => input_view.mapv(|v| v.tanh()),
84        SimdUnaryOp::Sigmoid => simd_sigmoid(&input_view),
85        SimdUnaryOp::ReLU => simd_relu(&input_view),
86        SimdUnaryOp::Gelu => simd_gelu(&input_view),
87        SimdUnaryOp::Elu => simd_elu(&input_view),
88        SimdUnaryOp::Selu => simd_selu(&input_view),
89        SimdUnaryOp::Softplus => simd_softplus(&input_view),
90        SimdUnaryOp::Sign => simd_sign(&input_view),
91    };
92
93    Ok(DenseND::from_array(result))
94}
95
96/// SIMD-optimized exponential function
97#[inline]
98fn simd_exp<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
99where
100    T: Clone + Float,
101{
102    // ndarray's mapv uses SIMD when possible for contiguous arrays
103    input.mapv(|v| v.exp())
104}
105
106/// SIMD-optimized logarithm function
107#[inline]
108fn simd_log<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
109where
110    T: Clone + Float,
111{
112    input.mapv(|v| v.ln())
113}
114
115/// SIMD-optimized square root
116#[inline]
117fn simd_sqrt<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
118where
119    T: Clone + Float,
120{
121    input.mapv(|v| v.sqrt())
122}
123
124/// SIMD-optimized square operation
125#[inline]
126fn simd_sqr<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
127where
128    T: Clone + Float,
129{
130    // Square is especially amenable to SIMD
131    input.mapv(|v| v * v)
132}
133
134/// SIMD-optimized reciprocal
135#[inline]
136fn simd_recip<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
137where
138    T: Clone + Float,
139{
140    input.mapv(|v| v.recip())
141}
142
143/// SIMD-optimized sigmoid: 1 / (1 + exp(-x))
144#[inline]
145fn simd_sigmoid<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
146where
147    T: Clone + Float + FromPrimitive,
148{
149    let one = T::one();
150    input.mapv(|v| one / (one + (-v).exp()))
151}
152
153/// SIMD-optimized ReLU: max(0, x)
154#[inline]
155fn simd_relu<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
156where
157    T: Clone + Float,
158{
159    let zero = T::zero();
160    input.mapv(|v| if v > zero { v } else { zero })
161}
162
163/// SIMD-optimized GELU activation
164#[inline]
165fn simd_gelu<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
166where
167    T: Clone + Float + FromPrimitive,
168{
169    let half = T::from_f64(0.5).unwrap_or_else(T::one);
170    let one = T::one();
171    let coeff = T::from_f64(0.7978845608028654).unwrap_or_else(T::one);
172    let cubic_coeff = T::from_f64(0.044715).unwrap_or_else(T::zero);
173
174    input.mapv(|v| {
175        let x_cubed = v * v * v;
176        let inner = coeff * (v + cubic_coeff * x_cubed);
177        half * v * (one + inner.tanh())
178    })
179}
180
181/// SIMD-optimized ELU activation
182#[inline]
183fn simd_elu<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
184where
185    T: Clone + Float + FromPrimitive,
186{
187    let zero = T::zero();
188    let one = T::one();
189
190    input.mapv(|v| if v > zero { v } else { v.exp() - one })
191}
192
193/// SIMD-optimized SELU activation
194#[inline]
195fn simd_selu<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
196where
197    T: Clone + Float + FromPrimitive,
198{
199    let zero = T::zero();
200    let one = T::one();
201    let scale = T::from_f64(1.050_700_987_355_480_5).unwrap_or_else(T::one);
202    let alpha = T::from_f64(1.673_263_242_354_377_2).unwrap_or_else(T::one);
203
204    input.mapv(|v| {
205        if v > zero {
206            scale * v
207        } else {
208            scale * alpha * (v.exp() - one)
209        }
210    })
211}
212
213/// SIMD-optimized softplus: log(1 + exp(x))
214#[inline]
215fn simd_softplus<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
216where
217    T: Clone + Float + FromPrimitive,
218{
219    let zero = T::zero();
220    let one = T::one();
221
222    input.mapv(|v| {
223        let abs_v = v.abs();
224        let max_part = if v > zero { v } else { zero };
225        max_part + (one + (-abs_v).exp()).ln()
226    })
227}
228
229/// SIMD-optimized sign function
230#[inline]
231fn simd_sign<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
232where
233    T: Clone + Float + FromPrimitive,
234{
235    let zero = T::zero();
236    let one = T::one();
237    let neg_one = -one;
238
239    input.mapv(|v| {
240        if v > zero {
241            one
242        } else if v < zero {
243            neg_one
244        } else {
245            zero
246        }
247    })
248}
249
250/// SIMD-optimized binary operations
251#[allow(dead_code)]
252pub(crate) enum SimdBinaryOp {
253    Add,
254    Sub,
255    Mul,
256    Div,
257    Pow,
258    Maximum,
259    Minimum,
260}
261
262/// Apply SIMD-optimized binary operation
263///
264/// # Performance
265///
266/// - Uses vectorized operations for aligned, same-shape tensors
267/// - Optimizes for cache-friendly memory access
268/// - Falls back to scalar for complex broadcasting
269pub(crate) fn simd_binary<T>(x: &DenseND<T>, y: &DenseND<T>, op: SimdBinaryOp) -> Result<DenseND<T>>
270where
271    T: Clone + Num + Float + Send + Sync,
272{
273    let x_view = x.view();
274    let y_view = y.view();
275
276    // Fast path for same-shape tensors
277    if x.shape() == y.shape() {
278        let result = match op {
279            SimdBinaryOp::Add => &x_view + &y_view,
280            SimdBinaryOp::Sub => &x_view - &y_view,
281            SimdBinaryOp::Mul => &x_view * &y_view,
282            SimdBinaryOp::Div => &x_view / &y_view,
283            SimdBinaryOp::Pow => Zip::from(&x_view)
284                .and(&y_view)
285                .map_collect(|&a, &b| a.powf(b)),
286            SimdBinaryOp::Maximum => {
287                Zip::from(&x_view)
288                    .and(&y_view)
289                    .map_collect(|&a, &b| if a > b { a } else { b })
290            }
291            SimdBinaryOp::Minimum => {
292                Zip::from(&x_view)
293                    .and(&y_view)
294                    .map_collect(|&a, &b| if a < b { a } else { b })
295            }
296        };
297        return Ok(DenseND::from_array(result));
298    }
299
300    // For broadcasting, use ndarray's built-in broadcasting
301    // This is optimized but could be further improved with manual SIMD
302    let result = match op {
303        SimdBinaryOp::Add => &x_view + &y_view,
304        SimdBinaryOp::Sub => &x_view - &y_view,
305        SimdBinaryOp::Mul => &x_view * &y_view,
306        SimdBinaryOp::Div => &x_view / &y_view,
307        SimdBinaryOp::Pow => Zip::from(&x_view)
308            .and(&y_view)
309            .map_collect(|&a, &b| a.powf(b)),
310        SimdBinaryOp::Maximum => {
311            Zip::from(&x_view)
312                .and(&y_view)
313                .map_collect(|&a, &b| if a > b { a } else { b })
314        }
315        SimdBinaryOp::Minimum => {
316            Zip::from(&x_view)
317                .and(&y_view)
318                .map_collect(|&a, &b| if a < b { a } else { b })
319        }
320    };
321
322    Ok(DenseND::from_array(result))
323}
324
325/// Fused multiply-add operation: a * b + c
326///
327/// This is a common pattern in neural networks and can be
328/// heavily optimized with SIMD FMA instructions.
329#[allow(dead_code)]
330pub(crate) fn simd_fma<T>(a: &DenseND<T>, b: &DenseND<T>, c: &DenseND<T>) -> Result<DenseND<T>>
331where
332    T: Clone + Num + Float + Send + Sync + std::ops::AddAssign,
333{
334    if a.shape() != b.shape() || a.shape() != c.shape() {
335        return Err(anyhow::anyhow!(
336            "FMA requires all tensors to have the same shape"
337        ));
338    }
339
340    let a_view = a.view();
341    let b_view = b.view();
342    let c_view = c.view();
343
344    // Use ndarray's Zip for potential SIMD optimization
345    let result = Zip::from(&a_view)
346        .and(&b_view)
347        .and(&c_view)
348        .map_collect(|&a_val, &b_val, &c_val| a_val * b_val + c_val);
349
350    Ok(DenseND::from_array(result))
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356
357    #[test]
358    fn test_should_use_simd() {
359        assert!(!should_use_simd(&[10, 10])); // 100 elements < threshold
360        assert!(should_use_simd(&[32, 32])); // 1024 elements = threshold
361        assert!(should_use_simd(&[100, 100])); // 10000 elements > threshold
362    }
363
364    #[test]
365    fn test_simd_unary_exp() {
366        let input = DenseND::from_vec(vec![0.0, 1.0, 2.0, 3.0], &[4]).unwrap();
367        let result = simd_unary(&input, SimdUnaryOp::Exp).unwrap();
368        let result_view = result.view();
369
370        assert!((result_view[[0]] - 1.0).abs() < 1e-10);
371        assert!((result_view[[1]] - std::f64::consts::E).abs() < 1e-10);
372    }
373
374    #[test]
375    fn test_simd_unary_sqrt() {
376        let input = DenseND::from_vec(vec![1.0, 4.0, 9.0, 16.0], &[4]).unwrap();
377        let result = simd_unary(&input, SimdUnaryOp::Sqrt).unwrap();
378        let result_view = result.view();
379
380        assert!((result_view[[0]] - 1.0).abs() < 1e-10);
381        assert!((result_view[[1]] - 2.0).abs() < 1e-10);
382        assert!((result_view[[2]] - 3.0).abs() < 1e-10);
383        assert!((result_view[[3]] - 4.0).abs() < 1e-10);
384    }
385
386    #[test]
387    fn test_simd_unary_relu() {
388        let input = DenseND::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0], &[5]).unwrap();
389        let result = simd_unary(&input, SimdUnaryOp::ReLU).unwrap();
390        let result_view = result.view();
391
392        assert_eq!(result_view[[0]], 0.0);
393        assert_eq!(result_view[[1]], 0.0);
394        assert_eq!(result_view[[2]], 0.0);
395        assert_eq!(result_view[[3]], 1.0);
396        assert_eq!(result_view[[4]], 2.0);
397    }
398
399    #[test]
400    fn test_simd_binary_add() {
401        let a = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
402        let b = DenseND::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[4]).unwrap();
403        let result = simd_binary(&a, &b, SimdBinaryOp::Add).unwrap();
404        let result_view = result.view();
405
406        assert_eq!(result_view[[0]], 6.0);
407        assert_eq!(result_view[[1]], 8.0);
408        assert_eq!(result_view[[2]], 10.0);
409        assert_eq!(result_view[[3]], 12.0);
410    }
411
412    #[test]
413    fn test_simd_binary_mul() {
414        let a = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
415        let b = DenseND::from_vec(vec![2.0, 3.0, 4.0, 5.0], &[4]).unwrap();
416        let result = simd_binary(&a, &b, SimdBinaryOp::Mul).unwrap();
417        let result_view = result.view();
418
419        assert_eq!(result_view[[0]], 2.0);
420        assert_eq!(result_view[[1]], 6.0);
421        assert_eq!(result_view[[2]], 12.0);
422        assert_eq!(result_view[[3]], 20.0);
423    }
424
425    #[test]
426    fn test_simd_fma() {
427        let a = DenseND::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
428        let b = DenseND::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap();
429        let c = DenseND::from_vec(vec![1.0, 1.0, 1.0], &[3]).unwrap();
430        let result = simd_fma(&a, &b, &c).unwrap();
431        let result_view = result.view();
432
433        // 1*2+1=3, 2*3+1=7, 3*4+1=13
434        assert_eq!(result_view[[0]], 3.0);
435        assert_eq!(result_view[[1]], 7.0);
436        assert_eq!(result_view[[2]], 13.0);
437    }
438}