Skip to main content

vortex_array/arrays/primitive/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBuffer;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::ExecutionCtx;
9use crate::IntoArray;
10use crate::arrays::BoolArray;
11use crate::arrays::PrimitiveArray;
12use crate::arrays::PrimitiveVTable;
13use crate::dtype::NativePType;
14use crate::dtype::Nullability;
15use crate::match_each_native_ptype;
16use crate::scalar_fn::fns::between::BetweenKernel;
17use crate::scalar_fn::fns::between::BetweenOptions;
18use crate::scalar_fn::fns::between::StrictComparison;
19use crate::vtable::ValidityHelper;
20
21impl BetweenKernel for PrimitiveVTable {
22    fn between(
23        arr: &PrimitiveArray,
24        lower: &ArrayRef,
25        upper: &ArrayRef,
26        options: &BetweenOptions,
27        _ctx: &mut ExecutionCtx,
28    ) -> VortexResult<Option<ArrayRef>> {
29        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
30            return Ok(None);
31        };
32
33        // Note, we know that have checked before that the lower and upper bounds are not constant
34        // null values
35
36        let nullability =
37            arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
38
39        Ok(Some(match_each_native_ptype!(arr.ptype(), |P| {
40            between_impl::<P>(
41                arr,
42                P::try_from(&lower)?,
43                P::try_from(&upper)?,
44                nullability,
45                options,
46            )
47        })))
48    }
49}
50
51fn between_impl<T: NativePType + Copy>(
52    arr: &PrimitiveArray,
53    lower: T,
54    upper: T,
55    nullability: Nullability,
56    options: &BetweenOptions,
57) -> ArrayRef {
58    match (options.lower_strict, options.upper_strict) {
59        // Note: these comparisons are explicitly passed in to allow function impl inlining
60        (StrictComparison::Strict, StrictComparison::Strict) => between_impl_(
61            arr,
62            lower,
63            NativePType::is_lt,
64            upper,
65            NativePType::is_lt,
66            nullability,
67        ),
68        (StrictComparison::Strict, StrictComparison::NonStrict) => between_impl_(
69            arr,
70            lower,
71            NativePType::is_lt,
72            upper,
73            NativePType::is_le,
74            nullability,
75        ),
76        (StrictComparison::NonStrict, StrictComparison::Strict) => between_impl_(
77            arr,
78            lower,
79            NativePType::is_le,
80            upper,
81            NativePType::is_lt,
82            nullability,
83        ),
84        (StrictComparison::NonStrict, StrictComparison::NonStrict) => between_impl_(
85            arr,
86            lower,
87            NativePType::is_le,
88            upper,
89            NativePType::is_le,
90            nullability,
91        ),
92    }
93}
94
95fn between_impl_<T>(
96    arr: &PrimitiveArray,
97    lower: T,
98    lower_fn: impl Fn(T, T) -> bool,
99    upper: T,
100    upper_fn: impl Fn(T, T) -> bool,
101    nullability: Nullability,
102) -> ArrayRef
103where
104    T: NativePType + Copy,
105{
106    let slice = arr.as_slice::<T>();
107    BoolArray::new(
108        BitBuffer::collect_bool(slice.len(), |idx| {
109            // We only iterate upto arr len and |arr| == |slice|.
110            let i = unsafe { *slice.get_unchecked(idx) };
111            lower_fn(lower, i) & upper_fn(i, upper)
112        }),
113        arr.validity().clone().union_nullability(nullability),
114    )
115    .into_array()
116}