simd_kernels/kernels/arithmetic/
std.rs

1// Copyright Peter Bower 2025. All Rights Reserved.
2// Licensed under Mozilla Public License (MPL) 2.0.
3
4//! # **Standard Arithmetic Kernels Module** - *Scalar Fallback / Non-SIMD Implementations*
5//!
6//! Portable scalar implementations of arithmetic operations for compatibility and unaligned data.
7//!
8//! Prefer dispatch.rs for easily handling the general case, otherwise you can use these inner functions
9//! directly (e.g., "dense_std") vs. "maybe masked, maybe std". 
10//! 
11//! ## Overview
12//! - **Scalar loops**: Standard element-wise operations without vectorisation
13//! - **Fallback role**: Used when SIMD alignment requirements aren't met or SIMD is disabled
14//! - **Full compatibility**: Works on any architecture regardless of SIMD support
15//! - **Null-aware**: Supports Arrow-compatible null mask propagation
16//!
17//! ## Design Notes
18//! - Intentionally avoids parallelisation to allow higher-level chunking strategies
19//! - Wrapping arithmetic for integers to prevent overflow panics
20//! - Division by zero handling: panics for integers, produces Inf/NaN for floats
21
22use crate::operators::ArithmeticOperator;
23use minarrow::Bitmask;
24use num_traits::{Float, PrimInt, ToPrimitive, WrappingAdd, WrappingMul, WrappingSub};
25
26/// Scalar integer arithmetic kernel for dense arrays (no nulls).
27/// Performs element-wise operations using wrapping arithmetic to prevent overflow panics.
28/// Panics on division/remainder by zero.
29#[inline(always)]
30pub fn int_dense_body_std<T: PrimInt + ToPrimitive + WrappingAdd + WrappingSub + WrappingMul>(
31    op: ArithmeticOperator,
32    lhs: &[T],
33    rhs: &[T],
34    out: &mut [T],
35) {
36    let n = lhs.len();
37    for i in 0..n {
38        out[i] = match op {
39            ArithmeticOperator::Add => lhs[i].wrapping_add(&rhs[i]),
40            ArithmeticOperator::Subtract => lhs[i].wrapping_sub(&rhs[i]),
41            ArithmeticOperator::Multiply => lhs[i].wrapping_mul(&rhs[i]),
42            ArithmeticOperator::Divide => {
43                if rhs[i] == T::zero() {
44                    panic!("Division by zero")
45                } else {
46                    lhs[i] / rhs[i]
47                }
48            }
49            ArithmeticOperator::Remainder => {
50                if rhs[i] == T::zero() {
51                    panic!("Remainder by zero")
52                } else {
53                    lhs[i] % rhs[i]
54                }
55            }
56            ArithmeticOperator::Power => lhs[i].pow(rhs[i].to_u32().unwrap_or(0)),
57        };
58    }
59}
60
61/// Scalar integer arithmetic kernel with null mask support.
62/// Handles division by zero gracefully by marking results as null instead of panicking.
63/// Invalid inputs (mask=false) and zero division produce null outputs.
64#[inline(always)]
65pub fn int_masked_body_std<T: PrimInt + ToPrimitive + WrappingAdd + WrappingSub + WrappingMul>(
66    op: ArithmeticOperator,
67    lhs: &[T],
68    rhs: &[T],
69    mask: &Bitmask,
70    out: &mut [T],
71    out_mask: &mut Bitmask,
72) {
73    let n = lhs.len();
74    for i in 0..n {
75        let valid = unsafe { mask.get_unchecked(i) };
76        if valid {
77            let (result, final_valid) = match op {
78                ArithmeticOperator::Add => (lhs[i].wrapping_add(&rhs[i]), true),
79                ArithmeticOperator::Subtract => (lhs[i].wrapping_sub(&rhs[i]), true),
80                ArithmeticOperator::Multiply => (lhs[i].wrapping_mul(&rhs[i]), true),
81                ArithmeticOperator::Divide => {
82                    if rhs[i] == T::zero() {
83                        (T::zero(), false) // division by zero -> invalid
84                    } else {
85                        (lhs[i] / rhs[i], true)
86                    }
87                }
88                ArithmeticOperator::Remainder => {
89                    if rhs[i] == T::zero() {
90                        (T::zero(), false) // remainder by zero -> invalid
91                    } else {
92                        (lhs[i] % rhs[i], true)
93                    }
94                }
95                ArithmeticOperator::Power => (lhs[i].pow(rhs[i].to_u32().unwrap_or(0)), true),
96            };
97            out[i] = result;
98            unsafe {
99                out_mask.set_unchecked(i, final_valid);
100            }
101        } else {
102            out[i] = T::zero();
103            unsafe {
104                out_mask.set_unchecked(i, false);
105            }
106        }
107    }
108}
109
110/// Scalar floating-point arithmetic kernel for dense arrays (no nulls).
111/// Division by zero produces Inf/NaN rather than panicking.
112/// Power operations use logarithmic exponentiation: `exp(b * ln(a))`.
113#[inline(always)]
114pub fn float_dense_body_std<T: Float>(op: ArithmeticOperator, lhs: &[T], rhs: &[T], out: &mut [T]) {
115    let n = lhs.len();
116    for i in 0..n {
117        out[i] = match op {
118            ArithmeticOperator::Add => lhs[i] + rhs[i],
119            ArithmeticOperator::Subtract => lhs[i] - rhs[i],
120            ArithmeticOperator::Multiply => lhs[i] * rhs[i],
121            ArithmeticOperator::Divide => lhs[i] / rhs[i],
122            ArithmeticOperator::Remainder => lhs[i] % rhs[i],
123            ArithmeticOperator::Power => (rhs[i] * lhs[i].ln()).exp(),
124        };
125    }
126}
127
128/// Scalar floating-point arithmetic kernel with null mask support.
129/// Preserves IEEE 754 semantics: division by zero produces Inf/NaN, no panicking.
130/// Invalid inputs (mask=false) produce null outputs with zero values.
131#[inline(always)]
132pub fn float_masked_body_std<T: Float>(
133    op: ArithmeticOperator,
134    lhs: &[T],
135    rhs: &[T],
136    mask: &Bitmask,
137    out: &mut [T],
138    out_mask: &mut Bitmask,
139) {
140    let n = lhs.len();
141    for i in 0..n {
142        let valid = unsafe { mask.get_unchecked(i) };
143        if valid {
144            out[i] = match op {
145                ArithmeticOperator::Add => lhs[i] + rhs[i],
146                ArithmeticOperator::Subtract => lhs[i] - rhs[i],
147                ArithmeticOperator::Multiply => lhs[i] * rhs[i],
148                ArithmeticOperator::Divide => lhs[i] / rhs[i],
149                ArithmeticOperator::Remainder => lhs[i] % rhs[i],
150                ArithmeticOperator::Power => (rhs[i] * lhs[i].ln()).exp(),
151            };
152            unsafe {
153                out_mask.set_unchecked(i, true);
154            }
155        } else {
156            out[i] = T::zero();
157            unsafe {
158                out_mask.set_unchecked(i, false);
159            }
160        }
161    }
162}
163
164/// Fused multiply add (a * b + acc) with null mask
165#[inline(always)]
166pub fn fma_masked_body_std<T: Float>(
167    lhs: &[T],
168    rhs: &[T],
169    acc: &[T],
170    mask: &Bitmask,
171    out: &mut [T],
172    out_mask: &mut Bitmask,
173) {
174    let n = lhs.len();
175    for i in 0..n {
176        let valid = unsafe { mask.get_unchecked(i) };
177        if valid {
178            out[i] = lhs[i].mul_add(rhs[i], acc[i]);
179            unsafe {
180                out_mask.set_unchecked(i, true);
181            }
182        } else {
183            out[i] = T::zero();
184            unsafe {
185                out_mask.set_unchecked(i, false);
186            }
187        }
188    }
189}
190
191/// Dense fused multiply add (a * b + acc)
192#[inline(always)]
193pub fn fma_dense_body_std<T: Float>(lhs: &[T], rhs: &[T], acc: &[T], out: &mut [T]) {
194    let n = lhs.len();
195    for i in 0..n {
196        out[i] = lhs[i].mul_add(rhs[i], acc[i]);
197    }
198}