simd_kernels/utils.rs
1// Copyright Peter Bower 2025. All Rights Reserved.
2// Licensed under Mozilla Public License (MPL) 2.0.
3
4//! # **Utility Functions** - *SIMD Processing and Memory Management Utilities*
5//!
6//! Core utilities supporting SIMD kernel implementations with efficient memory handling,
7//! bitmask operations, and performance-critical helper functions.
8
9#[cfg(feature = "str_arithmetic")]
10use std::mem::MaybeUninit;
11use std::{
12 collections::HashSet,
13 simd::{LaneCount, Mask, MaskElement, SimdElement, SupportedLaneCount},
14};
15
16use minarrow::{Bitmask, CategoricalArray, Integer, MaskedArray, StringArray, Vec64};
17#[cfg(feature = "str_arithmetic")]
18use ryu::Float;
19
20use crate::errors::KernelError;
21
22/// Extracts a core::SIMD `Mask<M, N>` for a batch of N lanes from a Minarrow `Bitmask`.
23///
24/// - `mask_bytes`: packed Arrow validity bits (LSB=index 0, bit=1 means valid)
25/// - `offset`: starting index (bit offset into the mask)
26/// - `logical_len`: number of logical bits in the mask
27/// - `M`: SIMD mask type (e.g., i64 for f64, i32 for f32, i8 for i8)
28///
29/// Returns: SIMD Mask<M, N> representing validity for these N lanes.
30/// Bits outside the logical length (i.e., mask is shorter than offset+N)
31/// are treated as valid.
32#[inline(always)]
33pub fn bitmask_to_simd_mask<const N: usize, M>(
34 mask_bytes: &[u8],
35 offset: usize,
36 logical_len: usize,
37) -> Mask<M, N>
38where
39 LaneCount<N>: SupportedLaneCount,
40 M: MaskElement + SimdElement,
41{
42 let lane_limit = (offset + N).min(logical_len);
43 let n_lanes = lane_limit - offset;
44 let mut bits: u64 = 0;
45 for j in 0..n_lanes {
46 let idx = offset + j;
47 let byte = mask_bytes[idx >> 3];
48 if ((byte >> (idx & 7)) & 1) != 0 {
49 bits |= 1u64 << j;
50 }
51 }
52 if n_lanes < N {
53 bits |= !0u64 << n_lanes;
54 }
55 Mask::<M, N>::from_bitmask(bits)
56}
57
58/// Converts a SIMD `Mask<M, N>` to a Minarrow `Bitmask` for the given logical length.
59/// Used at the end of a block operation within SIMD-accelerated kernel functions.
60#[inline(always)]
61pub fn simd_mask_to_bitmask<const N: usize, M>(mask: Mask<M, N>, len: usize) -> Bitmask
62where
63 LaneCount<N>: SupportedLaneCount,
64 M: MaskElement + SimdElement,
65{
66 let mut bits = Vec64::with_capacity((len + 7) / 8);
67 bits.resize((len + 7) / 8, 0);
68
69 let word = mask.to_bitmask();
70 let bytes = word.to_le_bytes();
71
72 let n_bytes = (len + 7) / 8;
73 bits[..n_bytes].copy_from_slice(&bytes[..n_bytes]);
74
75 if len % 8 != 0 {
76 let last = n_bytes - 1;
77 let mask_byte = (1u8 << (len % 8)) - 1;
78 bits[last] &= mask_byte;
79 }
80
81 Bitmask {
82 bits: bits.into(),
83 len,
84 }
85}
86
87/// Bulk-ORs a local bitmask block (from a SIMD mask or similar) into the global Minarrow bitmask at the correct byte offset.
88/// The block (`block_mask`) is expected to contain at least ceil(n_lanes/8) bytes,
89/// with the bit-packed validity bits starting from position 0.
90///
91/// Used to streamline repetitive boilerplate and ensure consistency across kernel null-mask handling.
92///
93/// ### Parameters
94/// - `out_mask`: mutable reference to the output/global Bitmask
95/// - `block_mask`: reference to the local Bitmask containing the block's bits
96/// - `offset`: starting bit offset in the global mask
97/// - `n_lanes`: number of bits in this block (usually SIMD lane count)
98#[inline(always)]
99pub fn write_global_bitmask_block(
100 out_mask: &mut Bitmask,
101 block_mask: &Bitmask,
102 offset: usize,
103 n_lanes: usize,
104) {
105 let n_bytes = (n_lanes + 7) / 8;
106 let base = offset / 8;
107 let block_bytes = &block_mask.bits[..n_bytes];
108 for b in 0..n_bytes {
109 if base + b < out_mask.bits.len() {
110 out_mask.bits[base + b] |= block_bytes[b];
111 }
112 }
113}
114
115/// Determines whether nulls are present given an optional null count and mask reference.
116/// Avoids computing mask cardinality to preserve performance guarantees.
117#[inline(always)]
118pub fn has_nulls(null_count: Option<usize>, mask: Option<&Bitmask>) -> bool {
119 match null_count {
120 Some(n) => n > 0,
121 None => mask.is_some(),
122 }
123}
124
125/// Creates a SIMD mask from a bitmask window for vectorised conditional operations.
126///
127/// Converts a contiguous section of a bitmask into a SIMD mask.
128/// The resulting mask can be used to selectively enable/disable SIMD lanes during
129/// computation, providing efficient support for sparse or conditional operations.
130///
131/// # Type Parameters
132/// - `T`: Mask element type implementing `MaskElement` (typically i8, i16, i32, or i64)
133/// - `N`: Number of SIMD lanes, must match the SIMD vector width for the target operation
134///
135/// # Parameters
136/// - `mask`: Source bitmask containing validity information
137/// - `offset`: Starting bit offset within the bitmask
138/// - `len`: Maximum number of bits to consider (bounds checking)
139///
140/// # Returns
141/// A `Mask<T, N>` where each lane corresponds to the validity of the corresponding input element.
142/// Lanes beyond `len` are set to false for safety.
143///
144/// # Usage Example
145/// ```rust,ignore
146/// use simd_kernels::utils::simd_mask;
147///
148/// // Create 8-lane mask for conditional SIMD operations
149/// let mask: Mask<i32, 8> = simd_mask(&bitmask, 0, 64);
150/// let result = simd_vector.select(mask, default_vector);
151/// ```
152#[inline(always)]
153pub fn simd_mask<T: MaskElement, const N: usize>(
154 mask: &Bitmask,
155 offset: usize,
156 len: usize,
157) -> Mask<T, N>
158where
159 LaneCount<N>: SupportedLaneCount,
160{
161 let mut bits = [false; N];
162 for l in 0..N {
163 let idx = offset + l;
164 bits[l] = idx < len && unsafe { mask.get_unchecked(idx) };
165 }
166 Mask::from_array(bits)
167}
168
169/// Merge two optional Bitmasks into a new output mask, computing per-row AND.
170/// Returns None if both inputs are None (output is dense).
171#[inline]
172pub fn merge_bitmasks_to_new(
173 lhs: Option<&Bitmask>,
174 rhs: Option<&Bitmask>,
175 len: usize,
176) -> Option<Bitmask> {
177 match (lhs, rhs) {
178 (None, None) => None,
179 (Some(l), None) | (None, Some(l)) => {
180 debug_assert!(l.len() >= len, "Bitmask too short in merge");
181 let mut out = Bitmask::new_set_all(len, true);
182 for i in 0..len {
183 out.set(i, l.get(i));
184 }
185 Some(out)
186 }
187 (Some(l), Some(r)) => {
188 debug_assert!(l.len() >= len, "Left Bitmask too short in merge");
189 debug_assert!(r.len() >= len, "Right Bitmask too short in merge");
190 let mut out = Bitmask::new_set_all(len, true);
191 for i in 0..len {
192 out.set(i, l.get(i) && r.get(i));
193 }
194 Some(out)
195 }
196 }
197}
198
199/// Checks the mask capacity is large enough
200/// Used so we can avoid bounds checks in the hot loop
201#[inline(always)]
202pub fn confirm_mask_capacity(cmp_len: usize, mask: Option<&Bitmask>) -> Result<(), KernelError> {
203 if let Some(m) = mask {
204 confirm_capacity("mask (Bitmask)", m.capacity(), cmp_len)?;
205 }
206 Ok(())
207}
208
209/// Strips '.0' from concatenated decimal values so 'Hello1.0' becomes 'Hello1'.
210#[inline]
211#[cfg(feature = "str_arithmetic")]
212pub fn format_finite<F: Float>(buf: &mut [MaybeUninit<u8>; 24], f: F) -> &str {
213 unsafe {
214 let ptr = buf.as_mut_ptr() as *mut u8;
215 let n = f.write_to_ryu_buffer(ptr);
216 debug_assert!(n <= buf.len());
217
218 let slice = core::slice::from_raw_parts(ptr, n);
219 let s = core::str::from_utf8_unchecked(slice);
220
221 // Strip trailing ".0" if present
222 if s.ends_with(".0") {
223 let trimmed_len = s.len() - 2;
224 core::str::from_utf8_unchecked(&slice[..trimmed_len])
225 } else {
226 s
227 }
228 }
229}
230
231/// Estimate cardinality ratio on a sample from a CategoricalArray.
232/// Used to quickly figure out the optimal strategy when comparing
233/// StringArray and CategoricalArrays.
234#[inline(always)]
235pub fn estimate_categorical_cardinality(cat: &CategoricalArray<u32>, sample_size: usize) -> f64 {
236 let len = cat.data.len();
237 if len == 0 {
238 return 0.0;
239 }
240 let mut seen = HashSet::with_capacity(sample_size.min(len));
241 let step = (len / sample_size.max(1)).max(1);
242 for i in (0..len).step_by(step) {
243 let s = unsafe { cat.get_str_unchecked(i) };
244 seen.insert(s);
245 if seen.len() >= sample_size {
246 break;
247 }
248 }
249 (seen.len() as f64) / (sample_size.min(len) as f64)
250}
251
252/// Estimate cardinality ratio on a sample from a StringArray.
253/// Used to quickly figure out the optimal strategy when comparing
254/// StringArray and CategoricalArrays.
255#[inline(always)]
256pub fn estimate_string_cardinality<T: Integer>(arr: &StringArray<T>, sample_size: usize) -> f64 {
257 let len = arr.len();
258 if len == 0 {
259 return 0.0;
260 }
261 let mut seen = HashSet::with_capacity(sample_size.min(len));
262 let step = (len / sample_size.max(1)).max(1);
263 for i in (0..len).step_by(step) {
264 let s = unsafe { arr.get_str_unchecked(i) };
265 seen.insert(s);
266 if seen.len() >= sample_size {
267 break;
268 }
269 }
270 (seen.len() as f64) / (sample_size.min(len) as f64)
271}
272
273/// Validates that actual capacity matches expected capacity for kernel operations.
274///
275/// Essential validation function used throughout the kernel library to ensure data structure
276/// capacities are correct before performing operations. Prevents buffer overruns and ensures
277/// memory safety by catching capacity mismatches early with descriptive error messages.
278///
279/// # Parameters
280/// - `label`: Descriptive label for the validation context (used in error messages)
281/// - `actual`: The actual capacity of the data structure being validated
282/// - `expected`: The expected capacity required for the operation
283///
284/// # Returns
285/// `Ok(())` if capacities match, otherwise `KernelError::InvalidArguments` with detailed message.
286///
287/// # Error Conditions
288/// Returns `KernelError::InvalidArguments` when `actual != expected`, providing a clear
289/// error message indicating the mismatch and context.
290#[inline(always)]
291pub fn confirm_capacity(label: &str, actual: usize, expected: usize) -> Result<(), KernelError> {
292 if actual != expected {
293 return Err(KernelError::InvalidArguments(format!(
294 "{}: capacity mismatch (expected {}, got {})",
295 label, expected, actual
296 )));
297 }
298 Ok(())
299}
300
301/// Validates that two lengths are equal for binary kernel operations.
302///
303/// Critical validation function ensuring input arrays have matching lengths before performing
304/// binary operations like comparisons, arithmetic, or logical operations. Prevents undefined
305/// behaviour and provides clear error diagnostics when length mismatches occur.
306///
307/// # Parameters
308/// - `label`: Descriptive context label for error reporting (e.g., "compare numeric")
309/// - `a`: Length of the first input array or data structure
310/// - `b`: Length of the second input array or data structure
311///
312/// # Returns
313/// `Ok(())` if lengths are equal, otherwise `KernelError::LengthMismatch` with diagnostic details.
314#[inline(always)]
315pub fn confirm_equal_len(label: &str, a: usize, b: usize) -> Result<(), KernelError> {
316 if a != b {
317 return Err(KernelError::LengthMismatch(format!(
318 "{}: length mismatch (lhs: {}, rhs: {})",
319 label, a, b
320 )));
321 }
322 Ok(())
323}
324
325/// SIMD Alignment check. Returns true if the slice is properly
326/// 64-byte aligned for SIMD operations, false otherwise.
327#[inline(always)]
328pub fn is_simd_aligned<T>(slice: &[T]) -> bool {
329 if slice.is_empty() {
330 true
331 } else {
332 (slice.as_ptr() as usize) % 64 == 0
333 }
334}