vortex_array/arrays/decimal/compute/
between.rs1use arrow_buffer::BooleanBuffer;
2use vortex_dtype::Nullability;
3use vortex_error::{VortexResult, vortex_bail};
4use vortex_scalar::{DecimalValue, Scalar};
5
6use crate::arrays::{BoolArray, DecimalArray, DecimalEncoding, NativeDecimalType};
7use crate::compute::{BetweenKernel, BetweenKernelAdapter, BetweenOptions, StrictComparison};
8use crate::{Array, ArrayRef, match_each_decimal_value_type, register_kernel};
9
10impl BetweenKernel for DecimalEncoding {
11 fn between(
13 &self,
14 arr: &DecimalArray,
15 lower: &dyn Array,
16 upper: &dyn Array,
17 options: &BetweenOptions,
18 ) -> VortexResult<Option<ArrayRef>> {
19 let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
23 return Ok(None);
24 };
25
26 let nullability =
28 arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
29
30 match_each_decimal_value_type!(arr.values_type(), |($D, $CTor)| {
31 between_unpack::<$D>(
32 arr,
33 lower,
34 upper,
35 |d| match d {
36 $CTor(v) => Some(v),
37 _ => None,
38 },
39 nullability,
40 options,
41 )
42 })
43 }
44}
45
46fn between_unpack<T: NativeDecimalType>(
47 arr: &DecimalArray,
48 lower: Scalar,
49 upper: Scalar,
50 unpack: impl Fn(DecimalValue) -> Option<T> + Copy,
51 nullability: Nullability,
52 options: &BetweenOptions,
53) -> VortexResult<Option<ArrayRef>> {
54 let Some(lower_value) = lower.as_decimal().decimal_value().and_then(unpack) else {
55 vortex_bail!("invalid lower bound Scalar: {lower}");
56 };
57 let Some(upper_value) = upper.as_decimal().decimal_value().and_then(unpack) else {
58 vortex_bail!("invalid upper bound Scalar: {upper}");
59 };
60
61 let lower_op = match options.lower_strict {
62 StrictComparison::Strict => |a, b| a < b,
63 StrictComparison::NonStrict => |a, b| a <= b,
64 };
65
66 let upper_op = match options.upper_strict {
67 StrictComparison::Strict => |a, b| a < b,
68 StrictComparison::NonStrict => |a, b| a <= b,
69 };
70
71 Ok(Some(between_impl::<T>(
72 arr,
73 lower_value,
74 upper_value,
75 nullability,
76 lower_op,
77 upper_op,
78 )))
79}
80
81register_kernel!(BetweenKernelAdapter(DecimalEncoding).lift());
82
83fn between_impl<T: NativeDecimalType>(
84 arr: &DecimalArray,
85 lower: T,
86 upper: T,
87 nullability: Nullability,
88 lower_op: impl Fn(T, T) -> bool,
89 upper_op: impl Fn(T, T) -> bool,
90) -> ArrayRef {
91 let buffer = arr.buffer::<T>();
92 BoolArray::new(
93 BooleanBuffer::collect_bool(buffer.len(), |idx| {
94 let value = buffer[idx];
95 lower_op(lower, value) & upper_op(value, upper)
96 }),
97 arr.validity().clone().union_nullability(nullability),
98 )
99 .into_array()
100}
101
102#[cfg(test)]
103mod tests {
104 use vortex_buffer::buffer;
105 use vortex_dtype::{DecimalDType, Nullability};
106 use vortex_scalar::{DecimalValue, Scalar};
107
108 use crate::Array;
109 use crate::arrays::{ConstantArray, DecimalArray};
110 use crate::compute::{BetweenOptions, StrictComparison, between};
111 use crate::validity::Validity;
112
113 #[test]
114 fn test_between() {
115 let values = buffer![100i128, 200i128, 300i128, 400i128];
116 let decimal_type = DecimalDType::new(3, 2);
117 let array = DecimalArray::new(values, decimal_type, Validity::NonNullable);
118
119 let lower = ConstantArray::new(
120 Scalar::decimal(
121 DecimalValue::I128(100i128),
122 decimal_type,
123 Nullability::NonNullable,
124 ),
125 array.len(),
126 );
127 let upper = ConstantArray::new(
128 Scalar::decimal(
129 DecimalValue::I128(400i128),
130 decimal_type,
131 Nullability::NonNullable,
132 ),
133 array.len(),
134 );
135
136 let between_strict = between(
138 &array,
139 &lower,
140 &upper,
141 &BetweenOptions {
142 lower_strict: StrictComparison::Strict,
143 upper_strict: StrictComparison::NonStrict,
144 },
145 )
146 .unwrap();
147 assert_eq!(bool_to_vec(&between_strict), vec![false, true, true, true]);
148
149 let between_strict = between(
151 &array,
152 &lower,
153 &upper,
154 &BetweenOptions {
155 lower_strict: StrictComparison::NonStrict,
156 upper_strict: StrictComparison::Strict,
157 },
158 )
159 .unwrap();
160 assert_eq!(bool_to_vec(&between_strict), vec![true, true, true, false]);
161 }
162
163 fn bool_to_vec(array: &dyn Array) -> Vec<bool> {
164 array
165 .to_canonical()
166 .unwrap()
167 .into_bool()
168 .unwrap()
169 .boolean_buffer()
170 .iter()
171 .collect()
172 }
173}