Skip to main content

simd_kernels/kernels/
logical.rs

1// Copyright (c) 2025 SpaceCell Enterprises Ltd
2// SPDX-License-Identifier: AGPL-3.0-or-later
3// Commercial licensing available. See LICENSE and LICENSING.md.
4
5//! # **Logical Operations Kernels Module** - *Boolean Logic and Set Operations*
6//!
7//! Logical operation kernels providing efficient boolean algebra, set membership testing,
8//! and range operations with SIMD acceleration and null-aware semantics. Critical foundation
9//! for query execution, filtering predicates, and analytical data processing workflows.
10//!
11//! ## Core Operations
12//! - **Boolean algebra**: AND, OR, XOR, NOT operations on boolean arrays with bitmask optimisation
13//! - **Set membership**: IN and NOT IN operations with hash-based lookup optimisation
14//! - **Range operations**: BETWEEN predicates for numeric and string data types
15//! - **Pattern matching**: String pattern matching with optimised prefix/suffix detection
16//! - **Null-aware logic**: Three-valued logic implementation following SQL semantics
17//! - **Compound predicates**: Efficient evaluation of complex multi-condition expressions
18
19include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
20
21use std::collections::HashSet;
22use std::hash::Hash;
23use std::marker::PhantomData;
24#[cfg(feature = "simd")]
25use std::simd::{Mask, Simd, cmp::SimdPartialEq, cmp::SimdPartialOrd, num::SimdFloat};
26
27use minarrow::kernels::arithmetic::string::MAX_DICT_CHECK;
28use minarrow::traits::type_unions::Float;
29use minarrow::{
30    Array, Bitmask, BooleanAVT, BooleanArray, CategoricalAVT, Integer, MaskedArray, Numeric,
31    NumericArray, StringAVT, TextArray, Vec64,
32};
33
34#[cfg(not(feature = "simd"))]
35use crate::kernels::bitmask::dispatch::{and_masks, or_masks, xor_masks};
36use crate::operators::LogicalOperator;
37use minarrow::enums::error::KernelError;
38#[cfg(feature = "simd")]
39use minarrow::kernels::bitmask::simd::{and_masks_simd, or_masks_simd, xor_masks_simd};
40use minarrow::utils::confirm_mask_capacity;
41
42#[cfg(feature = "simd")]
43use minarrow::utils::is_simd_aligned;
44use std::any::TypeId;
45
46/// Builds the Boolean result buffer.
47/// `len` – number of rows that will be written.
48#[inline(always)]
49fn new_bool_buffer(len: usize) -> Bitmask {
50    Bitmask::new_set_all(len, false)
51}
52
53// Between
54
55macro_rules! impl_between_numeric {
56    ($name:ident, $name_to:ident, $ty:ty, $mask_elem:ty, $lanes:expr) => {
57        /// Zero-allocation variant: writes directly to caller's output buffer.
58        ///
59        /// Test if LHS values fall between RHS min/max bounds.
60        /// The output Bitmask must have capacity >= lhs.len().
61        #[inline(always)]
62        pub fn $name_to(
63            lhs: &[$ty],
64            rhs: &[$ty],
65            mask: Option<&Bitmask>,
66            has_nulls: bool,
67            output: &mut Bitmask,
68        ) -> Result<(), KernelError> {
69            let len = lhs.len();
70            if rhs.len() != 2 && rhs.len() != 2 * len {
71                return Err(KernelError::InvalidArguments(
72                    format!("between: RHS must have len 2 or 2×LHS (got lhs: {}, rhs: {})", len, rhs.len())
73                ));
74            }
75
76            if let Some(m) = mask {
77                if m.capacity() < len {
78                    return Err(KernelError::InvalidArguments(
79                        format!("between: mask (Bitmask) capacity must be ≥ len (got capacity: {}, len: {})", m.capacity(), len)
80                    ));
81                }
82            }
83            assert!(output.capacity() >= len, concat!(stringify!($name_to), ": output capacity too small"));
84
85            // SIMD fast-path
86            #[cfg(feature = "simd")]
87            {
88                // Check if both arrays are 64-byte aligned for SIMD
89                if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
90                    const N: usize = $lanes;
91                    type V = Simd<$ty, N>;
92                    type M = Mask<$mask_elem, N>;
93
94                    if !has_nulls && rhs.len() == 2 {
95                        let min_v = V::splat(rhs[0]);
96                        let max_v = V::splat(rhs[1]);
97
98                        let mut i = 0usize;
99                        while i + N <= len {
100                            let x = V::from_slice(&lhs[i..i + N]);
101                            let m: M = x.simd_ge(min_v) & x.simd_le(max_v);
102                            let bm = m.to_bitmask();
103
104                            for l in 0..N {
105                                if ((bm >> l) & 1) == 1 {
106                                    output.set(i + l, true);
107                                }
108                            }
109                            i += N;
110                        }
111                        // fall back to scalar for tail
112                        for j in i..len {
113                            if lhs[j] >= rhs[0] && lhs[j] <= rhs[1] {
114                                output.set(j, true);
115                            }
116                        }
117
118                        return Ok(());
119                    }
120                }
121                // Fall through to scalar path if alignment check failed
122            }
123
124            // Scalar / null-aware path
125            if rhs.len() == 2 {
126                let (min, max) = (rhs[0], rhs[1]);
127                for i in 0..len {
128                    if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
129                        && lhs[i] >= min
130                        && lhs[i] <= max
131                    {
132                        output.set(i, true);
133                    }
134                }
135            } else {
136                // per-row min / max
137                for i in 0..len {
138                    let min = rhs[i * 2];
139                    let max = rhs[i * 2 + 1];
140                    if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
141                        && lhs[i] >= min
142                        && lhs[i] <= max
143                    {
144                        output.set(i, true);
145                    }
146                }
147            }
148
149            Ok(())
150        }
151
152        /// Test if LHS values fall between RHS min/max bounds, producing boolean result array.
153        #[inline(always)]
154        pub fn $name(
155            lhs: &[$ty],
156            rhs: &[$ty],
157            mask: Option<&Bitmask>,
158            has_nulls: bool
159        ) -> Result<BooleanArray<()>, KernelError> {
160            let len = lhs.len();
161            let mut out_data = new_bool_buffer(len);
162            $name_to(lhs, rhs, mask, has_nulls, &mut out_data)?;
163            Ok(BooleanArray {
164                data: out_data.into(),
165                null_mask: mask.cloned(),
166                len,
167                _phantom: PhantomData
168            })
169        }
170    };
171}
172
173// floats
174
175#[inline(always)]
176fn between_generic<T: Numeric + Copy + std::cmp::PartialOrd>(
177    lhs: &[T],
178    rhs: &[T],
179    mask: Option<&Bitmask>,
180    has_nulls: bool,
181) -> Result<BooleanArray<()>, KernelError> {
182    let len = lhs.len();
183    let mut out = new_bool_buffer(len);
184    let _ = confirm_mask_capacity(len, mask)?;
185    if rhs.len() == 2 {
186        let (min, max) = (rhs[0], rhs[1]);
187        for i in 0..len {
188            if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
189                && lhs[i] >= min
190                && lhs[i] <= max
191            {
192                out.set(i, true);
193            }
194        }
195    } else {
196        for i in 0..len {
197            let min = rhs[i * 2];
198            let max = rhs[i * 2 + 1];
199            if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
200                && lhs[i] >= min
201                && lhs[i] <= max
202            {
203                out.set(i, true);
204            }
205        }
206    }
207
208    Ok(BooleanArray {
209        data: out.into(),
210        null_mask: mask.cloned(),
211        len,
212        _phantom: PhantomData,
213    })
214}
215
216// In and Not In
217
218macro_rules! impl_in_int {
219    ($name:ident, $name_to:ident, $ty:ty, $lanes:expr, $mask_elem:ty) => {
220        /// Zero-allocation variant: writes directly to caller's output buffer.
221        ///
222        /// Test membership of LHS integer values in RHS set.
223        /// The output Bitmask must have capacity >= lhs.len().
224        #[inline(always)]
225        pub fn $name_to(
226            lhs: &[$ty],
227            rhs: &[$ty],
228            mask: Option<&Bitmask>,
229            has_nulls: bool,
230            output: &mut Bitmask,
231        ) -> Result<(), KernelError> {
232            let len = lhs.len();
233            let _ = confirm_mask_capacity(len, mask)?;
234            assert!(
235                output.capacity() >= len,
236                concat!(stringify!($name_to), ": output capacity too small")
237            );
238
239            #[cfg(feature = "simd")]
240            {
241                // Check if both arrays are 64-byte aligned for SIMD
242                if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
243                    use crate::utils::bitmask_to_simd_mask;
244                    use core::simd::{Mask, Simd};
245
246                    if rhs.len() <= 16 {
247                        let mut i = 0;
248                        let rhs_simd = rhs;
249                        if !has_nulls {
250                            while i + $lanes <= len {
251                                let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
252                                let mut m = Mask::<$mask_elem, $lanes>::splat(false);
253                                for &v in rhs_simd {
254                                    m |= x.simd_eq(Simd::<$ty, $lanes>::splat(v));
255                                }
256                                let bm = m.to_bitmask();
257                                for l in 0..$lanes {
258                                    if ((bm >> l) & 1) == 1 {
259                                        output.set(i + l, true);
260                                    }
261                                }
262                                i += $lanes;
263                            }
264                            for j in i..len {
265                                if rhs_simd.contains(&lhs[j]) {
266                                    output.set(j, true);
267                                }
268                            }
269                            return Ok(());
270                        } else {
271                            // ---- SIMD + nulls: use bitmask_to_simd_mask
272                            let mb = mask.expect("Bitmask must be Some if has_nulls is set");
273                            let mask_bytes = mb.as_bytes();
274                            while i + $lanes <= len {
275                                let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
276                                // valid lanes
277                                let lane_mask =
278                                    bitmask_to_simd_mask::<$lanes, $mask_elem>(mask_bytes, i, len);
279                                let mut in_mask = Mask::<$mask_elem, $lanes>::splat(false);
280                                for &v in rhs_simd {
281                                    in_mask |= x.simd_eq(Simd::<$ty, $lanes>::splat(v));
282                                }
283                                // Only set bits for lanes that are both valid and match RHS
284                                let valid_in = lane_mask & in_mask;
285                                let bm = valid_in.to_bitmask();
286                                for l in 0..$lanes {
287                                    if ((bm >> l) & 1) == 1 {
288                                        output.set(i + l, true);
289                                    }
290                                }
291                                i += $lanes;
292                            }
293                            for j in i..len {
294                                if unsafe { mb.get_unchecked(j) } && rhs_simd.contains(&lhs[j]) {
295                                    output.set(j, true);
296                                }
297                            }
298                            return Ok(());
299                        }
300                    }
301                }
302                // Fall through to scalar path if alignment check failed
303            }
304
305            // Scalar fallback (null-aware and large-RHS)
306            let set: std::collections::HashSet<$ty> = rhs.iter().copied().collect();
307            for i in 0..len {
308                if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
309                    && set.contains(&lhs[i])
310                {
311                    output.set(i, true);
312                }
313            }
314            Ok(())
315        }
316
317        /// Test membership of LHS integer values in RHS set, producing boolean result array.
318        #[inline(always)]
319        pub fn $name(
320            lhs: &[$ty],
321            rhs: &[$ty],
322            mask: Option<&Bitmask>,
323            has_nulls: bool,
324        ) -> Result<BooleanArray<()>, KernelError> {
325            let len = lhs.len();
326            let mut out = new_bool_buffer(len);
327            $name_to(lhs, rhs, mask, has_nulls, &mut out)?;
328            Ok(BooleanArray {
329                data: out.into(),
330                null_mask: mask.cloned(),
331                len,
332                _phantom: PhantomData,
333            })
334        }
335    };
336}
337
338/// Implements SIMD/Scalar IN kernel for floats, handling NaN semantics and optional null mask.
339macro_rules! impl_in_float {
340    (
341        $fn_name:ident, $fn_name_to:ident, $ty:ty, $lanes:expr, $mask_elem:ty
342    ) => {
343        /// Zero-allocation variant: writes directly to caller's output buffer.
344        ///
345        /// Test membership of LHS floating-point values in RHS set with NaN handling.
346        /// The output Bitmask must have capacity >= lhs.len().
347        #[inline(always)]
348        pub fn $fn_name_to(
349            lhs: &[$ty],
350            rhs: &[$ty],
351            mask: Option<&Bitmask>,
352            has_nulls: bool,
353            output: &mut Bitmask,
354        ) -> Result<(), KernelError> {
355            let len = lhs.len();
356            let _ = confirm_mask_capacity(len, mask)?;
357            assert!(
358                output.capacity() >= len,
359                concat!(stringify!($fn_name_to), ": output capacity too small")
360            );
361
362            #[cfg(feature = "simd")]
363            {
364                // Check if both arrays are 64-byte aligned for SIMD
365                if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
366                    use crate::utils::bitmask_to_simd_mask;
367                    use core::simd::{Mask, Simd};
368                    if rhs.len() <= 16 {
369                        let mut i = 0;
370                        if !has_nulls {
371                            while i + $lanes <= len {
372                                let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
373                                let mut m = Mask::<$mask_elem, $lanes>::splat(false);
374                                for &v in rhs {
375                                    let vmask = x.simd_eq(Simd::<$ty, $lanes>::splat(v))
376                                        | (x.is_nan() & Simd::<$ty, $lanes>::splat(v).is_nan());
377                                    m |= vmask;
378                                }
379                                let bm = m.to_bitmask();
380                                for l in 0..$lanes {
381                                    if ((bm >> l) & 1) == 1 {
382                                        output.set(i + l, true);
383                                    }
384                                }
385                                i += $lanes;
386                            }
387                            for j in i..len {
388                                let x = lhs[j];
389                                if rhs.iter().any(|&v| x == v || (x.is_nan() && v.is_nan())) {
390                                    output.set(j, true);
391                                }
392                            }
393                            return Ok(());
394                        } else {
395                            let mb = mask.expect("Bitmask must be Some if nulls are present");
396                            let mask_bytes = mb.as_bytes();
397                            while i + $lanes <= len {
398                                let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
399                                let lane_mask =
400                                    bitmask_to_simd_mask::<$lanes, $mask_elem>(mask_bytes, i, len);
401                                let mut m = Mask::<$mask_elem, $lanes>::splat(false);
402                                for &v in rhs {
403                                    let vmask = x.simd_eq(Simd::<$ty, $lanes>::splat(v))
404                                        | (x.is_nan() & Simd::<$ty, $lanes>::splat(v).is_nan());
405                                    m |= vmask;
406                                }
407                                let m = m & lane_mask;
408                                let bm = m.to_bitmask();
409                                for l in 0..$lanes {
410                                    if ((bm >> l) & 1) == 1 {
411                                        output.set(i + l, true);
412                                    }
413                                }
414                                i += $lanes;
415                            }
416                            for j in i..len {
417                                if mask.map_or(true, |m| unsafe { m.get_unchecked(j) }) {
418                                    let x = lhs[j];
419                                    if rhs.iter().any(|&v| x == v || (x.is_nan() && v.is_nan())) {
420                                        output.set(j, true);
421                                    }
422                                }
423                            }
424                            return Ok(());
425                        }
426                    }
427                }
428                // Fall through to scalar path if alignment check failed
429            }
430
431            // Scalar fallback
432            for i in 0..len {
433                if has_nulls && !mask.map_or(true, |m| unsafe { m.get_unchecked(i) }) {
434                    continue;
435                }
436                let x = lhs[i];
437                if rhs.iter().any(|&v| x == v || (x.is_nan() && v.is_nan())) {
438                    output.set(i, true);
439                }
440            }
441            Ok(())
442        }
443
444        /// Test membership of LHS floating-point values in RHS set with NaN handling.
445        #[inline(always)]
446        pub fn $fn_name(
447            lhs: &[$ty],
448            rhs: &[$ty],
449            mask: Option<&Bitmask>,
450            has_nulls: bool,
451        ) -> Result<BooleanArray<()>, KernelError> {
452            let len = lhs.len();
453            let mut out = new_bool_buffer(len);
454            $fn_name_to(lhs, rhs, mask, has_nulls, &mut out)?;
455            Ok(BooleanArray {
456                data: out.into(),
457                null_mask: mask.cloned(),
458                len,
459                _phantom: PhantomData,
460            })
461        }
462    };
463}
464
465// Correct MaskElement types per std::simd
466#[cfg(feature = "extended_numeric_types")]
467impl_in_int!(in_i8, in_i8_to, i8, W8, i8);
468#[cfg(feature = "extended_numeric_types")]
469impl_in_int!(in_u8, in_u8_to, u8, W8, i8);
470#[cfg(feature = "extended_numeric_types")]
471impl_in_int!(in_i16, in_i16_to, i16, W16, i16);
472#[cfg(feature = "extended_numeric_types")]
473impl_in_int!(in_u16, in_u16_to, u16, W16, i16);
474impl_in_int!(in_i32, in_i32_to, i32, W32, i32);
475impl_in_int!(in_u32, in_u32_to, u32, W32, i32);
476impl_in_int!(in_i64, in_i64_to, i64, W64, i64);
477impl_in_int!(in_u64, in_u64_to, u64, W64, i64);
478impl_in_float!(in_f32, in_f32_to, f32, W32, i32);
479impl_in_float!(in_f64, in_f64_to, f64, W64, i64);
480
481#[cfg(feature = "extended_numeric_types")]
482impl_between_numeric!(between_i8, between_i8_to, i8, i8, W8);
483#[cfg(feature = "extended_numeric_types")]
484impl_between_numeric!(between_u8, between_u8_to, u8, i8, W8);
485#[cfg(feature = "extended_numeric_types")]
486impl_between_numeric!(between_i16, between_i16_to, i16, i16, W16);
487#[cfg(feature = "extended_numeric_types")]
488impl_between_numeric!(between_u16, between_u16_to, u16, i16, W16);
489
490impl_between_numeric!(between_i32, between_i32_to, i32, i32, W32);
491impl_between_numeric!(between_u32, between_u32_to, u32, i32, W32);
492impl_between_numeric!(between_i64, between_i64_to, i64, i64, W64);
493impl_between_numeric!(between_u64, between_u64_to, u64, i64, W64);
494impl_between_numeric!(between_f32, between_f32_to, f32, i32, W32);
495impl_between_numeric!(between_f64, between_f64_to, f64, i64, W64);
496
497// String and dictionary
498
499/// Test if LHS string values fall lexicographically between RHS min/max bounds.
500#[inline(always)]
501pub fn cmp_str_between<'a, T: Integer>(
502    lhs: StringAVT<'a, T>,
503    rhs: StringAVT<'a, T>,
504) -> Result<BooleanArray<()>, KernelError> {
505    let (larr, loff, llen) = lhs;
506    let (rarr, roff, rlen) = rhs;
507
508    if rlen < 2 {
509        return Err(KernelError::InvalidArguments(format!(
510            "str_between: RHS must contain at least two values (got {})",
511            rlen
512        )));
513    }
514    let min = rarr.get(roff).unwrap_or("");
515    let max = rarr.get(roff + 1).unwrap_or("");
516    let mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
517    let _ = confirm_mask_capacity(llen, mask.as_ref())?;
518
519    let mut out = new_bool_buffer(llen);
520
521    for i in 0..llen {
522        if mask
523            .as_ref()
524            .map_or(true, |m| unsafe { m.get_unchecked(i) })
525        {
526            let s = unsafe { larr.get_str_unchecked(loff + i) };
527            if s >= min && s <= max {
528                unsafe { out.set_unchecked(i, true) };
529            }
530        }
531    }
532
533    Ok(BooleanArray {
534        data: out.into(),
535        null_mask: mask,
536        len: llen,
537        _phantom: PhantomData,
538    })
539}
540
541#[inline(always)]
542/// Test membership of LHS string values in RHS string set.
543pub fn cmp_str_in<'a, T: Integer>(
544    lhs: StringAVT<'a, T>,
545    rhs: StringAVT<'a, T>,
546) -> Result<BooleanArray<()>, KernelError> {
547    let (larr, loff, llen) = lhs;
548    let (rarr, roff, rlen) = rhs;
549
550    let set: HashSet<&str> = (0..rlen)
551        .map(|i| unsafe { rarr.get_str_unchecked(roff + i) })
552        .collect();
553
554    let mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
555    let _ = confirm_mask_capacity(llen, mask.as_ref())?;
556
557    let mut out = new_bool_buffer(llen);
558
559    for i in 0..llen {
560        if mask
561            .as_ref()
562            .map_or(true, |m| unsafe { m.get_unchecked(i) })
563        {
564            let s = unsafe { larr.get_str_unchecked(loff + i) };
565            if set.contains(s) {
566                unsafe { out.set_unchecked(i, true) };
567            }
568        }
569    }
570    Ok(BooleanArray {
571        data: out.into(),
572        null_mask: mask,
573        len: llen,
574        _phantom: PhantomData,
575    })
576}
577
578// Public functions
579
580/// Test if values fall between min/max bounds for comparable numeric types.
581pub fn cmp_between<T: PartialOrd + Copy + Numeric>(
582    lhs: &[T],
583    rhs: &[T],
584) -> Result<BooleanArray<()>, KernelError> {
585    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
586        return between_i32(
587            unsafe { std::mem::transmute(lhs) },
588            unsafe { std::mem::transmute(rhs) },
589            None,
590            false,
591        );
592    }
593    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
594        return between_u32(
595            unsafe { std::mem::transmute(lhs) },
596            unsafe { std::mem::transmute(rhs) },
597            None,
598            false,
599        );
600    }
601    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
602        return between_i64(
603            unsafe { std::mem::transmute(lhs) },
604            unsafe { std::mem::transmute(rhs) },
605            None,
606            false,
607        );
608    }
609    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
610        return between_u64(
611            unsafe { std::mem::transmute(lhs) },
612            unsafe { std::mem::transmute(rhs) },
613            None,
614            false,
615        );
616    }
617    // Fallback – floats or any other PartialOrd type
618    between_generic(lhs, rhs, None, false)
619}
620
621/// Mask-aware variant
622#[inline(always)]
623pub fn cmp_between_mask<T: PartialOrd + Copy + Numeric>(
624    lhs: &[T],
625    rhs: &[T],
626    mask: Option<&Bitmask>,
627) -> Result<BooleanArray<()>, KernelError> {
628    let has_nulls = mask.is_some();
629    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
630        return between_i32(
631            unsafe { std::mem::transmute(lhs) },
632            unsafe { std::mem::transmute(rhs) },
633            mask,
634            has_nulls,
635        );
636    }
637    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
638        return between_u32(
639            unsafe { std::mem::transmute(lhs) },
640            unsafe { std::mem::transmute(rhs) },
641            mask,
642            has_nulls,
643        );
644    }
645    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
646        return between_i64(
647            unsafe { std::mem::transmute(lhs) },
648            unsafe { std::mem::transmute(rhs) },
649            mask,
650            has_nulls,
651        );
652    }
653    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
654        return between_u64(
655            unsafe { std::mem::transmute(lhs) },
656            unsafe { std::mem::transmute(rhs) },
657            mask,
658            has_nulls,
659        );
660    }
661    between_generic(lhs, rhs, mask, has_nulls)
662}
663
664// In and Not In
665
666/// Test membership in set for hashable types using hash-based lookup.
667pub fn cmp_in<T: Eq + Hash + Copy + 'static>(
668    lhs: &[T],
669    rhs: &[T],
670) -> Result<BooleanArray<()>, KernelError> {
671    // i32
672    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
673        return in_i32(
674            unsafe { std::mem::transmute(lhs) },
675            unsafe { std::mem::transmute(rhs) },
676            None,
677            false,
678        );
679    }
680    // u32
681    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
682        return in_u32(
683            unsafe { std::mem::transmute(lhs) },
684            unsafe { std::mem::transmute(rhs) },
685            None,
686            false,
687        );
688    }
689    // i64
690    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
691        return in_i64(
692            unsafe { std::mem::transmute(lhs) },
693            unsafe { std::mem::transmute(rhs) },
694            None,
695            false,
696        );
697    }
698    // u64
699    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
700        return in_u64(
701            unsafe { std::mem::transmute(lhs) },
702            unsafe { std::mem::transmute(rhs) },
703            None,
704            false,
705        );
706    }
707    // i16
708    #[cfg(feature = "extended_numeric_types")]
709    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i16>() {
710        return in_i16(
711            unsafe { std::mem::transmute(lhs) },
712            unsafe { std::mem::transmute(rhs) },
713            None,
714            false,
715        );
716    }
717    // u16
718    #[cfg(feature = "extended_numeric_types")]
719    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u16>() {
720        return in_u16(
721            unsafe { std::mem::transmute(lhs) },
722            unsafe { std::mem::transmute(rhs) },
723            None,
724            false,
725        );
726    }
727    // i8
728    #[cfg(feature = "extended_numeric_types")]
729    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i8>() {
730        return in_i8(
731            unsafe { std::mem::transmute(lhs) },
732            unsafe { std::mem::transmute(rhs) },
733            None,
734            false,
735        );
736    }
737    // u8
738    #[cfg(feature = "extended_numeric_types")]
739    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u8>() {
740        return in_u8(
741            unsafe { std::mem::transmute(lhs) },
742            unsafe { std::mem::transmute(rhs) },
743            None,
744            false,
745        );
746    }
747    return Err(KernelError::UnsupportedType(
748        "cmp_in: unsupported type for SIMD in".into(),
749    ));
750}
751
752/// Mask-aware variant
753#[inline(always)]
754pub fn cmp_in_mask<T: Eq + Hash + Copy + 'static>(
755    lhs: &[T],
756    rhs: &[T],
757    mask: Option<&Bitmask>,
758) -> Result<BooleanArray<()>, KernelError> {
759    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
760        return in_i32(
761            unsafe { std::mem::transmute(lhs) },
762            unsafe { std::mem::transmute(rhs) },
763            mask,
764            mask.is_some(),
765        );
766    }
767    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
768        return in_u32(
769            unsafe { std::mem::transmute(lhs) },
770            unsafe { std::mem::transmute(rhs) },
771            mask,
772            mask.is_some(),
773        );
774    }
775    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
776        return in_i64(
777            unsafe { std::mem::transmute(lhs) },
778            unsafe { std::mem::transmute(rhs) },
779            mask,
780            mask.is_some(),
781        );
782    }
783    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
784        return in_u64(
785            unsafe { std::mem::transmute(lhs) },
786            unsafe { std::mem::transmute(rhs) },
787            mask,
788            mask.is_some(),
789        );
790    }
791    return Err(KernelError::UnsupportedType(
792        "cmp_in_mask: unsupported type (expected integer type)".into(),
793    ));
794}
795
796/// SIMD-aware, type-specific dispatch for cmp_in_f_mask and cmp_in_f
797#[inline(always)]
798pub fn cmp_in_f_mask<T: Float + Copy>(
799    lhs: &[T],
800    rhs: &[T],
801    mask: Option<&Bitmask>,
802) -> Result<BooleanArray<()>, KernelError> {
803    if TypeId::of::<T>() == TypeId::of::<f32>() {
804        let lhs = unsafe { &*(lhs as *const [T] as *const [f32]) };
805        let rhs = unsafe { &*(rhs as *const [T] as *const [f32]) };
806        in_f32(lhs, rhs, mask, mask.is_some())
807    } else if TypeId::of::<T>() == TypeId::of::<f64>() {
808        let lhs = unsafe { &*(lhs as *const [T] as *const [f64]) };
809        let rhs = unsafe { &*(rhs as *const [T] as *const [f64]) };
810        in_f64(lhs, rhs, mask, mask.is_some())
811    } else {
812        unreachable!("cmp_in_f_mask: Only f32/f64 supported for Float kernels")
813    }
814}
815
816#[inline(always)]
817/// Test membership in set for floating-point types with NaN handling.
818pub fn cmp_in_f<T: Float + Copy>(lhs: &[T], rhs: &[T]) -> Result<BooleanArray<()>, KernelError> {
819    if TypeId::of::<T>() == TypeId::of::<f32>() {
820        let lhs = unsafe { &*(lhs as *const [T] as *const [f32]) };
821        let rhs = unsafe { &*(rhs as *const [T] as *const [f32]) };
822        in_f32(lhs, rhs, None, false)
823    } else if TypeId::of::<T>() == TypeId::of::<f64>() {
824        let lhs = unsafe { &*(lhs as *const [T] as *const [f64]) };
825        let rhs = unsafe { &*(rhs as *const [T] as *const [f64]) };
826        in_f64(lhs, rhs, None, false)
827    } else {
828        unreachable!("cmp_in_f: Only f32/f64 supported for Float kernels")
829    }
830}
831
832// String and dictionary
833
834/// Test if floating-point values fall between bounds with NaN handling.
835pub fn cmp_between_f<T: PartialOrd + Copy + Float + Numeric>(
836    lhs: &[T],
837    rhs: &[T],
838) -> Result<BooleanArray<()>, KernelError> {
839    between_generic(lhs, rhs, None, false)
840}
841
842/// Test if dictionary/categorical values fall between lexicographic bounds.
843pub fn cmp_dict_between<'a, T: Integer>(
844    lhs: CategoricalAVT<'a, T>,
845    rhs: CategoricalAVT<'a, T>,
846) -> Result<BooleanArray<()>, KernelError> {
847    let (larr, loff, llen) = lhs;
848    let (rarr, roff, _rlen) = rhs;
849
850    let min = rarr.get(roff).unwrap_or("");
851    let max = rarr.get(roff + 1).unwrap_or("");
852    let mask = larr.null_mask.as_ref();
853    let _ = confirm_mask_capacity(larr.data.len(), mask)?;
854    let has_nulls = mask.is_some();
855
856    let mut out = new_bool_buffer(llen);
857    for i in 0..llen {
858        let li = loff + i;
859        if !has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(li) }) {
860            let s = unsafe { larr.get_str_unchecked(li) };
861            if s >= min && s <= max {
862                unsafe { out.set_unchecked(i, true) };
863            }
864        }
865    }
866    Ok(BooleanArray {
867        data: out.into(),
868        null_mask: mask.map(|m| m.slice_clone(loff, llen)),
869        len: llen,
870        _phantom: PhantomData,
871    })
872}
873
874/// Returns `true` for each row in `lhs` whose string value also appears
875/// anywhere in `rhs`, respecting null masks on both sides.
876/// Returns `true` for each row in `lhs` whose string value also appears
877/// anywhere in `rhs`, respecting null masks on both sides.
878pub fn cmp_dict_in<'a, T: Integer + Hash>(
879    lhs: CategoricalAVT<'a, T>,
880    rhs: CategoricalAVT<'a, T>,
881) -> Result<BooleanArray<()>, KernelError> {
882    let (larr, loff, llen) = lhs;
883    let (rarr, roff, rlen) = rhs;
884    let mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
885    let _ = confirm_mask_capacity(llen, mask.as_ref())?;
886
887    let mut out = Bitmask::new_set_all(llen, false);
888
889    if (larr.unique_values.len() == rarr.unique_values.len())
890        && (larr.unique_values.len() <= MAX_DICT_CHECK)
891    {
892        let mut same_dict = true;
893        for (a, b) in larr.unique_values.iter().zip(rarr.unique_values.iter()) {
894            if a != b {
895                same_dict = false;
896                break;
897            }
898        }
899
900        if same_dict {
901            let rhs_codes: HashSet<T> = rarr.data[roff..roff + rlen].iter().copied().collect();
902            for i in 0..llen {
903                if mask
904                    .as_ref()
905                    .map_or(true, |m| unsafe { m.get_unchecked(i) })
906                {
907                    let code = larr.data[loff + i];
908                    if rhs_codes.contains(&code) {
909                        unsafe { out.set_unchecked(i, true) };
910                    }
911                }
912            }
913            return Ok(BooleanArray {
914                data: out.into(),
915                null_mask: mask,
916                len: llen,
917                _phantom: PhantomData,
918            });
919        }
920    }
921
922    let rhs_strings: HashSet<&str> = (0..rlen)
923        .filter(|&i| {
924            rarr.null_mask
925                .as_ref()
926                .map_or(true, |m| unsafe { m.get_unchecked(roff + i) })
927        })
928        .map(|i| unsafe { rarr.get_str_unchecked(roff + i) })
929        .collect();
930
931    for i in 0..llen {
932        if mask
933            .as_ref()
934            .map_or(true, |m| unsafe { m.get_unchecked(i) })
935        {
936            let s = unsafe { larr.get_str_unchecked(loff + i) };
937            if rhs_strings.contains(s) {
938                unsafe { out.set_unchecked(i, true) };
939            }
940        }
941    }
942
943    Ok(BooleanArray {
944        data: out.into(),
945        null_mask: mask,
946        len: llen,
947        _phantom: PhantomData,
948    })
949}
950
951// Is Null and Not null predicates
952
953/// Generate boolean mask indicating null elements in any array type.
954pub fn is_null_array(arr: &Array) -> Result<BooleanArray<()>, KernelError> {
955    let not_null = is_not_null_array(arr)?;
956    Ok(!not_null)
957}
958/// Generate boolean mask indicating non-null elements in any array type.
959pub fn is_not_null_array(arr: &Array) -> Result<BooleanArray<()>, KernelError> {
960    let len = arr.len();
961    let mut data = Bitmask::new_set_all(len, false);
962
963    if let Some(mask) = arr.null_mask() {
964        data = mask.clone();
965    } else {
966        data.fill(true);
967    }
968    Ok(BooleanArray {
969        data,
970        null_mask: None,
971        len,
972        _phantom: PhantomData,
973    })
974}
975
976// Array in, between , not in
977/// Test membership of array elements in values set, dispatching by array type.
978pub fn in_array(input: &Array, values: &Array) -> Result<BooleanArray<()>, KernelError> {
979    match (input, values) {
980        (
981            Array::NumericArray(NumericArray::Int32(a)),
982            Array::NumericArray(NumericArray::Int32(b)),
983        ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
984        (
985            Array::NumericArray(NumericArray::Int64(a)),
986            Array::NumericArray(NumericArray::Int64(b)),
987        ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
988        (
989            Array::NumericArray(NumericArray::UInt32(a)),
990            Array::NumericArray(NumericArray::UInt32(b)),
991        ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
992        (
993            Array::NumericArray(NumericArray::UInt64(a)),
994            Array::NumericArray(NumericArray::UInt64(b)),
995        ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
996        (
997            Array::NumericArray(NumericArray::Float32(a)),
998            Array::NumericArray(NumericArray::Float32(b)),
999        ) => cmp_in_f_mask(&a.data, &b.data, a.null_mask.as_ref()),
1000        (
1001            Array::NumericArray(NumericArray::Float64(a)),
1002            Array::NumericArray(NumericArray::Float64(b)),
1003        ) => cmp_in_f_mask(&a.data, &b.data, a.null_mask.as_ref()),
1004        (Array::TextArray(TextArray::String32(a)), Array::TextArray(TextArray::String32(b))) => {
1005            cmp_str_in((**a).tuple_ref(0, a.len()), (**b).tuple_ref(0, b.len()))
1006        }
1007        (Array::BooleanArray(a), Array::BooleanArray(b)) => {
1008            cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref())
1009        }
1010        (
1011            Array::TextArray(TextArray::Categorical32(a)),
1012            Array::TextArray(TextArray::Categorical32(b)),
1013        ) => cmp_dict_in((**a).tuple_ref(0, a.len()), (**b).tuple_ref(0, b.len())),
1014        _ => unimplemented!(),
1015    }
1016}
1017
1018#[inline(always)]
1019/// Test non-membership of array elements in values set, dispatching by array type.
1020pub fn not_in_array(input: &Array, values: &Array) -> Result<BooleanArray<()>, KernelError> {
1021    let result = in_array(input, values)?;
1022    Ok(!result)
1023}
1024
1025/// Test if array elements fall between min/max bounds, dispatching by array type.
1026pub fn between_array(input: &Array, min: &Array, max: &Array) -> Result<Array, KernelError> {
1027    macro_rules! between_case {
1028        ($variant:ident, $cmp:ident) => {{
1029            let arr = match input {
1030                Array::NumericArray(NumericArray::$variant(arr)) => arr,
1031                _ => unreachable!(),
1032            };
1033            let mins = match min {
1034                Array::NumericArray(NumericArray::$variant(arr)) => arr,
1035                _ => unreachable!(),
1036            };
1037            let maxs = match max {
1038                Array::NumericArray(NumericArray::$variant(arr)) => arr,
1039                _ => unreachable!(),
1040            };
1041            let rhs: Vec64<_> = mins
1042                .data
1043                .iter()
1044                .zip(&maxs.data)
1045                .flat_map(|(&lo, &hi)| [lo, hi])
1046                .collect();
1047            Ok(Array::BooleanArray(
1048                $cmp(
1049                    &arr.data,
1050                    &rhs,
1051                    arr.null_mask.as_ref(),
1052                    arr.null_mask.is_some(),
1053                )?
1054                .into(),
1055            ))
1056        }};
1057    }
1058
1059    match (input, min, max) {
1060        (
1061            Array::NumericArray(NumericArray::Int32(..)),
1062            Array::NumericArray(NumericArray::Int32(..)),
1063            Array::NumericArray(NumericArray::Int32(..)),
1064        ) => between_case!(Int32, between_i32),
1065        (
1066            Array::NumericArray(NumericArray::Int64(..)),
1067            Array::NumericArray(NumericArray::Int64(..)),
1068            Array::NumericArray(NumericArray::Int64(..)),
1069        ) => between_case!(Int64, between_i64),
1070        (
1071            Array::NumericArray(NumericArray::UInt32(..)),
1072            Array::NumericArray(NumericArray::UInt32(..)),
1073            Array::NumericArray(NumericArray::UInt32(..)),
1074        ) => between_case!(UInt32, between_u32),
1075        (
1076            Array::NumericArray(NumericArray::UInt64(..)),
1077            Array::NumericArray(NumericArray::UInt64(..)),
1078            Array::NumericArray(NumericArray::UInt64(..)),
1079        ) => between_case!(UInt64, between_u64),
1080        (
1081            Array::NumericArray(NumericArray::Float32(..)),
1082            Array::NumericArray(NumericArray::Float32(..)),
1083            Array::NumericArray(NumericArray::Float32(..)),
1084        ) => between_case!(Float32, between_generic),
1085        (
1086            Array::NumericArray(NumericArray::Float64(..)),
1087            Array::NumericArray(NumericArray::Float64(..)),
1088            Array::NumericArray(NumericArray::Float64(..)),
1089        ) => between_case!(Float64, between_generic),
1090        _ => Err(KernelError::UnsupportedType(
1091            "Unsupported Type Error.".to_string(),
1092        )),
1093    }
1094}
1095
1096/// Bitwise NOT of a bit-packed boolean mask window.
1097/// Offset is a bit offset; len is in bits.
1098/// Requires offset % 64 == 0 for word-level SIMD processing.
1099#[inline]
1100pub fn not_bool<const LANES: usize>(
1101    src: BooleanAVT<'_, ()>,
1102) -> Result<BooleanArray<()>, KernelError>
1103where
1104{
1105    let (arr, offset, len) = src;
1106
1107    if offset % 64 != 0 {
1108        return Err(KernelError::InvalidArguments(format!(
1109            "not_bool: offset must be 64-bit aligned (got offset={})",
1110            offset
1111        )));
1112    }
1113
1114    let null_mask = arr.null_mask.as_ref().map(|nm| nm.slice_clone(offset, len));
1115
1116    let data = {
1117        #[cfg(feature = "simd")]
1118        {
1119            minarrow::kernels::bitmask::simd::not_mask_simd::<LANES>((&arr.data, offset, len))
1120        }
1121        #[cfg(not(feature = "simd"))]
1122        {
1123            minarrow::kernels::bitmask::std::not_mask((&arr.data, offset, len))
1124        }
1125    };
1126
1127    Ok(BooleanArray {
1128        data,
1129        null_mask,
1130        len,
1131        _phantom: core::marker::PhantomData,
1132    })
1133}
1134
1135/// Logical AND/OR/XOR of two bit-packed boolean masks over a window.
1136/// Offsets are bit offsets. Length is in bits.
1137/// Panics if offsets are not 64-bit aligned.
1138pub fn apply_logical_bool<const LANES: usize>(
1139    lhs: BooleanAVT<'_, ()>,
1140    rhs: BooleanAVT<'_, ()>,
1141    op: LogicalOperator,
1142) -> Result<BooleanArray<()>, KernelError>
1143where
1144{
1145    let (lhs_arr, lhs_off, len) = lhs;
1146    let (rhs_arr, rhs_off, rlen) = rhs;
1147
1148    if len != rlen {
1149        return Err(KernelError::LengthMismatch(format!(
1150            "logical_bool: window length mismatch (lhs: {}, rhs: {})",
1151            len, rlen
1152        )));
1153    }
1154    if lhs_off % 64 != 0 || rhs_off % 64 != 0 {
1155        return Err(KernelError::InvalidArguments(format!(
1156            "logical_bool: offsets must be 64-bit aligned (lhs: {}, rhs: {})",
1157            lhs_off, rhs_off
1158        )));
1159    }
1160
1161    // Apply bitmask kernel for the logical operation.
1162
1163    #[cfg(feature = "simd")]
1164    let data = match op {
1165        LogicalOperator::And => {
1166            and_masks_simd::<LANES>((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1167        }
1168        LogicalOperator::Or => {
1169            or_masks_simd::<LANES>((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1170        }
1171        LogicalOperator::Xor => {
1172            xor_masks_simd::<LANES>((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1173        }
1174    };
1175
1176    // Merge validity (null) masks using AND
1177    #[cfg(feature = "simd")]
1178    let null_mask = match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
1179        (None, None) => None,
1180        (Some(a), None) | (None, Some(a)) => Some(a.slice_clone(lhs_off, len)),
1181        (Some(a), Some(b)) => Some(and_masks_simd::<LANES>(
1182            (a, lhs_off, len),
1183            (b, rhs_off, len),
1184        )),
1185    };
1186
1187    #[cfg(not(feature = "simd"))]
1188    let data = match op {
1189        LogicalOperator::And => {
1190            and_masks((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1191        }
1192        LogicalOperator::Or => {
1193            or_masks((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1194        }
1195        LogicalOperator::Xor => {
1196            xor_masks((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1197        }
1198    };
1199
1200    #[cfg(not(feature = "simd"))]
1201    let null_mask = match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
1202        (None, None) => None,
1203        (Some(a), None) | (None, Some(a)) => Some(a.slice_clone(lhs_off, len)),
1204        (Some(a), Some(b)) => Some(and_masks((a, lhs_off, len), (b, rhs_off, len))),
1205    };
1206
1207    Ok(BooleanArray {
1208        data,
1209        null_mask,
1210        len,
1211        _phantom: PhantomData,
1212    })
1213}
1214
1215#[cfg(test)]
1216mod tests {
1217    use minarrow::structs::variants::categorical::CategoricalArray;
1218    use minarrow::structs::variants::float::FloatArray;
1219    use minarrow::structs::variants::integer::IntegerArray;
1220    use minarrow::structs::variants::string::StringArray;
1221    use minarrow::{Array, Bitmask, BooleanArray, vec64};
1222
1223    use super::*;
1224
1225    // --- helpers ---
1226
1227    fn bm(bits: &[bool]) -> Bitmask {
1228        let mut m = Bitmask::new_set_all(bits.len(), false);
1229        for (i, &b) in bits.iter().enumerate() {
1230            m.set(i, b);
1231        }
1232        m
1233    }
1234
1235    fn assert_bool(arr: &BooleanArray<()>, expect: &[bool], expect_mask: Option<&[bool]>) {
1236        assert_eq!(arr.len, expect.len(), "length mismatch");
1237        for i in 0..expect.len() {
1238            assert_eq!(arr.data.get(i), expect[i], "val @ {i}");
1239        }
1240        match (expect_mask, &arr.null_mask) {
1241            (None, None) => {}
1242            (Some(exp), Some(mask)) => {
1243                for (i, &b) in exp.iter().enumerate() {
1244                    assert_eq!(mask.get(i), b, "mask @ {i}");
1245                }
1246            }
1247            (None, Some(mask)) => {
1248                // all mask bits should be true
1249                for i in 0..arr.len {
1250                    assert!(mask.get(i), "unexpected false mask @ {i}");
1251                }
1252            }
1253            (Some(_), None) => panic!("expected null mask"),
1254        }
1255    }
1256
1257    fn i32_arr(data: &[i32]) -> IntegerArray<i32> {
1258        IntegerArray::from_slice(data)
1259    }
1260    fn f32_arr(data: &[f32]) -> FloatArray<f32> {
1261        FloatArray::from_slice(data)
1262    }
1263    fn str_arr<T: Integer>(vals: &[&str]) -> StringArray<T> {
1264        StringArray::<T>::from_slice(vals)
1265    }
1266    fn dict_arr<T: Integer>(vals: &[&str]) -> CategoricalArray<T> {
1267        let owned: Vec<&str> = vals.to_vec();
1268        CategoricalArray::<T>::from_values(owned)
1269    }
1270    //  BETWEEN
1271
1272    #[test]
1273    fn between_i32_scalar_rhs() {
1274        let lhs = vec64![1, 3, 5, 7];
1275        let rhs = vec64![2, 6];
1276        let out = between_i32(&lhs, &rhs, None, false).unwrap();
1277        assert_bool(&out, &[false, true, true, false], None);
1278    }
1279
1280    #[test]
1281    fn between_i32_per_row_rhs() {
1282        let lhs = vec64![5, 9, 2, 8];
1283        let rhs = vec64![0, 10, 0, 4, 2, 2, 8, 9]; // min/max for each row
1284        let out = between_i32(&lhs, &rhs, None, false).unwrap();
1285        assert_bool(&out, &[true, false, true, true], None);
1286    }
1287
1288    #[test]
1289    fn between_i32_nulls_propagate() {
1290        let lhs = vec64![5, 9, 2, 8];
1291        let rhs = vec64![0, 10, 0, 4, 2, 2, 8, 9];
1292        let mask = bm(&[true, false, true, true]);
1293        let out = between_i32(&lhs, &rhs, Some(&mask), true).unwrap();
1294        assert_bool(
1295            &out,
1296            &[true, false, true, true],
1297            Some(&[true, false, true, true]),
1298        );
1299    }
1300
1301    #[cfg(feature = "extended_numeric_types")]
1302    #[test]
1303    fn between_i16_works() {
1304        let lhs = vec64![10i16, 12, 99];
1305        let rhs = vec64![10i16, 12];
1306        let out = in_i16(&lhs, &rhs, None, false).unwrap();
1307        assert_bool(&out, &[true, true, false], None);
1308    }
1309
1310    #[test]
1311    fn between_f64_scalar_and_nulls() {
1312        let lhs = vec64![1.0, 5.0, 8.0, 20.0];
1313        let rhs = vec64![4.0, 10.0];
1314        let mask = bm(&[true, false, true, true]);
1315        let out = between_f64(&lhs, &rhs, Some(&mask), true).unwrap();
1316        assert_bool(
1317            &out,
1318            &[false, false, true, false],
1319            Some(&[true, false, true, true]),
1320        );
1321    }
1322
1323    #[test]
1324    fn between_f32_generic_dispatch() {
1325        let lhs = vec64![0.1f32, 0.5, 1.2, -1.0];
1326        let rhs = vec64![0.0, 1.0];
1327        let out = cmp_between(&lhs, &rhs).unwrap();
1328        assert_bool(&out, &[true, true, false, false], None);
1329    }
1330
1331    #[test]
1332    fn between_masked_dispatch() {
1333        let lhs = vec64![1i32, 2, 3];
1334        let rhs = vec64![0, 2];
1335        let mask = bm(&[true, false, true]);
1336        let out = cmp_between_mask(&lhs, &rhs, Some(&mask)).unwrap();
1337        assert_bool(&out, &[true, false, false], Some(&[true, false, true]));
1338    }
1339
1340    // IN
1341
1342    #[test]
1343    fn in_i32_small_rhs() {
1344        let lhs = vec64![1, 2, 3, 4, 5];
1345        let rhs = vec64![2, 4];
1346        let out = in_i32(&lhs, &rhs, None, false).unwrap();
1347        assert_bool(&out, &[false, true, false, true, false], None);
1348    }
1349
1350    #[test]
1351    fn in_i32_with_nulls() {
1352        let lhs = vec64![7, 8, 9];
1353        let rhs = vec64![8];
1354        let mask = bm(&[true, false, true]);
1355        let out = in_i32(&lhs, &rhs, Some(&mask), true).unwrap();
1356        assert_bool(&out, &[false, false, false], Some(&[true, false, true]));
1357    }
1358
1359    #[test]
1360    fn in_i64_large_rhs() {
1361        let lhs = vec64![1i64, 2, 3, 7, 8, 15];
1362        let rhs: Vec<i64> = (2..10).collect();
1363        let out = in_i64(&lhs, &rhs, None, false).unwrap();
1364        assert_bool(&out, &[false, true, true, true, true, false], None);
1365    }
1366
1367    #[cfg(feature = "extended_numeric_types")]
1368    #[test]
1369    fn in_u8_small_rhs() {
1370        let lhs = vec64![1u8, 2, 3, 4];
1371        let rhs = vec64![2u8, 3];
1372        let out = in_u8(&lhs, &rhs, None, false).unwrap();
1373        assert_bool(&out, &[false, true, true, false], None);
1374    }
1375
1376    #[test]
1377    fn in_float_nan_and_normal() {
1378        let lhs = vec64![1.0f32, f32::NAN, 7.0];
1379        let rhs = vec64![f32::NAN, 7.0];
1380        let out = in_f32(&lhs, &rhs, None, false).unwrap();
1381        assert_bool(&out, &[false, true, true], None);
1382    }
1383
1384    // BETWEEN / IN
1385
1386    #[test]
1387    fn string_between() {
1388        let lhs = str_arr::<u32>(&["aa", "bb", "zz"]);
1389        let rhs = str_arr::<u32>(&["b", "y"]);
1390        let lhs_slice = (&lhs, 0, lhs.len());
1391        let rhs_slice = (&rhs, 0, rhs.len());
1392        let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1393        assert_bool(&out, &[false, true, false], None);
1394    }
1395
1396    #[test]
1397    fn string_between_chunk() {
1398        let lhs = str_arr::<u32>(&["0", "aa", "bb", "zz", "9"]);
1399        let rhs = str_arr::<u32>(&["a", "b", "y", "z"]);
1400        // Windowed: skip first/last for lhs; use a window for rhs
1401        let lhs_slice = (&lhs, 1, 3); // ["aa", "bb", "zz"]
1402        let rhs_slice = (&rhs, 1, 2); // ["b", "y"]
1403        let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1404        assert_bool(&out, &[false, true, false], None);
1405    }
1406
1407    #[test]
1408    fn string_in_basic() {
1409        let lhs = str_arr::<u32>(&["x", "y", "z"]);
1410        let rhs = str_arr::<u32>(&["y", "a"]);
1411        let lhs_slice = (&lhs, 0, lhs.len());
1412        let rhs_slice = (&rhs, 0, rhs.len());
1413        let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1414        assert_bool(&out, &[false, true, false], None);
1415    }
1416
1417    #[test]
1418    fn string_in_basic_chunk() {
1419        let lhs = str_arr::<u32>(&["0", "x", "y", "z", "9"]);
1420        let rhs = str_arr::<u32>(&["b", "y", "a", "c"]);
1421        let lhs_slice = (&lhs, 1, 3); // ["x", "y", "z"]
1422        let rhs_slice = (&rhs, 1, 2); // ["y", "a"]
1423        let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1424        assert_bool(&out, &[false, true, false], None);
1425    }
1426
1427    #[test]
1428    fn dict_between() {
1429        let lhs = dict_arr::<u32>(&["cat", "dog", "emu"]);
1430        let rhs = dict_arr::<u32>(&["cobra", "dove"]);
1431        let lhs_slice = (&lhs, 0, lhs.len());
1432        let rhs_slice = (&rhs, 0, rhs.len());
1433        let out = cmp_dict_between(lhs_slice, rhs_slice).unwrap();
1434        assert_bool(&out, &[false, true, false], None);
1435    }
1436
1437    #[test]
1438    fn dict_between_chunk() {
1439        let lhs = dict_arr::<u32>(&["a", "cat", "dog", "emu", "z"]);
1440        let rhs = dict_arr::<u32>(&["a", "cobra", "dove", "zz"]);
1441        let lhs_slice = (&lhs, 1, 3); // ["cat", "dog", "emu"]
1442        let rhs_slice = (&rhs, 1, 2); // ["cobra", "dove"]
1443        let out = cmp_dict_between(lhs_slice, rhs_slice).unwrap();
1444        assert_bool(&out, &[false, true, false], None);
1445    }
1446
1447    #[test]
1448    fn dict_in_membership() {
1449        let lhs = dict_arr::<u32>(&["aa", "bb", "cc"]);
1450        let rhs = dict_arr::<u32>(&["bb", "dd"]);
1451        let lhs_slice = (&lhs, 0, lhs.len());
1452        let rhs_slice = (&rhs, 0, rhs.len());
1453        let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1454        assert_bool(&out, &[false, true, false], None);
1455    }
1456
1457    #[test]
1458    fn dict_in_membership_chunk() {
1459        let lhs = dict_arr::<u32>(&["0", "aa", "bb", "cc", "9"]);
1460        let rhs = dict_arr::<u32>(&["a", "bb", "dd", "zz"]);
1461        let lhs_slice = (&lhs, 1, 3); // ["aa", "bb", "cc"]
1462        let rhs_slice = (&rhs, 1, 2); // ["bb", "dd"]
1463        let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1464        assert_bool(&out, &[false, true, false], None);
1465    }
1466
1467    #[test]
1468    fn string_between_nulls() {
1469        let mut lhs = str_arr::<u32>(&["foo", "bar", "baz"]);
1470        lhs.null_mask = Some(bm(&[true, false, true]));
1471        let rhs = str_arr::<u32>(&["a", "zzz"]);
1472        let lhs_slice = (&lhs, 0, lhs.len());
1473        let rhs_slice = (&rhs, 0, rhs.len());
1474        let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1475        assert_bool(&out, &[true, false, true], Some(&[true, false, true]));
1476    }
1477
1478    #[test]
1479    fn string_between_nulls_chunk() {
1480        let mut lhs = str_arr::<u32>(&["0", "foo", "bar", "baz", "z"]);
1481        lhs.null_mask = Some(bm(&[true, true, false, true, true]));
1482        let rhs = str_arr::<u32>(&["0", "a", "zzz", "9"]);
1483        let lhs_slice = (&lhs, 1, 3); // ["foo", "bar", "baz"]
1484        let rhs_slice = (&rhs, 1, 2); // ["a", "zzz"]
1485        let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1486        assert_bool(&out, &[true, false, true], Some(&[true, false, true]));
1487    }
1488
1489    #[test]
1490    fn dict_in_nulls() {
1491        let mut lhs = dict_arr::<u32>(&["one", "two", "three"]);
1492        lhs.null_mask = Some(bm(&[false, true, true]));
1493        let rhs = dict_arr::<u32>(&["two", "four"]);
1494        let lhs_slice = (&lhs, 0, lhs.len());
1495        let rhs_slice = (&rhs, 0, rhs.len());
1496        let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1497        assert_bool(&out, &[false, true, false], Some(&[false, true, true]));
1498    }
1499
1500    #[test]
1501    fn dict_in_nulls_chunk() {
1502        let mut lhs = dict_arr::<u32>(&["x", "one", "two", "three", "z"]);
1503        lhs.null_mask = Some(bm(&[true, false, true, true, true]));
1504        let rhs = dict_arr::<u32>(&["a", "two", "four", "b"]);
1505        let lhs_slice = (&lhs, 1, 3); // ["one", "two", "three"]
1506        let rhs_slice = (&rhs, 1, 2); // ["two", "four"]
1507        let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1508        assert_bool(&out, &[false, true, false], Some(&[false, true, true]));
1509    }
1510
1511    // Boolean/Null
1512
1513    #[test]
1514    fn is_null_and_is_not_null() {
1515        let mut arr = i32_arr(&[1, 2, 0]);
1516        arr.null_mask = Some(bm(&[true, false, true]));
1517        let array = Array::from_int32(arr.clone());
1518
1519        let not_null = is_not_null_array(&array).unwrap();
1520        let is_null = is_null_array(&array).unwrap();
1521
1522        assert_bool(&not_null, &[true, false, true], None);
1523        assert_bool(&is_null, &[false, true, false], None);
1524    }
1525
1526    #[test]
1527    fn is_null_not_null_dense() {
1528        let arr = i32_arr(&[1, 2, 3]);
1529        let array = Array::from_int32(arr.clone());
1530        let is_null = is_null_array(&array).unwrap();
1531        assert_bool(&is_null, &[false, false, false], None);
1532        let not_null = is_not_null_array(&array).unwrap();
1533        assert_bool(&not_null, &[true, true, true], None);
1534    }
1535
1536    //  Array dispatch in_array, not_in_array, between_array ----
1537
1538    #[test]
1539    fn in_array_int32_dispatch() {
1540        let inp = Array::from_int32(i32_arr(&[10, 20, 30]));
1541        let vals = Array::from_int32(i32_arr(&[20, 40]));
1542        let out = in_array(&inp, &vals).unwrap();
1543        assert_bool(&out, &[false, true, false], None);
1544
1545        let out_not = not_in_array(&inp, &vals).unwrap();
1546        assert_bool(&out_not, &[true, false, true], None);
1547    }
1548
1549    #[test]
1550    fn in_array_f32_dispatch() {
1551        let inp = Array::from_float32(f32_arr(&[1.0, f32::NAN, 7.0]));
1552        let vals = Array::from_float32(f32_arr(&[f32::NAN, 7.0]));
1553        let out = in_array(&inp, &vals).unwrap();
1554        assert_bool(&out, &[false, true, true], None);
1555    }
1556
1557    #[test]
1558    fn in_array_string_dispatch() {
1559        let inp = Array::from_string32(str_arr::<u32>(&["a", "b", "c"]));
1560        let vals = Array::from_string32(str_arr::<u32>(&["b", "d"]));
1561        let out = in_array(&inp, &vals).unwrap();
1562        assert_bool(&out, &[false, true, false], None);
1563    }
1564
1565    #[test]
1566    fn in_array_dictionary_dispatch() {
1567        let inp = Array::from_categorical32(dict_arr::<u32>(&["aa", "bb", "cc"]));
1568        let vals = Array::from_categorical32(dict_arr::<u32>(&["bb", "cc"]));
1569        let out = in_array(&inp, &vals).unwrap();
1570        assert_bool(&out, &[false, true, true], None);
1571    }
1572
1573    #[test]
1574    fn between_array_int32_rows() {
1575        let inp = Array::from_int32(i32_arr(&[5, 15, 25]));
1576        let min = Array::from_int32(i32_arr(&[0, 10, 20]));
1577        let max = Array::from_int32(i32_arr(&[10, 20, 30]));
1578
1579        let out = between_array(&inp, &min, &max).unwrap();
1580        match out {
1581            Array::BooleanArray(b) => assert_bool(&b, &[true, true, true], None),
1582            _ => panic!("expected Bool array"),
1583        }
1584    }
1585
1586    #[test]
1587    fn between_array_float_generic() {
1588        let inp = Array::from_float32(f32_arr(&[0.5, 1.5, 2.5]));
1589        let min = Array::from_float32(f32_arr(&[0.0, 1.0, 2.0]));
1590        let max = Array::from_float32(f32_arr(&[1.0, 2.0, 3.0]));
1591
1592        let out = between_array(&inp, &min, &max).unwrap();
1593        match out {
1594            Array::BooleanArray(b) => assert_bool(&b, &[true, true, true], None),
1595            _ => panic!("expected Bool"),
1596        }
1597    }
1598
1599    #[test]
1600    fn between_array_type_mismatch() {
1601        let inp = Array::from_int32(i32_arr(&[1, 2, 3]));
1602        let min = Array::from_float32(f32_arr(&[0.0, 0.0, 0.0]));
1603        let max = Array::from_float32(f32_arr(&[5.0, 5.0, 5.0]));
1604        let err = between_array(&inp, &min, &max).unwrap_err();
1605        match err {
1606            KernelError::UnsupportedType(_) => {}
1607            _ => panic!("Expected UnsupportedType error"),
1608        }
1609    }
1610
1611    // all integer types, short and long
1612
1613    #[test]
1614    fn in_integers_various_types() {
1615        #[cfg(feature = "extended_numeric_types")]
1616        {
1617            let u8_lhs = vec64![1u8, 2, 3, 5];
1618            let u8_rhs = vec64![3u8, 5, 8];
1619            let out = in_u8(&u8_lhs, &u8_rhs, None, false).unwrap();
1620            assert_bool(&out, &[false, false, true, true], None);
1621
1622            let u16_lhs = vec64![100u16, 200, 300];
1623            let u16_rhs = vec64![200u16, 500];
1624            let out = in_u16(&u16_lhs, &u16_rhs, None, false).unwrap();
1625            assert_bool(&out, &[false, true, false], None);
1626
1627            let i16_lhs = vec64![10i16, 15, 42];
1628            let i16_rhs = vec64![15i16, 42, 77];
1629            let out = in_i16(&i16_lhs, &i16_rhs, None, false).unwrap();
1630            assert_bool(&out, &[false, true, true], None);
1631        }
1632
1633        let u32_lhs = vec64![0u32, 1, 2, 9];
1634        let u32_rhs = vec64![9u32, 1];
1635        let out = in_u32(&u32_lhs, &u32_rhs, None, false).unwrap();
1636        assert_bool(&out, &[false, true, false, true], None);
1637
1638        let i64_lhs = vec64![1i64, 9, 10];
1639        let i64_rhs = vec64![2i64, 10, 20];
1640        let out = in_i64(&i64_lhs, &i64_rhs, None, false).unwrap();
1641        assert_bool(&out, &[false, false, true], None);
1642
1643        let u64_lhs = vec64![1u64, 2, 3, 4];
1644        let u64_rhs = vec64![2u64, 4, 8];
1645        let out = in_u64(&u64_lhs, &u64_rhs, None, false).unwrap();
1646        assert_bool(&out, &[false, true, false, true], None);
1647    }
1648
1649    // empty input edge
1650
1651    #[test]
1652    fn between_and_in_empty_inputs() {
1653        // Between, scalar rhs (for numeric arrays, no slice tuple needed)
1654        let lhs: [i32; 0] = [];
1655        let rhs = vec64![0, 1];
1656        let out = between_i32(&lhs, &rhs, None, false).unwrap();
1657        assert_eq!(out.len, 0);
1658
1659        // In, any rhs (for numeric arrays, no slice tuple needed)
1660        let lhs: [i32; 0] = [];
1661        let rhs = vec64![1, 2, 3];
1662        let out = in_i32(&lhs, &rhs, None, false).unwrap();
1663        assert_eq!(out.len, 0);
1664
1665        // String, in (slice API)
1666        let lhs = str_arr::<u32>(&[]);
1667        let rhs = str_arr::<u32>(&["a", "b"]);
1668        let lhs_slice = (&lhs, 0, lhs.len());
1669        let rhs_slice = (&rhs, 0, rhs.len());
1670        let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1671        assert_eq!(out.len, 0);
1672    }
1673
1674    #[test]
1675    fn between_and_in_empty_inputs_chunk() {
1676        // Only applies to the string in version
1677        let lhs = str_arr::<u32>(&["x", "y"]);
1678        let rhs = str_arr::<u32>(&["a", "b", "c"]);
1679        let lhs_slice = (&lhs, 1, 0); // zero-length window
1680        let rhs_slice = (&rhs, 1, 2); // window ["b", "c"]
1681        let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1682        assert_eq!(out.len, 0);
1683    }
1684
1685    #[test]
1686    fn between_per_row_bounds_on_last_row() {
1687        // Coverage: last row per-row
1688        let lhs = vec64![0i32, 10, 20, 30];
1689        let rhs = vec64![0, 5, 5, 15, 15, 25, 25, 35];
1690        let out = between_i32(&lhs, &rhs, None, false).unwrap();
1691        assert_bool(&out, &[true, true, true, true], None);
1692    }
1693
1694    #[test]
1695    fn test_cmp_dict_in_force_fallback() {
1696        // lhs and rhs have different unique_values lengths
1697        let mut lhs = dict_arr::<u32>(&["a", "b", "c", "a"]);
1698        lhs.unique_values = vec64!["a".to_string(), "b".to_string(), "c".to_string()]; // len=3
1699        let mut rhs = dict_arr::<u32>(&["b", "x", "y", "z"]);
1700        rhs.unique_values = vec64![
1701            "b".to_string(),
1702            "x".to_string(),
1703            "y".to_string(),
1704            "z".to_string()
1705        ]; // len=4
1706        lhs.null_mask = Some(bm(&[true, true, true, true]));
1707        let lhs_slice = (&lhs, 0, lhs.len());
1708        let rhs_slice = (&rhs, 0, rhs.len());
1709        let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1710        // should fall back to string-matching: only "b" matches
1711        assert_bool(
1712            &out,
1713            &[false, true, false, false],
1714            Some(&[true, true, true, true]),
1715        );
1716    }
1717
1718    #[test]
1719    fn test_cmp_dict_in_force_fallback_chunk() {
1720        let mut lhs = dict_arr::<u32>(&["z", "a", "b", "c", "a", "q"]);
1721        lhs.unique_values = vec64![
1722            "z".to_string(),
1723            "a".to_string(),
1724            "b".to_string(),
1725            "c".to_string(),
1726            "q".to_string()
1727        ];
1728        let mut rhs = dict_arr::<u32>(&["x", "b", "x", "y", "z"]);
1729        rhs.unique_values = vec64![
1730            "x".to_string(),
1731            "b".to_string(),
1732            "y".to_string(),
1733            "z".to_string()
1734        ];
1735        lhs.null_mask = Some(bm(&[true, true, true, true, true, true]));
1736        // Window: pick ["a", "b", "c", "a"] and ["b", "x", "y", "z"]
1737        let lhs_slice = (&lhs, 1, 4);
1738        let rhs_slice = (&rhs, 1, 4);
1739        let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1740        // Only "b" matches (index 1 of window)
1741        assert_bool(
1742            &out,
1743            &[false, true, false, false],
1744            Some(&[true, true, true, true]),
1745        );
1746    }
1747
1748    #[test]
1749    fn test_in_array_empty_rhs() {
1750        let arr = Array::from_int32(i32_arr(&[1, 2, 3]));
1751        let empty = Array::from_int32(i32_arr(&[]));
1752        let out = in_array(&arr, &empty).unwrap();
1753        // must be all false, and mask preserved (no mask => all bits true)
1754        assert_bool(&out, &[false, false, false], None);
1755    }
1756}