simd_kernels/kernels/
comparison.rs

1// Copyright Peter Bower 2025. All Rights Reserved.
2// Licensed under Mozilla Public License (MPL) 2.0.
3
4//! # **Comparison Operations Kernels Module** - *High-Performance Element-wise Comparison Operations*
5//!
6//! Optimised comparison kernels providing comprehensive element-wise comparison operations
7//! across numeric, string, and categorical data types with SIMD acceleration and null-aware semantics.
8//! Foundation for filtering, conditional logic, and analytical query processing.
9//!
10//! ## Core Operations
11//! - **Numeric comparisons**: Equal, not equal, greater than, less than, greater/less than or equal
12//! - **String comparisons**: UTF-8 aware lexicographic ordering with efficient prefix matching
13//! - **Categorical comparisons**: Dictionary-encoded comparisons avoiding string materialisation
14//! - **Null-aware semantics**: Proper three-valued logic handling (true/false/null)
15//! - **SIMD vectorisation**: Hardware-accelerated bulk comparison operations
16//! - **Bitmask operations**: Efficient boolean result representation using bit manipulation
17
18include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
19
20// SIMD
21use core::simd::{LaneCount, SupportedLaneCount};
22use std::marker::PhantomData;
23#[cfg(feature = "simd")]
24use std::simd::{Mask, Simd};
25
26use minarrow::{Bitmask, BooleanArray, Integer, Numeric};
27
28#[cfg(not(feature = "simd"))]
29use crate::kernels::bitmask::std::{and_masks, in_mask, not_in_mask, not_mask};
30use crate::operators::ComparisonOperator;
31use minarrow::enums::error::KernelError;
32#[cfg(feature = "simd")]
33use minarrow::kernels::bitmask::simd::{
34    and_masks_simd, in_mask_simd, not_in_mask_simd, not_mask_simd,
35};
36use minarrow::utils::confirm_equal_len;
37#[cfg(feature = "simd")]
38use minarrow::utils::is_simd_aligned;
39use minarrow::{BitmaskVT, BooleanAVT, CategoricalAVT, StringAVT};
40
41/// Returns a new Bitmask for boolean buffers, all bits cleared (false).
42#[inline(always)]
43fn new_bool_bitmask(len: usize) -> Bitmask {
44    Bitmask::new_set_all(len, false)
45}
46
47/// Merge two Bitmasks using bitwise AND, or propagate one if only one is present.
48fn merge_bitmasks_to_new(a: Option<&Bitmask>, b: Option<&Bitmask>, len: usize) -> Option<Bitmask> {
49    match (a, b) {
50        (None, None) => None,
51        (Some(x), None) | (None, Some(x)) => Some(x.slice_clone(0, len)),
52        (Some(x), Some(y)) => {
53            let mut out = Bitmask::new_set_all(len, true);
54            for i in 0..len {
55                unsafe { out.set_unchecked(i, x.get_unchecked(i) && y.get_unchecked(i)) };
56            }
57            Some(out)
58        }
59    }
60}
61
62// Int and float
63
64macro_rules! impl_cmp_numeric {
65    ($fn_name:ident, $ty:ty, $lanes:expr, $mask_elem:ty) => {
66        /// Type-specific SIMD-accelerated comparison function with vectorised operations.
67        ///
68        /// Specialised comparison implementation optimised for the specific numeric type with
69        /// architecture-appropriate lane configuration. Features memory alignment checking,
70        /// SIMD vectorisation, and optional null mask support for maximum performance.
71        ///
72        /// # Parameters
73        /// - `lhs`: Left-hand side slice for comparison
74        /// - `rhs`: Right-hand side slice for comparison
75        /// - `mask`: Optional validity mask applied after comparison
76        /// - `op`: Comparison operator to apply
77        ///
78        /// # Returns
79        /// `Result<BooleanArray<()>, KernelError>` containing comparison results.
80        ///
81        /// # SIMD Optimisations
82        /// - Memory alignment: Checks 64-byte alignment for optimal SIMD operations
83        /// - Vectorised comparisons: Uses SIMD compare operations for parallel processing
84        /// - Scalar fallback: Efficient scalar path for unaligned or remainder elements
85        #[inline(always)]
86        pub fn $fn_name(
87            lhs: &[$ty],
88            rhs: &[$ty],
89            mask: Option<&Bitmask>,
90            op: ComparisonOperator,
91        ) -> Result<BooleanArray<()>, KernelError> {
92            let len = lhs.len();
93            confirm_equal_len("compare numeric length mismatch", len, rhs.len())?;
94            let has_nulls = mask.is_some();
95            let mut out = new_bool_bitmask(len);
96
97            #[cfg(feature = "simd")]
98            {
99                // Check if both arrays are 64-byte aligned for SIMD
100                if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
101                    use std::simd::cmp::{SimdPartialEq, SimdPartialOrd};
102                    const N: usize = $lanes;
103                    if !has_nulls {
104                        let mut i = 0;
105                        while i + N <= len {
106                            let a = Simd::<$ty, N>::from_slice(&lhs[i..i + N]);
107                            let b = Simd::<$ty, N>::from_slice(&rhs[i..i + N]);
108                            let m: Mask<$mask_elem, N> = match op {
109                                ComparisonOperator::Equals => a.simd_eq(b),
110                                ComparisonOperator::NotEquals => a.simd_ne(b),
111                                ComparisonOperator::LessThan => a.simd_lt(b),
112                                ComparisonOperator::LessThanOrEqualTo => a.simd_le(b),
113                                ComparisonOperator::GreaterThan => a.simd_gt(b),
114                                ComparisonOperator::GreaterThanOrEqualTo => a.simd_ge(b),
115                                _ => Mask::splat(false),
116                            };
117                            let bits = m.to_bitmask();
118                            for l in 0..N {
119                                if ((bits >> l) & 1) == 1 {
120                                    unsafe { out.set_unchecked(i + l, true) };
121                                }
122                            }
123                            i += N;
124                        }
125                        // Tail often caused by `n % LANES != 0`; uses scalar fallback.
126                        for j in i..len {
127                            let res = match op {
128                                ComparisonOperator::Equals => lhs[j] == rhs[j],
129                                ComparisonOperator::NotEquals => lhs[j] != rhs[j],
130                                ComparisonOperator::LessThan => lhs[j] < rhs[j],
131                                ComparisonOperator::LessThanOrEqualTo => lhs[j] <= rhs[j],
132                                ComparisonOperator::GreaterThan => lhs[j] > rhs[j],
133                                ComparisonOperator::GreaterThanOrEqualTo => lhs[j] >= rhs[j],
134                                _ => false,
135                            };
136                            if res {
137                                unsafe { out.set_unchecked(j, true) };
138                            }
139                        }
140
141                        return Ok(BooleanArray {
142                            data: out.into(),
143                            null_mask: None,
144                            len,
145                            _phantom: PhantomData,
146                        });
147                    }
148                }
149                // Fall through to scalar path if alignment check failed
150            }
151
152            // Scalar fallback - alignment check failed
153            for i in 0..len {
154                if has_nulls && !mask.map_or(true, |m| unsafe { m.get_unchecked(i) }) {
155                    continue;
156                }
157                let res = match op {
158                    ComparisonOperator::Equals => lhs[i] == rhs[i],
159                    ComparisonOperator::NotEquals => lhs[i] != rhs[i],
160                    ComparisonOperator::LessThan => lhs[i] < rhs[i],
161                    ComparisonOperator::LessThanOrEqualTo => lhs[i] <= rhs[i],
162                    ComparisonOperator::GreaterThan => lhs[i] > rhs[i],
163                    ComparisonOperator::GreaterThanOrEqualTo => lhs[i] >= rhs[i],
164                    _ => false,
165                };
166                if res {
167                    unsafe { out.set_unchecked(i, true) };
168                }
169            }
170            Ok(BooleanArray {
171                data: out.into(),
172                null_mask: mask.cloned(),
173                len,
174                _phantom: PhantomData,
175            })
176        }
177    };
178}
179
180/// Unified numeric comparison dispatch with optional null mask support.
181///
182/// High-performance generic comparison function that dispatches to type-specific SIMD implementations
183/// based on runtime type identification. Supports all numeric types with optional null mask filtering
184/// and comprehensive error handling for mismatched lengths and unsupported types.
185///
186/// # Type Parameters
187/// - `T`: Numeric type implementing `Numeric + Copy + 'static` (i32, i64, u32, u64, f32, f64)
188///
189/// # Parameters
190/// - `lhs`: Left-hand side numeric slice for comparison
191/// - `rhs`: Right-hand side numeric slice for comparison  
192/// - `mask`: Optional validity mask applied after comparison (AND operation)
193/// - `op`: Comparison operator to apply (Equals, NotEquals, LessThan, etc.)
194///
195/// # Returns
196/// `Result<BooleanArray<()>, KernelError>` containing comparison results or error details.
197///
198/// # Dispatch Strategy
199/// Uses `TypeId` runtime checking to dispatch to optimal type-specific implementations:
200/// - `i32`/`u32`: 32-bit integer SIMD kernels with W32 lane configuration
201/// - `i64`/`u64`: 64-bit integer SIMD kernels with W64 lane configuration  
202/// - `f32`/`f64`: IEEE 754 floating-point SIMD kernels with specialised NaN handling
203///
204/// # Error Conditions
205/// - `KernelError::LengthMismatch`: Input slices have different lengths
206/// - `KernelError::InvalidArguments`: Unsupported numeric type (unreachable in practice)
207///
208/// # Performance Benefits
209/// - Zero-cost dispatch: Type resolution optimised away at compile time for monomorphic usage
210/// - SIMD acceleration: Delegates to vectorised implementations for maximum throughput
211/// - Memory efficiency: Optional mask processing avoids unnecessary allocations
212///
213/// # Example Usage
214/// ```rust,ignore
215/// use simd_kernels::kernels::comparison::{cmp_numeric, ComparisonOperator};
216///
217/// let lhs = &[1i32, 2, 3, 4];
218/// let rhs = &[1i32, 1, 4, 3];
219/// let result = cmp_numeric(lhs, rhs, None, ComparisonOperator::Equals)?;
220/// // Result: [true, false, false, false]
221/// ```
222#[inline(always)]
223pub fn cmp_numeric<T: Numeric + Copy + 'static>(
224    lhs: &[T],
225    rhs: &[T],
226    mask: Option<&Bitmask>,
227    op: ComparisonOperator,
228) -> Result<BooleanArray<()>, KernelError> {
229    macro_rules! dispatch {
230        ($t:ty, $f:ident) => {
231            if std::any::TypeId::of::<T>() == std::any::TypeId::of::<$t>() {
232                return $f(
233                    unsafe { std::mem::transmute(lhs) },
234                    unsafe { std::mem::transmute(rhs) },
235                    mask,
236                    op,
237                );
238            }
239        };
240    }
241    dispatch!(i32, cmp_i32);
242    dispatch!(i64, cmp_i64);
243    dispatch!(u32, cmp_u32);
244    dispatch!(u64, cmp_u64);
245    dispatch!(f32, cmp_f32);
246    dispatch!(f64, cmp_f64);
247
248    unreachable!("Unsupported numeric type for compare_numeric");
249}
250
251/// SIMD-accelerated compare bitmask
252///
253/// Compare two packed bool bitmask slices over a window, using the given operator.
254/// The offsets are bit offsets into each mask.
255/// The mask, if provided, is ANDed after the comparison.
256/// Requires that all offsets are 64-bit aligned (i.e., offset % 64 == 0).
257///
258/// This lower level kernel can be orchestrated by apply_cmp_bool which
259/// wraps it into a BoolWindow with null-aware semantics.
260#[cfg(feature = "simd")]
261pub fn cmp_bitmask_simd<const LANES: usize>(
262    lhs: BitmaskVT<'_>,
263    rhs: BitmaskVT<'_>,
264    mask: Option<BitmaskVT<'_>>,
265    op: ComparisonOperator,
266) -> Result<Bitmask, KernelError>
267where
268    LaneCount<LANES>: SupportedLaneCount,
269{
270    // We have some code duplication here with the `std` version,
271    // but unifying then means a const LANE generic on the non-simd path,
272    // and adding a higher level dispatch layer creates additional indirection
273    // and 9 args instead of 4, hence why it's this way.
274
275    confirm_equal_len("compare bool length mismatch", lhs.2, rhs.2)?;
276    let (lhs_mask, lhs_offset, len) = lhs;
277    let (rhs_mask, rhs_offset, _) = rhs;
278
279    // Handle 'In' and 'NotIn' early
280
281    if matches!(op, ComparisonOperator::In | ComparisonOperator::NotIn) {
282        let mut out = match op {
283            ComparisonOperator::In => in_mask_simd::<LANES>(lhs, rhs),
284            ComparisonOperator::NotIn => not_in_mask_simd::<LANES>(lhs, rhs),
285            _ => unreachable!(),
286        };
287        if let Some(mask_slice) = mask {
288            out = and_masks_simd::<LANES>((&out, 0, out.len), mask_slice);
289        }
290        return Ok(out);
291    }
292
293    // Word-aligned offsets
294    if lhs_offset % 64 != 0
295        || rhs_offset % 64 != 0
296        || mask.as_ref().map_or(false, |(_, mo, _)| mo % 64 != 0)
297    {
298        return Err(KernelError::InvalidArguments(format!(
299            "cmp_bitmask: all offsets must be 64-bit aligned (lhs: {}, rhs: {}, mask offset: {:?})",
300            lhs_offset,
301            rhs_offset,
302            mask.as_ref().map(|(_, mo, _)| mo)
303        )));
304    }
305
306    // Precompute word indices/counts
307    let lhs_word_start = lhs_offset / 64;
308    let rhs_word_start = rhs_offset / 64;
309    let n_words = (len + 63) / 64;
310
311    // Allocate output
312    let mut out = Bitmask::new_set_all(len, false);
313
314    type Word = u64;
315    let lane_words = LANES;
316    let simd_chunks = n_words / lane_words;
317
318    let tail_words = n_words % lane_words;
319    let mut word_idx = 0;
320
321    // SIMD main path
322    for chunk in 0..simd_chunks {
323        let base_lhs = lhs_word_start + chunk * lane_words;
324        let base_rhs = rhs_word_start + chunk * lane_words;
325        let base_mask = mask
326            .as_ref()
327            .map(|(m, mask_word_start, _)| (m, mask_word_start + chunk * lane_words));
328
329        let mut lhs_arr = [0u64; LANES];
330        let mut rhs_arr = [0u64; LANES];
331        let mut mask_arr = [!0u64; LANES];
332
333        for lane in 0..LANES {
334            lhs_arr[lane] = unsafe { lhs_mask.word_unchecked(base_lhs + lane) };
335            rhs_arr[lane] = unsafe { rhs_mask.word_unchecked(base_rhs + lane) };
336            if let Some((m, mask_word_start)) = base_mask {
337                mask_arr[lane] = unsafe { m.word_unchecked(mask_word_start + lane) };
338            }
339        }
340        let lhs_v = Simd::<Word, LANES>::from_array(lhs_arr);
341        let rhs_v = Simd::<Word, LANES>::from_array(rhs_arr);
342        let mask_v = Simd::<Word, LANES>::from_array(mask_arr);
343
344        let cmp_v = match op {
345            ComparisonOperator::Equals => !(lhs_v ^ rhs_v),
346            ComparisonOperator::NotEquals => lhs_v ^ rhs_v,
347            ComparisonOperator::GreaterThan => lhs_v & (!rhs_v),
348            ComparisonOperator::LessThan => (!lhs_v) & rhs_v,
349            ComparisonOperator::GreaterThanOrEqualTo => lhs_v | (!rhs_v),
350            ComparisonOperator::LessThanOrEqualTo => (!lhs_v) | rhs_v,
351            _ => Simd::splat(0),
352        };
353        let result_v = cmp_v & mask_v;
354
355        for lane in 0..LANES {
356            unsafe {
357                out.set_word_unchecked(word_idx, result_v[lane]);
358            }
359            word_idx += 1;
360        }
361    }
362
363    // Tail often caused by `n % LANES != 0`; uses scalar fallback.
364    let base_lhs = lhs_word_start + simd_chunks * lane_words;
365    let base_rhs = rhs_word_start + simd_chunks * lane_words;
366    let base_mask: Option<(&Bitmask, usize)> = mask
367        .as_ref()
368        .map(|(m, mo, _)| (*m, mo + simd_chunks * lane_words));
369
370    for tail in 0..tail_words {
371        let a = unsafe { lhs_mask.word_unchecked(base_lhs + tail) };
372        let b = unsafe { rhs_mask.word_unchecked(base_rhs + tail) };
373        let m = if let Some((m, mask_word_start)) = base_mask {
374            unsafe { m.word_unchecked(mask_word_start + tail) }
375        } else {
376            !0u64
377        };
378        let cmp = match op {
379            ComparisonOperator::Equals => !(a ^ b),
380            ComparisonOperator::NotEquals => a ^ b,
381            ComparisonOperator::GreaterThan => a & (!b),
382            ComparisonOperator::LessThan => (!a) & b,
383            ComparisonOperator::GreaterThanOrEqualTo => a | (!b),
384            ComparisonOperator::LessThanOrEqualTo => (!a) | b,
385            _ => 0,
386        } & m;
387        unsafe {
388            out.set_word_unchecked(word_idx, cmp);
389        }
390        word_idx += 1;
391    }
392
393    out.mask_trailing_bits();
394    Ok(out)
395}
396
397/// Performs vectorised boolean array comparisons with null mask handling.
398///
399/// High-performance SIMD-accelerated comparison function for boolean arrays with automatic null
400/// mask merging and operator-specific optimisations. Supports all comparison operators through
401/// efficient bitmask operations with configurable lane counts for architecture-specific tuning.
402///
403/// # Type Parameters
404/// - `LANES`: Number of SIMD lanes to process simultaneously
405///
406/// # Parameters
407/// - `lhs`: Left-hand side boolean array view as `(array, offset, length)` tuple
408/// - `rhs`: Right-hand side boolean array view as `(array, offset, length)` tuple  
409/// - `op`: Comparison operator (Equals, NotEquals, In, NotIn, IsNull, IsNotNull, etc.)
410///
411/// # Returns
412/// `Result<BooleanArray<()>, KernelError>` containing comparison results with merged null semantics.
413pub fn cmp_bool<const LANES: usize>(
414    lhs: BooleanAVT<'_, ()>,
415    rhs: BooleanAVT<'_, ()>,
416    op: ComparisonOperator,
417) -> Result<BooleanArray<()>, KernelError>
418where
419    LaneCount<LANES>: SupportedLaneCount,
420{
421    let (lhs_arr, lhs_off, len) = lhs;
422    let (rhs_arr, rhs_off, rlen) = rhs;
423    debug_assert_eq!(len, rlen, "cmp_bool: window length mismatch");
424
425    #[cfg(feature = "simd")]
426    let merged_null_mask: Option<Bitmask> =
427        match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
428            (None, None) => None,
429            (Some(m), None) | (None, Some(m)) => Some(m.slice_clone(lhs_off, len)),
430            (Some(a), Some(b)) => {
431                let am = (a, lhs_off, len);
432                let bm = (b, rhs_off, len);
433                Some(and_masks_simd::<LANES>(am, bm))
434            }
435        };
436
437    #[cfg(not(feature = "simd"))]
438    let merged_null_mask: Option<Bitmask> =
439        match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
440            (None, None) => None,
441            (Some(m), None) | (None, Some(m)) => Some(m.slice_clone(lhs_off, len)),
442            (Some(a), Some(b)) => {
443                let am = (a, lhs_off, len);
444                let bm = (b, rhs_off, len);
445                Some(and_masks(am, bm))
446            }
447        };
448
449    let mask_slice = merged_null_mask.as_ref().map(|m| (m, 0, len));
450
451    let data = match op {
452        ComparisonOperator::Equals
453        | ComparisonOperator::NotEquals
454        | ComparisonOperator::LessThan
455        | ComparisonOperator::LessThanOrEqualTo
456        | ComparisonOperator::GreaterThan
457        | ComparisonOperator::GreaterThanOrEqualTo
458        | ComparisonOperator::In
459        | ComparisonOperator::NotIn => {
460            #[cfg(feature = "simd")]
461            let res = cmp_bitmask_simd::<LANES>(
462                (&lhs_arr.data, lhs_off, len),
463                (&rhs_arr.data, rhs_off, len),
464                mask_slice,
465                op,
466            )?;
467            #[cfg(not(feature = "simd"))]
468            let res = cmp_bitmask_std(
469                (&lhs_arr.data, lhs_off, len),
470                (&rhs_arr.data, rhs_off, len),
471                mask_slice,
472                op,
473            )?;
474            res
475        }
476        ComparisonOperator::IsNull => {
477            #[cfg(feature = "simd")]
478            let data = match merged_null_mask.as_ref() {
479                Some(mask) => not_mask_simd::<LANES>((mask, 0, len)),
480                None => Bitmask::new_set_all(len, false),
481            };
482            #[cfg(not(feature = "simd"))]
483            let data = match merged_null_mask.as_ref() {
484                Some(mask) => not_mask((mask, 0, len)),
485                None => Bitmask::new_set_all(len, false),
486            };
487            return Ok(BooleanArray {
488                data,
489                null_mask: None,
490                len,
491                _phantom: PhantomData,
492            });
493        }
494        ComparisonOperator::IsNotNull => {
495            let data = match merged_null_mask.as_ref() {
496                Some(mask) => mask.slice_clone(0, len),
497                None => Bitmask::new_set_all(len, true),
498            };
499            return Ok(BooleanArray {
500                data,
501                null_mask: None,
502                len,
503                _phantom: PhantomData,
504            });
505        }
506        ComparisonOperator::Between => {
507            return Err(KernelError::InvalidArguments(
508                "Set operations are not defined for Bool arrays".to_owned(),
509            ));
510        }
511    };
512
513    Ok(BooleanArray {
514        data,
515        null_mask: merged_null_mask,
516        len,
517        _phantom: PhantomData,
518    })
519}
520
521/// Compare two packed bool bitmask slices over a window, using the given operator.
522/// The offsets are bit offsets into each mask.
523/// The mask, if provided, is ANDed after the comparison.
524/// Requires that all offsets are 64-bit aligned (i.e., offset % 64 == 0).
525///
526/// This lower level kernel can be orchestrated by apply_cmp_bool which
527/// wraps it into a BoolWindow with null-aware semantics.
528#[cfg(not(feature = "simd"))]
529pub fn cmp_bitmask_std(
530    lhs: BitmaskVT<'_>,
531    rhs: BitmaskVT<'_>,
532    mask: Option<BitmaskVT<'_>>,
533    op: ComparisonOperator,
534) -> Result<Bitmask, KernelError> {
535    // We have some code duplication here with the `simd` version,
536    // but unifying then means a const LANE generic on the non-simd path,
537    // and adding a higher level dispatch layer create additional indirection
538    // and 9 args instead of 4, hence why it's this way.
539
540    confirm_equal_len("compare bool length mismatch", lhs.2, rhs.2)?;
541    let (lhs_mask, lhs_offset, len) = lhs;
542    let (rhs_mask, rhs_offset, _) = rhs;
543
544    // Handle 'In' and 'NotIn' early
545
546    if matches!(op, ComparisonOperator::In | ComparisonOperator::NotIn) {
547        let mut out = match op {
548            ComparisonOperator::In => in_mask(lhs, rhs),
549            ComparisonOperator::NotIn => not_in_mask(lhs, rhs),
550            _ => unreachable!(),
551        };
552        if let Some(mask_slice) = mask {
553            out = and_masks((&out, 0, out.len), mask_slice);
554        }
555        return Ok(out);
556    }
557
558    // Word-aligned offsets
559    if lhs_offset % 64 != 0
560        || rhs_offset % 64 != 0
561        || mask.as_ref().map_or(false, |(_, mo, _)| mo % 64 != 0)
562    {
563        return Err(KernelError::InvalidArguments(format!(
564            "cmp_bitmask: all offsets must be 64-bit aligned (lhs: {}, rhs: {}, mask offset: {:?})",
565            lhs_offset,
566            rhs_offset,
567            mask.as_ref().map(|(_, mo, _)| mo)
568        )));
569    }
570
571    // Precompute word indices/counts
572    let lhs_word_start = lhs_offset / 64;
573    let rhs_word_start = rhs_offset / 64;
574    let n_words = (len + 63) / 64;
575
576    // Allocate output
577    let mut out = Bitmask::new_set_all(len, false);
578
579    let words = n_words;
580    let tail = len % 64;
581    let mask_mask_opt = mask;
582
583    // Word-aligned loop
584    for w in 0..words {
585        let a = unsafe { lhs_mask.word_unchecked(lhs_word_start + w) };
586        let b = unsafe { rhs_mask.word_unchecked(rhs_word_start + w) };
587        let valid_bits =
588            mask_mask_opt
589                .as_ref()
590                .map_or(!0u64, |(mask_mask, mask_word_start, _)| unsafe {
591                    mask_mask.word_unchecked(mask_word_start + w)
592                });
593        let word_cmp = match op {
594            ComparisonOperator::Equals => !(a ^ b),
595            ComparisonOperator::NotEquals => a ^ b,
596            ComparisonOperator::GreaterThan => a & (!b),
597            ComparisonOperator::LessThan => (!a) & b,
598            ComparisonOperator::GreaterThanOrEqualTo => a | (!b),
599            ComparisonOperator::LessThanOrEqualTo => (!a) | b,
600            _ => 0,
601        };
602        let final_bits = word_cmp & valid_bits;
603        unsafe {
604            out.set_word_unchecked(w, final_bits);
605        }
606    }
607
608    // Tail often caused by `n % LANES != 0`; uses scalar fallback.
609
610    let base = words * 64;
611    for i in 0..tail {
612        let idx_lhs = lhs_offset + base + i;
613        let idx_rhs = rhs_offset + base + i;
614        let mask_valid =
615            mask_mask_opt
616                .as_ref()
617                .map_or(true, |(mask_mask, mask_word_start, mask_len)| unsafe {
618                    let mask_idx = mask_word_start * 64 + base + i;
619                    if mask_idx < *mask_len {
620                        mask_mask.get_unchecked(mask_idx)
621                    } else {
622                        false
623                    }
624                });
625        if !mask_valid {
626            continue;
627        }
628        if idx_lhs >= lhs_mask.len() || idx_rhs >= rhs_mask.len() {
629            continue;
630        }
631        let a = unsafe { lhs_mask.get_unchecked(idx_lhs) };
632        let b = unsafe { rhs_mask.get_unchecked(idx_rhs) };
633        let res = match op {
634            ComparisonOperator::Equals => a == b,
635            ComparisonOperator::NotEquals => a != b,
636            ComparisonOperator::GreaterThan => a & !b,
637            ComparisonOperator::LessThan => !a & b,
638            ComparisonOperator::GreaterThanOrEqualTo => a | !b,
639            ComparisonOperator::LessThanOrEqualTo => !a | b,
640            _ => false,
641        };
642        if res {
643            out.set(base + i, true)
644        }
645    }
646    out.mask_trailing_bits();
647    Ok(out)
648}
649
650// String and dictionary
651
652macro_rules! impl_cmp_utf8_slice {
653    ($fn_name:ident, $lhs_slice:ty, $rhs_slice:ty, [$($gen:tt)+]) => {
654        /// Compare UTF-8 string or dictionary arrays using the specified comparison operator.
655        #[inline(always)]
656        pub fn $fn_name<$($gen)+>(
657            lhs: $lhs_slice,
658            rhs: $rhs_slice,
659            op: ComparisonOperator,
660        ) -> Result<BooleanArray<()>, KernelError> {
661            let (larr, loff, llen) = lhs;
662            let (rarr, roff, rlen) = rhs;
663            confirm_equal_len("compare string/dict length mismatch (slice contract)", llen, rlen)?;
664
665            let lhs_mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
666            let rhs_mask = rarr.null_mask.as_ref().map(|m| m.slice_clone(roff, rlen));
667
668            if let Some(m) = larr.null_mask.as_ref() {
669                if m.capacity() < loff + llen {
670                    return Err(KernelError::InvalidArguments(
671                        format!(
672                            "lhs mask capacity too small (expected ≥ {}, got {})",
673                            loff + llen,
674                            m.capacity()
675                        ),
676                    ));
677                }
678            }
679            if let Some(m) = rarr.null_mask.as_ref() {
680                if m.capacity() < roff + rlen {
681                    return Err(KernelError::InvalidArguments(
682                        format!(
683                            "rhs mask capacity too small (expected ≥ {}, got {})",
684                            roff + rlen,
685                            m.capacity()
686                        ),
687                    ));
688                }
689            }
690
691            let has_nulls = lhs_mask.is_some() || rhs_mask.is_some();
692            let mut out = new_bool_bitmask(llen);
693            for i in 0..llen {
694                if has_nulls
695                    && !(lhs_mask.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) })
696                        && rhs_mask.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }))
697                {
698                    continue;
699                }
700                let l = unsafe { larr.get_str_unchecked(loff + i) };
701                let r = unsafe { rarr.get_str_unchecked(roff + i) };
702                let res = match op {
703                    ComparisonOperator::Equals => l == r,
704                    ComparisonOperator::NotEquals => l != r,
705                    ComparisonOperator::GreaterThan => l > r,
706                    ComparisonOperator::LessThan => l < r,
707                    ComparisonOperator::GreaterThanOrEqualTo => l >= r,
708                    ComparisonOperator::LessThanOrEqualTo => l <= r,
709                    _ => false,
710                };
711                if res {
712                    out.set(i, true);
713                }
714            }
715            let null_mask = merge_bitmasks_to_new(lhs_mask.as_ref(), rhs_mask.as_ref(), llen);
716            Ok(BooleanArray { data: out.into(), null_mask, len: llen, _phantom: PhantomData })
717        }
718    };
719}
720
721impl_cmp_numeric!(cmp_i32, i32, W32, i32);
722impl_cmp_numeric!(cmp_u32, u32, W32, i32);
723impl_cmp_numeric!(cmp_i64, i64, W64, i64);
724impl_cmp_numeric!(cmp_u64, u64, W64, i64);
725impl_cmp_numeric!(cmp_f32, f32, W32, i32);
726impl_cmp_numeric!(cmp_f64, f64, W64, i64);
727impl_cmp_utf8_slice!(cmp_str_str,   StringAVT<'a, T>,     StringAVT<'a, T>,      [ 'a, T: Integer ]);
728impl_cmp_utf8_slice!(cmp_str_dict,  StringAVT<'a, T>,     CategoricalAVT<'a, U>,      [ 'a, T: Integer, U: Integer ]);
729impl_cmp_utf8_slice!(cmp_dict_str,  CategoricalAVT<'a, T>,     StringAVT<'a, U>,      [ 'a, T: Integer, U: Integer ]);
730impl_cmp_utf8_slice!(cmp_dict_dict, CategoricalAVT<'a, T>,     CategoricalAVT<'a, T>,      [ 'a, T: Integer ]);
731
732#[cfg(test)]
733mod tests {
734    use minarrow::{Bitmask, BooleanArray, CategoricalArray, Integer, StringArray, vec64};
735
736    use crate::kernels::comparison::{
737        cmp_dict_dict, cmp_dict_str, cmp_i32, cmp_numeric, cmp_str_dict,
738    };
739
740    #[cfg(feature = "simd")]
741    use crate::kernels::comparison::{W64, cmp_bitmask_simd};
742
743    use crate::operators::ComparisonOperator;
744
745    /// --- helpers --------------------------------------------------------------
746
747    fn bm(bits: &[bool]) -> Bitmask {
748        let mut m = Bitmask::new_set_all(bits.len(), false);
749        for (i, &b) in bits.iter().enumerate() {
750            m.set(i, b);
751        }
752        m
753    }
754
755    /// Assert BooleanArray ⇢ expected value bits & expected null bits.
756    fn assert_bool(arr: &BooleanArray<()>, expect: &[bool], expect_mask: Option<&[bool]>) {
757        assert_eq!(arr.len, expect.len());
758        for i in 0..expect.len() {
759            assert_eq!(arr.data.get(i), expect[i], "value bit {i}");
760        }
761        match (arr.null_mask.as_ref(), expect_mask) {
762            (None, None) => {}
763            (Some(m), Some(exp)) => {
764                for (i, &b) in exp.iter().enumerate() {
765                    assert_eq!(m.get(i), b, "null-bit {i}");
766                }
767            }
768            _ => panic!("mask mismatch"),
769        }
770    }
771
772    /// Tiny helpers to build test String / Dict arrays.
773    fn str_arr<T: Integer>(v: &[&str]) -> StringArray<T> {
774        StringArray::<T>::from_slice(v)
775    }
776
777    fn dict_arr<T: Integer>(vals: &[&str]) -> CategoricalArray<T> {
778        let owned: Vec<&str> = vals.to_vec();
779        CategoricalArray::<T>::from_values(owned)
780    }
781
782    // NUMERIC
783
784    #[test]
785    fn numeric_compare_no_nulls() {
786        let a = vec64![1i32, 2, 3, 4];
787        let b = vec64![1i32, 1, 4, 4];
788
789        let eq = cmp_i32(&a, &b, None, ComparisonOperator::Equals).unwrap();
790        let neq = cmp_i32(&a, &b, None, ComparisonOperator::NotEquals).unwrap();
791        let lt = cmp_i32(&a, &b, None, ComparisonOperator::LessThan).unwrap();
792        let le = cmp_i32(&a, &b, None, ComparisonOperator::LessThanOrEqualTo).unwrap();
793        let gt = cmp_i32(&a, &b, None, ComparisonOperator::GreaterThan).unwrap();
794        let ge = cmp_i32(&a, &b, None, ComparisonOperator::GreaterThanOrEqualTo).unwrap();
795
796        assert_bool(&eq, &[true, false, false, true], None);
797        assert_bool(&neq, &[false, true, true, false], None);
798        assert_bool(&lt, &[false, false, true, false], None);
799        assert_bool(&le, &[true, false, true, true], None);
800        assert_bool(&gt, &[false, true, false, false], None);
801        assert_bool(&ge, &[true, true, false, true], None);
802    }
803
804    #[test]
805    fn numeric_compare_with_nulls_generic_dispatch() {
806        // last element masked-out
807        let a = vec64![1u64, 5, 9, 10];
808        let b = vec64![0u64, 5, 8, 11];
809        let mask = bm(&[true, true, true, false]);
810
811        let out = cmp_numeric(&a, &b, Some(&mask), ComparisonOperator::GreaterThan).unwrap();
812        // result bits for valid rows only
813        assert_bool(
814            &out,
815            &[true, false, true, false],
816            Some(&[true, true, true, false]),
817        );
818    }
819
820    // BOOLEAN
821
822    #[cfg(feature = "simd")]
823    #[test]
824    fn bool_compare_all_ops() {
825        let a = bm(&[true, false, true, false]);
826        let b = bm(&[true, true, false, false]);
827        let eq = cmp_bitmask_simd::<W64>(
828            (&a, 0, a.len()),
829            (&b, 0, b.len()),
830            None,
831            ComparisonOperator::Equals,
832        )
833        .unwrap();
834        let lt = cmp_bitmask_simd::<W64>(
835            (&a, 0, a.len()),
836            (&b, 0, b.len()),
837            None,
838            ComparisonOperator::LessThan,
839        )
840        .unwrap();
841        let gt = cmp_bitmask_simd::<W64>(
842            (&a, 0, a.len()),
843            (&b, 0, b.len()),
844            None,
845            ComparisonOperator::GreaterThan,
846        )
847        .unwrap();
848
849        assert_bool(
850            &BooleanArray::from_bitmask(eq, None),
851            &[true, false, false, true],
852            None,
853        );
854        assert_bool(
855            &BooleanArray::from_bitmask(lt, None),
856            &[false, true, false, false],
857            None,
858        );
859        assert_bool(
860            &BooleanArray::from_bitmask(gt, None),
861            &[false, false, true, false],
862            None,
863        );
864    }
865
866    // UTF & DICTIONARY
867
868    #[test]
869    fn string_vs_dict_compare_with_nulls() {
870        let mut lhs = str_arr::<u32>(&["x", "y", "z"]);
871        lhs.null_mask = Some(bm(&[true, false, true]));
872        let rhs = dict_arr::<u32>(&["x", "w", "zz"]);
873        let lhs_slice = (&lhs, 0, lhs.len());
874        let rhs_slice = (&rhs, 0, rhs.data.len());
875        let res = cmp_str_dict(lhs_slice, rhs_slice, ComparisonOperator::Equals).unwrap();
876        assert_bool(&res, &[true, false, false], Some(&[true, false, true]));
877    }
878
879    #[test]
880    fn string_vs_dict_compare_with_nulls_chunk() {
881        let mut lhs = str_arr::<u32>(&["pad", "x", "y", "z", "pad"]);
882        lhs.null_mask = Some(bm(&[true, true, false, true, true]));
883        let rhs = dict_arr::<u32>(&["pad", "x", "w", "zz", "pad"]);
884        let lhs_slice = (&lhs, 1, 3);
885        let rhs_slice = (&rhs, 1, 3);
886        let res = cmp_str_dict(lhs_slice, rhs_slice, ComparisonOperator::Equals).unwrap();
887        assert_bool(&res, &[true, false, false], Some(&[true, false, true]));
888    }
889
890    #[test]
891    fn dict_vs_dict_compare_gt() {
892        let lhs = dict_arr::<u32>(&["apple", "pear", "banana"]);
893        let rhs = dict_arr::<u32>(&["ant", "pear", "apricot"]);
894        let lhs_slice = (&lhs, 0, lhs.data.len());
895        let rhs_slice = (&rhs, 0, rhs.data.len());
896        let res = cmp_dict_dict(lhs_slice, rhs_slice, ComparisonOperator::GreaterThan).unwrap();
897        assert_bool(&res, &[true, false, true], None);
898    }
899
900    #[test]
901    fn dict_vs_dict_compare_gt_chunk() {
902        let lhs = dict_arr::<u32>(&["pad", "apple", "pear", "banana", "pad"]);
903        let rhs = dict_arr::<u32>(&["pad", "ant", "pear", "apricot", "pad"]);
904        let lhs_slice = (&lhs, 1, 3);
905        let rhs_slice = (&rhs, 1, 3);
906        let res = cmp_dict_dict(lhs_slice, rhs_slice, ComparisonOperator::GreaterThan).unwrap();
907        assert_bool(&res, &[true, false, true], None);
908    }
909
910    #[test]
911    fn dict_vs_string_compare_le() {
912        let lhs = dict_arr::<u32>(&["a", "b", "c"]);
913        let rhs = str_arr::<u32>(&["b", "b", "d"]);
914        let lhs_slice = (&lhs, 0, lhs.data.len());
915        let rhs_slice = (&rhs, 0, rhs.len());
916        let res =
917            cmp_dict_str(lhs_slice, rhs_slice, ComparisonOperator::LessThanOrEqualTo).unwrap();
918        assert_bool(&res, &[true, true, true], None);
919    }
920
921    #[test]
922    fn dict_vs_string_compare_le_chunk() {
923        let lhs = dict_arr::<u32>(&["pad", "a", "b", "c", "pad"]);
924        let rhs = str_arr::<u32>(&["pad", "b", "b", "d", "pad"]);
925        let lhs_slice = (&lhs, 1, 3);
926        let rhs_slice = (&rhs, 1, 3);
927        let res =
928            cmp_dict_str(lhs_slice, rhs_slice, ComparisonOperator::LessThanOrEqualTo).unwrap();
929        assert_bool(&res, &[true, true, true], None);
930    }
931}