vortex_array/arrays/decimal/compute/
between.rs1use vortex_buffer::BitBuffer;
5use vortex_error::VortexResult;
6use vortex_error::vortex_bail;
7
8use crate::ArrayRef;
9use crate::ExecutionCtx;
10use crate::IntoArray;
11use crate::arrays::BoolArray;
12use crate::arrays::DecimalArray;
13use crate::arrays::DecimalVTable;
14use crate::dtype::NativeDecimalType;
15use crate::dtype::Nullability;
16use crate::match_each_decimal_value_type;
17use crate::scalar::Scalar;
18use crate::scalar_fn::fns::between::BetweenKernel;
19use crate::scalar_fn::fns::between::BetweenOptions;
20use crate::scalar_fn::fns::between::StrictComparison;
21use crate::vtable::ValidityHelper;
22
23impl BetweenKernel for DecimalVTable {
24 fn between(
25 arr: &DecimalArray,
26 lower: &ArrayRef,
27 upper: &ArrayRef,
28 options: &BetweenOptions,
29 _ctx: &mut ExecutionCtx,
30 ) -> VortexResult<Option<ArrayRef>> {
31 let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
35 return Ok(None);
36 };
37
38 let nullability =
40 arr.dtype.nullability() | lower.dtype().nullability() | upper.dtype().nullability();
41
42 match_each_decimal_value_type!(arr.values_type(), |D| {
43 between_unpack::<D>(arr, lower, upper, nullability, options)
44 })
45 }
46}
47
48fn between_unpack<T: NativeDecimalType>(
49 arr: &DecimalArray,
50 lower: Scalar,
51 upper: Scalar,
52 nullability: Nullability,
53 options: &BetweenOptions,
54) -> VortexResult<Option<ArrayRef>> {
55 let Some(lower_value) = lower
56 .as_decimal()
57 .decimal_value()
58 .and_then(|v| v.cast::<T>())
59 else {
60 vortex_bail!(
61 "invalid lower bound Scalar: {lower}, expected {:?}",
62 T::DECIMAL_TYPE
63 )
64 };
65 let Some(upper_value) = upper
66 .as_decimal()
67 .decimal_value()
68 .and_then(|v| v.cast::<T>())
69 else {
70 vortex_bail!(
71 "invalid upper bound Scalar: {upper}, expected {:?}",
72 T::DECIMAL_TYPE
73 )
74 };
75
76 let lower_op = match options.lower_strict {
77 StrictComparison::Strict => |a, b| a < b,
78 StrictComparison::NonStrict => |a, b| a <= b,
79 };
80
81 let upper_op = match options.upper_strict {
82 StrictComparison::Strict => |a, b| a < b,
83 StrictComparison::NonStrict => |a, b| a <= b,
84 };
85
86 Ok(Some(between_impl::<T>(
87 arr,
88 lower_value,
89 upper_value,
90 nullability,
91 lower_op,
92 upper_op,
93 )))
94}
95
96fn between_impl<T: NativeDecimalType>(
97 arr: &DecimalArray,
98 lower: T,
99 upper: T,
100 nullability: Nullability,
101 lower_op: impl Fn(T, T) -> bool,
102 upper_op: impl Fn(T, T) -> bool,
103) -> ArrayRef {
104 let buffer = arr.buffer::<T>();
105 BoolArray::new(
106 BitBuffer::collect_bool(buffer.len(), |idx| {
107 let value = buffer[idx];
108 lower_op(lower, value) & upper_op(value, upper)
109 }),
110 arr.validity().clone().union_nullability(nullability),
111 )
112 .into_array()
113}