Skip to main content

simd_kernels/kernels/
binary.rs

1// Copyright (c) 2025 Peter Bower
2// SPDX-License-Identifier: AGPL-3.0-or-later
3// Commercial licensing available. See LICENSE and LICENSING.md.
4
5//! # **Binary Operations Kernels Module** - *High-Performance Element-wise Binary Operations*
6//!
7//! Comprehensive binary operation kernels providing element-wise operations between array pairs
8//! with null-aware semantics and SIMD acceleration. Critical foundation for analytical computing
9//! requiring efficient pairwise data transformations.
10//!
11//! ## Core Operations
12//! - **Numeric comparisons**: Greater than, less than, equal operations across all numeric types
13//! - **String operations**: String comparison with UTF-8 aware lexicographic ordering
14//! - **Categorical operations**: Dictionary-encoded string comparisons with optimised lookups
15//! - **Set operations**: Membership testing with efficient hash-based implementations
16//! - **Range operations**: Between operations for numeric and string data types
17//! - **Logical combinations**: AND, OR, XOR operations on boolean arrays with bitmask optimisation
18//!
19//! ## Architecture Overview
20//! The module provides a unified interface for binary operations across heterogeneous data types:
21//!
22//! - **Type-aware dispatch**: Automatic selection of optimised kernels based on input types
23//! - **Memory layout optimisation**: Direct array-to-array operations minimising intermediate allocations
24//! - **Null propagation**: Proper handling of null values following Apache Arrow semantics
25//! - **SIMD vectorisation**: Hardware-accelerated operations on compatible data types
26
27include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
28
29use std::hash::Hash;
30use std::marker::PhantomData;
31
32use minarrow::traits::type_unions::Float;
33use minarrow::utils::confirm_equal_len;
34use minarrow::{Bitmask, BooleanAVT, BooleanArray, CategoricalAVT, Integer, Numeric, StringAVT};
35
36#[cfg(not(feature = "simd"))]
37use crate::kernels::bitmask::std::{and_masks, not_mask};
38#[cfg(not(feature = "simd"))]
39use crate::kernels::comparison::cmp_bitmask_std;
40use crate::kernels::comparison::{
41    cmp_dict_dict, cmp_dict_str, cmp_numeric, cmp_str_dict, cmp_str_str,
42};
43use crate::kernels::logical::{
44    cmp_between, cmp_dict_between, cmp_dict_in, cmp_in, cmp_in_f, cmp_str_between, cmp_str_in,
45};
46use crate::operators::ComparisonOperator;
47use minarrow::enums::error::KernelError;
48
49/// Returns a new Bitmask for boolean buffers, all bits cleared (false).
50#[inline(always)]
51fn new_bool_bitmask(len_bits: usize) -> Bitmask {
52    Bitmask::new_set_all(len_bits, false)
53}
54
55/// Returns a new Bitmask for boolean buffers, all bits set (true).
56#[inline(always)]
57fn full_bool_bitmask(len_bits: usize) -> Bitmask {
58    Bitmask::new_set_all(len_bits, true)
59}
60
61/// Merge two optional Bitmasks into a new output mask, computing per-row AND.
62/// Returns None if both inputs are None (output is dense).
63#[inline]
64fn merge_bitmasks_to_new(
65    lhs: Option<&Bitmask>,
66    rhs: Option<&Bitmask>,
67    len: usize,
68) -> Option<Bitmask> {
69    if let Some(m) = lhs {
70        debug_assert!(
71            m.capacity() >= len,
72            "lhs null mask too small: capacity {} < len {}",
73            m.capacity(),
74            len
75        );
76    }
77    if let Some(m) = rhs {
78        debug_assert!(
79            m.capacity() >= len,
80            "rhs null mask too small: capacity {} < len {}",
81            m.capacity(),
82            len
83        );
84    }
85
86    match (lhs, rhs) {
87        (None, None) => None,
88
89        (Some(l), None) | (None, Some(l)) => {
90            let mut out = Bitmask::new_set_all(len, true);
91            for i in 0..len {
92                unsafe {
93                    out.set_unchecked(i, l.get_unchecked(i));
94                }
95            }
96            Some(out)
97        }
98
99        (Some(l), Some(r)) => {
100            let mut out = Bitmask::new_set_all(len, true);
101            for i in 0..len {
102                unsafe {
103                    out.set_unchecked(i, l.get_unchecked(i) && r.get_unchecked(i));
104                }
105            }
106            Some(out)
107        }
108    }
109}
110
111// Numeric and Float
112
113/// Applies comparison operations between numeric arrays with comprehensive operator support.
114///
115/// Performs element-wise comparison operations between two numeric arrays using the specified
116/// comparison operator. Supports the full range of SQL comparison semantics including
117/// set membership operations and null-aware comparisons.
118///
119/// ## Parameters
120/// * `lhs` - Left-hand side numeric array for comparison
121/// * `rhs` - Right-hand side numeric array for comparison  
122/// * `mask` - Optional bitmask indicating valid elements in input arrays
123/// * `op` - Comparison operator defining the comparison semantics to apply
124///
125/// ## Returns
126/// Returns `Result<BooleanArray<()>, KernelError>` containing:
127/// - **Success**: Boolean array with comparison results
128/// - **Error**: KernelError if comparison operation fails
129///
130/// ## Supported Operations
131/// - **Basic comparisons**: `<`, `<=`, `>`, `>=`, `==`, `!=`
132/// - **Set operations**: `IN`, `NOT IN` for membership testing
133/// - **Range operations**: `BETWEEN` for range inclusion testing
134/// - **Null operations**: `IS NULL`, `IS NOT NULL` for null checking
135///
136/// ## Examples
137/// ```rust,ignore
138/// use simd_kernels::kernels::binary::apply_cmp;
139/// use simd_kernels::operators::ComparisonOperator;
140///
141/// let lhs = [1, 2, 3, 4];
142/// let rhs = [2, 2, 2, 2];
143/// let result = apply_cmp(&lhs, &rhs, None, ComparisonOperator::LessThan).unwrap();
144/// // Result: [true, false, false, false]
145/// ```
146pub fn apply_cmp<T>(
147    lhs: &[T],
148    rhs: &[T],
149    mask: Option<&Bitmask>,
150    op: ComparisonOperator,
151) -> Result<BooleanArray<()>, KernelError>
152where
153    T: Numeric + Copy + Hash + Eq + PartialOrd + 'static,
154{
155    let len = lhs.len();
156    match op {
157        ComparisonOperator::Between => {
158            let mut out = cmp_between(lhs, rhs)?;
159            out.null_mask = mask.cloned();
160            Ok(out)
161        }
162        ComparisonOperator::In => {
163            let mut out = cmp_in(lhs, rhs)?;
164            out.null_mask = mask.cloned();
165            Ok(out)
166        }
167        ComparisonOperator::NotIn => {
168            let mut out = cmp_in(lhs, rhs)?;
169            for i in 0..len {
170                unsafe { out.data.set_unchecked(i, !out.data.get_unchecked(i)) };
171            }
172            out.null_mask = mask.cloned();
173            Ok(out)
174        }
175        ComparisonOperator::IsNull => {
176            let data = match mask {
177                Some(m) => {
178                    #[cfg(feature = "simd")]
179                    {
180                        minarrow::kernels::bitmask::simd::not_mask_simd::<W8>((m, 0, len))
181                    }
182                    #[cfg(not(feature = "simd"))]
183                    {
184                        not_mask((m, 0, len))
185                    }
186                }
187                None => new_bool_bitmask(len),
188            };
189            Ok(BooleanArray {
190                data,
191                null_mask: None,
192                len,
193                _phantom: std::marker::PhantomData,
194            })
195        }
196        ComparisonOperator::IsNotNull => {
197            let data = match mask {
198                Some(m) => m.slice_clone(0, len),
199                None => full_bool_bitmask(len),
200            };
201            Ok(BooleanArray {
202                data,
203                null_mask: None,
204                len,
205                _phantom: std::marker::PhantomData,
206            })
207        }
208        _ => {
209            let mut out = cmp_numeric(lhs, rhs, mask, op)?;
210            out.null_mask = mask.cloned();
211            Ok(out)
212        }
213    }
214}
215
216/// Applies comparison operations between floating-point arrays with IEEE 754 compliance.
217///
218/// Performs element-wise floating-point comparisons with proper handling of IEEE 754
219/// special values (NaN, infinity). Implements comprehensive comparison semantics for
220/// floating-point data including set operations and null-aware processing.
221///
222/// ## Parameters
223/// * `lhs` - Left-hand side floating-point array for comparison
224/// * `rhs` - Right-hand side floating-point array for comparison
225/// * `mask` - Optional bitmask indicating valid elements in arrays
226/// * `op` - Comparison operator specifying the comparison type to perform
227///
228/// ## Returns
229/// Returns `Result<BooleanArray<()>, KernelError>` containing:
230/// - **Success**: Boolean array with IEEE 754 compliant comparison results
231/// - **Error**: KernelError if floating-point comparison fails
232/// ```
233pub fn apply_cmp_f<T>(
234    lhs: &[T],
235    rhs: &[T],
236    mask: Option<&Bitmask>,
237    op: ComparisonOperator,
238) -> Result<BooleanArray<()>, KernelError>
239where
240    T: Float + Numeric + Copy + 'static,
241{
242    let len = lhs.len();
243    match op {
244        ComparisonOperator::Between => {
245            let mut out = cmp_between(lhs, rhs)?;
246            out.null_mask = mask.cloned();
247            Ok(out)
248        }
249        ComparisonOperator::In => {
250            let mut out = cmp_in_f(lhs, rhs)?;
251            out.null_mask = mask.cloned();
252            Ok(out)
253        }
254        ComparisonOperator::NotIn => {
255            let mut out = cmp_in_f(lhs, rhs)?;
256            for i in 0..len {
257                unsafe { out.data.set_unchecked(i, !out.data.get_unchecked(i)) };
258            }
259            out.null_mask = mask.cloned();
260            Ok(out)
261        }
262        ComparisonOperator::IsNull => {
263            let data = match mask {
264                Some(m) => {
265                    #[cfg(feature = "simd")]
266                    {
267                        minarrow::kernels::bitmask::simd::not_mask_simd::<W8>((m, 0, len))
268                    }
269                    #[cfg(not(feature = "simd"))]
270                    {
271                        not_mask((m, 0, len))
272                    }
273                }
274                None => new_bool_bitmask(len),
275            };
276            Ok(BooleanArray {
277                data,
278                null_mask: None,
279                len,
280                _phantom: std::marker::PhantomData,
281            })
282        }
283        ComparisonOperator::IsNotNull => {
284            let data = match mask {
285                Some(m) => m.slice_clone(0, len),
286                None => full_bool_bitmask(len),
287            };
288            Ok(BooleanArray {
289                data,
290                null_mask: None,
291                len,
292                _phantom: std::marker::PhantomData,
293            })
294        }
295        ComparisonOperator::Equals | ComparisonOperator::NotEquals => {
296            let mut out = cmp_numeric(lhs, rhs, mask, op)?;
297            // Patch NaN-pairs for legacy semantics
298            for i in 0..len {
299                let is_valid = mask.map_or(true, |m| unsafe { m.get_unchecked(i) });
300                if is_valid && lhs[i].is_nan() && rhs[i].is_nan() {
301                    match op {
302                        ComparisonOperator::Equals => unsafe { out.data.set_unchecked(i, true) },
303                        ComparisonOperator::NotEquals => unsafe {
304                            out.data.set_unchecked(i, false)
305                        },
306                        _ => {}
307                    }
308                }
309            }
310            out.null_mask = mask.cloned();
311            Ok(out)
312        }
313        _ => {
314            let mut out = cmp_numeric(lhs, rhs, mask, op)?;
315            out.null_mask = mask.cloned();
316            Ok(out)
317        }
318    }
319}
320
321/// Boolean Bit packed
322///
323/// Note that this function delegates to SIMD or not SIMD within the inner cmp_bool
324/// module, given bool is self-contained as a datatype.
325/// "Elementwise boolean bitwise SIMD comparison falling back to scalar if simd not enabled.
326/// Returns `BooleanArray<()>`."
327#[inline(always)]
328pub fn apply_cmp_bool(
329    lhs: BooleanAVT<'_, ()>,
330    rhs: BooleanAVT<'_, ()>,
331    op: ComparisonOperator,
332) -> Result<BooleanArray<()>, KernelError> {
333    let (lhs_arr, lhs_off, len) = lhs;
334    let (rhs_arr, rhs_off, rlen) = rhs;
335    confirm_equal_len("apply_cmp_bool_windowed: window length mismatch", len, rlen)?;
336
337    // Merge windowed null masks
338    #[cfg(feature = "simd")]
339    let merged_null_mask: Option<Bitmask> =
340        match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
341            (None, None) => None,
342            (Some(m), None) => Some(m.slice_clone(lhs_off, len)),
343            (None, Some(m)) => Some(m.slice_clone(rhs_off, len)),
344            (Some(a), Some(b)) => {
345                use minarrow::kernels::bitmask::simd::and_masks_simd;
346                let am = (a, lhs_off, len);
347                let bm = (b, rhs_off, len);
348                Some(and_masks_simd::<W8>(am, bm))
349            }
350        };
351
352    #[cfg(not(feature = "simd"))]
353    let merged_null_mask: Option<Bitmask> =
354        match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
355            (None, None) => None,
356            (Some(m), None) => Some(m.slice_clone(lhs_off, len)),
357            (None, Some(m)) => Some(m.slice_clone(rhs_off, len)),
358            (Some(a), Some(b)) => {
359                let am = (a, lhs_off, len);
360                let bm = (b, rhs_off, len);
361                Some(and_masks(am, bm))
362            }
363        };
364
365    let mask_slice = merged_null_mask.as_ref().map(|m| (m, 0, len));
366
367    #[cfg(feature = "simd")]
368    let data = match op {
369        ComparisonOperator::Equals
370        | ComparisonOperator::NotEquals
371        | ComparisonOperator::LessThan
372        | ComparisonOperator::LessThanOrEqualTo
373        | ComparisonOperator::GreaterThan
374        | ComparisonOperator::GreaterThanOrEqualTo
375        | ComparisonOperator::In
376        | ComparisonOperator::NotIn => crate::kernels::comparison::cmp_bitmask_simd::<W8>(
377            (&lhs_arr.data, lhs_off, len),
378            (&rhs_arr.data, rhs_off, len),
379            mask_slice,
380            op,
381        )?,
382        ComparisonOperator::IsNull => {
383            let data = match merged_null_mask.as_ref() {
384                Some(mask) => minarrow::kernels::bitmask::simd::not_mask_simd::<W8>((mask, 0, len)),
385                None => Bitmask::new_set_all(len, false),
386            };
387            return Ok(BooleanArray {
388                data,
389                null_mask: None,
390                len,
391                _phantom: PhantomData,
392            });
393        }
394        ComparisonOperator::IsNotNull => {
395            let data = match merged_null_mask.as_ref() {
396                Some(mask) => mask.slice_clone(0, len),
397                None => Bitmask::new_set_all(len, true),
398            };
399            return Ok(BooleanArray {
400                data,
401                null_mask: None,
402                len,
403                _phantom: PhantomData,
404            });
405        }
406        ComparisonOperator::Between => {
407            return Err(KernelError::InvalidArguments(
408                "Set operations are not defined for Bool arrays".to_owned(),
409            ));
410        }
411    };
412
413    #[cfg(not(feature = "simd"))]
414    let data = match op {
415        ComparisonOperator::Equals
416        | ComparisonOperator::NotEquals
417        | ComparisonOperator::LessThan
418        | ComparisonOperator::LessThanOrEqualTo
419        | ComparisonOperator::GreaterThan
420        | ComparisonOperator::GreaterThanOrEqualTo
421        | ComparisonOperator::In
422        | ComparisonOperator::NotIn => cmp_bitmask_std(
423            (&lhs_arr.data, lhs_off, len),
424            (&rhs_arr.data, rhs_off, len),
425            mask_slice,
426            op,
427        )?,
428        ComparisonOperator::IsNull => {
429            let data = match merged_null_mask.as_ref() {
430                Some(mask) => not_mask((mask, 0, len)),
431                None => Bitmask::new_set_all(len, false),
432            };
433            return Ok(BooleanArray {
434                data,
435                null_mask: None,
436                len,
437                _phantom: PhantomData,
438            });
439        }
440        ComparisonOperator::IsNotNull => {
441            let data = match merged_null_mask.as_ref() {
442                Some(mask) => mask.slice_clone(0, len),
443                None => Bitmask::new_set_all(len, true),
444            };
445            return Ok(BooleanArray {
446                data,
447                null_mask: None,
448                len,
449                _phantom: PhantomData,
450            });
451        }
452        ComparisonOperator::Between => {
453            return Err(KernelError::InvalidArguments(
454                "Set operations are not defined for Bool arrays".to_owned(),
455            ));
456        }
457    };
458
459    Ok(BooleanArray {
460        data,
461        null_mask: merged_null_mask,
462        len,
463        _phantom: PhantomData,
464    })
465}
466
467// Utf8 Dictionary
468
469/// Applies comparison operations between corresponding string elements from string arrays.
470///
471/// Performs element-wise string comparison using lexicographic ordering with
472/// UTF-8 awareness and efficient null handling.
473///
474/// # Parameters
475/// - `lhs`: Left-hand string array view tuple `(StringArray, offset, length)`
476/// - `rhs`: Right-hand string array view tuple `(StringArray, offset, length)`
477/// - `op`: Comparison operator (Eq, Ne, Lt, Le, Gt, Ge, In, NotIn)
478///
479/// # String Comparison
480/// - Uses Rust's standard UTF-8 aware lexicographic ordering
481/// - Null strings handled consistently across all operations
482/// - Set operations support string membership testing
483///
484/// # Returns
485/// `Result<BooleanArray<()>, KernelError>` where true elements satisfy the comparison.
486///
487/// # Performance
488/// - Optimised string comparison avoiding unnecessary allocations
489/// - Efficient null mask processing with bitwise operations
490/// - Dictionary-style operations for set membership testing
491pub fn apply_cmp_str<T: Integer>(
492    lhs: StringAVT<T>,
493    rhs: StringAVT<T>,
494    op: ComparisonOperator,
495) -> Result<BooleanArray<()>, KernelError> {
496    // Destructure slice windows
497    let (larr, loff, llen) = lhs;
498    let (rarr, roff, rlen) = rhs;
499
500    assert_eq!(llen, rlen, "apply_cmp_str: slice lengths must match");
501
502    let lmask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
503    let rmask = rarr.null_mask.as_ref().map(|m| m.slice_clone(roff, rlen));
504    let null_mask = merge_bitmasks_to_new(lmask.as_ref(), rmask.as_ref(), llen);
505
506    let mut out = match op {
507        ComparisonOperator::Between => cmp_str_between((larr, loff, llen), (rarr, roff, rlen)),
508        ComparisonOperator::In => cmp_str_in((larr, loff, llen), (rarr, roff, rlen)),
509        ComparisonOperator::NotIn => {
510            let mut b = cmp_str_in((larr, loff, llen), (rarr, roff, rlen))?;
511            debug_assert!(
512                b.data.capacity() >= llen,
513                "bitmask capacity {} < needed len {}",
514                b.data.capacity(),
515                llen
516            );
517            for i in 0..llen {
518                unsafe { b.data.set_unchecked(i, !b.data.get_unchecked(i)) };
519            }
520            Ok(b)
521        }
522        ComparisonOperator::IsNull => {
523            let data = match null_mask.as_ref() {
524                Some(m) => {
525                    #[cfg(feature = "simd")]
526                    {
527                        minarrow::kernels::bitmask::simd::not_mask_simd::<W8>((m, 0, llen))
528                    }
529                    #[cfg(not(feature = "simd"))]
530                    {
531                        not_mask((m, 0, llen))
532                    }
533                }
534                None => new_bool_bitmask(llen),
535            };
536            return Ok(BooleanArray {
537                data,
538                null_mask: None,
539                len: llen,
540                _phantom: std::marker::PhantomData,
541            });
542        }
543        ComparisonOperator::IsNotNull => {
544            let data = match null_mask.as_ref() {
545                Some(m) => m.slice_clone(0, llen),
546                None => full_bool_bitmask(llen),
547            };
548            return Ok(BooleanArray {
549                data,
550                null_mask: None,
551                len: llen,
552                _phantom: std::marker::PhantomData,
553            });
554        }
555        _ => cmp_str_str((larr, loff, llen), (rarr, roff, rlen), op),
556    }?;
557    out.null_mask = null_mask;
558    out.len = llen;
559    Ok(out)
560}
561
562/// Applies comparison operations between string array elements and categorical dictionary values.
563///
564/// Performs element-wise comparison where left operands are strings and right operands
565/// are resolved from a categorical array's dictionary.
566///
567/// # Parameters
568/// - `lhs`: String array view tuple `(StringArray, offset, length)`
569/// - `rhs`: Categorical array view tuple `(CategoricalArray, offset, length)`
570/// - `op`: Comparison operator (Eq, Ne, Lt, Le, Gt, Ge, In, NotIn)
571///
572/// # Type Parameters
573/// - `T`: Integer type for string array offsets
574/// - `U`: Integer type for categorical array indices
575///
576/// # Returns
577/// `Result<BooleanArray<()>, KernelError>` where true elements satisfy the comparison.
578///
579/// # Performance
580/// Dictionary lookups amortised across categorical comparisons with caching opportunities.
581pub fn apply_cmp_str_dict<T: Integer, U: Integer>(
582    lhs: StringAVT<T>,
583    rhs: CategoricalAVT<U>,
584    op: ComparisonOperator,
585) -> Result<BooleanArray<()>, KernelError> {
586    let (larr, loff, llen) = lhs;
587    let (rarr, roff, rlen) = rhs;
588    assert_eq!(llen, rlen, "apply_cmp_str_dict: slice lengths must match");
589
590    // TODO: Avoid double clone - merge/slice bitmasks in one go
591    let lmask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
592    let rmask = rarr.null_mask.as_ref().map(|m| m.slice_clone(roff, rlen));
593    let null_mask = merge_bitmasks_to_new(lmask.as_ref(), rmask.as_ref(), llen);
594
595    let mut out = cmp_str_dict((larr, loff, llen), (rarr, roff, rlen), op)?;
596    out.null_mask = null_mask;
597    out.len = llen;
598    Ok(out)
599}
600
601/// Applies comparison operations between categorical dictionary values and string array elements.
602///
603/// Performs element-wise comparison where left operands are resolved from a categorical
604/// array's dictionary and right operands are strings.
605///
606/// # Parameters
607/// - `lhs`: Categorical array view tuple `(CategoricalArray, offset, length)`
608/// - `rhs`: String array view tuple `(StringArray, offset, length)`
609/// - `op`: Comparison operator (Eq, Ne, Lt, Le, Gt, Ge, In, NotIn)
610///
611/// # Type Parameters
612/// - `T`: Integer type for categorical array indices
613/// - `U`: Integer type for string array offsets
614///
615/// # Returns
616/// `Result<BooleanArray<()>, KernelError>` where true elements satisfy the comparison.
617///
618/// # Performance
619/// Dictionary lookups optimised with categorical encoding efficiency.
620pub fn apply_cmp_dict_str<T: Integer, U: Integer>(
621    lhs: CategoricalAVT<T>,
622    rhs: StringAVT<U>,
623    op: ComparisonOperator,
624) -> Result<BooleanArray<()>, KernelError> {
625    let (larr, loff, llen) = lhs;
626    let (rarr, roff, rlen) = rhs;
627    assert_eq!(llen, rlen, "apply_cmp_dict_str: slice lengths must match");
628
629    // TODO: Avoid double clone - merge/slice bitmasks in one go
630    let lmask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
631    let rmask = rarr.null_mask.as_ref().map(|m| m.slice_clone(roff, rlen));
632    let null_mask = merge_bitmasks_to_new(lmask.as_ref(), rmask.as_ref(), llen);
633
634    let mut out = cmp_dict_str((larr, loff, llen), (rarr, roff, rlen), op)?;
635    out.null_mask = null_mask;
636    out.len = llen;
637    Ok(out)
638}
639
640/// Applies comparison operations between corresponding categorical dictionary values.
641///
642/// Performs element-wise comparison by resolving both operands from their respective
643/// categorical dictionaries and comparing the resulting string values.
644///
645/// # Parameters
646/// - `lhs`: Left categorical array view tuple `(CategoricalArray, offset, length)`
647/// - `rhs`: Right categorical array view tuple `(CategoricalArray, offset, length)`
648/// - `op`: Comparison operator (Eq, Ne, Lt, Le, Gt, Ge, In, NotIn)
649///
650/// # Type Parameters
651/// - `T`: Integer type for categorical array indices (must implement `Hash`)
652///
653/// # Returns
654/// `Result<BooleanArray<()>, KernelError>` where true elements satisfy the comparison.
655///
656/// # Performance
657/// - Dictionary lookups amortised across bulk categorical operations
658/// - Hash-based optimisations for set membership operations
659/// - Efficient categorical code comparison where possible
660pub fn apply_cmp_dict<T: Integer + Hash>(
661    lhs: CategoricalAVT<T>,
662    rhs: CategoricalAVT<T>,
663    op: ComparisonOperator,
664) -> Result<BooleanArray<()>, KernelError> {
665    let (larr, loff, llen) = lhs;
666    let (rarr, roff, rlen) = rhs;
667    assert_eq!(llen, rlen, "apply_cmp_dict: slice lengths must match");
668    let lmask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
669    let rmask = rarr.null_mask.as_ref().map(|m| m.slice_clone(roff, rlen));
670    let null_mask = merge_bitmasks_to_new(lmask.as_ref(), rmask.as_ref(), llen);
671    let mut out = match op {
672        ComparisonOperator::Between => cmp_dict_between((larr, loff, llen), (rarr, roff, rlen)),
673        ComparisonOperator::In => cmp_dict_in((larr, loff, llen), (rarr, roff, rlen)),
674        ComparisonOperator::NotIn => {
675            let mut b = cmp_dict_in((larr, loff, llen), (rarr, roff, rlen))?;
676            for i in 0..llen {
677                unsafe {
678                    b.data.set_unchecked(i, !b.data.get_unchecked(i));
679                }
680            }
681            Ok(b)
682        }
683        ComparisonOperator::IsNull => {
684            let data = match null_mask.as_ref() {
685                Some(m) => {
686                    #[cfg(feature = "simd")]
687                    {
688                        minarrow::kernels::bitmask::simd::not_mask_simd::<W8>((m, 0, llen))
689                    }
690                    #[cfg(not(feature = "simd"))]
691                    {
692                        not_mask((m, 0, llen))
693                    }
694                }
695                None => new_bool_bitmask(llen),
696            };
697            return Ok(BooleanArray {
698                data,
699                null_mask: None,
700                len: llen,
701                _phantom: std::marker::PhantomData,
702            });
703        }
704        ComparisonOperator::IsNotNull => {
705            let data = match null_mask.as_ref() {
706                Some(m) => m.slice_clone(0, llen),
707                None => full_bool_bitmask(llen),
708            };
709            return Ok(BooleanArray {
710                data,
711                null_mask: None,
712                len: llen,
713                _phantom: std::marker::PhantomData,
714            });
715        }
716        _ => cmp_dict_dict((larr, loff, llen), (rarr, roff, rlen), op),
717    }?;
718    out.null_mask = null_mask;
719    out.len = llen;
720    Ok(out)
721}
722
723#[cfg(test)]
724mod tests {
725    use minarrow::structs::variants::categorical::CategoricalArray;
726    use minarrow::structs::variants::string::StringArray;
727    use minarrow::{Bitmask, BooleanArray, MaskedArray, vec64};
728
729    use super::*;
730
731    // --- Helpers ---
732    fn bm(bools: &[bool]) -> Bitmask {
733        Bitmask::from_bools(bools)
734    }
735    fn bool_arr(bools: &[bool]) -> BooleanArray<()> {
736        BooleanArray::from_slice(bools)
737    }
738
739    // ----------- Numeric & Float -----------
740    #[test]
741    fn test_apply_cmp_numeric_all_ops() {
742        let a = vec64![1, 2, 3, 4, 5, 6];
743        let b = vec64![3, 2, 1, 4, 5, 0];
744        let mask = bm(&[true, false, true, true, true, true]);
745
746        // Standard operators
747        for &op in &[
748            ComparisonOperator::Equals,
749            ComparisonOperator::NotEquals,
750            ComparisonOperator::LessThan,
751            ComparisonOperator::LessThanOrEqualTo,
752            ComparisonOperator::GreaterThan,
753            ComparisonOperator::GreaterThanOrEqualTo,
754        ] {
755            let arr = apply_cmp(&a, &b, Some(&mask), op).unwrap();
756            for i in 0..a.len() {
757                let expect = match op {
758                    ComparisonOperator::Equals => a[i] == b[i],
759                    ComparisonOperator::NotEquals => a[i] != b[i],
760                    ComparisonOperator::LessThan => a[i] < b[i],
761                    ComparisonOperator::LessThanOrEqualTo => a[i] <= b[i],
762                    ComparisonOperator::GreaterThan => a[i] > b[i],
763                    ComparisonOperator::GreaterThanOrEqualTo => a[i] >= b[i],
764                    _ => unreachable!(),
765                };
766                if mask.get(i) {
767                    assert_eq!(arr.data.get(i), expect);
768                } else {
769                    assert_eq!(arr.get(i), None);
770                }
771            }
772            assert_eq!(arr.null_mask, Some(mask.clone()));
773        }
774    }
775
776    #[test]
777    fn test_apply_cmp_numeric_between_in_notin() {
778        let a = vec64![4, 2, 3, 5];
779        let mask = bm(&[true, true, false, true]);
780        // Between [2, 4] (all lhs compared to range 2..=4)
781        let rhs = vec64![2, 4];
782        let arr = apply_cmp(&a, &rhs, Some(&mask), ComparisonOperator::Between).unwrap();
783        assert_eq!(arr.data.get(0), true); // 4 in [2,4]
784        assert_eq!(arr.data.get(1), true); // 2 in [2,4]
785        assert_eq!(arr.get(2), None);
786        assert_eq!(arr.data.get(3), false); // 5 in [2,4]
787        // In: a in b
788        let rhs = vec64![2, 3, 4];
789        let arr = apply_cmp(&a, &rhs, Some(&mask), ComparisonOperator::In).unwrap();
790        assert_eq!(arr.data.get(0), true); // 4 in [2,3,4]
791        assert_eq!(arr.data.get(1), true); // 2 in [2,3,4]
792        assert_eq!(arr.get(2), None);
793        assert_eq!(arr.data.get(3), false); // 5 not in [2,3,4]
794        // NotIn: inverted
795        let arr = apply_cmp(&a, &rhs, Some(&mask), ComparisonOperator::NotIn).unwrap();
796        assert_eq!(arr.data.get(0), false);
797        assert_eq!(arr.data.get(1), false);
798        assert_eq!(arr.get(2), None);
799        assert_eq!(arr.data.get(3), true);
800    }
801
802    #[test]
803    fn test_apply_cmp_numeric_isnull_isnotnull() {
804        let a = vec64![1, 2, 3];
805        let mask = bm(&[true, false, true]); // position 1 is null
806        let arr = apply_cmp(&a, &a, Some(&mask), ComparisonOperator::IsNull).unwrap();
807        assert_eq!(arr.data.get(0), false); // present -> not null
808        assert_eq!(arr.data.get(1), true); // absent -> is null
809        assert_eq!(arr.data.get(2), false); // present -> not null
810        assert_eq!(arr.null_mask, None); // result is always valid
811        let arr = apply_cmp(&a, &a, Some(&mask), ComparisonOperator::IsNotNull).unwrap();
812        assert_eq!(arr.data.get(0), true); // present -> is not null
813        assert_eq!(arr.data.get(1), false); // absent -> is null
814        assert_eq!(arr.data.get(2), true); // present -> is not null
815        assert_eq!(arr.null_mask, None); // result is always valid
816    }
817
818    #[test]
819    fn test_apply_cmp_numeric_edge_cases() {
820        // Empty
821        let a: [i32; 0] = [];
822        let arr = apply_cmp(&a, &a, None, ComparisonOperator::Equals).unwrap();
823        assert_eq!(arr.len, 0);
824        assert!(arr.null_mask.is_none());
825        // All mask None
826        let a = vec64![7];
827        let arr = apply_cmp(&a, &a, None, ComparisonOperator::Equals).unwrap();
828        assert_eq!(arr.data.get(0), true);
829        assert!(arr.null_mask.is_none());
830    }
831
832    #[test]
833    fn test_apply_cmp_f_all_ops_nan_patch() {
834        let a = vec64![1.0, 2.0, f32::NAN, f32::NAN];
835        let b = vec64![1.0, 3.0, f32::NAN, 0.0];
836        let mask = bm(&[true, true, true, false]);
837        // Equals/NotEquals patches NaN==NaN to true/false
838        for &op in &[ComparisonOperator::Equals, ComparisonOperator::NotEquals] {
839            let arr = apply_cmp_f(&a, &b, Some(&mask), op).unwrap();
840            assert_eq!(arr.data.get(2), matches!(op, ComparisonOperator::Equals)) // true for ==, false for !=
841        }
842        // In/NotIn
843        let arr = apply_cmp_f(&a, &b, Some(&mask), ComparisonOperator::In).unwrap();
844        assert_eq!(arr.data.get(0), true); // 1.0 in 1.0
845        assert_eq!(arr.data.get(1), false);
846    }
847
848    #[test]
849    fn test_cmp_bool_w8() {
850        let a = bool_arr(&[true, false, true]);
851        let b = bool_arr(&[false, false, true]);
852        let op = ComparisonOperator::Equals;
853        let arr = apply_cmp_bool((&a, 0, a.len()), (&b, 0, b.len()), op).unwrap();
854        assert!(!arr.data.get(0));
855        assert!(arr.data.get(1));
856        assert!(arr.data.get(2));
857        println!("mask bytes: {:02x?}", arr.data.bits);
858        println!("get(0): {}", arr.data.get(0));
859        println!("get(1): {}", arr.data.get(1));
860        println!("get(2): {}", arr.data.get(2));
861        println!("lhs: {:?}", a);
862        println!("rhs: {:?}", b);
863        println!(
864            "{}: mask bytes: {:?} get(0): {} get(1): {} get(2): {}",
865            stringify!($test_name),
866            arr.data.as_slice(),
867            arr.data.get(0),
868            arr.data.get(1),
869            arr.data.get(2)
870        );
871
872        // NotEquals
873        let arr = apply_cmp_bool(
874            (&a, 0, a.len()),
875            (&b, 0, b.len()),
876            ComparisonOperator::NotEquals,
877        )
878        .unwrap();
879        assert!(arr.data.get(0));
880        assert!(!arr.data.get(1));
881        assert!(!arr.data.get(2));
882
883        // LessThan
884        let arr = apply_cmp_bool(
885            (&a, 0, a.len()),
886            (&b, 0, b.len()),
887            ComparisonOperator::LessThan,
888        )
889        .unwrap();
890        assert!(!arr.data.get(0));
891        assert!(!arr.data.get(1));
892        assert!(!arr.data.get(2));
893
894        // All null masks
895        let mut a = bool_arr(&[true, false]);
896        a.null_mask = Some(bm(&[true, false]));
897        let mut b = bool_arr(&[true, false]);
898        b.null_mask = Some(bm(&[true, true]));
899        let arr = apply_cmp_bool(
900            (&a, 0, a.len()),
901            (&b, 0, b.len()),
902            ComparisonOperator::Equals,
903        )
904        .unwrap();
905        assert!(arr.null_mask.as_ref().unwrap().get(0));
906        assert!(!arr.null_mask.as_ref().unwrap().get(1));
907    }
908
909    #[test]
910    fn test_bool_is_null() {
911        let a = bool_arr(&[true, false]);
912        let b = bool_arr(&[false, true]);
913        let arr = apply_cmp_bool(
914            (&a, 0, a.len()),
915            (&b, 0, b.len()),
916            ComparisonOperator::IsNull,
917        )
918        .unwrap();
919        assert!(!arr.data.get(0));
920        assert!(!arr.data.get(1));
921        let arr = apply_cmp_bool(
922            (&a, 0, a.len()),
923            (&b, 0, b.len()),
924            ComparisonOperator::IsNotNull,
925        )
926        .unwrap();
927        assert!(arr.data.get(0));
928        assert!(arr.data.get(1));
929    }
930
931    // ----------- String/Utf8 -----------
932
933    #[test]
934    fn test_apply_cmp_str_all_ops() {
935        let a = StringArray::<u32>::from_slice(&["foo", "bar", "baz", "qux"]);
936        let b = StringArray::<u32>::from_slice(&["foo", "baz", "baz", "quux"]);
937        let mut a2 = a.clone();
938        a2.set_null(2);
939        let a_slice = (&a, 0, a.len());
940        let b_slice = (&b, 0, b.len());
941        let a2_slice = (&a2, 0, a2.len());
942
943        let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::Equals).unwrap();
944        assert_eq!(arr.data.get(0), true); // "foo" == "foo"
945        assert_eq!(arr.data.get(1), false); // "bar" != "baz"
946        assert_eq!(arr.data.get(2), true); // "baz" == "baz"
947        assert_eq!(arr.data.get(3), false);
948
949        // NotEquals
950        let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::NotEquals).unwrap();
951        assert_eq!(arr.data.get(0), false);
952        assert_eq!(arr.data.get(1), true);
953        assert_eq!(arr.data.get(2), false);
954        assert_eq!(arr.data.get(3), true);
955
956        // LessThan
957        let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::LessThan).unwrap();
958        assert_eq!(arr.data.get(0), false); // "foo" < "foo"
959        assert_eq!(arr.data.get(1), true); // "bar" < "baz"
960        assert_eq!(arr.data.get(2), false);
961        assert_eq!(arr.data.get(3), false);
962
963        // Null merging
964        let mut b2 = b.clone();
965        b2.set_null(1);
966        let b2_slice = (&b2, 0, b2.len());
967        let arr = apply_cmp_str(a2_slice, b2_slice, ComparisonOperator::Equals).unwrap();
968        assert!(!arr.null_mask.as_ref().unwrap().get(2));
969        assert!(!arr.null_mask.as_ref().unwrap().get(1));
970        assert!(arr.null_mask.as_ref().unwrap().get(0));
971        assert!(arr.null_mask.as_ref().unwrap().get(3));
972    }
973
974    #[test]
975    fn test_apply_cmp_str_all_ops_chunk() {
976        let a = StringArray::<u32>::from_slice(&["x", "foo", "bar", "baz", "qux", "y"]);
977        let b = StringArray::<u32>::from_slice(&["q", "foo", "baz", "baz", "quux", "z"]);
978        // Chunk [1,2,3,4]
979        let a_slice = (&a, 1, 4);
980        let b_slice = (&b, 1, 4);
981        let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::Equals).unwrap();
982        assert_eq!(arr.data.get(0), true); // "foo" == "foo"
983        assert_eq!(arr.data.get(1), false); // "bar" != "baz"
984        assert_eq!(arr.data.get(2), true); // "baz" == "baz"
985        assert_eq!(arr.data.get(3), false);
986    }
987
988    #[test]
989    fn test_apply_cmp_str_set_ops() {
990        let a = StringArray::<u32>::from_slice(&["foo", "bar", "baz"]);
991        let b = StringArray::<u32>::from_slice(&["foo", "qux", "baz"]);
992        let a_slice = (&a, 0, a.len());
993        let b_slice = (&b, 0, b.len());
994        // Between
995        let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::Between).unwrap();
996        assert_eq!(arr.len, 3);
997        // In/NotIn
998        let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::In).unwrap();
999        let arr2 = apply_cmp_str(a_slice, b_slice, ComparisonOperator::NotIn).unwrap();
1000        for i in 0..a.len() {
1001            assert_eq!(arr.data.get(i), !arr2.data.get(i));
1002        }
1003    }
1004
1005    #[test]
1006    fn test_apply_cmp_str_set_ops_chunk() {
1007        let a = StringArray::<u32>::from_slice(&["foo", "bar", "baz", "w"]);
1008        let b = StringArray::<u32>::from_slice(&["foo", "qux", "baz", "w"]);
1009        // Chunk [1,2]
1010        let a_slice = (&a, 1, 2);
1011        let b_slice = (&b, 1, 2);
1012        // Between
1013        let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::Between).unwrap();
1014        assert_eq!(arr.len, 2);
1015        // In/NotIn
1016        let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::In).unwrap();
1017        let arr2 = apply_cmp_str(a_slice, b_slice, ComparisonOperator::NotIn).unwrap();
1018        for i in 0..2 {
1019            assert_eq!(arr.data.get(i), !arr2.data.get(i));
1020        }
1021    }
1022
1023    #[test]
1024    fn test_apply_cmp_str_isnull_isnotnull() {
1025        let a = StringArray::<u32>::from_slice(&["foo"]);
1026        let b = StringArray::<u32>::from_slice(&["bar"]);
1027        let a_slice = (&a, 0, a.len());
1028        let b_slice = (&b, 0, b.len());
1029        let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::IsNull).unwrap();
1030        assert_eq!(arr.data.get(0), false);
1031        let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::IsNotNull).unwrap();
1032        assert_eq!(arr.data.get(0), true);
1033    }
1034
1035    #[test]
1036    fn test_apply_cmp_str_isnull_isnotnull_chunk() {
1037        let a = StringArray::<u32>::from_slice(&["pad", "foo"]);
1038        let b = StringArray::<u32>::from_slice(&["pad", "bar"]);
1039        let a_slice = (&a, 1, 1);
1040        let b_slice = (&b, 1, 1);
1041        let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::IsNull).unwrap();
1042        assert_eq!(arr.data.get(0), false);
1043        let arr = apply_cmp_str(a_slice, b_slice, ComparisonOperator::IsNotNull).unwrap();
1044        assert_eq!(arr.data.get(0), true);
1045    }
1046
1047    // ----------- String/Dict -----------
1048
1049    #[test]
1050    fn test_apply_cmp_str_dict() {
1051        let s = StringArray::<u32>::from_slice(&["a", "b", "c"]);
1052        let dict = CategoricalArray::<u32>::from_slices(&[0, 1, 0], &["a".into(), "b".into()]);
1053
1054        let s_slice = (&s, 0, s.len());
1055        let dict_slice = (&dict, 0, dict.data.len());
1056        let arr = apply_cmp_str_dict(s_slice, dict_slice, ComparisonOperator::Equals).unwrap();
1057        assert_eq!(arr.len, 3);
1058
1059        // Null merging
1060        let mut s2 = s.clone();
1061        s2.set_null(0);
1062        let mut d2 = dict.clone();
1063        d2.set_null(1);
1064
1065        let s2_slice = (&s2, 0, s2.len());
1066        let d2_slice = (&d2, 0, d2.data.len());
1067        let arr = apply_cmp_str_dict(s2_slice, d2_slice, ComparisonOperator::Equals).unwrap();
1068
1069        let mask = arr.null_mask.as_ref().unwrap();
1070        assert!(!mask.get(0));
1071        assert!(!mask.get(1));
1072        assert!(mask.get(2));
1073    }
1074
1075    #[test]
1076    fn test_apply_cmp_str_dict_chunk() {
1077        let s = StringArray::<u32>::from_slice(&["pad", "a", "b", "c", "pad2"]);
1078        let dict = CategoricalArray::<u32>::from_slices(
1079            &[2, 0, 1, 0, 2], // All indices valid for 3 unique values
1080            &["z".into(), "a".into(), "b".into()],
1081        );
1082        // Slice window ["a", "b", "c"] and ["a", "b", "a"]
1083        let s_slice = (&s, 1, 3);
1084        let dict_slice = (&dict, 1, 3);
1085        let arr = apply_cmp_str_dict(s_slice, dict_slice, ComparisonOperator::Equals).unwrap();
1086        assert_eq!(arr.len, 3);
1087    }
1088
1089    // ----------- Dict/Str -----------
1090
1091    #[test]
1092    fn test_apply_cmp_dict_str() {
1093        let dict = CategoricalArray::<u32>::from_slices(&[0, 1, 0], &["a".into(), "b".into()]);
1094        let s = StringArray::<u32>::from_slice(&["a", "b", "c"]);
1095        let dict_slice = (&dict, 0, dict.data.len());
1096        let s_slice = (&s, 0, s.len());
1097        let arr = apply_cmp_dict_str(dict_slice, s_slice, ComparisonOperator::Equals).unwrap();
1098        assert_eq!(arr.len, 3);
1099    }
1100
1101    #[test]
1102    fn test_apply_cmp_dict_str_chunk() {
1103        let dict = CategoricalArray::<u32>::from_slices(
1104            &[2, 0, 1, 0, 2], // Use only indices 0, 1, 2
1105            &["z".into(), "a".into(), "b".into()],
1106        );
1107        let s = StringArray::<u32>::from_slice(&["pad", "a", "b", "c", "pad2"]);
1108        let dict_slice = (&dict, 1, 3);
1109        let s_slice = (&s, 1, 3);
1110        let arr = apply_cmp_dict_str(dict_slice, s_slice, ComparisonOperator::Equals).unwrap();
1111        assert_eq!(arr.len, 3);
1112    }
1113
1114    // ----------- Dict/Dict -----------
1115
1116    #[test]
1117    fn test_apply_cmp_dict_all_ops() {
1118        let a = CategoricalArray::<u32>::from_slices(
1119            &[0, 1, 2],
1120            &["dog".into(), "cat".into(), "fish".into()],
1121        );
1122        let b = CategoricalArray::<u32>::from_slices(
1123            &[2, 1, 0],
1124            &["fish".into(), "cat".into(), "dog".into()],
1125        );
1126
1127        let a_slice = (&a, 0, a.data.len());
1128        let b_slice = (&b, 0, b.data.len());
1129
1130        // Equals, NotEquals, etc.
1131        for &op in &[
1132            ComparisonOperator::Equals,
1133            ComparisonOperator::NotEquals,
1134            ComparisonOperator::LessThan,
1135            ComparisonOperator::LessThanOrEqualTo,
1136            ComparisonOperator::GreaterThan,
1137            ComparisonOperator::GreaterThanOrEqualTo,
1138        ] {
1139            let arr = apply_cmp_dict(a_slice, b_slice, op).unwrap();
1140            assert_eq!(arr.len, 3);
1141        }
1142        // Between, In, NotIn
1143        let arr = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::Between).unwrap();
1144        assert_eq!(arr.len, 3);
1145        let arr2 = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::In).unwrap();
1146        let arr3 = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::NotIn).unwrap();
1147        for i in 0..3 {
1148            assert_eq!(arr2.data.get(i), !arr3.data.get(i));
1149        }
1150    }
1151
1152    #[test]
1153    fn test_apply_cmp_dict_all_ops_chunk() {
1154        let a = CategoricalArray::<u32>::from_slices(
1155            &[0, 1, 2, 3, 1], // All indices in 0..4 for 4 unique values
1156            &["pad".into(), "dog".into(), "cat".into(), "fish".into()],
1157        );
1158        let b = CategoricalArray::<u32>::from_slices(
1159            &[3, 2, 1, 0, 2], // All indices in 0..4 for 4 unique values
1160            &["foo".into(), "fish".into(), "cat".into(), "dog".into()],
1161        );
1162        // Slice window [1, 2, 3] and [2, 1, 0]
1163        let a_slice = (&a, 1, 3);
1164        let b_slice = (&b, 1, 3);
1165
1166        for &op in &[
1167            ComparisonOperator::Equals,
1168            ComparisonOperator::NotEquals,
1169            ComparisonOperator::LessThan,
1170            ComparisonOperator::LessThanOrEqualTo,
1171            ComparisonOperator::GreaterThan,
1172            ComparisonOperator::GreaterThanOrEqualTo,
1173        ] {
1174            let arr = apply_cmp_dict(a_slice, b_slice, op).unwrap();
1175            assert_eq!(arr.len, 3);
1176        }
1177        let arr = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::Between).unwrap();
1178        assert_eq!(arr.len, 3);
1179        let arr2 = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::In).unwrap();
1180        let arr3 = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::NotIn).unwrap();
1181        for i in 0..3 {
1182            assert_eq!(arr2.data.get(i), !arr3.data.get(i));
1183        }
1184    }
1185
1186    #[test]
1187    fn test_apply_cmp_dict_isnull_isnotnull() {
1188        let a = CategoricalArray::<u32>::from_slices(&[0, 1], &["x".into(), "y".into()]);
1189        let b = CategoricalArray::<u32>::from_slices(&[1, 0], &["y".into(), "x".into()]);
1190        let a_slice = (&a, 0, a.data.len());
1191        let b_slice = (&b, 0, b.data.len());
1192        let arr = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::IsNull).unwrap();
1193        assert_eq!(arr.data.get(0), false);
1194        let arr = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::IsNotNull).unwrap();
1195        assert_eq!(arr.data.get(0), true);
1196    }
1197
1198    #[test]
1199    fn test_apply_cmp_dict_isnull_isnotnull_chunk() {
1200        let a = CategoricalArray::<u32>::from_slices(
1201            &[2, 0, 1, 2],
1202            &["z".into(), "x".into(), "y".into()],
1203        );
1204        let b = CategoricalArray::<u32>::from_slices(
1205            &[2, 1, 0, 1],
1206            &["w".into(), "y".into(), "x".into(), "z".into()],
1207        );
1208        let a_slice = (&a, 1, 2);
1209        let b_slice = (&b, 1, 2);
1210        let arr = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::IsNull).unwrap();
1211        assert_eq!(arr.data.get(0), false);
1212        let arr = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::IsNotNull).unwrap();
1213        assert_eq!(arr.data.get(0), true);
1214    }
1215
1216    #[test]
1217    #[should_panic(expected = "All indices must be valid for unique_values")]
1218    fn test_apply_cmp_dict_isnull_isnotnull_chunk_invalid_indices() {
1219        let a = CategoricalArray::<u32>::from_slices(
1220            &[9, 0, 1, 9], // 9 is out-of-bounds for 3 unique values
1221            &["z".into(), "x".into(), "y".into()],
1222        );
1223        let b = CategoricalArray::<u32>::from_slices(
1224            &[2, 1, 0, 3], /* 3 is out-of-bounds for 4 unique values (0..3 is valid, so 3 is valid here) */
1225            &["w".into(), "y".into(), "x".into(), "z".into()],
1226        );
1227        let a_slice = (&a, 1, 2);
1228        let b_slice = (&b, 1, 2);
1229        let _ = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::IsNull).unwrap();
1230        let _ = apply_cmp_dict(a_slice, b_slice, ComparisonOperator::IsNotNull).unwrap();
1231    }
1232
1233    // ----------- merge_bitmasks_to_new -----------
1234    #[test]
1235    fn test_merge_bitmasks_to_new_none_none() {
1236        assert!(merge_bitmasks_to_new(None, None, 5).is_none());
1237    }
1238    #[test]
1239    fn test_merge_bitmasks_to_new_some_none() {
1240        let m = bm(&[true, false, true]);
1241        let out = merge_bitmasks_to_new(Some(&m), None, 3).unwrap();
1242        for i in 0..3 {
1243            assert_eq!(out.get(i), m.get(i));
1244        }
1245        let out2 = merge_bitmasks_to_new(None, Some(&m), 3).unwrap();
1246        for i in 0..3 {
1247            assert_eq!(out2.get(i), m.get(i));
1248        }
1249    }
1250    #[test]
1251    fn test_merge_bitmasks_to_new_both_some_and() {
1252        let a = bm(&[true, false, true, true]);
1253        let b = bm(&[true, true, false, true]);
1254        let out = merge_bitmasks_to_new(Some(&a), Some(&b), 4).unwrap();
1255        assert_eq!(out.get(0), true);
1256        assert_eq!(out.get(1), false);
1257        assert_eq!(out.get(2), false);
1258        assert_eq!(out.get(3), true);
1259    }
1260}