vortex_array/arrays/primitive/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBuffer;
5use vortex_dtype::NativePType;
6use vortex_dtype::Nullability;
7use vortex_dtype::match_each_native_ptype;
8use vortex_error::VortexResult;
9
10use crate::Array;
11use crate::ArrayRef;
12use crate::IntoArray;
13use crate::arrays::BoolArray;
14use crate::arrays::PrimitiveArray;
15use crate::arrays::PrimitiveVTable;
16use crate::compute::BetweenKernel;
17use crate::compute::BetweenKernelAdapter;
18use crate::compute::BetweenOptions;
19use crate::compute::StrictComparison;
20use crate::register_kernel;
21use crate::vtable::ValidityHelper;
22
23impl BetweenKernel for PrimitiveVTable {
24    fn between(
25        &self,
26        arr: &PrimitiveArray,
27        lower: &dyn Array,
28        upper: &dyn Array,
29        options: &BetweenOptions,
30    ) -> VortexResult<Option<ArrayRef>> {
31        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
32            return Ok(None);
33        };
34
35        // Note, we know that have checked before that the lower and upper bounds are not constant
36        // null values
37
38        let nullability =
39            arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
40
41        Ok(Some(match_each_native_ptype!(arr.ptype(), |P| {
42            between_impl::<P>(
43                arr,
44                P::try_from(lower)?,
45                P::try_from(upper)?,
46                nullability,
47                options,
48            )
49        })))
50    }
51}
52
53register_kernel!(BetweenKernelAdapter(PrimitiveVTable).lift());
54
55fn between_impl<T: NativePType + Copy>(
56    arr: &PrimitiveArray,
57    lower: T,
58    upper: T,
59    nullability: Nullability,
60    options: &BetweenOptions,
61) -> ArrayRef {
62    match (options.lower_strict, options.upper_strict) {
63        // Note: these comparisons are explicitly passed in to allow function impl inlining
64        (StrictComparison::Strict, StrictComparison::Strict) => between_impl_(
65            arr,
66            lower,
67            NativePType::is_lt,
68            upper,
69            NativePType::is_lt,
70            nullability,
71        ),
72        (StrictComparison::Strict, StrictComparison::NonStrict) => between_impl_(
73            arr,
74            lower,
75            NativePType::is_lt,
76            upper,
77            NativePType::is_le,
78            nullability,
79        ),
80        (StrictComparison::NonStrict, StrictComparison::Strict) => between_impl_(
81            arr,
82            lower,
83            NativePType::is_le,
84            upper,
85            NativePType::is_lt,
86            nullability,
87        ),
88        (StrictComparison::NonStrict, StrictComparison::NonStrict) => between_impl_(
89            arr,
90            lower,
91            NativePType::is_le,
92            upper,
93            NativePType::is_le,
94            nullability,
95        ),
96    }
97}
98
99fn between_impl_<T>(
100    arr: &PrimitiveArray,
101    lower: T,
102    lower_fn: impl Fn(T, T) -> bool,
103    upper: T,
104    upper_fn: impl Fn(T, T) -> bool,
105    nullability: Nullability,
106) -> ArrayRef
107where
108    T: NativePType + Copy,
109{
110    let slice = arr.as_slice::<T>();
111    BoolArray::from_bit_buffer(
112        BitBuffer::collect_bool(slice.len(), |idx| {
113            // We only iterate upto arr len and |arr| == |slice|.
114            let i = unsafe { *slice.get_unchecked(idx) };
115            lower_fn(lower, i) & upper_fn(i, upper)
116        }),
117        arr.validity().clone().union_nullability(nullability),
118    )
119    .into_array()
120}