1use std::fmt::Debug;
5
6use vortex_array::ArrayRef;
7use vortex_array::ArrayView;
8use vortex_array::ExecutionCtx;
9use vortex_array::IntoArray;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::builtins::ArrayBuiltins;
12use vortex_array::dtype::NativePType;
13use vortex_array::scalar::Scalar;
14use vortex_array::scalar_fn::fns::binary::CompareKernel;
15use vortex_array::scalar_fn::fns::operators::CompareOperator;
16use vortex_array::scalar_fn::fns::operators::Operator;
17use vortex_error::VortexResult;
18use vortex_error::vortex_bail;
19use vortex_error::vortex_err;
20
21use crate::ALP;
22use crate::ALPArrayExt;
23use crate::ALPArraySlotsExt;
24use crate::ALPFloat;
25use crate::match_each_alp_float_ptype;
26
27impl CompareKernel for ALP {
30 fn compare(
31 lhs: ArrayView<'_, Self>,
32 rhs: &ArrayRef,
33 operator: CompareOperator,
34 _ctx: &mut ExecutionCtx,
35 ) -> VortexResult<Option<ArrayRef>> {
36 if lhs.patches().is_some() {
37 return Ok(None);
39 }
40 if lhs.dtype().is_nullable() || rhs.dtype().is_nullable() {
41 return Ok(None);
43 }
44
45 if let Some(const_scalar) = rhs.as_constant() {
46 let pscalar = const_scalar.as_primitive_opt().ok_or_else(|| {
47 vortex_err!(
48 "ALP Compare RHS had the wrong type {}, expected {}",
49 const_scalar,
50 const_scalar.dtype()
51 )
52 })?;
53
54 match_each_alp_float_ptype!(pscalar.ptype(), |T| {
55 match pscalar.typed_value::<T>() {
56 Some(value) => return alp_scalar_compare(lhs, value, operator),
57 None => vortex_bail!(
58 "Failed to convert scalar {:?} to ALP type {:?}",
59 pscalar,
60 pscalar.ptype()
61 ),
62 }
63 });
64 }
65
66 Ok(None)
67 }
68}
69
70fn alp_scalar_compare<F: ALPFloat + Into<Scalar>>(
74 alp: ArrayView<ALP>,
75 value: F,
76 operator: CompareOperator,
77) -> VortexResult<Option<ArrayRef>>
78where
79 F::ALPInt: Into<Scalar>,
80 <F as ALPFloat>::ALPInt: Debug,
81{
82 if alp.patches().is_some() {
84 return Ok(None);
85 }
86
87 let exponents = alp.exponents();
88 let encoded = F::encode_single(value, alp.exponents());
91 match encoded {
92 Some(encoded) => {
93 let s = ConstantArray::new(encoded, alp.len());
94 Ok(Some(
95 alp.encoded()
96 .binary(s.into_array(), Operator::from(operator))?,
97 ))
98 }
99 None => match operator {
100 CompareOperator::Eq => Ok(Some(ConstantArray::new(false, alp.len()).into_array())),
103 CompareOperator::NotEq => Ok(Some(ConstantArray::new(true, alp.len()).into_array())),
106 CompareOperator::Gt | CompareOperator::Gte => {
107 let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
110 if is_not_finite {
111 Ok(Some(
112 ConstantArray::new(value.is_sign_negative(), alp.len()).into_array(),
113 ))
114 } else {
115 Ok(Some(
116 alp.encoded().binary(
117 ConstantArray::new(F::encode_above(value, exponents), alp.len())
118 .into_array(),
119 Operator::Gte,
123 )?,
124 ))
125 }
126 }
127 CompareOperator::Lt | CompareOperator::Lte => {
128 let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
131 if is_not_finite {
132 Ok(Some(
133 ConstantArray::new(value.is_sign_positive(), alp.len()).into_array(),
134 ))
135 } else {
136 Ok(Some(
137 alp.encoded().binary(
138 ConstantArray::new(F::encode_below(value, exponents), alp.len())
139 .into_array(),
140 Operator::Lte,
143 )?,
144 ))
145 }
146 }
147 },
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use rstest::rstest;
154 use vortex_array::ArrayRef;
155 use vortex_array::LEGACY_SESSION;
156 use vortex_array::ToCanonical;
157 use vortex_array::VortexSessionExecute;
158 use vortex_array::arrays::BoolArray;
159 use vortex_array::arrays::ConstantArray;
160 use vortex_array::arrays::PrimitiveArray;
161 use vortex_array::assert_arrays_eq;
162 use vortex_array::builtins::ArrayBuiltins;
163 use vortex_array::dtype::DType;
164 use vortex_array::dtype::Nullability;
165 use vortex_array::dtype::PType;
166 use vortex_array::scalar::Scalar;
167 use vortex_array::scalar_fn::fns::operators::CompareOperator;
168 use vortex_array::scalar_fn::fns::operators::Operator;
169
170 use super::*;
171 use crate::alp_encode;
172
173 fn test_alp_compare<F: ALPFloat + Into<Scalar>>(
174 alp: ArrayView<ALP>,
175 value: F,
176 operator: CompareOperator,
177 ) -> Option<ArrayRef>
178 where
179 F::ALPInt: Into<Scalar>,
180 <F as ALPFloat>::ALPInt: Debug,
181 {
182 alp_scalar_compare(alp, value, operator).unwrap()
183 }
184
185 #[test]
186 fn basic_comparison_test() {
187 let array = PrimitiveArray::from_iter([1.234f32; 1025]);
188 let encoded = alp_encode(
189 array.as_view(),
190 None,
191 &mut LEGACY_SESSION.create_execution_ctx(),
192 )
193 .unwrap();
194 assert!(encoded.patches().is_none());
195 assert_eq!(
196 encoded.encoded().to_primitive().as_slice::<i32>(),
197 vec![1234; 1025]
198 );
199
200 let r = alp_scalar_compare(encoded.as_view(), 1.3_f32, CompareOperator::Eq)
201 .unwrap()
202 .unwrap();
203 let expected = BoolArray::from_iter([false; 1025]);
204 assert_arrays_eq!(r, expected);
205
206 let r = alp_scalar_compare(encoded.as_view(), 1.234f32, CompareOperator::Eq)
207 .unwrap()
208 .unwrap();
209 let expected = BoolArray::from_iter([true; 1025]);
210 assert_arrays_eq!(r, expected);
211 }
212
213 #[test]
214 fn comparison_with_unencodable_value() {
215 let array = PrimitiveArray::from_iter([1.234f32; 1025]);
216 let encoded = alp_encode(
217 array.as_view(),
218 None,
219 &mut LEGACY_SESSION.create_execution_ctx(),
220 )
221 .unwrap();
222 assert!(encoded.patches().is_none());
223 assert_eq!(
224 encoded.encoded().to_primitive().as_slice::<i32>(),
225 vec![1234; 1025]
226 );
227
228 let r_eq = alp_scalar_compare(encoded.as_view(), 1.234444_f32, CompareOperator::Eq)
229 .unwrap()
230 .unwrap();
231 let expected = BoolArray::from_iter([false; 1025]);
232 assert_arrays_eq!(r_eq, expected);
233
234 let r_neq = alp_scalar_compare(encoded.as_view(), 1.234444f32, CompareOperator::NotEq)
235 .unwrap()
236 .unwrap();
237 let expected = BoolArray::from_iter([true; 1025]);
238 assert_arrays_eq!(r_neq, expected);
239 }
240
241 #[test]
242 fn comparison_range() {
243 let array = PrimitiveArray::from_iter([0.0605_f32; 10]);
244 let encoded = alp_encode(
245 array.as_view(),
246 None,
247 &mut LEGACY_SESSION.create_execution_ctx(),
248 )
249 .unwrap();
250 assert!(encoded.patches().is_none());
251 assert_eq!(
252 encoded.encoded().to_primitive().as_slice::<i32>(),
253 vec![605; 10]
254 );
255
256 let r_gte = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Gte)
258 .unwrap()
259 .unwrap();
260 let expected = BoolArray::from_iter([false; 10]);
261 assert_arrays_eq!(r_gte, expected);
262
263 let r_gt = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Gt)
265 .unwrap()
266 .unwrap();
267 let expected = BoolArray::from_iter([false; 10]);
268 assert_arrays_eq!(r_gt, expected);
269
270 let r_lte = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lte)
272 .unwrap()
273 .unwrap();
274 let expected = BoolArray::from_iter([true; 10]);
275 assert_arrays_eq!(r_lte, expected);
276
277 let r_lt = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lt)
279 .unwrap()
280 .unwrap();
281 let expected = BoolArray::from_iter([true; 10]);
282 assert_arrays_eq!(r_lt, expected);
283 }
284
285 #[test]
286 fn comparison_zeroes() {
287 let array = PrimitiveArray::from_iter([0.0_f32; 10]);
288 let encoded = alp_encode(
289 array.as_view(),
290 None,
291 &mut LEGACY_SESSION.create_execution_ctx(),
292 )
293 .unwrap();
294 assert!(encoded.patches().is_none());
295 assert_eq!(
296 encoded.encoded().to_primitive().as_slice::<i32>(),
297 vec![0; 10]
298 );
299
300 let r_gte =
301 test_alp_compare(encoded.as_view(), -0.00000001_f32, CompareOperator::Gte).unwrap();
302 let expected = BoolArray::from_iter([true; 10]);
303 assert_arrays_eq!(r_gte, expected);
304
305 let r_gte = test_alp_compare(encoded.as_view(), -0.0_f32, CompareOperator::Gte).unwrap();
306 let expected = BoolArray::from_iter([true; 10]);
307 assert_arrays_eq!(r_gte, expected);
308
309 let r_gt =
310 test_alp_compare(encoded.as_view(), -0.0000000001f32, CompareOperator::Gt).unwrap();
311 let expected = BoolArray::from_iter([true; 10]);
312 assert_arrays_eq!(r_gt, expected);
313
314 let r_gte = test_alp_compare(encoded.as_view(), -0.0_f32, CompareOperator::Gt).unwrap();
315 let expected = BoolArray::from_iter([true; 10]);
316 assert_arrays_eq!(r_gte, expected);
317
318 let r_lte = test_alp_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lte).unwrap();
319 let expected = BoolArray::from_iter([true; 10]);
320 assert_arrays_eq!(r_lte, expected);
321
322 let r_lt = test_alp_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lt).unwrap();
323 let expected = BoolArray::from_iter([true; 10]);
324 assert_arrays_eq!(r_lt, expected);
325
326 let r_lt = test_alp_compare(encoded.as_view(), -0.00001_f32, CompareOperator::Lt).unwrap();
327 let expected = BoolArray::from_iter([false; 10]);
328 assert_arrays_eq!(r_lt, expected);
329 }
330
331 #[test]
332 fn compare_with_patches() {
333 let array =
334 PrimitiveArray::from_iter([1.234f32, 1.5, 19.0, std::f32::consts::E, 1_000_000.9]);
335 let encoded = alp_encode(
336 array.as_view(),
337 None,
338 &mut LEGACY_SESSION.create_execution_ctx(),
339 )
340 .unwrap();
341 assert!(encoded.patches().is_some());
342
343 assert!(
345 alp_scalar_compare(encoded.as_view(), 1_000_000.9_f32, CompareOperator::Eq)
346 .unwrap()
347 .is_none()
348 )
349 }
350
351 #[test]
352 fn compare_to_null() {
353 let array = PrimitiveArray::from_iter([1.234f32; 10]);
354 let encoded = alp_encode(
355 array.as_view(),
356 None,
357 &mut LEGACY_SESSION.create_execution_ctx(),
358 )
359 .unwrap();
360
361 let other = ConstantArray::new(
362 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
363 array.len(),
364 );
365
366 let r = encoded
367 .into_array()
368 .binary(other.into_array(), Operator::Eq)
369 .unwrap();
370 let expected = BoolArray::from_iter([None::<bool>; 10]);
372 assert_arrays_eq!(r, expected);
373 }
374
375 #[rstest]
376 #[case(f32::NAN, false)]
377 #[case(-1.0f32 / 0.0f32, true)]
378 #[case(f32::INFINITY, false)]
379 #[case(f32::NEG_INFINITY, true)]
380 fn compare_to_non_finite_gt(#[case] value: f32, #[case] result: bool) {
381 let array = PrimitiveArray::from_iter([1.234f32; 10]);
382 let encoded = alp_encode(
383 array.as_view(),
384 None,
385 &mut LEGACY_SESSION.create_execution_ctx(),
386 )
387 .unwrap();
388
389 let r = test_alp_compare(encoded.as_view(), value, CompareOperator::Gt).unwrap();
390 let expected = BoolArray::from_iter([result; 10]);
391 assert_arrays_eq!(r, expected);
392 }
393
394 #[rstest]
395 #[case(f32::NAN, true)]
396 #[case(-1.0f32 / 0.0f32, false)]
397 #[case(f32::INFINITY, true)]
398 #[case(f32::NEG_INFINITY, false)]
399 fn compare_to_non_finite_lt(#[case] value: f32, #[case] result: bool) {
400 let array = PrimitiveArray::from_iter([1.234f32; 10]);
401 let encoded = alp_encode(
402 array.as_view(),
403 None,
404 &mut LEGACY_SESSION.create_execution_ctx(),
405 )
406 .unwrap();
407
408 let r = test_alp_compare(encoded.as_view(), value, CompareOperator::Lt).unwrap();
409 let expected = BoolArray::from_iter([result; 10]);
410 assert_arrays_eq!(r, expected);
411 }
412}