Skip to main content

vortex_array/arrays/decimal/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBuffer;
5use vortex_error::VortexResult;
6use vortex_error::vortex_bail;
7
8use crate::ArrayRef;
9use crate::ExecutionCtx;
10use crate::IntoArray;
11use crate::arrays::BoolArray;
12use crate::arrays::DecimalArray;
13use crate::arrays::DecimalVTable;
14use crate::dtype::NativeDecimalType;
15use crate::dtype::Nullability;
16use crate::match_each_decimal_value_type;
17use crate::scalar::Scalar;
18use crate::scalar_fn::fns::between::BetweenKernel;
19use crate::scalar_fn::fns::between::BetweenOptions;
20use crate::scalar_fn::fns::between::StrictComparison;
21use crate::vtable::ValidityHelper;
22
23impl BetweenKernel for DecimalVTable {
24    fn between(
25        arr: &DecimalArray,
26        lower: &ArrayRef,
27        upper: &ArrayRef,
28        options: &BetweenOptions,
29        _ctx: &mut ExecutionCtx,
30    ) -> VortexResult<Option<ArrayRef>> {
31        // NOTE: We know that the precision and scale were already checked to be equal by the main
32        // `between` entrypoint function.
33
34        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
35            return Ok(None);
36        };
37
38        // NOTE: we know that have checked before that the lower and upper bounds are not all null.
39        let nullability =
40            arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
41
42        match_each_decimal_value_type!(arr.values_type(), |D| {
43            between_unpack::<D>(arr, lower, upper, nullability, options)
44        })
45    }
46}
47
48fn between_unpack<T: NativeDecimalType>(
49    arr: &DecimalArray,
50    lower: Scalar,
51    upper: Scalar,
52    nullability: Nullability,
53    options: &BetweenOptions,
54) -> VortexResult<Option<ArrayRef>> {
55    let Some(lower_value) = lower
56        .as_decimal()
57        .decimal_value()
58        .and_then(|v| v.cast::<T>())
59    else {
60        vortex_bail!(
61            "invalid lower bound Scalar: {lower}, expected {:?}",
62            T::DECIMAL_TYPE
63        )
64    };
65    let Some(upper_value) = upper
66        .as_decimal()
67        .decimal_value()
68        .and_then(|v| v.cast::<T>())
69    else {
70        vortex_bail!(
71            "invalid upper bound Scalar: {upper}, expected {:?}",
72            T::DECIMAL_TYPE
73        )
74    };
75
76    let lower_op = match options.lower_strict {
77        StrictComparison::Strict => |a, b| a < b,
78        StrictComparison::NonStrict => |a, b| a <= b,
79    };
80
81    let upper_op = match options.upper_strict {
82        StrictComparison::Strict => |a, b| a < b,
83        StrictComparison::NonStrict => |a, b| a <= b,
84    };
85
86    Ok(Some(between_impl::<T>(
87        arr,
88        lower_value,
89        upper_value,
90        nullability,
91        lower_op,
92        upper_op,
93    )))
94}
95
96fn between_impl<T: NativeDecimalType>(
97    arr: &DecimalArray,
98    lower: T,
99    upper: T,
100    nullability: Nullability,
101    lower_op: impl Fn(T, T) -> bool,
102    upper_op: impl Fn(T, T) -> bool,
103) -> ArrayRef {
104    let buffer = arr.buffer::<T>();
105    BoolArray::new(
106        BitBuffer::collect_bool(buffer.len(), |idx| {
107            let value = buffer[idx];
108            lower_op(lower, value) & upper_op(value, upper)
109        }),
110        arr.validity().clone().union_nullability(nullability),
111    )
112    .into_array()
113}