vortex_array/arrays/decimal/compute/
between.rs

1use arrow_buffer::BooleanBuffer;
2use vortex_dtype::Nullability;
3use vortex_error::{VortexResult, vortex_bail};
4use vortex_scalar::{NativeDecimalType, Scalar, match_each_decimal_value_type};
5
6use crate::arrays::{BoolArray, DecimalArray, DecimalVTable};
7use crate::compute::{BetweenKernel, BetweenKernelAdapter, BetweenOptions, StrictComparison};
8use crate::vtable::ValidityHelper;
9use crate::{Array, ArrayRef, IntoArray, register_kernel};
10
11impl BetweenKernel for DecimalVTable {
12    // Determine if the values are between the lower and upper bounds
13    fn between(
14        &self,
15        arr: &DecimalArray,
16        lower: &dyn Array,
17        upper: &dyn Array,
18        options: &BetweenOptions,
19    ) -> VortexResult<Option<ArrayRef>> {
20        // NOTE: We know that the precision and scale were already checked to be equal by the main
21        // `between` entrypoint function.
22
23        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
24            return Ok(None);
25        };
26
27        // NOTE: we know that have checked before that the lower and upper bounds are not all null.
28        let nullability =
29            arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
30
31        match_each_decimal_value_type!(arr.values_type(), |D| {
32            between_unpack::<D>(arr, lower, upper, nullability, options)
33        })
34    }
35}
36
37fn between_unpack<T: NativeDecimalType>(
38    arr: &DecimalArray,
39    lower: Scalar,
40    upper: Scalar,
41    nullability: Nullability,
42    options: &BetweenOptions,
43) -> VortexResult<Option<ArrayRef>> {
44    let Some(lower_value) = lower
45        .as_decimal()
46        .decimal_value()
47        .and_then(|v| T::try_from(v))
48    else {
49        vortex_bail!("invalid lower bound Scalar: {lower}");
50    };
51    let Some(upper_value) = upper
52        .as_decimal()
53        .decimal_value()
54        .and_then(|v| T::try_from(v))
55    else {
56        vortex_bail!("invalid upper bound Scalar: {upper}");
57    };
58
59    let lower_op = match options.lower_strict {
60        StrictComparison::Strict => |a, b| a < b,
61        StrictComparison::NonStrict => |a, b| a <= b,
62    };
63
64    let upper_op = match options.upper_strict {
65        StrictComparison::Strict => |a, b| a < b,
66        StrictComparison::NonStrict => |a, b| a <= b,
67    };
68
69    Ok(Some(between_impl::<T>(
70        arr,
71        lower_value,
72        upper_value,
73        nullability,
74        lower_op,
75        upper_op,
76    )))
77}
78
79register_kernel!(BetweenKernelAdapter(DecimalVTable).lift());
80
81fn between_impl<T: NativeDecimalType>(
82    arr: &DecimalArray,
83    lower: T,
84    upper: T,
85    nullability: Nullability,
86    lower_op: impl Fn(T, T) -> bool,
87    upper_op: impl Fn(T, T) -> bool,
88) -> ArrayRef {
89    let buffer = arr.buffer::<T>();
90    BoolArray::new(
91        BooleanBuffer::collect_bool(buffer.len(), |idx| {
92            let value = buffer[idx];
93            lower_op(lower, value) & upper_op(value, upper)
94        }),
95        arr.validity().clone().union_nullability(nullability),
96    )
97    .into_array()
98}
99
100#[cfg(test)]
101mod tests {
102    use vortex_buffer::buffer;
103    use vortex_dtype::{DecimalDType, Nullability};
104    use vortex_scalar::{DecimalValue, Scalar};
105
106    use crate::Array;
107    use crate::arrays::{ConstantArray, DecimalArray};
108    use crate::compute::{BetweenOptions, StrictComparison, between};
109    use crate::validity::Validity;
110
111    #[test]
112    fn test_between() {
113        let values = buffer![100i128, 200i128, 300i128, 400i128];
114        let decimal_type = DecimalDType::new(3, 2);
115        let array = DecimalArray::new(values, decimal_type, Validity::NonNullable);
116
117        let lower = ConstantArray::new(
118            Scalar::decimal(
119                DecimalValue::I128(100i128),
120                decimal_type,
121                Nullability::NonNullable,
122            ),
123            array.len(),
124        );
125        let upper = ConstantArray::new(
126            Scalar::decimal(
127                DecimalValue::I128(400i128),
128                decimal_type,
129                Nullability::NonNullable,
130            ),
131            array.len(),
132        );
133
134        // Strict lower bound, non-strict upper bound
135        let between_strict = between(
136            array.as_ref(),
137            lower.as_ref(),
138            upper.as_ref(),
139            &BetweenOptions {
140                lower_strict: StrictComparison::Strict,
141                upper_strict: StrictComparison::NonStrict,
142            },
143        )
144        .unwrap();
145        assert_eq!(bool_to_vec(&between_strict), vec![false, true, true, true]);
146
147        // Non-strict lower bound, strict upper bound
148        let between_strict = between(
149            array.as_ref(),
150            lower.as_ref(),
151            upper.as_ref(),
152            &BetweenOptions {
153                lower_strict: StrictComparison::NonStrict,
154                upper_strict: StrictComparison::Strict,
155            },
156        )
157        .unwrap();
158        assert_eq!(bool_to_vec(&between_strict), vec![true, true, true, false]);
159    }
160
161    fn bool_to_vec(array: &dyn Array) -> Vec<bool> {
162        array
163            .to_canonical()
164            .unwrap()
165            .into_bool()
166            .unwrap()
167            .boolean_buffer()
168            .iter()
169            .collect()
170    }
171}