vortex_array/arrays/decimal/compute/
between.rs1use arrow_buffer::BooleanBuffer;
5use vortex_dtype::Nullability;
6use vortex_error::{VortexResult, vortex_bail};
7use vortex_scalar::{NativeDecimalType, Scalar, match_each_decimal_value_type};
8
9use crate::arrays::{BoolArray, DecimalArray, DecimalVTable};
10use crate::compute::{BetweenKernel, BetweenKernelAdapter, BetweenOptions, StrictComparison};
11use crate::vtable::ValidityHelper;
12use crate::{Array, ArrayRef, IntoArray, register_kernel};
13
14impl BetweenKernel for DecimalVTable {
15 fn between(
17 &self,
18 arr: &DecimalArray,
19 lower: &dyn Array,
20 upper: &dyn Array,
21 options: &BetweenOptions,
22 ) -> VortexResult<Option<ArrayRef>> {
23 let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
27 return Ok(None);
28 };
29
30 let nullability =
32 arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
33
34 match_each_decimal_value_type!(arr.values_type(), |D| {
35 between_unpack::<D>(arr, lower, upper, nullability, options)
36 })
37 }
38}
39
40fn between_unpack<T: NativeDecimalType>(
41 arr: &DecimalArray,
42 lower: Scalar,
43 upper: Scalar,
44 nullability: Nullability,
45 options: &BetweenOptions,
46) -> VortexResult<Option<ArrayRef>> {
47 let Some(lower_value) = lower
48 .as_decimal()
49 .decimal_value()
50 .and_then(|v| v.cast::<T>())
51 else {
52 vortex_bail!(
53 "invalid lower bound Scalar: {lower}, expected {:?}",
54 T::VALUES_TYPE
55 )
56 };
57 let Some(upper_value) = upper
58 .as_decimal()
59 .decimal_value()
60 .and_then(|v| v.cast::<T>())
61 else {
62 vortex_bail!(
63 "invalid upper bound Scalar: {upper}, expected {:?}",
64 T::VALUES_TYPE
65 )
66 };
67
68 let lower_op = match options.lower_strict {
69 StrictComparison::Strict => |a, b| a < b,
70 StrictComparison::NonStrict => |a, b| a <= b,
71 };
72
73 let upper_op = match options.upper_strict {
74 StrictComparison::Strict => |a, b| a < b,
75 StrictComparison::NonStrict => |a, b| a <= b,
76 };
77
78 Ok(Some(between_impl::<T>(
79 arr,
80 lower_value,
81 upper_value,
82 nullability,
83 lower_op,
84 upper_op,
85 )))
86}
87
88register_kernel!(BetweenKernelAdapter(DecimalVTable).lift());
89
90fn between_impl<T: NativeDecimalType>(
91 arr: &DecimalArray,
92 lower: T,
93 upper: T,
94 nullability: Nullability,
95 lower_op: impl Fn(T, T) -> bool,
96 upper_op: impl Fn(T, T) -> bool,
97) -> ArrayRef {
98 let buffer = arr.buffer::<T>();
99 BoolArray::new(
100 BooleanBuffer::collect_bool(buffer.len(), |idx| {
101 let value = buffer[idx];
102 lower_op(lower, value) & upper_op(value, upper)
103 }),
104 arr.validity().clone().union_nullability(nullability),
105 )
106 .into_array()
107}
108
109#[cfg(test)]
110mod tests {
111 use vortex_buffer::buffer;
112 use vortex_dtype::{DecimalDType, Nullability};
113 use vortex_scalar::{DecimalValue, Scalar};
114
115 use crate::Array;
116 use crate::arrays::{ConstantArray, DecimalArray};
117 use crate::compute::{BetweenOptions, StrictComparison, between};
118 use crate::validity::Validity;
119
120 #[test]
121 fn test_between() {
122 let values = buffer![100i128, 200i128, 300i128, 400i128];
123 let decimal_type = DecimalDType::new(3, 2);
124 let array = DecimalArray::new(values, decimal_type, Validity::NonNullable);
125
126 let lower = ConstantArray::new(
127 Scalar::decimal(
128 DecimalValue::I128(100i128),
129 decimal_type,
130 Nullability::NonNullable,
131 ),
132 array.len(),
133 );
134 let upper = ConstantArray::new(
135 Scalar::decimal(
136 DecimalValue::I128(400i128),
137 decimal_type,
138 Nullability::NonNullable,
139 ),
140 array.len(),
141 );
142
143 let between_strict = between(
145 array.as_ref(),
146 lower.as_ref(),
147 upper.as_ref(),
148 &BetweenOptions {
149 lower_strict: StrictComparison::Strict,
150 upper_strict: StrictComparison::NonStrict,
151 },
152 )
153 .unwrap();
154 assert_eq!(bool_to_vec(&between_strict), vec![false, true, true, true]);
155
156 let between_strict = between(
158 array.as_ref(),
159 lower.as_ref(),
160 upper.as_ref(),
161 &BetweenOptions {
162 lower_strict: StrictComparison::NonStrict,
163 upper_strict: StrictComparison::Strict,
164 },
165 )
166 .unwrap();
167 assert_eq!(bool_to_vec(&between_strict), vec![true, true, true, false]);
168 }
169
170 fn bool_to_vec(array: &dyn Array) -> Vec<bool> {
171 array
172 .to_canonical()
173 .unwrap()
174 .into_bool()
175 .unwrap()
176 .boolean_buffer()
177 .iter()
178 .collect()
179 }
180}