Skip to main content

vortex_array/arrays/decimal/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBuffer;
5use vortex_error::VortexExpect;
6use vortex_error::VortexResult;
7
8use crate::ArrayRef;
9use crate::ExecutionCtx;
10use crate::IntoArray;
11use crate::array::ArrayView;
12use crate::arrays::BoolArray;
13use crate::arrays::ConstantArray;
14use crate::arrays::Decimal;
15use crate::dtype::NativeDecimalType;
16use crate::dtype::Nullability;
17use crate::dtype::i256;
18use crate::match_each_decimal_value_type;
19use crate::scalar::Scalar;
20use crate::scalar_fn::fns::between::BetweenKernel;
21use crate::scalar_fn::fns::between::BetweenOptions;
22use crate::scalar_fn::fns::between::StrictComparison;
23
24impl BetweenKernel for Decimal {
25    fn between(
26        arr: ArrayView<'_, Decimal>,
27        lower: &ArrayRef,
28        upper: &ArrayRef,
29        options: &BetweenOptions,
30        _ctx: &mut ExecutionCtx,
31    ) -> VortexResult<Option<ArrayRef>> {
32        // NOTE: We know that the precision and scale were already checked to be equal by the main
33        // `between` entrypoint function.
34
35        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
36            return Ok(None);
37        };
38
39        // NOTE: we know that have checked before that the lower and upper bounds are not all null.
40        let nullability =
41            arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability();
42
43        match_each_decimal_value_type!(arr.values_type(), |D| {
44            between_unpack::<D>(arr, lower, upper, nullability, options)
45        })
46    }
47}
48
49fn between_unpack<T: NativeDecimalType>(
50    arr: ArrayView<'_, Decimal>,
51    lower: Scalar,
52    upper: Scalar,
53    nullability: Nullability,
54    options: &BetweenOptions,
55) -> VortexResult<Option<ArrayRef>> {
56    let Some(lower_dv) = lower.as_decimal().decimal_value() else {
57        // Null lower bound — fall back to canonical path.
58        return Ok(None);
59    };
60    let Some(upper_dv) = upper.as_decimal().decimal_value() else {
61        // Null upper bound — fall back to canonical path.
62        return Ok(None);
63    };
64
65    // Try to cast the bound scalar to the array's storage type T.
66    //
67    // If the cast fails, the bound's value is outside [T::MIN, T::MAX].  For all signed
68    // NativeDecimalType implementations the minimum is negative and the maximum is positive, so
69    // we can determine the direction of overflow from the sign of the bound:
70    //   • non-negative and doesn't fit in T  ⟹  value > T::MAX
71    //   • negative and doesn't fit in T      ⟹  value < T::MIN
72    //
73    // From the direction we can answer the comparison immediately:
74    //   lower > T::MAX: no array value (≤ T::MAX) satisfies lower ≤ value  → all-false
75    //   lower < T::MIN: every array value (≥ T::MIN) satisfies lower ≤ value → no lower constraint
76    //   upper > T::MAX: every array value (≤ T::MAX) satisfies value ≤ upper → no upper constraint
77    //   upper < T::MIN: no array value (≥ T::MIN) satisfies value ≤ upper   → all-false
78    //
79    // Both the strict and non-strict forms lead to the same conclusion because the overflow is
80    // by at least one integer, so no boundary element can make the strict form differ.
81    let lower_value: Option<T> = match lower_dv.cast::<T>() {
82        Some(v) => Some(v),
83        None => {
84            if lower_dv.as_i256() >= i256::ZERO {
85                return Ok(Some(
86                    ConstantArray::new(Scalar::bool(false, nullability), arr.len()).into_array(),
87                ));
88            }
89            None
90        }
91    };
92
93    let upper_value: Option<T> = match upper_dv.cast::<T>() {
94        Some(v) => Some(v),
95        None => {
96            if upper_dv.as_i256() < i256::ZERO {
97                return Ok(Some(
98                    ConstantArray::new(Scalar::bool(false, nullability), arr.len()).into_array(),
99                ));
100            }
101            None
102        }
103    };
104
105    let lower_op = match options.lower_strict {
106        StrictComparison::Strict => |a, b| a < b,
107        StrictComparison::NonStrict => |a, b| a <= b,
108    };
109
110    let upper_op = match options.upper_strict {
111        StrictComparison::Strict => |a, b| a < b,
112        StrictComparison::NonStrict => |a, b| a <= b,
113    };
114
115    Ok(Some(between_impl::<T>(
116        arr,
117        lower_value,
118        upper_value,
119        nullability,
120        lower_op,
121        upper_op,
122    )))
123}
124
125fn between_impl<T: NativeDecimalType>(
126    arr: ArrayView<'_, Decimal>,
127    lower: Option<T>,
128    upper: Option<T>,
129    nullability: Nullability,
130    lower_op: impl Fn(T, T) -> bool,
131    upper_op: impl Fn(T, T) -> bool,
132) -> ArrayRef {
133    let buffer = arr.buffer::<T>();
134    BoolArray::new(
135        BitBuffer::collect_bool(buffer.len(), |idx| {
136            let value = buffer[idx];
137            lower.is_none_or(|l| lower_op(l, value)) & upper.is_none_or(|u| upper_op(value, u))
138        }),
139        arr.validity()
140            .vortex_expect("validity should be derivable")
141            .union_nullability(nullability),
142    )
143    .into_array()
144}