Skip to main content

vortex_alp/alp/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use vortex_array::ArrayRef;
7use vortex_array::DynArray;
8use vortex_array::ExecutionCtx;
9use vortex_array::IntoArray;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::builtins::ArrayBuiltins;
12use vortex_array::dtype::NativePType;
13use vortex_array::scalar::Scalar;
14use vortex_array::scalar_fn::fns::binary::CompareKernel;
15use vortex_array::scalar_fn::fns::operators::CompareOperator;
16use vortex_array::scalar_fn::fns::operators::Operator;
17use vortex_error::VortexResult;
18use vortex_error::vortex_bail;
19use vortex_error::vortex_err;
20
21use crate::ALPArray;
22use crate::ALPFloat;
23use crate::ALPVTable;
24use crate::match_each_alp_float_ptype;
25
26// TODO(joe): add fuzzing.
27
28impl CompareKernel for ALPVTable {
29    fn compare(
30        lhs: &ALPArray,
31        rhs: &ArrayRef,
32        operator: CompareOperator,
33        _ctx: &mut ExecutionCtx,
34    ) -> VortexResult<Option<ArrayRef>> {
35        if lhs.patches().is_some() {
36            // TODO(joe): support patches
37            return Ok(None);
38        }
39        if lhs.dtype().is_nullable() || rhs.dtype().is_nullable() {
40            // TODO(joe): support nullability
41            return Ok(None);
42        }
43
44        if let Some(const_scalar) = rhs.as_constant() {
45            let pscalar = const_scalar.as_primitive_opt().ok_or_else(|| {
46                vortex_err!(
47                    "ALP Compare RHS had the wrong type {}, expected {}",
48                    const_scalar,
49                    const_scalar.dtype()
50                )
51            })?;
52
53            match_each_alp_float_ptype!(pscalar.ptype(), |T| {
54                match pscalar.typed_value::<T>() {
55                    Some(value) => return alp_scalar_compare(lhs, value, operator),
56                    None => vortex_bail!(
57                        "Failed to convert scalar {:?} to ALP type {:?}",
58                        pscalar,
59                        pscalar.ptype()
60                    ),
61                }
62            });
63        }
64
65        Ok(None)
66    }
67}
68
69/// We can compare a scalar to an ALPArray by encoding the scalar into the ALP domain and comparing
70/// the encoded value to the encoded values in the ALPArray. There are fixups when the value doesn't
71/// encode into the ALP domain.
72fn alp_scalar_compare<F: ALPFloat + Into<Scalar>>(
73    alp: &ALPArray,
74    value: F,
75    operator: CompareOperator,
76) -> VortexResult<Option<ArrayRef>>
77where
78    F::ALPInt: Into<Scalar>,
79    <F as ALPFloat>::ALPInt: Debug,
80{
81    // TODO(joe): support patches, this is checked above.
82    if alp.patches().is_some() {
83        return Ok(None);
84    }
85
86    let exponents = alp.exponents();
87    // If the scalar doesn't fit into the ALP domain,
88    // it cannot be equal to any values in the encoded array.
89    let encoded = F::encode_single(value, alp.exponents());
90    match encoded {
91        Some(encoded) => {
92            let s = ConstantArray::new(encoded, alp.len());
93            Ok(Some(
94                alp.encoded()
95                    .binary(s.into_array(), Operator::from(operator))?,
96            ))
97        }
98        None => match operator {
99            // Since this value is not encodable it cannot be equal to any value in the encoded
100            // array.
101            CompareOperator::Eq => Ok(Some(ConstantArray::new(false, alp.len()).into_array())),
102            // Since this value is not encodable it cannot be equal to any value in the encoded
103            // array, hence != to all values in the encoded array.
104            CompareOperator::NotEq => Ok(Some(ConstantArray::new(true, alp.len()).into_array())),
105            CompareOperator::Gt | CompareOperator::Gte => {
106                // Per IEEE 754 totalOrder semantics the ordering is -Nan < -Inf < Inf < Nan.
107                // All values in the encoded array are definitely finite
108                let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
109                if is_not_finite {
110                    Ok(Some(
111                        ConstantArray::new(value.is_sign_negative(), alp.len()).into_array(),
112                    ))
113                } else {
114                    Ok(Some(
115                        alp.encoded().binary(
116                            ConstantArray::new(F::encode_above(value, exponents), alp.len())
117                                .into_array(),
118                            // Since the encoded value is unencodable gte is equivalent to gt.
119                            // Consider a value v, between two encodable values v_l (just less) and
120                            // v_a (just above), then for all encodable values (u), v > u <=> v_g >= u
121                            Operator::Gte,
122                        )?,
123                    ))
124                }
125            }
126            CompareOperator::Lt | CompareOperator::Lte => {
127                // Per IEEE 754 totalOrder semantics the ordering is -Nan < -Inf < Inf < Nan.
128                // All values in the encoded array are definitely finite
129                let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
130                if is_not_finite {
131                    Ok(Some(
132                        ConstantArray::new(value.is_sign_positive(), alp.len()).into_array(),
133                    ))
134                } else {
135                    Ok(Some(
136                        alp.encoded().binary(
137                            ConstantArray::new(F::encode_below(value, exponents), alp.len())
138                                .into_array(),
139                            // Since the encoded values unencodable lt is equivalent to lte.
140                            // See Gt | Gte for further explanation.
141                            Operator::Lte,
142                        )?,
143                    ))
144                }
145            }
146        },
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use rstest::rstest;
153    use vortex_array::ArrayRef;
154    use vortex_array::ToCanonical;
155    use vortex_array::arrays::BoolArray;
156    use vortex_array::arrays::ConstantArray;
157    use vortex_array::arrays::PrimitiveArray;
158    use vortex_array::assert_arrays_eq;
159    use vortex_array::builtins::ArrayBuiltins;
160    use vortex_array::dtype::DType;
161    use vortex_array::dtype::Nullability;
162    use vortex_array::dtype::PType;
163    use vortex_array::scalar::Scalar;
164    use vortex_array::scalar_fn::fns::operators::CompareOperator;
165    use vortex_array::scalar_fn::fns::operators::Operator;
166
167    use super::*;
168    use crate::alp_encode;
169
170    fn test_alp_compare<F: ALPFloat + Into<Scalar>>(
171        alp: &ALPArray,
172        value: F,
173        operator: CompareOperator,
174    ) -> Option<ArrayRef>
175    where
176        F::ALPInt: Into<Scalar>,
177        <F as ALPFloat>::ALPInt: Debug,
178    {
179        alp_scalar_compare(alp, value, operator).unwrap()
180    }
181
182    #[test]
183    fn basic_comparison_test() {
184        let array = PrimitiveArray::from_iter([1.234f32; 1025]);
185        let encoded = alp_encode(&array, None).unwrap();
186        assert!(encoded.patches().is_none());
187        assert_eq!(
188            encoded.encoded().to_primitive().as_slice::<i32>(),
189            vec![1234; 1025]
190        );
191
192        let r = alp_scalar_compare(&encoded, 1.3_f32, CompareOperator::Eq)
193            .unwrap()
194            .unwrap();
195        let expected = BoolArray::from_iter([false; 1025]);
196        assert_arrays_eq!(r, expected);
197
198        let r = alp_scalar_compare(&encoded, 1.234f32, CompareOperator::Eq)
199            .unwrap()
200            .unwrap();
201        let expected = BoolArray::from_iter([true; 1025]);
202        assert_arrays_eq!(r, expected);
203    }
204
205    #[test]
206    fn comparison_with_unencodable_value() {
207        let array = PrimitiveArray::from_iter([1.234f32; 1025]);
208        let encoded = alp_encode(&array, None).unwrap();
209        assert!(encoded.patches().is_none());
210        assert_eq!(
211            encoded.encoded().to_primitive().as_slice::<i32>(),
212            vec![1234; 1025]
213        );
214
215        #[allow(clippy::excessive_precision)]
216        let r_eq = alp_scalar_compare(&encoded, 1.234444_f32, CompareOperator::Eq)
217            .unwrap()
218            .unwrap();
219        let expected = BoolArray::from_iter([false; 1025]);
220        assert_arrays_eq!(r_eq, expected);
221
222        #[allow(clippy::excessive_precision)]
223        let r_neq = alp_scalar_compare(&encoded, 1.234444f32, CompareOperator::NotEq)
224            .unwrap()
225            .unwrap();
226        let expected = BoolArray::from_iter([true; 1025]);
227        assert_arrays_eq!(r_neq, expected);
228    }
229
230    #[test]
231    fn comparison_range() {
232        let array = PrimitiveArray::from_iter([0.0605_f32; 10]);
233        let encoded = alp_encode(&array, None).unwrap();
234        assert!(encoded.patches().is_none());
235        assert_eq!(
236            encoded.encoded().to_primitive().as_slice::<i32>(),
237            vec![605; 10]
238        );
239
240        // !(0.0605_f32 >= 0.06051_f32);
241        let r_gte = alp_scalar_compare(&encoded, 0.06051_f32, CompareOperator::Gte)
242            .unwrap()
243            .unwrap();
244        let expected = BoolArray::from_iter([false; 10]);
245        assert_arrays_eq!(r_gte, expected);
246
247        // (0.0605_f32 > 0.06051_f32);
248        let r_gt = alp_scalar_compare(&encoded, 0.06051_f32, CompareOperator::Gt)
249            .unwrap()
250            .unwrap();
251        let expected = BoolArray::from_iter([false; 10]);
252        assert_arrays_eq!(r_gt, expected);
253
254        // 0.0605_f32 <= 0.06051_f32;
255        let r_lte = alp_scalar_compare(&encoded, 0.06051_f32, CompareOperator::Lte)
256            .unwrap()
257            .unwrap();
258        let expected = BoolArray::from_iter([true; 10]);
259        assert_arrays_eq!(r_lte, expected);
260
261        //0.0605_f32 < 0.06051_f32;
262        let r_lt = alp_scalar_compare(&encoded, 0.06051_f32, CompareOperator::Lt)
263            .unwrap()
264            .unwrap();
265        let expected = BoolArray::from_iter([true; 10]);
266        assert_arrays_eq!(r_lt, expected);
267    }
268
269    #[test]
270    fn comparison_zeroes() {
271        let array = PrimitiveArray::from_iter([0.0_f32; 10]);
272        let encoded = alp_encode(&array, None).unwrap();
273        assert!(encoded.patches().is_none());
274        assert_eq!(
275            encoded.encoded().to_primitive().as_slice::<i32>(),
276            vec![0; 10]
277        );
278
279        let r_gte = test_alp_compare(&encoded, -0.00000001_f32, CompareOperator::Gte).unwrap();
280        let expected = BoolArray::from_iter([true; 10]);
281        assert_arrays_eq!(r_gte, expected);
282
283        let r_gte = test_alp_compare(&encoded, -0.0_f32, CompareOperator::Gte).unwrap();
284        let expected = BoolArray::from_iter([true; 10]);
285        assert_arrays_eq!(r_gte, expected);
286
287        let r_gt = test_alp_compare(&encoded, -0.0000000001f32, CompareOperator::Gt).unwrap();
288        let expected = BoolArray::from_iter([true; 10]);
289        assert_arrays_eq!(r_gt, expected);
290
291        let r_gte = test_alp_compare(&encoded, -0.0_f32, CompareOperator::Gt).unwrap();
292        let expected = BoolArray::from_iter([true; 10]);
293        assert_arrays_eq!(r_gte, expected);
294
295        let r_lte = test_alp_compare(&encoded, 0.06051_f32, CompareOperator::Lte).unwrap();
296        let expected = BoolArray::from_iter([true; 10]);
297        assert_arrays_eq!(r_lte, expected);
298
299        let r_lt = test_alp_compare(&encoded, 0.06051_f32, CompareOperator::Lt).unwrap();
300        let expected = BoolArray::from_iter([true; 10]);
301        assert_arrays_eq!(r_lt, expected);
302
303        let r_lt = test_alp_compare(&encoded, -0.00001_f32, CompareOperator::Lt).unwrap();
304        let expected = BoolArray::from_iter([false; 10]);
305        assert_arrays_eq!(r_lt, expected);
306    }
307
308    #[test]
309    fn compare_with_patches() {
310        let array =
311            PrimitiveArray::from_iter([1.234f32, 1.5, 19.0, std::f32::consts::E, 1_000_000.9]);
312        let encoded = alp_encode(&array, None).unwrap();
313        assert!(encoded.patches().is_some());
314
315        // Not supported!
316        assert!(
317            alp_scalar_compare(&encoded, 1_000_000.9_f32, CompareOperator::Eq)
318                .unwrap()
319                .is_none()
320        )
321    }
322
323    #[test]
324    fn compare_to_null() {
325        let array = PrimitiveArray::from_iter([1.234f32; 10]);
326        let encoded = alp_encode(&array, None).unwrap();
327
328        let other = ConstantArray::new(
329            Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
330            array.len(),
331        );
332
333        let r = encoded
334            .into_array()
335            .binary(other.into_array(), Operator::Eq)
336            .unwrap();
337        // Comparing to null yields null results
338        let expected = BoolArray::from_iter([None::<bool>; 10]);
339        assert_arrays_eq!(r, expected);
340    }
341
342    #[rstest]
343    #[case(f32::NAN, false)]
344    #[case(-1.0f32 / 0.0f32, true)]
345    #[case(f32::INFINITY, false)]
346    #[case(f32::NEG_INFINITY, true)]
347    fn compare_to_non_finite_gt(#[case] value: f32, #[case] result: bool) {
348        let array = PrimitiveArray::from_iter([1.234f32; 10]);
349        let encoded = alp_encode(&array, None).unwrap();
350
351        let r = test_alp_compare(&encoded, value, CompareOperator::Gt).unwrap();
352        let expected = BoolArray::from_iter([result; 10]);
353        assert_arrays_eq!(r, expected);
354    }
355
356    #[rstest]
357    #[case(f32::NAN, true)]
358    #[case(-1.0f32 / 0.0f32, false)]
359    #[case(f32::INFINITY, true)]
360    #[case(f32::NEG_INFINITY, false)]
361    fn compare_to_non_finite_lt(#[case] value: f32, #[case] result: bool) {
362        let array = PrimitiveArray::from_iter([1.234f32; 10]);
363        let encoded = alp_encode(&array, None).unwrap();
364
365        let r = test_alp_compare(&encoded, value, CompareOperator::Lt).unwrap();
366        let expected = BoolArray::from_iter([result; 10]);
367        assert_arrays_eq!(r, expected);
368    }
369}