vortex_array/arrays/decimal/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_buffer::BooleanBuffer;
5use vortex_dtype::Nullability;
6use vortex_error::{VortexResult, vortex_bail};
7use vortex_scalar::{NativeDecimalType, Scalar, match_each_decimal_value_type};
8
9use crate::arrays::{BoolArray, DecimalArray, DecimalVTable};
10use crate::compute::{BetweenKernel, BetweenKernelAdapter, BetweenOptions, StrictComparison};
11use crate::vtable::ValidityHelper;
12use crate::{Array, ArrayRef, IntoArray, register_kernel};
13
14impl BetweenKernel for DecimalVTable {
15    // Determine if the values are between the lower and upper bounds
16    fn between(
17        &self,
18        arr: &DecimalArray,
19        lower: &dyn Array,
20        upper: &dyn Array,
21        options: &BetweenOptions,
22    ) -> VortexResult<Option<ArrayRef>> {
23        // NOTE: We know that the precision and scale were already checked to be equal by the main
24        // `between` entrypoint function.
25
26        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
27            return Ok(None);
28        };
29
30        // NOTE: we know that have checked before that the lower and upper bounds are not all null.
31        let nullability =
32            arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
33
34        match_each_decimal_value_type!(arr.values_type(), |D| {
35            between_unpack::<D>(arr, lower, upper, nullability, options)
36        })
37    }
38}
39
40fn between_unpack<T: NativeDecimalType>(
41    arr: &DecimalArray,
42    lower: Scalar,
43    upper: Scalar,
44    nullability: Nullability,
45    options: &BetweenOptions,
46) -> VortexResult<Option<ArrayRef>> {
47    let Some(lower_value) = lower
48        .as_decimal()
49        .decimal_value()
50        .and_then(|v| v.cast::<T>())
51    else {
52        vortex_bail!(
53            "invalid lower bound Scalar: {lower}, expected {:?}",
54            T::VALUES_TYPE
55        )
56    };
57    let Some(upper_value) = upper
58        .as_decimal()
59        .decimal_value()
60        .and_then(|v| v.cast::<T>())
61    else {
62        vortex_bail!(
63            "invalid upper bound Scalar: {upper}, expected {:?}",
64            T::VALUES_TYPE
65        )
66    };
67
68    let lower_op = match options.lower_strict {
69        StrictComparison::Strict => |a, b| a < b,
70        StrictComparison::NonStrict => |a, b| a <= b,
71    };
72
73    let upper_op = match options.upper_strict {
74        StrictComparison::Strict => |a, b| a < b,
75        StrictComparison::NonStrict => |a, b| a <= b,
76    };
77
78    Ok(Some(between_impl::<T>(
79        arr,
80        lower_value,
81        upper_value,
82        nullability,
83        lower_op,
84        upper_op,
85    )))
86}
87
88register_kernel!(BetweenKernelAdapter(DecimalVTable).lift());
89
90fn between_impl<T: NativeDecimalType>(
91    arr: &DecimalArray,
92    lower: T,
93    upper: T,
94    nullability: Nullability,
95    lower_op: impl Fn(T, T) -> bool,
96    upper_op: impl Fn(T, T) -> bool,
97) -> ArrayRef {
98    let buffer = arr.buffer::<T>();
99    BoolArray::new(
100        BooleanBuffer::collect_bool(buffer.len(), |idx| {
101            let value = buffer[idx];
102            lower_op(lower, value) & upper_op(value, upper)
103        }),
104        arr.validity().clone().union_nullability(nullability),
105    )
106    .into_array()
107}
108
109#[cfg(test)]
110mod tests {
111    use vortex_buffer::buffer;
112    use vortex_dtype::{DecimalDType, Nullability};
113    use vortex_scalar::{DecimalValue, Scalar};
114
115    use crate::Array;
116    use crate::arrays::{ConstantArray, DecimalArray};
117    use crate::compute::{BetweenOptions, StrictComparison, between};
118    use crate::validity::Validity;
119
120    #[test]
121    fn test_between() {
122        let values = buffer![100i128, 200i128, 300i128, 400i128];
123        let decimal_type = DecimalDType::new(3, 2);
124        let array = DecimalArray::new(values, decimal_type, Validity::NonNullable);
125
126        let lower = ConstantArray::new(
127            Scalar::decimal(
128                DecimalValue::I128(100i128),
129                decimal_type,
130                Nullability::NonNullable,
131            ),
132            array.len(),
133        );
134        let upper = ConstantArray::new(
135            Scalar::decimal(
136                DecimalValue::I128(400i128),
137                decimal_type,
138                Nullability::NonNullable,
139            ),
140            array.len(),
141        );
142
143        // Strict lower bound, non-strict upper bound
144        let between_strict = between(
145            array.as_ref(),
146            lower.as_ref(),
147            upper.as_ref(),
148            &BetweenOptions {
149                lower_strict: StrictComparison::Strict,
150                upper_strict: StrictComparison::NonStrict,
151            },
152        )
153        .unwrap();
154        assert_eq!(bool_to_vec(&between_strict), vec![false, true, true, true]);
155
156        // Non-strict lower bound, strict upper bound
157        let between_strict = between(
158            array.as_ref(),
159            lower.as_ref(),
160            upper.as_ref(),
161            &BetweenOptions {
162                lower_strict: StrictComparison::NonStrict,
163                upper_strict: StrictComparison::Strict,
164            },
165        )
166        .unwrap();
167        assert_eq!(bool_to_vec(&between_strict), vec![true, true, true, false]);
168    }
169
170    fn bool_to_vec(array: &dyn Array) -> Vec<bool> {
171        array
172            .to_canonical()
173            .unwrap()
174            .into_bool()
175            .unwrap()
176            .boolean_buffer()
177            .iter()
178            .collect()
179    }
180}