1use std::fmt::Debug;
5
6use vortex_array::ArrayRef;
7use vortex_array::ArrayView;
8use vortex_array::ExecutionCtx;
9use vortex_array::IntoArray;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::builtins::ArrayBuiltins;
12use vortex_array::dtype::NativePType;
13use vortex_array::scalar::Scalar;
14use vortex_array::scalar_fn::fns::binary::CompareKernel;
15use vortex_array::scalar_fn::fns::operators::CompareOperator;
16use vortex_array::scalar_fn::fns::operators::Operator;
17use vortex_error::VortexResult;
18use vortex_error::vortex_bail;
19use vortex_error::vortex_err;
20
21use crate::ALP;
22use crate::ALPArrayExt;
23use crate::ALPArraySlotsExt;
24use crate::ALPFloat;
25use crate::match_each_alp_float_ptype;
26
27impl CompareKernel for ALP {
30 fn compare(
31 lhs: ArrayView<'_, Self>,
32 rhs: &ArrayRef,
33 operator: CompareOperator,
34 _ctx: &mut ExecutionCtx,
35 ) -> VortexResult<Option<ArrayRef>> {
36 if lhs.patches().is_some() {
37 return Ok(None);
39 }
40 if lhs.dtype().is_nullable() || rhs.dtype().is_nullable() {
41 return Ok(None);
43 }
44
45 if let Some(const_scalar) = rhs.as_constant() {
46 let pscalar = const_scalar.as_primitive_opt().ok_or_else(|| {
47 vortex_err!(
48 "ALP Compare RHS had the wrong type {}, expected {}",
49 const_scalar,
50 const_scalar.dtype()
51 )
52 })?;
53
54 match_each_alp_float_ptype!(pscalar.ptype(), |T| {
55 match pscalar.typed_value::<T>() {
56 Some(value) => return alp_scalar_compare(lhs, value, operator),
57 None => vortex_bail!(
58 "Failed to convert scalar {:?} to ALP type {:?}",
59 pscalar,
60 pscalar.ptype()
61 ),
62 }
63 });
64 }
65
66 Ok(None)
67 }
68}
69
70fn alp_scalar_compare<F: ALPFloat + Into<Scalar>>(
74 alp: ArrayView<ALP>,
75 value: F,
76 operator: CompareOperator,
77) -> VortexResult<Option<ArrayRef>>
78where
79 F::ALPInt: Into<Scalar>,
80 <F as ALPFloat>::ALPInt: Debug,
81{
82 if alp.patches().is_some() {
84 return Ok(None);
85 }
86
87 let exponents = alp.exponents();
88 let encoded = F::encode_single(value, alp.exponents());
91 match encoded {
92 Some(encoded) => {
93 let s = ConstantArray::new(encoded, alp.len());
94 Ok(Some(
95 alp.encoded()
96 .binary(s.into_array(), Operator::from(operator))?,
97 ))
98 }
99 None => match operator {
100 CompareOperator::Eq => Ok(Some(ConstantArray::new(false, alp.len()).into_array())),
103 CompareOperator::NotEq => Ok(Some(ConstantArray::new(true, alp.len()).into_array())),
106 CompareOperator::Gt | CompareOperator::Gte => {
107 let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
110 if is_not_finite {
111 Ok(Some(
112 ConstantArray::new(value.is_sign_negative(), alp.len()).into_array(),
113 ))
114 } else {
115 Ok(Some(
116 alp.encoded().binary(
117 ConstantArray::new(F::encode_above(value, exponents), alp.len())
118 .into_array(),
119 Operator::Gte,
123 )?,
124 ))
125 }
126 }
127 CompareOperator::Lt | CompareOperator::Lte => {
128 let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
131 if is_not_finite {
132 Ok(Some(
133 ConstantArray::new(value.is_sign_positive(), alp.len()).into_array(),
134 ))
135 } else {
136 Ok(Some(
137 alp.encoded().binary(
138 ConstantArray::new(F::encode_below(value, exponents), alp.len())
139 .into_array(),
140 Operator::Lte,
143 )?,
144 ))
145 }
146 }
147 },
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use std::f32;
154
155 use rstest::rstest;
156 use vortex_array::ArrayRef;
157 use vortex_array::LEGACY_SESSION;
158 use vortex_array::VortexSessionExecute;
159 use vortex_array::arrays::BoolArray;
160 use vortex_array::arrays::ConstantArray;
161 use vortex_array::arrays::PrimitiveArray;
162 use vortex_array::assert_arrays_eq;
163 use vortex_array::builtins::ArrayBuiltins;
164 use vortex_array::dtype::DType;
165 use vortex_array::dtype::Nullability;
166 use vortex_array::dtype::PType;
167 use vortex_array::scalar::Scalar;
168 use vortex_array::scalar_fn::fns::operators::CompareOperator;
169 use vortex_array::scalar_fn::fns::operators::Operator;
170
171 use super::*;
172 use crate::alp_encode;
173
174 fn test_alp_compare<F: ALPFloat + Into<Scalar>>(
175 alp: ArrayView<ALP>,
176 value: F,
177 operator: CompareOperator,
178 ) -> Option<ArrayRef>
179 where
180 F::ALPInt: Into<Scalar>,
181 <F as ALPFloat>::ALPInt: Debug,
182 {
183 alp_scalar_compare(alp, value, operator).unwrap()
184 }
185
186 #[test]
187 fn basic_comparison_test() {
188 let mut ctx = LEGACY_SESSION.create_execution_ctx();
189 let array = PrimitiveArray::from_iter([1.234f32; 1025]);
190 let encoded = alp_encode(array.as_view(), None, &mut ctx).unwrap();
191 assert!(encoded.patches().is_none());
192 let encoded_prim = encoded
193 .encoded()
194 .clone()
195 .execute::<PrimitiveArray>(&mut ctx)
196 .unwrap();
197 assert_eq!(encoded_prim.as_slice::<i32>(), vec![1234; 1025]);
198
199 let r = alp_scalar_compare(encoded.as_view(), 1.3_f32, CompareOperator::Eq)
200 .unwrap()
201 .unwrap();
202 let expected = BoolArray::from_iter([false; 1025]);
203 assert_arrays_eq!(r, expected);
204
205 let r = alp_scalar_compare(encoded.as_view(), 1.234f32, CompareOperator::Eq)
206 .unwrap()
207 .unwrap();
208 let expected = BoolArray::from_iter([true; 1025]);
209 assert_arrays_eq!(r, expected);
210 }
211
212 #[test]
213 fn comparison_with_unencodable_value() {
214 let mut ctx = LEGACY_SESSION.create_execution_ctx();
215 let array = PrimitiveArray::from_iter([1.234f32; 1025]);
216 let encoded = alp_encode(array.as_view(), None, &mut ctx).unwrap();
217 assert!(encoded.patches().is_none());
218 let encoded_prim = encoded
219 .encoded()
220 .clone()
221 .execute::<PrimitiveArray>(&mut ctx)
222 .unwrap();
223 assert_eq!(encoded_prim.as_slice::<i32>(), vec![1234; 1025]);
224
225 let r_eq = alp_scalar_compare(encoded.as_view(), 1.234444_f32, CompareOperator::Eq)
226 .unwrap()
227 .unwrap();
228 let expected = BoolArray::from_iter([false; 1025]);
229 assert_arrays_eq!(r_eq, expected);
230
231 let r_neq = alp_scalar_compare(encoded.as_view(), 1.234444f32, CompareOperator::NotEq)
232 .unwrap()
233 .unwrap();
234 let expected = BoolArray::from_iter([true; 1025]);
235 assert_arrays_eq!(r_neq, expected);
236 }
237
238 #[test]
239 fn comparison_range() {
240 let mut ctx = LEGACY_SESSION.create_execution_ctx();
241 let array = PrimitiveArray::from_iter([0.0605_f32; 10]);
242 let encoded = alp_encode(array.as_view(), None, &mut ctx).unwrap();
243 assert!(encoded.patches().is_none());
244 let encoded_prim = encoded
245 .encoded()
246 .clone()
247 .execute::<PrimitiveArray>(&mut ctx)
248 .unwrap();
249 assert_eq!(encoded_prim.as_slice::<i32>(), vec![605; 10]);
250
251 let r_gte = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Gte)
253 .unwrap()
254 .unwrap();
255 let expected = BoolArray::from_iter([false; 10]);
256 assert_arrays_eq!(r_gte, expected);
257
258 let r_gt = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Gt)
260 .unwrap()
261 .unwrap();
262 let expected = BoolArray::from_iter([false; 10]);
263 assert_arrays_eq!(r_gt, expected);
264
265 let r_lte = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lte)
267 .unwrap()
268 .unwrap();
269 let expected = BoolArray::from_iter([true; 10]);
270 assert_arrays_eq!(r_lte, expected);
271
272 let r_lt = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lt)
274 .unwrap()
275 .unwrap();
276 let expected = BoolArray::from_iter([true; 10]);
277 assert_arrays_eq!(r_lt, expected);
278 }
279
280 #[test]
281 fn comparison_zeroes() {
282 let mut ctx = LEGACY_SESSION.create_execution_ctx();
283 let array = PrimitiveArray::from_iter([0.0_f32; 10]);
284 let encoded = alp_encode(array.as_view(), None, &mut ctx).unwrap();
285 assert!(encoded.patches().is_none());
286 let encoded_prim = encoded
287 .encoded()
288 .clone()
289 .execute::<PrimitiveArray>(&mut ctx)
290 .unwrap();
291 assert_eq!(encoded_prim.as_slice::<i32>(), vec![0; 10]);
292
293 let r_gte =
294 test_alp_compare(encoded.as_view(), -0.00000001_f32, CompareOperator::Gte).unwrap();
295 let expected = BoolArray::from_iter([true; 10]);
296 assert_arrays_eq!(r_gte, expected);
297
298 let r_gte = test_alp_compare(encoded.as_view(), -0.0_f32, CompareOperator::Gte).unwrap();
299 let expected = BoolArray::from_iter([true; 10]);
300 assert_arrays_eq!(r_gte, expected);
301
302 let r_gt =
303 test_alp_compare(encoded.as_view(), -0.0000000001f32, CompareOperator::Gt).unwrap();
304 let expected = BoolArray::from_iter([true; 10]);
305 assert_arrays_eq!(r_gt, expected);
306
307 let r_gte = test_alp_compare(encoded.as_view(), -0.0_f32, CompareOperator::Gt).unwrap();
308 let expected = BoolArray::from_iter([true; 10]);
309 assert_arrays_eq!(r_gte, expected);
310
311 let r_lte = test_alp_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lte).unwrap();
312 let expected = BoolArray::from_iter([true; 10]);
313 assert_arrays_eq!(r_lte, expected);
314
315 let r_lt = test_alp_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lt).unwrap();
316 let expected = BoolArray::from_iter([true; 10]);
317 assert_arrays_eq!(r_lt, expected);
318
319 let r_lt = test_alp_compare(encoded.as_view(), -0.00001_f32, CompareOperator::Lt).unwrap();
320 let expected = BoolArray::from_iter([false; 10]);
321 assert_arrays_eq!(r_lt, expected);
322 }
323
324 #[test]
325 fn compare_with_patches() {
326 let array = PrimitiveArray::from_iter([1.234f32, 1.5, 19.0, f32::consts::E, 1_000_000.9]);
327 let encoded = alp_encode(
328 array.as_view(),
329 None,
330 &mut LEGACY_SESSION.create_execution_ctx(),
331 )
332 .unwrap();
333 assert!(encoded.patches().is_some());
334
335 assert!(
337 alp_scalar_compare(encoded.as_view(), 1_000_000.9_f32, CompareOperator::Eq)
338 .unwrap()
339 .is_none()
340 )
341 }
342
343 #[test]
344 fn compare_to_null() {
345 let array = PrimitiveArray::from_iter([1.234f32; 10]);
346 let encoded = alp_encode(
347 array.as_view(),
348 None,
349 &mut LEGACY_SESSION.create_execution_ctx(),
350 )
351 .unwrap();
352
353 let other = ConstantArray::new(
354 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
355 array.len(),
356 );
357
358 let r = encoded
359 .into_array()
360 .binary(other.into_array(), Operator::Eq)
361 .unwrap();
362 let expected = BoolArray::from_iter([None::<bool>; 10]);
364 assert_arrays_eq!(r, expected);
365 }
366
367 #[rstest]
368 #[case(f32::NAN, false)]
369 #[case(-1.0f32 / 0.0f32, true)]
370 #[case(f32::INFINITY, false)]
371 #[case(f32::NEG_INFINITY, true)]
372 fn compare_to_non_finite_gt(#[case] value: f32, #[case] result: bool) {
373 let array = PrimitiveArray::from_iter([1.234f32; 10]);
374 let encoded = alp_encode(
375 array.as_view(),
376 None,
377 &mut LEGACY_SESSION.create_execution_ctx(),
378 )
379 .unwrap();
380
381 let r = test_alp_compare(encoded.as_view(), value, CompareOperator::Gt).unwrap();
382 let expected = BoolArray::from_iter([result; 10]);
383 assert_arrays_eq!(r, expected);
384 }
385
386 #[rstest]
387 #[case(f32::NAN, true)]
388 #[case(-1.0f32 / 0.0f32, false)]
389 #[case(f32::INFINITY, true)]
390 #[case(f32::NEG_INFINITY, false)]
391 fn compare_to_non_finite_lt(#[case] value: f32, #[case] result: bool) {
392 let array = PrimitiveArray::from_iter([1.234f32; 10]);
393 let encoded = alp_encode(
394 array.as_view(),
395 None,
396 &mut LEGACY_SESSION.create_execution_ctx(),
397 )
398 .unwrap();
399
400 let r = test_alp_compare(encoded.as_view(), value, CompareOperator::Lt).unwrap();
401 let expected = BoolArray::from_iter([result; 10]);
402 assert_arrays_eq!(r, expected);
403 }
404}