simd_kernels/kernels/scientific/
scalar.rs

1// Copyright Peter Bower 2025. All Rights Reserved.
2// Licensed under Mozilla Public License (MPL) 2.0.
3
4//! # **Universal Scalar Function Module** - *High-Performance Element-Wise Mathematical Operations*
5//!
6//! Full suite of vectorised mathematical functions that operate
7//! element-wise on arrays of floating-point values. It serves as the computational backbone for
8//! mathematical operations across the simd-kernels crate, for both scalar and SIMD-accelerated
9//! implementations with opt-in Arrow-compatible null masking.
10//!
11//! These are the semantic equivalent of *numpy ufuncs* in Python.
12//!
13//! ## Overview
14//!
15//! Universal scalar functions are fundamental building blocks for:
16//! - **Data Preprocessing**: Normalisations, transformations, and scaling operations
17//! - **Scientific Computing**: Mathematical transformations and special function evaluation
18//! - **Machine Learning**: Activation functions, feature engineering, and data preparation
19//! - **Signal Processing**: Filtering, transforms, and spectral analysis
20//! - **Statistics**: Data transformations and statistical preprocessing
21//! - **Financial Mathematics**: Risk calculations and price transformations
22
23use crate::kernels::scientific::erf::erf as erf_fn;
24use crate::utils::{bitmask_to_simd_mask, simd_mask_to_bitmask, write_global_bitmask_block};
25use minarrow::utils::is_simd_aligned;
26use minarrow::{Bitmask, FloatArray, Vec64};
27use std::simd::{LaneCount, SupportedLaneCount};
28
29/// Generates a mapping kernel that returns a FloatArray<f64>,
30/// propagating any input nulls (and never touching lanes that were null).
31///
32/// `$name`  – function name to create  
33/// `$expr`  – expression mapping a scalar `f64 → f64`  
34#[macro_export]
35macro_rules! impl_vecmap {
36    ($name:ident, $expr:expr) => {
37        #[inline(always)]
38        pub fn $name<const LANES: usize>(
39            input: &[f64],
40            null_mask: Option<&Bitmask>,
41            null_count: Option<usize>,
42        ) -> Result<FloatArray<f64>, &'static str>
43        where
44            LaneCount<LANES>: SupportedLaneCount,
45        {
46            let len = input.len();
47            // fast length‐0 case
48            if len == 0 {
49                return Ok(FloatArray::from_slice(&[]));
50            }
51            // decide if we need the null‐aware path
52            let has_nulls = match null_count {
53                Some(n) => n > 0,
54                None => null_mask.is_some(),
55            };
56            // dense (no nulls) path
57            if !has_nulls {
58                let mut out = Vec64::with_capacity(len);
59                #[cfg(feature = "simd")]
60                {
61                    // Check if input array is properly aligned for SIMD (cheap runtime check)
62                    if is_simd_aligned(input) {
63                        use core::simd::Simd;
64                        let mut i = 0;
65                        while i + LANES <= len {
66                            let v = Simd::<f64, LANES>::from_slice(&input[i..i + LANES]);
67                            let mut r = Simd::<f64, LANES>::splat(0.0);
68                            for lane in 0..LANES {
69                                r[lane] = $expr(v[lane]);
70                            }
71                            out.extend_from_slice(r.as_array());
72                            i += LANES;
73                        }
74                        // scalar tail
75                        for &x in &input[i..] {
76                            out.push($expr(x));
77                        }
78                        return Ok(FloatArray::from_vec64(out, None));
79                    }
80                    // Fall through to scalar path if alignment check failed
81                }
82                // Scalar fallback - alignment check failed
83                #[cfg(not(feature = "simd"))]
84                {
85                    for &x in input {
86                        out.push($expr(x));
87                    }
88                }
89                #[cfg(feature = "simd")]
90                {
91                    for &x in input {
92                        out.push($expr(x));
93                    }
94                }
95                return Ok(FloatArray::from_vec64(out, None));
96            }
97            // null‐aware path
98            let mb = null_mask.expect(concat!(
99                stringify!($name),
100                ": input mask required when nulls present"
101            ));
102            let mut out = Vec64::with_capacity(len);
103            let mut out_mask = Bitmask::new_set_all(len, true);
104
105            #[cfg(feature = "simd")]
106            {
107                // Check if input array is properly aligned for SIMD (cheap runtime check)
108                if is_simd_aligned(input) {
109                    use core::simd::{Mask, Simd};
110                    let mask_bytes = mb.as_bytes();
111                    let mut i = 0;
112                    while i + LANES <= len {
113                        // pull in the Arrow validity into a SIMD mask
114                        let lane_valid: Mask<i8, LANES> =
115                            bitmask_to_simd_mask::<LANES, i8>(mask_bytes, i, len);
116
117                        // Gather inputs (nulls → NaN)
118                        let mut arr = [0.0f64; LANES];
119                        for j in 0..LANES {
120                            let idx = i + j;
121                            arr[j] = if unsafe { lane_valid.test_unchecked(j) } {
122                                input[idx]
123                            } else {
124                                f64::NAN
125                            };
126                        }
127                        let v = Simd::<f64, LANES>::from_array(arr);
128
129                        // Apply your scalar expr in SIMD form
130                        let mut r = Simd::<f64, LANES>::splat(0.0);
131                        for lane in 0..LANES {
132                            r[lane] = $expr(v[lane]);
133                        }
134                        let r_arr = r.as_array();
135                        out.extend_from_slice(r_arr);
136
137                        // write those same validity bits back into our new null‐bitmap
138                        let block = simd_mask_to_bitmask::<LANES, i8>(lane_valid, LANES);
139                        write_global_bitmask_block(&mut out_mask, &block, i, LANES);
140
141                        i += LANES;
142                    }
143                    // scalar tail
144                    for idx in i..len {
145                        if !unsafe { mb.get_unchecked(idx) } {
146                            out.push(f64::NAN);
147                            unsafe { out_mask.set_unchecked(idx, false) };
148                        } else {
149                            let y = $expr(input[idx]);
150                            out.push(y);
151                            unsafe { out_mask.set_unchecked(idx, true) };
152                        }
153                    }
154
155                    // if every lane stayed valid, drop the mask
156                    let null_bitmap = if out_mask.all_set() {
157                        None
158                    } else {
159                        Some(out_mask)
160                    };
161                    return Ok(FloatArray {
162                        data: out.into(),
163                        null_mask: null_bitmap,
164                    });
165                }
166                // Fall through to scalar path if alignment check failed
167            }
168
169            // Scalar fallback - alignment check failed
170            #[cfg(not(feature = "simd"))]
171            {
172                for idx in 0..len {
173                    if !unsafe { mb.get_unchecked(idx) } {
174                        out.push(f64::NAN);
175                        unsafe { out_mask.set_unchecked(idx, false) };
176                    } else {
177                        let y = $expr(input[idx]);
178                        out.push(y);
179                        unsafe { out_mask.set_unchecked(idx, true) };
180                    }
181                }
182            }
183            #[cfg(feature = "simd")]
184            {
185                for idx in 0..len {
186                    if !unsafe { mb.get_unchecked(idx) } {
187                        out.push(f64::NAN);
188                        unsafe { out_mask.set_unchecked(idx, false) };
189                    } else {
190                        let y = $expr(input[idx]);
191                        out.push(y);
192                        unsafe { out_mask.set_unchecked(idx, true) };
193                    }
194                }
195            }
196
197            // if every lane stayed valid, drop the mask
198            let null_bitmap = if out_mask.all_set() {
199                None
200            } else {
201                Some(out_mask)
202            };
203            Ok(FloatArray {
204                data: out.into(),
205                null_mask: null_bitmap,
206            })
207        }
208    };
209}
210
211impl_vecmap!(abs, |x: f64| x.abs());
212impl_vecmap!(neg, |x: f64| -x);
213impl_vecmap!(recip, |x: f64| 1.0 / x);
214impl_vecmap!(sqrt, |x: f64| x.sqrt());
215impl_vecmap!(cbrt, |x: f64| x.cbrt());
216
217impl_vecmap!(exp, |x: f64| x.exp());
218impl_vecmap!(exp2, |x: f64| x.exp2());
219impl_vecmap!(ln, |x: f64| x.ln());
220impl_vecmap!(log2, |x: f64| x.log2());
221impl_vecmap!(log10, |x: f64| x.log10());
222
223impl_vecmap!(sin, |x: f64| x.sin());
224impl_vecmap!(cos, |x: f64| x.cos());
225impl_vecmap!(tan, |x: f64| x.tan());
226impl_vecmap!(asin, |x: f64| x.asin());
227impl_vecmap!(acos, |x: f64| x.acos());
228impl_vecmap!(atan, |x: f64| x.atan());
229
230impl_vecmap!(sinh, |x: f64| x.sinh());
231impl_vecmap!(cosh, |x: f64| x.cosh());
232impl_vecmap!(tanh, |x: f64| x.tanh());
233impl_vecmap!(asinh, |x: f64| x.asinh());
234impl_vecmap!(acosh, |x: f64| x.acosh());
235impl_vecmap!(atanh, |x: f64| x.atanh());
236
237impl_vecmap!(erf, |x: f64| erf_fn(x));
238impl_vecmap!(erfc, |x: f64| erf_fn(x));
239
240impl_vecmap!(ceil, |x: f64| x.ceil());
241impl_vecmap!(floor, |x: f64| x.floor());
242impl_vecmap!(trunc, |x: f64| x.trunc());
243impl_vecmap!(round, |x: f64| x.round());
244impl_vecmap!(sign, |x: f64| x.signum());
245
246impl_vecmap!(sigmoid, |x: f64| 1.0 / (1.0 + (-x).exp()));
247impl_vecmap!(softplus, |x: f64| (1.0 + x.exp()).ln());
248impl_vecmap!(relu, |x: f64| if x > 0.0 { x } else { 0.0 });
249impl_vecmap!(gelu, |x: f64| 0.5
250    * x
251    * (1.0 + erf_fn(x / std::f64::consts::SQRT_2)));