simd_kernels/kernels/arithmetic/
mod.rs

1// Copyright Peter Bower 2025. All Rights Reserved.
2// Licensed under Mozilla Public License (MPL) 2.0.
3
4//! # **Arithmetic Kernels Module** - *High-Performance Arithmetic*
5//!
6//! SIMD-optimised arithmetic operations for numeric arrays with null-aware semantics.
7//!
8//! ## Modules
9//! - **`dispatch`**: Smart dispatch layer selecting SIMD vs scalar implementations based on alignment
10//! - **`simd`**: SIMD-accelerated implementations using `std::simd` with portable vectorisation  
11//! - **`std`**: Scalar fallback implementations for compatibility and unaligned data
12//! - **`string`**: Specialised arithmetic operations for string concatenation and manipulation
13//!
14//! ## Operations
15//! Supports standard arithmetic operations (add, subtract, multiply, divide, remainder, power)
16//! plus fused multiply-add (FMA) for floating-point types with hardware acceleration.
17//! 
18//! ## Scope
19//! **These do not leverage parallel-thread processing, as this is expected to be applied in the engine layer,
20//! which is app-specific.**.
21
22pub mod dispatch;
23#[cfg(feature = "simd")]
24pub mod simd;
25pub mod std;
26pub mod string;
27
28// Shared tests for SIMD and Std
29
30#[cfg(test)]
31mod tests {
32    use minarrow::structs::variants::float::FloatArray;
33    use minarrow::structs::variants::integer::IntegerArray;
34    use minarrow::{Bitmask, MaskedArray, vec64};
35
36    use crate::kernels::arithmetic::dispatch::{
37        apply_float_f32, apply_float_f64, apply_fma_f32, apply_fma_f64, apply_int_i32,
38        apply_int_i64, apply_int_u32, apply_int_u64,
39    };
40    #[cfg(feature = "extended_numeric_types")]
41    use crate::kernels::arithmetic::dispatch::{
42        apply_int_i8, apply_int_i16, apply_int_u8, apply_int_u16,
43    };
44    #[cfg(feature = "simd")]
45    use crate::kernels::arithmetic::simd::int_dense_body_simd;
46    use crate::operators::ArithmeticOperator;
47
48    fn assert_int<T>(arr: &IntegerArray<T>, values: &[T], valid: Option<&[bool]>)
49    where
50        T: num_traits::PrimInt + std::fmt::Debug,
51    {
52        assert_eq!(arr.data.as_slice(), values);
53        match (valid, &arr.null_mask) {
54            (None, None) => {}
55            (Some(expected), Some(mask)) => {
56                for (i, bit) in expected.iter().enumerate() {
57                    assert_eq!(
58                        unsafe { mask.get_unchecked(i) },
59                        *bit,
60                        "mask mismatch at {i}"
61                    );
62                }
63            }
64            (None, Some(mask)) => {
65                assert!(mask.all_true(), "mask unexpectedly present");
66            }
67            (Some(_), None) => panic!("expected mask missing"),
68        }
69    }
70
71    fn assert_float<T>(arr: &FloatArray<T>, values: &[T], valid: Option<&[bool]>)
72    where
73        T: num_traits::Float + std::fmt::Debug,
74    {
75        assert_eq!(arr.data.as_slice(), values);
76        match (valid, &arr.null_mask) {
77            (None, None) => {}
78            (Some(expected), Some(mask)) => {
79                for (i, bit) in expected.iter().enumerate() {
80                    assert_eq!(
81                        unsafe { mask.get_unchecked(i) },
82                        *bit,
83                        "mask mismatch at {i}"
84                    );
85                }
86            }
87            (None, Some(mask)) => {
88                assert!(mask.all_true(), "mask unexpectedly present");
89            }
90            (Some(_), None) => panic!("expected mask missing"),
91        }
92    }
93
94    fn bitmask(bits: &[bool]) -> Bitmask {
95        let mut m = Bitmask::new_set_all(bits.len(), false);
96        for (i, b) in bits.iter().enumerate() {
97            unsafe { m.set_unchecked(i, *b) };
98        }
99        m
100    }
101
102    macro_rules! int_kernel_suite {
103        ($fn_dense:ident, $fn_masked:ident, $fn_empty:ident, $ty:ty, $apply_fn:ident) => {
104            #[test]
105            fn $fn_dense() {
106                let lhs = vec64![1, 4, 9, 16];
107                let rhs = vec64![1, 2, 3, 4];
108
109                let out = $apply_fn(&lhs, &rhs, ArithmeticOperator::Add, None).unwrap();
110                assert_int(
111                    &out,
112                    &IntegerArray::<$ty>::from_slice(&[2, 6, 12, 20]),
113                    None,
114                );
115
116                let out = $apply_fn(&lhs, &rhs, ArithmeticOperator::Subtract, None).unwrap();
117                assert_int(&out, &IntegerArray::<$ty>::from_slice(&[0, 2, 6, 12]), None);
118
119                let out = $apply_fn(&lhs, &rhs, ArithmeticOperator::Multiply, None).unwrap();
120                assert_int(
121                    &out,
122                    &IntegerArray::<$ty>::from_slice(&[1, 8, 27, 64]),
123                    None,
124                );
125
126                let out = $apply_fn(&lhs, &rhs, ArithmeticOperator::Divide, None).unwrap();
127                assert_int(&out, &IntegerArray::<$ty>::from_slice(&[1, 2, 3, 4]), None);
128
129                let out = $apply_fn(&lhs, &rhs, ArithmeticOperator::Remainder, None).unwrap();
130                assert_int(&out, &IntegerArray::<$ty>::from_slice(&[0, 0, 0, 0]), None);
131
132                let expected: Vec<$ty> = lhs
133                    .iter()
134                    .zip(rhs.iter())
135                    .map(|(&a, &b)| {
136                        let mut acc = <$ty as num_traits::One>::one();
137                        for _ in 0..(b as u32) {
138                            acc = acc.wrapping_mul(a);
139                        }
140                        acc
141                    })
142                    .collect();
143                let out = $apply_fn(&lhs, &rhs, ArithmeticOperator::Power, None).unwrap();
144                assert_int(&out, &IntegerArray::<$ty>::from_slice(&expected), None);
145
146                // Division by zero should panic
147                let rhs_divzero: &[$ty] = &[0, 0, 0, 0];
148                let result = std::panic::catch_unwind(|| {
149                    $apply_fn(&lhs, rhs_divzero, ArithmeticOperator::Divide, None).unwrap()
150                });
151                assert!(
152                    result.is_err(),
153                    "Dense integer kernel division by zero must panic"
154                );
155
156                let result = std::panic::catch_unwind(|| {
157                    $apply_fn(&lhs, rhs_divzero, ArithmeticOperator::Remainder, None).unwrap()
158                });
159                assert!(
160                    result.is_err(),
161                    "Dense integer kernel remainder by zero must panic"
162                );
163            }
164
165            #[test]
166            fn $fn_masked() {
167                let lhs = vec64![10, 20, 30, 40];
168                let rhs = vec64![2, 0, 3, 5];
169                let mask = bitmask(&[true, false, true, false]);
170
171                // Division: mask==true and rhs!=0 are valid, otherwise null
172                let out = $apply_fn(&lhs, &rhs, ArithmeticOperator::Divide, Some(&mask)).unwrap();
173                assert_int(
174                    &out,
175                    &IntegerArray::<$ty>::from_slice(&[5, 0, 10, 0]),
176                    Some(&[true, false, true, false]),
177                );
178
179                // Remainder with mask, matching above
180                let out =
181                    $apply_fn(&lhs, &rhs, ArithmeticOperator::Remainder, Some(&mask)).unwrap();
182                assert_int(
183                    &out,
184                    &IntegerArray::<$ty>::from_slice(&[0, 0, 0, 0]),
185                    Some(&[true, false, true, false]),
186                );
187
188                // Division by zero where mask is true but rhs is zero must yield null in mask (false) and output 0
189                let mask_divzero = bitmask(&[true, true, true, true]);
190                let rhs_divzero: &[$ty] = &[1, 0, 2, 0];
191                let lhs2: &[$ty] = &[100, 100, 100, 100];
192
193                let out = $apply_fn(
194                    lhs2,
195                    rhs_divzero,
196                    ArithmeticOperator::Divide,
197                    Some(&mask_divzero),
198                )
199                .unwrap();
200                assert_int(
201                    &out,
202                    &IntegerArray::<$ty>::from_slice(&[100, 0, 50, 0]),
203                    Some(&[true, false, true, false]),
204                );
205            }
206
207            #[test]
208            fn $fn_empty() {
209                let lhs = vec64![];
210                let rhs = vec64![];
211                let out = $apply_fn(&lhs, &rhs, ArithmeticOperator::Add, None).unwrap();
212                assert!(out.is_empty());
213            }
214        };
215    }
216
217    #[cfg(feature = "extended_numeric_types")]
218    int_kernel_suite!(
219        apply_int_i8_dense,
220        apply_int_i8_masked,
221        apply_int_i8_empty,
222        i8,
223        apply_int_i8
224    );
225    #[cfg(feature = "extended_numeric_types")]
226    int_kernel_suite!(
227        apply_int_u8_dense,
228        apply_int_u8_masked,
229        apply_int_u8_empty,
230        u8,
231        apply_int_u8
232    );
233    #[cfg(feature = "extended_numeric_types")]
234    int_kernel_suite!(
235        apply_int_i16_dense,
236        apply_int_i16_masked,
237        apply_int_i16_empty,
238        i16,
239        apply_int_i16
240    );
241    #[cfg(feature = "extended_numeric_types")]
242    int_kernel_suite!(
243        apply_int_u16_dense,
244        apply_int_u16_masked,
245        apply_int_u16_empty,
246        u16,
247        apply_int_u16
248    );
249    int_kernel_suite!(
250        apply_int_i32_dense,
251        apply_int_i32_masked,
252        apply_int_i32_empty,
253        i32,
254        apply_int_i32
255    );
256    int_kernel_suite!(
257        apply_int_u32_dense,
258        apply_int_u32_masked,
259        apply_int_u32_empty,
260        u32,
261        apply_int_u32
262    );
263    int_kernel_suite!(
264        apply_int_i64_dense,
265        apply_int_i64_masked,
266        apply_int_i64_empty,
267        i64,
268        apply_int_i64
269    );
270    int_kernel_suite!(
271        apply_int_u64_dense,
272        apply_int_u64_masked,
273        apply_int_u64_empty,
274        u64,
275        apply_int_u64
276    );
277
278    macro_rules! float_kernel_suite {
279        ($test_fn:ident, $ty:ty, $apply_fn:ident, $eps:expr) => {
280            #[test]
281            fn $test_fn() {
282                let lhs = vec64![1.0, 4.0, 9.0, 16.0];
283                let rhs = vec64![0.5, 2.0, 3.0, 4.0];
284
285                let lhs: &[$ty] = lhs.as_slice();
286                let rhs: &[$ty] = rhs.as_slice();
287
288                let arr = $apply_fn(lhs, rhs, ArithmeticOperator::Add, None).unwrap();
289                assert_eq!(arr.data.as_slice(), &[1.5 as $ty, 6.0, 12.0, 20.0]);
290
291                let arr = $apply_fn(lhs, rhs, ArithmeticOperator::Subtract, None).unwrap();
292                assert_eq!(arr.data.as_slice(), &[0.5 as $ty, 2.0, 6.0, 12.0]);
293
294                let arr = $apply_fn(lhs, rhs, ArithmeticOperator::Multiply, None).unwrap();
295                assert_eq!(arr.data.as_slice(), &[0.5 as $ty, 8.0, 27.0, 64.0]);
296
297                let arr = $apply_fn(lhs, rhs, ArithmeticOperator::Divide, None).unwrap();
298                assert_eq!(arr.data.as_slice(), &[2.0 as $ty, 2.0, 3.0, 4.0]);
299
300                let arr = $apply_fn(lhs, rhs, ArithmeticOperator::Remainder, None).unwrap();
301                assert!(
302                    arr.data
303                        .as_slice()
304                        .iter()
305                        .zip(
306                            [1.0 % 0.5, 4.0 % 2.0, 9.0 % 3.0, 16.0 % 4.0]
307                                .iter()
308                                .map(|&x| x as $ty)
309                        )
310                        .all(|(a, b)| (*a - b).abs() < $eps)
311                );
312
313                let arr = $apply_fn(lhs, rhs, ArithmeticOperator::Power, None).unwrap();
314                let expected: Vec<$ty> = lhs
315                    .iter()
316                    .zip(rhs.iter())
317                    .map(|(&a, &b)| (b * a.ln()).exp())
318                    .collect();
319                assert!(
320                    arr.data
321                        .as_slice()
322                        .iter()
323                        .zip(expected.iter())
324                        .all(|(a, b)| (*a - *b).abs() < $eps)
325                );
326
327                // Division by zero for floats yields Inf/NaN, never panics
328                let rhs_divzero: &[$ty] = &[0.0, 0.0, 0.0, 0.0];
329                let arr = $apply_fn(lhs, rhs_divzero, ArithmeticOperator::Divide, None).unwrap();
330                assert!(
331                    arr.data.iter().all(|&x| x.is_infinite()),
332                    "Float division by zero should yield Inf"
333                );
334
335                let arr = $apply_fn(lhs, rhs_divzero, ArithmeticOperator::Remainder, None).unwrap();
336                assert!(
337                    arr.data.iter().all(|&x| x.is_nan()),
338                    "Float remainder by zero should yield NaN"
339                );
340
341                // Masked test
342                let mask = bitmask(&[true, false, true, false]);
343                let arr = $apply_fn(lhs, rhs, ArithmeticOperator::Multiply, Some(&mask)).unwrap();
344                assert_eq!(arr.data.as_slice(), &[0.5 as $ty, 0.0, 27.0, 0.0]);
345                assert_eq!(arr.null_mask.as_ref().unwrap().len(), 4);
346
347                // Empty
348                let arr = $apply_fn(&[], &[], ArithmeticOperator::Add, None).unwrap();
349                assert!(arr.is_empty());
350            }
351        };
352    }
353
354    float_kernel_suite!(apply_float_f32_dense, f32, apply_float_f32, 1e-6f32);
355    float_kernel_suite!(apply_float_f64_dense, f64, apply_float_f64, 1e-12f64);
356
357    #[test]
358    fn fma_f32() {
359        let lhs = vec64![1.0f32, 2.0, 3.0];
360        let rhs = vec64![4.0f32, 5.0, 6.0];
361        let acc = vec64![0.5f32, 0.5, 0.5];
362        let out = apply_fma_f32(&lhs, &rhs, &acc, None).unwrap();
363        assert_float(&out, &[4.5, 10.5, 18.5], None);
364
365        let mask = bitmask(&[true, false, true]);
366        let out = apply_fma_f32(&lhs, &rhs, &acc, Some(&mask)).unwrap();
367        assert_float(&out, &[4.5, 0.0, 18.5], Some(&[true, false, true]));
368
369        let out = apply_fma_f32(&[], &[], &[], None).unwrap();
370        assert!(out.is_empty());
371    }
372
373    #[test]
374    fn fma_f64() {
375        let lhs = vec64![1.0f64, 2.0, 3.0];
376        let rhs = vec64![4.0f64, 5.0, 6.0];
377        let acc = vec64![0.5f64, 0.5, 0.5];
378        let out = apply_fma_f64(&lhs, &rhs, &acc, None).unwrap();
379        assert_float(&out, &[4.5, 10.5, 18.5], None);
380
381        let mask = bitmask(&[true, false, true]);
382        let out = apply_fma_f64(&lhs, &rhs, &acc, Some(&mask)).unwrap();
383        assert_float(&out, &[4.5, 0.0, 18.5], Some(&[true, false, true]));
384    }
385
386    #[test]
387    fn merge_masks_correctness() {
388        let a = bitmask(&[true, false, true, true]);
389        let b = bitmask(&[true, true, false, true]);
390        let merged = crate::utils::merge_bitmasks_to_new(Some(&a), Some(&b), 4).unwrap();
391        let expected = vec![true, false, false, true];
392        let merged_vec: Vec<bool> = (0..4).map(|i| merged.get(i)).collect();
393        assert_eq!(merged_vec, expected);
394    }
395
396    // ─────────────────────────────────────────────────────────────────────────────
397    // Datetime Kernels
398    // ─────────────────────────────────────────────────────────────────────────────
399    #[cfg(feature = "datetime")]
400    use minarrow::structs::variants::datetime::DatetimeArray;
401
402    #[cfg(feature = "datetime")]
403    use crate::kernels::arithmetic::dispatch::apply_datetime_i64;
404
405    #[cfg(feature = "datetime")]
406    #[test]
407    fn datetime_add() {
408        let lhs = DatetimeArray::<i64>::from_slice(&[1_000i64, 2_000, 3_000], None);
409        let rhs = DatetimeArray::<i64>::from_slice(&[10, 20, 30], None);
410        let lhs_slice = (&lhs, 0, lhs.len());
411        let rhs_slice = (&rhs, 0, rhs.len());
412        let out = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
413        assert_eq!(out.data.as_slice(), &[1_010, 2_020, 3_030]);
414        assert!(out.null_mask.is_none());
415    }
416
417    #[cfg(feature = "datetime")]
418    #[test]
419    fn datetime_all_ops() {
420        let lhs = DatetimeArray::<i64>::from_slice(&[10, 20, 30, 40], None);
421        let rhs = DatetimeArray::<i64>::from_slice(&[1, 2, 3, 4], None);
422        let lhs_slice = (&lhs, 0, lhs.len());
423        let rhs_slice = (&rhs, 0, rhs.len());
424
425        let out = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
426        assert_eq!(out.data.as_slice(), &[11, 22, 33, 44]);
427
428        let out = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Subtract).unwrap();
429        assert_eq!(out.data.as_slice(), &[9, 18, 27, 36]);
430
431        let out = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Multiply).unwrap();
432        assert_eq!(out.data.as_slice(), &[10, 40, 90, 160]);
433
434        let out = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Divide).unwrap();
435        assert_eq!(out.data.as_slice(), &[10, 10, 10, 10]);
436
437        let out = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Remainder).unwrap();
438        assert_eq!(out.data.as_slice(), &[0, 0, 0, 0]);
439
440        let out = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Power).unwrap();
441        assert_eq!(
442            out.data.as_slice(),
443            &[10_i64.pow(1), 20_i64.pow(2), 30_i64.pow(3), 40_i64.pow(4)]
444        );
445    }
446
447    #[cfg(feature = "datetime")]
448    #[test]
449    fn datetime_masked_and_empty() {
450        let lhs = DatetimeArray::<i64>::from_slice(&[10, 20, 30, 40], None);
451        let rhs = DatetimeArray::<i64>::from_slice(&[1, 2, 3, 4], None);
452        let mask = bitmask(&[true, false, true, true]);
453        let lhs_slice = (&lhs, 0, lhs.len());
454        let rhs_slice = (&rhs, 0, rhs.len());
455
456        let out = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
457        assert_eq!(out.data.as_slice(), &[11, 22, 33, 44]);
458
459        // Masked
460        let mut lhs_masked = lhs.clone();
461        lhs_masked.null_mask = Some(mask.clone());
462        let lhs_slice_masked = (&lhs_masked, 0, lhs_masked.len());
463        let out = apply_datetime_i64(lhs_slice_masked, rhs_slice, ArithmeticOperator::Add).unwrap();
464        let expected = vec![11, 0, 33, 44];
465        let mask_vec: Vec<bool> = (0..4).map(|i| mask.get(i)).collect();
466        assert_eq!(out.data.as_slice(), &expected);
467        assert_eq!(
468            out.null_mask
469                .as_ref()
470                .map(|m| (0..4).map(|i| m.get(i)).collect::<Vec<_>>()),
471            Some(mask_vec)
472        );
473
474        // Empty
475        let lhs_empty = DatetimeArray::<i64>::from_slice(&[], None);
476        let rhs_empty = DatetimeArray::<i64>::from_slice(&[], None);
477        let lhs_slice = (&lhs_empty, 0, lhs_empty.len());
478        let rhs_slice = (&rhs_empty, 0, rhs_empty.len());
479        let out = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
480        assert!(out.is_empty());
481    }
482
483    #[cfg(feature = "datetime")]
484    #[test]
485    #[should_panic(expected = "apply_datetime: length mismatch")]
486    fn datetime_len_mismatch_panics() {
487        let lhs = DatetimeArray::<i64>::from_slice(&[1_000i64, 2_000], None);
488        let rhs = DatetimeArray::<i64>::from_slice(&[10], None);
489        let lhs_slice = (&lhs, 0, lhs.len());
490        let rhs_slice = (&rhs, 0, rhs.len());
491        let _ = apply_datetime_i64(lhs_slice, rhs_slice, ArithmeticOperator::Add).unwrap();
492    }
493
494    #[cfg(feature = "simd")]
495    #[test]
496    fn test_int_dense_power_short_vs_long_input_simd() {
497        let lhs_short = vec64![2u32; 16];
498        let rhs_short = vec64![10u32; 16];
499        let mut out_short = vec64![0u32; 16];
500
501        let lhs_long = vec64![2u32; 128];
502        let rhs_long = vec64![10u32; 128];
503        let mut out_long = vec64![0u32; 128];
504
505        int_dense_body_simd::<u32, 4>(
506            ArithmeticOperator::Power,
507            &lhs_short,
508            &rhs_short,
509            &mut out_short,
510        );
511        int_dense_body_simd::<u32, 4>(
512            ArithmeticOperator::Power,
513            &lhs_long,
514            &rhs_long,
515            &mut out_long,
516        );
517
518        for &v in out_short.iter() {
519            assert_eq!(v, 1024);
520        }
521        for &v in out_long.iter() {
522            assert_eq!(v, 1024);
523        }
524    }
525}