vortex_array/arrays/decimal/compute/
between.rs1use vortex_buffer::BitBuffer;
5use vortex_dtype::NativeDecimalType;
6use vortex_dtype::Nullability;
7use vortex_dtype::match_each_decimal_value_type;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10
11use crate::Array;
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::arrays::BoolArray;
16use crate::arrays::DecimalArray;
17use crate::arrays::DecimalVTable;
18use crate::expr::BetweenKernel;
19use crate::expr::BetweenOptions;
20use crate::expr::StrictComparison;
21use crate::scalar::Scalar;
22use crate::vtable::ValidityHelper;
23
24impl BetweenKernel for DecimalVTable {
25 fn between(
26 arr: &DecimalArray,
27 lower: &dyn Array,
28 upper: &dyn Array,
29 options: &BetweenOptions,
30 _ctx: &mut ExecutionCtx,
31 ) -> VortexResult<Option<ArrayRef>> {
32 let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
36 return Ok(None);
37 };
38
39 let nullability =
41 arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
42
43 match_each_decimal_value_type!(arr.values_type(), |D| {
44 between_unpack::<D>(arr, lower, upper, nullability, options)
45 })
46 }
47}
48
49fn between_unpack<T: NativeDecimalType>(
50 arr: &DecimalArray,
51 lower: Scalar,
52 upper: Scalar,
53 nullability: Nullability,
54 options: &BetweenOptions,
55) -> VortexResult<Option<ArrayRef>> {
56 let Some(lower_value) = lower
57 .as_decimal()
58 .decimal_value()
59 .and_then(|v| v.cast::<T>())
60 else {
61 vortex_bail!(
62 "invalid lower bound Scalar: {lower}, expected {:?}",
63 T::DECIMAL_TYPE
64 )
65 };
66 let Some(upper_value) = upper
67 .as_decimal()
68 .decimal_value()
69 .and_then(|v| v.cast::<T>())
70 else {
71 vortex_bail!(
72 "invalid upper bound Scalar: {upper}, expected {:?}",
73 T::DECIMAL_TYPE
74 )
75 };
76
77 let lower_op = match options.lower_strict {
78 StrictComparison::Strict => |a, b| a < b,
79 StrictComparison::NonStrict => |a, b| a <= b,
80 };
81
82 let upper_op = match options.upper_strict {
83 StrictComparison::Strict => |a, b| a < b,
84 StrictComparison::NonStrict => |a, b| a <= b,
85 };
86
87 Ok(Some(between_impl::<T>(
88 arr,
89 lower_value,
90 upper_value,
91 nullability,
92 lower_op,
93 upper_op,
94 )))
95}
96
97fn between_impl<T: NativeDecimalType>(
98 arr: &DecimalArray,
99 lower: T,
100 upper: T,
101 nullability: Nullability,
102 lower_op: impl Fn(T, T) -> bool,
103 upper_op: impl Fn(T, T) -> bool,
104) -> ArrayRef {
105 let buffer = arr.buffer::<T>();
106 BoolArray::new(
107 BitBuffer::collect_bool(buffer.len(), |idx| {
108 let value = buffer[idx];
109 lower_op(lower, value) & upper_op(value, upper)
110 }),
111 arr.validity().clone().union_nullability(nullability),
112 )
113 .into_array()
114}