simd_kernels/kernels/scientific/scalar.rs
1// Copyright (c) 2025 SpaceCell Enterprises Ltd
2// SPDX-License-Identifier: AGPL-3.0-or-later
3// Commercial licensing available. See LICENSE and LICENSING.md.
4
5//! # **Universal Scalar Function Module** - *High-Performance Element-Wise Mathematical Operations*
6//!
7//! Full suite of vectorised mathematical functions that operate
8//! element-wise on arrays of floating-point values. It serves as the computational backbone for
9//! mathematical operations across the simd-kernels crate, for both scalar and SIMD-accelerated
10//! implementations with opt-in Arrow-compatible null masking.
11//!
12//! These are the semantic equivalent of *numpy ufuncs* in Python.
13//!
14//! ## Overview
15//!
16//! Universal scalar functions are fundamental building blocks for:
17//! - **Data Preprocessing**: Normalisations, transformations, and scaling operations
18//! - **Scientific Computing**: Mathematical transformations and special function evaluation
19//! - **Machine Learning**: Activation functions, feature engineering, and data preparation
20//! - **Signal Processing**: Filtering, transforms, and spectral analysis
21//! - **Statistics**: Data transformations and statistical preprocessing
22//! - **Financial Mathematics**: Risk calculations and price transformations
23
24use crate::kernels::scientific::erf::erf as erf_fn;
25use crate::utils::bitmask_to_simd_mask;
26use minarrow::utils::is_simd_aligned;
27use minarrow::{Bitmask, FloatArray, Vec64};
28
29/// Generates a mapping kernel in three variants:
30///
31/// 1. `$name_to` - Zero-allocation canonical implementation, writes to caller's buffer
32/// 2. `$name` - Allocates internally, delegates to `$name_to`
33/// 3. `$name_elem` - Element-wise `fn(f64) -> f64` for kernel fusion
34///
35/// The `_to` variant is for pre-allocated parallel execution where each chunk
36/// writes directly to its slice of a shared output buffer.
37///
38/// The `_elem` variant is for kernel fusion where multiple operations are composed
39/// into a single loop, keeping intermediate values in registers instead of memory.
40///
41/// `$name` – allocating function name
42/// `$name_to` – zero-allocation function name
43/// `$name_elem` – element-wise function for fusion
44/// `$expr` – expression mapping a scalar `f64 -> f64`
45#[macro_export]
46macro_rules! impl_vecmap {
47 ($name:ident, $name_to:ident, $name_elem:ident, $expr:expr) => {
48 /// Element-wise variant for kernel fusion.
49 ///
50 /// # Example
51 /// ```ignore
52 /// let ops = &[neg_elem, exp_elem, sin_elem];
53 /// execute_fused::<8>(input, output, ops);
54 /// // Equivalent to neg -> exp -> sin but with ONE memory read/write
55 /// ```
56 #[inline(always)]
57 pub fn $name_elem(x: f64) -> f64 {
58 $expr(x)
59 }
60 /// Zero-allocation variant: writes directly to caller's output buffer.
61 ///
62 /// Canonical implementation with full SIMD acceleration and null handling.
63 /// For parallel execution with pre-allocated output.
64 /// Panics if input.len() != output.len().
65 #[inline(always)]
66 pub fn $name_to<const LANES: usize>(
67 input: &[f64],
68 output: &mut [f64],
69 null_mask: Option<&Bitmask>,
70 null_count: Option<usize>,
71 ) -> Result<(), &'static str>
72 where
73 {
74 let len = input.len();
75 assert_eq!(
76 len,
77 output.len(),
78 concat!(stringify!($name_to), ": input/output length mismatch")
79 );
80
81 if len == 0 {
82 return Ok(());
83 }
84 // decide if we need the null‐aware path
85 let has_nulls = match null_count {
86 Some(n) => n > 0,
87 None => null_mask.is_some(),
88 };
89 // dense (no nulls) path
90 if !has_nulls {
91 #[cfg(feature = "simd")]
92 {
93 if is_simd_aligned(input) {
94 use core::simd::Simd;
95 let mut i = 0;
96 while i + LANES <= len {
97 let v = Simd::<f64, LANES>::from_slice(&input[i..i + LANES]);
98 let mut r = Simd::<f64, LANES>::splat(0.0);
99 for lane in 0..LANES {
100 r[lane] = $expr(v[lane]);
101 }
102 output[i..i + LANES].copy_from_slice(r.as_array());
103 i += LANES;
104 }
105 // scalar tail
106 for j in i..len {
107 output[j] = $expr(input[j]);
108 }
109 return Ok(());
110 }
111 }
112
113 // Scalar fallback
114 for j in 0..len {
115 output[j] = $expr(input[j]);
116 }
117 return Ok(());
118 }
119 // null‐aware path
120 let mb = null_mask.ok_or(concat!(
121 stringify!($name_to),
122 ": input mask required when nulls present"
123 ))?;
124
125 #[cfg(feature = "simd")]
126 {
127 // Check if input array is properly aligned for SIMD (cheap runtime check)
128 if is_simd_aligned(input) {
129 use core::simd::{Mask, Simd};
130 let mask_bytes = mb.as_bytes();
131 let mut i = 0;
132 while i + LANES <= len {
133 // pull in the Arrow validity into a SIMD mask
134 let lane_valid: Mask<i8, LANES> =
135 bitmask_to_simd_mask::<LANES, i8>(mask_bytes, i, len);
136
137 // Gather inputs (nulls -> NaN)
138 let mut arr = [0.0f64; LANES];
139 for j in 0..LANES {
140 let idx = i + j;
141 arr[j] = if unsafe { lane_valid.test_unchecked(j) } {
142 input[idx]
143 } else {
144 f64::NAN
145 };
146 }
147 let v = Simd::<f64, LANES>::from_array(arr);
148
149 // Apply your scalar expr in SIMD form
150 let mut r = Simd::<f64, LANES>::splat(0.0);
151 for lane in 0..LANES {
152 r[lane] = $expr(v[lane]);
153 }
154 let r_arr = r.as_array();
155 output[i..i + LANES].copy_from_slice(r_arr);
156
157 i += LANES;
158 }
159 // scalar tail
160 for idx in i..len {
161 if !unsafe { mb.get_unchecked(idx) } {
162 output[idx] = f64::NAN;
163 } else {
164 output[idx] = $expr(input[idx]);
165 }
166 }
167
168 return Ok(());
169 }
170 // Fall through to scalar path if alignment check failed
171 }
172
173 // Scalar fallback - alignment check failed
174 #[cfg(not(feature = "simd"))]
175 {
176 for idx in 0..len {
177 if !unsafe { mb.get_unchecked(idx) } {
178 output[idx] = f64::NAN;
179 } else {
180 output[idx] = $expr(input[idx]);
181 }
182 }
183 }
184 #[cfg(feature = "simd")]
185 {
186 for idx in 0..len {
187 if !unsafe { mb.get_unchecked(idx) } {
188 output[idx] = f64::NAN;
189 } else {
190 output[idx] = $expr(input[idx]);
191 }
192 }
193 }
194
195 Ok(())
196 }
197
198 /// Returns a new `FloatArray<f64>` with the function applied element-wise.
199 /// Propagates any input nulls (null lanes are not touched).
200 #[inline(always)]
201 pub fn $name<const LANES: usize>(
202 input: &[f64],
203 null_mask: Option<&Bitmask>,
204 null_count: Option<usize>,
205 ) -> Result<FloatArray<f64>, &'static str>
206 where
207 {
208 let len = input.len();
209 // fast length‐0 case
210 if len == 0 {
211 return Ok(FloatArray::from_slice(&[]));
212 }
213
214 let mut out = Vec64::with_capacity(len);
215 // SAFETY: we just allocated capacity, extend len to match
216 unsafe {
217 out.set_len(len);
218 }
219
220 $name_to::<LANES>(input, out.as_mut_slice(), null_mask, null_count)?;
221
222 Ok(FloatArray::from_vec64(out, null_mask.cloned()))
223 }
224 };
225}
226
227// Basic operations
228impl_vecmap!(abs, abs_to, abs_elem, |x: f64| x.abs());
229impl_vecmap!(neg, neg_to, neg_elem, |x: f64| -x);
230impl_vecmap!(recip, recip_to, recip_elem, |x: f64| 1.0 / x);
231impl_vecmap!(sqrt, sqrt_to, sqrt_elem, |x: f64| x.sqrt());
232impl_vecmap!(cbrt, cbrt_to, cbrt_elem, |x: f64| x.cbrt());
233
234// Exponential and logarithmic
235impl_vecmap!(exp, exp_to, exp_elem, |x: f64| x.exp());
236impl_vecmap!(exp2, exp2_to, exp2_elem, |x: f64| x.exp2());
237impl_vecmap!(ln, ln_to, ln_elem, |x: f64| x.ln());
238impl_vecmap!(log2, log2_to, log2_elem, |x: f64| x.log2());
239impl_vecmap!(log10, log10_to, log10_elem, |x: f64| x.log10());
240
241// Trigonometric
242impl_vecmap!(sin, sin_to, sin_elem, |x: f64| x.sin());
243impl_vecmap!(cos, cos_to, cos_elem, |x: f64| x.cos());
244impl_vecmap!(tan, tan_to, tan_elem, |x: f64| x.tan());
245impl_vecmap!(asin, asin_to, asin_elem, |x: f64| x.asin());
246impl_vecmap!(acos, acos_to, acos_elem, |x: f64| x.acos());
247impl_vecmap!(atan, atan_to, atan_elem, |x: f64| x.atan());
248
249// Hyperbolic
250impl_vecmap!(sinh, sinh_to, sinh_elem, |x: f64| x.sinh());
251impl_vecmap!(cosh, cosh_to, cosh_elem, |x: f64| x.cosh());
252impl_vecmap!(tanh, tanh_to, tanh_elem, |x: f64| x.tanh());
253impl_vecmap!(asinh, asinh_to, asinh_elem, |x: f64| x.asinh());
254impl_vecmap!(acosh, acosh_to, acosh_elem, |x: f64| x.acosh());
255impl_vecmap!(atanh, atanh_to, atanh_elem, |x: f64| x.atanh());
256
257// Error functions
258impl_vecmap!(erf, erf_to, erf_elem, |x: f64| erf_fn(x));
259impl_vecmap!(erfc, erfc_to, erfc_elem, |x: f64| 1.0 - erf_fn(x));
260
261// Rounding
262impl_vecmap!(ceil, ceil_to, ceil_elem, |x: f64| x.ceil());
263impl_vecmap!(floor, floor_to, floor_elem, |x: f64| x.floor());
264impl_vecmap!(trunc, trunc_to, trunc_elem, |x: f64| x.trunc());
265impl_vecmap!(round, round_to, round_elem, |x: f64| x.round());
266impl_vecmap!(sign, sign_to, sign_elem, |x: f64| x.signum());
267
268// Activation functions
269impl_vecmap!(sigmoid, sigmoid_to, sigmoid_elem, |x: f64| 1.0
270 / (1.0 + (-x).exp()));
271impl_vecmap!(softplus, softplus_to, softplus_elem, |x: f64| (1.0
272 + x.exp())
273.ln());
274impl_vecmap!(relu, relu_to, relu_elem, |x: f64| if x > 0.0 {
275 x
276} else {
277 0.0
278});
279impl_vecmap!(gelu, gelu_to, gelu_elem, |x: f64| {
280 0.5 * x * (1.0 + erf_fn(x / std::f64::consts::SQRT_2))
281});