simd_kernels/kernels/arithmetic/
simd.rs

1// Copyright Peter Bower 2025. All Rights Reserved.
2// Licensed under Mozilla Public License (MPL) 2.0.
3
4//! # **SIMD Arithmetic Kernels Module** - *High-Performance Arithmetic*
5//!
6//! Inner SIMD-accelerated implementations using `std::simd` for maximum performance on modern hardware.
7//! Prefer dispatch.rs for easily handling the general case, otherwise you can use these inner functions
8//! directly (e.g., "dense_simd") vs. "maybe masked, maybe simd". 
9//!
10//! ## Overview
11//! - **Portable SIMD**: Uses `std::simd` for cross-platform vectorisation with compile-time lane optimisation
12//! - **Null masks**: Dense (no nulls) and masked variants for Arrow-compatible null handling. 
13//!   These are uniified in dispatch.rs, and opting out of masking yields no performance penalty.
14//! - **Type support**: Integer and floating-point arithmetic with specialised FMA operations
15//! - **Safety**: All unsafe operations are bounds-checked or guaranteed by caller invariants
16//!
17//! ## Architecture Notes
18//! - Building blocks for higher-level dispatch layers, or for low-level hot loops
19//! where one wants to fully avoid abstraction overhead.
20//! - Parallelisation intentionally excluded to allow flexible chunking strategies
21//! - Power operations fall back to scalar for integers, use logarithmic computation for floats
22
23include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
24
25use core::simd::{LaneCount, Mask, Simd, SimdElement, SupportedLaneCount};
26use std::ops::{Add, Div, Mul, Rem, Sub};
27use std::simd::StdFloat;
28use std::simd::cmp::SimdPartialEq;
29
30use minarrow::Bitmask;
31use num_traits::{One, PrimInt, ToPrimitive, WrappingAdd, WrappingMul, WrappingSub, Zero};
32
33use crate::kernels::bitmask::simd::all_true_mask_simd;
34use crate::operators::ArithmeticOperator;
35use crate::utils::simd_mask;
36
37/// SIMD integer arithmetic kernel for dense arrays (no nulls).
38/// Vectorised operations with scalar fallback for power operations and array tails.
39/// Panics on division/remainder by zero (consistent with scalar behaviour).
40#[inline(always)]
41pub fn int_dense_body_simd<T, const LANES: usize>(
42    op: ArithmeticOperator,
43    lhs: &[T],
44    rhs: &[T],
45    out: &mut [T],
46) where
47    T: Copy + One + PrimInt + ToPrimitive + Zero + SimdElement + WrappingMul,
48    LaneCount<LANES>: SupportedLaneCount,
49    Simd<T, LANES>: Add<Output = Simd<T, LANES>>
50        + Sub<Output = Simd<T, LANES>>
51        + Mul<Output = Simd<T, LANES>>
52        + Div<Output = Simd<T, LANES>>
53        + Rem<Output = Simd<T, LANES>>,
54{
55    let n = lhs.len();
56    let mut vectorisable = n / LANES * LANES;
57    let mut i = 0;
58    while i < vectorisable {
59        let a = Simd::<T, LANES>::from_slice(&lhs[i..i + LANES]);
60        let b = Simd::<T, LANES>::from_slice(&rhs[i..i + LANES]);
61        let r = match op {
62            ArithmeticOperator::Add => a + b,
63            ArithmeticOperator::Subtract => a - b,
64            ArithmeticOperator::Multiply => a * b,
65            ArithmeticOperator::Divide => a / b, // Panics if divisor is zero
66            ArithmeticOperator::Remainder => a % b, // Panics if divisor is zero
67            ArithmeticOperator::Power => {
68                vectorisable = 0;
69                break;
70            }
71        };
72        r.copy_to_slice(&mut out[i..i + LANES]);
73        i += LANES;
74    }
75
76    // Scalar tail
77    for idx in vectorisable..n {
78        out[idx] = match op {
79            ArithmeticOperator::Add => lhs[idx] + rhs[idx],
80            ArithmeticOperator::Subtract => lhs[idx] - rhs[idx],
81            ArithmeticOperator::Multiply => lhs[idx] * rhs[idx],
82            ArithmeticOperator::Divide => lhs[idx] / rhs[idx], // Panics if divisor is zero
83            ArithmeticOperator::Remainder => lhs[idx] % rhs[idx], // Panics if divisor is zero
84            ArithmeticOperator::Power => {
85                let mut acc = T::one();
86                let exp = rhs[idx].to_u32().unwrap_or(0);
87                for _ in 0..exp {
88                    acc = acc.wrapping_mul(&lhs[idx]);
89                }
90                acc
91            }
92        };
93    }
94}
95
96/// SIMD integer arithmetic kernel with null mask support.
97/// Division/remainder by zero produces null results (mask=false) rather than panicking.
98#[inline(always)]
99pub fn int_masked_body_simd<T, const LANES: usize>(
100    op: ArithmeticOperator,
101    lhs: &[T],
102    rhs: &[T],
103    mask: &Bitmask,
104    out: &mut [T],
105    out_mask: &mut Bitmask,
106) where
107    T: Copy
108        + PrimInt
109        + ToPrimitive
110        + Zero
111        + One
112        + SimdElement
113        + PartialEq
114        + WrappingAdd
115        + WrappingMul
116        + WrappingSub,
117    LaneCount<LANES>: SupportedLaneCount,
118    Simd<T, LANES>: Add<Output = Simd<T, LANES>>
119        + SimdPartialEq<Mask = Mask<<T as SimdElement>::Mask, LANES>>
120        + Sub<Output = Simd<T, LANES>>
121        + Mul<Output = Simd<T, LANES>>
122        + Div<Output = Simd<T, LANES>>
123        + Rem<Output = Simd<T, LANES>>,
124{
125    let n = lhs.len();
126    let dense = all_true_mask_simd(mask);
127
128    /* If dense, we unfortunately need to near-replicate the dense implementation
129    as that dedicated function panics on `div/0` as it needs to stay mask-free,
130    to support varied workloads. This one works on the same dense principles,
131    but substitutes the null mask when any div/0 issues occur. */
132
133    if dense {
134        // This block replaces the int_dense_body_simd call and handles masking for div/rem
135        let vectorisable = n / LANES * LANES;
136        let mut i = 0;
137        while i < vectorisable {
138            let a = Simd::<T, LANES>::from_slice(&lhs[i..i + LANES]);
139            let b = Simd::<T, LANES>::from_slice(&rhs[i..i + LANES]);
140
141            let (r, valid): (Simd<T, LANES>, Mask<<T as SimdElement>::Mask, LANES>) = match op {
142                ArithmeticOperator::Add => (a + b, Mask::splat(true)),
143                ArithmeticOperator::Subtract => (a - b, Mask::splat(true)),
144                ArithmeticOperator::Multiply => (a * b, Mask::splat(true)),
145                ArithmeticOperator::Power => {
146                    let mut tmp = [T::zero(); LANES];
147                    for l in 0..LANES {
148                        tmp[l] = a[l].pow(b[l].to_u32().unwrap_or(0));
149                    }
150                    (Simd::<T, LANES>::from_array(tmp), Mask::splat(true))
151                }
152                ArithmeticOperator::Divide | ArithmeticOperator::Remainder => {
153                    let div_zero = b.simd_eq(Simd::splat(T::zero()));
154                    let valid = !div_zero;
155                    let safe_b = div_zero.select(Simd::splat(T::one()), b);
156                    let r = match op {
157                        ArithmeticOperator::Divide => a / safe_b,
158                        ArithmeticOperator::Remainder => a % safe_b,
159                        _ => unreachable!(),
160                    };
161                    let r = div_zero.select(Simd::splat(T::zero()), r);
162                    (r, valid)
163                }
164            };
165            r.copy_to_slice(&mut out[i..i + LANES]);
166            // Write the out_mask based on the op
167            let valid_bits = valid.to_bitmask();
168            for l in 0..LANES {
169                let idx = i + l;
170                if idx < n {
171                    unsafe {
172                        out_mask.set_unchecked(idx, ((valid_bits >> l) & 1) == 1);
173                    }
174                }
175            }
176            i += LANES;
177        }
178        // Scalar tail
179        for idx in vectorisable..n {
180            match op {
181                ArithmeticOperator::Add => {
182                    out[idx] = lhs[idx].wrapping_add(&rhs[idx]);
183                    unsafe {
184                        out_mask.set_unchecked(idx, true);
185                    }
186                }
187                ArithmeticOperator::Subtract => {
188                    out[idx] = lhs[idx].wrapping_sub(&rhs[idx]);
189                    unsafe {
190                        out_mask.set_unchecked(idx, true);
191                    }
192                }
193                ArithmeticOperator::Multiply => {
194                    out[idx] = lhs[idx].wrapping_mul(&rhs[idx]);
195                    unsafe {
196                        out_mask.set_unchecked(idx, true);
197                    }
198                }
199                ArithmeticOperator::Power => {
200                    out[idx] = lhs[idx].pow(rhs[idx].to_u32().unwrap_or(0));
201                    unsafe {
202                        out_mask.set_unchecked(idx, true);
203                    }
204                }
205                ArithmeticOperator::Divide | ArithmeticOperator::Remainder => {
206                    if rhs[idx] == T::zero() {
207                        out[idx] = T::zero();
208                        unsafe {
209                            out_mask.set_unchecked(idx, false);
210                        }
211                    } else {
212                        out[idx] = match op {
213                            ArithmeticOperator::Divide => lhs[idx] / rhs[idx],
214                            ArithmeticOperator::Remainder => lhs[idx] % rhs[idx],
215                            _ => unreachable!(),
216                        };
217                        unsafe {
218                            out_mask.set_unchecked(idx, true);
219                        }
220                    }
221                }
222            }
223        }
224        return;
225    }
226
227    let mut i = 0;
228    while i + LANES <= n {
229        let a = Simd::<T, LANES>::from_slice(&lhs[i..i + LANES]);
230        let b = Simd::<T, LANES>::from_slice(&rhs[i..i + LANES]);
231        let m_src: Mask<_, LANES> = simd_mask::<_, LANES>(mask, i, n); // validity mask
232
233        // divisor-is-zero mask
234        let div_zero: Mask<_, LANES> = b.simd_eq(Simd::splat(T::zero()));
235
236        // ── compute result ───────────────────────────────────────────────────
237        let res = match op {
238            ArithmeticOperator::Add => a + b,
239            ArithmeticOperator::Subtract => a - b,
240            ArithmeticOperator::Multiply => a * b,
241            ArithmeticOperator::Divide => {
242                let safe_b = div_zero.select(Simd::splat(T::one()), b); // 0 → 1
243                let q = a / safe_b;
244                div_zero.select(Simd::splat(T::zero()), q) // restore 0
245            }
246            ArithmeticOperator::Remainder => {
247                let safe_b = div_zero.select(Simd::splat(T::one()), b);
248                let r = a % safe_b;
249                div_zero.select(Simd::splat(T::zero()), r)
250            }
251            ArithmeticOperator::Power => {
252                // scalar per-lane power
253                let mut tmp = [T::zero(); LANES];
254                for l in 0..LANES {
255                    tmp[l] = a[l].pow(b[l].to_u32().unwrap_or(0));
256                }
257                Simd::<T, LANES>::from_array(tmp)
258            }
259        };
260
261        // apply source validity mask, write results
262        let selected = m_src.select(res, Simd::splat(T::zero()));
263        selected.copy_to_slice(&mut out[i..i + LANES]);
264
265        // write out-mask bits: combine source mask with div-by-zero validity
266        let final_mask = match op {
267            ArithmeticOperator::Divide | ArithmeticOperator::Remainder => {
268                // For div/rem: valid iff source is valid AND not dividing by zero
269                m_src & !div_zero
270            }
271            _ => m_src,
272        };
273        let mbits = final_mask.to_bitmask();
274        for l in 0..LANES {
275            let idx = i + l;
276            if idx < n {
277                unsafe { out_mask.set_unchecked(idx, ((mbits >> l) & 1) == 1) };
278            }
279        }
280        i += LANES;
281    }
282
283    // scalar tail
284    for j in i..n {
285        let valid = unsafe { mask.get_unchecked(j) };
286        if valid {
287            let (result, final_valid) = match op {
288                ArithmeticOperator::Add => (lhs[j].wrapping_add(&rhs[j]), true),
289                ArithmeticOperator::Subtract => (lhs[j].wrapping_sub(&rhs[j]), true),
290                ArithmeticOperator::Multiply => (lhs[j].wrapping_mul(&rhs[j]), true),
291                ArithmeticOperator::Divide => {
292                    if rhs[j] == T::zero() {
293                        (T::zero(), false) // division by zero -> invalid
294                    } else {
295                        (lhs[j] / rhs[j], true)
296                    }
297                }
298                ArithmeticOperator::Remainder => {
299                    if rhs[j] == T::zero() {
300                        (T::zero(), false) // remainder by zero -> invalid
301                    } else {
302                        (lhs[j] % rhs[j], true)
303                    }
304                }
305                ArithmeticOperator::Power => (lhs[j].pow(rhs[j].to_u32().unwrap_or(0)), true),
306            };
307            out[j] = result;
308            unsafe { out_mask.set_unchecked(j, final_valid) };
309        } else {
310            out[j] = T::zero();
311            unsafe { out_mask.set_unchecked(j, false) };
312        }
313    }
314}
315
316/// SIMD f32 arithmetic kernel with null mask support.
317/// Preserves IEEE 754 semantics: division by zero produces Inf/NaN, no exceptions.
318/// Power operations use scalar fallback with logarithmic computation.
319#[inline(always)]
320pub fn float_masked_body_f32_simd<const LANES: usize>(
321    op: ArithmeticOperator,
322    lhs: &[f32],
323    rhs: &[f32],
324    mask: &Bitmask,
325    out: &mut [f32],
326    out_mask: &mut Bitmask,
327) where
328    LaneCount<LANES>: SupportedLaneCount,
329{
330    type M = <f32 as SimdElement>::Mask;
331
332    let n = lhs.len();
333    let mut i = 0;
334    let dense = all_true_mask_simd(mask);
335
336    while i + LANES <= n {
337        let a = Simd::<f32, LANES>::from_slice(&lhs[i..i + LANES]);
338        let b = Simd::<f32, LANES>::from_slice(&rhs[i..i + LANES]);
339        let m: Mask<M, LANES> = if dense {
340            Mask::splat(true)
341        } else {
342            simd_mask::<M, LANES>(mask, i, n)
343        };
344
345        let res = match op {
346            ArithmeticOperator::Add => a + b,
347            ArithmeticOperator::Subtract => a - b,
348            ArithmeticOperator::Multiply => a * b,
349            ArithmeticOperator::Divide => a / b,
350            ArithmeticOperator::Remainder => a % b,
351            ArithmeticOperator::Power => (b * a.ln()).exp(),
352        };
353
354        let selected = m.select(res, Simd::<f32, LANES>::splat(0.0));
355        selected.copy_to_slice(&mut out[i..i + LANES]);
356
357        let mbits = m.to_bitmask();
358        for l in 0..LANES {
359            let idx = i + l;
360            if idx < n {
361                unsafe { out_mask.set_unchecked(idx, ((mbits >> l) & 1) == 1) };
362            }
363        }
364        i += LANES;
365    }
366
367    // Tail often caused by `n % LANES =! 0`; uses scalar fallback
368    for j in i..n {
369        let valid = dense || unsafe { mask.get_unchecked(j) };
370        if valid {
371            out[j] = match op {
372                ArithmeticOperator::Add => lhs[j] + rhs[j],
373                ArithmeticOperator::Subtract => lhs[j] - rhs[j],
374                ArithmeticOperator::Multiply => lhs[j] * rhs[j],
375                ArithmeticOperator::Divide => lhs[j] / rhs[j],
376                ArithmeticOperator::Remainder => lhs[j] % rhs[j],
377                ArithmeticOperator::Power => (rhs[j] * lhs[j].ln()).exp(),
378            };
379            unsafe { out_mask.set_unchecked(j, true) };
380        } else {
381            out[j] = 0.0;
382            unsafe { out_mask.set_unchecked(j, false) };
383        }
384    }
385}
386
387/// SIMD f64 arithmetic kernel with null mask support.
388/// Preserves IEEE 754 semantics: division by zero produces Inf/NaN, no exceptions.
389/// Power operations use scalar fallback with logarithmic computation.
390#[inline(always)]
391pub fn float_masked_body_f64_simd<const LANES: usize>(
392    op: ArithmeticOperator,
393    lhs: &[f64],
394    rhs: &[f64],
395    mask: &Bitmask,
396    out: &mut [f64],
397    out_mask: &mut Bitmask,
398) where
399    LaneCount<LANES>: SupportedLaneCount,
400{
401    type M = <f64 as SimdElement>::Mask;
402
403    let n = lhs.len();
404    let dense = all_true_mask_simd(mask);
405
406    if dense {
407        // hot
408        float_dense_body_f64_simd::<LANES>(op, lhs, rhs, out);
409        out_mask.fill(true);
410        return;
411    }
412
413    let mut i = 0;
414    while i + LANES <= n {
415        let a = Simd::<f64, LANES>::from_slice(&lhs[i..i + LANES]);
416        let b = Simd::<f64, LANES>::from_slice(&rhs[i..i + LANES]);
417        let m: Mask<M, LANES> = simd_mask::<M, LANES>(mask, i, n);
418
419        let res = match op {
420            ArithmeticOperator::Add => a + b,
421            ArithmeticOperator::Subtract => a - b,
422            ArithmeticOperator::Multiply => a * b,
423            ArithmeticOperator::Divide => a / b,
424            ArithmeticOperator::Remainder => a % b,
425            ArithmeticOperator::Power => (b * a.ln()).exp(),
426        };
427
428        let selected = m.select(res, Simd::<f64, LANES>::splat(0.0));
429        selected.copy_to_slice(&mut out[i..i + LANES]);
430
431        let mbits = m.to_bitmask();
432        for l in 0..LANES {
433            let idx = i + l;
434            if idx < n {
435                unsafe { out_mask.set_unchecked(idx, ((mbits >> l) & 1) == 1) };
436            }
437        }
438        i += LANES;
439    }
440
441    // Tail often caused by `n % LANES =! 0`; uses scalar fallback
442    for j in i..n {
443        let valid = unsafe { mask.get_unchecked(j) };
444        if valid {
445            out[j] = match op {
446                ArithmeticOperator::Add => lhs[j] + rhs[j],
447                ArithmeticOperator::Subtract => lhs[j] - rhs[j],
448                ArithmeticOperator::Multiply => lhs[j] * rhs[j],
449                ArithmeticOperator::Divide => lhs[j] / rhs[j],
450                ArithmeticOperator::Remainder => lhs[j] % rhs[j],
451                ArithmeticOperator::Power => (rhs[j] * lhs[j].ln()).exp(),
452            };
453            unsafe { out_mask.set_unchecked(j, true) };
454        } else {
455            out[j] = 0.0;
456            unsafe { out_mask.set_unchecked(j, false) };
457        }
458    }
459}
460
461/// SIMD f32 arithmetic kernel for dense arrays (no nulls).
462/// Vectorised operations with scalar fallback for power operations and array tails.
463/// Division by zero produces Inf/NaN following IEEE 754 semantics.
464#[inline(always)]
465pub fn float_dense_body_f32_simd<const LANES: usize>(
466    op: ArithmeticOperator,
467    lhs: &[f32],
468    rhs: &[f32],
469    out: &mut [f32],
470) where
471    LaneCount<LANES>: SupportedLaneCount,
472{
473    let n = lhs.len();
474    let mut i = 0;
475    while i + LANES <= n {
476        let a = Simd::<f32, LANES>::from_slice(&lhs[i..i + LANES]);
477        let b = Simd::<f32, LANES>::from_slice(&rhs[i..i + LANES]);
478        let res = match op {
479            ArithmeticOperator::Add => a + b,
480            ArithmeticOperator::Subtract => a - b,
481            ArithmeticOperator::Multiply => a * b,
482            ArithmeticOperator::Divide => a / b,
483            ArithmeticOperator::Remainder => a % b,
484            ArithmeticOperator::Power => (b * a.ln()).exp(),
485        };
486        res.copy_to_slice(&mut out[i..i + LANES]);
487        i += LANES;
488    }
489
490    // Tail often caused by `n % LANES =! 0`; uses scalar fallback
491    for j in i..n {
492        out[j] = match op {
493            ArithmeticOperator::Add => lhs[j] + rhs[j],
494            ArithmeticOperator::Subtract => lhs[j] - rhs[j],
495            ArithmeticOperator::Multiply => lhs[j] * rhs[j],
496            ArithmeticOperator::Divide => lhs[j] / rhs[j],
497            ArithmeticOperator::Remainder => lhs[j] % rhs[j],
498            ArithmeticOperator::Power => (rhs[j] * lhs[j].ln()).exp(),
499        };
500    }
501}
502
503/// SIMD f64 arithmetic kernel for dense arrays (no nulls).
504/// Vectorised operations with scalar fallback for power operations and array tails.
505/// Division by zero produces Inf/NaN following IEEE 754 semantics.
506#[inline(always)]
507pub fn float_dense_body_f64_simd<const LANES: usize>(
508    op: ArithmeticOperator,
509    lhs: &[f64],
510    rhs: &[f64],
511    out: &mut [f64],
512) where
513    LaneCount<LANES>: SupportedLaneCount,
514{
515    let n = lhs.len();
516    let mut i = 0;
517    while i + LANES <= n {
518        let a = Simd::<f64, LANES>::from_slice(&lhs[i..i + LANES]);
519        let b = Simd::<f64, LANES>::from_slice(&rhs[i..i + LANES]);
520        let res = match op {
521            ArithmeticOperator::Add => a + b,
522            ArithmeticOperator::Subtract => a - b,
523            ArithmeticOperator::Multiply => a * b,
524            ArithmeticOperator::Divide => a / b,
525            ArithmeticOperator::Remainder => a % b,
526            ArithmeticOperator::Power => (b * a.ln()).exp(),
527        };
528        res.copy_to_slice(&mut out[i..i + LANES]);
529        i += LANES;
530    }
531
532    // Tail often caused by `n % LANES =! 0`; uses scalar fallback
533    for j in i..n {
534        out[j] = match op {
535            ArithmeticOperator::Add => lhs[j] + rhs[j],
536            ArithmeticOperator::Subtract => lhs[j] - rhs[j],
537            ArithmeticOperator::Multiply => lhs[j] * rhs[j],
538            ArithmeticOperator::Divide => lhs[j] / rhs[j],
539            ArithmeticOperator::Remainder => lhs[j] % rhs[j],
540            ArithmeticOperator::Power => (rhs[j] * lhs[j].ln()).exp(),
541        };
542    }
543}
544
545/// SIMD f32 fused multiply-add kernel with null mask support.
546/// Hardware-accelerated `a.mul_add(b, c)` with proper null propagation.
547#[inline(always)]
548pub fn fma_masked_body_f32_simd<const LANES: usize>(
549    lhs: &[f32],
550    rhs: &[f32],
551    acc: &[f32],
552    mask: &Bitmask,
553    out: &mut [f32],
554    out_mask: &mut minarrow::Bitmask,
555) where
556    LaneCount<LANES>: SupportedLaneCount,
557{
558    use core::simd::{Mask, Simd};
559
560    let n = lhs.len();
561    let mut i = 0;
562    let dense = all_true_mask_simd(mask);
563
564    if dense {
565        fma_dense_body_f32_simd::<LANES>(lhs, rhs, acc, out);
566        out_mask.fill(true);
567        return;
568    }
569
570    while i + LANES <= n {
571        let a = Simd::<f32, LANES>::from_slice(&lhs[i..i + LANES]);
572        let b = Simd::<f32, LANES>::from_slice(&rhs[i..i + LANES]);
573        let c = Simd::<f32, LANES>::from_slice(&acc[i..i + LANES]);
574        let m: Mask<i32, LANES> = simd_mask::<i32, LANES>(mask, i, n);
575
576        let res = a.mul_add(b, c);
577
578        let selected = m.select(res, Simd::<f32, LANES>::splat(0.0));
579        selected.copy_to_slice(&mut out[i..i + LANES]);
580
581        let mbits = m.to_bitmask();
582        for l in 0..LANES {
583            let idx = i + l;
584            if idx < n {
585                unsafe { out_mask.set_unchecked(idx, ((mbits >> l) & 1) == 1) };
586            }
587        }
588        i += LANES;
589    }
590
591    // Scalar tail
592
593    for j in i..n {
594        let valid = unsafe { mask.get_unchecked(j) };
595        if valid {
596            out[j] = lhs[j].mul_add(rhs[j], acc[j]);
597            unsafe { out_mask.set_unchecked(j, true) };
598        } else {
599            out[j] = 0.0;
600            unsafe { out_mask.set_unchecked(j, false) };
601        }
602    }
603}
604
605/// SIMD f64 fused multiply-add kernel with null mask support.
606/// Hardware-accelerated `a.mul_add(b, c)` with proper null propagation.
607#[inline(always)]
608pub fn fma_masked_body_f64_simd<const LANES: usize>(
609    lhs: &[f64],
610    rhs: &[f64],
611    acc: &[f64],
612    mask: &Bitmask,
613    out: &mut [f64],
614    out_mask: &mut minarrow::Bitmask,
615) where
616    LaneCount<LANES>: SupportedLaneCount,
617{
618    use core::simd::{Mask, Simd};
619
620    let n = lhs.len();
621    let mut i = 0;
622    let dense = all_true_mask_simd(mask);
623
624    if dense {
625        // Hot
626        fma_dense_body_f64_simd::<LANES>(lhs, rhs, acc, out);
627        out_mask.fill(true);
628        return;
629    }
630
631    while i + LANES <= n {
632        let a = Simd::<f64, LANES>::from_slice(&lhs[i..i + LANES]);
633        let b = Simd::<f64, LANES>::from_slice(&rhs[i..i + LANES]);
634        let c = Simd::<f64, LANES>::from_slice(&acc[i..i + LANES]);
635        let m: Mask<i64, LANES> = simd_mask::<i64, LANES>(mask, i, n);
636
637        let res = a.mul_add(b, c);
638
639        let selected = m.select(res, Simd::<f64, LANES>::splat(0.0));
640        selected.copy_to_slice(&mut out[i..i + LANES]);
641
642        let mbits = m.to_bitmask();
643        for l in 0..LANES {
644            let idx = i + l;
645            if idx < n {
646                unsafe { out_mask.set_unchecked(idx, ((mbits >> l) & 1) == 1) };
647            }
648        }
649        i += LANES;
650    }
651
652    // Scalar tail
653
654    for j in i..n {
655        let valid = unsafe { mask.get_unchecked(j) };
656        if valid {
657            out[j] = lhs[j].mul_add(rhs[j], acc[j]);
658            unsafe { out_mask.set_unchecked(j, true) };
659        } else {
660            out[j] = 0.0;
661            unsafe { out_mask.set_unchecked(j, false) };
662        }
663    }
664}
665
666/// SIMD f32 fused multiply-add kernel for dense arrays (no nulls).
667/// Hardware-accelerated `a.mul_add(b, c)` with vectorised and scalar tail processing.
668#[inline(always)]
669pub fn fma_dense_body_f32_simd<const LANES: usize>(
670    lhs: &[f32],
671    rhs: &[f32],
672    acc: &[f32],
673    out: &mut [f32],
674) where
675    LaneCount<LANES>: SupportedLaneCount,
676{
677    use core::simd::Simd;
678
679    let n = lhs.len();
680    let mut i = 0;
681
682    while i + LANES <= n {
683        let a = Simd::<f32, LANES>::from_slice(&lhs[i..i + LANES]);
684        let b = Simd::<f32, LANES>::from_slice(&rhs[i..i + LANES]);
685        let c = Simd::<f32, LANES>::from_slice(&acc[i..i + LANES]);
686        let res = a.mul_add(b, c);
687        res.copy_to_slice(&mut out[i..i + LANES]);
688        i += LANES;
689    }
690
691    for j in i..n {
692        out[j] = lhs[j].mul_add(rhs[j], acc[j]);
693    }
694}
695
696/// SIMD f64 fused multiply-add kernel for dense arrays (no nulls).
697/// Hardware-accelerated `a.mul_add(b, c)` with vectorised and scalar tail processing.
698#[inline(always)]
699pub fn fma_dense_body_f64_simd<const LANES: usize>(
700    lhs: &[f64],
701    rhs: &[f64],
702    acc: &[f64],
703    out: &mut [f64],
704) where
705    LaneCount<LANES>: SupportedLaneCount,
706{
707    use core::simd::Simd;
708
709    let n = lhs.len();
710    let mut i = 0;
711
712    while i + LANES <= n {
713        let a = Simd::<f64, LANES>::from_slice(&lhs[i..i + LANES]);
714        let b = Simd::<f64, LANES>::from_slice(&rhs[i..i + LANES]);
715        let c = Simd::<f64, LANES>::from_slice(&acc[i..i + LANES]);
716        let res = a.mul_add(b, c);
717        res.copy_to_slice(&mut out[i..i + LANES]);
718        i += LANES;
719    }
720
721    // Tail uses scalar fallback
722    for j in i..n {
723        out[j] = lhs[j].mul_add(rhs[j], acc[j]);
724    }
725}