vortex_array/arrays/decimal/compute/
between.rs

1use arrow_buffer::BooleanBuffer;
2use vortex_dtype::Nullability;
3use vortex_error::{VortexResult, vortex_bail};
4use vortex_scalar::{DecimalValue, NativeDecimalType, Scalar, match_each_decimal_value_type};
5
6use crate::arrays::{BoolArray, DecimalArray, DecimalVTable};
7use crate::compute::{BetweenKernel, BetweenKernelAdapter, BetweenOptions, StrictComparison};
8use crate::vtable::ValidityHelper;
9use crate::{Array, ArrayRef, IntoArray, register_kernel};
10
11impl BetweenKernel for DecimalVTable {
12    // Determine if the values are between the lower and upper bounds
13    fn between(
14        &self,
15        arr: &DecimalArray,
16        lower: &dyn Array,
17        upper: &dyn Array,
18        options: &BetweenOptions,
19    ) -> VortexResult<Option<ArrayRef>> {
20        // NOTE: We know that the precision and scale were already checked to be equal by the main
21        // `between` entrypoint function.
22
23        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
24            return Ok(None);
25        };
26
27        // NOTE: we know that have checked before that the lower and upper bounds are not all null.
28        let nullability =
29            arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
30
31        match_each_decimal_value_type!(arr.values_type(), |($D, $CTor)| {
32           between_unpack::<$D>(
33                arr,
34                lower,
35                upper,
36                |d| match d {
37                    $CTor(v) => Some(v),
38                    _ => None,
39                },
40                nullability,
41                options,
42             )
43        })
44    }
45}
46
47fn between_unpack<T: NativeDecimalType>(
48    arr: &DecimalArray,
49    lower: Scalar,
50    upper: Scalar,
51    unpack: impl Fn(DecimalValue) -> Option<T> + Copy,
52    nullability: Nullability,
53    options: &BetweenOptions,
54) -> VortexResult<Option<ArrayRef>> {
55    let Some(lower_value) = lower.as_decimal().decimal_value().and_then(unpack) else {
56        vortex_bail!("invalid lower bound Scalar: {lower}");
57    };
58    let Some(upper_value) = upper.as_decimal().decimal_value().and_then(unpack) else {
59        vortex_bail!("invalid upper bound Scalar: {upper}");
60    };
61
62    let lower_op = match options.lower_strict {
63        StrictComparison::Strict => |a, b| a < b,
64        StrictComparison::NonStrict => |a, b| a <= b,
65    };
66
67    let upper_op = match options.upper_strict {
68        StrictComparison::Strict => |a, b| a < b,
69        StrictComparison::NonStrict => |a, b| a <= b,
70    };
71
72    Ok(Some(between_impl::<T>(
73        arr,
74        lower_value,
75        upper_value,
76        nullability,
77        lower_op,
78        upper_op,
79    )))
80}
81
82register_kernel!(BetweenKernelAdapter(DecimalVTable).lift());
83
84fn between_impl<T: NativeDecimalType>(
85    arr: &DecimalArray,
86    lower: T,
87    upper: T,
88    nullability: Nullability,
89    lower_op: impl Fn(T, T) -> bool,
90    upper_op: impl Fn(T, T) -> bool,
91) -> ArrayRef {
92    let buffer = arr.buffer::<T>();
93    BoolArray::new(
94        BooleanBuffer::collect_bool(buffer.len(), |idx| {
95            let value = buffer[idx];
96            lower_op(lower, value) & upper_op(value, upper)
97        }),
98        arr.validity().clone().union_nullability(nullability),
99    )
100    .into_array()
101}
102
103#[cfg(test)]
104mod tests {
105    use vortex_buffer::buffer;
106    use vortex_dtype::{DecimalDType, Nullability};
107    use vortex_scalar::{DecimalValue, Scalar};
108
109    use crate::Array;
110    use crate::arrays::{ConstantArray, DecimalArray};
111    use crate::compute::{BetweenOptions, StrictComparison, between};
112    use crate::validity::Validity;
113
114    #[test]
115    fn test_between() {
116        let values = buffer![100i128, 200i128, 300i128, 400i128];
117        let decimal_type = DecimalDType::new(3, 2);
118        let array = DecimalArray::new(values, decimal_type, Validity::NonNullable);
119
120        let lower = ConstantArray::new(
121            Scalar::decimal(
122                DecimalValue::I128(100i128),
123                decimal_type,
124                Nullability::NonNullable,
125            ),
126            array.len(),
127        );
128        let upper = ConstantArray::new(
129            Scalar::decimal(
130                DecimalValue::I128(400i128),
131                decimal_type,
132                Nullability::NonNullable,
133            ),
134            array.len(),
135        );
136
137        // Strict lower bound, non-strict upper bound
138        let between_strict = between(
139            array.as_ref(),
140            lower.as_ref(),
141            upper.as_ref(),
142            &BetweenOptions {
143                lower_strict: StrictComparison::Strict,
144                upper_strict: StrictComparison::NonStrict,
145            },
146        )
147        .unwrap();
148        assert_eq!(bool_to_vec(&between_strict), vec![false, true, true, true]);
149
150        // Non-strict lower bound, strict upper bound
151        let between_strict = between(
152            array.as_ref(),
153            lower.as_ref(),
154            upper.as_ref(),
155            &BetweenOptions {
156                lower_strict: StrictComparison::NonStrict,
157                upper_strict: StrictComparison::Strict,
158            },
159        )
160        .unwrap();
161        assert_eq!(bool_to_vec(&between_strict), vec![true, true, true, false]);
162    }
163
164    fn bool_to_vec(array: &dyn Array) -> Vec<bool> {
165        array
166            .to_canonical()
167            .unwrap()
168            .into_bool()
169            .unwrap()
170            .boolean_buffer()
171            .iter()
172            .collect()
173    }
174}