vortex_array/arrays/decimal/compute/
between.rs1use arrow_buffer::BooleanBuffer;
2use vortex_dtype::Nullability;
3use vortex_error::{VortexResult, vortex_bail};
4use vortex_scalar::{NativeDecimalType, Scalar, match_each_decimal_value_type};
5
6use crate::arrays::{BoolArray, DecimalArray, DecimalVTable};
7use crate::compute::{BetweenKernel, BetweenKernelAdapter, BetweenOptions, StrictComparison};
8use crate::vtable::ValidityHelper;
9use crate::{Array, ArrayRef, IntoArray, register_kernel};
10
11impl BetweenKernel for DecimalVTable {
12 fn between(
14 &self,
15 arr: &DecimalArray,
16 lower: &dyn Array,
17 upper: &dyn Array,
18 options: &BetweenOptions,
19 ) -> VortexResult<Option<ArrayRef>> {
20 let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
24 return Ok(None);
25 };
26
27 let nullability =
29 arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
30
31 match_each_decimal_value_type!(arr.values_type(), |D| {
32 between_unpack::<D>(arr, lower, upper, nullability, options)
33 })
34 }
35}
36
37fn between_unpack<T: NativeDecimalType>(
38 arr: &DecimalArray,
39 lower: Scalar,
40 upper: Scalar,
41 nullability: Nullability,
42 options: &BetweenOptions,
43) -> VortexResult<Option<ArrayRef>> {
44 let Some(lower_value) = lower
45 .as_decimal()
46 .decimal_value()
47 .and_then(|v| T::try_from(v))
48 else {
49 vortex_bail!("invalid lower bound Scalar: {lower}");
50 };
51 let Some(upper_value) = upper
52 .as_decimal()
53 .decimal_value()
54 .and_then(|v| T::try_from(v))
55 else {
56 vortex_bail!("invalid upper bound Scalar: {upper}");
57 };
58
59 let lower_op = match options.lower_strict {
60 StrictComparison::Strict => |a, b| a < b,
61 StrictComparison::NonStrict => |a, b| a <= b,
62 };
63
64 let upper_op = match options.upper_strict {
65 StrictComparison::Strict => |a, b| a < b,
66 StrictComparison::NonStrict => |a, b| a <= b,
67 };
68
69 Ok(Some(between_impl::<T>(
70 arr,
71 lower_value,
72 upper_value,
73 nullability,
74 lower_op,
75 upper_op,
76 )))
77}
78
79register_kernel!(BetweenKernelAdapter(DecimalVTable).lift());
80
81fn between_impl<T: NativeDecimalType>(
82 arr: &DecimalArray,
83 lower: T,
84 upper: T,
85 nullability: Nullability,
86 lower_op: impl Fn(T, T) -> bool,
87 upper_op: impl Fn(T, T) -> bool,
88) -> ArrayRef {
89 let buffer = arr.buffer::<T>();
90 BoolArray::new(
91 BooleanBuffer::collect_bool(buffer.len(), |idx| {
92 let value = buffer[idx];
93 lower_op(lower, value) & upper_op(value, upper)
94 }),
95 arr.validity().clone().union_nullability(nullability),
96 )
97 .into_array()
98}
99
100#[cfg(test)]
101mod tests {
102 use vortex_buffer::buffer;
103 use vortex_dtype::{DecimalDType, Nullability};
104 use vortex_scalar::{DecimalValue, Scalar};
105
106 use crate::Array;
107 use crate::arrays::{ConstantArray, DecimalArray};
108 use crate::compute::{BetweenOptions, StrictComparison, between};
109 use crate::validity::Validity;
110
111 #[test]
112 fn test_between() {
113 let values = buffer![100i128, 200i128, 300i128, 400i128];
114 let decimal_type = DecimalDType::new(3, 2);
115 let array = DecimalArray::new(values, decimal_type, Validity::NonNullable);
116
117 let lower = ConstantArray::new(
118 Scalar::decimal(
119 DecimalValue::I128(100i128),
120 decimal_type,
121 Nullability::NonNullable,
122 ),
123 array.len(),
124 );
125 let upper = ConstantArray::new(
126 Scalar::decimal(
127 DecimalValue::I128(400i128),
128 decimal_type,
129 Nullability::NonNullable,
130 ),
131 array.len(),
132 );
133
134 let between_strict = between(
136 array.as_ref(),
137 lower.as_ref(),
138 upper.as_ref(),
139 &BetweenOptions {
140 lower_strict: StrictComparison::Strict,
141 upper_strict: StrictComparison::NonStrict,
142 },
143 )
144 .unwrap();
145 assert_eq!(bool_to_vec(&between_strict), vec![false, true, true, true]);
146
147 let between_strict = between(
149 array.as_ref(),
150 lower.as_ref(),
151 upper.as_ref(),
152 &BetweenOptions {
153 lower_strict: StrictComparison::NonStrict,
154 upper_strict: StrictComparison::Strict,
155 },
156 )
157 .unwrap();
158 assert_eq!(bool_to_vec(&between_strict), vec![true, true, true, false]);
159 }
160
161 fn bool_to_vec(array: &dyn Array) -> Vec<bool> {
162 array
163 .to_canonical()
164 .unwrap()
165 .into_bool()
166 .unwrap()
167 .boolean_buffer()
168 .iter()
169 .collect()
170 }
171}