1use std::fmt::Debug;
5
6use vortex_array::Array;
7use vortex_array::ArrayRef;
8use vortex_array::ExecutionCtx;
9use vortex_array::IntoArray;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::compute::Operator;
12use vortex_array::compute::compare;
13use vortex_array::expr::CompareKernel;
14use vortex_array::scalar::Scalar;
15use vortex_dtype::NativePType;
16use vortex_error::VortexResult;
17use vortex_error::vortex_bail;
18use vortex_error::vortex_err;
19
20use crate::ALPArray;
21use crate::ALPFloat;
22use crate::ALPVTable;
23use crate::match_each_alp_float_ptype;
24
25impl CompareKernel for ALPVTable {
28 fn compare(
29 lhs: &ALPArray,
30 rhs: &dyn Array,
31 operator: Operator,
32 _ctx: &mut ExecutionCtx,
33 ) -> VortexResult<Option<ArrayRef>> {
34 if lhs.patches().is_some() {
35 return Ok(None);
37 }
38 if lhs.dtype().is_nullable() || rhs.dtype().is_nullable() {
39 return Ok(None);
41 }
42
43 if let Some(const_scalar) = rhs.as_constant() {
44 let pscalar = const_scalar.as_primitive_opt().ok_or_else(|| {
45 vortex_err!(
46 "ALP Compare RHS had the wrong type {}, expected {}",
47 const_scalar,
48 const_scalar.dtype()
49 )
50 })?;
51
52 match_each_alp_float_ptype!(pscalar.ptype(), |T| {
53 match pscalar.typed_value::<T>() {
54 Some(value) => return alp_scalar_compare(lhs, value, operator),
55 None => vortex_bail!(
56 "Failed to convert scalar {:?} to ALP type {:?}",
57 pscalar,
58 pscalar.ptype()
59 ),
60 }
61 });
62 }
63
64 Ok(None)
65 }
66}
67
68fn alp_scalar_compare<F: ALPFloat + Into<Scalar>>(
72 alp: &ALPArray,
73 value: F,
74 operator: Operator,
75) -> VortexResult<Option<ArrayRef>>
76where
77 F::ALPInt: Into<Scalar>,
78 <F as ALPFloat>::ALPInt: Debug,
79{
80 if alp.patches().is_some() {
82 return Ok(None);
83 }
84
85 let exponents = alp.exponents();
86 let encoded = F::encode_single(value, alp.exponents());
89 match encoded {
90 Some(encoded) => {
91 let s = ConstantArray::new(encoded, alp.len());
92 Ok(Some(compare(alp.encoded(), s.as_ref(), operator)?))
93 }
94 None => match operator {
95 Operator::Eq => Ok(Some(ConstantArray::new(false, alp.len()).into_array())),
98 Operator::NotEq => Ok(Some(ConstantArray::new(true, alp.len()).into_array())),
101 Operator::Gt | Operator::Gte => {
102 let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
105 if is_not_finite {
106 Ok(Some(
107 ConstantArray::new(value.is_sign_negative(), alp.len()).into_array(),
108 ))
109 } else {
110 Ok(Some(compare(
111 alp.encoded(),
112 ConstantArray::new(F::encode_above(value, exponents), alp.len()).as_ref(),
113 Operator::Gte,
117 )?))
118 }
119 }
120 Operator::Lt | Operator::Lte => {
121 let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
124 if is_not_finite {
125 Ok(Some(
126 ConstantArray::new(value.is_sign_positive(), alp.len()).into_array(),
127 ))
128 } else {
129 Ok(Some(compare(
130 alp.encoded(),
131 ConstantArray::new(F::encode_below(value, exponents), alp.len()).as_ref(),
132 Operator::Lte,
135 )?))
136 }
137 }
138 },
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use rstest::rstest;
145 use vortex_array::ArrayRef;
146 use vortex_array::ToCanonical;
147 use vortex_array::arrays::BoolArray;
148 use vortex_array::arrays::ConstantArray;
149 use vortex_array::arrays::PrimitiveArray;
150 use vortex_array::assert_arrays_eq;
151 use vortex_array::compute::Operator;
152 use vortex_array::compute::compare;
153 use vortex_array::scalar::Scalar;
154 use vortex_dtype::DType;
155 use vortex_dtype::Nullability;
156 use vortex_dtype::PType;
157
158 use super::*;
159 use crate::alp_encode;
160
161 fn test_alp_compare<F: ALPFloat + Into<Scalar>>(
162 alp: &ALPArray,
163 value: F,
164 operator: Operator,
165 ) -> Option<ArrayRef>
166 where
167 F::ALPInt: Into<Scalar>,
168 <F as ALPFloat>::ALPInt: Debug,
169 {
170 alp_scalar_compare(alp, value, operator).unwrap()
171 }
172
173 #[test]
174 fn basic_comparison_test() {
175 let array = PrimitiveArray::from_iter([1.234f32; 1025]);
176 let encoded = alp_encode(&array, None).unwrap();
177 assert!(encoded.patches().is_none());
178 assert_eq!(
179 encoded.encoded().to_primitive().as_slice::<i32>(),
180 vec![1234; 1025]
181 );
182
183 let r = alp_scalar_compare(&encoded, 1.3_f32, Operator::Eq)
184 .unwrap()
185 .unwrap();
186 let expected = BoolArray::from_iter([false; 1025]);
187 assert_arrays_eq!(r, expected);
188
189 let r = alp_scalar_compare(&encoded, 1.234f32, Operator::Eq)
190 .unwrap()
191 .unwrap();
192 let expected = BoolArray::from_iter([true; 1025]);
193 assert_arrays_eq!(r, expected);
194 }
195
196 #[test]
197 fn comparison_with_unencodable_value() {
198 let array = PrimitiveArray::from_iter([1.234f32; 1025]);
199 let encoded = alp_encode(&array, None).unwrap();
200 assert!(encoded.patches().is_none());
201 assert_eq!(
202 encoded.encoded().to_primitive().as_slice::<i32>(),
203 vec![1234; 1025]
204 );
205
206 #[allow(clippy::excessive_precision)]
207 let r_eq = alp_scalar_compare(&encoded, 1.234444_f32, Operator::Eq)
208 .unwrap()
209 .unwrap();
210 let expected = BoolArray::from_iter([false; 1025]);
211 assert_arrays_eq!(r_eq, expected);
212
213 #[allow(clippy::excessive_precision)]
214 let r_neq = alp_scalar_compare(&encoded, 1.234444f32, Operator::NotEq)
215 .unwrap()
216 .unwrap();
217 let expected = BoolArray::from_iter([true; 1025]);
218 assert_arrays_eq!(r_neq, expected);
219 }
220
221 #[test]
222 fn comparison_range() {
223 let array = PrimitiveArray::from_iter([0.0605_f32; 10]);
224 let encoded = alp_encode(&array, None).unwrap();
225 assert!(encoded.patches().is_none());
226 assert_eq!(
227 encoded.encoded().to_primitive().as_slice::<i32>(),
228 vec![605; 10]
229 );
230
231 let r_gte = alp_scalar_compare(&encoded, 0.06051_f32, Operator::Gte)
233 .unwrap()
234 .unwrap();
235 let expected = BoolArray::from_iter([false; 10]);
236 assert_arrays_eq!(r_gte, expected);
237
238 let r_gt = alp_scalar_compare(&encoded, 0.06051_f32, Operator::Gt)
240 .unwrap()
241 .unwrap();
242 let expected = BoolArray::from_iter([false; 10]);
243 assert_arrays_eq!(r_gt, expected);
244
245 let r_lte = alp_scalar_compare(&encoded, 0.06051_f32, Operator::Lte)
247 .unwrap()
248 .unwrap();
249 let expected = BoolArray::from_iter([true; 10]);
250 assert_arrays_eq!(r_lte, expected);
251
252 let r_lt = alp_scalar_compare(&encoded, 0.06051_f32, Operator::Lt)
254 .unwrap()
255 .unwrap();
256 let expected = BoolArray::from_iter([true; 10]);
257 assert_arrays_eq!(r_lt, expected);
258 }
259
260 #[test]
261 fn comparison_zeroes() {
262 let array = PrimitiveArray::from_iter([0.0_f32; 10]);
263 let encoded = alp_encode(&array, None).unwrap();
264 assert!(encoded.patches().is_none());
265 assert_eq!(
266 encoded.encoded().to_primitive().as_slice::<i32>(),
267 vec![0; 10]
268 );
269
270 let r_gte = test_alp_compare(&encoded, -0.00000001_f32, Operator::Gte).unwrap();
271 let expected = BoolArray::from_iter([true; 10]);
272 assert_arrays_eq!(r_gte, expected);
273
274 let r_gte = test_alp_compare(&encoded, -0.0_f32, Operator::Gte).unwrap();
275 let expected = BoolArray::from_iter([true; 10]);
276 assert_arrays_eq!(r_gte, expected);
277
278 let r_gt = test_alp_compare(&encoded, -0.0000000001f32, Operator::Gt).unwrap();
279 let expected = BoolArray::from_iter([true; 10]);
280 assert_arrays_eq!(r_gt, expected);
281
282 let r_gte = test_alp_compare(&encoded, -0.0_f32, Operator::Gt).unwrap();
283 let expected = BoolArray::from_iter([true; 10]);
284 assert_arrays_eq!(r_gte, expected);
285
286 let r_lte = test_alp_compare(&encoded, 0.06051_f32, Operator::Lte).unwrap();
287 let expected = BoolArray::from_iter([true; 10]);
288 assert_arrays_eq!(r_lte, expected);
289
290 let r_lt = test_alp_compare(&encoded, 0.06051_f32, Operator::Lt).unwrap();
291 let expected = BoolArray::from_iter([true; 10]);
292 assert_arrays_eq!(r_lt, expected);
293
294 let r_lt = test_alp_compare(&encoded, -0.00001_f32, Operator::Lt).unwrap();
295 let expected = BoolArray::from_iter([false; 10]);
296 assert_arrays_eq!(r_lt, expected);
297 }
298
299 #[test]
300 fn compare_with_patches() {
301 let array =
302 PrimitiveArray::from_iter([1.234f32, 1.5, 19.0, std::f32::consts::E, 1_000_000.9]);
303 let encoded = alp_encode(&array, None).unwrap();
304 assert!(encoded.patches().is_some());
305
306 assert!(
308 alp_scalar_compare(&encoded, 1_000_000.9_f32, Operator::Eq)
309 .unwrap()
310 .is_none()
311 )
312 }
313
314 #[test]
315 fn compare_to_null() {
316 let array = PrimitiveArray::from_iter([1.234f32; 10]);
317 let encoded = alp_encode(&array, None).unwrap();
318
319 let other = ConstantArray::new(
320 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
321 array.len(),
322 );
323
324 let r = compare(encoded.as_ref(), other.as_ref(), Operator::Eq).unwrap();
325 let expected = BoolArray::from_iter([None::<bool>; 10]);
327 assert_arrays_eq!(r, expected);
328 }
329
330 #[rstest]
331 #[case(f32::NAN, false)]
332 #[case(-1.0f32 / 0.0f32, true)]
333 #[case(f32::INFINITY, false)]
334 #[case(f32::NEG_INFINITY, true)]
335 fn compare_to_non_finite_gt(#[case] value: f32, #[case] result: bool) {
336 let array = PrimitiveArray::from_iter([1.234f32; 10]);
337 let encoded = alp_encode(&array, None).unwrap();
338
339 let r = test_alp_compare(&encoded, value, Operator::Gt).unwrap();
340 let expected = BoolArray::from_iter([result; 10]);
341 assert_arrays_eq!(r, expected);
342 }
343
344 #[rstest]
345 #[case(f32::NAN, true)]
346 #[case(-1.0f32 / 0.0f32, false)]
347 #[case(f32::INFINITY, true)]
348 #[case(f32::NEG_INFINITY, false)]
349 fn compare_to_non_finite_lt(#[case] value: f32, #[case] result: bool) {
350 let array = PrimitiveArray::from_iter([1.234f32; 10]);
351 let encoded = alp_encode(&array, None).unwrap();
352
353 let r = test_alp_compare(&encoded, value, Operator::Lt).unwrap();
354 let expected = BoolArray::from_iter([result; 10]);
355 assert_arrays_eq!(r, expected);
356 }
357}