1use std::fmt::Debug;
5
6use vortex_array::ArrayRef;
7use vortex_array::ArrayView;
8use vortex_array::ExecutionCtx;
9use vortex_array::IntoArray;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::builtins::ArrayBuiltins;
12use vortex_array::dtype::NativePType;
13use vortex_array::scalar::Scalar;
14use vortex_array::scalar_fn::fns::binary::CompareKernel;
15use vortex_array::scalar_fn::fns::operators::CompareOperator;
16use vortex_array::scalar_fn::fns::operators::Operator;
17use vortex_error::VortexResult;
18use vortex_error::vortex_bail;
19use vortex_error::vortex_err;
20
21use crate::ALP;
22use crate::ALPArrayExt;
23use crate::ALPArraySlotsExt;
24use crate::ALPFloat;
25use crate::match_each_alp_float_ptype;
26
27impl CompareKernel for ALP {
30 fn compare(
31 lhs: ArrayView<'_, Self>,
32 rhs: &ArrayRef,
33 operator: CompareOperator,
34 _ctx: &mut ExecutionCtx,
35 ) -> VortexResult<Option<ArrayRef>> {
36 if lhs.patches().is_some() {
37 return Ok(None);
39 }
40 if lhs.dtype().is_nullable() || rhs.dtype().is_nullable() {
41 return Ok(None);
43 }
44
45 if let Some(const_scalar) = rhs.as_constant() {
46 let pscalar = const_scalar.as_primitive_opt().ok_or_else(|| {
47 vortex_err!(
48 "ALP Compare RHS had the wrong type {}, expected {}",
49 const_scalar,
50 const_scalar.dtype()
51 )
52 })?;
53
54 match_each_alp_float_ptype!(pscalar.ptype(), |T| {
55 match pscalar.typed_value::<T>() {
56 Some(value) => return alp_scalar_compare(lhs, value, operator),
57 None => vortex_bail!(
58 "Failed to convert scalar {:?} to ALP type {:?}",
59 pscalar,
60 pscalar.ptype()
61 ),
62 }
63 });
64 }
65
66 Ok(None)
67 }
68}
69
70fn alp_scalar_compare<F: ALPFloat + Into<Scalar>>(
74 alp: ArrayView<ALP>,
75 value: F,
76 operator: CompareOperator,
77) -> VortexResult<Option<ArrayRef>>
78where
79 F::ALPInt: Into<Scalar>,
80 <F as ALPFloat>::ALPInt: Debug,
81{
82 if alp.patches().is_some() {
84 return Ok(None);
85 }
86
87 let exponents = alp.exponents();
88 let encoded = F::encode_single(value, alp.exponents());
91 match encoded {
92 Some(encoded) => {
93 let s = ConstantArray::new(encoded, alp.len());
94 Ok(Some(
95 alp.encoded()
96 .binary(s.into_array(), Operator::from(operator))?,
97 ))
98 }
99 None => match operator {
100 CompareOperator::Eq => Ok(Some(ConstantArray::new(false, alp.len()).into_array())),
103 CompareOperator::NotEq => Ok(Some(ConstantArray::new(true, alp.len()).into_array())),
106 CompareOperator::Gt | CompareOperator::Gte => {
107 let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
110 if is_not_finite {
111 Ok(Some(
112 ConstantArray::new(value.is_sign_negative(), alp.len()).into_array(),
113 ))
114 } else {
115 Ok(Some(
116 alp.encoded().binary(
117 ConstantArray::new(F::encode_above(value, exponents), alp.len())
118 .into_array(),
119 Operator::Gte,
123 )?,
124 ))
125 }
126 }
127 CompareOperator::Lt | CompareOperator::Lte => {
128 let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
131 if is_not_finite {
132 Ok(Some(
133 ConstantArray::new(value.is_sign_positive(), alp.len()).into_array(),
134 ))
135 } else {
136 Ok(Some(
137 alp.encoded().binary(
138 ConstantArray::new(F::encode_below(value, exponents), alp.len())
139 .into_array(),
140 Operator::Lte,
143 )?,
144 ))
145 }
146 }
147 },
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use rstest::rstest;
154 use vortex_array::ArrayRef;
155 use vortex_array::ToCanonical;
156 use vortex_array::arrays::BoolArray;
157 use vortex_array::arrays::ConstantArray;
158 use vortex_array::arrays::PrimitiveArray;
159 use vortex_array::assert_arrays_eq;
160 use vortex_array::builtins::ArrayBuiltins;
161 use vortex_array::dtype::DType;
162 use vortex_array::dtype::Nullability;
163 use vortex_array::dtype::PType;
164 use vortex_array::scalar::Scalar;
165 use vortex_array::scalar_fn::fns::operators::CompareOperator;
166 use vortex_array::scalar_fn::fns::operators::Operator;
167
168 use super::*;
169 use crate::alp_encode;
170
171 fn test_alp_compare<F: ALPFloat + Into<Scalar>>(
172 alp: ArrayView<ALP>,
173 value: F,
174 operator: CompareOperator,
175 ) -> Option<ArrayRef>
176 where
177 F::ALPInt: Into<Scalar>,
178 <F as ALPFloat>::ALPInt: Debug,
179 {
180 alp_scalar_compare(alp, value, operator).unwrap()
181 }
182
183 #[test]
184 fn basic_comparison_test() {
185 let array = PrimitiveArray::from_iter([1.234f32; 1025]);
186 let encoded = alp_encode(&array, None).unwrap();
187 assert!(encoded.patches().is_none());
188 assert_eq!(
189 encoded.encoded().to_primitive().as_slice::<i32>(),
190 vec![1234; 1025]
191 );
192
193 let r = alp_scalar_compare(encoded.as_view(), 1.3_f32, CompareOperator::Eq)
194 .unwrap()
195 .unwrap();
196 let expected = BoolArray::from_iter([false; 1025]);
197 assert_arrays_eq!(r, expected);
198
199 let r = alp_scalar_compare(encoded.as_view(), 1.234f32, CompareOperator::Eq)
200 .unwrap()
201 .unwrap();
202 let expected = BoolArray::from_iter([true; 1025]);
203 assert_arrays_eq!(r, expected);
204 }
205
206 #[test]
207 fn comparison_with_unencodable_value() {
208 let array = PrimitiveArray::from_iter([1.234f32; 1025]);
209 let encoded = alp_encode(&array, None).unwrap();
210 assert!(encoded.patches().is_none());
211 assert_eq!(
212 encoded.encoded().to_primitive().as_slice::<i32>(),
213 vec![1234; 1025]
214 );
215
216 #[allow(clippy::excessive_precision)]
217 let r_eq = alp_scalar_compare(encoded.as_view(), 1.234444_f32, CompareOperator::Eq)
218 .unwrap()
219 .unwrap();
220 let expected = BoolArray::from_iter([false; 1025]);
221 assert_arrays_eq!(r_eq, expected);
222
223 #[allow(clippy::excessive_precision)]
224 let r_neq = alp_scalar_compare(encoded.as_view(), 1.234444f32, CompareOperator::NotEq)
225 .unwrap()
226 .unwrap();
227 let expected = BoolArray::from_iter([true; 1025]);
228 assert_arrays_eq!(r_neq, expected);
229 }
230
231 #[test]
232 fn comparison_range() {
233 let array = PrimitiveArray::from_iter([0.0605_f32; 10]);
234 let encoded = alp_encode(&array, None).unwrap();
235 assert!(encoded.patches().is_none());
236 assert_eq!(
237 encoded.encoded().to_primitive().as_slice::<i32>(),
238 vec![605; 10]
239 );
240
241 let r_gte = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Gte)
243 .unwrap()
244 .unwrap();
245 let expected = BoolArray::from_iter([false; 10]);
246 assert_arrays_eq!(r_gte, expected);
247
248 let r_gt = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Gt)
250 .unwrap()
251 .unwrap();
252 let expected = BoolArray::from_iter([false; 10]);
253 assert_arrays_eq!(r_gt, expected);
254
255 let r_lte = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lte)
257 .unwrap()
258 .unwrap();
259 let expected = BoolArray::from_iter([true; 10]);
260 assert_arrays_eq!(r_lte, expected);
261
262 let r_lt = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lt)
264 .unwrap()
265 .unwrap();
266 let expected = BoolArray::from_iter([true; 10]);
267 assert_arrays_eq!(r_lt, expected);
268 }
269
270 #[test]
271 fn comparison_zeroes() {
272 let array = PrimitiveArray::from_iter([0.0_f32; 10]);
273 let encoded = alp_encode(&array, None).unwrap();
274 assert!(encoded.patches().is_none());
275 assert_eq!(
276 encoded.encoded().to_primitive().as_slice::<i32>(),
277 vec![0; 10]
278 );
279
280 let r_gte =
281 test_alp_compare(encoded.as_view(), -0.00000001_f32, CompareOperator::Gte).unwrap();
282 let expected = BoolArray::from_iter([true; 10]);
283 assert_arrays_eq!(r_gte, expected);
284
285 let r_gte = test_alp_compare(encoded.as_view(), -0.0_f32, CompareOperator::Gte).unwrap();
286 let expected = BoolArray::from_iter([true; 10]);
287 assert_arrays_eq!(r_gte, expected);
288
289 let r_gt =
290 test_alp_compare(encoded.as_view(), -0.0000000001f32, CompareOperator::Gt).unwrap();
291 let expected = BoolArray::from_iter([true; 10]);
292 assert_arrays_eq!(r_gt, expected);
293
294 let r_gte = test_alp_compare(encoded.as_view(), -0.0_f32, CompareOperator::Gt).unwrap();
295 let expected = BoolArray::from_iter([true; 10]);
296 assert_arrays_eq!(r_gte, expected);
297
298 let r_lte = test_alp_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lte).unwrap();
299 let expected = BoolArray::from_iter([true; 10]);
300 assert_arrays_eq!(r_lte, expected);
301
302 let r_lt = test_alp_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lt).unwrap();
303 let expected = BoolArray::from_iter([true; 10]);
304 assert_arrays_eq!(r_lt, expected);
305
306 let r_lt = test_alp_compare(encoded.as_view(), -0.00001_f32, CompareOperator::Lt).unwrap();
307 let expected = BoolArray::from_iter([false; 10]);
308 assert_arrays_eq!(r_lt, expected);
309 }
310
311 #[test]
312 fn compare_with_patches() {
313 let array =
314 PrimitiveArray::from_iter([1.234f32, 1.5, 19.0, std::f32::consts::E, 1_000_000.9]);
315 let encoded = alp_encode(&array, None).unwrap();
316 assert!(encoded.patches().is_some());
317
318 assert!(
320 alp_scalar_compare(encoded.as_view(), 1_000_000.9_f32, CompareOperator::Eq)
321 .unwrap()
322 .is_none()
323 )
324 }
325
326 #[test]
327 fn compare_to_null() {
328 let array = PrimitiveArray::from_iter([1.234f32; 10]);
329 let encoded = alp_encode(&array, None).unwrap();
330
331 let other = ConstantArray::new(
332 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
333 array.len(),
334 );
335
336 let r = encoded
337 .into_array()
338 .binary(other.into_array(), Operator::Eq)
339 .unwrap();
340 let expected = BoolArray::from_iter([None::<bool>; 10]);
342 assert_arrays_eq!(r, expected);
343 }
344
345 #[rstest]
346 #[case(f32::NAN, false)]
347 #[case(-1.0f32 / 0.0f32, true)]
348 #[case(f32::INFINITY, false)]
349 #[case(f32::NEG_INFINITY, true)]
350 fn compare_to_non_finite_gt(#[case] value: f32, #[case] result: bool) {
351 let array = PrimitiveArray::from_iter([1.234f32; 10]);
352 let encoded = alp_encode(&array, None).unwrap();
353
354 let r = test_alp_compare(encoded.as_view(), value, CompareOperator::Gt).unwrap();
355 let expected = BoolArray::from_iter([result; 10]);
356 assert_arrays_eq!(r, expected);
357 }
358
359 #[rstest]
360 #[case(f32::NAN, true)]
361 #[case(-1.0f32 / 0.0f32, false)]
362 #[case(f32::INFINITY, true)]
363 #[case(f32::NEG_INFINITY, false)]
364 fn compare_to_non_finite_lt(#[case] value: f32, #[case] result: bool) {
365 let array = PrimitiveArray::from_iter([1.234f32; 10]);
366 let encoded = alp_encode(&array, None).unwrap();
367
368 let r = test_alp_compare(encoded.as_view(), value, CompareOperator::Lt).unwrap();
369 let expected = BoolArray::from_iter([result; 10]);
370 assert_arrays_eq!(r, expected);
371 }
372}