vortex_array/arrays/decimal/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBuffer;
5use vortex_dtype::NativeDecimalType;
6use vortex_dtype::Nullability;
7use vortex_dtype::match_each_decimal_value_type;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_scalar::Scalar;
11
12use crate::Array;
13use crate::ArrayRef;
14use crate::IntoArray;
15use crate::arrays::BoolArray;
16use crate::arrays::DecimalArray;
17use crate::arrays::DecimalVTable;
18use crate::compute::BetweenKernel;
19use crate::compute::BetweenKernelAdapter;
20use crate::compute::BetweenOptions;
21use crate::compute::StrictComparison;
22use crate::register_kernel;
23use crate::vtable::ValidityHelper;
24
25impl BetweenKernel for DecimalVTable {
26    // Determine if the values are between the lower and upper bounds
27    fn between(
28        &self,
29        arr: &DecimalArray,
30        lower: &dyn Array,
31        upper: &dyn Array,
32        options: &BetweenOptions,
33    ) -> VortexResult<Option<ArrayRef>> {
34        // NOTE: We know that the precision and scale were already checked to be equal by the main
35        // `between` entrypoint function.
36
37        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
38            return Ok(None);
39        };
40
41        // NOTE: we know that have checked before that the lower and upper bounds are not all null.
42        let nullability =
43            arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
44
45        match_each_decimal_value_type!(arr.values_type(), |D| {
46            between_unpack::<D>(arr, lower, upper, nullability, options)
47        })
48    }
49}
50
51fn between_unpack<T: NativeDecimalType>(
52    arr: &DecimalArray,
53    lower: Scalar,
54    upper: Scalar,
55    nullability: Nullability,
56    options: &BetweenOptions,
57) -> VortexResult<Option<ArrayRef>> {
58    let Some(lower_value) = lower
59        .as_decimal()
60        .decimal_value()
61        .and_then(|v| v.cast::<T>())
62    else {
63        vortex_bail!(
64            "invalid lower bound Scalar: {lower}, expected {:?}",
65            T::DECIMAL_TYPE
66        )
67    };
68    let Some(upper_value) = upper
69        .as_decimal()
70        .decimal_value()
71        .and_then(|v| v.cast::<T>())
72    else {
73        vortex_bail!(
74            "invalid upper bound Scalar: {upper}, expected {:?}",
75            T::DECIMAL_TYPE
76        )
77    };
78
79    let lower_op = match options.lower_strict {
80        StrictComparison::Strict => |a, b| a < b,
81        StrictComparison::NonStrict => |a, b| a <= b,
82    };
83
84    let upper_op = match options.upper_strict {
85        StrictComparison::Strict => |a, b| a < b,
86        StrictComparison::NonStrict => |a, b| a <= b,
87    };
88
89    Ok(Some(between_impl::<T>(
90        arr,
91        lower_value,
92        upper_value,
93        nullability,
94        lower_op,
95        upper_op,
96    )))
97}
98
99register_kernel!(BetweenKernelAdapter(DecimalVTable).lift());
100
101fn between_impl<T: NativeDecimalType>(
102    arr: &DecimalArray,
103    lower: T,
104    upper: T,
105    nullability: Nullability,
106    lower_op: impl Fn(T, T) -> bool,
107    upper_op: impl Fn(T, T) -> bool,
108) -> ArrayRef {
109    let buffer = arr.buffer::<T>();
110    BoolArray::from_bit_buffer(
111        BitBuffer::collect_bool(buffer.len(), |idx| {
112            let value = buffer[idx];
113            lower_op(lower, value) & upper_op(value, upper)
114        }),
115        arr.validity().clone().union_nullability(nullability),
116    )
117    .into_array()
118}
119
120#[cfg(test)]
121mod tests {
122    use vortex_buffer::buffer;
123    use vortex_dtype::DecimalDType;
124    use vortex_dtype::Nullability;
125    use vortex_scalar::DecimalValue;
126    use vortex_scalar::Scalar;
127
128    use crate::Array;
129    use crate::ToCanonical;
130    use crate::arrays::ConstantArray;
131    use crate::arrays::DecimalArray;
132    use crate::compute::BetweenOptions;
133    use crate::compute::StrictComparison;
134    use crate::compute::between;
135    use crate::validity::Validity;
136
137    #[test]
138    fn test_between() {
139        let values = buffer![100i128, 200i128, 300i128, 400i128];
140        let decimal_type = DecimalDType::new(3, 2);
141        let array = DecimalArray::new(values, decimal_type, Validity::NonNullable);
142
143        let lower = ConstantArray::new(
144            Scalar::decimal(
145                DecimalValue::I128(100i128),
146                decimal_type,
147                Nullability::NonNullable,
148            ),
149            array.len(),
150        );
151        let upper = ConstantArray::new(
152            Scalar::decimal(
153                DecimalValue::I128(400i128),
154                decimal_type,
155                Nullability::NonNullable,
156            ),
157            array.len(),
158        );
159
160        // Strict lower bound, non-strict upper bound
161        let between_strict = between(
162            array.as_ref(),
163            lower.as_ref(),
164            upper.as_ref(),
165            &BetweenOptions {
166                lower_strict: StrictComparison::Strict,
167                upper_strict: StrictComparison::NonStrict,
168            },
169        )
170        .unwrap();
171        assert_eq!(bool_to_vec(&between_strict), vec![false, true, true, true]);
172
173        // Non-strict lower bound, strict upper bound
174        let between_strict = between(
175            array.as_ref(),
176            lower.as_ref(),
177            upper.as_ref(),
178            &BetweenOptions {
179                lower_strict: StrictComparison::NonStrict,
180                upper_strict: StrictComparison::Strict,
181            },
182        )
183        .unwrap();
184        assert_eq!(bool_to_vec(&between_strict), vec![true, true, true, false]);
185    }
186
187    fn bool_to_vec(array: &dyn Array) -> Vec<bool> {
188        array.to_bool().bit_buffer().iter().collect()
189    }
190}