vortex_array/arrays/decimal/compute/
between.rs

1use arrow_buffer::BooleanBuffer;
2use vortex_dtype::Nullability;
3use vortex_error::{VortexResult, vortex_bail};
4use vortex_scalar::{DecimalValue, Scalar};
5
6use crate::arrays::{BoolArray, DecimalArray, DecimalEncoding, NativeDecimalType};
7use crate::compute::{BetweenKernel, BetweenKernelAdapter, BetweenOptions, StrictComparison};
8use crate::{Array, ArrayRef, match_each_decimal_value_type, register_kernel};
9
10impl BetweenKernel for DecimalEncoding {
11    // Determine if the values are between the lower and upper bounds
12    fn between(
13        &self,
14        arr: &DecimalArray,
15        lower: &dyn Array,
16        upper: &dyn Array,
17        options: &BetweenOptions,
18    ) -> VortexResult<Option<ArrayRef>> {
19        // NOTE: We know that the precision and scale were already checked to be equal by the main
20        // `between` entrypoint function.
21
22        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
23            return Ok(None);
24        };
25
26        // NOTE: we know that have checked before that the lower and upper bounds are not all null.
27        let nullability =
28            arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
29
30        match_each_decimal_value_type!(arr.values_type(), |($D, $CTor)| {
31           between_unpack::<$D>(
32                arr,
33                lower,
34                upper,
35                |d| match d {
36                    $CTor(v) => Some(v),
37                    _ => None,
38                },
39                nullability,
40                options,
41             )
42        })
43    }
44}
45
46fn between_unpack<T: NativeDecimalType>(
47    arr: &DecimalArray,
48    lower: Scalar,
49    upper: Scalar,
50    unpack: impl Fn(DecimalValue) -> Option<T> + Copy,
51    nullability: Nullability,
52    options: &BetweenOptions,
53) -> VortexResult<Option<ArrayRef>> {
54    let Some(lower_value) = lower.as_decimal().decimal_value().and_then(unpack) else {
55        vortex_bail!("invalid lower bound Scalar: {lower}");
56    };
57    let Some(upper_value) = upper.as_decimal().decimal_value().and_then(unpack) else {
58        vortex_bail!("invalid upper bound Scalar: {upper}");
59    };
60
61    let lower_op = match options.lower_strict {
62        StrictComparison::Strict => |a, b| a < b,
63        StrictComparison::NonStrict => |a, b| a <= b,
64    };
65
66    let upper_op = match options.upper_strict {
67        StrictComparison::Strict => |a, b| a < b,
68        StrictComparison::NonStrict => |a, b| a <= b,
69    };
70
71    Ok(Some(between_impl::<T>(
72        arr,
73        lower_value,
74        upper_value,
75        nullability,
76        lower_op,
77        upper_op,
78    )))
79}
80
81register_kernel!(BetweenKernelAdapter(DecimalEncoding).lift());
82
83fn between_impl<T: NativeDecimalType>(
84    arr: &DecimalArray,
85    lower: T,
86    upper: T,
87    nullability: Nullability,
88    lower_op: impl Fn(T, T) -> bool,
89    upper_op: impl Fn(T, T) -> bool,
90) -> ArrayRef {
91    let buffer = arr.buffer::<T>();
92    BoolArray::new(
93        BooleanBuffer::collect_bool(buffer.len(), |idx| {
94            let value = buffer[idx];
95            lower_op(lower, value) & upper_op(value, upper)
96        }),
97        arr.validity().clone().union_nullability(nullability),
98    )
99    .into_array()
100}
101
102#[cfg(test)]
103mod tests {
104    use vortex_buffer::buffer;
105    use vortex_dtype::{DecimalDType, Nullability};
106    use vortex_scalar::{DecimalValue, Scalar};
107
108    use crate::Array;
109    use crate::arrays::{ConstantArray, DecimalArray};
110    use crate::compute::{BetweenOptions, StrictComparison, between};
111    use crate::validity::Validity;
112
113    #[test]
114    fn test_between() {
115        let values = buffer![100i128, 200i128, 300i128, 400i128];
116        let decimal_type = DecimalDType::new(3, 2);
117        let array = DecimalArray::new(values, decimal_type, Validity::NonNullable);
118
119        let lower = ConstantArray::new(
120            Scalar::decimal(
121                DecimalValue::I128(100i128),
122                decimal_type,
123                Nullability::NonNullable,
124            ),
125            array.len(),
126        );
127        let upper = ConstantArray::new(
128            Scalar::decimal(
129                DecimalValue::I128(400i128),
130                decimal_type,
131                Nullability::NonNullable,
132            ),
133            array.len(),
134        );
135
136        // Strict lower bound, non-strict upper bound
137        let between_strict = between(
138            &array,
139            &lower,
140            &upper,
141            &BetweenOptions {
142                lower_strict: StrictComparison::Strict,
143                upper_strict: StrictComparison::NonStrict,
144            },
145        )
146        .unwrap();
147        assert_eq!(bool_to_vec(&between_strict), vec![false, true, true, true]);
148
149        // Non-strict lower bound, strict upper bound
150        let between_strict = between(
151            &array,
152            &lower,
153            &upper,
154            &BetweenOptions {
155                lower_strict: StrictComparison::NonStrict,
156                upper_strict: StrictComparison::Strict,
157            },
158        )
159        .unwrap();
160        assert_eq!(bool_to_vec(&between_strict), vec![true, true, true, false]);
161    }
162
163    fn bool_to_vec(array: &dyn Array) -> Vec<bool> {
164        array
165            .to_canonical()
166            .unwrap()
167            .into_bool()
168            .unwrap()
169            .boolean_buffer()
170            .iter()
171            .collect()
172    }
173}