Skip to main content

simd_kernels/kernels/scientific/
scalar.rs

1// Copyright (c) 2025 SpaceCell Enterprises Ltd
2// SPDX-License-Identifier: AGPL-3.0-or-later
3// Commercial licensing available. See LICENSE and LICENSING.md.
4
5//! # **Universal Scalar Function Module** - *High-Performance Element-Wise Mathematical Operations*
6//!
7//! Full suite of vectorised mathematical functions that operate
8//! element-wise on arrays of floating-point values. It serves as the computational backbone for
9//! mathematical operations across the simd-kernels crate, for both scalar and SIMD-accelerated
10//! implementations with opt-in Arrow-compatible null masking.
11//!
12//! These are the semantic equivalent of *numpy ufuncs* in Python.
13//!
14//! ## Overview
15//!
16//! Universal scalar functions are fundamental building blocks for:
17//! - **Data Preprocessing**: Normalisations, transformations, and scaling operations
18//! - **Scientific Computing**: Mathematical transformations and special function evaluation
19//! - **Machine Learning**: Activation functions, feature engineering, and data preparation
20//! - **Signal Processing**: Filtering, transforms, and spectral analysis
21//! - **Statistics**: Data transformations and statistical preprocessing
22//! - **Financial Mathematics**: Risk calculations and price transformations
23
24use crate::kernels::scientific::erf::erf as erf_fn;
25use crate::utils::bitmask_to_simd_mask;
26use minarrow::utils::is_simd_aligned;
27use minarrow::{Bitmask, FloatArray, Vec64};
28
29/// Generates a mapping kernel in three variants:
30///
31/// 1. `$name_to` - Zero-allocation canonical implementation, writes to caller's buffer
32/// 2. `$name` - Allocates internally, delegates to `$name_to`
33/// 3. `$name_elem` - Element-wise `fn(f64) -> f64` for kernel fusion
34///
35/// The `_to` variant is for pre-allocated parallel execution where each chunk
36/// writes directly to its slice of a shared output buffer.
37///
38/// The `_elem` variant is for kernel fusion where multiple operations are composed
39/// into a single loop, keeping intermediate values in registers instead of memory.
40///
41/// `$name`      – allocating function name
42/// `$name_to`   – zero-allocation function name
43/// `$name_elem` – element-wise function for fusion
44/// `$expr`      – expression mapping a scalar `f64 -> f64`
45#[macro_export]
46macro_rules! impl_vecmap {
47    ($name:ident, $name_to:ident, $name_elem:ident, $expr:expr) => {
48        /// Element-wise variant for kernel fusion.
49        ///
50        /// # Example
51        /// ```ignore
52        /// let ops = &[neg_elem, exp_elem, sin_elem];
53        /// execute_fused::<8>(input, output, ops);
54        /// // Equivalent to neg -> exp -> sin but with ONE memory read/write
55        /// ```
56        #[inline(always)]
57        pub fn $name_elem(x: f64) -> f64 {
58            $expr(x)
59        }
60        /// Zero-allocation variant: writes directly to caller's output buffer.
61        ///
62        /// Canonical implementation with full SIMD acceleration and null handling.
63        /// For parallel execution with pre-allocated output.
64        /// Panics if input.len() != output.len().
65        #[inline(always)]
66        pub fn $name_to<const LANES: usize>(
67            input: &[f64],
68            output: &mut [f64],
69            null_mask: Option<&Bitmask>,
70            null_count: Option<usize>,
71        ) -> Result<(), &'static str>
72        where {
73            let len = input.len();
74            assert_eq!(
75                len,
76                output.len(),
77                concat!(stringify!($name_to), ": input/output length mismatch")
78            );
79
80            if len == 0 {
81                return Ok(());
82            }
83            // decide if we need the null‐aware path
84            let has_nulls = match null_count {
85                Some(n) => n > 0,
86                None => null_mask.is_some(),
87            };
88            // dense (no nulls) path
89            if !has_nulls {
90                #[cfg(feature = "simd")]
91                {
92                    if is_simd_aligned(input) {
93                        use core::simd::Simd;
94                        let mut i = 0;
95                        while i + LANES <= len {
96                            let v = Simd::<f64, LANES>::from_slice(&input[i..i + LANES]);
97                            let mut r = Simd::<f64, LANES>::splat(0.0);
98                            for lane in 0..LANES {
99                                r[lane] = $expr(v[lane]);
100                            }
101                            output[i..i + LANES].copy_from_slice(r.as_array());
102                            i += LANES;
103                        }
104                        // scalar tail
105                        for j in i..len {
106                            output[j] = $expr(input[j]);
107                        }
108                        return Ok(());
109                    }
110                }
111
112                // Scalar fallback
113                for j in 0..len {
114                    output[j] = $expr(input[j]);
115                }
116                return Ok(());
117            }
118            // null‐aware path
119            let mb = null_mask.ok_or(concat!(
120                stringify!($name_to),
121                ": input mask required when nulls present"
122            ))?;
123
124            #[cfg(feature = "simd")]
125            {
126                // Check if input array is properly aligned for SIMD (cheap runtime check)
127                if is_simd_aligned(input) {
128                    use core::simd::{Mask, Simd};
129                    let mask_bytes = mb.as_bytes();
130                    let mut i = 0;
131                    while i + LANES <= len {
132                        // pull in the Arrow validity into a SIMD mask
133                        let lane_valid: Mask<i8, LANES> =
134                            bitmask_to_simd_mask::<LANES, i8>(mask_bytes, i, len);
135
136                        // Gather inputs (nulls -> NaN)
137                        let mut arr = [0.0f64; LANES];
138                        for j in 0..LANES {
139                            let idx = i + j;
140                            arr[j] = if unsafe { lane_valid.test_unchecked(j) } {
141                                input[idx]
142                            } else {
143                                f64::NAN
144                            };
145                        }
146                        let v = Simd::<f64, LANES>::from_array(arr);
147
148                        // Apply your scalar expr in SIMD form
149                        let mut r = Simd::<f64, LANES>::splat(0.0);
150                        for lane in 0..LANES {
151                            r[lane] = $expr(v[lane]);
152                        }
153                        let r_arr = r.as_array();
154                        output[i..i + LANES].copy_from_slice(r_arr);
155
156                        i += LANES;
157                    }
158                    // scalar tail
159                    for idx in i..len {
160                        if !unsafe { mb.get_unchecked(idx) } {
161                            output[idx] = f64::NAN;
162                        } else {
163                            output[idx] = $expr(input[idx]);
164                        }
165                    }
166
167                    return Ok(());
168                }
169                // Fall through to scalar path if alignment check failed
170            }
171
172            // Scalar fallback - alignment check failed
173            #[cfg(not(feature = "simd"))]
174            {
175                for idx in 0..len {
176                    if !unsafe { mb.get_unchecked(idx) } {
177                        output[idx] = f64::NAN;
178                    } else {
179                        output[idx] = $expr(input[idx]);
180                    }
181                }
182            }
183            #[cfg(feature = "simd")]
184            {
185                for idx in 0..len {
186                    if !unsafe { mb.get_unchecked(idx) } {
187                        output[idx] = f64::NAN;
188                    } else {
189                        output[idx] = $expr(input[idx]);
190                    }
191                }
192            }
193
194            Ok(())
195        }
196
197        /// Returns a new `FloatArray<f64>` with the function applied element-wise.
198        /// Propagates any input nulls (null lanes are not touched).
199        #[inline(always)]
200        pub fn $name<const LANES: usize>(
201            input: &[f64],
202            null_mask: Option<&Bitmask>,
203            null_count: Option<usize>,
204        ) -> Result<FloatArray<f64>, &'static str>
205        where {
206            let len = input.len();
207            // fast length‐0 case
208            if len == 0 {
209                return Ok(FloatArray::from_slice(&[]));
210            }
211
212            let mut out = Vec64::with_capacity(len);
213            // SAFETY: we just allocated capacity, extend len to match
214            unsafe {
215                out.set_len(len);
216            }
217
218            $name_to::<LANES>(input, out.as_mut_slice(), null_mask, null_count)?;
219
220            Ok(FloatArray::from_vec64(out, null_mask.cloned()))
221        }
222    };
223}
224
225// Basic operations
226impl_vecmap!(abs, abs_to, abs_elem, |x: f64| x.abs());
227impl_vecmap!(neg, neg_to, neg_elem, |x: f64| -x);
228impl_vecmap!(recip, recip_to, recip_elem, |x: f64| 1.0 / x);
229impl_vecmap!(sqrt, sqrt_to, sqrt_elem, |x: f64| x.sqrt());
230impl_vecmap!(cbrt, cbrt_to, cbrt_elem, |x: f64| x.cbrt());
231
232// Exponential and logarithmic
233impl_vecmap!(exp, exp_to, exp_elem, |x: f64| x.exp());
234impl_vecmap!(exp2, exp2_to, exp2_elem, |x: f64| x.exp2());
235impl_vecmap!(ln, ln_to, ln_elem, |x: f64| x.ln());
236impl_vecmap!(log2, log2_to, log2_elem, |x: f64| x.log2());
237impl_vecmap!(log10, log10_to, log10_elem, |x: f64| x.log10());
238
239// Trigonometric
240impl_vecmap!(sin, sin_to, sin_elem, |x: f64| x.sin());
241impl_vecmap!(cos, cos_to, cos_elem, |x: f64| x.cos());
242impl_vecmap!(tan, tan_to, tan_elem, |x: f64| x.tan());
243impl_vecmap!(asin, asin_to, asin_elem, |x: f64| x.asin());
244impl_vecmap!(acos, acos_to, acos_elem, |x: f64| x.acos());
245impl_vecmap!(atan, atan_to, atan_elem, |x: f64| x.atan());
246
247// Hyperbolic
248impl_vecmap!(sinh, sinh_to, sinh_elem, |x: f64| x.sinh());
249impl_vecmap!(cosh, cosh_to, cosh_elem, |x: f64| x.cosh());
250impl_vecmap!(tanh, tanh_to, tanh_elem, |x: f64| x.tanh());
251impl_vecmap!(asinh, asinh_to, asinh_elem, |x: f64| x.asinh());
252impl_vecmap!(acosh, acosh_to, acosh_elem, |x: f64| x.acosh());
253impl_vecmap!(atanh, atanh_to, atanh_elem, |x: f64| x.atanh());
254
255// Error functions
256impl_vecmap!(erf, erf_to, erf_elem, |x: f64| erf_fn(x));
257impl_vecmap!(erfc, erfc_to, erfc_elem, |x: f64| 1.0 - erf_fn(x));
258
259// Rounding
260impl_vecmap!(ceil, ceil_to, ceil_elem, |x: f64| x.ceil());
261impl_vecmap!(floor, floor_to, floor_elem, |x: f64| x.floor());
262impl_vecmap!(trunc, trunc_to, trunc_elem, |x: f64| x.trunc());
263impl_vecmap!(round, round_to, round_elem, |x: f64| x.round());
264impl_vecmap!(sign, sign_to, sign_elem, |x: f64| x.signum());
265
266// Activation functions
267impl_vecmap!(sigmoid, sigmoid_to, sigmoid_elem, |x: f64| 1.0
268    / (1.0 + (-x).exp()));
269impl_vecmap!(softplus, softplus_to, softplus_elem, |x: f64| if x > 20.0 {
270    x
271} else {
272    (1.0 + x.exp()).ln()
273});
274impl_vecmap!(relu, relu_to, relu_elem, |x: f64| if x > 0.0 {
275    x
276} else {
277    0.0
278});
279impl_vecmap!(gelu, gelu_to, gelu_elem, |x: f64| {
280    0.5 * x * (1.0 + erf_fn(x / std::f64::consts::SQRT_2))
281});