simd_kernels/kernels/arithmetic/
dispatch.rs

1// Copyright Peter Bower 2025. All Rights Reserved.
2// Licensed under Mozilla Public License (MPL) 2.0.
3
4//! # **Arithmetic Dispatch Module** - *SIMD/Scalar Dispatch Layer for Arithmetic Operations*
5//!
6//! High-performance arithmetic kernel dispatcher that automatically selects between SIMD and scalar
7//! implementations based on data alignment and feature flags.
8//!
9//! ## Overview
10//! - **Dual-path execution**: SIMD-accelerated path with scalar fallback for unaligned data
11//! - **Type-specific dispatch**: Optimised kernels for integers (i32/i64/u32/u64), floats (f32/f64), and datetime types
12//! - **Null-aware operations**: Arrow-compatible null mask propagation and handling
13//! - **Build-time SIMD lanes**: Lane counts determined at build time based on target architecture
14//!
15//! ## Supported Operations  
16//! - **Basic arithmetic**: Add, subtract, multiply, divide, remainder, power
17//! - **Fused multiply-add (FMA)**: Hardware-accelerated `a * b + c` operations for floats
18//! - **Datetime arithmetic**: Temporal operations with integer kernel delegation
19//!
20//! ## Performance Strategy
21//! - SIMD requires 64-byte aligned input data. This is automatic with `minarrow`'s Vec64.
22//! - Scalar fallback ensures correctness regardless of input alignment
23
24include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
25
26use crate::errors::KernelError;
27#[cfg(feature = "simd")]
28use crate::kernels::arithmetic::simd::{
29    float_dense_body_f32_simd, float_dense_body_f64_simd, float_masked_body_f32_simd,
30    float_masked_body_f64_simd, fma_dense_body_f32_simd, fma_dense_body_f64_simd,
31    fma_masked_body_f32_simd, fma_masked_body_f64_simd, int_dense_body_simd, int_masked_body_simd,
32};
33use crate::kernels::arithmetic::std::{
34    float_dense_body_std, float_masked_body_std, int_dense_body_std, int_masked_body_std,
35};
36use crate::operators::ArithmeticOperator::{self};
37use crate::utils::confirm_equal_len;
38#[cfg(feature = "simd")]
39use crate::utils::is_simd_aligned;
40#[cfg(feature = "datetime")]
41use minarrow::DatetimeAVT;
42#[cfg(feature = "datetime")]
43use minarrow::DatetimeArray;
44use minarrow::structs::variants::float::FloatArray;
45use minarrow::structs::variants::integer::IntegerArray;
46use minarrow::{Bitmask, Vec64};
47
48// Kernels
49
50/// Generates element-wise integer arithmetic functions with SIMD/scalar dispatch.
51/// Creates functions that operate on `&[T]` slices, returning `IntegerArray<T>` with proper null handling.
52/// Automatically selects SIMD path for 64-byte aligned inputs, falls back to scalar otherwise.
53macro_rules! impl_apply_int {
54    ($fn_name:ident, $ty:ty, $lanes:expr) => {
55        #[doc = concat!(
56            "Performs element-wise integer `ArithmeticOperator` over two `&[", stringify!($ty),
57            "]`, SIMD-accelerated using ", stringify!($lanes), " lanes if available, \
58            otherwise falls back to scalar. \
59            Returns `IntegerArray<", stringify!($ty), ">` with appropriate null-mask handling."
60        )]
61        #[inline(always)]
62        pub fn $fn_name(
63            lhs: &[$ty],
64            rhs: &[$ty],
65            op: ArithmeticOperator,
66            mask: Option<&Bitmask>
67        ) -> Result<IntegerArray<$ty>, KernelError> {
68            let len = lhs.len();
69            confirm_equal_len("apply numeric: length mismatch", len, rhs.len())?;
70
71            #[cfg(feature = "simd")]
72            {
73                // Check if both arrays are 64-byte aligned for SIMD
74                if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
75                    // SIMD path - safe because we verified alignment
76                    let mut out = Vec64::with_capacity(len);
77                    unsafe { out.set_len(len) };
78                    match mask {
79                        Some(mask) => {
80                            let mut out_mask = minarrow::Bitmask::new_set_all(len, true);
81                            int_masked_body_simd::<$ty, $lanes>(op, lhs, rhs, mask, &mut out, &mut out_mask);
82                            return Ok(IntegerArray {
83                                data: out.into(),
84                                null_mask: Some(out_mask),
85                            });
86                        }
87                        None => {
88                            int_dense_body_simd::<$ty, $lanes>(op, lhs, rhs, &mut out);
89                            return Ok(IntegerArray {
90                                data: out.into(),
91                                null_mask: None,
92                            });
93                        }
94                    }
95                }
96                // Fall through to scalar path if alignment check failed
97            }
98
99            // Scalar fallback - alignment check failed
100            let mut out = Vec64::with_capacity(len);
101            unsafe { out.set_len(len) };
102            match mask {
103                Some(mask) => {
104                    let mut out_mask = minarrow::Bitmask::new_set_all(len, true);
105                    int_masked_body_std::<$ty>(op, lhs, rhs, mask, &mut out, &mut out_mask);
106                    Ok(IntegerArray {
107                        data: out.into(),
108                        null_mask: Some(out_mask),
109                    })
110                }
111                None => {
112                    int_dense_body_std::<$ty>(op, lhs, rhs, &mut out);
113                    Ok(IntegerArray {
114                        data: out.into(),
115                        null_mask: None,
116                    })
117                }
118            }
119        }
120    };
121}
122
123/// Generates element-wise floating-point arithmetic functions with SIMD/scalar dispatch.
124/// Creates functions that operate on `&[T]` slices, returning `FloatArray<T>` with proper null handling.
125/// Supports hardware-accelerated operations including FMA when available.
126macro_rules! impl_apply_float {
127    ($fn_name:ident, $ty:ty, $lanes:expr, $dense_body_simd:ident, $masked_body_simd:ident) => {
128        #[doc = concat!(
129                    "Performs element-wise float `ArithmeticOperator` on `&[", stringify!($ty),
130                    "]` using SIMD (", stringify!($lanes), " lanes) for dense/masked cases,  \
131                    Falls back to standard scalar ops when the `simd` feature is not enabled. \
132            Returns `FloatArray<", stringify!($ty), ">` and handles optional null-mask."
133                )]
134        #[inline(always)]
135        pub fn $fn_name(
136            lhs: &[$ty],
137            rhs: &[$ty],
138            op: ArithmeticOperator,
139            mask: Option<&Bitmask>
140        ) -> Result<FloatArray<$ty>, KernelError> {
141            let len = lhs.len();
142            confirm_equal_len("apply numeric: length mismatch", len, rhs.len())?;
143
144            #[cfg(feature = "simd")]
145            {
146                // Check if both arrays are 64-byte aligned for SIMD
147                if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
148                    // SIMD path - safe because we verified alignment
149                    let mut out = Vec64::with_capacity(len);
150                    unsafe { out.set_len(len) };
151                    match mask {
152                        Some(mask) => {
153                            let mut out_mask = minarrow::Bitmask::new_set_all(len, true);
154                            $masked_body_simd::<$lanes>(op, lhs, rhs, mask, &mut out, &mut out_mask);
155                            return Ok(FloatArray {
156                                data: out.into(),
157                                null_mask: Some(out_mask),
158                            });
159                        }
160                        None => {
161                            $dense_body_simd::<$lanes>(op, lhs, rhs, &mut out);
162                            return Ok(FloatArray {
163                                data: out.into(),
164                                null_mask: None,
165                            });
166                        }
167                    }
168                }
169                // Fall through to scalar path if alignment check failed
170            }
171
172            // Scalar fallback - alignment check failed
173            let mut out = Vec64::with_capacity(len);
174            unsafe { out.set_len(len) };
175            match mask {
176                Some(mask) => {
177                    let mut out_mask = minarrow::Bitmask::new_set_all(len, true);
178                    float_masked_body_std::<$ty>(op, lhs, rhs, mask, &mut out, &mut out_mask);
179                    Ok(FloatArray {
180                        data: out.into(),
181                        null_mask: Some(out_mask),
182                    })
183                }
184                None => {
185                    float_dense_body_std::<$ty>(op, lhs, rhs, &mut out);
186                    Ok(FloatArray {
187                        data: out.into(),
188                        null_mask: None,
189                    })
190                }
191            }
192        }
193    };
194}
195
196/// Generates fused multiply-add (FMA) functions with SIMD/scalar dispatch.
197/// Creates `a * b + c` operations on `&[T]` slices, returning `FloatArray<T>`.
198/// Uses `mul_add()`, which leverages hardware FMA when available.
199macro_rules! impl_apply_fma_float {
200    ($fn_name:ident, $ty:ty, $lanes:expr, $dense_simd:ident, $masked_simd:ident) => {
201        #[doc = concat!(
202            "Performs element-wise fused multiply-add (`a * b + acc`) on `&[", stringify!($ty),
203            "]` using SIMD (", stringify!($lanes), " lanes; dense or masked, via `",
204            stringify!($dense), "`/`", stringify!($masked), "` as needed. \
205            Falls back to standard scalar ops when the `simd` feature is not enabled. \
206            Results in a `FloatArray<", stringify!($ty), ">`."
207        )]
208        #[inline(always)]
209        pub fn $fn_name(
210            lhs: &[$ty],
211            rhs: &[$ty],
212            acc: &[$ty],
213            mask: Option<&Bitmask>
214        ) -> Result<FloatArray<$ty>, KernelError> {
215            let len = lhs.len();
216            confirm_equal_len("apply numeric: length mismatch", len, rhs.len())?;
217            confirm_equal_len("acc length mismatch", len, acc.len())?;
218
219            let mut out = Vec64::with_capacity(len);
220            unsafe { out.set_len(len) };
221            let mut out_mask = minarrow::Bitmask::new_set_all(len, true);
222
223            #[cfg(feature = "simd")]
224            {
225                // Check if all arrays are properly aligned for SIMD (cheap runtime check)
226                if is_simd_aligned(lhs) && is_simd_aligned(rhs) && is_simd_aligned(acc) {
227                    // SIMD path - safe because we verified alignment
228                    match mask {
229                        Some(mask) => {
230                            $masked_simd::<$lanes>(lhs, rhs, acc, mask, &mut out, &mut out_mask);
231                            return Ok(FloatArray {
232                                data: out.into(),
233                                null_mask: Some(out_mask),
234                            });
235                        }
236                        None => {
237                            $dense_simd::<$lanes>(lhs, rhs, acc, &mut out);
238                            return Ok(FloatArray {
239                                data: out.into(),
240                                null_mask: None,
241                            });
242                        }
243                    }
244                }
245                // Fall through to scalar path if alignment check failed
246            }
247
248            // Scalar fallback - alignment check failed
249            match mask {
250                Some(mask) => {
251                    // Masked FMA: a * b + acc with null handling
252                    for i in 0..len {
253                        if unsafe { mask.get_unchecked(i) } {
254                            out[i] = lhs[i] * rhs[i] + acc[i];
255                        } else {
256                            out[i] = 0 as $ty;  // Initialize masked values to zero
257                            out_mask.set(i, false);
258                        }
259                    }
260                    Ok(FloatArray {
261                        data: out.into(),
262                        null_mask: Some(out_mask),
263                    })
264                }
265                None => {
266                    // Dense FMA: a * b + acc
267                    for i in 0..len {
268                        out[i] = lhs[i] * rhs[i] + acc[i];
269                    }
270                    Ok(FloatArray {
271                        data: out.into(),
272                        null_mask: None,
273                    })
274                }
275            }
276        }
277    };
278}
279
280/// Performs element-wise arithmetic between two `DatetimeArray<T>`s with SIMD/scalar fallback,
281/// using the standard integer SIMD/scalar kernels for the underlying data.
282///
283/// Returns `DatetimeArray<T>` with correct null propagation.
284///
285/// # Supported operations
286/// - Add, Subtract, Multiply, Divide, Remainder
287/// - Power: defined as left-value (lhs) preserved (see notes)
288///
289/// # Notes
290/// - **"Power" for dates/times is undefined: returns lhs unchanged**.
291/// - All other ops delegate directly to the integer kernels for correctness/performance.
292/// - Any future date-specific ops should be implemented in the stub below.
293/// Generates datetime arithmetic functions with SIMD/scalar dispatch.
294/// Creates functions that operate on `DatetimeArray<T>` types, delegating to integer kernels
295/// for the underlying temporal data while preserving datetime semantics.
296#[cfg(feature = "datetime")]
297macro_rules! impl_apply_datetime {
298    ($fn_name:ident, $ty:ty, $lanes:expr) => {
299        #[inline(always)]
300        pub fn $fn_name(
301            lhs: DatetimeAVT<$ty>,
302            rhs: DatetimeAVT<$ty>,
303            op: ArithmeticOperator,
304        ) -> Result<DatetimeArray<$ty>, KernelError> {
305            use crate::utils::merge_bitmasks_to_new;
306            let (larr, loff, llen) = lhs;
307            let (rarr, roff, rlen) = rhs;
308            confirm_equal_len("apply_datetime: length mismatch", llen, rlen)?;
309
310            let out_mask =
311                merge_bitmasks_to_new(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), llen);
312            let ldata = &larr.data[loff..loff + llen];
313            let rdata = &rarr.data[roff..roff + rlen];
314
315            let mut out = Vec64::<$ty>::with_capacity(llen);
316            unsafe {
317                out.set_len(llen);
318            }
319
320            match out_mask.as_ref() {
321                Some(mask) => {
322                    let mut result_mask = minarrow::Bitmask::new_set_all(llen, true);
323                    #[cfg(feature = "simd")]
324                    {
325                        int_masked_body_simd::<$ty, $lanes>(
326                            op,
327                            ldata,
328                            rdata,
329                            mask,
330                            &mut out,
331                            &mut result_mask,
332                        );
333                    }
334                    #[cfg(not(feature = "simd"))]
335                    {
336                        int_masked_body_std::<$ty>(
337                            op,
338                            ldata,
339                            rdata,
340                            mask,
341                            &mut out,
342                            &mut result_mask,
343                        );
344                    }
345                    Ok(DatetimeArray::from_vec64(out, Some(result_mask), None))
346                }
347                None => {
348                    #[cfg(feature = "simd")]
349                    {
350                        int_dense_body_simd::<$ty, $lanes>(op, ldata, rdata, &mut out);
351                    }
352                    #[cfg(not(feature = "simd"))]
353                    {
354                        int_dense_body_std::<$ty>(op, ldata, rdata, &mut out);
355                    }
356                    Ok(DatetimeArray::from_vec64(out, None, None))
357                }
358            }
359        }
360    };
361}
362
363// Generates i32, u32, i64, u64, f32, f64 variants using lane counts via simd_lanes.rs
364
365impl_apply_int!(apply_int_i32, i32, W32);
366impl_apply_int!(apply_int_u32, u32, W32);
367impl_apply_int!(apply_int_i64, i64, W64);
368impl_apply_int!(apply_int_u64, u64, W64);
369#[cfg(feature = "extended_numeric_types")]
370impl_apply_int!(apply_int_i16, i16, W16);
371#[cfg(feature = "extended_numeric_types")]
372impl_apply_int!(apply_int_u16, u16, W16);
373#[cfg(feature = "extended_numeric_types")]
374impl_apply_int!(apply_int_i8, i8, W8);
375#[cfg(feature = "extended_numeric_types")]
376impl_apply_int!(apply_int_u8, u8, W8);
377
378impl_apply_float!(
379    apply_float_f32,
380    f32,
381    W32,
382    float_dense_body_f32_simd,
383    float_masked_body_f32_simd
384);
385impl_apply_float!(
386    apply_float_f64,
387    f64,
388    W64,
389    float_dense_body_f64_simd,
390    float_masked_body_f64_simd
391);
392
393impl_apply_fma_float!(
394    apply_fma_f32,
395    f32,
396    W32,
397    fma_dense_body_f32_simd,
398    fma_masked_body_f32_simd
399);
400
401impl_apply_fma_float!(
402    apply_fma_f64,
403    f64,
404    W64,
405    fma_dense_body_f64_simd,
406    fma_masked_body_f64_simd
407);
408
409#[cfg(feature = "datetime")]
410impl_apply_datetime!(apply_datetime_i32, i32, W32);
411#[cfg(feature = "datetime")]
412impl_apply_datetime!(apply_datetime_u32, u32, W32);
413#[cfg(feature = "datetime")]
414impl_apply_datetime!(apply_datetime_i64, i64, W64);
415#[cfg(feature = "datetime")]
416impl_apply_datetime!(apply_datetime_u64, u64, W64);