vortex_array/arrays/decimal/compute/
between.rs1use arrow_buffer::BooleanBuffer;
2use vortex_dtype::Nullability;
3use vortex_error::{VortexResult, vortex_bail};
4use vortex_scalar::{DecimalValue, NativeDecimalType, Scalar, match_each_decimal_value_type};
5
6use crate::arrays::{BoolArray, DecimalArray, DecimalVTable};
7use crate::compute::{BetweenKernel, BetweenKernelAdapter, BetweenOptions, StrictComparison};
8use crate::vtable::ValidityHelper;
9use crate::{Array, ArrayRef, IntoArray, register_kernel};
10
11impl BetweenKernel for DecimalVTable {
12 fn between(
14 &self,
15 arr: &DecimalArray,
16 lower: &dyn Array,
17 upper: &dyn Array,
18 options: &BetweenOptions,
19 ) -> VortexResult<Option<ArrayRef>> {
20 let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
24 return Ok(None);
25 };
26
27 let nullability =
29 arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
30
31 match_each_decimal_value_type!(arr.values_type(), |($D, $CTor)| {
32 between_unpack::<$D>(
33 arr,
34 lower,
35 upper,
36 |d| match d {
37 $CTor(v) => Some(v),
38 _ => None,
39 },
40 nullability,
41 options,
42 )
43 })
44 }
45}
46
47fn between_unpack<T: NativeDecimalType>(
48 arr: &DecimalArray,
49 lower: Scalar,
50 upper: Scalar,
51 unpack: impl Fn(DecimalValue) -> Option<T> + Copy,
52 nullability: Nullability,
53 options: &BetweenOptions,
54) -> VortexResult<Option<ArrayRef>> {
55 let Some(lower_value) = lower.as_decimal().decimal_value().and_then(unpack) else {
56 vortex_bail!("invalid lower bound Scalar: {lower}");
57 };
58 let Some(upper_value) = upper.as_decimal().decimal_value().and_then(unpack) else {
59 vortex_bail!("invalid upper bound Scalar: {upper}");
60 };
61
62 let lower_op = match options.lower_strict {
63 StrictComparison::Strict => |a, b| a < b,
64 StrictComparison::NonStrict => |a, b| a <= b,
65 };
66
67 let upper_op = match options.upper_strict {
68 StrictComparison::Strict => |a, b| a < b,
69 StrictComparison::NonStrict => |a, b| a <= b,
70 };
71
72 Ok(Some(between_impl::<T>(
73 arr,
74 lower_value,
75 upper_value,
76 nullability,
77 lower_op,
78 upper_op,
79 )))
80}
81
82register_kernel!(BetweenKernelAdapter(DecimalVTable).lift());
83
84fn between_impl<T: NativeDecimalType>(
85 arr: &DecimalArray,
86 lower: T,
87 upper: T,
88 nullability: Nullability,
89 lower_op: impl Fn(T, T) -> bool,
90 upper_op: impl Fn(T, T) -> bool,
91) -> ArrayRef {
92 let buffer = arr.buffer::<T>();
93 BoolArray::new(
94 BooleanBuffer::collect_bool(buffer.len(), |idx| {
95 let value = buffer[idx];
96 lower_op(lower, value) & upper_op(value, upper)
97 }),
98 arr.validity().clone().union_nullability(nullability),
99 )
100 .into_array()
101}
102
103#[cfg(test)]
104mod tests {
105 use vortex_buffer::buffer;
106 use vortex_dtype::{DecimalDType, Nullability};
107 use vortex_scalar::{DecimalValue, Scalar};
108
109 use crate::Array;
110 use crate::arrays::{ConstantArray, DecimalArray};
111 use crate::compute::{BetweenOptions, StrictComparison, between};
112 use crate::validity::Validity;
113
114 #[test]
115 fn test_between() {
116 let values = buffer![100i128, 200i128, 300i128, 400i128];
117 let decimal_type = DecimalDType::new(3, 2);
118 let array = DecimalArray::new(values, decimal_type, Validity::NonNullable);
119
120 let lower = ConstantArray::new(
121 Scalar::decimal(
122 DecimalValue::I128(100i128),
123 decimal_type,
124 Nullability::NonNullable,
125 ),
126 array.len(),
127 );
128 let upper = ConstantArray::new(
129 Scalar::decimal(
130 DecimalValue::I128(400i128),
131 decimal_type,
132 Nullability::NonNullable,
133 ),
134 array.len(),
135 );
136
137 let between_strict = between(
139 array.as_ref(),
140 lower.as_ref(),
141 upper.as_ref(),
142 &BetweenOptions {
143 lower_strict: StrictComparison::Strict,
144 upper_strict: StrictComparison::NonStrict,
145 },
146 )
147 .unwrap();
148 assert_eq!(bool_to_vec(&between_strict), vec![false, true, true, true]);
149
150 let between_strict = between(
152 array.as_ref(),
153 lower.as_ref(),
154 upper.as_ref(),
155 &BetweenOptions {
156 lower_strict: StrictComparison::NonStrict,
157 upper_strict: StrictComparison::Strict,
158 },
159 )
160 .unwrap();
161 assert_eq!(bool_to_vec(&between_strict), vec![true, true, true, false]);
162 }
163
164 fn bool_to_vec(array: &dyn Array) -> Vec<bool> {
165 array
166 .to_canonical()
167 .unwrap()
168 .into_bool()
169 .unwrap()
170 .boolean_buffer()
171 .iter()
172 .collect()
173 }
174}