simd_kernels/kernels/scientific/
scalar.rs1use crate::kernels::scientific::erf::erf as erf_fn;
24use crate::utils::{bitmask_to_simd_mask, simd_mask_to_bitmask, write_global_bitmask_block};
25use minarrow::utils::is_simd_aligned;
26use minarrow::{Bitmask, FloatArray, Vec64};
27use std::simd::{LaneCount, SupportedLaneCount};
28
29#[macro_export]
35macro_rules! impl_vecmap {
36 ($name:ident, $expr:expr) => {
37 #[inline(always)]
38 pub fn $name<const LANES: usize>(
39 input: &[f64],
40 null_mask: Option<&Bitmask>,
41 null_count: Option<usize>,
42 ) -> Result<FloatArray<f64>, &'static str>
43 where
44 LaneCount<LANES>: SupportedLaneCount,
45 {
46 let len = input.len();
47 if len == 0 {
49 return Ok(FloatArray::from_slice(&[]));
50 }
51 let has_nulls = match null_count {
53 Some(n) => n > 0,
54 None => null_mask.is_some(),
55 };
56 if !has_nulls {
58 let mut out = Vec64::with_capacity(len);
59 #[cfg(feature = "simd")]
60 {
61 if is_simd_aligned(input) {
63 use core::simd::Simd;
64 let mut i = 0;
65 while i + LANES <= len {
66 let v = Simd::<f64, LANES>::from_slice(&input[i..i + LANES]);
67 let mut r = Simd::<f64, LANES>::splat(0.0);
68 for lane in 0..LANES {
69 r[lane] = $expr(v[lane]);
70 }
71 out.extend_from_slice(r.as_array());
72 i += LANES;
73 }
74 for &x in &input[i..] {
76 out.push($expr(x));
77 }
78 return Ok(FloatArray::from_vec64(out, None));
79 }
80 }
82 #[cfg(not(feature = "simd"))]
84 {
85 for &x in input {
86 out.push($expr(x));
87 }
88 }
89 #[cfg(feature = "simd")]
90 {
91 for &x in input {
92 out.push($expr(x));
93 }
94 }
95 return Ok(FloatArray::from_vec64(out, None));
96 }
97 let mb = null_mask.expect(concat!(
99 stringify!($name),
100 ": input mask required when nulls present"
101 ));
102 let mut out = Vec64::with_capacity(len);
103 let mut out_mask = Bitmask::new_set_all(len, true);
104
105 #[cfg(feature = "simd")]
106 {
107 if is_simd_aligned(input) {
109 use core::simd::{Mask, Simd};
110 let mask_bytes = mb.as_bytes();
111 let mut i = 0;
112 while i + LANES <= len {
113 let lane_valid: Mask<i8, LANES> =
115 bitmask_to_simd_mask::<LANES, i8>(mask_bytes, i, len);
116
117 let mut arr = [0.0f64; LANES];
119 for j in 0..LANES {
120 let idx = i + j;
121 arr[j] = if unsafe { lane_valid.test_unchecked(j) } {
122 input[idx]
123 } else {
124 f64::NAN
125 };
126 }
127 let v = Simd::<f64, LANES>::from_array(arr);
128
129 let mut r = Simd::<f64, LANES>::splat(0.0);
131 for lane in 0..LANES {
132 r[lane] = $expr(v[lane]);
133 }
134 let r_arr = r.as_array();
135 out.extend_from_slice(r_arr);
136
137 let block = simd_mask_to_bitmask::<LANES, i8>(lane_valid, LANES);
139 write_global_bitmask_block(&mut out_mask, &block, i, LANES);
140
141 i += LANES;
142 }
143 for idx in i..len {
145 if !unsafe { mb.get_unchecked(idx) } {
146 out.push(f64::NAN);
147 unsafe { out_mask.set_unchecked(idx, false) };
148 } else {
149 let y = $expr(input[idx]);
150 out.push(y);
151 unsafe { out_mask.set_unchecked(idx, true) };
152 }
153 }
154
155 let null_bitmap = if out_mask.all_set() {
157 None
158 } else {
159 Some(out_mask)
160 };
161 return Ok(FloatArray {
162 data: out.into(),
163 null_mask: null_bitmap,
164 });
165 }
166 }
168
169 #[cfg(not(feature = "simd"))]
171 {
172 for idx in 0..len {
173 if !unsafe { mb.get_unchecked(idx) } {
174 out.push(f64::NAN);
175 unsafe { out_mask.set_unchecked(idx, false) };
176 } else {
177 let y = $expr(input[idx]);
178 out.push(y);
179 unsafe { out_mask.set_unchecked(idx, true) };
180 }
181 }
182 }
183 #[cfg(feature = "simd")]
184 {
185 for idx in 0..len {
186 if !unsafe { mb.get_unchecked(idx) } {
187 out.push(f64::NAN);
188 unsafe { out_mask.set_unchecked(idx, false) };
189 } else {
190 let y = $expr(input[idx]);
191 out.push(y);
192 unsafe { out_mask.set_unchecked(idx, true) };
193 }
194 }
195 }
196
197 let null_bitmap = if out_mask.all_set() {
199 None
200 } else {
201 Some(out_mask)
202 };
203 Ok(FloatArray {
204 data: out.into(),
205 null_mask: null_bitmap,
206 })
207 }
208 };
209}
210
211impl_vecmap!(abs, |x: f64| x.abs());
212impl_vecmap!(neg, |x: f64| -x);
213impl_vecmap!(recip, |x: f64| 1.0 / x);
214impl_vecmap!(sqrt, |x: f64| x.sqrt());
215impl_vecmap!(cbrt, |x: f64| x.cbrt());
216
217impl_vecmap!(exp, |x: f64| x.exp());
218impl_vecmap!(exp2, |x: f64| x.exp2());
219impl_vecmap!(ln, |x: f64| x.ln());
220impl_vecmap!(log2, |x: f64| x.log2());
221impl_vecmap!(log10, |x: f64| x.log10());
222
223impl_vecmap!(sin, |x: f64| x.sin());
224impl_vecmap!(cos, |x: f64| x.cos());
225impl_vecmap!(tan, |x: f64| x.tan());
226impl_vecmap!(asin, |x: f64| x.asin());
227impl_vecmap!(acos, |x: f64| x.acos());
228impl_vecmap!(atan, |x: f64| x.atan());
229
230impl_vecmap!(sinh, |x: f64| x.sinh());
231impl_vecmap!(cosh, |x: f64| x.cosh());
232impl_vecmap!(tanh, |x: f64| x.tanh());
233impl_vecmap!(asinh, |x: f64| x.asinh());
234impl_vecmap!(acosh, |x: f64| x.acosh());
235impl_vecmap!(atanh, |x: f64| x.atanh());
236
237impl_vecmap!(erf, |x: f64| erf_fn(x));
238impl_vecmap!(erfc, |x: f64| erf_fn(x));
239
240impl_vecmap!(ceil, |x: f64| x.ceil());
241impl_vecmap!(floor, |x: f64| x.floor());
242impl_vecmap!(trunc, |x: f64| x.trunc());
243impl_vecmap!(round, |x: f64| x.round());
244impl_vecmap!(sign, |x: f64| x.signum());
245
246impl_vecmap!(sigmoid, |x: f64| 1.0 / (1.0 + (-x).exp()));
247impl_vecmap!(softplus, |x: f64| (1.0 + x.exp()).ln());
248impl_vecmap!(relu, |x: f64| if x > 0.0 { x } else { 0.0 });
249impl_vecmap!(gelu, |x: f64| 0.5
250 * x
251 * (1.0 + erf_fn(x / std::f64::consts::SQRT_2)));