simd_kernels/kernels/
logical.rs

1// Copyright Peter Bower 2025. All Rights Reserved.
2// Licensed under Mozilla Public License (MPL) 2.0.
3
4//! # **Logical Operations Kernels Module** - *Boolean Logic and Set Operations*
5//!
6//! Logical operation kernels providing efficient boolean algebra, set membership testing,
7//! and range operations with SIMD acceleration and null-aware semantics. Critical foundation
8//! for query execution, filtering predicates, and analytical data processing workflows.
9//!
10//! ## Core Operations
11//! - **Boolean algebra**: AND, OR, XOR, NOT operations on boolean arrays with bitmask optimisation
12//! - **Set membership**: IN and NOT IN operations with hash-based lookup optimisation
13//! - **Range operations**: BETWEEN predicates for numeric and string data types
14//! - **Pattern matching**: String pattern matching with optimised prefix/suffix detection
15//! - **Null-aware logic**: Three-valued logic implementation following SQL semantics
16//! - **Compound predicates**: Efficient evaluation of complex multi-condition expressions
17
18include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
19
20use std::collections::HashSet;
21use std::hash::Hash;
22use std::marker::PhantomData;
23use std::simd::{LaneCount, SupportedLaneCount};
24#[cfg(feature = "simd")]
25use std::simd::{Mask, Simd, cmp::SimdPartialEq, cmp::SimdPartialOrd, num::SimdFloat};
26
27use minarrow::kernels::arithmetic::string::MAX_DICT_CHECK;
28use minarrow::traits::type_unions::Float;
29use minarrow::{
30    Array, Bitmask, BooleanAVT, BooleanArray, CategoricalAVT, Integer, MaskedArray, Numeric,
31    NumericArray, StringAVT, TextArray, Vec64,
32};
33
34#[cfg(not(feature = "simd"))]
35use crate::kernels::bitmask::dispatch::{and_masks, or_masks, xor_masks};
36use crate::operators::LogicalOperator;
37use minarrow::enums::error::KernelError;
38#[cfg(feature = "simd")]
39use minarrow::kernels::bitmask::simd::{and_masks_simd, or_masks_simd, xor_masks_simd};
40use minarrow::utils::confirm_mask_capacity;
41
42#[cfg(feature = "simd")]
43use minarrow::utils::is_simd_aligned;
44use std::any::TypeId;
45
46/// Builds the Boolean result buffer.
47/// `len` – number of rows that will be written.
48#[inline(always)]
49fn new_bool_buffer(len: usize) -> Bitmask {
50    Bitmask::new_set_all(len, false)
51}
52
53// Between
54
55macro_rules! impl_between_numeric {
56    ($name:ident, $ty:ty, $mask_elem:ty, $lanes:expr) => {
57        /// Test if LHS values fall between RHS min/max bounds, producing boolean result array.
58        #[inline(always)]
59        pub fn $name(
60            lhs: &[$ty],
61            rhs: &[$ty],
62            mask: Option<&Bitmask>, // lhs validity mask
63            has_nulls: bool
64        ) -> Result<BooleanArray<()>, KernelError> {
65
66            let len = lhs.len();
67            if rhs.len() != 2 && rhs.len() != 2 * len {
68                return Err(KernelError::InvalidArguments(
69                    format!("between: RHS must have len 2 or 2×LHS (got lhs: {}, rhs: {})", len, rhs.len())
70                ));
71            }
72
73            if let Some(m) = mask {
74                if m.capacity() < len {
75                    return Err(KernelError::InvalidArguments(
76                        format!("between: mask (Bitmask) capacity must be ≥ len (got capacity: {}, len: {})", m.capacity(), len)
77                    ));
78                }
79            }
80            let mut out_data = new_bool_buffer(len);
81
82            // SIMD fast-path
83            #[cfg(feature = "simd")]
84            {
85                // Check if both arrays are 64-byte aligned for SIMD
86                if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
87                    const N: usize = $lanes;
88                    type V = Simd<$ty, N>;
89                    type M = Mask<$mask_elem, N>;
90
91                    if !has_nulls && rhs.len() == 2 {
92                        let min_v = V::splat(rhs[0]);
93                        let max_v = V::splat(rhs[1]);
94
95                        let mut i = 0usize;
96                        while i + N <= len {
97                            let x = V::from_slice(&lhs[i..i + N]);
98                            let m: M = x.simd_ge(min_v) & x.simd_le(max_v);
99                            let bm = m.to_bitmask();
100
101                            for l in 0..N {
102                                if ((bm >> l) & 1) == 1 {
103                                    out_data.set(i + l, true);
104                                }
105                            }
106                            i += N;
107                        }
108                        // fall back to scalar for tail
109                        for j in i..len {
110                            if lhs[j] >= rhs[0] && lhs[j] <= rhs[1] {
111                                out_data.set(j, true);
112                            }
113                        }
114
115                        return Ok(BooleanArray {
116                            data: out_data.into(),
117                            null_mask: mask.cloned(),
118                            len,
119                            _phantom: PhantomData
120                        });
121                    }
122                }
123                // Fall through to scalar path if alignment check failed
124            }
125
126            // Scalar / null-aware path
127            if rhs.len() == 2 {
128                let (min, max) = (rhs[0], rhs[1]);
129                for i in 0..len {
130                    if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
131                        && lhs[i] >= min
132                        && lhs[i] <= max
133                    {
134                        out_data.set(i, true);
135                    }
136                }
137            } else {
138                // per-row min / max
139                for i in 0..len {
140                    let min = rhs[i * 2];
141                    let max = rhs[i * 2 + 1];
142                    if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
143                        && lhs[i] >= min
144                        && lhs[i] <= max
145                    {
146                        out_data.set(i, true);
147                    }
148                }
149            }
150
151            Ok(BooleanArray {
152                data: out_data.into(),
153                null_mask: mask.cloned(),
154                len,
155                _phantom: PhantomData
156            })
157        }
158    };
159}
160
161// floats
162
163#[inline(always)]
164fn between_generic<T: Numeric + Copy + std::cmp::PartialOrd>(
165    lhs: &[T],
166    rhs: &[T],
167    mask: Option<&Bitmask>,
168    has_nulls: bool,
169) -> Result<BooleanArray<()>, KernelError> {
170    let len = lhs.len();
171    let mut out = new_bool_buffer(len);
172    let _ = confirm_mask_capacity(len, mask)?;
173    if rhs.len() == 2 {
174        let (min, max) = (rhs[0], rhs[1]);
175        for i in 0..len {
176            if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
177                && lhs[i] >= min
178                && lhs[i] <= max
179            {
180                out.set(i, true);
181            }
182        }
183    } else {
184        for i in 0..len {
185            let min = rhs[i * 2];
186            let max = rhs[i * 2 + 1];
187            if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
188                && lhs[i] >= min
189                && lhs[i] <= max
190            {
191                out.set(i, true);
192            }
193        }
194    }
195
196    Ok(BooleanArray {
197        data: out.into(),
198        null_mask: mask.cloned(),
199        len,
200        _phantom: PhantomData,
201    })
202}
203
204// In and Not In
205
206macro_rules! impl_in_int {
207    ($name:ident, $ty:ty, $lanes:expr, $mask_elem:ty) => {
208        /// Test membership of LHS integer values in RHS set, producing boolean result array.
209        #[inline(always)]
210        pub fn $name(
211            lhs: &[$ty],
212            rhs: &[$ty],
213            mask: Option<&Bitmask>,
214            has_nulls: bool,
215        ) -> Result<BooleanArray<()>, KernelError> {
216            let len = lhs.len();
217            let mut out = new_bool_buffer(len);
218            let _ = confirm_mask_capacity(len, mask)?;
219
220            #[cfg(feature = "simd")]
221            {
222                // Check if both arrays are 64-byte aligned for SIMD
223                if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
224                    use crate::utils::bitmask_to_simd_mask;
225                    use core::simd::{Mask, Simd};
226
227                    if rhs.len() <= 16 {
228                        let mut i = 0;
229                        let rhs_simd = rhs;
230                        if !has_nulls {
231                            while i + $lanes <= len {
232                                let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
233                                let mut m = Mask::<$mask_elem, $lanes>::splat(false);
234                                for &v in rhs_simd {
235                                    m |= x.simd_eq(Simd::<$ty, $lanes>::splat(v));
236                                }
237                                let bm = m.to_bitmask();
238                                for l in 0..$lanes {
239                                    if ((bm >> l) & 1) == 1 {
240                                        out.set(i + l, true);
241                                    }
242                                }
243                                i += $lanes;
244                            }
245                            for j in i..len {
246                                if rhs_simd.contains(&lhs[j]) {
247                                    out.set(j, true);
248                                }
249                            }
250                            return Ok(BooleanArray {
251                                data: out.into(),
252                                null_mask: mask.cloned(),
253                                len,
254                                _phantom: PhantomData,
255                            });
256                        } else {
257                            // ---- SIMD + nulls: use bitmask_to_simd_mask
258                            let mb = mask.expect("Bitmask must be Some if has_nulls is set");
259                            let mask_bytes = mb.as_bytes();
260                            while i + $lanes <= len {
261                                let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
262                                // valid lanes
263                                let lane_mask =
264                                    bitmask_to_simd_mask::<$lanes, $mask_elem>(mask_bytes, i, len);
265                                let mut in_mask = Mask::<$mask_elem, $lanes>::splat(false);
266                                for &v in rhs_simd {
267                                    in_mask |= x.simd_eq(Simd::<$ty, $lanes>::splat(v));
268                                }
269                                // Only set bits for lanes that are both valid and match RHS
270                                let valid_in = lane_mask & in_mask;
271                                let bm = valid_in.to_bitmask();
272                                for l in 0..$lanes {
273                                    if ((bm >> l) & 1) == 1 {
274                                        out.set(i + l, true);
275                                    }
276                                }
277                                i += $lanes;
278                            }
279                            for j in i..len {
280                                if unsafe { mb.get_unchecked(j) } && rhs_simd.contains(&lhs[j]) {
281                                    out.set(j, true);
282                                }
283                            }
284                            return Ok(BooleanArray {
285                                data: out.into(),
286                                null_mask: mask.cloned(),
287                                len,
288                                _phantom: PhantomData,
289                            });
290                        }
291                    }
292                }
293                // Fall through to scalar path if alignment check failed
294            }
295
296            // Scalar fallback (null-aware and large-RHS)
297            let set: std::collections::HashSet<$ty> = rhs.iter().copied().collect();
298            for i in 0..len {
299                if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
300                    && set.contains(&lhs[i])
301                {
302                    out.set(i, true);
303                }
304            }
305            Ok(BooleanArray {
306                data: out.into(),
307                null_mask: mask.cloned(),
308                len,
309                _phantom: PhantomData,
310            })
311        }
312    };
313}
314
315/// Implements SIMD/Scalar IN kernel for floats, handling NaN semantics and optional null mask.
316macro_rules! impl_in_float {
317    (
318        $fn_name:ident, $ty:ty, $lanes:expr, $mask_elem:ty
319    ) => {
320        /// Test membership of LHS floating-point values in RHS set with NaN handling.
321        #[inline(always)]
322        pub fn $fn_name(
323            lhs: &[$ty],
324            rhs: &[$ty],
325            mask: Option<&Bitmask>,
326            has_nulls: bool,
327        ) -> Result<BooleanArray<()>, KernelError> {
328            let len = lhs.len();
329            let mut out = new_bool_buffer(len);
330            let _ = confirm_mask_capacity(len, mask)?;
331
332            #[cfg(feature = "simd")]
333            {
334                // Check if both arrays are 64-byte aligned for SIMD
335                if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
336                    use crate::utils::bitmask_to_simd_mask;
337                    use core::simd::{Mask, Simd};
338                    if rhs.len() <= 16 {
339                        let mut i = 0;
340                        if !has_nulls {
341                            while i + $lanes <= len {
342                                let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
343                                let mut m = Mask::<$mask_elem, $lanes>::splat(false);
344                                for &v in rhs {
345                                    let vmask = x.simd_eq(Simd::<$ty, $lanes>::splat(v))
346                                        | (x.is_nan() & Simd::<$ty, $lanes>::splat(v).is_nan());
347                                    m |= vmask;
348                                }
349                                let bm = m.to_bitmask();
350                                for l in 0..$lanes {
351                                    if ((bm >> l) & 1) == 1 {
352                                        out.set(i + l, true);
353                                    }
354                                }
355                                i += $lanes;
356                            }
357                            for j in i..len {
358                                let x = lhs[j];
359                                if rhs.iter().any(|&v| x == v || (x.is_nan() && v.is_nan())) {
360                                    out.set(j, true);
361                                }
362                            }
363                            return Ok(BooleanArray {
364                                data: out.into(),
365                                null_mask: mask.cloned(),
366                                len,
367                                _phantom: PhantomData,
368                            });
369                        } else {
370                            let mb = mask.expect("Bitmask must be Some if nulls are present");
371                            let mask_bytes = mb.as_bytes();
372                            while i + $lanes <= len {
373                                let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
374                                let lane_mask =
375                                    bitmask_to_simd_mask::<$lanes, $mask_elem>(mask_bytes, i, len);
376                                let mut m = Mask::<$mask_elem, $lanes>::splat(false);
377                                for &v in rhs {
378                                    let vmask = x.simd_eq(Simd::<$ty, $lanes>::splat(v))
379                                        | (x.is_nan() & Simd::<$ty, $lanes>::splat(v).is_nan());
380                                    m |= vmask;
381                                }
382                                let m = m & lane_mask;
383                                let bm = m.to_bitmask();
384                                for l in 0..$lanes {
385                                    if ((bm >> l) & 1) == 1 {
386                                        out.set(i + l, true);
387                                    }
388                                }
389                                i += $lanes;
390                            }
391                            for j in i..len {
392                                if mask.map_or(true, |m| unsafe { m.get_unchecked(j) }) {
393                                    let x = lhs[j];
394                                    if rhs.iter().any(|&v| x == v || (x.is_nan() && v.is_nan())) {
395                                        out.set(j, true);
396                                    }
397                                }
398                            }
399                            return Ok(BooleanArray {
400                                data: out.into(),
401                                null_mask: mask.cloned(),
402                                len,
403                                _phantom: PhantomData,
404                            });
405                        }
406                    }
407                }
408                // Fall through to scalar path if alignment check failed
409            }
410
411            // Scalar fallback
412            for i in 0..len {
413                if has_nulls && !mask.map_or(true, |m| unsafe { m.get_unchecked(i) }) {
414                    continue;
415                }
416                let x = lhs[i];
417                if rhs.iter().any(|&v| x == v || (x.is_nan() && v.is_nan())) {
418                    out.set(i, true);
419                }
420            }
421            Ok(BooleanArray {
422                data: out.into(),
423                null_mask: mask.cloned(),
424                len,
425                _phantom: PhantomData,
426            })
427        }
428    };
429}
430
431// Correct MaskElement types per std::simd
432#[cfg(feature = "extended_numeric_types")]
433impl_in_int!(in_i8, i8, W8, i8);
434#[cfg(feature = "extended_numeric_types")]
435impl_in_int!(in_u8, u8, W8, i8);
436#[cfg(feature = "extended_numeric_types")]
437impl_in_int!(in_i16, i16, W16, i16);
438#[cfg(feature = "extended_numeric_types")]
439impl_in_int!(in_u16, u16, W16, i16);
440impl_in_int!(in_i32, i32, W32, i32);
441impl_in_int!(in_u32, u32, W32, i32);
442impl_in_int!(in_i64, i64, W64, i64);
443impl_in_int!(in_u64, u64, W64, i64);
444impl_in_float!(in_f32, f32, W32, i32);
445impl_in_float!(in_f64, f64, W64, i64);
446
447#[cfg(feature = "extended_numeric_types")]
448impl_between_numeric!(between_i8, i8, i8, W8);
449#[cfg(feature = "extended_numeric_types")]
450impl_between_numeric!(between_u8, u8, i8, W8);
451#[cfg(feature = "extended_numeric_types")]
452impl_between_numeric!(between_i16, i16, i16, W16);
453#[cfg(feature = "extended_numeric_types")]
454impl_between_numeric!(between_u16, u16, i16, W16);
455
456impl_between_numeric!(between_i32, i32, i32, W32);
457impl_between_numeric!(between_u32, u32, i32, W32);
458impl_between_numeric!(between_i64, i64, i64, W64);
459impl_between_numeric!(between_u64, u64, i64, W64);
460impl_between_numeric!(between_f32, f32, i32, W32);
461impl_between_numeric!(between_f64, f64, i64, W64);
462
463// String and dictionary
464
465/// Test if LHS string values fall lexicographically between RHS min/max bounds.
466#[inline(always)]
467pub fn cmp_str_between<'a, T: Integer>(
468    lhs: StringAVT<'a, T>,
469    rhs: StringAVT<'a, T>,
470) -> Result<BooleanArray<()>, KernelError> {
471    let (larr, loff, llen) = lhs;
472    let (rarr, roff, rlen) = rhs;
473
474    if rlen < 2 {
475        return Err(KernelError::InvalidArguments(format!(
476            "str_between: RHS must contain at least two values (got {})",
477            rlen
478        )));
479    }
480    let min = rarr.get(roff).unwrap_or("");
481    let max = rarr.get(roff + 1).unwrap_or("");
482    let mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
483    let _ = confirm_mask_capacity(llen, mask.as_ref())?;
484
485    let mut out = new_bool_buffer(llen);
486
487    for i in 0..llen {
488        if mask
489            .as_ref()
490            .map_or(true, |m| unsafe { m.get_unchecked(i) })
491        {
492            let s = unsafe { larr.get_str_unchecked(loff + i) };
493            if s >= min && s <= max {
494                unsafe { out.set_unchecked(i, true) };
495            }
496        }
497    }
498
499    Ok(BooleanArray {
500        data: out.into(),
501        null_mask: mask,
502        len: llen,
503        _phantom: PhantomData,
504    })
505}
506
507#[inline(always)]
508/// Test membership of LHS string values in RHS string set.
509pub fn cmp_str_in<'a, T: Integer>(
510    lhs: StringAVT<'a, T>,
511    rhs: StringAVT<'a, T>,
512) -> Result<BooleanArray<()>, KernelError> {
513    let (larr, loff, llen) = lhs;
514    let (rarr, roff, rlen) = rhs;
515
516    let set: HashSet<&str> = (0..rlen)
517        .map(|i| unsafe { rarr.get_str_unchecked(roff + i) })
518        .collect();
519
520    let mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
521    let _ = confirm_mask_capacity(llen, mask.as_ref())?;
522
523    let mut out = new_bool_buffer(llen);
524
525    for i in 0..llen {
526        if mask
527            .as_ref()
528            .map_or(true, |m| unsafe { m.get_unchecked(i) })
529        {
530            let s = unsafe { larr.get_str_unchecked(loff + i) };
531            if set.contains(s) {
532                unsafe { out.set_unchecked(i, true) };
533            }
534        }
535    }
536    Ok(BooleanArray {
537        data: out.into(),
538        null_mask: mask,
539        len: llen,
540        _phantom: PhantomData,
541    })
542}
543
544// Public functions
545
546/// Test if values fall between min/max bounds for comparable numeric types.
547pub fn cmp_between<T: PartialOrd + Copy + Numeric>(
548    lhs: &[T],
549    rhs: &[T],
550) -> Result<BooleanArray<()>, KernelError> {
551    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
552        return between_i32(
553            unsafe { std::mem::transmute(lhs) },
554            unsafe { std::mem::transmute(rhs) },
555            None,
556            false,
557        );
558    }
559    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
560        return between_u32(
561            unsafe { std::mem::transmute(lhs) },
562            unsafe { std::mem::transmute(rhs) },
563            None,
564            false,
565        );
566    }
567    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
568        return between_i64(
569            unsafe { std::mem::transmute(lhs) },
570            unsafe { std::mem::transmute(rhs) },
571            None,
572            false,
573        );
574    }
575    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
576        return between_u64(
577            unsafe { std::mem::transmute(lhs) },
578            unsafe { std::mem::transmute(rhs) },
579            None,
580            false,
581        );
582    }
583    // Fallback – floats or any other PartialOrd type
584    between_generic(lhs, rhs, None, false)
585}
586
587/// Mask-aware variant
588#[inline(always)]
589pub fn cmp_between_mask<T: PartialOrd + Copy + Numeric>(
590    lhs: &[T],
591    rhs: &[T],
592    mask: Option<&Bitmask>,
593) -> Result<BooleanArray<()>, KernelError> {
594    let has_nulls = mask.is_some();
595    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
596        return between_i32(
597            unsafe { std::mem::transmute(lhs) },
598            unsafe { std::mem::transmute(rhs) },
599            mask,
600            has_nulls,
601        );
602    }
603    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
604        return between_u32(
605            unsafe { std::mem::transmute(lhs) },
606            unsafe { std::mem::transmute(rhs) },
607            mask,
608            has_nulls,
609        );
610    }
611    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
612        return between_i64(
613            unsafe { std::mem::transmute(lhs) },
614            unsafe { std::mem::transmute(rhs) },
615            mask,
616            has_nulls,
617        );
618    }
619    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
620        return between_u64(
621            unsafe { std::mem::transmute(lhs) },
622            unsafe { std::mem::transmute(rhs) },
623            mask,
624            has_nulls,
625        );
626    }
627    between_generic(lhs, rhs, mask, has_nulls)
628}
629
630// In and Not In
631
632/// Test membership in set for hashable types using hash-based lookup.
633pub fn cmp_in<T: Eq + Hash + Copy + 'static>(
634    lhs: &[T],
635    rhs: &[T],
636) -> Result<BooleanArray<()>, KernelError> {
637    // i32
638    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
639        return in_i32(
640            unsafe { std::mem::transmute(lhs) },
641            unsafe { std::mem::transmute(rhs) },
642            None,
643            false,
644        );
645    }
646    // u32
647    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
648        return in_u32(
649            unsafe { std::mem::transmute(lhs) },
650            unsafe { std::mem::transmute(rhs) },
651            None,
652            false,
653        );
654    }
655    // i64
656    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
657        return in_i64(
658            unsafe { std::mem::transmute(lhs) },
659            unsafe { std::mem::transmute(rhs) },
660            None,
661            false,
662        );
663    }
664    // u64
665    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
666        return in_u64(
667            unsafe { std::mem::transmute(lhs) },
668            unsafe { std::mem::transmute(rhs) },
669            None,
670            false,
671        );
672    }
673    // i16
674    #[cfg(feature = "extended_numeric_types")]
675    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i16>() {
676        return in_i16(
677            unsafe { std::mem::transmute(lhs) },
678            unsafe { std::mem::transmute(rhs) },
679            None,
680            false,
681        );
682    }
683    // u16
684    #[cfg(feature = "extended_numeric_types")]
685    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u16>() {
686        return in_u16(
687            unsafe { std::mem::transmute(lhs) },
688            unsafe { std::mem::transmute(rhs) },
689            None,
690            false,
691        );
692    }
693    // i8
694    #[cfg(feature = "extended_numeric_types")]
695    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i8>() {
696        return in_i8(
697            unsafe { std::mem::transmute(lhs) },
698            unsafe { std::mem::transmute(rhs) },
699            None,
700            false,
701        );
702    }
703    // u8
704    #[cfg(feature = "extended_numeric_types")]
705    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u8>() {
706        return in_u8(
707            unsafe { std::mem::transmute(lhs) },
708            unsafe { std::mem::transmute(rhs) },
709            None,
710            false,
711        );
712    }
713    return Err(KernelError::UnsupportedType(
714        "cmp_in: unsupported type for SIMD in".into(),
715    ));
716}
717
718/// Mask-aware variant
719#[inline(always)]
720pub fn cmp_in_mask<T: Eq + Hash + Copy + 'static>(
721    lhs: &[T],
722    rhs: &[T],
723    mask: Option<&Bitmask>,
724) -> Result<BooleanArray<()>, KernelError> {
725    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
726        return in_i32(
727            unsafe { std::mem::transmute(lhs) },
728            unsafe { std::mem::transmute(rhs) },
729            mask,
730            mask.is_some(),
731        );
732    }
733    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
734        return in_u32(
735            unsafe { std::mem::transmute(lhs) },
736            unsafe { std::mem::transmute(rhs) },
737            mask,
738            mask.is_some(),
739        );
740    }
741    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
742        return in_i64(
743            unsafe { std::mem::transmute(lhs) },
744            unsafe { std::mem::transmute(rhs) },
745            mask,
746            mask.is_some(),
747        );
748    }
749    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
750        return in_u64(
751            unsafe { std::mem::transmute(lhs) },
752            unsafe { std::mem::transmute(rhs) },
753            mask,
754            mask.is_some(),
755        );
756    }
757    return Err(KernelError::UnsupportedType(
758        "cmp_in_mask: unsupported type (expected integer type)".into(),
759    ));
760}
761
762/// SIMD-aware, type-specific dispatch for cmp_in_f_mask and cmp_in_f
763#[inline(always)]
764pub fn cmp_in_f_mask<T: Float + Copy>(
765    lhs: &[T],
766    rhs: &[T],
767    mask: Option<&Bitmask>,
768) -> Result<BooleanArray<()>, KernelError> {
769    if TypeId::of::<T>() == TypeId::of::<f32>() {
770        let lhs = unsafe { &*(lhs as *const [T] as *const [f32]) };
771        let rhs = unsafe { &*(rhs as *const [T] as *const [f32]) };
772        in_f32(lhs, rhs, mask, mask.is_some())
773    } else if TypeId::of::<T>() == TypeId::of::<f64>() {
774        let lhs = unsafe { &*(lhs as *const [T] as *const [f64]) };
775        let rhs = unsafe { &*(rhs as *const [T] as *const [f64]) };
776        in_f64(lhs, rhs, mask, mask.is_some())
777    } else {
778        unreachable!("cmp_in_f_mask: Only f32/f64 supported for Float kernels")
779    }
780}
781
782#[inline(always)]
783/// Test membership in set for floating-point types with NaN handling.
784pub fn cmp_in_f<T: Float + Copy>(lhs: &[T], rhs: &[T]) -> Result<BooleanArray<()>, KernelError> {
785    if TypeId::of::<T>() == TypeId::of::<f32>() {
786        let lhs = unsafe { &*(lhs as *const [T] as *const [f32]) };
787        let rhs = unsafe { &*(rhs as *const [T] as *const [f32]) };
788        in_f32(lhs, rhs, None, false)
789    } else if TypeId::of::<T>() == TypeId::of::<f64>() {
790        let lhs = unsafe { &*(lhs as *const [T] as *const [f64]) };
791        let rhs = unsafe { &*(rhs as *const [T] as *const [f64]) };
792        in_f64(lhs, rhs, None, false)
793    } else {
794        unreachable!("cmp_in_f: Only f32/f64 supported for Float kernels")
795    }
796}
797
798// String and dictionary
799
800/// Test if floating-point values fall between bounds with NaN handling.
801pub fn cmp_between_f<T: PartialOrd + Copy + Float + Numeric>(
802    lhs: &[T],
803    rhs: &[T],
804) -> Result<BooleanArray<()>, KernelError> {
805    between_generic(lhs, rhs, None, false)
806}
807
808/// Test if dictionary/categorical values fall between lexicographic bounds.
809pub fn cmp_dict_between<'a, T: Integer>(
810    lhs: CategoricalAVT<'a, T>,
811    rhs: CategoricalAVT<'a, T>,
812) -> Result<BooleanArray<()>, KernelError> {
813    let (larr, loff, llen) = lhs;
814    let (rarr, roff, _rlen) = rhs;
815
816    let min = rarr.get(roff).unwrap_or("");
817    let max = rarr.get(roff + 1).unwrap_or("");
818    let mask = larr.null_mask.as_ref();
819    let _ = confirm_mask_capacity(larr.data.len(), mask)?;
820    let has_nulls = mask.is_some();
821
822    let mut out = new_bool_buffer(llen);
823    for i in 0..llen {
824        let li = loff + i;
825        if !has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(li) }) {
826            let s = unsafe { larr.get_str_unchecked(li) };
827            if s > min && s <= max {
828                unsafe { out.set_unchecked(i, true) };
829            }
830        }
831    }
832    Ok(BooleanArray {
833        data: out.into(),
834        null_mask: mask.cloned(),
835        len: llen,
836        _phantom: PhantomData,
837    })
838}
839
840/// Returns `true` for each row in `lhs` whose string value also appears
841/// anywhere in `rhs`, respecting null masks on both sides.
842/// Returns `true` for each row in `lhs` whose string value also appears
843/// anywhere in `rhs`, respecting null masks on both sides.
844pub fn cmp_dict_in<'a, T: Integer + Hash>(
845    lhs: CategoricalAVT<'a, T>,
846    rhs: CategoricalAVT<'a, T>,
847) -> Result<BooleanArray<()>, KernelError> {
848    let (larr, loff, llen) = lhs;
849    let (rarr, roff, rlen) = rhs;
850    let mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
851    let _ = confirm_mask_capacity(llen, mask.as_ref())?;
852
853    let mut out = Bitmask::new_set_all(llen, false);
854
855    if (larr.unique_values.len() == rarr.unique_values.len())
856        && (larr.unique_values.len() <= MAX_DICT_CHECK)
857    {
858        let mut same_dict = true;
859        for (a, b) in larr.unique_values.iter().zip(rarr.unique_values.iter()) {
860            if a != b {
861                same_dict = false;
862                break;
863            }
864        }
865
866        if same_dict {
867            let rhs_codes: HashSet<T> = rarr.data[roff..roff + rlen].iter().copied().collect();
868            for i in 0..llen {
869                if mask
870                    .as_ref()
871                    .map_or(true, |m| unsafe { m.get_unchecked(i) })
872                {
873                    let code = larr.data[loff + i];
874                    if rhs_codes.contains(&code) {
875                        unsafe { out.set_unchecked(i, true) };
876                    }
877                }
878            }
879            return Ok(BooleanArray {
880                data: out.into(),
881                null_mask: mask,
882                len: llen,
883                _phantom: PhantomData,
884            });
885        }
886    }
887
888    let rhs_strings: HashSet<&str> = (0..rlen)
889        .filter(|&i| {
890            rarr.null_mask
891                .as_ref()
892                .map_or(true, |m| unsafe { m.get_unchecked(roff + i) })
893        })
894        .map(|i| unsafe { rarr.get_str_unchecked(roff + i) })
895        .collect();
896
897    for i in 0..llen {
898        if mask
899            .as_ref()
900            .map_or(true, |m| unsafe { m.get_unchecked(i) })
901        {
902            let s = unsafe { larr.get_str_unchecked(loff + i) };
903            if rhs_strings.contains(s) {
904                unsafe { out.set_unchecked(i, true) };
905            }
906        }
907    }
908
909    Ok(BooleanArray {
910        data: out.into(),
911        null_mask: mask,
912        len: llen,
913        _phantom: PhantomData,
914    })
915}
916
917// Is Null and Not null predicates
918
919/// Generate boolean mask indicating null elements in any array type.
920pub fn is_null_array(arr: &Array) -> Result<BooleanArray<()>, KernelError> {
921    let not_null = is_not_null_array(arr)?;
922    Ok(!not_null)
923}
924/// Generate boolean mask indicating non-null elements in any array type.
925pub fn is_not_null_array(arr: &Array) -> Result<BooleanArray<()>, KernelError> {
926    let len = arr.len();
927    let mut data = Bitmask::new_set_all(len, false);
928
929    if let Some(mask) = arr.null_mask() {
930        data = mask.clone();
931    } else {
932        data.fill(true);
933    }
934    Ok(BooleanArray {
935        data,
936        null_mask: None,
937        len,
938        _phantom: PhantomData,
939    })
940}
941
942// Array in, between , not in
943/// Test membership of array elements in values set, dispatching by array type.
944pub fn in_array(input: &Array, values: &Array) -> Result<BooleanArray<()>, KernelError> {
945    match (input, values) {
946        (
947            Array::NumericArray(NumericArray::Int32(a)),
948            Array::NumericArray(NumericArray::Int32(b)),
949        ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
950        (
951            Array::NumericArray(NumericArray::Int64(a)),
952            Array::NumericArray(NumericArray::Int64(b)),
953        ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
954        (
955            Array::NumericArray(NumericArray::UInt32(a)),
956            Array::NumericArray(NumericArray::UInt32(b)),
957        ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
958        (
959            Array::NumericArray(NumericArray::UInt64(a)),
960            Array::NumericArray(NumericArray::UInt64(b)),
961        ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
962        (
963            Array::NumericArray(NumericArray::Float32(a)),
964            Array::NumericArray(NumericArray::Float32(b)),
965        ) => cmp_in_f_mask(&a.data, &b.data, a.null_mask.as_ref()),
966        (
967            Array::NumericArray(NumericArray::Float64(a)),
968            Array::NumericArray(NumericArray::Float64(b)),
969        ) => cmp_in_f_mask(&a.data, &b.data, a.null_mask.as_ref()),
970        (Array::TextArray(TextArray::String32(a)), Array::TextArray(TextArray::String32(b))) => {
971            cmp_str_in((**a).tuple_ref(0, a.len()), (**b).tuple_ref(0, b.len()))
972        }
973        (Array::BooleanArray(a), Array::BooleanArray(b)) => {
974            cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref())
975        }
976        (
977            Array::TextArray(TextArray::Categorical32(a)),
978            Array::TextArray(TextArray::Categorical32(b)),
979        ) => cmp_dict_in((**a).tuple_ref(0, a.len()), (**b).tuple_ref(0, b.len())),
980        _ => unimplemented!(),
981    }
982}
983
984#[inline(always)]
985/// Test non-membership of array elements in values set, dispatching by array type.
986pub fn not_in_array(input: &Array, values: &Array) -> Result<BooleanArray<()>, KernelError> {
987    let result = in_array(input, values)?;
988    Ok(!result)
989}
990
991/// Test if array elements fall between min/max bounds, dispatching by array type.
992pub fn between_array(input: &Array, min: &Array, max: &Array) -> Result<Array, KernelError> {
993    macro_rules! between_case {
994        ($variant:ident, $cmp:ident) => {{
995            let arr = match input {
996                Array::NumericArray(NumericArray::$variant(arr)) => arr,
997                _ => unreachable!(),
998            };
999            let mins = match min {
1000                Array::NumericArray(NumericArray::$variant(arr)) => arr,
1001                _ => unreachable!(),
1002            };
1003            let maxs = match max {
1004                Array::NumericArray(NumericArray::$variant(arr)) => arr,
1005                _ => unreachable!(),
1006            };
1007            let rhs: Vec64<_> = mins
1008                .data
1009                .iter()
1010                .zip(&maxs.data)
1011                .flat_map(|(&lo, &hi)| [lo, hi])
1012                .collect();
1013            Ok(Array::BooleanArray(
1014                $cmp(
1015                    &arr.data,
1016                    &rhs,
1017                    arr.null_mask.as_ref(),
1018                    arr.null_mask.is_some(),
1019                )?
1020                .into(),
1021            ))
1022        }};
1023    }
1024
1025    match (input, min, max) {
1026        (
1027            Array::NumericArray(NumericArray::Int32(..)),
1028            Array::NumericArray(NumericArray::Int32(..)),
1029            Array::NumericArray(NumericArray::Int32(..)),
1030        ) => between_case!(Int32, between_i32),
1031        (
1032            Array::NumericArray(NumericArray::Int64(..)),
1033            Array::NumericArray(NumericArray::Int64(..)),
1034            Array::NumericArray(NumericArray::Int64(..)),
1035        ) => between_case!(Int64, between_i64),
1036        (
1037            Array::NumericArray(NumericArray::UInt32(..)),
1038            Array::NumericArray(NumericArray::UInt32(..)),
1039            Array::NumericArray(NumericArray::UInt32(..)),
1040        ) => between_case!(UInt32, between_u32),
1041        (
1042            Array::NumericArray(NumericArray::UInt64(..)),
1043            Array::NumericArray(NumericArray::UInt64(..)),
1044            Array::NumericArray(NumericArray::UInt64(..)),
1045        ) => between_case!(UInt64, between_u64),
1046        (
1047            Array::NumericArray(NumericArray::Float32(..)),
1048            Array::NumericArray(NumericArray::Float32(..)),
1049            Array::NumericArray(NumericArray::Float32(..)),
1050        ) => between_case!(Float32, between_generic),
1051        (
1052            Array::NumericArray(NumericArray::Float64(..)),
1053            Array::NumericArray(NumericArray::Float64(..)),
1054            Array::NumericArray(NumericArray::Float64(..)),
1055        ) => between_case!(Float64, between_generic),
1056        _ => Err(KernelError::UnsupportedType(
1057            "Unsupported Type Error.".to_string(),
1058        )),
1059    }
1060}
1061
1062/// Bitwise NOT of a bit-packed boolean mask window.
1063/// Offset is a bit offset; len is in bits.
1064/// Requires offset % 64 == 0 for word-level SIMD processing.
1065#[inline]
1066pub fn not_bool<const LANES: usize>(
1067    src: BooleanAVT<'_, ()>,
1068) -> Result<BooleanArray<()>, KernelError>
1069where
1070    LaneCount<LANES>: SupportedLaneCount,
1071{
1072    let (arr, offset, len) = src;
1073
1074    if offset % 64 != 0 {
1075        return Err(KernelError::InvalidArguments(format!(
1076            "not_bool: offset must be 64-bit aligned (got offset={})",
1077            offset
1078        )));
1079    }
1080
1081    let null_mask = arr.null_mask.as_ref().map(|nm| nm.slice_clone(offset, len));
1082
1083    let data = if null_mask.is_none() {
1084        #[cfg(feature = "simd")]
1085        {
1086            minarrow::kernels::bitmask::simd::not_mask_simd::<LANES>((&arr.data, offset, len))
1087        }
1088        #[cfg(not(feature = "simd"))]
1089        {
1090            minarrow::kernels::bitmask::std::not_mask((&arr.data, offset, len))
1091        }
1092    } else {
1093        // clone window – no modifications
1094        arr.data.slice_clone(offset, len)
1095    };
1096
1097    Ok(BooleanArray {
1098        data,
1099        null_mask,
1100        len,
1101        _phantom: core::marker::PhantomData,
1102    })
1103}
1104
1105/// Logical AND/OR/XOR of two bit-packed boolean masks over a window.
1106/// Offsets are bit offsets. Length is in bits.
1107/// Panics if offsets are not 64-bit aligned.
1108pub fn apply_logical_bool<const LANES: usize>(
1109    lhs: BooleanAVT<'_, ()>,
1110    rhs: BooleanAVT<'_, ()>,
1111    op: LogicalOperator,
1112) -> Result<BooleanArray<()>, KernelError>
1113where
1114    LaneCount<LANES>: SupportedLaneCount,
1115{
1116    let (lhs_arr, lhs_off, len) = lhs;
1117    let (rhs_arr, rhs_off, rlen) = rhs;
1118
1119    if len != rlen {
1120        return Err(KernelError::LengthMismatch(format!(
1121            "logical_bool: window length mismatch (lhs: {}, rhs: {})",
1122            len, rlen
1123        )));
1124    }
1125    if lhs_off % 64 != 0 || rhs_off % 64 != 0 {
1126        return Err(KernelError::InvalidArguments(format!(
1127            "logical_bool: offsets must be 64-bit aligned (lhs: {}, rhs: {})",
1128            lhs_off, rhs_off
1129        )));
1130    }
1131
1132    // Apply bitmask kernel for the logical operation.
1133
1134    #[cfg(feature = "simd")]
1135    let data = match op {
1136        LogicalOperator::And => {
1137            and_masks_simd::<LANES>((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1138        }
1139        LogicalOperator::Or => {
1140            or_masks_simd::<LANES>((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1141        }
1142        LogicalOperator::Xor => {
1143            xor_masks_simd::<LANES>((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1144        }
1145    };
1146
1147    // Merge validity (null) masks using AND
1148    #[cfg(feature = "simd")]
1149    let null_mask = match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
1150        (None, None) => None,
1151        (Some(a), None) | (None, Some(a)) => Some(a.slice_clone(lhs_off, len)),
1152        (Some(a), Some(b)) => Some(and_masks_simd::<LANES>(
1153            (a, lhs_off, len),
1154            (b, rhs_off, len),
1155        )),
1156    };
1157
1158    #[cfg(not(feature = "simd"))]
1159    let data = match op {
1160        LogicalOperator::And => {
1161            and_masks((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1162        }
1163        LogicalOperator::Or => {
1164            or_masks((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1165        }
1166        LogicalOperator::Xor => {
1167            xor_masks((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1168        }
1169    };
1170
1171    #[cfg(not(feature = "simd"))]
1172    let null_mask = match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
1173        (None, None) => None,
1174        (Some(a), None) | (None, Some(a)) => Some(a.slice_clone(lhs_off, len)),
1175        (Some(a), Some(b)) => Some(and_masks((a, lhs_off, len), (b, rhs_off, len))),
1176    };
1177
1178    Ok(BooleanArray {
1179        data,
1180        null_mask,
1181        len,
1182        _phantom: PhantomData,
1183    })
1184}
1185
1186#[cfg(test)]
1187mod tests {
1188    use minarrow::structs::variants::categorical::CategoricalArray;
1189    use minarrow::structs::variants::float::FloatArray;
1190    use minarrow::structs::variants::integer::IntegerArray;
1191    use minarrow::structs::variants::string::StringArray;
1192    use minarrow::{Array, Bitmask, BooleanArray, vec64};
1193
1194    use super::*;
1195
1196    // --- helpers ---
1197
1198    fn bm(bits: &[bool]) -> Bitmask {
1199        let mut m = Bitmask::new_set_all(bits.len(), false);
1200        for (i, &b) in bits.iter().enumerate() {
1201            m.set(i, b);
1202        }
1203        m
1204    }
1205
1206    fn assert_bool(arr: &BooleanArray<()>, expect: &[bool], expect_mask: Option<&[bool]>) {
1207        assert_eq!(arr.len, expect.len(), "length mismatch");
1208        for i in 0..expect.len() {
1209            assert_eq!(arr.data.get(i), expect[i], "val @ {i}");
1210        }
1211        match (expect_mask, &arr.null_mask) {
1212            (None, None) => {}
1213            (Some(exp), Some(mask)) => {
1214                for (i, &b) in exp.iter().enumerate() {
1215                    assert_eq!(mask.get(i), b, "mask @ {i}");
1216                }
1217            }
1218            (None, Some(mask)) => {
1219                // all mask bits should be true
1220                for i in 0..arr.len {
1221                    assert!(mask.get(i), "unexpected false mask @ {i}");
1222                }
1223            }
1224            (Some(_), None) => panic!("expected null mask"),
1225        }
1226    }
1227
1228    fn i32_arr(data: &[i32]) -> IntegerArray<i32> {
1229        IntegerArray::from_slice(data)
1230    }
1231    fn f32_arr(data: &[f32]) -> FloatArray<f32> {
1232        FloatArray::from_slice(data)
1233    }
1234    fn str_arr<T: Integer>(vals: &[&str]) -> StringArray<T> {
1235        StringArray::<T>::from_slice(vals)
1236    }
1237    fn dict_arr<T: Integer>(vals: &[&str]) -> CategoricalArray<T> {
1238        let owned: Vec<&str> = vals.to_vec();
1239        CategoricalArray::<T>::from_values(owned)
1240    }
1241    //  BETWEEN
1242
1243    #[test]
1244    fn between_i32_scalar_rhs() {
1245        let lhs = vec64![1, 3, 5, 7];
1246        let rhs = vec64![2, 6];
1247        let out = between_i32(&lhs, &rhs, None, false).unwrap();
1248        assert_bool(&out, &[false, true, true, false], None);
1249    }
1250
1251    #[test]
1252    fn between_i32_per_row_rhs() {
1253        let lhs = vec64![5, 9, 2, 8];
1254        let rhs = vec64![0, 10, 0, 4, 2, 2, 8, 9]; // min/max for each row
1255        let out = between_i32(&lhs, &rhs, None, false).unwrap();
1256        assert_bool(&out, &[true, false, true, true], None);
1257    }
1258
1259    #[test]
1260    fn between_i32_nulls_propagate() {
1261        let lhs = vec64![5, 9, 2, 8];
1262        let rhs = vec64![0, 10, 0, 4, 2, 2, 8, 9];
1263        let mask = bm(&[true, false, true, true]);
1264        let out = between_i32(&lhs, &rhs, Some(&mask), true).unwrap();
1265        assert_bool(
1266            &out,
1267            &[true, false, true, true],
1268            Some(&[true, false, true, true]),
1269        );
1270    }
1271
1272    #[cfg(feature = "extended_numeric_types")]
1273    #[test]
1274    fn between_i16_works() {
1275        let lhs = vec64![10i16, 12, 99];
1276        let rhs = vec64![10i16, 12];
1277        let out = in_i16(&lhs, &rhs, None, false).unwrap();
1278        assert_bool(&out, &[true, true, false], None);
1279    }
1280
1281    #[test]
1282    fn between_f64_scalar_and_nulls() {
1283        let lhs = vec64![1.0, 5.0, 8.0, 20.0];
1284        let rhs = vec64![4.0, 10.0];
1285        let mask = bm(&[true, false, true, true]);
1286        let out = between_f64(&lhs, &rhs, Some(&mask), true).unwrap();
1287        assert_bool(
1288            &out,
1289            &[false, false, true, false],
1290            Some(&[true, false, true, true]),
1291        );
1292    }
1293
1294    #[test]
1295    fn between_f32_generic_dispatch() {
1296        let lhs = vec64![0.1f32, 0.5, 1.2, -1.0];
1297        let rhs = vec64![0.0, 1.0];
1298        let out = cmp_between(&lhs, &rhs).unwrap();
1299        assert_bool(&out, &[true, true, false, false], None);
1300    }
1301
1302    #[test]
1303    fn between_masked_dispatch() {
1304        let lhs = vec64![1i32, 2, 3];
1305        let rhs = vec64![0, 2];
1306        let mask = bm(&[true, false, true]);
1307        let out = cmp_between_mask(&lhs, &rhs, Some(&mask)).unwrap();
1308        assert_bool(&out, &[true, false, false], Some(&[true, false, true]));
1309    }
1310
1311    // IN
1312
1313    #[test]
1314    fn in_i32_small_rhs() {
1315        let lhs = vec64![1, 2, 3, 4, 5];
1316        let rhs = vec64![2, 4];
1317        let out = in_i32(&lhs, &rhs, None, false).unwrap();
1318        assert_bool(&out, &[false, true, false, true, false], None);
1319    }
1320
1321    #[test]
1322    fn in_i32_with_nulls() {
1323        let lhs = vec64![7, 8, 9];
1324        let rhs = vec64![8];
1325        let mask = bm(&[true, false, true]);
1326        let out = in_i32(&lhs, &rhs, Some(&mask), true).unwrap();
1327        assert_bool(&out, &[false, false, false], Some(&[true, false, true]));
1328    }
1329
1330    #[test]
1331    fn in_i64_large_rhs() {
1332        let lhs = vec64![1i64, 2, 3, 7, 8, 15];
1333        let rhs: Vec<i64> = (2..10).collect();
1334        let out = in_i64(&lhs, &rhs, None, false).unwrap();
1335        assert_bool(&out, &[false, true, true, true, true, false], None);
1336    }
1337
1338    #[cfg(feature = "extended_numeric_types")]
1339    #[test]
1340    fn in_u8_small_rhs() {
1341        let lhs = vec64![1u8, 2, 3, 4];
1342        let rhs = vec64![2u8, 3];
1343        let out = in_u8(&lhs, &rhs, None, false).unwrap();
1344        assert_bool(&out, &[false, true, true, false], None);
1345    }
1346
1347    #[test]
1348    fn in_float_nan_and_normal() {
1349        let lhs = vec64![1.0f32, f32::NAN, 7.0];
1350        let rhs = vec64![f32::NAN, 7.0];
1351        let out = in_f32(&lhs, &rhs, None, false).unwrap();
1352        assert_bool(&out, &[false, true, true], None);
1353    }
1354
1355    // BETWEEN / IN
1356
1357    #[test]
1358    fn string_between() {
1359        let lhs = str_arr::<u32>(&["aa", "bb", "zz"]);
1360        let rhs = str_arr::<u32>(&["b", "y"]);
1361        let lhs_slice = (&lhs, 0, lhs.len());
1362        let rhs_slice = (&rhs, 0, rhs.len());
1363        let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1364        assert_bool(&out, &[false, true, false], None);
1365    }
1366
1367    #[test]
1368    fn string_between_chunk() {
1369        let lhs = str_arr::<u32>(&["0", "aa", "bb", "zz", "9"]);
1370        let rhs = str_arr::<u32>(&["a", "b", "y", "z"]);
1371        // Windowed: skip first/last for lhs; use a window for rhs
1372        let lhs_slice = (&lhs, 1, 3); // ["aa", "bb", "zz"]
1373        let rhs_slice = (&rhs, 1, 2); // ["b", "y"]
1374        let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1375        assert_bool(&out, &[false, true, false], None);
1376    }
1377
1378    #[test]
1379    fn string_in_basic() {
1380        let lhs = str_arr::<u32>(&["x", "y", "z"]);
1381        let rhs = str_arr::<u32>(&["y", "a"]);
1382        let lhs_slice = (&lhs, 0, lhs.len());
1383        let rhs_slice = (&rhs, 0, rhs.len());
1384        let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1385        assert_bool(&out, &[false, true, false], None);
1386    }
1387
1388    #[test]
1389    fn string_in_basic_chunk() {
1390        let lhs = str_arr::<u32>(&["0", "x", "y", "z", "9"]);
1391        let rhs = str_arr::<u32>(&["b", "y", "a", "c"]);
1392        let lhs_slice = (&lhs, 1, 3); // ["x", "y", "z"]
1393        let rhs_slice = (&rhs, 1, 2); // ["y", "a"]
1394        let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1395        assert_bool(&out, &[false, true, false], None);
1396    }
1397
1398    #[test]
1399    fn dict_between() {
1400        let lhs = dict_arr::<u32>(&["cat", "dog", "emu"]);
1401        let rhs = dict_arr::<u32>(&["cobra", "dove"]);
1402        let lhs_slice = (&lhs, 0, lhs.len());
1403        let rhs_slice = (&rhs, 0, rhs.len());
1404        let out = cmp_dict_between(lhs_slice, rhs_slice).unwrap();
1405        assert_bool(&out, &[false, true, false], None);
1406    }
1407
1408    #[test]
1409    fn dict_between_chunk() {
1410        let lhs = dict_arr::<u32>(&["a", "cat", "dog", "emu", "z"]);
1411        let rhs = dict_arr::<u32>(&["a", "cobra", "dove", "zz"]);
1412        let lhs_slice = (&lhs, 1, 3); // ["cat", "dog", "emu"]
1413        let rhs_slice = (&rhs, 1, 2); // ["cobra", "dove"]
1414        let out = cmp_dict_between(lhs_slice, rhs_slice).unwrap();
1415        assert_bool(&out, &[false, true, false], None);
1416    }
1417
1418    #[test]
1419    fn dict_in_membership() {
1420        let lhs = dict_arr::<u32>(&["aa", "bb", "cc"]);
1421        let rhs = dict_arr::<u32>(&["bb", "dd"]);
1422        let lhs_slice = (&lhs, 0, lhs.len());
1423        let rhs_slice = (&rhs, 0, rhs.len());
1424        let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1425        assert_bool(&out, &[false, true, false], None);
1426    }
1427
1428    #[test]
1429    fn dict_in_membership_chunk() {
1430        let lhs = dict_arr::<u32>(&["0", "aa", "bb", "cc", "9"]);
1431        let rhs = dict_arr::<u32>(&["a", "bb", "dd", "zz"]);
1432        let lhs_slice = (&lhs, 1, 3); // ["aa", "bb", "cc"]
1433        let rhs_slice = (&rhs, 1, 2); // ["bb", "dd"]
1434        let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1435        assert_bool(&out, &[false, true, false], None);
1436    }
1437
1438    #[test]
1439    fn string_between_nulls() {
1440        let mut lhs = str_arr::<u32>(&["foo", "bar", "baz"]);
1441        lhs.null_mask = Some(bm(&[true, false, true]));
1442        let rhs = str_arr::<u32>(&["a", "zzz"]);
1443        let lhs_slice = (&lhs, 0, lhs.len());
1444        let rhs_slice = (&rhs, 0, rhs.len());
1445        let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1446        assert_bool(&out, &[true, false, true], Some(&[true, false, true]));
1447    }
1448
1449    #[test]
1450    fn string_between_nulls_chunk() {
1451        let mut lhs = str_arr::<u32>(&["0", "foo", "bar", "baz", "z"]);
1452        lhs.null_mask = Some(bm(&[true, true, false, true, true]));
1453        let rhs = str_arr::<u32>(&["0", "a", "zzz", "9"]);
1454        let lhs_slice = (&lhs, 1, 3); // ["foo", "bar", "baz"]
1455        let rhs_slice = (&rhs, 1, 2); // ["a", "zzz"]
1456        let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1457        assert_bool(&out, &[true, false, true], Some(&[true, false, true]));
1458    }
1459
1460    #[test]
1461    fn dict_in_nulls() {
1462        let mut lhs = dict_arr::<u32>(&["one", "two", "three"]);
1463        lhs.null_mask = Some(bm(&[false, true, true]));
1464        let rhs = dict_arr::<u32>(&["two", "four"]);
1465        let lhs_slice = (&lhs, 0, lhs.len());
1466        let rhs_slice = (&rhs, 0, rhs.len());
1467        let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1468        assert_bool(&out, &[false, true, false], Some(&[false, true, true]));
1469    }
1470
1471    #[test]
1472    fn dict_in_nulls_chunk() {
1473        let mut lhs = dict_arr::<u32>(&["x", "one", "two", "three", "z"]);
1474        lhs.null_mask = Some(bm(&[true, false, true, true, true]));
1475        let rhs = dict_arr::<u32>(&["a", "two", "four", "b"]);
1476        let lhs_slice = (&lhs, 1, 3); // ["one", "two", "three"]
1477        let rhs_slice = (&rhs, 1, 2); // ["two", "four"]
1478        let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1479        assert_bool(&out, &[false, true, false], Some(&[false, true, true]));
1480    }
1481
1482    // Boolean/Null
1483
1484    #[test]
1485    fn is_null_and_is_not_null() {
1486        let mut arr = i32_arr(&[1, 2, 0]);
1487        arr.null_mask = Some(bm(&[true, false, true]));
1488        let array = Array::from_int32(arr.clone());
1489
1490        let not_null = is_not_null_array(&array).unwrap();
1491        let is_null = is_null_array(&array).unwrap();
1492
1493        assert_bool(&not_null, &[true, false, true], None);
1494        assert_bool(&is_null, &[false, true, false], None);
1495    }
1496
1497    #[test]
1498    fn is_null_not_null_dense() {
1499        let arr = i32_arr(&[1, 2, 3]);
1500        let array = Array::from_int32(arr.clone());
1501        let is_null = is_null_array(&array).unwrap();
1502        assert_bool(&is_null, &[false, false, false], None);
1503        let not_null = is_not_null_array(&array).unwrap();
1504        assert_bool(&not_null, &[true, true, true], None);
1505    }
1506
1507    //  Array dispatch in_array, not_in_array, between_array ----
1508
1509    #[test]
1510    fn in_array_int32_dispatch() {
1511        let inp = Array::from_int32(i32_arr(&[10, 20, 30]));
1512        let vals = Array::from_int32(i32_arr(&[20, 40]));
1513        let out = in_array(&inp, &vals).unwrap();
1514        assert_bool(&out, &[false, true, false], None);
1515
1516        let out_not = not_in_array(&inp, &vals).unwrap();
1517        assert_bool(&out_not, &[true, false, true], None);
1518    }
1519
1520    #[test]
1521    fn in_array_f32_dispatch() {
1522        let inp = Array::from_float32(f32_arr(&[1.0, f32::NAN, 7.0]));
1523        let vals = Array::from_float32(f32_arr(&[f32::NAN, 7.0]));
1524        let out = in_array(&inp, &vals).unwrap();
1525        assert_bool(&out, &[false, true, true], None);
1526    }
1527
1528    #[test]
1529    fn in_array_string_dispatch() {
1530        let inp = Array::from_string32(str_arr::<u32>(&["a", "b", "c"]));
1531        let vals = Array::from_string32(str_arr::<u32>(&["b", "d"]));
1532        let out = in_array(&inp, &vals).unwrap();
1533        assert_bool(&out, &[false, true, false], None);
1534    }
1535
1536    #[test]
1537    fn in_array_dictionary_dispatch() {
1538        let inp = Array::from_categorical32(dict_arr::<u32>(&["aa", "bb", "cc"]));
1539        let vals = Array::from_categorical32(dict_arr::<u32>(&["bb", "cc"]));
1540        let out = in_array(&inp, &vals).unwrap();
1541        assert_bool(&out, &[false, true, true], None);
1542    }
1543
1544    #[test]
1545    fn between_array_int32_rows() {
1546        let inp = Array::from_int32(i32_arr(&[5, 15, 25]));
1547        let min = Array::from_int32(i32_arr(&[0, 10, 20]));
1548        let max = Array::from_int32(i32_arr(&[10, 20, 30]));
1549
1550        let out = between_array(&inp, &min, &max).unwrap();
1551        match out {
1552            Array::BooleanArray(b) => assert_bool(&b, &[true, true, true], None),
1553            _ => panic!("expected Bool array"),
1554        }
1555    }
1556
1557    #[test]
1558    fn between_array_float_generic() {
1559        let inp = Array::from_float32(f32_arr(&[0.5, 1.5, 2.5]));
1560        let min = Array::from_float32(f32_arr(&[0.0, 1.0, 2.0]));
1561        let max = Array::from_float32(f32_arr(&[1.0, 2.0, 3.0]));
1562
1563        let out = between_array(&inp, &min, &max).unwrap();
1564        match out {
1565            Array::BooleanArray(b) => assert_bool(&b, &[true, true, true], None),
1566            _ => panic!("expected Bool"),
1567        }
1568    }
1569
1570    #[test]
1571    fn between_array_type_mismatch() {
1572        let inp = Array::from_int32(i32_arr(&[1, 2, 3]));
1573        let min = Array::from_float32(f32_arr(&[0.0, 0.0, 0.0]));
1574        let max = Array::from_float32(f32_arr(&[5.0, 5.0, 5.0]));
1575        let err = between_array(&inp, &min, &max).unwrap_err();
1576        match err {
1577            KernelError::UnsupportedType(_) => {}
1578            _ => panic!("Expected UnsupportedType error"),
1579        }
1580    }
1581
1582    // all integer types, short and long
1583
1584    #[test]
1585    fn in_integers_various_types() {
1586        #[cfg(feature = "extended_numeric_types")]
1587        {
1588            let u8_lhs = vec64![1u8, 2, 3, 5];
1589            let u8_rhs = vec64![3u8, 5, 8];
1590            let out = in_u8(&u8_lhs, &u8_rhs, None, false).unwrap();
1591            assert_bool(&out, &[false, false, true, true], None);
1592
1593            let u16_lhs = vec64![100u16, 200, 300];
1594            let u16_rhs = vec64![200u16, 500];
1595            let out = in_u16(&u16_lhs, &u16_rhs, None, false).unwrap();
1596            assert_bool(&out, &[false, true, false], None);
1597
1598            let i16_lhs = vec64![10i16, 15, 42];
1599            let i16_rhs = vec64![15i16, 42, 77];
1600            let out = in_i16(&i16_lhs, &i16_rhs, None, false).unwrap();
1601            assert_bool(&out, &[false, true, true], None);
1602        }
1603
1604        let u32_lhs = vec64![0u32, 1, 2, 9];
1605        let u32_rhs = vec64![9u32, 1];
1606        let out = in_u32(&u32_lhs, &u32_rhs, None, false).unwrap();
1607        assert_bool(&out, &[false, true, false, true], None);
1608
1609        let i64_lhs = vec64![1i64, 9, 10];
1610        let i64_rhs = vec64![2i64, 10, 20];
1611        let out = in_i64(&i64_lhs, &i64_rhs, None, false).unwrap();
1612        assert_bool(&out, &[false, false, true], None);
1613
1614        let u64_lhs = vec64![1u64, 2, 3, 4];
1615        let u64_rhs = vec64![2u64, 4, 8];
1616        let out = in_u64(&u64_lhs, &u64_rhs, None, false).unwrap();
1617        assert_bool(&out, &[false, true, false, true], None);
1618    }
1619
1620    // empty input edge
1621
1622    #[test]
1623    fn between_and_in_empty_inputs() {
1624        // Between, scalar rhs (for numeric arrays, no slice tuple needed)
1625        let lhs: [i32; 0] = [];
1626        let rhs = vec64![0, 1];
1627        let out = between_i32(&lhs, &rhs, None, false).unwrap();
1628        assert_eq!(out.len, 0);
1629
1630        // In, any rhs (for numeric arrays, no slice tuple needed)
1631        let lhs: [i32; 0] = [];
1632        let rhs = vec64![1, 2, 3];
1633        let out = in_i32(&lhs, &rhs, None, false).unwrap();
1634        assert_eq!(out.len, 0);
1635
1636        // String, in (slice API)
1637        let lhs = str_arr::<u32>(&[]);
1638        let rhs = str_arr::<u32>(&["a", "b"]);
1639        let lhs_slice = (&lhs, 0, lhs.len());
1640        let rhs_slice = (&rhs, 0, rhs.len());
1641        let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1642        assert_eq!(out.len, 0);
1643    }
1644
1645    #[test]
1646    fn between_and_in_empty_inputs_chunk() {
1647        // Only applies to the string in version
1648        let lhs = str_arr::<u32>(&["x", "y"]);
1649        let rhs = str_arr::<u32>(&["a", "b", "c"]);
1650        let lhs_slice = (&lhs, 1, 0); // zero-length window
1651        let rhs_slice = (&rhs, 1, 2); // window ["b", "c"]
1652        let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1653        assert_eq!(out.len, 0);
1654    }
1655
1656    #[test]
1657    fn between_per_row_bounds_on_last_row() {
1658        // Coverage: last row per-row
1659        let lhs = vec64![0i32, 10, 20, 30];
1660        let rhs = vec64![0, 5, 5, 15, 15, 25, 25, 35];
1661        let out = between_i32(&lhs, &rhs, None, false).unwrap();
1662        assert_bool(&out, &[true, true, true, true], None);
1663    }
1664
1665    #[test]
1666    fn test_cmp_dict_in_force_fallback() {
1667        // lhs and rhs have different unique_values lengths
1668        let mut lhs = dict_arr::<u32>(&["a", "b", "c", "a"]);
1669        lhs.unique_values = vec64!["a".to_string(), "b".to_string(), "c".to_string()]; // len=3
1670        let mut rhs = dict_arr::<u32>(&["b", "x", "y", "z"]);
1671        rhs.unique_values = vec64![
1672            "b".to_string(),
1673            "x".to_string(),
1674            "y".to_string(),
1675            "z".to_string()
1676        ]; // len=4
1677        lhs.null_mask = Some(bm(&[true, true, true, true]));
1678        let lhs_slice = (&lhs, 0, lhs.len());
1679        let rhs_slice = (&rhs, 0, rhs.len());
1680        let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1681        // should fall back to string-matching: only "b" matches
1682        assert_bool(
1683            &out,
1684            &[false, true, false, false],
1685            Some(&[true, true, true, true]),
1686        );
1687    }
1688
1689    #[test]
1690    fn test_cmp_dict_in_force_fallback_chunk() {
1691        let mut lhs = dict_arr::<u32>(&["z", "a", "b", "c", "a", "q"]);
1692        lhs.unique_values = vec64![
1693            "z".to_string(),
1694            "a".to_string(),
1695            "b".to_string(),
1696            "c".to_string(),
1697            "q".to_string()
1698        ];
1699        let mut rhs = dict_arr::<u32>(&["x", "b", "x", "y", "z"]);
1700        rhs.unique_values = vec64![
1701            "x".to_string(),
1702            "b".to_string(),
1703            "y".to_string(),
1704            "z".to_string()
1705        ];
1706        lhs.null_mask = Some(bm(&[true, true, true, true, true, true]));
1707        // Window: pick ["a", "b", "c", "a"] and ["b", "x", "y", "z"]
1708        let lhs_slice = (&lhs, 1, 4);
1709        let rhs_slice = (&rhs, 1, 4);
1710        let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1711        // Only "b" matches (index 1 of window)
1712        assert_bool(
1713            &out,
1714            &[false, true, false, false],
1715            Some(&[true, true, true, true]),
1716        );
1717    }
1718
1719    #[test]
1720    fn test_in_array_empty_rhs() {
1721        let arr = Array::from_int32(i32_arr(&[1, 2, 3]));
1722        let empty = Array::from_int32(i32_arr(&[]));
1723        let out = in_array(&arr, &empty).unwrap();
1724        // must be all false, and mask preserved (no mask => all bits true)
1725        assert_bool(&out, &[false, false, false], None);
1726    }
1727}