1include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
25
26use crate::errors::KernelError;
27#[cfg(feature = "simd")]
28use crate::kernels::arithmetic::simd::{
29 float_dense_body_f32_simd, float_dense_body_f64_simd, float_masked_body_f32_simd,
30 float_masked_body_f64_simd, fma_dense_body_f32_simd, fma_dense_body_f64_simd,
31 fma_masked_body_f32_simd, fma_masked_body_f64_simd, int_dense_body_simd, int_masked_body_simd,
32};
33use crate::kernels::arithmetic::std::{
34 float_dense_body_std, float_masked_body_std, int_dense_body_std, int_masked_body_std,
35};
36use crate::operators::ArithmeticOperator::{self};
37use crate::utils::confirm_equal_len;
38#[cfg(feature = "simd")]
39use crate::utils::is_simd_aligned;
40#[cfg(feature = "datetime")]
41use minarrow::DatetimeAVT;
42#[cfg(feature = "datetime")]
43use minarrow::DatetimeArray;
44use minarrow::structs::variants::float::FloatArray;
45use minarrow::structs::variants::integer::IntegerArray;
46use minarrow::{Bitmask, Vec64};
47
48macro_rules! impl_apply_int {
54 ($fn_name:ident, $ty:ty, $lanes:expr) => {
55 #[doc = concat!(
56 "Performs element-wise integer `ArithmeticOperator` over two `&[", stringify!($ty),
57 "]`, SIMD-accelerated using ", stringify!($lanes), " lanes if available, \
58 otherwise falls back to scalar. \
59 Returns `IntegerArray<", stringify!($ty), ">` with appropriate null-mask handling."
60 )]
61 #[inline(always)]
62 pub fn $fn_name(
63 lhs: &[$ty],
64 rhs: &[$ty],
65 op: ArithmeticOperator,
66 mask: Option<&Bitmask>
67 ) -> Result<IntegerArray<$ty>, KernelError> {
68 let len = lhs.len();
69 confirm_equal_len("apply numeric: length mismatch", len, rhs.len())?;
70
71 #[cfg(feature = "simd")]
72 {
73 if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
75 let mut out = Vec64::with_capacity(len);
77 unsafe { out.set_len(len) };
78 match mask {
79 Some(mask) => {
80 let mut out_mask = minarrow::Bitmask::new_set_all(len, true);
81 int_masked_body_simd::<$ty, $lanes>(op, lhs, rhs, mask, &mut out, &mut out_mask);
82 return Ok(IntegerArray {
83 data: out.into(),
84 null_mask: Some(out_mask),
85 });
86 }
87 None => {
88 int_dense_body_simd::<$ty, $lanes>(op, lhs, rhs, &mut out);
89 return Ok(IntegerArray {
90 data: out.into(),
91 null_mask: None,
92 });
93 }
94 }
95 }
96 }
98
99 let mut out = Vec64::with_capacity(len);
101 unsafe { out.set_len(len) };
102 match mask {
103 Some(mask) => {
104 let mut out_mask = minarrow::Bitmask::new_set_all(len, true);
105 int_masked_body_std::<$ty>(op, lhs, rhs, mask, &mut out, &mut out_mask);
106 Ok(IntegerArray {
107 data: out.into(),
108 null_mask: Some(out_mask),
109 })
110 }
111 None => {
112 int_dense_body_std::<$ty>(op, lhs, rhs, &mut out);
113 Ok(IntegerArray {
114 data: out.into(),
115 null_mask: None,
116 })
117 }
118 }
119 }
120 };
121}
122
123macro_rules! impl_apply_float {
127 ($fn_name:ident, $ty:ty, $lanes:expr, $dense_body_simd:ident, $masked_body_simd:ident) => {
128 #[doc = concat!(
129 "Performs element-wise float `ArithmeticOperator` on `&[", stringify!($ty),
130 "]` using SIMD (", stringify!($lanes), " lanes) for dense/masked cases, \
131 Falls back to standard scalar ops when the `simd` feature is not enabled. \
132 Returns `FloatArray<", stringify!($ty), ">` and handles optional null-mask."
133 )]
134 #[inline(always)]
135 pub fn $fn_name(
136 lhs: &[$ty],
137 rhs: &[$ty],
138 op: ArithmeticOperator,
139 mask: Option<&Bitmask>
140 ) -> Result<FloatArray<$ty>, KernelError> {
141 let len = lhs.len();
142 confirm_equal_len("apply numeric: length mismatch", len, rhs.len())?;
143
144 #[cfg(feature = "simd")]
145 {
146 if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
148 let mut out = Vec64::with_capacity(len);
150 unsafe { out.set_len(len) };
151 match mask {
152 Some(mask) => {
153 let mut out_mask = minarrow::Bitmask::new_set_all(len, true);
154 $masked_body_simd::<$lanes>(op, lhs, rhs, mask, &mut out, &mut out_mask);
155 return Ok(FloatArray {
156 data: out.into(),
157 null_mask: Some(out_mask),
158 });
159 }
160 None => {
161 $dense_body_simd::<$lanes>(op, lhs, rhs, &mut out);
162 return Ok(FloatArray {
163 data: out.into(),
164 null_mask: None,
165 });
166 }
167 }
168 }
169 }
171
172 let mut out = Vec64::with_capacity(len);
174 unsafe { out.set_len(len) };
175 match mask {
176 Some(mask) => {
177 let mut out_mask = minarrow::Bitmask::new_set_all(len, true);
178 float_masked_body_std::<$ty>(op, lhs, rhs, mask, &mut out, &mut out_mask);
179 Ok(FloatArray {
180 data: out.into(),
181 null_mask: Some(out_mask),
182 })
183 }
184 None => {
185 float_dense_body_std::<$ty>(op, lhs, rhs, &mut out);
186 Ok(FloatArray {
187 data: out.into(),
188 null_mask: None,
189 })
190 }
191 }
192 }
193 };
194}
195
196macro_rules! impl_apply_fma_float {
200 ($fn_name:ident, $ty:ty, $lanes:expr, $dense_simd:ident, $masked_simd:ident) => {
201 #[doc = concat!(
202 "Performs element-wise fused multiply-add (`a * b + acc`) on `&[", stringify!($ty),
203 "]` using SIMD (", stringify!($lanes), " lanes; dense or masked, via `",
204 stringify!($dense), "`/`", stringify!($masked), "` as needed. \
205 Falls back to standard scalar ops when the `simd` feature is not enabled. \
206 Results in a `FloatArray<", stringify!($ty), ">`."
207 )]
208 #[inline(always)]
209 pub fn $fn_name(
210 lhs: &[$ty],
211 rhs: &[$ty],
212 acc: &[$ty],
213 mask: Option<&Bitmask>
214 ) -> Result<FloatArray<$ty>, KernelError> {
215 let len = lhs.len();
216 confirm_equal_len("apply numeric: length mismatch", len, rhs.len())?;
217 confirm_equal_len("acc length mismatch", len, acc.len())?;
218
219 let mut out = Vec64::with_capacity(len);
220 unsafe { out.set_len(len) };
221 let mut out_mask = minarrow::Bitmask::new_set_all(len, true);
222
223 #[cfg(feature = "simd")]
224 {
225 if is_simd_aligned(lhs) && is_simd_aligned(rhs) && is_simd_aligned(acc) {
227 match mask {
229 Some(mask) => {
230 $masked_simd::<$lanes>(lhs, rhs, acc, mask, &mut out, &mut out_mask);
231 return Ok(FloatArray {
232 data: out.into(),
233 null_mask: Some(out_mask),
234 });
235 }
236 None => {
237 $dense_simd::<$lanes>(lhs, rhs, acc, &mut out);
238 return Ok(FloatArray {
239 data: out.into(),
240 null_mask: None,
241 });
242 }
243 }
244 }
245 }
247
248 match mask {
250 Some(mask) => {
251 for i in 0..len {
253 if unsafe { mask.get_unchecked(i) } {
254 out[i] = lhs[i] * rhs[i] + acc[i];
255 } else {
256 out[i] = 0 as $ty; out_mask.set(i, false);
258 }
259 }
260 Ok(FloatArray {
261 data: out.into(),
262 null_mask: Some(out_mask),
263 })
264 }
265 None => {
266 for i in 0..len {
268 out[i] = lhs[i] * rhs[i] + acc[i];
269 }
270 Ok(FloatArray {
271 data: out.into(),
272 null_mask: None,
273 })
274 }
275 }
276 }
277 };
278}
279
280#[cfg(feature = "datetime")]
297macro_rules! impl_apply_datetime {
298 ($fn_name:ident, $ty:ty, $lanes:expr) => {
299 #[inline(always)]
300 pub fn $fn_name(
301 lhs: DatetimeAVT<$ty>,
302 rhs: DatetimeAVT<$ty>,
303 op: ArithmeticOperator,
304 ) -> Result<DatetimeArray<$ty>, KernelError> {
305 use crate::utils::merge_bitmasks_to_new;
306 let (larr, loff, llen) = lhs;
307 let (rarr, roff, rlen) = rhs;
308 confirm_equal_len("apply_datetime: length mismatch", llen, rlen)?;
309
310 let out_mask =
311 merge_bitmasks_to_new(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), llen);
312 let ldata = &larr.data[loff..loff + llen];
313 let rdata = &rarr.data[roff..roff + rlen];
314
315 let mut out = Vec64::<$ty>::with_capacity(llen);
316 unsafe {
317 out.set_len(llen);
318 }
319
320 match out_mask.as_ref() {
321 Some(mask) => {
322 let mut result_mask = minarrow::Bitmask::new_set_all(llen, true);
323 #[cfg(feature = "simd")]
324 {
325 int_masked_body_simd::<$ty, $lanes>(
326 op,
327 ldata,
328 rdata,
329 mask,
330 &mut out,
331 &mut result_mask,
332 );
333 }
334 #[cfg(not(feature = "simd"))]
335 {
336 int_masked_body_std::<$ty>(
337 op,
338 ldata,
339 rdata,
340 mask,
341 &mut out,
342 &mut result_mask,
343 );
344 }
345 Ok(DatetimeArray::from_vec64(out, Some(result_mask), None))
346 }
347 None => {
348 #[cfg(feature = "simd")]
349 {
350 int_dense_body_simd::<$ty, $lanes>(op, ldata, rdata, &mut out);
351 }
352 #[cfg(not(feature = "simd"))]
353 {
354 int_dense_body_std::<$ty>(op, ldata, rdata, &mut out);
355 }
356 Ok(DatetimeArray::from_vec64(out, None, None))
357 }
358 }
359 }
360 };
361}
362
363impl_apply_int!(apply_int_i32, i32, W32);
366impl_apply_int!(apply_int_u32, u32, W32);
367impl_apply_int!(apply_int_i64, i64, W64);
368impl_apply_int!(apply_int_u64, u64, W64);
369#[cfg(feature = "extended_numeric_types")]
370impl_apply_int!(apply_int_i16, i16, W16);
371#[cfg(feature = "extended_numeric_types")]
372impl_apply_int!(apply_int_u16, u16, W16);
373#[cfg(feature = "extended_numeric_types")]
374impl_apply_int!(apply_int_i8, i8, W8);
375#[cfg(feature = "extended_numeric_types")]
376impl_apply_int!(apply_int_u8, u8, W8);
377
378impl_apply_float!(
379 apply_float_f32,
380 f32,
381 W32,
382 float_dense_body_f32_simd,
383 float_masked_body_f32_simd
384);
385impl_apply_float!(
386 apply_float_f64,
387 f64,
388 W64,
389 float_dense_body_f64_simd,
390 float_masked_body_f64_simd
391);
392
393impl_apply_fma_float!(
394 apply_fma_f32,
395 f32,
396 W32,
397 fma_dense_body_f32_simd,
398 fma_masked_body_f32_simd
399);
400
401impl_apply_fma_float!(
402 apply_fma_f64,
403 f64,
404 W64,
405 fma_dense_body_f64_simd,
406 fma_masked_body_f64_simd
407);
408
409#[cfg(feature = "datetime")]
410impl_apply_datetime!(apply_datetime_i32, i32, W32);
411#[cfg(feature = "datetime")]
412impl_apply_datetime!(apply_datetime_u32, u32, W32);
413#[cfg(feature = "datetime")]
414impl_apply_datetime!(apply_datetime_i64, i64, W64);
415#[cfg(feature = "datetime")]
416impl_apply_datetime!(apply_datetime_u64, u64, W64);