Skip to main content

simd_kernels/kernels/scientific/
scalar.rs

1// Copyright (c) 2025 SpaceCell Enterprises Ltd
2// SPDX-License-Identifier: AGPL-3.0-or-later
3// Commercial licensing available. See LICENSE and LICENSING.md.
4
5//! # **Universal Scalar Function Module** - *High-Performance Element-Wise Mathematical Operations*
6//!
7//! Full suite of vectorised mathematical functions that operate
8//! element-wise on arrays of floating-point values. It serves as the computational backbone for
9//! mathematical operations across the simd-kernels crate, for both scalar and SIMD-accelerated
10//! implementations with opt-in Arrow-compatible null masking.
11//!
12//! These are the semantic equivalent of *numpy ufuncs* in Python.
13//!
14//! ## Overview
15//!
16//! Universal scalar functions are fundamental building blocks for:
17//! - **Data Preprocessing**: Normalisations, transformations, and scaling operations
18//! - **Scientific Computing**: Mathematical transformations and special function evaluation
19//! - **Machine Learning**: Activation functions, feature engineering, and data preparation
20//! - **Signal Processing**: Filtering, transforms, and spectral analysis
21//! - **Statistics**: Data transformations and statistical preprocessing
22//! - **Financial Mathematics**: Risk calculations and price transformations
23
24use crate::kernels::scientific::erf::erf as erf_fn;
25use crate::utils::bitmask_to_simd_mask;
26use minarrow::utils::is_simd_aligned;
27use minarrow::{Bitmask, FloatArray, Vec64};
28
29/// Generates a mapping kernel in three variants:
30///
31/// 1. `$name_to` - Zero-allocation canonical implementation, writes to caller's buffer
32/// 2. `$name` - Allocates internally, delegates to `$name_to`
33/// 3. `$name_elem` - Element-wise `fn(f64) -> f64` for kernel fusion
34///
35/// The `_to` variant is for pre-allocated parallel execution where each chunk
36/// writes directly to its slice of a shared output buffer.
37///
38/// The `_elem` variant is for kernel fusion where multiple operations are composed
39/// into a single loop, keeping intermediate values in registers instead of memory.
40///
41/// `$name`      – allocating function name
42/// `$name_to`   – zero-allocation function name
43/// `$name_elem` – element-wise function for fusion
44/// `$expr`      – expression mapping a scalar `f64 -> f64`
45#[macro_export]
46macro_rules! impl_vecmap {
47    ($name:ident, $name_to:ident, $name_elem:ident, $expr:expr) => {
48        /// Element-wise variant for kernel fusion.
49        ///
50        /// # Example
51        /// ```ignore
52        /// let ops = &[neg_elem, exp_elem, sin_elem];
53        /// execute_fused::<8>(input, output, ops);
54        /// // Equivalent to neg -> exp -> sin but with ONE memory read/write
55        /// ```
56        #[inline(always)]
57        pub fn $name_elem(x: f64) -> f64 {
58            $expr(x)
59        }
60        /// Zero-allocation variant: writes directly to caller's output buffer.
61        ///
62        /// Canonical implementation with full SIMD acceleration and null handling.
63        /// For parallel execution with pre-allocated output.
64        /// Panics if input.len() != output.len().
65        #[inline(always)]
66        pub fn $name_to<const LANES: usize>(
67            input: &[f64],
68            output: &mut [f64],
69            null_mask: Option<&Bitmask>,
70            null_count: Option<usize>,
71        ) -> Result<(), &'static str>
72        where
73        {
74            let len = input.len();
75            assert_eq!(
76                len,
77                output.len(),
78                concat!(stringify!($name_to), ": input/output length mismatch")
79            );
80
81            if len == 0 {
82                return Ok(());
83            }
84            // decide if we need the null‐aware path
85            let has_nulls = match null_count {
86                Some(n) => n > 0,
87                None => null_mask.is_some(),
88            };
89            // dense (no nulls) path
90            if !has_nulls {
91                #[cfg(feature = "simd")]
92                {
93                    if is_simd_aligned(input) {
94                        use core::simd::Simd;
95                        let mut i = 0;
96                        while i + LANES <= len {
97                            let v = Simd::<f64, LANES>::from_slice(&input[i..i + LANES]);
98                            let mut r = Simd::<f64, LANES>::splat(0.0);
99                            for lane in 0..LANES {
100                                r[lane] = $expr(v[lane]);
101                            }
102                            output[i..i + LANES].copy_from_slice(r.as_array());
103                            i += LANES;
104                        }
105                        // scalar tail
106                        for j in i..len {
107                            output[j] = $expr(input[j]);
108                        }
109                        return Ok(());
110                    }
111                }
112
113                // Scalar fallback
114                for j in 0..len {
115                    output[j] = $expr(input[j]);
116                }
117                return Ok(());
118            }
119            // null‐aware path
120            let mb = null_mask.ok_or(concat!(
121                stringify!($name_to),
122                ": input mask required when nulls present"
123            ))?;
124
125            #[cfg(feature = "simd")]
126            {
127                // Check if input array is properly aligned for SIMD (cheap runtime check)
128                if is_simd_aligned(input) {
129                    use core::simd::{Mask, Simd};
130                    let mask_bytes = mb.as_bytes();
131                    let mut i = 0;
132                    while i + LANES <= len {
133                        // pull in the Arrow validity into a SIMD mask
134                        let lane_valid: Mask<i8, LANES> =
135                            bitmask_to_simd_mask::<LANES, i8>(mask_bytes, i, len);
136
137                        // Gather inputs (nulls -> NaN)
138                        let mut arr = [0.0f64; LANES];
139                        for j in 0..LANES {
140                            let idx = i + j;
141                            arr[j] = if unsafe { lane_valid.test_unchecked(j) } {
142                                input[idx]
143                            } else {
144                                f64::NAN
145                            };
146                        }
147                        let v = Simd::<f64, LANES>::from_array(arr);
148
149                        // Apply your scalar expr in SIMD form
150                        let mut r = Simd::<f64, LANES>::splat(0.0);
151                        for lane in 0..LANES {
152                            r[lane] = $expr(v[lane]);
153                        }
154                        let r_arr = r.as_array();
155                        output[i..i + LANES].copy_from_slice(r_arr);
156
157                        i += LANES;
158                    }
159                    // scalar tail
160                    for idx in i..len {
161                        if !unsafe { mb.get_unchecked(idx) } {
162                            output[idx] = f64::NAN;
163                        } else {
164                            output[idx] = $expr(input[idx]);
165                        }
166                    }
167
168                    return Ok(());
169                }
170                // Fall through to scalar path if alignment check failed
171            }
172
173            // Scalar fallback - alignment check failed
174            #[cfg(not(feature = "simd"))]
175            {
176                for idx in 0..len {
177                    if !unsafe { mb.get_unchecked(idx) } {
178                        output[idx] = f64::NAN;
179                    } else {
180                        output[idx] = $expr(input[idx]);
181                    }
182                }
183            }
184            #[cfg(feature = "simd")]
185            {
186                for idx in 0..len {
187                    if !unsafe { mb.get_unchecked(idx) } {
188                        output[idx] = f64::NAN;
189                    } else {
190                        output[idx] = $expr(input[idx]);
191                    }
192                }
193            }
194
195            Ok(())
196        }
197
198        /// Returns a new `FloatArray<f64>` with the function applied element-wise.
199        /// Propagates any input nulls (null lanes are not touched).
200        #[inline(always)]
201        pub fn $name<const LANES: usize>(
202            input: &[f64],
203            null_mask: Option<&Bitmask>,
204            null_count: Option<usize>,
205        ) -> Result<FloatArray<f64>, &'static str>
206        where
207        {
208            let len = input.len();
209            // fast length‐0 case
210            if len == 0 {
211                return Ok(FloatArray::from_slice(&[]));
212            }
213
214            let mut out = Vec64::with_capacity(len);
215            // SAFETY: we just allocated capacity, extend len to match
216            unsafe {
217                out.set_len(len);
218            }
219
220            $name_to::<LANES>(input, out.as_mut_slice(), null_mask, null_count)?;
221
222            Ok(FloatArray::from_vec64(out, null_mask.cloned()))
223        }
224    };
225}
226
227// Basic operations
228impl_vecmap!(abs, abs_to, abs_elem, |x: f64| x.abs());
229impl_vecmap!(neg, neg_to, neg_elem, |x: f64| -x);
230impl_vecmap!(recip, recip_to, recip_elem, |x: f64| 1.0 / x);
231impl_vecmap!(sqrt, sqrt_to, sqrt_elem, |x: f64| x.sqrt());
232impl_vecmap!(cbrt, cbrt_to, cbrt_elem, |x: f64| x.cbrt());
233
234// Exponential and logarithmic
235impl_vecmap!(exp, exp_to, exp_elem, |x: f64| x.exp());
236impl_vecmap!(exp2, exp2_to, exp2_elem, |x: f64| x.exp2());
237impl_vecmap!(ln, ln_to, ln_elem, |x: f64| x.ln());
238impl_vecmap!(log2, log2_to, log2_elem, |x: f64| x.log2());
239impl_vecmap!(log10, log10_to, log10_elem, |x: f64| x.log10());
240
241// Trigonometric
242impl_vecmap!(sin, sin_to, sin_elem, |x: f64| x.sin());
243impl_vecmap!(cos, cos_to, cos_elem, |x: f64| x.cos());
244impl_vecmap!(tan, tan_to, tan_elem, |x: f64| x.tan());
245impl_vecmap!(asin, asin_to, asin_elem, |x: f64| x.asin());
246impl_vecmap!(acos, acos_to, acos_elem, |x: f64| x.acos());
247impl_vecmap!(atan, atan_to, atan_elem, |x: f64| x.atan());
248
249// Hyperbolic
250impl_vecmap!(sinh, sinh_to, sinh_elem, |x: f64| x.sinh());
251impl_vecmap!(cosh, cosh_to, cosh_elem, |x: f64| x.cosh());
252impl_vecmap!(tanh, tanh_to, tanh_elem, |x: f64| x.tanh());
253impl_vecmap!(asinh, asinh_to, asinh_elem, |x: f64| x.asinh());
254impl_vecmap!(acosh, acosh_to, acosh_elem, |x: f64| x.acosh());
255impl_vecmap!(atanh, atanh_to, atanh_elem, |x: f64| x.atanh());
256
257// Error functions
258impl_vecmap!(erf, erf_to, erf_elem, |x: f64| erf_fn(x));
259impl_vecmap!(erfc, erfc_to, erfc_elem, |x: f64| 1.0 - erf_fn(x));
260
261// Rounding
262impl_vecmap!(ceil, ceil_to, ceil_elem, |x: f64| x.ceil());
263impl_vecmap!(floor, floor_to, floor_elem, |x: f64| x.floor());
264impl_vecmap!(trunc, trunc_to, trunc_elem, |x: f64| x.trunc());
265impl_vecmap!(round, round_to, round_elem, |x: f64| x.round());
266impl_vecmap!(sign, sign_to, sign_elem, |x: f64| x.signum());
267
268// Activation functions
269impl_vecmap!(sigmoid, sigmoid_to, sigmoid_elem, |x: f64| 1.0
270    / (1.0 + (-x).exp()));
271impl_vecmap!(softplus, softplus_to, softplus_elem, |x: f64| (1.0
272    + x.exp())
273.ln());
274impl_vecmap!(relu, relu_to, relu_elem, |x: f64| if x > 0.0 {
275    x
276} else {
277    0.0
278});
279impl_vecmap!(gelu, gelu_to, gelu_elem, |x: f64| {
280    0.5 * x * (1.0 + erf_fn(x / std::f64::consts::SQRT_2))
281});