simd_kernels/kernels/scientific/
scalar.rs

1// Copyright Peter Bower 2025. All Rights Reserved.
2// Licensed under Mozilla Public License (MPL) 2.0.
3
4//! # **Universal Scalar Function Module** - *High-Performance Element-Wise Mathematical Operations*
5//!
6//! Full suite of vectorised mathematical functions that operate
7//! element-wise on arrays of floating-point values. It serves as the computational backbone for
8//! mathematical operations across the simd-kernels crate, for both scalar and SIMD-accelerated
9//! implementations with opt-in Arrow-compatible null masking.
10//! 
11//! These are the semantic equivalent of *numpy ufuncs* in Python.
12//!
13//! ## Overview
14//!
15//! Universal scalar functions are fundamental building blocks for:
16//! - **Data Preprocessing**: Normalisations, transformations, and scaling operations
17//! - **Scientific Computing**: Mathematical transformations and special function evaluation
18//! - **Machine Learning**: Activation functions, feature engineering, and data preparation
19//! - **Signal Processing**: Filtering, transforms, and spectral analysis
20//! - **Statistics**: Data transformations and statistical preprocessing
21//! - **Financial Mathematics**: Risk calculations and price transformations
22
23use minarrow::{Bitmask, FloatArray, Vec64};
24use crate::utils::{
25    bitmask_to_simd_mask, is_simd_aligned, simd_mask_to_bitmask, write_global_bitmask_block,
26};
27use std::simd::{LaneCount, SupportedLaneCount};
28use crate::kernels::scientific::erf::erf as erf_fn;
29
30/// Generates a mapping kernel that returns a FloatArray<f64>,
31/// propagating any input nulls (and never touching lanes that were null).
32///
33/// `$name`  – function name to create  
34/// `$expr`  – expression mapping a scalar `f64 → f64`  
35#[macro_export]
36macro_rules! impl_vecmap {
37    ($name:ident, $expr:expr) => {
38        #[inline(always)]
39        pub fn $name<const LANES: usize>(
40            input: &[f64],
41            null_mask: Option<&Bitmask>,
42            null_count: Option<usize>,
43        ) -> Result<FloatArray<f64>, &'static str>
44        where
45            LaneCount<LANES>: SupportedLaneCount,
46        {
47            let len = input.len();
48            // fast length‐0 case
49            if len == 0 {
50                return Ok(FloatArray::from_slice(&[]));
51            }
52            // decide if we need the null‐aware path
53            let has_nulls = match null_count {
54                Some(n) => n > 0,
55                None => null_mask.is_some(),
56            };
57            // dense (no nulls) path
58            if !has_nulls {
59                let mut out = Vec64::with_capacity(len);
60                #[cfg(feature = "simd")]
61                {
62                    // Check if input array is properly aligned for SIMD (cheap runtime check)
63                    if is_simd_aligned(input) {
64                        use core::simd::Simd;
65                        let mut i = 0;
66                        while i + LANES <= len {
67                            let v = Simd::<f64, LANES>::from_slice(&input[i..i + LANES]);
68                            let mut r = Simd::<f64, LANES>::splat(0.0);
69                            for lane in 0..LANES {
70                                r[lane] = $expr(v[lane]);
71                            }
72                            out.extend_from_slice(r.as_array());
73                            i += LANES;
74                        }
75                        // scalar tail
76                        for &x in &input[i..] {
77                            out.push($expr(x));
78                        }
79                        return Ok(FloatArray::from_vec64(out, None));
80                    }
81                    // Fall through to scalar path if alignment check failed
82                }
83                // Scalar fallback - alignment check failed
84                #[cfg(not(feature = "simd"))]
85                {
86                    for &x in input {
87                        out.push($expr(x));
88                    }
89                }
90                #[cfg(feature = "simd")]
91                {
92                    for &x in input {
93                        out.push($expr(x));
94                    }
95                }
96                return Ok(FloatArray::from_vec64(out, None));
97            }
98            // null‐aware path
99            let mb = null_mask.expect(concat!(
100                stringify!($name),
101                ": input mask required when nulls present"
102            ));
103            let mut out = Vec64::with_capacity(len);
104            let mut out_mask = Bitmask::new_set_all(len, true);
105
106            #[cfg(feature = "simd")]
107            {
108                // Check if input array is properly aligned for SIMD (cheap runtime check)
109                if is_simd_aligned(input) {
110                    use core::simd::{Mask, Simd};
111                    let mask_bytes = mb.as_bytes();
112                    let mut i = 0;
113                    while i + LANES <= len {
114                        // pull in the Arrow validity into a SIMD mask
115                        let lane_valid: Mask<i8, LANES> =
116                            bitmask_to_simd_mask::<LANES, i8>(mask_bytes, i, len);
117
118                        // Gather inputs (nulls → NaN)
119                        let mut arr = [0.0f64; LANES];
120                        for j in 0..LANES {
121                            let idx = i + j;
122                            arr[j] = if unsafe { lane_valid.test_unchecked(j) } {
123                                input[idx]
124                            } else {
125                                f64::NAN
126                            };
127                        }
128                        let v = Simd::<f64, LANES>::from_array(arr);
129
130                        // Apply your scalar expr in SIMD form
131                        let mut r = Simd::<f64, LANES>::splat(0.0);
132                        for lane in 0..LANES {
133                            r[lane] = $expr(v[lane]);
134                        }
135                        let r_arr = r.as_array();
136                        out.extend_from_slice(r_arr);
137
138                        // write those same validity bits back into our new null‐bitmap
139                        let block = simd_mask_to_bitmask::<LANES, i8>(lane_valid, LANES);
140                        write_global_bitmask_block(&mut out_mask, &block, i, LANES);
141
142                        i += LANES;
143                    }
144                    // scalar tail
145                    for idx in i..len {
146                        if !unsafe { mb.get_unchecked(idx) } {
147                            out.push(f64::NAN);
148                            unsafe { out_mask.set_unchecked(idx, false) };
149                        } else {
150                            let y = $expr(input[idx]);
151                            out.push(y);
152                            unsafe { out_mask.set_unchecked(idx, true) };
153                        }
154                    }
155
156                    // if every lane stayed valid, drop the mask
157                    let null_bitmap = if out_mask.all_set() {
158                        None
159                    } else {
160                        Some(out_mask)
161                    };
162                    return Ok(FloatArray {
163                        data: out.into(),
164                        null_mask: null_bitmap,
165                    });
166                }
167                // Fall through to scalar path if alignment check failed
168            }
169
170            // Scalar fallback - alignment check failed
171            #[cfg(not(feature = "simd"))]
172            {
173                for idx in 0..len {
174                    if !unsafe { mb.get_unchecked(idx) } {
175                        out.push(f64::NAN);
176                        unsafe { out_mask.set_unchecked(idx, false) };
177                    } else {
178                        let y = $expr(input[idx]);
179                        out.push(y);
180                        unsafe { out_mask.set_unchecked(idx, true) };
181                    }
182                }
183            }
184            #[cfg(feature = "simd")]
185            {
186                for idx in 0..len {
187                    if !unsafe { mb.get_unchecked(idx) } {
188                        out.push(f64::NAN);
189                        unsafe { out_mask.set_unchecked(idx, false) };
190                    } else {
191                        let y = $expr(input[idx]);
192                        out.push(y);
193                        unsafe { out_mask.set_unchecked(idx, true) };
194                    }
195                }
196            }
197
198            // if every lane stayed valid, drop the mask
199            let null_bitmap = if out_mask.all_set() {
200                None
201            } else {
202                Some(out_mask)
203            };
204            Ok(FloatArray {
205                data: out.into(),
206                null_mask: null_bitmap,
207            })
208        }
209    };
210}
211
212impl_vecmap!(abs, |x: f64| x.abs());
213impl_vecmap!(neg, |x: f64| -x);
214impl_vecmap!(recip, |x: f64| 1.0 / x);
215impl_vecmap!(sqrt, |x: f64| x.sqrt());
216impl_vecmap!(cbrt, |x: f64| x.cbrt());
217
218impl_vecmap!(exp, |x: f64| x.exp());
219impl_vecmap!(exp2, |x: f64| x.exp2());
220impl_vecmap!(ln, |x: f64| x.ln());
221impl_vecmap!(log2, |x: f64| x.log2());
222impl_vecmap!(log10, |x: f64| x.log10());
223
224impl_vecmap!(sin, |x: f64| x.sin());
225impl_vecmap!(cos, |x: f64| x.cos());
226impl_vecmap!(tan, |x: f64| x.tan());
227impl_vecmap!(asin, |x: f64| x.asin());
228impl_vecmap!(acos, |x: f64| x.acos());
229impl_vecmap!(atan, |x: f64| x.atan());
230
231impl_vecmap!(sinh, |x: f64| x.sinh());
232impl_vecmap!(cosh, |x: f64| x.cosh());
233impl_vecmap!(tanh, |x: f64| x.tanh());
234impl_vecmap!(asinh, |x: f64| x.asinh());
235impl_vecmap!(acosh, |x: f64| x.acosh());
236impl_vecmap!(atanh, |x: f64| x.atanh());
237
238impl_vecmap!(erf, |x: f64| erf_fn(x));
239impl_vecmap!(erfc, |x: f64| erf_fn(x));
240
241impl_vecmap!(ceil, |x: f64| x.ceil());
242impl_vecmap!(floor, |x: f64| x.floor());
243impl_vecmap!(trunc, |x: f64| x.trunc());
244impl_vecmap!(round, |x: f64| x.round());
245impl_vecmap!(sign, |x: f64| x.signum());
246
247impl_vecmap!(sigmoid, |x: f64| 1.0 / (1.0 + (-x).exp()));
248impl_vecmap!(softplus, |x: f64| (1.0 + x.exp()).ln());
249impl_vecmap!(relu, |x: f64| if x > 0.0 { x } else { 0.0 });
250impl_vecmap!(gelu, |x: f64| 0.5
251    * x
252    * (1.0 + erf_fn(x / std::f64::consts::SQRT_2)));