1use std::fmt::Debug;
5
6use vortex_array::ArrayRef;
7use vortex_array::DynArray;
8use vortex_array::ExecutionCtx;
9use vortex_array::IntoArray;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::builtins::ArrayBuiltins;
12use vortex_array::dtype::NativePType;
13use vortex_array::scalar::Scalar;
14use vortex_array::scalar_fn::fns::binary::CompareKernel;
15use vortex_array::scalar_fn::fns::operators::CompareOperator;
16use vortex_array::scalar_fn::fns::operators::Operator;
17use vortex_error::VortexResult;
18use vortex_error::vortex_bail;
19use vortex_error::vortex_err;
20
21use crate::ALPArray;
22use crate::ALPFloat;
23use crate::ALPVTable;
24use crate::match_each_alp_float_ptype;
25
26impl CompareKernel for ALPVTable {
29 fn compare(
30 lhs: &ALPArray,
31 rhs: &ArrayRef,
32 operator: CompareOperator,
33 _ctx: &mut ExecutionCtx,
34 ) -> VortexResult<Option<ArrayRef>> {
35 if lhs.patches().is_some() {
36 return Ok(None);
38 }
39 if lhs.dtype().is_nullable() || rhs.dtype().is_nullable() {
40 return Ok(None);
42 }
43
44 if let Some(const_scalar) = rhs.as_constant() {
45 let pscalar = const_scalar.as_primitive_opt().ok_or_else(|| {
46 vortex_err!(
47 "ALP Compare RHS had the wrong type {}, expected {}",
48 const_scalar,
49 const_scalar.dtype()
50 )
51 })?;
52
53 match_each_alp_float_ptype!(pscalar.ptype(), |T| {
54 match pscalar.typed_value::<T>() {
55 Some(value) => return alp_scalar_compare(lhs, value, operator),
56 None => vortex_bail!(
57 "Failed to convert scalar {:?} to ALP type {:?}",
58 pscalar,
59 pscalar.ptype()
60 ),
61 }
62 });
63 }
64
65 Ok(None)
66 }
67}
68
69fn alp_scalar_compare<F: ALPFloat + Into<Scalar>>(
73 alp: &ALPArray,
74 value: F,
75 operator: CompareOperator,
76) -> VortexResult<Option<ArrayRef>>
77where
78 F::ALPInt: Into<Scalar>,
79 <F as ALPFloat>::ALPInt: Debug,
80{
81 if alp.patches().is_some() {
83 return Ok(None);
84 }
85
86 let exponents = alp.exponents();
87 let encoded = F::encode_single(value, alp.exponents());
90 match encoded {
91 Some(encoded) => {
92 let s = ConstantArray::new(encoded, alp.len());
93 Ok(Some(
94 alp.encoded()
95 .binary(s.into_array(), Operator::from(operator))?,
96 ))
97 }
98 None => match operator {
99 CompareOperator::Eq => Ok(Some(ConstantArray::new(false, alp.len()).into_array())),
102 CompareOperator::NotEq => Ok(Some(ConstantArray::new(true, alp.len()).into_array())),
105 CompareOperator::Gt | CompareOperator::Gte => {
106 let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
109 if is_not_finite {
110 Ok(Some(
111 ConstantArray::new(value.is_sign_negative(), alp.len()).into_array(),
112 ))
113 } else {
114 Ok(Some(
115 alp.encoded().binary(
116 ConstantArray::new(F::encode_above(value, exponents), alp.len())
117 .into_array(),
118 Operator::Gte,
122 )?,
123 ))
124 }
125 }
126 CompareOperator::Lt | CompareOperator::Lte => {
127 let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
130 if is_not_finite {
131 Ok(Some(
132 ConstantArray::new(value.is_sign_positive(), alp.len()).into_array(),
133 ))
134 } else {
135 Ok(Some(
136 alp.encoded().binary(
137 ConstantArray::new(F::encode_below(value, exponents), alp.len())
138 .into_array(),
139 Operator::Lte,
142 )?,
143 ))
144 }
145 }
146 },
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use rstest::rstest;
153 use vortex_array::ArrayRef;
154 use vortex_array::ToCanonical;
155 use vortex_array::arrays::BoolArray;
156 use vortex_array::arrays::ConstantArray;
157 use vortex_array::arrays::PrimitiveArray;
158 use vortex_array::assert_arrays_eq;
159 use vortex_array::builtins::ArrayBuiltins;
160 use vortex_array::dtype::DType;
161 use vortex_array::dtype::Nullability;
162 use vortex_array::dtype::PType;
163 use vortex_array::scalar::Scalar;
164 use vortex_array::scalar_fn::fns::operators::CompareOperator;
165 use vortex_array::scalar_fn::fns::operators::Operator;
166
167 use super::*;
168 use crate::alp_encode;
169
170 fn test_alp_compare<F: ALPFloat + Into<Scalar>>(
171 alp: &ALPArray,
172 value: F,
173 operator: CompareOperator,
174 ) -> Option<ArrayRef>
175 where
176 F::ALPInt: Into<Scalar>,
177 <F as ALPFloat>::ALPInt: Debug,
178 {
179 alp_scalar_compare(alp, value, operator).unwrap()
180 }
181
182 #[test]
183 fn basic_comparison_test() {
184 let array = PrimitiveArray::from_iter([1.234f32; 1025]);
185 let encoded = alp_encode(&array, None).unwrap();
186 assert!(encoded.patches().is_none());
187 assert_eq!(
188 encoded.encoded().to_primitive().as_slice::<i32>(),
189 vec![1234; 1025]
190 );
191
192 let r = alp_scalar_compare(&encoded, 1.3_f32, CompareOperator::Eq)
193 .unwrap()
194 .unwrap();
195 let expected = BoolArray::from_iter([false; 1025]);
196 assert_arrays_eq!(r, expected);
197
198 let r = alp_scalar_compare(&encoded, 1.234f32, CompareOperator::Eq)
199 .unwrap()
200 .unwrap();
201 let expected = BoolArray::from_iter([true; 1025]);
202 assert_arrays_eq!(r, expected);
203 }
204
205 #[test]
206 fn comparison_with_unencodable_value() {
207 let array = PrimitiveArray::from_iter([1.234f32; 1025]);
208 let encoded = alp_encode(&array, None).unwrap();
209 assert!(encoded.patches().is_none());
210 assert_eq!(
211 encoded.encoded().to_primitive().as_slice::<i32>(),
212 vec![1234; 1025]
213 );
214
215 #[allow(clippy::excessive_precision)]
216 let r_eq = alp_scalar_compare(&encoded, 1.234444_f32, CompareOperator::Eq)
217 .unwrap()
218 .unwrap();
219 let expected = BoolArray::from_iter([false; 1025]);
220 assert_arrays_eq!(r_eq, expected);
221
222 #[allow(clippy::excessive_precision)]
223 let r_neq = alp_scalar_compare(&encoded, 1.234444f32, CompareOperator::NotEq)
224 .unwrap()
225 .unwrap();
226 let expected = BoolArray::from_iter([true; 1025]);
227 assert_arrays_eq!(r_neq, expected);
228 }
229
230 #[test]
231 fn comparison_range() {
232 let array = PrimitiveArray::from_iter([0.0605_f32; 10]);
233 let encoded = alp_encode(&array, None).unwrap();
234 assert!(encoded.patches().is_none());
235 assert_eq!(
236 encoded.encoded().to_primitive().as_slice::<i32>(),
237 vec![605; 10]
238 );
239
240 let r_gte = alp_scalar_compare(&encoded, 0.06051_f32, CompareOperator::Gte)
242 .unwrap()
243 .unwrap();
244 let expected = BoolArray::from_iter([false; 10]);
245 assert_arrays_eq!(r_gte, expected);
246
247 let r_gt = alp_scalar_compare(&encoded, 0.06051_f32, CompareOperator::Gt)
249 .unwrap()
250 .unwrap();
251 let expected = BoolArray::from_iter([false; 10]);
252 assert_arrays_eq!(r_gt, expected);
253
254 let r_lte = alp_scalar_compare(&encoded, 0.06051_f32, CompareOperator::Lte)
256 .unwrap()
257 .unwrap();
258 let expected = BoolArray::from_iter([true; 10]);
259 assert_arrays_eq!(r_lte, expected);
260
261 let r_lt = alp_scalar_compare(&encoded, 0.06051_f32, CompareOperator::Lt)
263 .unwrap()
264 .unwrap();
265 let expected = BoolArray::from_iter([true; 10]);
266 assert_arrays_eq!(r_lt, expected);
267 }
268
269 #[test]
270 fn comparison_zeroes() {
271 let array = PrimitiveArray::from_iter([0.0_f32; 10]);
272 let encoded = alp_encode(&array, None).unwrap();
273 assert!(encoded.patches().is_none());
274 assert_eq!(
275 encoded.encoded().to_primitive().as_slice::<i32>(),
276 vec![0; 10]
277 );
278
279 let r_gte = test_alp_compare(&encoded, -0.00000001_f32, CompareOperator::Gte).unwrap();
280 let expected = BoolArray::from_iter([true; 10]);
281 assert_arrays_eq!(r_gte, expected);
282
283 let r_gte = test_alp_compare(&encoded, -0.0_f32, CompareOperator::Gte).unwrap();
284 let expected = BoolArray::from_iter([true; 10]);
285 assert_arrays_eq!(r_gte, expected);
286
287 let r_gt = test_alp_compare(&encoded, -0.0000000001f32, CompareOperator::Gt).unwrap();
288 let expected = BoolArray::from_iter([true; 10]);
289 assert_arrays_eq!(r_gt, expected);
290
291 let r_gte = test_alp_compare(&encoded, -0.0_f32, CompareOperator::Gt).unwrap();
292 let expected = BoolArray::from_iter([true; 10]);
293 assert_arrays_eq!(r_gte, expected);
294
295 let r_lte = test_alp_compare(&encoded, 0.06051_f32, CompareOperator::Lte).unwrap();
296 let expected = BoolArray::from_iter([true; 10]);
297 assert_arrays_eq!(r_lte, expected);
298
299 let r_lt = test_alp_compare(&encoded, 0.06051_f32, CompareOperator::Lt).unwrap();
300 let expected = BoolArray::from_iter([true; 10]);
301 assert_arrays_eq!(r_lt, expected);
302
303 let r_lt = test_alp_compare(&encoded, -0.00001_f32, CompareOperator::Lt).unwrap();
304 let expected = BoolArray::from_iter([false; 10]);
305 assert_arrays_eq!(r_lt, expected);
306 }
307
308 #[test]
309 fn compare_with_patches() {
310 let array =
311 PrimitiveArray::from_iter([1.234f32, 1.5, 19.0, std::f32::consts::E, 1_000_000.9]);
312 let encoded = alp_encode(&array, None).unwrap();
313 assert!(encoded.patches().is_some());
314
315 assert!(
317 alp_scalar_compare(&encoded, 1_000_000.9_f32, CompareOperator::Eq)
318 .unwrap()
319 .is_none()
320 )
321 }
322
323 #[test]
324 fn compare_to_null() {
325 let array = PrimitiveArray::from_iter([1.234f32; 10]);
326 let encoded = alp_encode(&array, None).unwrap();
327
328 let other = ConstantArray::new(
329 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
330 array.len(),
331 );
332
333 let r = encoded
334 .into_array()
335 .binary(other.into_array(), Operator::Eq)
336 .unwrap();
337 let expected = BoolArray::from_iter([None::<bool>; 10]);
339 assert_arrays_eq!(r, expected);
340 }
341
342 #[rstest]
343 #[case(f32::NAN, false)]
344 #[case(-1.0f32 / 0.0f32, true)]
345 #[case(f32::INFINITY, false)]
346 #[case(f32::NEG_INFINITY, true)]
347 fn compare_to_non_finite_gt(#[case] value: f32, #[case] result: bool) {
348 let array = PrimitiveArray::from_iter([1.234f32; 10]);
349 let encoded = alp_encode(&array, None).unwrap();
350
351 let r = test_alp_compare(&encoded, value, CompareOperator::Gt).unwrap();
352 let expected = BoolArray::from_iter([result; 10]);
353 assert_arrays_eq!(r, expected);
354 }
355
356 #[rstest]
357 #[case(f32::NAN, true)]
358 #[case(-1.0f32 / 0.0f32, false)]
359 #[case(f32::INFINITY, true)]
360 #[case(f32::NEG_INFINITY, false)]
361 fn compare_to_non_finite_lt(#[case] value: f32, #[case] result: bool) {
362 let array = PrimitiveArray::from_iter([1.234f32; 10]);
363 let encoded = alp_encode(&array, None).unwrap();
364
365 let r = test_alp_compare(&encoded, value, CompareOperator::Lt).unwrap();
366 let expected = BoolArray::from_iter([result; 10]);
367 assert_arrays_eq!(r, expected);
368 }
369}