Skip to main content

vortex_alp/alp/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use vortex_array::ArrayRef;
7use vortex_array::ArrayView;
8use vortex_array::ExecutionCtx;
9use vortex_array::IntoArray;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::builtins::ArrayBuiltins;
12use vortex_array::dtype::NativePType;
13use vortex_array::scalar::Scalar;
14use vortex_array::scalar_fn::fns::binary::CompareKernel;
15use vortex_array::scalar_fn::fns::operators::CompareOperator;
16use vortex_array::scalar_fn::fns::operators::Operator;
17use vortex_error::VortexResult;
18use vortex_error::vortex_bail;
19use vortex_error::vortex_err;
20
21use crate::ALP;
22use crate::ALPArrayExt;
23use crate::ALPArraySlotsExt;
24use crate::ALPFloat;
25use crate::match_each_alp_float_ptype;
26
27// TODO(joe): add fuzzing.
28
29impl CompareKernel for ALP {
30    fn compare(
31        lhs: ArrayView<'_, Self>,
32        rhs: &ArrayRef,
33        operator: CompareOperator,
34        _ctx: &mut ExecutionCtx,
35    ) -> VortexResult<Option<ArrayRef>> {
36        if lhs.patches().is_some() {
37            // TODO(joe): support patches
38            return Ok(None);
39        }
40        if lhs.dtype().is_nullable() || rhs.dtype().is_nullable() {
41            // TODO(joe): support nullability
42            return Ok(None);
43        }
44
45        if let Some(const_scalar) = rhs.as_constant() {
46            let pscalar = const_scalar.as_primitive_opt().ok_or_else(|| {
47                vortex_err!(
48                    "ALP Compare RHS had the wrong type {}, expected {}",
49                    const_scalar,
50                    const_scalar.dtype()
51                )
52            })?;
53
54            match_each_alp_float_ptype!(pscalar.ptype(), |T| {
55                match pscalar.typed_value::<T>() {
56                    Some(value) => return alp_scalar_compare(lhs, value, operator),
57                    None => vortex_bail!(
58                        "Failed to convert scalar {:?} to ALP type {:?}",
59                        pscalar,
60                        pscalar.ptype()
61                    ),
62                }
63            });
64        }
65
66        Ok(None)
67    }
68}
69
70/// We can compare a scalar to an ALPArray by encoding the scalar into the ALP domain and comparing
71/// the encoded value to the encoded values in the ALPArray. There are fixups when the value doesn't
72/// encode into the ALP domain.
73fn alp_scalar_compare<F: ALPFloat + Into<Scalar>>(
74    alp: ArrayView<ALP>,
75    value: F,
76    operator: CompareOperator,
77) -> VortexResult<Option<ArrayRef>>
78where
79    F::ALPInt: Into<Scalar>,
80    <F as ALPFloat>::ALPInt: Debug,
81{
82    // TODO(joe): support patches, this is checked above.
83    if alp.patches().is_some() {
84        return Ok(None);
85    }
86
87    let exponents = alp.exponents();
88    // If the scalar doesn't fit into the ALP domain,
89    // it cannot be equal to any values in the encoded array.
90    let encoded = F::encode_single(value, alp.exponents());
91    match encoded {
92        Some(encoded) => {
93            let s = ConstantArray::new(encoded, alp.len());
94            Ok(Some(
95                alp.encoded()
96                    .binary(s.into_array(), Operator::from(operator))?,
97            ))
98        }
99        None => match operator {
100            // Since this value is not encodable it cannot be equal to any value in the encoded
101            // array.
102            CompareOperator::Eq => Ok(Some(ConstantArray::new(false, alp.len()).into_array())),
103            // Since this value is not encodable it cannot be equal to any value in the encoded
104            // array, hence != to all values in the encoded array.
105            CompareOperator::NotEq => Ok(Some(ConstantArray::new(true, alp.len()).into_array())),
106            CompareOperator::Gt | CompareOperator::Gte => {
107                // Per IEEE 754 totalOrder semantics the ordering is -Nan < -Inf < Inf < Nan.
108                // All values in the encoded array are definitely finite
109                let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
110                if is_not_finite {
111                    Ok(Some(
112                        ConstantArray::new(value.is_sign_negative(), alp.len()).into_array(),
113                    ))
114                } else {
115                    Ok(Some(
116                        alp.encoded().binary(
117                            ConstantArray::new(F::encode_above(value, exponents), alp.len())
118                                .into_array(),
119                            // Since the encoded value is unencodable gte is equivalent to gt.
120                            // Consider a value v, between two encodable values v_l (just less) and
121                            // v_a (just above), then for all encodable values (u), v > u <=> v_g >= u
122                            Operator::Gte,
123                        )?,
124                    ))
125                }
126            }
127            CompareOperator::Lt | CompareOperator::Lte => {
128                // Per IEEE 754 totalOrder semantics the ordering is -Nan < -Inf < Inf < Nan.
129                // All values in the encoded array are definitely finite
130                let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
131                if is_not_finite {
132                    Ok(Some(
133                        ConstantArray::new(value.is_sign_positive(), alp.len()).into_array(),
134                    ))
135                } else {
136                    Ok(Some(
137                        alp.encoded().binary(
138                            ConstantArray::new(F::encode_below(value, exponents), alp.len())
139                                .into_array(),
140                            // Since the encoded values unencodable lt is equivalent to lte.
141                            // See Gt | Gte for further explanation.
142                            Operator::Lte,
143                        )?,
144                    ))
145                }
146            }
147        },
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use rstest::rstest;
154    use vortex_array::ArrayRef;
155    use vortex_array::LEGACY_SESSION;
156    use vortex_array::ToCanonical;
157    use vortex_array::VortexSessionExecute;
158    use vortex_array::arrays::BoolArray;
159    use vortex_array::arrays::ConstantArray;
160    use vortex_array::arrays::PrimitiveArray;
161    use vortex_array::assert_arrays_eq;
162    use vortex_array::builtins::ArrayBuiltins;
163    use vortex_array::dtype::DType;
164    use vortex_array::dtype::Nullability;
165    use vortex_array::dtype::PType;
166    use vortex_array::scalar::Scalar;
167    use vortex_array::scalar_fn::fns::operators::CompareOperator;
168    use vortex_array::scalar_fn::fns::operators::Operator;
169
170    use super::*;
171    use crate::alp_encode;
172
173    fn test_alp_compare<F: ALPFloat + Into<Scalar>>(
174        alp: ArrayView<ALP>,
175        value: F,
176        operator: CompareOperator,
177    ) -> Option<ArrayRef>
178    where
179        F::ALPInt: Into<Scalar>,
180        <F as ALPFloat>::ALPInt: Debug,
181    {
182        alp_scalar_compare(alp, value, operator).unwrap()
183    }
184
185    #[test]
186    fn basic_comparison_test() {
187        let array = PrimitiveArray::from_iter([1.234f32; 1025]);
188        let encoded = alp_encode(
189            array.as_view(),
190            None,
191            &mut LEGACY_SESSION.create_execution_ctx(),
192        )
193        .unwrap();
194        assert!(encoded.patches().is_none());
195        assert_eq!(
196            encoded.encoded().to_primitive().as_slice::<i32>(),
197            vec![1234; 1025]
198        );
199
200        let r = alp_scalar_compare(encoded.as_view(), 1.3_f32, CompareOperator::Eq)
201            .unwrap()
202            .unwrap();
203        let expected = BoolArray::from_iter([false; 1025]);
204        assert_arrays_eq!(r, expected);
205
206        let r = alp_scalar_compare(encoded.as_view(), 1.234f32, CompareOperator::Eq)
207            .unwrap()
208            .unwrap();
209        let expected = BoolArray::from_iter([true; 1025]);
210        assert_arrays_eq!(r, expected);
211    }
212
213    #[test]
214    fn comparison_with_unencodable_value() {
215        let array = PrimitiveArray::from_iter([1.234f32; 1025]);
216        let encoded = alp_encode(
217            array.as_view(),
218            None,
219            &mut LEGACY_SESSION.create_execution_ctx(),
220        )
221        .unwrap();
222        assert!(encoded.patches().is_none());
223        assert_eq!(
224            encoded.encoded().to_primitive().as_slice::<i32>(),
225            vec![1234; 1025]
226        );
227
228        let r_eq = alp_scalar_compare(encoded.as_view(), 1.234444_f32, CompareOperator::Eq)
229            .unwrap()
230            .unwrap();
231        let expected = BoolArray::from_iter([false; 1025]);
232        assert_arrays_eq!(r_eq, expected);
233
234        let r_neq = alp_scalar_compare(encoded.as_view(), 1.234444f32, CompareOperator::NotEq)
235            .unwrap()
236            .unwrap();
237        let expected = BoolArray::from_iter([true; 1025]);
238        assert_arrays_eq!(r_neq, expected);
239    }
240
241    #[test]
242    fn comparison_range() {
243        let array = PrimitiveArray::from_iter([0.0605_f32; 10]);
244        let encoded = alp_encode(
245            array.as_view(),
246            None,
247            &mut LEGACY_SESSION.create_execution_ctx(),
248        )
249        .unwrap();
250        assert!(encoded.patches().is_none());
251        assert_eq!(
252            encoded.encoded().to_primitive().as_slice::<i32>(),
253            vec![605; 10]
254        );
255
256        // !(0.0605_f32 >= 0.06051_f32);
257        let r_gte = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Gte)
258            .unwrap()
259            .unwrap();
260        let expected = BoolArray::from_iter([false; 10]);
261        assert_arrays_eq!(r_gte, expected);
262
263        // (0.0605_f32 > 0.06051_f32);
264        let r_gt = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Gt)
265            .unwrap()
266            .unwrap();
267        let expected = BoolArray::from_iter([false; 10]);
268        assert_arrays_eq!(r_gt, expected);
269
270        // 0.0605_f32 <= 0.06051_f32;
271        let r_lte = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lte)
272            .unwrap()
273            .unwrap();
274        let expected = BoolArray::from_iter([true; 10]);
275        assert_arrays_eq!(r_lte, expected);
276
277        //0.0605_f32 < 0.06051_f32;
278        let r_lt = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lt)
279            .unwrap()
280            .unwrap();
281        let expected = BoolArray::from_iter([true; 10]);
282        assert_arrays_eq!(r_lt, expected);
283    }
284
285    #[test]
286    fn comparison_zeroes() {
287        let array = PrimitiveArray::from_iter([0.0_f32; 10]);
288        let encoded = alp_encode(
289            array.as_view(),
290            None,
291            &mut LEGACY_SESSION.create_execution_ctx(),
292        )
293        .unwrap();
294        assert!(encoded.patches().is_none());
295        assert_eq!(
296            encoded.encoded().to_primitive().as_slice::<i32>(),
297            vec![0; 10]
298        );
299
300        let r_gte =
301            test_alp_compare(encoded.as_view(), -0.00000001_f32, CompareOperator::Gte).unwrap();
302        let expected = BoolArray::from_iter([true; 10]);
303        assert_arrays_eq!(r_gte, expected);
304
305        let r_gte = test_alp_compare(encoded.as_view(), -0.0_f32, CompareOperator::Gte).unwrap();
306        let expected = BoolArray::from_iter([true; 10]);
307        assert_arrays_eq!(r_gte, expected);
308
309        let r_gt =
310            test_alp_compare(encoded.as_view(), -0.0000000001f32, CompareOperator::Gt).unwrap();
311        let expected = BoolArray::from_iter([true; 10]);
312        assert_arrays_eq!(r_gt, expected);
313
314        let r_gte = test_alp_compare(encoded.as_view(), -0.0_f32, CompareOperator::Gt).unwrap();
315        let expected = BoolArray::from_iter([true; 10]);
316        assert_arrays_eq!(r_gte, expected);
317
318        let r_lte = test_alp_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lte).unwrap();
319        let expected = BoolArray::from_iter([true; 10]);
320        assert_arrays_eq!(r_lte, expected);
321
322        let r_lt = test_alp_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lt).unwrap();
323        let expected = BoolArray::from_iter([true; 10]);
324        assert_arrays_eq!(r_lt, expected);
325
326        let r_lt = test_alp_compare(encoded.as_view(), -0.00001_f32, CompareOperator::Lt).unwrap();
327        let expected = BoolArray::from_iter([false; 10]);
328        assert_arrays_eq!(r_lt, expected);
329    }
330
331    #[test]
332    fn compare_with_patches() {
333        let array =
334            PrimitiveArray::from_iter([1.234f32, 1.5, 19.0, std::f32::consts::E, 1_000_000.9]);
335        let encoded = alp_encode(
336            array.as_view(),
337            None,
338            &mut LEGACY_SESSION.create_execution_ctx(),
339        )
340        .unwrap();
341        assert!(encoded.patches().is_some());
342
343        // Not supported!
344        assert!(
345            alp_scalar_compare(encoded.as_view(), 1_000_000.9_f32, CompareOperator::Eq)
346                .unwrap()
347                .is_none()
348        )
349    }
350
351    #[test]
352    fn compare_to_null() {
353        let array = PrimitiveArray::from_iter([1.234f32; 10]);
354        let encoded = alp_encode(
355            array.as_view(),
356            None,
357            &mut LEGACY_SESSION.create_execution_ctx(),
358        )
359        .unwrap();
360
361        let other = ConstantArray::new(
362            Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
363            array.len(),
364        );
365
366        let r = encoded
367            .into_array()
368            .binary(other.into_array(), Operator::Eq)
369            .unwrap();
370        // Comparing to null yields null results
371        let expected = BoolArray::from_iter([None::<bool>; 10]);
372        assert_arrays_eq!(r, expected);
373    }
374
375    #[rstest]
376    #[case(f32::NAN, false)]
377    #[case(-1.0f32 / 0.0f32, true)]
378    #[case(f32::INFINITY, false)]
379    #[case(f32::NEG_INFINITY, true)]
380    fn compare_to_non_finite_gt(#[case] value: f32, #[case] result: bool) {
381        let array = PrimitiveArray::from_iter([1.234f32; 10]);
382        let encoded = alp_encode(
383            array.as_view(),
384            None,
385            &mut LEGACY_SESSION.create_execution_ctx(),
386        )
387        .unwrap();
388
389        let r = test_alp_compare(encoded.as_view(), value, CompareOperator::Gt).unwrap();
390        let expected = BoolArray::from_iter([result; 10]);
391        assert_arrays_eq!(r, expected);
392    }
393
394    #[rstest]
395    #[case(f32::NAN, true)]
396    #[case(-1.0f32 / 0.0f32, false)]
397    #[case(f32::INFINITY, true)]
398    #[case(f32::NEG_INFINITY, false)]
399    fn compare_to_non_finite_lt(#[case] value: f32, #[case] result: bool) {
400        let array = PrimitiveArray::from_iter([1.234f32; 10]);
401        let encoded = alp_encode(
402            array.as_view(),
403            None,
404            &mut LEGACY_SESSION.create_execution_ctx(),
405        )
406        .unwrap();
407
408        let r = test_alp_compare(encoded.as_view(), value, CompareOperator::Lt).unwrap();
409        let expected = BoolArray::from_iter([result; 10]);
410        assert_arrays_eq!(r, expected);
411    }
412}