vortex_array/arrays/primitive/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_buffer::BooleanBuffer;
5use vortex_dtype::{NativePType, Nullability, match_each_native_ptype};
6use vortex_error::VortexResult;
7
8use crate::arrays::{BoolArray, PrimitiveArray, PrimitiveVTable};
9use crate::compute::{BetweenKernel, BetweenKernelAdapter, BetweenOptions, StrictComparison};
10use crate::vtable::ValidityHelper;
11use crate::{Array, ArrayRef, IntoArray, register_kernel};
12
13impl BetweenKernel for PrimitiveVTable {
14    fn between(
15        &self,
16        arr: &PrimitiveArray,
17        lower: &dyn Array,
18        upper: &dyn Array,
19        options: &BetweenOptions,
20    ) -> VortexResult<Option<ArrayRef>> {
21        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
22            return Ok(None);
23        };
24
25        // Note, we know that have checked before that the lower and upper bounds are not constant
26        // null values
27
28        let nullability =
29            arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
30
31        Ok(Some(match_each_native_ptype!(arr.ptype(), |P| {
32            between_impl::<P>(
33                arr,
34                P::try_from(lower)?,
35                P::try_from(upper)?,
36                nullability,
37                options,
38            )
39        })))
40    }
41}
42
43register_kernel!(BetweenKernelAdapter(PrimitiveVTable).lift());
44
45fn between_impl<T: NativePType + Copy>(
46    arr: &PrimitiveArray,
47    lower: T,
48    upper: T,
49    nullability: Nullability,
50    options: &BetweenOptions,
51) -> ArrayRef {
52    match (options.lower_strict, options.upper_strict) {
53        // Note: these comparisons are explicitly passed in to allow function impl inlining
54        (StrictComparison::Strict, StrictComparison::Strict) => between_impl_(
55            arr,
56            lower,
57            NativePType::is_lt,
58            upper,
59            NativePType::is_lt,
60            nullability,
61        ),
62        (StrictComparison::Strict, StrictComparison::NonStrict) => between_impl_(
63            arr,
64            lower,
65            NativePType::is_lt,
66            upper,
67            NativePType::is_le,
68            nullability,
69        ),
70        (StrictComparison::NonStrict, StrictComparison::Strict) => between_impl_(
71            arr,
72            lower,
73            NativePType::is_le,
74            upper,
75            NativePType::is_lt,
76            nullability,
77        ),
78        (StrictComparison::NonStrict, StrictComparison::NonStrict) => between_impl_(
79            arr,
80            lower,
81            NativePType::is_le,
82            upper,
83            NativePType::is_le,
84            nullability,
85        ),
86    }
87}
88
89fn between_impl_<T>(
90    arr: &PrimitiveArray,
91    lower: T,
92    lower_fn: impl Fn(T, T) -> bool,
93    upper: T,
94    upper_fn: impl Fn(T, T) -> bool,
95    nullability: Nullability,
96) -> ArrayRef
97where
98    T: NativePType + Copy,
99{
100    let slice = arr.as_slice::<T>();
101    BoolArray::new(
102        BooleanBuffer::collect_bool(slice.len(), |idx| {
103            // We only iterate upto arr len and |arr| == |slice|.
104            let i = unsafe { *slice.get_unchecked(idx) };
105            lower_fn(lower, i) & upper_fn(i, upper)
106        }),
107        arr.validity().clone().union_nullability(nullability),
108    )
109    .into_array()
110}