vortex_array/arrays/decimal/compute/
between.rs1use vortex_buffer::BitBuffer;
5use vortex_dtype::NativeDecimalType;
6use vortex_dtype::Nullability;
7use vortex_dtype::match_each_decimal_value_type;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_scalar::Scalar;
11
12use crate::Array;
13use crate::ArrayRef;
14use crate::IntoArray;
15use crate::arrays::BoolArray;
16use crate::arrays::DecimalArray;
17use crate::arrays::DecimalVTable;
18use crate::compute::BetweenKernel;
19use crate::compute::BetweenKernelAdapter;
20use crate::compute::BetweenOptions;
21use crate::compute::StrictComparison;
22use crate::register_kernel;
23use crate::vtable::ValidityHelper;
24
25impl BetweenKernel for DecimalVTable {
26 fn between(
28 &self,
29 arr: &DecimalArray,
30 lower: &dyn Array,
31 upper: &dyn Array,
32 options: &BetweenOptions,
33 ) -> VortexResult<Option<ArrayRef>> {
34 let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
38 return Ok(None);
39 };
40
41 let nullability =
43 arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
44
45 match_each_decimal_value_type!(arr.values_type(), |D| {
46 between_unpack::<D>(arr, lower, upper, nullability, options)
47 })
48 }
49}
50
51fn between_unpack<T: NativeDecimalType>(
52 arr: &DecimalArray,
53 lower: Scalar,
54 upper: Scalar,
55 nullability: Nullability,
56 options: &BetweenOptions,
57) -> VortexResult<Option<ArrayRef>> {
58 let Some(lower_value) = lower
59 .as_decimal()
60 .decimal_value()
61 .and_then(|v| v.cast::<T>())
62 else {
63 vortex_bail!(
64 "invalid lower bound Scalar: {lower}, expected {:?}",
65 T::DECIMAL_TYPE
66 )
67 };
68 let Some(upper_value) = upper
69 .as_decimal()
70 .decimal_value()
71 .and_then(|v| v.cast::<T>())
72 else {
73 vortex_bail!(
74 "invalid upper bound Scalar: {upper}, expected {:?}",
75 T::DECIMAL_TYPE
76 )
77 };
78
79 let lower_op = match options.lower_strict {
80 StrictComparison::Strict => |a, b| a < b,
81 StrictComparison::NonStrict => |a, b| a <= b,
82 };
83
84 let upper_op = match options.upper_strict {
85 StrictComparison::Strict => |a, b| a < b,
86 StrictComparison::NonStrict => |a, b| a <= b,
87 };
88
89 Ok(Some(between_impl::<T>(
90 arr,
91 lower_value,
92 upper_value,
93 nullability,
94 lower_op,
95 upper_op,
96 )))
97}
98
99register_kernel!(BetweenKernelAdapter(DecimalVTable).lift());
100
101fn between_impl<T: NativeDecimalType>(
102 arr: &DecimalArray,
103 lower: T,
104 upper: T,
105 nullability: Nullability,
106 lower_op: impl Fn(T, T) -> bool,
107 upper_op: impl Fn(T, T) -> bool,
108) -> ArrayRef {
109 let buffer = arr.buffer::<T>();
110 BoolArray::from_bit_buffer(
111 BitBuffer::collect_bool(buffer.len(), |idx| {
112 let value = buffer[idx];
113 lower_op(lower, value) & upper_op(value, upper)
114 }),
115 arr.validity().clone().union_nullability(nullability),
116 )
117 .into_array()
118}
119
120#[cfg(test)]
121mod tests {
122 use vortex_buffer::buffer;
123 use vortex_dtype::DecimalDType;
124 use vortex_dtype::Nullability;
125 use vortex_scalar::DecimalValue;
126 use vortex_scalar::Scalar;
127
128 use crate::Array;
129 use crate::ToCanonical;
130 use crate::arrays::ConstantArray;
131 use crate::arrays::DecimalArray;
132 use crate::compute::BetweenOptions;
133 use crate::compute::StrictComparison;
134 use crate::compute::between;
135 use crate::validity::Validity;
136
137 #[test]
138 fn test_between() {
139 let values = buffer![100i128, 200i128, 300i128, 400i128];
140 let decimal_type = DecimalDType::new(3, 2);
141 let array = DecimalArray::new(values, decimal_type, Validity::NonNullable);
142
143 let lower = ConstantArray::new(
144 Scalar::decimal(
145 DecimalValue::I128(100i128),
146 decimal_type,
147 Nullability::NonNullable,
148 ),
149 array.len(),
150 );
151 let upper = ConstantArray::new(
152 Scalar::decimal(
153 DecimalValue::I128(400i128),
154 decimal_type,
155 Nullability::NonNullable,
156 ),
157 array.len(),
158 );
159
160 let between_strict = between(
162 array.as_ref(),
163 lower.as_ref(),
164 upper.as_ref(),
165 &BetweenOptions {
166 lower_strict: StrictComparison::Strict,
167 upper_strict: StrictComparison::NonStrict,
168 },
169 )
170 .unwrap();
171 assert_eq!(bool_to_vec(&between_strict), vec![false, true, true, true]);
172
173 let between_strict = between(
175 array.as_ref(),
176 lower.as_ref(),
177 upper.as_ref(),
178 &BetweenOptions {
179 lower_strict: StrictComparison::NonStrict,
180 upper_strict: StrictComparison::Strict,
181 },
182 )
183 .unwrap();
184 assert_eq!(bool_to_vec(&between_strict), vec![true, true, true, false]);
185 }
186
187 fn bool_to_vec(array: &dyn Array) -> Vec<bool> {
188 array.to_bool().bit_buffer().iter().collect()
189 }
190}