Skip to main content

simd_kernels/kernels/
comparison.rs

1// Copyright (c) 2025 Peter Bower
2// SPDX-License-Identifier: AGPL-3.0-or-later
3// Commercial licensing available. See LICENSE and LICENSING.md.
4
5//! # **Comparison Operations Kernels Module** - *High-Performance Element-wise Comparison Operations*
6//!
7//! Optimised comparison kernels providing comprehensive element-wise comparison operations
8//! across numeric, string, and categorical data types with SIMD acceleration and null-aware semantics.
9//! Foundation for filtering, conditional logic, and analytical query processing.
10//!
11//! ## Core Operations
12//! - **Numeric comparisons**: Equal, not equal, greater than, less than, greater/less than or equal
13//! - **String comparisons**: UTF-8 aware lexicographic ordering with efficient prefix matching
14//! - **Categorical comparisons**: Dictionary-encoded comparisons avoiding string materialisation
15//! - **Null-aware semantics**: Proper three-valued logic handling (true/false/null)
16//! - **SIMD vectorisation**: Hardware-accelerated bulk comparison operations
17//! - **Bitmask operations**: Efficient boolean result representation using bit manipulation
18
19include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
20
21// SIMD
22use std::marker::PhantomData;
23#[cfg(feature = "simd")]
24use std::simd::{Mask, Simd};
25
26use minarrow::{Bitmask, BooleanArray, Integer, Numeric};
27
28#[cfg(not(feature = "simd"))]
29use crate::kernels::bitmask::std::{and_masks, in_mask, not_in_mask, not_mask};
30use crate::operators::ComparisonOperator;
31use minarrow::enums::error::KernelError;
32#[cfg(feature = "simd")]
33use minarrow::kernels::bitmask::simd::{
34    and_masks_simd, in_mask_simd, not_in_mask_simd, not_mask_simd,
35};
36use minarrow::utils::confirm_equal_len;
37#[cfg(feature = "simd")]
38use minarrow::utils::is_simd_aligned;
39use minarrow::{BitmaskVT, BooleanAVT, CategoricalAVT, StringAVT};
40
41/// Returns a new Bitmask for boolean buffers, all bits cleared (false).
42#[inline(always)]
43fn new_bool_bitmask(len: usize) -> Bitmask {
44    Bitmask::new_set_all(len, false)
45}
46
47/// Merge two Bitmasks using bitwise AND, or propagate one if only one is present.
48fn merge_bitmasks_to_new(a: Option<&Bitmask>, b: Option<&Bitmask>, len: usize) -> Option<Bitmask> {
49    match (a, b) {
50        (None, None) => None,
51        (Some(x), None) | (None, Some(x)) => Some(x.slice_clone(0, len)),
52        (Some(x), Some(y)) => {
53            let mut out = Bitmask::new_set_all(len, true);
54            for i in 0..len {
55                unsafe { out.set_unchecked(i, x.get_unchecked(i) && y.get_unchecked(i)) };
56            }
57            Some(out)
58        }
59    }
60}
61
62// Int and float
63
64macro_rules! impl_cmp_numeric {
65    ($fn_name:ident, $fn_name_to:ident, $ty:ty, $lanes:expr, $mask_elem:ty) => {
66        /// Zero-allocation variant: writes directly to caller's output buffer.
67        ///
68        /// Type-specific SIMD-accelerated comparison function with vectorised operations.
69        /// Panics if output capacity < lhs.len().
70        #[inline(always)]
71        pub fn $fn_name_to(
72            lhs: &[$ty],
73            rhs: &[$ty],
74            mask: Option<&Bitmask>,
75            op: ComparisonOperator,
76            output: &mut Bitmask,
77        ) -> Result<(), KernelError> {
78            let len = lhs.len();
79            confirm_equal_len("compare numeric length mismatch", len, rhs.len())?;
80            assert!(
81                output.capacity() >= len,
82                concat!(stringify!($fn_name_to), ": output capacity too small")
83            );
84            let has_nulls = mask.is_some();
85
86            #[cfg(feature = "simd")]
87            {
88                // Check if both arrays are 64-byte aligned for SIMD
89                if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
90                    use std::simd::cmp::{SimdPartialEq, SimdPartialOrd};
91                    const N: usize = $lanes;
92                    if !has_nulls {
93                        let mut i = 0;
94                        while i + N <= len {
95                            let a = Simd::<$ty, N>::from_slice(&lhs[i..i + N]);
96                            let b = Simd::<$ty, N>::from_slice(&rhs[i..i + N]);
97                            let m: Mask<$mask_elem, N> = match op {
98                                ComparisonOperator::Equals => a.simd_eq(b),
99                                ComparisonOperator::NotEquals => a.simd_ne(b),
100                                ComparisonOperator::LessThan => a.simd_lt(b),
101                                ComparisonOperator::LessThanOrEqualTo => a.simd_le(b),
102                                ComparisonOperator::GreaterThan => a.simd_gt(b),
103                                ComparisonOperator::GreaterThanOrEqualTo => a.simd_ge(b),
104                                _ => Mask::splat(false),
105                            };
106                            let bits = m.to_bitmask();
107                            for l in 0..N {
108                                if ((bits >> l) & 1) == 1 {
109                                    unsafe { output.set_unchecked(i + l, true) };
110                                }
111                            }
112                            i += N;
113                        }
114                        // Tail often caused by `n % LANES != 0`; uses scalar fallback.
115                        for j in i..len {
116                            let res = match op {
117                                ComparisonOperator::Equals => lhs[j] == rhs[j],
118                                ComparisonOperator::NotEquals => lhs[j] != rhs[j],
119                                ComparisonOperator::LessThan => lhs[j] < rhs[j],
120                                ComparisonOperator::LessThanOrEqualTo => lhs[j] <= rhs[j],
121                                ComparisonOperator::GreaterThan => lhs[j] > rhs[j],
122                                ComparisonOperator::GreaterThanOrEqualTo => lhs[j] >= rhs[j],
123                                _ => false,
124                            };
125                            if res {
126                                unsafe { output.set_unchecked(j, true) };
127                            }
128                        }
129
130                        return Ok(());
131                    }
132                }
133                // Fall through to scalar path if alignment check failed
134            }
135
136            // Scalar fallback - alignment check failed
137            for i in 0..len {
138                if has_nulls && !mask.map_or(true, |m| unsafe { m.get_unchecked(i) }) {
139                    continue;
140                }
141                let res = match op {
142                    ComparisonOperator::Equals => lhs[i] == rhs[i],
143                    ComparisonOperator::NotEquals => lhs[i] != rhs[i],
144                    ComparisonOperator::LessThan => lhs[i] < rhs[i],
145                    ComparisonOperator::LessThanOrEqualTo => lhs[i] <= rhs[i],
146                    ComparisonOperator::GreaterThan => lhs[i] > rhs[i],
147                    ComparisonOperator::GreaterThanOrEqualTo => lhs[i] >= rhs[i],
148                    _ => false,
149                };
150                if res {
151                    unsafe { output.set_unchecked(i, true) };
152                }
153            }
154            Ok(())
155        }
156
157        /// Type-specific SIMD-accelerated comparison function with vectorised operations.
158        ///
159        /// Specialised comparison implementation optimised for the specific numeric type with
160        /// architecture-appropriate lane configuration. Features memory alignment checking,
161        /// SIMD vectorisation, and optional null mask support for maximum performance.
162        ///
163        /// # Parameters
164        /// - `lhs`: Left-hand side slice for comparison
165        /// - `rhs`: Right-hand side slice for comparison
166        /// - `mask`: Optional validity mask applied after comparison
167        /// - `op`: Comparison operator to apply
168        ///
169        /// # Returns
170        /// `Result<BooleanArray<()>, KernelError>` containing comparison results.
171        ///
172        /// # SIMD Optimisations
173        /// - Memory alignment: Checks 64-byte alignment for optimal SIMD operations
174        /// - Vectorised comparisons: Uses SIMD compare operations for parallel processing
175        /// - Scalar fallback: Efficient scalar path for unaligned or remainder elements
176        #[inline(always)]
177        pub fn $fn_name(
178            lhs: &[$ty],
179            rhs: &[$ty],
180            mask: Option<&Bitmask>,
181            op: ComparisonOperator,
182        ) -> Result<BooleanArray<()>, KernelError> {
183            let len = lhs.len();
184            let mut out = new_bool_bitmask(len);
185            $fn_name_to(lhs, rhs, mask, op, &mut out)?;
186            Ok(BooleanArray {
187                data: out.into(),
188                null_mask: mask.cloned(),
189                len,
190                _phantom: PhantomData,
191            })
192        }
193    };
194}
195
196/// Zero-allocation variant: writes directly to caller's output buffer.
197///
198/// Unified numeric comparison dispatch with optional null mask support.
199/// The output Bitmask must have capacity >= lhs.len().
200#[inline(always)]
201pub fn cmp_numeric_to<T: Numeric + Copy + 'static>(
202    lhs: &[T],
203    rhs: &[T],
204    mask: Option<&Bitmask>,
205    op: ComparisonOperator,
206    output: &mut Bitmask,
207) -> Result<(), KernelError> {
208    macro_rules! dispatch {
209        ($t:ty, $f:ident) => {
210            if std::any::TypeId::of::<T>() == std::any::TypeId::of::<$t>() {
211                return $f(
212                    unsafe { std::mem::transmute(lhs) },
213                    unsafe { std::mem::transmute(rhs) },
214                    mask,
215                    op,
216                    output,
217                );
218            }
219        };
220    }
221    dispatch!(i32, cmp_i32_to);
222    dispatch!(i64, cmp_i64_to);
223    dispatch!(u32, cmp_u32_to);
224    dispatch!(u64, cmp_u64_to);
225    dispatch!(f32, cmp_f32_to);
226    dispatch!(f64, cmp_f64_to);
227
228    unreachable!("Unsupported numeric type for compare_numeric");
229}
230
231/// Unified numeric comparison dispatch with optional null mask support.
232///
233/// High-performance generic comparison function that dispatches to type-specific SIMD implementations
234/// based on runtime type identification. Supports all numeric types with optional null mask filtering
235/// and comprehensive error handling for mismatched lengths and unsupported types.
236///
237/// # Type Parameters
238/// - `T`: Numeric type implementing `Numeric + Copy + 'static` (i32, i64, u32, u64, f32, f64)
239///
240/// # Parameters
241/// - `lhs`: Left-hand side numeric slice for comparison
242/// - `rhs`: Right-hand side numeric slice for comparison
243/// - `mask`: Optional validity mask applied after comparison (AND operation)
244/// - `op`: Comparison operator to apply (Equals, NotEquals, LessThan, etc.)
245///
246/// # Returns
247/// `Result<BooleanArray<()>, KernelError>` containing comparison results or error details.
248///
249/// # Dispatch Strategy
250/// Uses `TypeId` runtime checking to dispatch to optimal type-specific implementations:
251/// - `i32`/`u32`: 32-bit integer SIMD kernels with W32 lane configuration
252/// - `i64`/`u64`: 64-bit integer SIMD kernels with W64 lane configuration
253/// - `f32`/`f64`: IEEE 754 floating-point SIMD kernels with specialised NaN handling
254///
255/// # Error Conditions
256/// - `KernelError::LengthMismatch`: Input slices have different lengths
257/// - `KernelError::InvalidArguments`: Unsupported numeric type (unreachable in practice)
258///
259/// # Performance Benefits
260/// - Zero-cost dispatch: Type resolution optimised away at compile time for monomorphic usage
261/// - SIMD acceleration: Delegates to vectorised implementations for maximum throughput
262/// - Memory efficiency: Optional mask processing avoids unnecessary allocations
263///
264/// # Example Usage
265/// ```rust,ignore
266/// use simd_kernels::kernels::comparison::{cmp_numeric, ComparisonOperator};
267///
268/// let lhs = &[1i32, 2, 3, 4];
269/// let rhs = &[1i32, 1, 4, 3];
270/// let result = cmp_numeric(lhs, rhs, None, ComparisonOperator::Equals)?;
271/// // Result: [true, false, false, false]
272/// ```
273#[inline(always)]
274pub fn cmp_numeric<T: Numeric + Copy + 'static>(
275    lhs: &[T],
276    rhs: &[T],
277    mask: Option<&Bitmask>,
278    op: ComparisonOperator,
279) -> Result<BooleanArray<()>, KernelError> {
280    let len = lhs.len();
281    let mut out = new_bool_bitmask(len);
282    cmp_numeric_to(lhs, rhs, mask, op, &mut out)?;
283    Ok(BooleanArray {
284        data: out.into(),
285        null_mask: mask.cloned(),
286        len,
287        _phantom: PhantomData,
288    })
289}
290
291/// SIMD-accelerated compare bitmask
292///
293/// Compare two packed bool bitmask slices over a window, using the given operator.
294/// The offsets are bit offsets into each mask.
295/// The mask, if provided, is ANDed after the comparison.
296/// Requires that all offsets are 64-bit aligned (i.e., offset % 64 == 0).
297///
298/// This lower level kernel can be orchestrated by apply_cmp_bool which
299/// wraps it into a BoolWindow with null-aware semantics.
300#[cfg(feature = "simd")]
301pub fn cmp_bitmask_simd<const LANES: usize>(
302    lhs: BitmaskVT<'_>,
303    rhs: BitmaskVT<'_>,
304    mask: Option<BitmaskVT<'_>>,
305    op: ComparisonOperator,
306) -> Result<Bitmask, KernelError>
307where
308{
309    // We have some code duplication here with the `std` version,
310    // but unifying then means a const LANE generic on the non-simd path,
311    // and adding a higher level dispatch layer creates additional indirection
312    // and 9 args instead of 4, hence why it's this way.
313
314    confirm_equal_len("compare bool length mismatch", lhs.2, rhs.2)?;
315    let (lhs_mask, lhs_offset, len) = lhs;
316    let (rhs_mask, rhs_offset, _) = rhs;
317
318    // Handle 'In' and 'NotIn' early
319
320    if matches!(op, ComparisonOperator::In | ComparisonOperator::NotIn) {
321        let mut out = match op {
322            ComparisonOperator::In => in_mask_simd::<LANES>(lhs, rhs),
323            ComparisonOperator::NotIn => not_in_mask_simd::<LANES>(lhs, rhs),
324            _ => unreachable!(),
325        };
326        if let Some(mask_slice) = mask {
327            out = and_masks_simd::<LANES>((&out, 0, out.len), mask_slice);
328        }
329        return Ok(out);
330    }
331
332    // Word-aligned offsets
333    if lhs_offset % 64 != 0
334        || rhs_offset % 64 != 0
335        || mask.as_ref().map_or(false, |(_, mo, _)| mo % 64 != 0)
336    {
337        return Err(KernelError::InvalidArguments(format!(
338            "cmp_bitmask: all offsets must be 64-bit aligned (lhs: {}, rhs: {}, mask offset: {:?})",
339            lhs_offset,
340            rhs_offset,
341            mask.as_ref().map(|(_, mo, _)| mo)
342        )));
343    }
344
345    // Precompute word indices/counts
346    let lhs_word_start = lhs_offset / 64;
347    let rhs_word_start = rhs_offset / 64;
348    let n_words = (len + 63) / 64;
349
350    // Allocate output
351    let mut out = Bitmask::new_set_all(len, false);
352
353    type Word = u64;
354    let lane_words = LANES;
355    let simd_chunks = n_words / lane_words;
356
357    let tail_words = n_words % lane_words;
358    let mut word_idx = 0;
359
360    // SIMD main path
361    for chunk in 0..simd_chunks {
362        let base_lhs = lhs_word_start + chunk * lane_words;
363        let base_rhs = rhs_word_start + chunk * lane_words;
364        let base_mask = mask
365            .as_ref()
366            .map(|(m, mask_word_start, _)| (m, mask_word_start + chunk * lane_words));
367
368        let mut lhs_arr = [0u64; LANES];
369        let mut rhs_arr = [0u64; LANES];
370        let mut mask_arr = [!0u64; LANES];
371
372        for lane in 0..LANES {
373            lhs_arr[lane] = unsafe { lhs_mask.word_unchecked(base_lhs + lane) };
374            rhs_arr[lane] = unsafe { rhs_mask.word_unchecked(base_rhs + lane) };
375            if let Some((m, mask_word_start)) = base_mask {
376                mask_arr[lane] = unsafe { m.word_unchecked(mask_word_start + lane) };
377            }
378        }
379        let lhs_v = Simd::<Word, LANES>::from_array(lhs_arr);
380        let rhs_v = Simd::<Word, LANES>::from_array(rhs_arr);
381        let mask_v = Simd::<Word, LANES>::from_array(mask_arr);
382
383        let cmp_v = match op {
384            ComparisonOperator::Equals => !(lhs_v ^ rhs_v),
385            ComparisonOperator::NotEquals => lhs_v ^ rhs_v,
386            ComparisonOperator::GreaterThan => lhs_v & (!rhs_v),
387            ComparisonOperator::LessThan => (!lhs_v) & rhs_v,
388            ComparisonOperator::GreaterThanOrEqualTo => lhs_v | (!rhs_v),
389            ComparisonOperator::LessThanOrEqualTo => (!lhs_v) | rhs_v,
390            _ => Simd::splat(0),
391        };
392        let result_v = cmp_v & mask_v;
393
394        for lane in 0..LANES {
395            unsafe {
396                out.set_word_unchecked(word_idx, result_v[lane]);
397            }
398            word_idx += 1;
399        }
400    }
401
402    // Tail often caused by `n % LANES != 0`; uses scalar fallback.
403    let base_lhs = lhs_word_start + simd_chunks * lane_words;
404    let base_rhs = rhs_word_start + simd_chunks * lane_words;
405    let base_mask: Option<(&Bitmask, usize)> = mask
406        .as_ref()
407        .map(|(m, mo, _)| (*m, mo + simd_chunks * lane_words));
408
409    for tail in 0..tail_words {
410        let a = unsafe { lhs_mask.word_unchecked(base_lhs + tail) };
411        let b = unsafe { rhs_mask.word_unchecked(base_rhs + tail) };
412        let m = if let Some((m, mask_word_start)) = base_mask {
413            unsafe { m.word_unchecked(mask_word_start + tail) }
414        } else {
415            !0u64
416        };
417        let cmp = match op {
418            ComparisonOperator::Equals => !(a ^ b),
419            ComparisonOperator::NotEquals => a ^ b,
420            ComparisonOperator::GreaterThan => a & (!b),
421            ComparisonOperator::LessThan => (!a) & b,
422            ComparisonOperator::GreaterThanOrEqualTo => a | (!b),
423            ComparisonOperator::LessThanOrEqualTo => (!a) | b,
424            _ => 0,
425        } & m;
426        unsafe {
427            out.set_word_unchecked(word_idx, cmp);
428        }
429        word_idx += 1;
430    }
431
432    out.mask_trailing_bits();
433    Ok(out)
434}
435
436/// Performs vectorised boolean array comparisons with null mask handling.
437///
438/// High-performance SIMD-accelerated comparison function for boolean arrays with automatic null
439/// mask merging and operator-specific optimisations. Supports all comparison operators through
440/// efficient bitmask operations with configurable lane counts for architecture-specific tuning.
441///
442/// # Type Parameters
443/// - `LANES`: Number of SIMD lanes to process simultaneously
444///
445/// # Parameters
446/// - `lhs`: Left-hand side boolean array view as `(array, offset, length)` tuple
447/// - `rhs`: Right-hand side boolean array view as `(array, offset, length)` tuple  
448/// - `op`: Comparison operator (Equals, NotEquals, In, NotIn, IsNull, IsNotNull, etc.)
449///
450/// # Returns
451/// `Result<BooleanArray<()>, KernelError>` containing comparison results with merged null semantics.
452pub fn cmp_bool<const LANES: usize>(
453    lhs: BooleanAVT<'_, ()>,
454    rhs: BooleanAVT<'_, ()>,
455    op: ComparisonOperator,
456) -> Result<BooleanArray<()>, KernelError>
457where
458{
459    let (lhs_arr, lhs_off, len) = lhs;
460    let (rhs_arr, rhs_off, rlen) = rhs;
461    debug_assert_eq!(len, rlen, "cmp_bool: window length mismatch");
462
463    #[cfg(feature = "simd")]
464    let merged_null_mask: Option<Bitmask> =
465        match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
466            (None, None) => None,
467            (Some(m), None) => Some(m.slice_clone(lhs_off, len)),
468            (None, Some(m)) => Some(m.slice_clone(rhs_off, len)),
469            (Some(a), Some(b)) => {
470                let am = (a, lhs_off, len);
471                let bm = (b, rhs_off, len);
472                Some(and_masks_simd::<LANES>(am, bm))
473            }
474        };
475
476    #[cfg(not(feature = "simd"))]
477    let merged_null_mask: Option<Bitmask> =
478        match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
479            (None, None) => None,
480            (Some(m), None) => Some(m.slice_clone(lhs_off, len)),
481            (None, Some(m)) => Some(m.slice_clone(rhs_off, len)),
482            (Some(a), Some(b)) => {
483                let am = (a, lhs_off, len);
484                let bm = (b, rhs_off, len);
485                Some(and_masks(am, bm))
486            }
487        };
488
489    let mask_slice = merged_null_mask.as_ref().map(|m| (m, 0, len));
490
491    let data = match op {
492        ComparisonOperator::Equals
493        | ComparisonOperator::NotEquals
494        | ComparisonOperator::LessThan
495        | ComparisonOperator::LessThanOrEqualTo
496        | ComparisonOperator::GreaterThan
497        | ComparisonOperator::GreaterThanOrEqualTo
498        | ComparisonOperator::In
499        | ComparisonOperator::NotIn => {
500            #[cfg(feature = "simd")]
501            let res = cmp_bitmask_simd::<LANES>(
502                (&lhs_arr.data, lhs_off, len),
503                (&rhs_arr.data, rhs_off, len),
504                mask_slice,
505                op,
506            )?;
507            #[cfg(not(feature = "simd"))]
508            let res = cmp_bitmask_std(
509                (&lhs_arr.data, lhs_off, len),
510                (&rhs_arr.data, rhs_off, len),
511                mask_slice,
512                op,
513            )?;
514            res
515        }
516        ComparisonOperator::IsNull => {
517            #[cfg(feature = "simd")]
518            let data = match merged_null_mask.as_ref() {
519                Some(mask) => not_mask_simd::<LANES>((mask, 0, len)),
520                None => Bitmask::new_set_all(len, false),
521            };
522            #[cfg(not(feature = "simd"))]
523            let data = match merged_null_mask.as_ref() {
524                Some(mask) => not_mask((mask, 0, len)),
525                None => Bitmask::new_set_all(len, false),
526            };
527            return Ok(BooleanArray {
528                data,
529                null_mask: None,
530                len,
531                _phantom: PhantomData,
532            });
533        }
534        ComparisonOperator::IsNotNull => {
535            let data = match merged_null_mask.as_ref() {
536                Some(mask) => mask.slice_clone(0, len),
537                None => Bitmask::new_set_all(len, true),
538            };
539            return Ok(BooleanArray {
540                data,
541                null_mask: None,
542                len,
543                _phantom: PhantomData,
544            });
545        }
546        ComparisonOperator::Between => {
547            return Err(KernelError::InvalidArguments(
548                "Set operations are not defined for Bool arrays".to_owned(),
549            ));
550        }
551    };
552
553    Ok(BooleanArray {
554        data,
555        null_mask: merged_null_mask,
556        len,
557        _phantom: PhantomData,
558    })
559}
560
561/// Compare two packed bool bitmask slices over a window, using the given operator.
562/// The offsets are bit offsets into each mask.
563/// The mask, if provided, is ANDed after the comparison.
564/// Requires that all offsets are 64-bit aligned (i.e., offset % 64 == 0).
565///
566/// This lower level kernel can be orchestrated by apply_cmp_bool which
567/// wraps it into a BoolWindow with null-aware semantics.
568#[cfg(not(feature = "simd"))]
569pub fn cmp_bitmask_std(
570    lhs: BitmaskVT<'_>,
571    rhs: BitmaskVT<'_>,
572    mask: Option<BitmaskVT<'_>>,
573    op: ComparisonOperator,
574) -> Result<Bitmask, KernelError> {
575    // We have some code duplication here with the `simd` version,
576    // but unifying then means a const LANE generic on the non-simd path,
577    // and adding a higher level dispatch layer create additional indirection
578    // and 9 args instead of 4, hence why it's this way.
579
580    confirm_equal_len("compare bool length mismatch", lhs.2, rhs.2)?;
581    let (lhs_mask, lhs_offset, len) = lhs;
582    let (rhs_mask, rhs_offset, _) = rhs;
583
584    // Handle 'In' and 'NotIn' early
585
586    if matches!(op, ComparisonOperator::In | ComparisonOperator::NotIn) {
587        let mut out = match op {
588            ComparisonOperator::In => in_mask(lhs, rhs),
589            ComparisonOperator::NotIn => not_in_mask(lhs, rhs),
590            _ => unreachable!(),
591        };
592        if let Some(mask_slice) = mask {
593            out = and_masks((&out, 0, out.len), mask_slice);
594        }
595        return Ok(out);
596    }
597
598    // Word-aligned offsets
599    if lhs_offset % 64 != 0
600        || rhs_offset % 64 != 0
601        || mask.as_ref().map_or(false, |(_, mo, _)| mo % 64 != 0)
602    {
603        return Err(KernelError::InvalidArguments(format!(
604            "cmp_bitmask: all offsets must be 64-bit aligned (lhs: {}, rhs: {}, mask offset: {:?})",
605            lhs_offset,
606            rhs_offset,
607            mask.as_ref().map(|(_, mo, _)| mo)
608        )));
609    }
610
611    // Precompute word indices/counts
612    let lhs_word_start = lhs_offset / 64;
613    let rhs_word_start = rhs_offset / 64;
614    let n_words = (len + 63) / 64;
615
616    // Allocate output
617    let mut out = Bitmask::new_set_all(len, false);
618
619    let words = n_words;
620    let tail = len % 64;
621    let mask_mask_opt = mask;
622
623    // Word-aligned loop
624    for w in 0..words {
625        let a = unsafe { lhs_mask.word_unchecked(lhs_word_start + w) };
626        let b = unsafe { rhs_mask.word_unchecked(rhs_word_start + w) };
627        let valid_bits =
628            mask_mask_opt
629                .as_ref()
630                .map_or(!0u64, |(mask_mask, mask_word_start, _)| unsafe {
631                    mask_mask.word_unchecked(mask_word_start + w)
632                });
633        let word_cmp = match op {
634            ComparisonOperator::Equals => !(a ^ b),
635            ComparisonOperator::NotEquals => a ^ b,
636            ComparisonOperator::GreaterThan => a & (!b),
637            ComparisonOperator::LessThan => (!a) & b,
638            ComparisonOperator::GreaterThanOrEqualTo => a | (!b),
639            ComparisonOperator::LessThanOrEqualTo => (!a) | b,
640            _ => 0,
641        };
642        let final_bits = word_cmp & valid_bits;
643        unsafe {
644            out.set_word_unchecked(w, final_bits);
645        }
646    }
647
648    // Tail often caused by `n % LANES != 0`; uses scalar fallback.
649
650    let base = words * 64;
651    for i in 0..tail {
652        let idx_lhs = lhs_offset + base + i;
653        let idx_rhs = rhs_offset + base + i;
654        let mask_valid =
655            mask_mask_opt
656                .as_ref()
657                .map_or(true, |(mask_mask, mask_word_start, mask_len)| unsafe {
658                    let mask_idx = mask_word_start * 64 + base + i;
659                    if mask_idx < *mask_len {
660                        mask_mask.get_unchecked(mask_idx)
661                    } else {
662                        false
663                    }
664                });
665        if !mask_valid {
666            continue;
667        }
668        if idx_lhs >= lhs_mask.len() || idx_rhs >= rhs_mask.len() {
669            continue;
670        }
671        let a = unsafe { lhs_mask.get_unchecked(idx_lhs) };
672        let b = unsafe { rhs_mask.get_unchecked(idx_rhs) };
673        let res = match op {
674            ComparisonOperator::Equals => a == b,
675            ComparisonOperator::NotEquals => a != b,
676            ComparisonOperator::GreaterThan => a & !b,
677            ComparisonOperator::LessThan => !a & b,
678            ComparisonOperator::GreaterThanOrEqualTo => a | !b,
679            ComparisonOperator::LessThanOrEqualTo => !a | b,
680            _ => false,
681        };
682        if res {
683            out.set(base + i, true)
684        }
685    }
686    out.mask_trailing_bits();
687    Ok(out)
688}
689
690// String and dictionary
691
692macro_rules! impl_cmp_utf8_slice {
693    ($fn_name:ident, $fn_name_to:ident, $lhs_slice:ty, $rhs_slice:ty, [$($gen:tt)+]) => {
694        /// Zero-allocation variant: writes directly to caller's output buffer.
695        ///
696        /// Compare UTF-8 string or dictionary arrays using the specified comparison operator.
697        /// The output Bitmask must have capacity >= llen.
698        #[inline(always)]
699        pub fn $fn_name_to<$($gen)+>(
700            lhs: $lhs_slice,
701            rhs: $rhs_slice,
702            op: ComparisonOperator,
703            output: &mut Bitmask,
704        ) -> Result<(), KernelError> {
705            let (larr, loff, llen) = lhs;
706            let (rarr, roff, rlen) = rhs;
707            confirm_equal_len("compare string/dict length mismatch (slice contract)", llen, rlen)?;
708            assert!(output.capacity() >= llen, concat!(stringify!($fn_name_to), ": output capacity too small"));
709
710            let lhs_mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
711            let rhs_mask = rarr.null_mask.as_ref().map(|m| m.slice_clone(roff, rlen));
712
713            if let Some(m) = larr.null_mask.as_ref() {
714                if m.capacity() < loff + llen {
715                    return Err(KernelError::InvalidArguments(
716                        format!(
717                            "lhs mask capacity too small (expected ≥ {}, got {})",
718                            loff + llen,
719                            m.capacity()
720                        ),
721                    ));
722                }
723            }
724            if let Some(m) = rarr.null_mask.as_ref() {
725                if m.capacity() < roff + rlen {
726                    return Err(KernelError::InvalidArguments(
727                        format!(
728                            "rhs mask capacity too small (expected ≥ {}, got {})",
729                            roff + rlen,
730                            m.capacity()
731                        ),
732                    ));
733                }
734            }
735
736            let has_nulls = lhs_mask.is_some() || rhs_mask.is_some();
737            for i in 0..llen {
738                if has_nulls
739                    && !(lhs_mask.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) })
740                        && rhs_mask.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }))
741                {
742                    continue;
743                }
744                let l = unsafe { larr.get_str_unchecked(loff + i) };
745                let r = unsafe { rarr.get_str_unchecked(roff + i) };
746                let res = match op {
747                    ComparisonOperator::Equals => l == r,
748                    ComparisonOperator::NotEquals => l != r,
749                    ComparisonOperator::GreaterThan => l > r,
750                    ComparisonOperator::LessThan => l < r,
751                    ComparisonOperator::GreaterThanOrEqualTo => l >= r,
752                    ComparisonOperator::LessThanOrEqualTo => l <= r,
753                    _ => false,
754                };
755                if res {
756                    output.set(i, true);
757                }
758            }
759            Ok(())
760        }
761
762        /// Compare UTF-8 string or dictionary arrays using the specified comparison operator.
763        #[inline(always)]
764        pub fn $fn_name<$($gen)+>(
765            lhs: $lhs_slice,
766            rhs: $rhs_slice,
767            op: ComparisonOperator,
768        ) -> Result<BooleanArray<()>, KernelError> {
769            let (larr, loff, llen) = lhs;
770            let (rarr, roff, _) = rhs;
771            let lhs_mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
772            let rhs_mask = rarr.null_mask.as_ref().map(|m| m.slice_clone(roff, llen));
773            let mut out = new_bool_bitmask(llen);
774            $fn_name_to((larr, loff, llen), (rarr, roff, llen), op, &mut out)?;
775            let null_mask = merge_bitmasks_to_new(lhs_mask.as_ref(), rhs_mask.as_ref(), llen);
776            Ok(BooleanArray { data: out.into(), null_mask, len: llen, _phantom: PhantomData })
777        }
778    };
779}
780
781impl_cmp_numeric!(cmp_i32, cmp_i32_to, i32, W32, i32);
782impl_cmp_numeric!(cmp_u32, cmp_u32_to, u32, W32, i32);
783impl_cmp_numeric!(cmp_i64, cmp_i64_to, i64, W64, i64);
784impl_cmp_numeric!(cmp_u64, cmp_u64_to, u64, W64, i64);
785impl_cmp_numeric!(cmp_f32, cmp_f32_to, f32, W32, i32);
786impl_cmp_numeric!(cmp_f64, cmp_f64_to, f64, W64, i64);
787impl_cmp_utf8_slice!(cmp_str_str,   cmp_str_str_to,   StringAVT<'a, T>,     StringAVT<'a, T>,      [ 'a, T: Integer ]);
788impl_cmp_utf8_slice!(cmp_str_dict,  cmp_str_dict_to,  StringAVT<'a, T>,     CategoricalAVT<'a, U>,      [ 'a, T: Integer, U: Integer ]);
789impl_cmp_utf8_slice!(cmp_dict_str,  cmp_dict_str_to,  CategoricalAVT<'a, T>,     StringAVT<'a, U>,      [ 'a, T: Integer, U: Integer ]);
790impl_cmp_utf8_slice!(cmp_dict_dict, cmp_dict_dict_to, CategoricalAVT<'a, T>,     CategoricalAVT<'a, T>,      [ 'a, T: Integer ]);
791
792#[cfg(test)]
793mod tests {
794    use minarrow::{Bitmask, BooleanArray, CategoricalArray, Integer, StringArray, vec64};
795
796    use crate::kernels::comparison::{
797        cmp_dict_dict, cmp_dict_str, cmp_i32, cmp_numeric, cmp_str_dict,
798    };
799
800    #[cfg(feature = "simd")]
801    use crate::kernels::comparison::{W64, cmp_bitmask_simd};
802
803    use crate::operators::ComparisonOperator;
804
805    /// --- helpers --------------------------------------------------------------
806
807    fn bm(bits: &[bool]) -> Bitmask {
808        let mut m = Bitmask::new_set_all(bits.len(), false);
809        for (i, &b) in bits.iter().enumerate() {
810            m.set(i, b);
811        }
812        m
813    }
814
815    /// Assert BooleanArray ⇢ expected value bits & expected null bits.
816    fn assert_bool(arr: &BooleanArray<()>, expect: &[bool], expect_mask: Option<&[bool]>) {
817        assert_eq!(arr.len, expect.len());
818        for i in 0..expect.len() {
819            assert_eq!(arr.data.get(i), expect[i], "value bit {i}");
820        }
821        match (arr.null_mask.as_ref(), expect_mask) {
822            (None, None) => {}
823            (Some(m), Some(exp)) => {
824                for (i, &b) in exp.iter().enumerate() {
825                    assert_eq!(m.get(i), b, "null-bit {i}");
826                }
827            }
828            _ => panic!("mask mismatch"),
829        }
830    }
831
832    /// Tiny helpers to build test String / Dict arrays.
833    fn str_arr<T: Integer>(v: &[&str]) -> StringArray<T> {
834        StringArray::<T>::from_slice(v)
835    }
836
837    fn dict_arr<T: Integer>(vals: &[&str]) -> CategoricalArray<T> {
838        let owned: Vec<&str> = vals.to_vec();
839        CategoricalArray::<T>::from_values(owned)
840    }
841
842    // NUMERIC
843
844    #[test]
845    fn numeric_compare_no_nulls() {
846        let a = vec64![1i32, 2, 3, 4];
847        let b = vec64![1i32, 1, 4, 4];
848
849        let eq = cmp_i32(&a, &b, None, ComparisonOperator::Equals).unwrap();
850        let neq = cmp_i32(&a, &b, None, ComparisonOperator::NotEquals).unwrap();
851        let lt = cmp_i32(&a, &b, None, ComparisonOperator::LessThan).unwrap();
852        let le = cmp_i32(&a, &b, None, ComparisonOperator::LessThanOrEqualTo).unwrap();
853        let gt = cmp_i32(&a, &b, None, ComparisonOperator::GreaterThan).unwrap();
854        let ge = cmp_i32(&a, &b, None, ComparisonOperator::GreaterThanOrEqualTo).unwrap();
855
856        assert_bool(&eq, &[true, false, false, true], None);
857        assert_bool(&neq, &[false, true, true, false], None);
858        assert_bool(&lt, &[false, false, true, false], None);
859        assert_bool(&le, &[true, false, true, true], None);
860        assert_bool(&gt, &[false, true, false, false], None);
861        assert_bool(&ge, &[true, true, false, true], None);
862    }
863
864    #[test]
865    fn numeric_compare_with_nulls_generic_dispatch() {
866        // last element masked-out
867        let a = vec64![1u64, 5, 9, 10];
868        let b = vec64![0u64, 5, 8, 11];
869        let mask = bm(&[true, true, true, false]);
870
871        let out = cmp_numeric(&a, &b, Some(&mask), ComparisonOperator::GreaterThan).unwrap();
872        // result bits for valid rows only
873        assert_bool(
874            &out,
875            &[true, false, true, false],
876            Some(&[true, true, true, false]),
877        );
878    }
879
880    // BOOLEAN
881
882    #[cfg(feature = "simd")]
883    #[test]
884    fn bool_compare_all_ops() {
885        let a = bm(&[true, false, true, false]);
886        let b = bm(&[true, true, false, false]);
887        let eq = cmp_bitmask_simd::<W64>(
888            (&a, 0, a.len()),
889            (&b, 0, b.len()),
890            None,
891            ComparisonOperator::Equals,
892        )
893        .unwrap();
894        let lt = cmp_bitmask_simd::<W64>(
895            (&a, 0, a.len()),
896            (&b, 0, b.len()),
897            None,
898            ComparisonOperator::LessThan,
899        )
900        .unwrap();
901        let gt = cmp_bitmask_simd::<W64>(
902            (&a, 0, a.len()),
903            (&b, 0, b.len()),
904            None,
905            ComparisonOperator::GreaterThan,
906        )
907        .unwrap();
908
909        assert_bool(
910            &BooleanArray::from_bitmask(eq, None),
911            &[true, false, false, true],
912            None,
913        );
914        assert_bool(
915            &BooleanArray::from_bitmask(lt, None),
916            &[false, true, false, false],
917            None,
918        );
919        assert_bool(
920            &BooleanArray::from_bitmask(gt, None),
921            &[false, false, true, false],
922            None,
923        );
924    }
925
926    // UTF & DICTIONARY
927
928    #[test]
929    fn string_vs_dict_compare_with_nulls() {
930        let mut lhs = str_arr::<u32>(&["x", "y", "z"]);
931        lhs.null_mask = Some(bm(&[true, false, true]));
932        let rhs = dict_arr::<u32>(&["x", "w", "zz"]);
933        let lhs_slice = (&lhs, 0, lhs.len());
934        let rhs_slice = (&rhs, 0, rhs.data.len());
935        let res = cmp_str_dict(lhs_slice, rhs_slice, ComparisonOperator::Equals).unwrap();
936        assert_bool(&res, &[true, false, false], Some(&[true, false, true]));
937    }
938
939    #[test]
940    fn string_vs_dict_compare_with_nulls_chunk() {
941        let mut lhs = str_arr::<u32>(&["pad", "x", "y", "z", "pad"]);
942        lhs.null_mask = Some(bm(&[true, true, false, true, true]));
943        let rhs = dict_arr::<u32>(&["pad", "x", "w", "zz", "pad"]);
944        let lhs_slice = (&lhs, 1, 3);
945        let rhs_slice = (&rhs, 1, 3);
946        let res = cmp_str_dict(lhs_slice, rhs_slice, ComparisonOperator::Equals).unwrap();
947        assert_bool(&res, &[true, false, false], Some(&[true, false, true]));
948    }
949
950    #[test]
951    fn dict_vs_dict_compare_gt() {
952        let lhs = dict_arr::<u32>(&["apple", "pear", "banana"]);
953        let rhs = dict_arr::<u32>(&["ant", "pear", "apricot"]);
954        let lhs_slice = (&lhs, 0, lhs.data.len());
955        let rhs_slice = (&rhs, 0, rhs.data.len());
956        let res = cmp_dict_dict(lhs_slice, rhs_slice, ComparisonOperator::GreaterThan).unwrap();
957        assert_bool(&res, &[true, false, true], None);
958    }
959
960    #[test]
961    fn dict_vs_dict_compare_gt_chunk() {
962        let lhs = dict_arr::<u32>(&["pad", "apple", "pear", "banana", "pad"]);
963        let rhs = dict_arr::<u32>(&["pad", "ant", "pear", "apricot", "pad"]);
964        let lhs_slice = (&lhs, 1, 3);
965        let rhs_slice = (&rhs, 1, 3);
966        let res = cmp_dict_dict(lhs_slice, rhs_slice, ComparisonOperator::GreaterThan).unwrap();
967        assert_bool(&res, &[true, false, true], None);
968    }
969
970    #[test]
971    fn dict_vs_string_compare_le() {
972        let lhs = dict_arr::<u32>(&["a", "b", "c"]);
973        let rhs = str_arr::<u32>(&["b", "b", "d"]);
974        let lhs_slice = (&lhs, 0, lhs.data.len());
975        let rhs_slice = (&rhs, 0, rhs.len());
976        let res =
977            cmp_dict_str(lhs_slice, rhs_slice, ComparisonOperator::LessThanOrEqualTo).unwrap();
978        assert_bool(&res, &[true, true, true], None);
979    }
980
981    #[test]
982    fn dict_vs_string_compare_le_chunk() {
983        let lhs = dict_arr::<u32>(&["pad", "a", "b", "c", "pad"]);
984        let rhs = str_arr::<u32>(&["pad", "b", "b", "d", "pad"]);
985        let lhs_slice = (&lhs, 1, 3);
986        let rhs_slice = (&rhs, 1, 3);
987        let res =
988            cmp_dict_str(lhs_slice, rhs_slice, ComparisonOperator::LessThanOrEqualTo).unwrap();
989        assert_bool(&res, &[true, true, true], None);
990    }
991}