Skip to main content

vortex_array/arrays/primitive/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBuffer;
5use vortex_dtype::NativePType;
6use vortex_dtype::Nullability;
7use vortex_dtype::match_each_native_ptype;
8use vortex_error::VortexResult;
9
10use crate::Array;
11use crate::ArrayRef;
12use crate::ExecutionCtx;
13use crate::IntoArray;
14use crate::arrays::BoolArray;
15use crate::arrays::PrimitiveArray;
16use crate::arrays::PrimitiveVTable;
17use crate::expr::BetweenKernel;
18use crate::expr::BetweenOptions;
19use crate::expr::StrictComparison;
20use crate::vtable::ValidityHelper;
21
22impl BetweenKernel for PrimitiveVTable {
23    fn between(
24        arr: &PrimitiveArray,
25        lower: &dyn Array,
26        upper: &dyn Array,
27        options: &BetweenOptions,
28        _ctx: &mut ExecutionCtx,
29    ) -> VortexResult<Option<ArrayRef>> {
30        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
31            return Ok(None);
32        };
33
34        // Note, we know that have checked before that the lower and upper bounds are not constant
35        // null values
36
37        let nullability =
38            arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
39
40        Ok(Some(match_each_native_ptype!(arr.ptype(), |P| {
41            between_impl::<P>(
42                arr,
43                P::try_from(&lower)?,
44                P::try_from(&upper)?,
45                nullability,
46                options,
47            )
48        })))
49    }
50}
51
52fn between_impl<T: NativePType + Copy>(
53    arr: &PrimitiveArray,
54    lower: T,
55    upper: T,
56    nullability: Nullability,
57    options: &BetweenOptions,
58) -> ArrayRef {
59    match (options.lower_strict, options.upper_strict) {
60        // Note: these comparisons are explicitly passed in to allow function impl inlining
61        (StrictComparison::Strict, StrictComparison::Strict) => between_impl_(
62            arr,
63            lower,
64            NativePType::is_lt,
65            upper,
66            NativePType::is_lt,
67            nullability,
68        ),
69        (StrictComparison::Strict, StrictComparison::NonStrict) => between_impl_(
70            arr,
71            lower,
72            NativePType::is_lt,
73            upper,
74            NativePType::is_le,
75            nullability,
76        ),
77        (StrictComparison::NonStrict, StrictComparison::Strict) => between_impl_(
78            arr,
79            lower,
80            NativePType::is_le,
81            upper,
82            NativePType::is_lt,
83            nullability,
84        ),
85        (StrictComparison::NonStrict, StrictComparison::NonStrict) => between_impl_(
86            arr,
87            lower,
88            NativePType::is_le,
89            upper,
90            NativePType::is_le,
91            nullability,
92        ),
93    }
94}
95
96fn between_impl_<T>(
97    arr: &PrimitiveArray,
98    lower: T,
99    lower_fn: impl Fn(T, T) -> bool,
100    upper: T,
101    upper_fn: impl Fn(T, T) -> bool,
102    nullability: Nullability,
103) -> ArrayRef
104where
105    T: NativePType + Copy,
106{
107    let slice = arr.as_slice::<T>();
108    BoolArray::new(
109        BitBuffer::collect_bool(slice.len(), |idx| {
110            // We only iterate upto arr len and |arr| == |slice|.
111            let i = unsafe { *slice.get_unchecked(idx) };
112            lower_fn(lower, i) & upper_fn(i, upper)
113        }),
114        arr.validity().clone().union_nullability(nullability),
115    )
116    .into_array()
117}