simd_kernels/
utils.rs

1// Copyright Peter Bower 2025. All Rights Reserved.
2// Licensed under Mozilla Public License (MPL) 2.0.
3
4//! # **Utility Functions** - *SIMD Processing and Memory Management Utilities*
5//!
6//! Core utilities supporting SIMD kernel implementations with efficient memory handling,
7//! bitmask operations, and performance-critical helper functions.
8
9use std::simd::{LaneCount, Mask, MaskElement, SimdElement, SupportedLaneCount};
10
11use minarrow::{Bitmask, Vec64};
12
13/// Extracts a core::SIMD `Mask<M, N>` for a batch of N lanes from a Minarrow `Bitmask`.
14///
15/// - `mask_bytes`: packed Arrow validity bits (LSB=index 0, bit=1 means valid)
16/// - `offset`: starting index (bit offset into the mask)
17/// - `logical_len`: number of logical bits in the mask
18/// - `M`: SIMD mask type (e.g., i64 for f64, i32 for f32, i8 for i8)
19///
20/// Returns: SIMD Mask<M, N> representing validity for these N lanes.
21/// Bits outside the logical length (i.e., mask is shorter than offset+N)
22/// are treated as valid.
23#[inline(always)]
24pub fn bitmask_to_simd_mask<const N: usize, M>(
25    mask_bytes: &[u8],
26    offset: usize,
27    logical_len: usize,
28) -> Mask<M, N>
29where
30    LaneCount<N>: SupportedLaneCount,
31    M: MaskElement + SimdElement,
32{
33    let lane_limit = (offset + N).min(logical_len);
34    let n_lanes = lane_limit - offset;
35    let mut bits: u64 = 0;
36    for j in 0..n_lanes {
37        let idx = offset + j;
38        let byte = mask_bytes[idx >> 3];
39        if ((byte >> (idx & 7)) & 1) != 0 {
40            bits |= 1u64 << j;
41        }
42    }
43    if n_lanes < N {
44        bits |= !0u64 << n_lanes;
45    }
46    Mask::<M, N>::from_bitmask(bits)
47}
48
49/// Converts a SIMD `Mask<M, N>` to a Minarrow `Bitmask` for the given logical length.
50/// Used at the end of a block operation within SIMD-accelerated kernel functions.
51#[inline(always)]
52pub fn simd_mask_to_bitmask<const N: usize, M>(mask: Mask<M, N>, len: usize) -> Bitmask
53where
54    LaneCount<N>: SupportedLaneCount,
55    M: MaskElement + SimdElement,
56{
57    let mut bits = Vec64::with_capacity((len + 7) / 8);
58    bits.resize((len + 7) / 8, 0);
59
60    let word = mask.to_bitmask();
61    let bytes = word.to_le_bytes();
62
63    let n_bytes = (len + 7) / 8;
64    bits[..n_bytes].copy_from_slice(&bytes[..n_bytes]);
65
66    if len % 8 != 0 {
67        let last = n_bytes - 1;
68        let mask_byte = (1u8 << (len % 8)) - 1;
69        bits[last] &= mask_byte;
70    }
71
72    Bitmask {
73        bits: bits.into(),
74        len,
75    }
76}
77
78/// Bulk-ORs a local bitmask block (from a SIMD mask or similar) into the global Minarrow bitmask at the correct byte offset.
79/// The block (`block_mask`) is expected to contain at least ceil(n_lanes/8) bytes,
80/// with the bit-packed validity bits starting from position 0.
81///
82/// Used to streamline repetitive boilerplate and ensure consistency across kernel null-mask handling.
83///
84/// ### Parameters
85/// - `out_mask`: mutable reference to the output/global Bitmask
86/// - `block_mask`: reference to the local Bitmask containing the block's bits
87/// - `offset`: starting bit offset in the global mask
88/// - `n_lanes`: number of bits in this block (usually SIMD lane count)
89#[inline(always)]
90pub fn write_global_bitmask_block(
91    out_mask: &mut Bitmask,
92    block_mask: &Bitmask,
93    offset: usize,
94    n_lanes: usize,
95) {
96    let n_bytes = (n_lanes + 7) / 8;
97    let base = offset / 8;
98    let block_bytes = &block_mask.bits[..n_bytes];
99    for b in 0..n_bytes {
100        if base + b < out_mask.bits.len() {
101            out_mask.bits[base + b] |= block_bytes[b];
102        }
103    }
104}
105
106/// Determines whether nulls are present given an optional null count and mask reference.
107/// Avoids computing mask cardinality to preserve performance guarantees.
108#[inline(always)]
109pub fn has_nulls(null_count: Option<usize>, mask: Option<&Bitmask>) -> bool {
110    match null_count {
111        Some(n) => n > 0,
112        None => mask.is_some(),
113    }
114}