simd_kernels/kernels/scientific/
scalar.rs1use minarrow::{Bitmask, FloatArray, Vec64};
24use crate::utils::{
25 bitmask_to_simd_mask, is_simd_aligned, simd_mask_to_bitmask, write_global_bitmask_block,
26};
27use std::simd::{LaneCount, SupportedLaneCount};
28use crate::kernels::scientific::erf::erf as erf_fn;
29
30#[macro_export]
36macro_rules! impl_vecmap {
37 ($name:ident, $expr:expr) => {
38 #[inline(always)]
39 pub fn $name<const LANES: usize>(
40 input: &[f64],
41 null_mask: Option<&Bitmask>,
42 null_count: Option<usize>,
43 ) -> Result<FloatArray<f64>, &'static str>
44 where
45 LaneCount<LANES>: SupportedLaneCount,
46 {
47 let len = input.len();
48 if len == 0 {
50 return Ok(FloatArray::from_slice(&[]));
51 }
52 let has_nulls = match null_count {
54 Some(n) => n > 0,
55 None => null_mask.is_some(),
56 };
57 if !has_nulls {
59 let mut out = Vec64::with_capacity(len);
60 #[cfg(feature = "simd")]
61 {
62 if is_simd_aligned(input) {
64 use core::simd::Simd;
65 let mut i = 0;
66 while i + LANES <= len {
67 let v = Simd::<f64, LANES>::from_slice(&input[i..i + LANES]);
68 let mut r = Simd::<f64, LANES>::splat(0.0);
69 for lane in 0..LANES {
70 r[lane] = $expr(v[lane]);
71 }
72 out.extend_from_slice(r.as_array());
73 i += LANES;
74 }
75 for &x in &input[i..] {
77 out.push($expr(x));
78 }
79 return Ok(FloatArray::from_vec64(out, None));
80 }
81 }
83 #[cfg(not(feature = "simd"))]
85 {
86 for &x in input {
87 out.push($expr(x));
88 }
89 }
90 #[cfg(feature = "simd")]
91 {
92 for &x in input {
93 out.push($expr(x));
94 }
95 }
96 return Ok(FloatArray::from_vec64(out, None));
97 }
98 let mb = null_mask.expect(concat!(
100 stringify!($name),
101 ": input mask required when nulls present"
102 ));
103 let mut out = Vec64::with_capacity(len);
104 let mut out_mask = Bitmask::new_set_all(len, true);
105
106 #[cfg(feature = "simd")]
107 {
108 if is_simd_aligned(input) {
110 use core::simd::{Mask, Simd};
111 let mask_bytes = mb.as_bytes();
112 let mut i = 0;
113 while i + LANES <= len {
114 let lane_valid: Mask<i8, LANES> =
116 bitmask_to_simd_mask::<LANES, i8>(mask_bytes, i, len);
117
118 let mut arr = [0.0f64; LANES];
120 for j in 0..LANES {
121 let idx = i + j;
122 arr[j] = if unsafe { lane_valid.test_unchecked(j) } {
123 input[idx]
124 } else {
125 f64::NAN
126 };
127 }
128 let v = Simd::<f64, LANES>::from_array(arr);
129
130 let mut r = Simd::<f64, LANES>::splat(0.0);
132 for lane in 0..LANES {
133 r[lane] = $expr(v[lane]);
134 }
135 let r_arr = r.as_array();
136 out.extend_from_slice(r_arr);
137
138 let block = simd_mask_to_bitmask::<LANES, i8>(lane_valid, LANES);
140 write_global_bitmask_block(&mut out_mask, &block, i, LANES);
141
142 i += LANES;
143 }
144 for idx in i..len {
146 if !unsafe { mb.get_unchecked(idx) } {
147 out.push(f64::NAN);
148 unsafe { out_mask.set_unchecked(idx, false) };
149 } else {
150 let y = $expr(input[idx]);
151 out.push(y);
152 unsafe { out_mask.set_unchecked(idx, true) };
153 }
154 }
155
156 let null_bitmap = if out_mask.all_set() {
158 None
159 } else {
160 Some(out_mask)
161 };
162 return Ok(FloatArray {
163 data: out.into(),
164 null_mask: null_bitmap,
165 });
166 }
167 }
169
170 #[cfg(not(feature = "simd"))]
172 {
173 for idx in 0..len {
174 if !unsafe { mb.get_unchecked(idx) } {
175 out.push(f64::NAN);
176 unsafe { out_mask.set_unchecked(idx, false) };
177 } else {
178 let y = $expr(input[idx]);
179 out.push(y);
180 unsafe { out_mask.set_unchecked(idx, true) };
181 }
182 }
183 }
184 #[cfg(feature = "simd")]
185 {
186 for idx in 0..len {
187 if !unsafe { mb.get_unchecked(idx) } {
188 out.push(f64::NAN);
189 unsafe { out_mask.set_unchecked(idx, false) };
190 } else {
191 let y = $expr(input[idx]);
192 out.push(y);
193 unsafe { out_mask.set_unchecked(idx, true) };
194 }
195 }
196 }
197
198 let null_bitmap = if out_mask.all_set() {
200 None
201 } else {
202 Some(out_mask)
203 };
204 Ok(FloatArray {
205 data: out.into(),
206 null_mask: null_bitmap,
207 })
208 }
209 };
210}
211
212impl_vecmap!(abs, |x: f64| x.abs());
213impl_vecmap!(neg, |x: f64| -x);
214impl_vecmap!(recip, |x: f64| 1.0 / x);
215impl_vecmap!(sqrt, |x: f64| x.sqrt());
216impl_vecmap!(cbrt, |x: f64| x.cbrt());
217
218impl_vecmap!(exp, |x: f64| x.exp());
219impl_vecmap!(exp2, |x: f64| x.exp2());
220impl_vecmap!(ln, |x: f64| x.ln());
221impl_vecmap!(log2, |x: f64| x.log2());
222impl_vecmap!(log10, |x: f64| x.log10());
223
224impl_vecmap!(sin, |x: f64| x.sin());
225impl_vecmap!(cos, |x: f64| x.cos());
226impl_vecmap!(tan, |x: f64| x.tan());
227impl_vecmap!(asin, |x: f64| x.asin());
228impl_vecmap!(acos, |x: f64| x.acos());
229impl_vecmap!(atan, |x: f64| x.atan());
230
231impl_vecmap!(sinh, |x: f64| x.sinh());
232impl_vecmap!(cosh, |x: f64| x.cosh());
233impl_vecmap!(tanh, |x: f64| x.tanh());
234impl_vecmap!(asinh, |x: f64| x.asinh());
235impl_vecmap!(acosh, |x: f64| x.acosh());
236impl_vecmap!(atanh, |x: f64| x.atanh());
237
238impl_vecmap!(erf, |x: f64| erf_fn(x));
239impl_vecmap!(erfc, |x: f64| erf_fn(x));
240
241impl_vecmap!(ceil, |x: f64| x.ceil());
242impl_vecmap!(floor, |x: f64| x.floor());
243impl_vecmap!(trunc, |x: f64| x.trunc());
244impl_vecmap!(round, |x: f64| x.round());
245impl_vecmap!(sign, |x: f64| x.signum());
246
247impl_vecmap!(sigmoid, |x: f64| 1.0 / (1.0 + (-x).exp()));
248impl_vecmap!(softplus, |x: f64| (1.0 + x.exp()).ln());
249impl_vecmap!(relu, |x: f64| if x > 0.0 { x } else { 0.0 });
250impl_vecmap!(gelu, |x: f64| 0.5
251 * x
252 * (1.0 + erf_fn(x / std::f64::consts::SQRT_2)));