Skip to main content

vortex_array/arrays/decimal/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBuffer;
5use vortex_dtype::NativeDecimalType;
6use vortex_dtype::Nullability;
7use vortex_dtype::match_each_decimal_value_type;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10
11use crate::Array;
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::arrays::BoolArray;
16use crate::arrays::DecimalArray;
17use crate::arrays::DecimalVTable;
18use crate::expr::BetweenKernel;
19use crate::expr::BetweenOptions;
20use crate::expr::StrictComparison;
21use crate::scalar::Scalar;
22use crate::vtable::ValidityHelper;
23
24impl BetweenKernel for DecimalVTable {
25    fn between(
26        arr: &DecimalArray,
27        lower: &dyn Array,
28        upper: &dyn Array,
29        options: &BetweenOptions,
30        _ctx: &mut ExecutionCtx,
31    ) -> VortexResult<Option<ArrayRef>> {
32        // NOTE: We know that the precision and scale were already checked to be equal by the main
33        // `between` entrypoint function.
34
35        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
36            return Ok(None);
37        };
38
39        // NOTE: we know that have checked before that the lower and upper bounds are not all null.
40        let nullability =
41            arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
42
43        match_each_decimal_value_type!(arr.values_type(), |D| {
44            between_unpack::<D>(arr, lower, upper, nullability, options)
45        })
46    }
47}
48
49fn between_unpack<T: NativeDecimalType>(
50    arr: &DecimalArray,
51    lower: Scalar,
52    upper: Scalar,
53    nullability: Nullability,
54    options: &BetweenOptions,
55) -> VortexResult<Option<ArrayRef>> {
56    let Some(lower_value) = lower
57        .as_decimal()
58        .decimal_value()
59        .and_then(|v| v.cast::<T>())
60    else {
61        vortex_bail!(
62            "invalid lower bound Scalar: {lower}, expected {:?}",
63            T::DECIMAL_TYPE
64        )
65    };
66    let Some(upper_value) = upper
67        .as_decimal()
68        .decimal_value()
69        .and_then(|v| v.cast::<T>())
70    else {
71        vortex_bail!(
72            "invalid upper bound Scalar: {upper}, expected {:?}",
73            T::DECIMAL_TYPE
74        )
75    };
76
77    let lower_op = match options.lower_strict {
78        StrictComparison::Strict => |a, b| a < b,
79        StrictComparison::NonStrict => |a, b| a <= b,
80    };
81
82    let upper_op = match options.upper_strict {
83        StrictComparison::Strict => |a, b| a < b,
84        StrictComparison::NonStrict => |a, b| a <= b,
85    };
86
87    Ok(Some(between_impl::<T>(
88        arr,
89        lower_value,
90        upper_value,
91        nullability,
92        lower_op,
93        upper_op,
94    )))
95}
96
97fn between_impl<T: NativeDecimalType>(
98    arr: &DecimalArray,
99    lower: T,
100    upper: T,
101    nullability: Nullability,
102    lower_op: impl Fn(T, T) -> bool,
103    upper_op: impl Fn(T, T) -> bool,
104) -> ArrayRef {
105    let buffer = arr.buffer::<T>();
106    BoolArray::new(
107        BitBuffer::collect_bool(buffer.len(), |idx| {
108            let value = buffer[idx];
109            lower_op(lower, value) & upper_op(value, upper)
110        }),
111        arr.validity().clone().union_nullability(nullability),
112    )
113    .into_array()
114}