vortex_array/arrays/decimal/compute/
between.rs

1use arrow_buffer::BooleanBuffer;
2use vortex_dtype::Nullability;
3use vortex_error::{VortexResult, vortex_bail};
4use vortex_scalar::{NativeDecimalType, Scalar, match_each_decimal_value_type};
5
6use crate::arrays::{BoolArray, DecimalArray, DecimalVTable};
7use crate::compute::{BetweenKernel, BetweenKernelAdapter, BetweenOptions, StrictComparison};
8use crate::vtable::ValidityHelper;
9use crate::{Array, ArrayRef, IntoArray, register_kernel};
10
11impl BetweenKernel for DecimalVTable {
12    // Determine if the values are between the lower and upper bounds
13    fn between(
14        &self,
15        arr: &DecimalArray,
16        lower: &dyn Array,
17        upper: &dyn Array,
18        options: &BetweenOptions,
19    ) -> VortexResult<Option<ArrayRef>> {
20        // NOTE: We know that the precision and scale were already checked to be equal by the main
21        // `between` entrypoint function.
22
23        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
24            return Ok(None);
25        };
26
27        // NOTE: we know that have checked before that the lower and upper bounds are not all null.
28        let nullability =
29            arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
30
31        match_each_decimal_value_type!(arr.values_type(), |D| {
32            between_unpack::<D>(arr, lower, upper, nullability, options)
33        })
34    }
35}
36
37fn between_unpack<T: NativeDecimalType>(
38    arr: &DecimalArray,
39    lower: Scalar,
40    upper: Scalar,
41    nullability: Nullability,
42    options: &BetweenOptions,
43) -> VortexResult<Option<ArrayRef>> {
44    let Some(lower_value) = lower
45        .as_decimal()
46        .decimal_value()
47        .and_then(|v| v.cast::<T>())
48    else {
49        vortex_bail!(
50            "invalid lower bound Scalar: {lower}, expected {:?}",
51            T::VALUES_TYPE
52        )
53    };
54    let Some(upper_value) = upper
55        .as_decimal()
56        .decimal_value()
57        .and_then(|v| v.cast::<T>())
58    else {
59        vortex_bail!(
60            "invalid upper bound Scalar: {upper}, expected {:?}",
61            T::VALUES_TYPE
62        )
63    };
64
65    let lower_op = match options.lower_strict {
66        StrictComparison::Strict => |a, b| a < b,
67        StrictComparison::NonStrict => |a, b| a <= b,
68    };
69
70    let upper_op = match options.upper_strict {
71        StrictComparison::Strict => |a, b| a < b,
72        StrictComparison::NonStrict => |a, b| a <= b,
73    };
74
75    Ok(Some(between_impl::<T>(
76        arr,
77        lower_value,
78        upper_value,
79        nullability,
80        lower_op,
81        upper_op,
82    )))
83}
84
85register_kernel!(BetweenKernelAdapter(DecimalVTable).lift());
86
87fn between_impl<T: NativeDecimalType>(
88    arr: &DecimalArray,
89    lower: T,
90    upper: T,
91    nullability: Nullability,
92    lower_op: impl Fn(T, T) -> bool,
93    upper_op: impl Fn(T, T) -> bool,
94) -> ArrayRef {
95    let buffer = arr.buffer::<T>();
96    BoolArray::new(
97        BooleanBuffer::collect_bool(buffer.len(), |idx| {
98            let value = buffer[idx];
99            lower_op(lower, value) & upper_op(value, upper)
100        }),
101        arr.validity().clone().union_nullability(nullability),
102    )
103    .into_array()
104}
105
106#[cfg(test)]
107mod tests {
108    use vortex_buffer::buffer;
109    use vortex_dtype::{DecimalDType, Nullability};
110    use vortex_scalar::{DecimalValue, Scalar};
111
112    use crate::Array;
113    use crate::arrays::{ConstantArray, DecimalArray};
114    use crate::compute::{BetweenOptions, StrictComparison, between};
115    use crate::validity::Validity;
116
117    #[test]
118    fn test_between() {
119        let values = buffer![100i128, 200i128, 300i128, 400i128];
120        let decimal_type = DecimalDType::new(3, 2);
121        let array = DecimalArray::new(values, decimal_type, Validity::NonNullable);
122
123        let lower = ConstantArray::new(
124            Scalar::decimal(
125                DecimalValue::I128(100i128),
126                decimal_type,
127                Nullability::NonNullable,
128            ),
129            array.len(),
130        );
131        let upper = ConstantArray::new(
132            Scalar::decimal(
133                DecimalValue::I128(400i128),
134                decimal_type,
135                Nullability::NonNullable,
136            ),
137            array.len(),
138        );
139
140        // Strict lower bound, non-strict upper bound
141        let between_strict = between(
142            array.as_ref(),
143            lower.as_ref(),
144            upper.as_ref(),
145            &BetweenOptions {
146                lower_strict: StrictComparison::Strict,
147                upper_strict: StrictComparison::NonStrict,
148            },
149        )
150        .unwrap();
151        assert_eq!(bool_to_vec(&between_strict), vec![false, true, true, true]);
152
153        // Non-strict lower bound, strict upper bound
154        let between_strict = between(
155            array.as_ref(),
156            lower.as_ref(),
157            upper.as_ref(),
158            &BetweenOptions {
159                lower_strict: StrictComparison::NonStrict,
160                upper_strict: StrictComparison::Strict,
161            },
162        )
163        .unwrap();
164        assert_eq!(bool_to_vec(&between_strict), vec![true, true, true, false]);
165    }
166
167    fn bool_to_vec(array: &dyn Array) -> Vec<bool> {
168        array
169            .to_canonical()
170            .unwrap()
171            .into_bool()
172            .unwrap()
173            .boolean_buffer()
174            .iter()
175            .collect()
176    }
177}