Skip to main content

simd_kernels/kernels/
comparison.rs

1// Copyright (c) 2025 Peter Bower
2// SPDX-License-Identifier: AGPL-3.0-or-later
3// Commercial licensing available. See LICENSE and LICENSING.md.
4
5//! # **Comparison Operations Kernels Module** - *High-Performance Element-wise Comparison Operations*
6//!
7//! Optimised comparison kernels providing comprehensive element-wise comparison operations
8//! across numeric, string, and categorical data types with SIMD acceleration and null-aware semantics.
9//! Foundation for filtering, conditional logic, and analytical query processing.
10//!
11//! ## Core Operations
12//! - **Numeric comparisons**: Equal, not equal, greater than, less than, greater/less than or equal
13//! - **String comparisons**: UTF-8 aware lexicographic ordering with efficient prefix matching
14//! - **Categorical comparisons**: Dictionary-encoded comparisons avoiding string materialisation
15//! - **Null-aware semantics**: Proper three-valued logic handling (true/false/null)
16//! - **SIMD vectorisation**: Hardware-accelerated bulk comparison operations
17//! - **Bitmask operations**: Efficient boolean result representation using bit manipulation
18
19include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
20
21// SIMD
22use std::marker::PhantomData;
23#[cfg(feature = "simd")]
24use std::simd::{Mask, Simd};
25
26use minarrow::{Bitmask, BooleanArray, Integer, Numeric};
27
28#[cfg(not(feature = "simd"))]
29use crate::kernels::bitmask::std::{and_masks, in_mask, not_in_mask, not_mask};
30use crate::operators::ComparisonOperator;
31use minarrow::enums::error::KernelError;
32#[cfg(feature = "simd")]
33use minarrow::kernels::bitmask::simd::{
34    and_masks_simd, in_mask_simd, not_in_mask_simd, not_mask_simd,
35};
36use minarrow::utils::confirm_equal_len;
37#[cfg(feature = "simd")]
38use minarrow::utils::is_simd_aligned;
39use minarrow::{BitmaskVT, BooleanAVT, CategoricalAVT, StringAVT};
40
41/// Returns a new Bitmask for boolean buffers, all bits cleared (false).
42#[inline(always)]
43fn new_bool_bitmask(len: usize) -> Bitmask {
44    Bitmask::new_set_all(len, false)
45}
46
47/// Merge two Bitmasks using bitwise AND, or propagate one if only one is present.
48fn merge_bitmasks_to_new(a: Option<&Bitmask>, b: Option<&Bitmask>, len: usize) -> Option<Bitmask> {
49    match (a, b) {
50        (None, None) => None,
51        (Some(x), None) | (None, Some(x)) => Some(x.slice_clone(0, len)),
52        (Some(x), Some(y)) => {
53            let mut out = Bitmask::new_set_all(len, true);
54            for i in 0..len {
55                unsafe { out.set_unchecked(i, x.get_unchecked(i) && y.get_unchecked(i)) };
56            }
57            Some(out)
58        }
59    }
60}
61
62// Int and float
63
64macro_rules! impl_cmp_numeric {
65    ($fn_name:ident, $fn_name_to:ident, $ty:ty, $lanes:expr, $mask_elem:ty) => {
66        /// Zero-allocation variant: writes directly to caller's output buffer.
67        ///
68        /// Type-specific SIMD-accelerated comparison function with vectorised operations.
69        /// Panics if output capacity < lhs.len().
70        #[inline(always)]
71        pub fn $fn_name_to(
72            lhs: &[$ty],
73            rhs: &[$ty],
74            mask: Option<&Bitmask>,
75            op: ComparisonOperator,
76            output: &mut Bitmask,
77        ) -> Result<(), KernelError> {
78            let len = lhs.len();
79            confirm_equal_len("compare numeric length mismatch", len, rhs.len())?;
80            assert!(
81                output.capacity() >= len,
82                concat!(stringify!($fn_name_to), ": output capacity too small")
83            );
84            let has_nulls = mask.is_some();
85
86            #[cfg(feature = "simd")]
87            {
88                // Check if both arrays are 64-byte aligned for SIMD
89                if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
90                    use std::simd::cmp::{SimdPartialEq, SimdPartialOrd};
91                    const N: usize = $lanes;
92                    if !has_nulls {
93                        let mut i = 0;
94                        while i + N <= len {
95                            let a = Simd::<$ty, N>::from_slice(&lhs[i..i + N]);
96                            let b = Simd::<$ty, N>::from_slice(&rhs[i..i + N]);
97                            let m: Mask<$mask_elem, N> = match op {
98                                ComparisonOperator::Equals => a.simd_eq(b),
99                                ComparisonOperator::NotEquals => a.simd_ne(b),
100                                ComparisonOperator::LessThan => a.simd_lt(b),
101                                ComparisonOperator::LessThanOrEqualTo => a.simd_le(b),
102                                ComparisonOperator::GreaterThan => a.simd_gt(b),
103                                ComparisonOperator::GreaterThanOrEqualTo => a.simd_ge(b),
104                                _ => Mask::splat(false),
105                            };
106                            let bits = m.to_bitmask();
107                            for l in 0..N {
108                                if ((bits >> l) & 1) == 1 {
109                                    unsafe { output.set_unchecked(i + l, true) };
110                                }
111                            }
112                            i += N;
113                        }
114                        // Tail often caused by `n % LANES != 0`; uses scalar fallback.
115                        for j in i..len {
116                            let res = match op {
117                                ComparisonOperator::Equals => lhs[j] == rhs[j],
118                                ComparisonOperator::NotEquals => lhs[j] != rhs[j],
119                                ComparisonOperator::LessThan => lhs[j] < rhs[j],
120                                ComparisonOperator::LessThanOrEqualTo => lhs[j] <= rhs[j],
121                                ComparisonOperator::GreaterThan => lhs[j] > rhs[j],
122                                ComparisonOperator::GreaterThanOrEqualTo => lhs[j] >= rhs[j],
123                                _ => false,
124                            };
125                            if res {
126                                unsafe { output.set_unchecked(j, true) };
127                            }
128                        }
129
130                        return Ok(());
131                    }
132                }
133                // Fall through to scalar path if alignment check failed
134            }
135
136            // Scalar fallback - alignment check failed
137            for i in 0..len {
138                if has_nulls && !mask.map_or(true, |m| unsafe { m.get_unchecked(i) }) {
139                    continue;
140                }
141                let res = match op {
142                    ComparisonOperator::Equals => lhs[i] == rhs[i],
143                    ComparisonOperator::NotEquals => lhs[i] != rhs[i],
144                    ComparisonOperator::LessThan => lhs[i] < rhs[i],
145                    ComparisonOperator::LessThanOrEqualTo => lhs[i] <= rhs[i],
146                    ComparisonOperator::GreaterThan => lhs[i] > rhs[i],
147                    ComparisonOperator::GreaterThanOrEqualTo => lhs[i] >= rhs[i],
148                    _ => false,
149                };
150                if res {
151                    unsafe { output.set_unchecked(i, true) };
152                }
153            }
154            Ok(())
155        }
156
157        /// Type-specific SIMD-accelerated comparison function with vectorised operations.
158        ///
159        /// Specialised comparison implementation optimised for the specific numeric type with
160        /// architecture-appropriate lane configuration. Features memory alignment checking,
161        /// SIMD vectorisation, and optional null mask support for maximum performance.
162        ///
163        /// # Parameters
164        /// - `lhs`: Left-hand side slice for comparison
165        /// - `rhs`: Right-hand side slice for comparison
166        /// - `mask`: Optional validity mask applied after comparison
167        /// - `op`: Comparison operator to apply
168        ///
169        /// # Returns
170        /// `Result<BooleanArray<()>, KernelError>` containing comparison results.
171        ///
172        /// # SIMD Optimisations
173        /// - Memory alignment: Checks 64-byte alignment for optimal SIMD operations
174        /// - Vectorised comparisons: Uses SIMD compare operations for parallel processing
175        /// - Scalar fallback: Efficient scalar path for unaligned or remainder elements
176        #[inline(always)]
177        pub fn $fn_name(
178            lhs: &[$ty],
179            rhs: &[$ty],
180            mask: Option<&Bitmask>,
181            op: ComparisonOperator,
182        ) -> Result<BooleanArray<()>, KernelError> {
183            let len = lhs.len();
184            let mut out = new_bool_bitmask(len);
185            $fn_name_to(lhs, rhs, mask, op, &mut out)?;
186            Ok(BooleanArray {
187                data: out.into(),
188                null_mask: mask.cloned(),
189                len,
190                _phantom: PhantomData,
191            })
192        }
193    };
194}
195
196/// Zero-allocation variant: writes directly to caller's output buffer.
197///
198/// Unified numeric comparison dispatch with optional null mask support.
199/// The output Bitmask must have capacity >= lhs.len().
200#[inline(always)]
201pub fn cmp_numeric_to<T: Numeric + Copy + 'static>(
202    lhs: &[T],
203    rhs: &[T],
204    mask: Option<&Bitmask>,
205    op: ComparisonOperator,
206    output: &mut Bitmask,
207) -> Result<(), KernelError> {
208    macro_rules! dispatch {
209        ($t:ty, $f:ident) => {
210            if std::any::TypeId::of::<T>() == std::any::TypeId::of::<$t>() {
211                return $f(
212                    unsafe { std::mem::transmute(lhs) },
213                    unsafe { std::mem::transmute(rhs) },
214                    mask,
215                    op,
216                    output,
217                );
218            }
219        };
220    }
221    dispatch!(i32, cmp_i32_to);
222    dispatch!(i64, cmp_i64_to);
223    dispatch!(u32, cmp_u32_to);
224    dispatch!(u64, cmp_u64_to);
225    dispatch!(f32, cmp_f32_to);
226    dispatch!(f64, cmp_f64_to);
227
228    unreachable!("Unsupported numeric type for compare_numeric");
229}
230
231/// Unified numeric comparison dispatch with optional null mask support.
232///
233/// High-performance generic comparison function that dispatches to type-specific SIMD implementations
234/// based on runtime type identification. Supports all numeric types with optional null mask filtering
235/// and comprehensive error handling for mismatched lengths and unsupported types.
236///
237/// # Type Parameters
238/// - `T`: Numeric type implementing `Numeric + Copy + 'static` (i32, i64, u32, u64, f32, f64)
239///
240/// # Parameters
241/// - `lhs`: Left-hand side numeric slice for comparison
242/// - `rhs`: Right-hand side numeric slice for comparison
243/// - `mask`: Optional validity mask applied after comparison (AND operation)
244/// - `op`: Comparison operator to apply (Equals, NotEquals, LessThan, etc.)
245///
246/// # Returns
247/// `Result<BooleanArray<()>, KernelError>` containing comparison results or error details.
248///
249/// # Dispatch Strategy
250/// Uses `TypeId` runtime checking to dispatch to optimal type-specific implementations:
251/// - `i32`/`u32`: 32-bit integer SIMD kernels with W32 lane configuration
252/// - `i64`/`u64`: 64-bit integer SIMD kernels with W64 lane configuration
253/// - `f32`/`f64`: IEEE 754 floating-point SIMD kernels with specialised NaN handling
254///
255/// # Error Conditions
256/// - `KernelError::LengthMismatch`: Input slices have different lengths
257/// - `KernelError::InvalidArguments`: Unsupported numeric type (unreachable in practice)
258///
259/// # Performance Benefits
260/// - Zero-cost dispatch: Type resolution optimised away at compile time for monomorphic usage
261/// - SIMD acceleration: Delegates to vectorised implementations for maximum throughput
262/// - Memory efficiency: Optional mask processing avoids unnecessary allocations
263///
264/// # Example Usage
265/// ```rust,ignore
266/// use simd_kernels::kernels::comparison::{cmp_numeric, ComparisonOperator};
267///
268/// let lhs = &[1i32, 2, 3, 4];
269/// let rhs = &[1i32, 1, 4, 3];
270/// let result = cmp_numeric(lhs, rhs, None, ComparisonOperator::Equals)?;
271/// // Result: [true, false, false, false]
272/// ```
273#[inline(always)]
274pub fn cmp_numeric<T: Numeric + Copy + 'static>(
275    lhs: &[T],
276    rhs: &[T],
277    mask: Option<&Bitmask>,
278    op: ComparisonOperator,
279) -> Result<BooleanArray<()>, KernelError> {
280    let len = lhs.len();
281    let mut out = new_bool_bitmask(len);
282    cmp_numeric_to(lhs, rhs, mask, op, &mut out)?;
283    Ok(BooleanArray {
284        data: out.into(),
285        null_mask: mask.cloned(),
286        len,
287        _phantom: PhantomData,
288    })
289}
290
291/// SIMD-accelerated compare bitmask
292///
293/// Compare two packed bool bitmask slices over a window, using the given operator.
294/// The offsets are bit offsets into each mask.
295/// The mask, if provided, is ANDed after the comparison.
296/// Requires that all offsets are 64-bit aligned (i.e., offset % 64 == 0).
297///
298/// This lower level kernel can be orchestrated by apply_cmp_bool which
299/// wraps it into a BoolWindow with null-aware semantics.
300#[cfg(feature = "simd")]
301pub fn cmp_bitmask_simd<const LANES: usize>(
302    lhs: BitmaskVT<'_>,
303    rhs: BitmaskVT<'_>,
304    mask: Option<BitmaskVT<'_>>,
305    op: ComparisonOperator,
306) -> Result<Bitmask, KernelError>
307where
308{
309    // We have some code duplication here with the `std` version,
310    // but unifying then means a const LANE generic on the non-simd path,
311    // and adding a higher level dispatch layer creates additional indirection
312    // and 9 args instead of 4, hence why it's this way.
313
314    confirm_equal_len("compare bool length mismatch", lhs.2, rhs.2)?;
315    let (lhs_mask, lhs_offset, len) = lhs;
316    let (rhs_mask, rhs_offset, _) = rhs;
317
318    // Handle 'In' and 'NotIn' early
319
320    if matches!(op, ComparisonOperator::In | ComparisonOperator::NotIn) {
321        let mut out = match op {
322            ComparisonOperator::In => in_mask_simd::<LANES>(lhs, rhs),
323            ComparisonOperator::NotIn => not_in_mask_simd::<LANES>(lhs, rhs),
324            _ => unreachable!(),
325        };
326        if let Some(mask_slice) = mask {
327            out = and_masks_simd::<LANES>((&out, 0, out.len), mask_slice);
328        }
329        return Ok(out);
330    }
331
332    // Word-aligned offsets
333    if lhs_offset % 64 != 0
334        || rhs_offset % 64 != 0
335        || mask.as_ref().map_or(false, |(_, mo, _)| mo % 64 != 0)
336    {
337        return Err(KernelError::InvalidArguments(format!(
338            "cmp_bitmask: all offsets must be 64-bit aligned (lhs: {}, rhs: {}, mask offset: {:?})",
339            lhs_offset,
340            rhs_offset,
341            mask.as_ref().map(|(_, mo, _)| mo)
342        )));
343    }
344
345    // Precompute word indices/counts
346    let lhs_word_start = lhs_offset / 64;
347    let rhs_word_start = rhs_offset / 64;
348    let n_words = (len + 63) / 64;
349
350    // Allocate output
351    let mut out = Bitmask::new_set_all(len, false);
352
353    type Word = u64;
354    let lane_words = LANES;
355    let simd_chunks = n_words / lane_words;
356
357    let tail_words = n_words % lane_words;
358    let mut word_idx = 0;
359
360    // SIMD main path
361    for chunk in 0..simd_chunks {
362        let base_lhs = lhs_word_start + chunk * lane_words;
363        let base_rhs = rhs_word_start + chunk * lane_words;
364        let base_mask = mask
365            .as_ref()
366            .map(|(m, mask_word_start, _)| (m, mask_word_start + chunk * lane_words));
367
368        let mut lhs_arr = [0u64; LANES];
369        let mut rhs_arr = [0u64; LANES];
370        let mut mask_arr = [!0u64; LANES];
371
372        for lane in 0..LANES {
373            lhs_arr[lane] = unsafe { lhs_mask.word_unchecked(base_lhs + lane) };
374            rhs_arr[lane] = unsafe { rhs_mask.word_unchecked(base_rhs + lane) };
375            if let Some((m, mask_word_start)) = base_mask {
376                mask_arr[lane] = unsafe { m.word_unchecked(mask_word_start + lane) };
377            }
378        }
379        let lhs_v = Simd::<Word, LANES>::from_array(lhs_arr);
380        let rhs_v = Simd::<Word, LANES>::from_array(rhs_arr);
381        let mask_v = Simd::<Word, LANES>::from_array(mask_arr);
382
383        let cmp_v = match op {
384            ComparisonOperator::Equals => !(lhs_v ^ rhs_v),
385            ComparisonOperator::NotEquals => lhs_v ^ rhs_v,
386            ComparisonOperator::GreaterThan => lhs_v & (!rhs_v),
387            ComparisonOperator::LessThan => (!lhs_v) & rhs_v,
388            ComparisonOperator::GreaterThanOrEqualTo => lhs_v | (!rhs_v),
389            ComparisonOperator::LessThanOrEqualTo => (!lhs_v) | rhs_v,
390            _ => Simd::splat(0),
391        };
392        let result_v = cmp_v & mask_v;
393
394        for lane in 0..LANES {
395            unsafe {
396                out.set_word_unchecked(word_idx, result_v[lane]);
397            }
398            word_idx += 1;
399        }
400    }
401
402    // Tail often caused by `n % LANES != 0`; uses scalar fallback.
403    let base_lhs = lhs_word_start + simd_chunks * lane_words;
404    let base_rhs = rhs_word_start + simd_chunks * lane_words;
405    let base_mask: Option<(&Bitmask, usize)> = mask
406        .as_ref()
407        .map(|(m, mo, _)| (*m, mo + simd_chunks * lane_words));
408
409    for tail in 0..tail_words {
410        let a = unsafe { lhs_mask.word_unchecked(base_lhs + tail) };
411        let b = unsafe { rhs_mask.word_unchecked(base_rhs + tail) };
412        let m = if let Some((m, mask_word_start)) = base_mask {
413            unsafe { m.word_unchecked(mask_word_start + tail) }
414        } else {
415            !0u64
416        };
417        let cmp = match op {
418            ComparisonOperator::Equals => !(a ^ b),
419            ComparisonOperator::NotEquals => a ^ b,
420            ComparisonOperator::GreaterThan => a & (!b),
421            ComparisonOperator::LessThan => (!a) & b,
422            ComparisonOperator::GreaterThanOrEqualTo => a | (!b),
423            ComparisonOperator::LessThanOrEqualTo => (!a) | b,
424            _ => 0,
425        } & m;
426        unsafe {
427            out.set_word_unchecked(word_idx, cmp);
428        }
429        word_idx += 1;
430    }
431
432    out.mask_trailing_bits();
433    Ok(out)
434}
435
436/// Performs vectorised boolean array comparisons with null mask handling.
437///
438/// High-performance SIMD-accelerated comparison function for boolean arrays with automatic null
439/// mask merging and operator-specific optimisations. Supports all comparison operators through
440/// efficient bitmask operations with configurable lane counts for architecture-specific tuning.
441///
442/// # Type Parameters
443/// - `LANES`: Number of SIMD lanes to process simultaneously
444///
445/// # Parameters
446/// - `lhs`: Left-hand side boolean array view as `(array, offset, length)` tuple
447/// - `rhs`: Right-hand side boolean array view as `(array, offset, length)` tuple  
448/// - `op`: Comparison operator (Equals, NotEquals, In, NotIn, IsNull, IsNotNull, etc.)
449///
450/// # Returns
451/// `Result<BooleanArray<()>, KernelError>` containing comparison results with merged null semantics.
452pub fn cmp_bool<const LANES: usize>(
453    lhs: BooleanAVT<'_, ()>,
454    rhs: BooleanAVT<'_, ()>,
455    op: ComparisonOperator,
456) -> Result<BooleanArray<()>, KernelError>
457where
458{
459    let (lhs_arr, lhs_off, len) = lhs;
460    let (rhs_arr, rhs_off, rlen) = rhs;
461    debug_assert_eq!(len, rlen, "cmp_bool: window length mismatch");
462
463    #[cfg(feature = "simd")]
464    let merged_null_mask: Option<Bitmask> =
465        match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
466            (None, None) => None,
467            (Some(m), None) | (None, Some(m)) => Some(m.slice_clone(lhs_off, len)),
468            (Some(a), Some(b)) => {
469                let am = (a, lhs_off, len);
470                let bm = (b, rhs_off, len);
471                Some(and_masks_simd::<LANES>(am, bm))
472            }
473        };
474
475    #[cfg(not(feature = "simd"))]
476    let merged_null_mask: Option<Bitmask> =
477        match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
478            (None, None) => None,
479            (Some(m), None) | (None, Some(m)) => Some(m.slice_clone(lhs_off, len)),
480            (Some(a), Some(b)) => {
481                let am = (a, lhs_off, len);
482                let bm = (b, rhs_off, len);
483                Some(and_masks(am, bm))
484            }
485        };
486
487    let mask_slice = merged_null_mask.as_ref().map(|m| (m, 0, len));
488
489    let data = match op {
490        ComparisonOperator::Equals
491        | ComparisonOperator::NotEquals
492        | ComparisonOperator::LessThan
493        | ComparisonOperator::LessThanOrEqualTo
494        | ComparisonOperator::GreaterThan
495        | ComparisonOperator::GreaterThanOrEqualTo
496        | ComparisonOperator::In
497        | ComparisonOperator::NotIn => {
498            #[cfg(feature = "simd")]
499            let res = cmp_bitmask_simd::<LANES>(
500                (&lhs_arr.data, lhs_off, len),
501                (&rhs_arr.data, rhs_off, len),
502                mask_slice,
503                op,
504            )?;
505            #[cfg(not(feature = "simd"))]
506            let res = cmp_bitmask_std(
507                (&lhs_arr.data, lhs_off, len),
508                (&rhs_arr.data, rhs_off, len),
509                mask_slice,
510                op,
511            )?;
512            res
513        }
514        ComparisonOperator::IsNull => {
515            #[cfg(feature = "simd")]
516            let data = match merged_null_mask.as_ref() {
517                Some(mask) => not_mask_simd::<LANES>((mask, 0, len)),
518                None => Bitmask::new_set_all(len, false),
519            };
520            #[cfg(not(feature = "simd"))]
521            let data = match merged_null_mask.as_ref() {
522                Some(mask) => not_mask((mask, 0, len)),
523                None => Bitmask::new_set_all(len, false),
524            };
525            return Ok(BooleanArray {
526                data,
527                null_mask: None,
528                len,
529                _phantom: PhantomData,
530            });
531        }
532        ComparisonOperator::IsNotNull => {
533            let data = match merged_null_mask.as_ref() {
534                Some(mask) => mask.slice_clone(0, len),
535                None => Bitmask::new_set_all(len, true),
536            };
537            return Ok(BooleanArray {
538                data,
539                null_mask: None,
540                len,
541                _phantom: PhantomData,
542            });
543        }
544        ComparisonOperator::Between => {
545            return Err(KernelError::InvalidArguments(
546                "Set operations are not defined for Bool arrays".to_owned(),
547            ));
548        }
549    };
550
551    Ok(BooleanArray {
552        data,
553        null_mask: merged_null_mask,
554        len,
555        _phantom: PhantomData,
556    })
557}
558
559/// Compare two packed bool bitmask slices over a window, using the given operator.
560/// The offsets are bit offsets into each mask.
561/// The mask, if provided, is ANDed after the comparison.
562/// Requires that all offsets are 64-bit aligned (i.e., offset % 64 == 0).
563///
564/// This lower level kernel can be orchestrated by apply_cmp_bool which
565/// wraps it into a BoolWindow with null-aware semantics.
566#[cfg(not(feature = "simd"))]
567pub fn cmp_bitmask_std(
568    lhs: BitmaskVT<'_>,
569    rhs: BitmaskVT<'_>,
570    mask: Option<BitmaskVT<'_>>,
571    op: ComparisonOperator,
572) -> Result<Bitmask, KernelError> {
573    // We have some code duplication here with the `simd` version,
574    // but unifying then means a const LANE generic on the non-simd path,
575    // and adding a higher level dispatch layer create additional indirection
576    // and 9 args instead of 4, hence why it's this way.
577
578    confirm_equal_len("compare bool length mismatch", lhs.2, rhs.2)?;
579    let (lhs_mask, lhs_offset, len) = lhs;
580    let (rhs_mask, rhs_offset, _) = rhs;
581
582    // Handle 'In' and 'NotIn' early
583
584    if matches!(op, ComparisonOperator::In | ComparisonOperator::NotIn) {
585        let mut out = match op {
586            ComparisonOperator::In => in_mask(lhs, rhs),
587            ComparisonOperator::NotIn => not_in_mask(lhs, rhs),
588            _ => unreachable!(),
589        };
590        if let Some(mask_slice) = mask {
591            out = and_masks((&out, 0, out.len), mask_slice);
592        }
593        return Ok(out);
594    }
595
596    // Word-aligned offsets
597    if lhs_offset % 64 != 0
598        || rhs_offset % 64 != 0
599        || mask.as_ref().map_or(false, |(_, mo, _)| mo % 64 != 0)
600    {
601        return Err(KernelError::InvalidArguments(format!(
602            "cmp_bitmask: all offsets must be 64-bit aligned (lhs: {}, rhs: {}, mask offset: {:?})",
603            lhs_offset,
604            rhs_offset,
605            mask.as_ref().map(|(_, mo, _)| mo)
606        )));
607    }
608
609    // Precompute word indices/counts
610    let lhs_word_start = lhs_offset / 64;
611    let rhs_word_start = rhs_offset / 64;
612    let n_words = (len + 63) / 64;
613
614    // Allocate output
615    let mut out = Bitmask::new_set_all(len, false);
616
617    let words = n_words;
618    let tail = len % 64;
619    let mask_mask_opt = mask;
620
621    // Word-aligned loop
622    for w in 0..words {
623        let a = unsafe { lhs_mask.word_unchecked(lhs_word_start + w) };
624        let b = unsafe { rhs_mask.word_unchecked(rhs_word_start + w) };
625        let valid_bits =
626            mask_mask_opt
627                .as_ref()
628                .map_or(!0u64, |(mask_mask, mask_word_start, _)| unsafe {
629                    mask_mask.word_unchecked(mask_word_start + w)
630                });
631        let word_cmp = match op {
632            ComparisonOperator::Equals => !(a ^ b),
633            ComparisonOperator::NotEquals => a ^ b,
634            ComparisonOperator::GreaterThan => a & (!b),
635            ComparisonOperator::LessThan => (!a) & b,
636            ComparisonOperator::GreaterThanOrEqualTo => a | (!b),
637            ComparisonOperator::LessThanOrEqualTo => (!a) | b,
638            _ => 0,
639        };
640        let final_bits = word_cmp & valid_bits;
641        unsafe {
642            out.set_word_unchecked(w, final_bits);
643        }
644    }
645
646    // Tail often caused by `n % LANES != 0`; uses scalar fallback.
647
648    let base = words * 64;
649    for i in 0..tail {
650        let idx_lhs = lhs_offset + base + i;
651        let idx_rhs = rhs_offset + base + i;
652        let mask_valid =
653            mask_mask_opt
654                .as_ref()
655                .map_or(true, |(mask_mask, mask_word_start, mask_len)| unsafe {
656                    let mask_idx = mask_word_start * 64 + base + i;
657                    if mask_idx < *mask_len {
658                        mask_mask.get_unchecked(mask_idx)
659                    } else {
660                        false
661                    }
662                });
663        if !mask_valid {
664            continue;
665        }
666        if idx_lhs >= lhs_mask.len() || idx_rhs >= rhs_mask.len() {
667            continue;
668        }
669        let a = unsafe { lhs_mask.get_unchecked(idx_lhs) };
670        let b = unsafe { rhs_mask.get_unchecked(idx_rhs) };
671        let res = match op {
672            ComparisonOperator::Equals => a == b,
673            ComparisonOperator::NotEquals => a != b,
674            ComparisonOperator::GreaterThan => a & !b,
675            ComparisonOperator::LessThan => !a & b,
676            ComparisonOperator::GreaterThanOrEqualTo => a | !b,
677            ComparisonOperator::LessThanOrEqualTo => !a | b,
678            _ => false,
679        };
680        if res {
681            out.set(base + i, true)
682        }
683    }
684    out.mask_trailing_bits();
685    Ok(out)
686}
687
688// String and dictionary
689
690macro_rules! impl_cmp_utf8_slice {
691    ($fn_name:ident, $fn_name_to:ident, $lhs_slice:ty, $rhs_slice:ty, [$($gen:tt)+]) => {
692        /// Zero-allocation variant: writes directly to caller's output buffer.
693        ///
694        /// Compare UTF-8 string or dictionary arrays using the specified comparison operator.
695        /// The output Bitmask must have capacity >= llen.
696        #[inline(always)]
697        pub fn $fn_name_to<$($gen)+>(
698            lhs: $lhs_slice,
699            rhs: $rhs_slice,
700            op: ComparisonOperator,
701            output: &mut Bitmask,
702        ) -> Result<(), KernelError> {
703            let (larr, loff, llen) = lhs;
704            let (rarr, roff, rlen) = rhs;
705            confirm_equal_len("compare string/dict length mismatch (slice contract)", llen, rlen)?;
706            assert!(output.capacity() >= llen, concat!(stringify!($fn_name_to), ": output capacity too small"));
707
708            let lhs_mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
709            let rhs_mask = rarr.null_mask.as_ref().map(|m| m.slice_clone(roff, rlen));
710
711            if let Some(m) = larr.null_mask.as_ref() {
712                if m.capacity() < loff + llen {
713                    return Err(KernelError::InvalidArguments(
714                        format!(
715                            "lhs mask capacity too small (expected ≥ {}, got {})",
716                            loff + llen,
717                            m.capacity()
718                        ),
719                    ));
720                }
721            }
722            if let Some(m) = rarr.null_mask.as_ref() {
723                if m.capacity() < roff + rlen {
724                    return Err(KernelError::InvalidArguments(
725                        format!(
726                            "rhs mask capacity too small (expected ≥ {}, got {})",
727                            roff + rlen,
728                            m.capacity()
729                        ),
730                    ));
731                }
732            }
733
734            let has_nulls = lhs_mask.is_some() || rhs_mask.is_some();
735            for i in 0..llen {
736                if has_nulls
737                    && !(lhs_mask.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) })
738                        && rhs_mask.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }))
739                {
740                    continue;
741                }
742                let l = unsafe { larr.get_str_unchecked(loff + i) };
743                let r = unsafe { rarr.get_str_unchecked(roff + i) };
744                let res = match op {
745                    ComparisonOperator::Equals => l == r,
746                    ComparisonOperator::NotEquals => l != r,
747                    ComparisonOperator::GreaterThan => l > r,
748                    ComparisonOperator::LessThan => l < r,
749                    ComparisonOperator::GreaterThanOrEqualTo => l >= r,
750                    ComparisonOperator::LessThanOrEqualTo => l <= r,
751                    _ => false,
752                };
753                if res {
754                    output.set(i, true);
755                }
756            }
757            Ok(())
758        }
759
760        /// Compare UTF-8 string or dictionary arrays using the specified comparison operator.
761        #[inline(always)]
762        pub fn $fn_name<$($gen)+>(
763            lhs: $lhs_slice,
764            rhs: $rhs_slice,
765            op: ComparisonOperator,
766        ) -> Result<BooleanArray<()>, KernelError> {
767            let (larr, loff, llen) = lhs;
768            let (rarr, roff, _) = rhs;
769            let lhs_mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
770            let rhs_mask = rarr.null_mask.as_ref().map(|m| m.slice_clone(roff, llen));
771            let mut out = new_bool_bitmask(llen);
772            $fn_name_to((larr, loff, llen), (rarr, roff, llen), op, &mut out)?;
773            let null_mask = merge_bitmasks_to_new(lhs_mask.as_ref(), rhs_mask.as_ref(), llen);
774            Ok(BooleanArray { data: out.into(), null_mask, len: llen, _phantom: PhantomData })
775        }
776    };
777}
778
779impl_cmp_numeric!(cmp_i32, cmp_i32_to, i32, W32, i32);
780impl_cmp_numeric!(cmp_u32, cmp_u32_to, u32, W32, i32);
781impl_cmp_numeric!(cmp_i64, cmp_i64_to, i64, W64, i64);
782impl_cmp_numeric!(cmp_u64, cmp_u64_to, u64, W64, i64);
783impl_cmp_numeric!(cmp_f32, cmp_f32_to, f32, W32, i32);
784impl_cmp_numeric!(cmp_f64, cmp_f64_to, f64, W64, i64);
785impl_cmp_utf8_slice!(cmp_str_str,   cmp_str_str_to,   StringAVT<'a, T>,     StringAVT<'a, T>,      [ 'a, T: Integer ]);
786impl_cmp_utf8_slice!(cmp_str_dict,  cmp_str_dict_to,  StringAVT<'a, T>,     CategoricalAVT<'a, U>,      [ 'a, T: Integer, U: Integer ]);
787impl_cmp_utf8_slice!(cmp_dict_str,  cmp_dict_str_to,  CategoricalAVT<'a, T>,     StringAVT<'a, U>,      [ 'a, T: Integer, U: Integer ]);
788impl_cmp_utf8_slice!(cmp_dict_dict, cmp_dict_dict_to, CategoricalAVT<'a, T>,     CategoricalAVT<'a, T>,      [ 'a, T: Integer ]);
789
790#[cfg(test)]
791mod tests {
792    use minarrow::{Bitmask, BooleanArray, CategoricalArray, Integer, StringArray, vec64};
793
794    use crate::kernels::comparison::{
795        cmp_dict_dict, cmp_dict_str, cmp_i32, cmp_numeric, cmp_str_dict,
796    };
797
798    #[cfg(feature = "simd")]
799    use crate::kernels::comparison::{W64, cmp_bitmask_simd};
800
801    use crate::operators::ComparisonOperator;
802
803    /// --- helpers --------------------------------------------------------------
804
805    fn bm(bits: &[bool]) -> Bitmask {
806        let mut m = Bitmask::new_set_all(bits.len(), false);
807        for (i, &b) in bits.iter().enumerate() {
808            m.set(i, b);
809        }
810        m
811    }
812
813    /// Assert BooleanArray ⇢ expected value bits & expected null bits.
814    fn assert_bool(arr: &BooleanArray<()>, expect: &[bool], expect_mask: Option<&[bool]>) {
815        assert_eq!(arr.len, expect.len());
816        for i in 0..expect.len() {
817            assert_eq!(arr.data.get(i), expect[i], "value bit {i}");
818        }
819        match (arr.null_mask.as_ref(), expect_mask) {
820            (None, None) => {}
821            (Some(m), Some(exp)) => {
822                for (i, &b) in exp.iter().enumerate() {
823                    assert_eq!(m.get(i), b, "null-bit {i}");
824                }
825            }
826            _ => panic!("mask mismatch"),
827        }
828    }
829
830    /// Tiny helpers to build test String / Dict arrays.
831    fn str_arr<T: Integer>(v: &[&str]) -> StringArray<T> {
832        StringArray::<T>::from_slice(v)
833    }
834
835    fn dict_arr<T: Integer>(vals: &[&str]) -> CategoricalArray<T> {
836        let owned: Vec<&str> = vals.to_vec();
837        CategoricalArray::<T>::from_values(owned)
838    }
839
840    // NUMERIC
841
842    #[test]
843    fn numeric_compare_no_nulls() {
844        let a = vec64![1i32, 2, 3, 4];
845        let b = vec64![1i32, 1, 4, 4];
846
847        let eq = cmp_i32(&a, &b, None, ComparisonOperator::Equals).unwrap();
848        let neq = cmp_i32(&a, &b, None, ComparisonOperator::NotEquals).unwrap();
849        let lt = cmp_i32(&a, &b, None, ComparisonOperator::LessThan).unwrap();
850        let le = cmp_i32(&a, &b, None, ComparisonOperator::LessThanOrEqualTo).unwrap();
851        let gt = cmp_i32(&a, &b, None, ComparisonOperator::GreaterThan).unwrap();
852        let ge = cmp_i32(&a, &b, None, ComparisonOperator::GreaterThanOrEqualTo).unwrap();
853
854        assert_bool(&eq, &[true, false, false, true], None);
855        assert_bool(&neq, &[false, true, true, false], None);
856        assert_bool(&lt, &[false, false, true, false], None);
857        assert_bool(&le, &[true, false, true, true], None);
858        assert_bool(&gt, &[false, true, false, false], None);
859        assert_bool(&ge, &[true, true, false, true], None);
860    }
861
862    #[test]
863    fn numeric_compare_with_nulls_generic_dispatch() {
864        // last element masked-out
865        let a = vec64![1u64, 5, 9, 10];
866        let b = vec64![0u64, 5, 8, 11];
867        let mask = bm(&[true, true, true, false]);
868
869        let out = cmp_numeric(&a, &b, Some(&mask), ComparisonOperator::GreaterThan).unwrap();
870        // result bits for valid rows only
871        assert_bool(
872            &out,
873            &[true, false, true, false],
874            Some(&[true, true, true, false]),
875        );
876    }
877
878    // BOOLEAN
879
880    #[cfg(feature = "simd")]
881    #[test]
882    fn bool_compare_all_ops() {
883        let a = bm(&[true, false, true, false]);
884        let b = bm(&[true, true, false, false]);
885        let eq = cmp_bitmask_simd::<W64>(
886            (&a, 0, a.len()),
887            (&b, 0, b.len()),
888            None,
889            ComparisonOperator::Equals,
890        )
891        .unwrap();
892        let lt = cmp_bitmask_simd::<W64>(
893            (&a, 0, a.len()),
894            (&b, 0, b.len()),
895            None,
896            ComparisonOperator::LessThan,
897        )
898        .unwrap();
899        let gt = cmp_bitmask_simd::<W64>(
900            (&a, 0, a.len()),
901            (&b, 0, b.len()),
902            None,
903            ComparisonOperator::GreaterThan,
904        )
905        .unwrap();
906
907        assert_bool(
908            &BooleanArray::from_bitmask(eq, None),
909            &[true, false, false, true],
910            None,
911        );
912        assert_bool(
913            &BooleanArray::from_bitmask(lt, None),
914            &[false, true, false, false],
915            None,
916        );
917        assert_bool(
918            &BooleanArray::from_bitmask(gt, None),
919            &[false, false, true, false],
920            None,
921        );
922    }
923
924    // UTF & DICTIONARY
925
926    #[test]
927    fn string_vs_dict_compare_with_nulls() {
928        let mut lhs = str_arr::<u32>(&["x", "y", "z"]);
929        lhs.null_mask = Some(bm(&[true, false, true]));
930        let rhs = dict_arr::<u32>(&["x", "w", "zz"]);
931        let lhs_slice = (&lhs, 0, lhs.len());
932        let rhs_slice = (&rhs, 0, rhs.data.len());
933        let res = cmp_str_dict(lhs_slice, rhs_slice, ComparisonOperator::Equals).unwrap();
934        assert_bool(&res, &[true, false, false], Some(&[true, false, true]));
935    }
936
937    #[test]
938    fn string_vs_dict_compare_with_nulls_chunk() {
939        let mut lhs = str_arr::<u32>(&["pad", "x", "y", "z", "pad"]);
940        lhs.null_mask = Some(bm(&[true, true, false, true, true]));
941        let rhs = dict_arr::<u32>(&["pad", "x", "w", "zz", "pad"]);
942        let lhs_slice = (&lhs, 1, 3);
943        let rhs_slice = (&rhs, 1, 3);
944        let res = cmp_str_dict(lhs_slice, rhs_slice, ComparisonOperator::Equals).unwrap();
945        assert_bool(&res, &[true, false, false], Some(&[true, false, true]));
946    }
947
948    #[test]
949    fn dict_vs_dict_compare_gt() {
950        let lhs = dict_arr::<u32>(&["apple", "pear", "banana"]);
951        let rhs = dict_arr::<u32>(&["ant", "pear", "apricot"]);
952        let lhs_slice = (&lhs, 0, lhs.data.len());
953        let rhs_slice = (&rhs, 0, rhs.data.len());
954        let res = cmp_dict_dict(lhs_slice, rhs_slice, ComparisonOperator::GreaterThan).unwrap();
955        assert_bool(&res, &[true, false, true], None);
956    }
957
958    #[test]
959    fn dict_vs_dict_compare_gt_chunk() {
960        let lhs = dict_arr::<u32>(&["pad", "apple", "pear", "banana", "pad"]);
961        let rhs = dict_arr::<u32>(&["pad", "ant", "pear", "apricot", "pad"]);
962        let lhs_slice = (&lhs, 1, 3);
963        let rhs_slice = (&rhs, 1, 3);
964        let res = cmp_dict_dict(lhs_slice, rhs_slice, ComparisonOperator::GreaterThan).unwrap();
965        assert_bool(&res, &[true, false, true], None);
966    }
967
968    #[test]
969    fn dict_vs_string_compare_le() {
970        let lhs = dict_arr::<u32>(&["a", "b", "c"]);
971        let rhs = str_arr::<u32>(&["b", "b", "d"]);
972        let lhs_slice = (&lhs, 0, lhs.data.len());
973        let rhs_slice = (&rhs, 0, rhs.len());
974        let res =
975            cmp_dict_str(lhs_slice, rhs_slice, ComparisonOperator::LessThanOrEqualTo).unwrap();
976        assert_bool(&res, &[true, true, true], None);
977    }
978
979    #[test]
980    fn dict_vs_string_compare_le_chunk() {
981        let lhs = dict_arr::<u32>(&["pad", "a", "b", "c", "pad"]);
982        let rhs = str_arr::<u32>(&["pad", "b", "b", "d", "pad"]);
983        let lhs_slice = (&lhs, 1, 3);
984        let rhs_slice = (&rhs, 1, 3);
985        let res =
986            cmp_dict_str(lhs_slice, rhs_slice, ComparisonOperator::LessThanOrEqualTo).unwrap();
987        assert_bool(&res, &[true, true, true], None);
988    }
989}