vortex_array/arrays/decimal/compute/
between.rs1use arrow_buffer::BooleanBuffer;
2use vortex_dtype::Nullability;
3use vortex_error::{VortexResult, vortex_bail};
4use vortex_scalar::{NativeDecimalType, Scalar, match_each_decimal_value_type};
5
6use crate::arrays::{BoolArray, DecimalArray, DecimalVTable};
7use crate::compute::{BetweenKernel, BetweenKernelAdapter, BetweenOptions, StrictComparison};
8use crate::vtable::ValidityHelper;
9use crate::{Array, ArrayRef, IntoArray, register_kernel};
10
11impl BetweenKernel for DecimalVTable {
12 fn between(
14 &self,
15 arr: &DecimalArray,
16 lower: &dyn Array,
17 upper: &dyn Array,
18 options: &BetweenOptions,
19 ) -> VortexResult<Option<ArrayRef>> {
20 let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
24 return Ok(None);
25 };
26
27 let nullability =
29 arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
30
31 match_each_decimal_value_type!(arr.values_type(), |D| {
32 between_unpack::<D>(arr, lower, upper, nullability, options)
33 })
34 }
35}
36
37fn between_unpack<T: NativeDecimalType>(
38 arr: &DecimalArray,
39 lower: Scalar,
40 upper: Scalar,
41 nullability: Nullability,
42 options: &BetweenOptions,
43) -> VortexResult<Option<ArrayRef>> {
44 let Some(lower_value) = lower
45 .as_decimal()
46 .decimal_value()
47 .and_then(|v| v.cast::<T>())
48 else {
49 vortex_bail!(
50 "invalid lower bound Scalar: {lower}, expected {:?}",
51 T::VALUES_TYPE
52 )
53 };
54 let Some(upper_value) = upper
55 .as_decimal()
56 .decimal_value()
57 .and_then(|v| v.cast::<T>())
58 else {
59 vortex_bail!(
60 "invalid upper bound Scalar: {upper}, expected {:?}",
61 T::VALUES_TYPE
62 )
63 };
64
65 let lower_op = match options.lower_strict {
66 StrictComparison::Strict => |a, b| a < b,
67 StrictComparison::NonStrict => |a, b| a <= b,
68 };
69
70 let upper_op = match options.upper_strict {
71 StrictComparison::Strict => |a, b| a < b,
72 StrictComparison::NonStrict => |a, b| a <= b,
73 };
74
75 Ok(Some(between_impl::<T>(
76 arr,
77 lower_value,
78 upper_value,
79 nullability,
80 lower_op,
81 upper_op,
82 )))
83}
84
85register_kernel!(BetweenKernelAdapter(DecimalVTable).lift());
86
87fn between_impl<T: NativeDecimalType>(
88 arr: &DecimalArray,
89 lower: T,
90 upper: T,
91 nullability: Nullability,
92 lower_op: impl Fn(T, T) -> bool,
93 upper_op: impl Fn(T, T) -> bool,
94) -> ArrayRef {
95 let buffer = arr.buffer::<T>();
96 BoolArray::new(
97 BooleanBuffer::collect_bool(buffer.len(), |idx| {
98 let value = buffer[idx];
99 lower_op(lower, value) & upper_op(value, upper)
100 }),
101 arr.validity().clone().union_nullability(nullability),
102 )
103 .into_array()
104}
105
106#[cfg(test)]
107mod tests {
108 use vortex_buffer::buffer;
109 use vortex_dtype::{DecimalDType, Nullability};
110 use vortex_scalar::{DecimalValue, Scalar};
111
112 use crate::Array;
113 use crate::arrays::{ConstantArray, DecimalArray};
114 use crate::compute::{BetweenOptions, StrictComparison, between};
115 use crate::validity::Validity;
116
117 #[test]
118 fn test_between() {
119 let values = buffer![100i128, 200i128, 300i128, 400i128];
120 let decimal_type = DecimalDType::new(3, 2);
121 let array = DecimalArray::new(values, decimal_type, Validity::NonNullable);
122
123 let lower = ConstantArray::new(
124 Scalar::decimal(
125 DecimalValue::I128(100i128),
126 decimal_type,
127 Nullability::NonNullable,
128 ),
129 array.len(),
130 );
131 let upper = ConstantArray::new(
132 Scalar::decimal(
133 DecimalValue::I128(400i128),
134 decimal_type,
135 Nullability::NonNullable,
136 ),
137 array.len(),
138 );
139
140 let between_strict = between(
142 array.as_ref(),
143 lower.as_ref(),
144 upper.as_ref(),
145 &BetweenOptions {
146 lower_strict: StrictComparison::Strict,
147 upper_strict: StrictComparison::NonStrict,
148 },
149 )
150 .unwrap();
151 assert_eq!(bool_to_vec(&between_strict), vec![false, true, true, true]);
152
153 let between_strict = between(
155 array.as_ref(),
156 lower.as_ref(),
157 upper.as_ref(),
158 &BetweenOptions {
159 lower_strict: StrictComparison::NonStrict,
160 upper_strict: StrictComparison::Strict,
161 },
162 )
163 .unwrap();
164 assert_eq!(bool_to_vec(&between_strict), vec![true, true, true, false]);
165 }
166
167 fn bool_to_vec(array: &dyn Array) -> Vec<bool> {
168 array
169 .to_canonical()
170 .unwrap()
171 .into_bool()
172 .unwrap()
173 .boolean_buffer()
174 .iter()
175 .collect()
176 }
177}