Skip to main content

vortex_alp/alp/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use vortex_array::ArrayRef;
7use vortex_array::ArrayView;
8use vortex_array::ExecutionCtx;
9use vortex_array::IntoArray;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::builtins::ArrayBuiltins;
12use vortex_array::dtype::NativePType;
13use vortex_array::scalar::Scalar;
14use vortex_array::scalar_fn::fns::binary::CompareKernel;
15use vortex_array::scalar_fn::fns::operators::CompareOperator;
16use vortex_array::scalar_fn::fns::operators::Operator;
17use vortex_error::VortexResult;
18use vortex_error::vortex_bail;
19use vortex_error::vortex_err;
20
21use crate::ALP;
22use crate::ALPArrayExt;
23use crate::ALPArraySlotsExt;
24use crate::ALPFloat;
25use crate::match_each_alp_float_ptype;
26
27// TODO(joe): add fuzzing.
28
29impl CompareKernel for ALP {
30    fn compare(
31        lhs: ArrayView<'_, Self>,
32        rhs: &ArrayRef,
33        operator: CompareOperator,
34        _ctx: &mut ExecutionCtx,
35    ) -> VortexResult<Option<ArrayRef>> {
36        if lhs.patches().is_some() {
37            // TODO(joe): support patches
38            return Ok(None);
39        }
40        if lhs.dtype().is_nullable() || rhs.dtype().is_nullable() {
41            // TODO(joe): support nullability
42            return Ok(None);
43        }
44
45        if let Some(const_scalar) = rhs.as_constant() {
46            let pscalar = const_scalar.as_primitive_opt().ok_or_else(|| {
47                vortex_err!(
48                    "ALP Compare RHS had the wrong type {}, expected {}",
49                    const_scalar,
50                    const_scalar.dtype()
51                )
52            })?;
53
54            match_each_alp_float_ptype!(pscalar.ptype(), |T| {
55                match pscalar.typed_value::<T>() {
56                    Some(value) => return alp_scalar_compare(lhs, value, operator),
57                    None => vortex_bail!(
58                        "Failed to convert scalar {:?} to ALP type {:?}",
59                        pscalar,
60                        pscalar.ptype()
61                    ),
62                }
63            });
64        }
65
66        Ok(None)
67    }
68}
69
70/// We can compare a scalar to an ALPArray by encoding the scalar into the ALP domain and comparing
71/// the encoded value to the encoded values in the ALPArray. There are fixups when the value doesn't
72/// encode into the ALP domain.
73fn alp_scalar_compare<F: ALPFloat + Into<Scalar>>(
74    alp: ArrayView<ALP>,
75    value: F,
76    operator: CompareOperator,
77) -> VortexResult<Option<ArrayRef>>
78where
79    F::ALPInt: Into<Scalar>,
80    <F as ALPFloat>::ALPInt: Debug,
81{
82    // TODO(joe): support patches, this is checked above.
83    if alp.patches().is_some() {
84        return Ok(None);
85    }
86
87    let exponents = alp.exponents();
88    // If the scalar doesn't fit into the ALP domain,
89    // it cannot be equal to any values in the encoded array.
90    let encoded = F::encode_single(value, alp.exponents());
91    match encoded {
92        Some(encoded) => {
93            let s = ConstantArray::new(encoded, alp.len());
94            Ok(Some(
95                alp.encoded()
96                    .binary(s.into_array(), Operator::from(operator))?,
97            ))
98        }
99        None => match operator {
100            // Since this value is not encodable it cannot be equal to any value in the encoded
101            // array.
102            CompareOperator::Eq => Ok(Some(ConstantArray::new(false, alp.len()).into_array())),
103            // Since this value is not encodable it cannot be equal to any value in the encoded
104            // array, hence != to all values in the encoded array.
105            CompareOperator::NotEq => Ok(Some(ConstantArray::new(true, alp.len()).into_array())),
106            CompareOperator::Gt | CompareOperator::Gte => {
107                // Per IEEE 754 totalOrder semantics the ordering is -Nan < -Inf < Inf < Nan.
108                // All values in the encoded array are definitely finite
109                let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
110                if is_not_finite {
111                    Ok(Some(
112                        ConstantArray::new(value.is_sign_negative(), alp.len()).into_array(),
113                    ))
114                } else {
115                    Ok(Some(
116                        alp.encoded().binary(
117                            ConstantArray::new(F::encode_above(value, exponents), alp.len())
118                                .into_array(),
119                            // Since the encoded value is unencodable gte is equivalent to gt.
120                            // Consider a value v, between two encodable values v_l (just less) and
121                            // v_a (just above), then for all encodable values (u), v > u <=> v_g >= u
122                            Operator::Gte,
123                        )?,
124                    ))
125                }
126            }
127            CompareOperator::Lt | CompareOperator::Lte => {
128                // Per IEEE 754 totalOrder semantics the ordering is -Nan < -Inf < Inf < Nan.
129                // All values in the encoded array are definitely finite
130                let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
131                if is_not_finite {
132                    Ok(Some(
133                        ConstantArray::new(value.is_sign_positive(), alp.len()).into_array(),
134                    ))
135                } else {
136                    Ok(Some(
137                        alp.encoded().binary(
138                            ConstantArray::new(F::encode_below(value, exponents), alp.len())
139                                .into_array(),
140                            // Since the encoded values unencodable lt is equivalent to lte.
141                            // See Gt | Gte for further explanation.
142                            Operator::Lte,
143                        )?,
144                    ))
145                }
146            }
147        },
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use rstest::rstest;
154    use vortex_array::ArrayRef;
155    use vortex_array::LEGACY_SESSION;
156    use vortex_array::VortexSessionExecute;
157    use vortex_array::arrays::BoolArray;
158    use vortex_array::arrays::ConstantArray;
159    use vortex_array::arrays::PrimitiveArray;
160    use vortex_array::assert_arrays_eq;
161    use vortex_array::builtins::ArrayBuiltins;
162    use vortex_array::dtype::DType;
163    use vortex_array::dtype::Nullability;
164    use vortex_array::dtype::PType;
165    use vortex_array::scalar::Scalar;
166    use vortex_array::scalar_fn::fns::operators::CompareOperator;
167    use vortex_array::scalar_fn::fns::operators::Operator;
168
169    use super::*;
170    use crate::alp_encode;
171
172    fn test_alp_compare<F: ALPFloat + Into<Scalar>>(
173        alp: ArrayView<ALP>,
174        value: F,
175        operator: CompareOperator,
176    ) -> Option<ArrayRef>
177    where
178        F::ALPInt: Into<Scalar>,
179        <F as ALPFloat>::ALPInt: Debug,
180    {
181        alp_scalar_compare(alp, value, operator).unwrap()
182    }
183
184    #[test]
185    fn basic_comparison_test() {
186        let mut ctx = LEGACY_SESSION.create_execution_ctx();
187        let array = PrimitiveArray::from_iter([1.234f32; 1025]);
188        let encoded = alp_encode(array.as_view(), None, &mut ctx).unwrap();
189        assert!(encoded.patches().is_none());
190        let encoded_prim = encoded
191            .encoded()
192            .clone()
193            .execute::<PrimitiveArray>(&mut ctx)
194            .unwrap();
195        assert_eq!(encoded_prim.as_slice::<i32>(), vec![1234; 1025]);
196
197        let r = alp_scalar_compare(encoded.as_view(), 1.3_f32, CompareOperator::Eq)
198            .unwrap()
199            .unwrap();
200        let expected = BoolArray::from_iter([false; 1025]);
201        assert_arrays_eq!(r, expected);
202
203        let r = alp_scalar_compare(encoded.as_view(), 1.234f32, CompareOperator::Eq)
204            .unwrap()
205            .unwrap();
206        let expected = BoolArray::from_iter([true; 1025]);
207        assert_arrays_eq!(r, expected);
208    }
209
210    #[test]
211    fn comparison_with_unencodable_value() {
212        let mut ctx = LEGACY_SESSION.create_execution_ctx();
213        let array = PrimitiveArray::from_iter([1.234f32; 1025]);
214        let encoded = alp_encode(array.as_view(), None, &mut ctx).unwrap();
215        assert!(encoded.patches().is_none());
216        let encoded_prim = encoded
217            .encoded()
218            .clone()
219            .execute::<PrimitiveArray>(&mut ctx)
220            .unwrap();
221        assert_eq!(encoded_prim.as_slice::<i32>(), vec![1234; 1025]);
222
223        let r_eq = alp_scalar_compare(encoded.as_view(), 1.234444_f32, CompareOperator::Eq)
224            .unwrap()
225            .unwrap();
226        let expected = BoolArray::from_iter([false; 1025]);
227        assert_arrays_eq!(r_eq, expected);
228
229        let r_neq = alp_scalar_compare(encoded.as_view(), 1.234444f32, CompareOperator::NotEq)
230            .unwrap()
231            .unwrap();
232        let expected = BoolArray::from_iter([true; 1025]);
233        assert_arrays_eq!(r_neq, expected);
234    }
235
236    #[test]
237    fn comparison_range() {
238        let mut ctx = LEGACY_SESSION.create_execution_ctx();
239        let array = PrimitiveArray::from_iter([0.0605_f32; 10]);
240        let encoded = alp_encode(array.as_view(), None, &mut ctx).unwrap();
241        assert!(encoded.patches().is_none());
242        let encoded_prim = encoded
243            .encoded()
244            .clone()
245            .execute::<PrimitiveArray>(&mut ctx)
246            .unwrap();
247        assert_eq!(encoded_prim.as_slice::<i32>(), vec![605; 10]);
248
249        // !(0.0605_f32 >= 0.06051_f32);
250        let r_gte = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Gte)
251            .unwrap()
252            .unwrap();
253        let expected = BoolArray::from_iter([false; 10]);
254        assert_arrays_eq!(r_gte, expected);
255
256        // (0.0605_f32 > 0.06051_f32);
257        let r_gt = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Gt)
258            .unwrap()
259            .unwrap();
260        let expected = BoolArray::from_iter([false; 10]);
261        assert_arrays_eq!(r_gt, expected);
262
263        // 0.0605_f32 <= 0.06051_f32;
264        let r_lte = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lte)
265            .unwrap()
266            .unwrap();
267        let expected = BoolArray::from_iter([true; 10]);
268        assert_arrays_eq!(r_lte, expected);
269
270        //0.0605_f32 < 0.06051_f32;
271        let r_lt = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lt)
272            .unwrap()
273            .unwrap();
274        let expected = BoolArray::from_iter([true; 10]);
275        assert_arrays_eq!(r_lt, expected);
276    }
277
278    #[test]
279    fn comparison_zeroes() {
280        let mut ctx = LEGACY_SESSION.create_execution_ctx();
281        let array = PrimitiveArray::from_iter([0.0_f32; 10]);
282        let encoded = alp_encode(array.as_view(), None, &mut ctx).unwrap();
283        assert!(encoded.patches().is_none());
284        let encoded_prim = encoded
285            .encoded()
286            .clone()
287            .execute::<PrimitiveArray>(&mut ctx)
288            .unwrap();
289        assert_eq!(encoded_prim.as_slice::<i32>(), vec![0; 10]);
290
291        let r_gte =
292            test_alp_compare(encoded.as_view(), -0.00000001_f32, CompareOperator::Gte).unwrap();
293        let expected = BoolArray::from_iter([true; 10]);
294        assert_arrays_eq!(r_gte, expected);
295
296        let r_gte = test_alp_compare(encoded.as_view(), -0.0_f32, CompareOperator::Gte).unwrap();
297        let expected = BoolArray::from_iter([true; 10]);
298        assert_arrays_eq!(r_gte, expected);
299
300        let r_gt =
301            test_alp_compare(encoded.as_view(), -0.0000000001f32, CompareOperator::Gt).unwrap();
302        let expected = BoolArray::from_iter([true; 10]);
303        assert_arrays_eq!(r_gt, expected);
304
305        let r_gte = test_alp_compare(encoded.as_view(), -0.0_f32, CompareOperator::Gt).unwrap();
306        let expected = BoolArray::from_iter([true; 10]);
307        assert_arrays_eq!(r_gte, expected);
308
309        let r_lte = test_alp_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lte).unwrap();
310        let expected = BoolArray::from_iter([true; 10]);
311        assert_arrays_eq!(r_lte, expected);
312
313        let r_lt = test_alp_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lt).unwrap();
314        let expected = BoolArray::from_iter([true; 10]);
315        assert_arrays_eq!(r_lt, expected);
316
317        let r_lt = test_alp_compare(encoded.as_view(), -0.00001_f32, CompareOperator::Lt).unwrap();
318        let expected = BoolArray::from_iter([false; 10]);
319        assert_arrays_eq!(r_lt, expected);
320    }
321
322    #[test]
323    fn compare_with_patches() {
324        let array =
325            PrimitiveArray::from_iter([1.234f32, 1.5, 19.0, std::f32::consts::E, 1_000_000.9]);
326        let encoded = alp_encode(
327            array.as_view(),
328            None,
329            &mut LEGACY_SESSION.create_execution_ctx(),
330        )
331        .unwrap();
332        assert!(encoded.patches().is_some());
333
334        // Not supported!
335        assert!(
336            alp_scalar_compare(encoded.as_view(), 1_000_000.9_f32, CompareOperator::Eq)
337                .unwrap()
338                .is_none()
339        )
340    }
341
342    #[test]
343    fn compare_to_null() {
344        let array = PrimitiveArray::from_iter([1.234f32; 10]);
345        let encoded = alp_encode(
346            array.as_view(),
347            None,
348            &mut LEGACY_SESSION.create_execution_ctx(),
349        )
350        .unwrap();
351
352        let other = ConstantArray::new(
353            Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
354            array.len(),
355        );
356
357        let r = encoded
358            .into_array()
359            .binary(other.into_array(), Operator::Eq)
360            .unwrap();
361        // Comparing to null yields null results
362        let expected = BoolArray::from_iter([None::<bool>; 10]);
363        assert_arrays_eq!(r, expected);
364    }
365
366    #[rstest]
367    #[case(f32::NAN, false)]
368    #[case(-1.0f32 / 0.0f32, true)]
369    #[case(f32::INFINITY, false)]
370    #[case(f32::NEG_INFINITY, true)]
371    fn compare_to_non_finite_gt(#[case] value: f32, #[case] result: bool) {
372        let array = PrimitiveArray::from_iter([1.234f32; 10]);
373        let encoded = alp_encode(
374            array.as_view(),
375            None,
376            &mut LEGACY_SESSION.create_execution_ctx(),
377        )
378        .unwrap();
379
380        let r = test_alp_compare(encoded.as_view(), value, CompareOperator::Gt).unwrap();
381        let expected = BoolArray::from_iter([result; 10]);
382        assert_arrays_eq!(r, expected);
383    }
384
385    #[rstest]
386    #[case(f32::NAN, true)]
387    #[case(-1.0f32 / 0.0f32, false)]
388    #[case(f32::INFINITY, true)]
389    #[case(f32::NEG_INFINITY, false)]
390    fn compare_to_non_finite_lt(#[case] value: f32, #[case] result: bool) {
391        let array = PrimitiveArray::from_iter([1.234f32; 10]);
392        let encoded = alp_encode(
393            array.as_view(),
394            None,
395            &mut LEGACY_SESSION.create_execution_ctx(),
396        )
397        .unwrap();
398
399        let r = test_alp_compare(encoded.as_view(), value, CompareOperator::Lt).unwrap();
400        let expected = BoolArray::from_iter([result; 10]);
401        assert_arrays_eq!(r, expected);
402    }
403}