1use std::fmt::Debug;
5
6use vortex_array::ArrayRef;
7use vortex_array::ArrayView;
8use vortex_array::ExecutionCtx;
9use vortex_array::IntoArray;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::builtins::ArrayBuiltins;
12use vortex_array::dtype::NativePType;
13use vortex_array::scalar::Scalar;
14use vortex_array::scalar_fn::fns::binary::CompareKernel;
15use vortex_array::scalar_fn::fns::operators::CompareOperator;
16use vortex_array::scalar_fn::fns::operators::Operator;
17use vortex_error::VortexResult;
18use vortex_error::vortex_bail;
19use vortex_error::vortex_err;
20
21use crate::ALP;
22use crate::ALPArrayExt;
23use crate::ALPArraySlotsExt;
24use crate::ALPFloat;
25use crate::match_each_alp_float_ptype;
26
27impl CompareKernel for ALP {
30 fn compare(
31 lhs: ArrayView<'_, Self>,
32 rhs: &ArrayRef,
33 operator: CompareOperator,
34 _ctx: &mut ExecutionCtx,
35 ) -> VortexResult<Option<ArrayRef>> {
36 if lhs.patches().is_some() {
37 return Ok(None);
39 }
40 if lhs.dtype().is_nullable() || rhs.dtype().is_nullable() {
41 return Ok(None);
43 }
44
45 if let Some(const_scalar) = rhs.as_constant() {
46 let pscalar = const_scalar.as_primitive_opt().ok_or_else(|| {
47 vortex_err!(
48 "ALP Compare RHS had the wrong type {}, expected {}",
49 const_scalar,
50 const_scalar.dtype()
51 )
52 })?;
53
54 match_each_alp_float_ptype!(pscalar.ptype(), |T| {
55 match pscalar.typed_value::<T>() {
56 Some(value) => return alp_scalar_compare(lhs, value, operator),
57 None => vortex_bail!(
58 "Failed to convert scalar {:?} to ALP type {:?}",
59 pscalar,
60 pscalar.ptype()
61 ),
62 }
63 });
64 }
65
66 Ok(None)
67 }
68}
69
70fn alp_scalar_compare<F: ALPFloat + Into<Scalar>>(
74 alp: ArrayView<ALP>,
75 value: F,
76 operator: CompareOperator,
77) -> VortexResult<Option<ArrayRef>>
78where
79 F::ALPInt: Into<Scalar>,
80 <F as ALPFloat>::ALPInt: Debug,
81{
82 if alp.patches().is_some() {
84 return Ok(None);
85 }
86
87 let exponents = alp.exponents();
88 let encoded = F::encode_single(value, alp.exponents());
91 match encoded {
92 Some(encoded) => {
93 let s = ConstantArray::new(encoded, alp.len());
94 Ok(Some(
95 alp.encoded()
96 .binary(s.into_array(), Operator::from(operator))?,
97 ))
98 }
99 None => match operator {
100 CompareOperator::Eq => Ok(Some(ConstantArray::new(false, alp.len()).into_array())),
103 CompareOperator::NotEq => Ok(Some(ConstantArray::new(true, alp.len()).into_array())),
106 CompareOperator::Gt | CompareOperator::Gte => {
107 let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
110 if is_not_finite {
111 Ok(Some(
112 ConstantArray::new(value.is_sign_negative(), alp.len()).into_array(),
113 ))
114 } else {
115 Ok(Some(
116 alp.encoded().binary(
117 ConstantArray::new(F::encode_above(value, exponents), alp.len())
118 .into_array(),
119 Operator::Gte,
123 )?,
124 ))
125 }
126 }
127 CompareOperator::Lt | CompareOperator::Lte => {
128 let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
131 if is_not_finite {
132 Ok(Some(
133 ConstantArray::new(value.is_sign_positive(), alp.len()).into_array(),
134 ))
135 } else {
136 Ok(Some(
137 alp.encoded().binary(
138 ConstantArray::new(F::encode_below(value, exponents), alp.len())
139 .into_array(),
140 Operator::Lte,
143 )?,
144 ))
145 }
146 }
147 },
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use rstest::rstest;
154 use vortex_array::ArrayRef;
155 use vortex_array::LEGACY_SESSION;
156 use vortex_array::VortexSessionExecute;
157 use vortex_array::arrays::BoolArray;
158 use vortex_array::arrays::ConstantArray;
159 use vortex_array::arrays::PrimitiveArray;
160 use vortex_array::assert_arrays_eq;
161 use vortex_array::builtins::ArrayBuiltins;
162 use vortex_array::dtype::DType;
163 use vortex_array::dtype::Nullability;
164 use vortex_array::dtype::PType;
165 use vortex_array::scalar::Scalar;
166 use vortex_array::scalar_fn::fns::operators::CompareOperator;
167 use vortex_array::scalar_fn::fns::operators::Operator;
168
169 use super::*;
170 use crate::alp_encode;
171
172 fn test_alp_compare<F: ALPFloat + Into<Scalar>>(
173 alp: ArrayView<ALP>,
174 value: F,
175 operator: CompareOperator,
176 ) -> Option<ArrayRef>
177 where
178 F::ALPInt: Into<Scalar>,
179 <F as ALPFloat>::ALPInt: Debug,
180 {
181 alp_scalar_compare(alp, value, operator).unwrap()
182 }
183
184 #[test]
185 fn basic_comparison_test() {
186 let mut ctx = LEGACY_SESSION.create_execution_ctx();
187 let array = PrimitiveArray::from_iter([1.234f32; 1025]);
188 let encoded = alp_encode(array.as_view(), None, &mut ctx).unwrap();
189 assert!(encoded.patches().is_none());
190 let encoded_prim = encoded
191 .encoded()
192 .clone()
193 .execute::<PrimitiveArray>(&mut ctx)
194 .unwrap();
195 assert_eq!(encoded_prim.as_slice::<i32>(), vec![1234; 1025]);
196
197 let r = alp_scalar_compare(encoded.as_view(), 1.3_f32, CompareOperator::Eq)
198 .unwrap()
199 .unwrap();
200 let expected = BoolArray::from_iter([false; 1025]);
201 assert_arrays_eq!(r, expected);
202
203 let r = alp_scalar_compare(encoded.as_view(), 1.234f32, CompareOperator::Eq)
204 .unwrap()
205 .unwrap();
206 let expected = BoolArray::from_iter([true; 1025]);
207 assert_arrays_eq!(r, expected);
208 }
209
210 #[test]
211 fn comparison_with_unencodable_value() {
212 let mut ctx = LEGACY_SESSION.create_execution_ctx();
213 let array = PrimitiveArray::from_iter([1.234f32; 1025]);
214 let encoded = alp_encode(array.as_view(), None, &mut ctx).unwrap();
215 assert!(encoded.patches().is_none());
216 let encoded_prim = encoded
217 .encoded()
218 .clone()
219 .execute::<PrimitiveArray>(&mut ctx)
220 .unwrap();
221 assert_eq!(encoded_prim.as_slice::<i32>(), vec![1234; 1025]);
222
223 let r_eq = alp_scalar_compare(encoded.as_view(), 1.234444_f32, CompareOperator::Eq)
224 .unwrap()
225 .unwrap();
226 let expected = BoolArray::from_iter([false; 1025]);
227 assert_arrays_eq!(r_eq, expected);
228
229 let r_neq = alp_scalar_compare(encoded.as_view(), 1.234444f32, CompareOperator::NotEq)
230 .unwrap()
231 .unwrap();
232 let expected = BoolArray::from_iter([true; 1025]);
233 assert_arrays_eq!(r_neq, expected);
234 }
235
236 #[test]
237 fn comparison_range() {
238 let mut ctx = LEGACY_SESSION.create_execution_ctx();
239 let array = PrimitiveArray::from_iter([0.0605_f32; 10]);
240 let encoded = alp_encode(array.as_view(), None, &mut ctx).unwrap();
241 assert!(encoded.patches().is_none());
242 let encoded_prim = encoded
243 .encoded()
244 .clone()
245 .execute::<PrimitiveArray>(&mut ctx)
246 .unwrap();
247 assert_eq!(encoded_prim.as_slice::<i32>(), vec![605; 10]);
248
249 let r_gte = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Gte)
251 .unwrap()
252 .unwrap();
253 let expected = BoolArray::from_iter([false; 10]);
254 assert_arrays_eq!(r_gte, expected);
255
256 let r_gt = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Gt)
258 .unwrap()
259 .unwrap();
260 let expected = BoolArray::from_iter([false; 10]);
261 assert_arrays_eq!(r_gt, expected);
262
263 let r_lte = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lte)
265 .unwrap()
266 .unwrap();
267 let expected = BoolArray::from_iter([true; 10]);
268 assert_arrays_eq!(r_lte, expected);
269
270 let r_lt = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lt)
272 .unwrap()
273 .unwrap();
274 let expected = BoolArray::from_iter([true; 10]);
275 assert_arrays_eq!(r_lt, expected);
276 }
277
278 #[test]
279 fn comparison_zeroes() {
280 let mut ctx = LEGACY_SESSION.create_execution_ctx();
281 let array = PrimitiveArray::from_iter([0.0_f32; 10]);
282 let encoded = alp_encode(array.as_view(), None, &mut ctx).unwrap();
283 assert!(encoded.patches().is_none());
284 let encoded_prim = encoded
285 .encoded()
286 .clone()
287 .execute::<PrimitiveArray>(&mut ctx)
288 .unwrap();
289 assert_eq!(encoded_prim.as_slice::<i32>(), vec![0; 10]);
290
291 let r_gte =
292 test_alp_compare(encoded.as_view(), -0.00000001_f32, CompareOperator::Gte).unwrap();
293 let expected = BoolArray::from_iter([true; 10]);
294 assert_arrays_eq!(r_gte, expected);
295
296 let r_gte = test_alp_compare(encoded.as_view(), -0.0_f32, CompareOperator::Gte).unwrap();
297 let expected = BoolArray::from_iter([true; 10]);
298 assert_arrays_eq!(r_gte, expected);
299
300 let r_gt =
301 test_alp_compare(encoded.as_view(), -0.0000000001f32, CompareOperator::Gt).unwrap();
302 let expected = BoolArray::from_iter([true; 10]);
303 assert_arrays_eq!(r_gt, expected);
304
305 let r_gte = test_alp_compare(encoded.as_view(), -0.0_f32, CompareOperator::Gt).unwrap();
306 let expected = BoolArray::from_iter([true; 10]);
307 assert_arrays_eq!(r_gte, expected);
308
309 let r_lte = test_alp_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lte).unwrap();
310 let expected = BoolArray::from_iter([true; 10]);
311 assert_arrays_eq!(r_lte, expected);
312
313 let r_lt = test_alp_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lt).unwrap();
314 let expected = BoolArray::from_iter([true; 10]);
315 assert_arrays_eq!(r_lt, expected);
316
317 let r_lt = test_alp_compare(encoded.as_view(), -0.00001_f32, CompareOperator::Lt).unwrap();
318 let expected = BoolArray::from_iter([false; 10]);
319 assert_arrays_eq!(r_lt, expected);
320 }
321
322 #[test]
323 fn compare_with_patches() {
324 let array =
325 PrimitiveArray::from_iter([1.234f32, 1.5, 19.0, std::f32::consts::E, 1_000_000.9]);
326 let encoded = alp_encode(
327 array.as_view(),
328 None,
329 &mut LEGACY_SESSION.create_execution_ctx(),
330 )
331 .unwrap();
332 assert!(encoded.patches().is_some());
333
334 assert!(
336 alp_scalar_compare(encoded.as_view(), 1_000_000.9_f32, CompareOperator::Eq)
337 .unwrap()
338 .is_none()
339 )
340 }
341
342 #[test]
343 fn compare_to_null() {
344 let array = PrimitiveArray::from_iter([1.234f32; 10]);
345 let encoded = alp_encode(
346 array.as_view(),
347 None,
348 &mut LEGACY_SESSION.create_execution_ctx(),
349 )
350 .unwrap();
351
352 let other = ConstantArray::new(
353 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
354 array.len(),
355 );
356
357 let r = encoded
358 .into_array()
359 .binary(other.into_array(), Operator::Eq)
360 .unwrap();
361 let expected = BoolArray::from_iter([None::<bool>; 10]);
363 assert_arrays_eq!(r, expected);
364 }
365
366 #[rstest]
367 #[case(f32::NAN, false)]
368 #[case(-1.0f32 / 0.0f32, true)]
369 #[case(f32::INFINITY, false)]
370 #[case(f32::NEG_INFINITY, true)]
371 fn compare_to_non_finite_gt(#[case] value: f32, #[case] result: bool) {
372 let array = PrimitiveArray::from_iter([1.234f32; 10]);
373 let encoded = alp_encode(
374 array.as_view(),
375 None,
376 &mut LEGACY_SESSION.create_execution_ctx(),
377 )
378 .unwrap();
379
380 let r = test_alp_compare(encoded.as_view(), value, CompareOperator::Gt).unwrap();
381 let expected = BoolArray::from_iter([result; 10]);
382 assert_arrays_eq!(r, expected);
383 }
384
385 #[rstest]
386 #[case(f32::NAN, true)]
387 #[case(-1.0f32 / 0.0f32, false)]
388 #[case(f32::INFINITY, true)]
389 #[case(f32::NEG_INFINITY, false)]
390 fn compare_to_non_finite_lt(#[case] value: f32, #[case] result: bool) {
391 let array = PrimitiveArray::from_iter([1.234f32; 10]);
392 let encoded = alp_encode(
393 array.as_view(),
394 None,
395 &mut LEGACY_SESSION.create_execution_ctx(),
396 )
397 .unwrap();
398
399 let r = test_alp_compare(encoded.as_view(), value, CompareOperator::Lt).unwrap();
400 let expected = BoolArray::from_iter([result; 10]);
401 assert_arrays_eq!(r, expected);
402 }
403}