vortex_alp/alp/compute/
between.rs1use std::fmt::Debug;
5
6use vortex_array::Array;
7use vortex_array::ArrayRef;
8use vortex_array::arrays::ConstantArray;
9use vortex_array::compute::BetweenKernel;
10use vortex_array::compute::BetweenKernelAdapter;
11use vortex_array::compute::BetweenOptions;
12use vortex_array::compute::StrictComparison;
13use vortex_array::compute::between;
14use vortex_array::register_kernel;
15use vortex_dtype::NativeDType;
16use vortex_dtype::NativePType;
17use vortex_dtype::Nullability;
18use vortex_error::VortexResult;
19use vortex_scalar::Scalar;
20
21use crate::ALPArray;
22use crate::ALPFloat;
23use crate::ALPVTable;
24use crate::match_each_alp_float_ptype;
25
26impl BetweenKernel for ALPVTable {
27 fn between(
28 &self,
29 array: &ALPArray,
30 lower: &dyn Array,
31 upper: &dyn Array,
32 options: &BetweenOptions,
33 ) -> VortexResult<Option<ArrayRef>> {
34 let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
35 return Ok(None);
36 };
37
38 if array.patches().is_some() {
39 return Ok(None);
40 }
41
42 let nullability =
43 array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability();
44
45 match_each_alp_float_ptype!(array.ptype(), |F| {
46 between_impl::<F>(
47 array,
48 F::try_from(lower)?,
49 F::try_from(upper)?,
50 nullability,
51 options,
52 )
53 })
54 .map(Some)
55 }
56}
57
58register_kernel!(BetweenKernelAdapter(ALPVTable).lift());
59
60fn between_impl<T: NativePType + ALPFloat>(
61 array: &ALPArray,
62 lower: T,
63 upper: T,
64 nullability: Nullability,
65 options: &BetweenOptions,
66) -> VortexResult<ArrayRef>
67where
68 Scalar: From<T::ALPInt>,
69 <T as ALPFloat>::ALPInt: NativeDType + Debug,
70{
71 let exponents = array.exponents();
72
73 let (lower_enc, lower_strict) = T::encode_single(lower, exponents)
80 .map(|x| (x, options.lower_strict))
81 .unwrap_or_else(|| (T::encode_below(lower, exponents), StrictComparison::Strict));
82
83 let (upper_enc, upper_strict) = T::encode_single(upper, exponents)
85 .map(|x| (x, options.upper_strict))
86 .unwrap_or_else(|| (T::encode_above(upper, exponents), StrictComparison::Strict));
87
88 let options = BetweenOptions {
89 lower_strict,
90 upper_strict,
91 };
92
93 between(
94 array.encoded(),
95 ConstantArray::new(Scalar::primitive(lower_enc, nullability), array.len()).as_ref(),
96 ConstantArray::new(Scalar::primitive(upper_enc, nullability), array.len()).as_ref(),
97 &options,
98 )
99}
100
101#[cfg(test)]
102mod tests {
103 use itertools::Itertools;
104 use vortex_array::ToCanonical;
105 use vortex_array::arrays::PrimitiveArray;
106 use vortex_array::compute::BetweenOptions;
107 use vortex_array::compute::StrictComparison;
108 use vortex_dtype::Nullability;
109
110 use crate::ALPArray;
111 use crate::alp::compute::between::between_impl;
112 use crate::alp_encode;
113
114 fn between_test(arr: &ALPArray, lower: f32, upper: f32, options: &BetweenOptions) -> bool {
115 let res = between_impl(arr, lower, upper, Nullability::Nullable, options)
116 .unwrap()
117 .to_bool()
118 .bit_buffer()
119 .iter()
120 .collect_vec();
121 assert_eq!(res.len(), 1);
122
123 res[0]
124 }
125
126 #[test]
127 fn comparison_range() {
128 let value = 0.0605_f32;
129 let array = PrimitiveArray::from_iter([value; 1]);
130 let encoded = alp_encode(&array, None).unwrap();
131 assert!(encoded.patches().is_none());
132 assert_eq!(
133 encoded.encoded().to_primitive().as_slice::<i32>(),
134 vec![605; 1]
135 );
136
137 assert!(between_test(
138 &encoded,
139 0.0605_f32,
140 0.0605,
141 &BetweenOptions {
142 lower_strict: StrictComparison::NonStrict,
143 upper_strict: StrictComparison::NonStrict,
144 },
145 ));
146
147 assert!(!between_test(
148 &encoded,
149 0.0605_f32,
150 0.0605,
151 &BetweenOptions {
152 lower_strict: StrictComparison::Strict,
153 upper_strict: StrictComparison::NonStrict,
154 },
155 ));
156
157 assert!(!between_test(
158 &encoded,
159 0.0605_f32,
160 0.0605,
161 &BetweenOptions {
162 lower_strict: StrictComparison::NonStrict,
163 upper_strict: StrictComparison::Strict,
164 },
165 ));
166
167 assert!(between_test(
168 &encoded,
169 0.060499_f32,
170 0.06051,
171 &BetweenOptions {
172 lower_strict: StrictComparison::NonStrict,
173 upper_strict: StrictComparison::NonStrict,
174 },
175 ));
176
177 assert!(between_test(
178 &encoded,
179 0.06_f32,
180 0.06051,
181 &BetweenOptions {
182 lower_strict: StrictComparison::NonStrict,
183 upper_strict: StrictComparison::Strict,
184 },
185 ))
186 }
187}