vortex_array/arrays/primitive/compute/
between.rs

1use arrow_buffer::BooleanBuffer;
2use vortex_dtype::{NativePType, Nullability, match_each_native_ptype};
3use vortex_error::VortexResult;
4
5use crate::arrays::{BoolArray, PrimitiveArray, PrimitiveVTable};
6use crate::compute::{BetweenKernel, BetweenKernelAdapter, BetweenOptions, StrictComparison};
7use crate::vtable::ValidityHelper;
8use crate::{Array, ArrayRef, IntoArray, register_kernel};
9
10impl BetweenKernel for PrimitiveVTable {
11    fn between(
12        &self,
13        arr: &PrimitiveArray,
14        lower: &dyn Array,
15        upper: &dyn Array,
16        options: &BetweenOptions,
17    ) -> VortexResult<Option<ArrayRef>> {
18        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
19            return Ok(None);
20        };
21
22        // Note, we know that have checked before that the lower and upper bounds are not constant
23        // null values
24
25        let nullability =
26            arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
27
28        Ok(Some(match_each_native_ptype!(arr.ptype(), |P| {
29            between_impl::<P>(
30                arr,
31                P::try_from(lower)?,
32                P::try_from(upper)?,
33                nullability,
34                options,
35            )
36        })))
37    }
38}
39
40register_kernel!(BetweenKernelAdapter(PrimitiveVTable).lift());
41
42fn between_impl<T: NativePType + Copy>(
43    arr: &PrimitiveArray,
44    lower: T,
45    upper: T,
46    nullability: Nullability,
47    options: &BetweenOptions,
48) -> ArrayRef {
49    match (options.lower_strict, options.upper_strict) {
50        // Note: these comparisons are explicitly passed in to allow function impl inlining
51        (StrictComparison::Strict, StrictComparison::Strict) => between_impl_(
52            arr,
53            lower,
54            NativePType::is_lt,
55            upper,
56            NativePType::is_lt,
57            nullability,
58        ),
59        (StrictComparison::Strict, StrictComparison::NonStrict) => between_impl_(
60            arr,
61            lower,
62            NativePType::is_lt,
63            upper,
64            NativePType::is_le,
65            nullability,
66        ),
67        (StrictComparison::NonStrict, StrictComparison::Strict) => between_impl_(
68            arr,
69            lower,
70            NativePType::is_le,
71            upper,
72            NativePType::is_lt,
73            nullability,
74        ),
75        (StrictComparison::NonStrict, StrictComparison::NonStrict) => between_impl_(
76            arr,
77            lower,
78            NativePType::is_le,
79            upper,
80            NativePType::is_le,
81            nullability,
82        ),
83    }
84}
85
86fn between_impl_<T>(
87    arr: &PrimitiveArray,
88    lower: T,
89    lower_fn: impl Fn(T, T) -> bool,
90    upper: T,
91    upper_fn: impl Fn(T, T) -> bool,
92    nullability: Nullability,
93) -> ArrayRef
94where
95    T: NativePType + Copy,
96{
97    let slice = arr.as_slice::<T>();
98    BoolArray::new(
99        BooleanBuffer::collect_bool(slice.len(), |idx| {
100            // We only iterate upto arr len and |arr| == |slice|.
101            let i = unsafe { *slice.get_unchecked(idx) };
102            lower_fn(lower, i) & upper_fn(i, upper)
103        }),
104        arr.validity().clone().union_nullability(nullability),
105    )
106    .into_array()
107}