Skip to main content

vortex_alp/alp/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use vortex_array::ArrayRef;
7use vortex_array::ArrayView;
8use vortex_array::ExecutionCtx;
9use vortex_array::IntoArray;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::builtins::ArrayBuiltins;
12use vortex_array::dtype::NativePType;
13use vortex_array::scalar::Scalar;
14use vortex_array::scalar_fn::fns::binary::CompareKernel;
15use vortex_array::scalar_fn::fns::operators::CompareOperator;
16use vortex_array::scalar_fn::fns::operators::Operator;
17use vortex_error::VortexResult;
18use vortex_error::vortex_bail;
19use vortex_error::vortex_err;
20
21use crate::ALP;
22use crate::ALPArrayExt;
23use crate::ALPArraySlotsExt;
24use crate::ALPFloat;
25use crate::match_each_alp_float_ptype;
26
27// TODO(joe): add fuzzing.
28
29impl CompareKernel for ALP {
30    fn compare(
31        lhs: ArrayView<'_, Self>,
32        rhs: &ArrayRef,
33        operator: CompareOperator,
34        _ctx: &mut ExecutionCtx,
35    ) -> VortexResult<Option<ArrayRef>> {
36        if lhs.patches().is_some() {
37            // TODO(joe): support patches
38            return Ok(None);
39        }
40        if lhs.dtype().is_nullable() || rhs.dtype().is_nullable() {
41            // TODO(joe): support nullability
42            return Ok(None);
43        }
44
45        if let Some(const_scalar) = rhs.as_constant() {
46            let pscalar = const_scalar.as_primitive_opt().ok_or_else(|| {
47                vortex_err!(
48                    "ALP Compare RHS had the wrong type {}, expected {}",
49                    const_scalar,
50                    const_scalar.dtype()
51                )
52            })?;
53
54            match_each_alp_float_ptype!(pscalar.ptype(), |T| {
55                match pscalar.typed_value::<T>() {
56                    Some(value) => return alp_scalar_compare(lhs, value, operator),
57                    None => vortex_bail!(
58                        "Failed to convert scalar {:?} to ALP type {:?}",
59                        pscalar,
60                        pscalar.ptype()
61                    ),
62                }
63            });
64        }
65
66        Ok(None)
67    }
68}
69
70/// We can compare a scalar to an ALPArray by encoding the scalar into the ALP domain and comparing
71/// the encoded value to the encoded values in the ALPArray. There are fixups when the value doesn't
72/// encode into the ALP domain.
73fn alp_scalar_compare<F: ALPFloat + Into<Scalar>>(
74    alp: ArrayView<ALP>,
75    value: F,
76    operator: CompareOperator,
77) -> VortexResult<Option<ArrayRef>>
78where
79    F::ALPInt: Into<Scalar>,
80    <F as ALPFloat>::ALPInt: Debug,
81{
82    // TODO(joe): support patches, this is checked above.
83    if alp.patches().is_some() {
84        return Ok(None);
85    }
86
87    let exponents = alp.exponents();
88    // If the scalar doesn't fit into the ALP domain,
89    // it cannot be equal to any values in the encoded array.
90    let encoded = F::encode_single(value, alp.exponents());
91    match encoded {
92        Some(encoded) => {
93            let s = ConstantArray::new(encoded, alp.len());
94            Ok(Some(
95                alp.encoded()
96                    .binary(s.into_array(), Operator::from(operator))?,
97            ))
98        }
99        None => match operator {
100            // Since this value is not encodable it cannot be equal to any value in the encoded
101            // array.
102            CompareOperator::Eq => Ok(Some(ConstantArray::new(false, alp.len()).into_array())),
103            // Since this value is not encodable it cannot be equal to any value in the encoded
104            // array, hence != to all values in the encoded array.
105            CompareOperator::NotEq => Ok(Some(ConstantArray::new(true, alp.len()).into_array())),
106            CompareOperator::Gt | CompareOperator::Gte => {
107                // Per IEEE 754 totalOrder semantics the ordering is -Nan < -Inf < Inf < Nan.
108                // All values in the encoded array are definitely finite
109                let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
110                if is_not_finite {
111                    Ok(Some(
112                        ConstantArray::new(value.is_sign_negative(), alp.len()).into_array(),
113                    ))
114                } else {
115                    Ok(Some(
116                        alp.encoded().binary(
117                            ConstantArray::new(F::encode_above(value, exponents), alp.len())
118                                .into_array(),
119                            // Since the encoded value is unencodable gte is equivalent to gt.
120                            // Consider a value v, between two encodable values v_l (just less) and
121                            // v_a (just above), then for all encodable values (u), v > u <=> v_g >= u
122                            Operator::Gte,
123                        )?,
124                    ))
125                }
126            }
127            CompareOperator::Lt | CompareOperator::Lte => {
128                // Per IEEE 754 totalOrder semantics the ordering is -Nan < -Inf < Inf < Nan.
129                // All values in the encoded array are definitely finite
130                let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
131                if is_not_finite {
132                    Ok(Some(
133                        ConstantArray::new(value.is_sign_positive(), alp.len()).into_array(),
134                    ))
135                } else {
136                    Ok(Some(
137                        alp.encoded().binary(
138                            ConstantArray::new(F::encode_below(value, exponents), alp.len())
139                                .into_array(),
140                            // Since the encoded values unencodable lt is equivalent to lte.
141                            // See Gt | Gte for further explanation.
142                            Operator::Lte,
143                        )?,
144                    ))
145                }
146            }
147        },
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use std::f32;
154
155    use rstest::rstest;
156    use vortex_array::ArrayRef;
157    use vortex_array::LEGACY_SESSION;
158    use vortex_array::VortexSessionExecute;
159    use vortex_array::arrays::BoolArray;
160    use vortex_array::arrays::ConstantArray;
161    use vortex_array::arrays::PrimitiveArray;
162    use vortex_array::assert_arrays_eq;
163    use vortex_array::builtins::ArrayBuiltins;
164    use vortex_array::dtype::DType;
165    use vortex_array::dtype::Nullability;
166    use vortex_array::dtype::PType;
167    use vortex_array::scalar::Scalar;
168    use vortex_array::scalar_fn::fns::operators::CompareOperator;
169    use vortex_array::scalar_fn::fns::operators::Operator;
170
171    use super::*;
172    use crate::alp_encode;
173
174    fn test_alp_compare<F: ALPFloat + Into<Scalar>>(
175        alp: ArrayView<ALP>,
176        value: F,
177        operator: CompareOperator,
178    ) -> Option<ArrayRef>
179    where
180        F::ALPInt: Into<Scalar>,
181        <F as ALPFloat>::ALPInt: Debug,
182    {
183        alp_scalar_compare(alp, value, operator).unwrap()
184    }
185
186    #[test]
187    fn basic_comparison_test() {
188        let mut ctx = LEGACY_SESSION.create_execution_ctx();
189        let array = PrimitiveArray::from_iter([1.234f32; 1025]);
190        let encoded = alp_encode(array.as_view(), None, &mut ctx).unwrap();
191        assert!(encoded.patches().is_none());
192        let encoded_prim = encoded
193            .encoded()
194            .clone()
195            .execute::<PrimitiveArray>(&mut ctx)
196            .unwrap();
197        assert_eq!(encoded_prim.as_slice::<i32>(), vec![1234; 1025]);
198
199        let r = alp_scalar_compare(encoded.as_view(), 1.3_f32, CompareOperator::Eq)
200            .unwrap()
201            .unwrap();
202        let expected = BoolArray::from_iter([false; 1025]);
203        assert_arrays_eq!(r, expected);
204
205        let r = alp_scalar_compare(encoded.as_view(), 1.234f32, CompareOperator::Eq)
206            .unwrap()
207            .unwrap();
208        let expected = BoolArray::from_iter([true; 1025]);
209        assert_arrays_eq!(r, expected);
210    }
211
212    #[test]
213    fn comparison_with_unencodable_value() {
214        let mut ctx = LEGACY_SESSION.create_execution_ctx();
215        let array = PrimitiveArray::from_iter([1.234f32; 1025]);
216        let encoded = alp_encode(array.as_view(), None, &mut ctx).unwrap();
217        assert!(encoded.patches().is_none());
218        let encoded_prim = encoded
219            .encoded()
220            .clone()
221            .execute::<PrimitiveArray>(&mut ctx)
222            .unwrap();
223        assert_eq!(encoded_prim.as_slice::<i32>(), vec![1234; 1025]);
224
225        let r_eq = alp_scalar_compare(encoded.as_view(), 1.234444_f32, CompareOperator::Eq)
226            .unwrap()
227            .unwrap();
228        let expected = BoolArray::from_iter([false; 1025]);
229        assert_arrays_eq!(r_eq, expected);
230
231        let r_neq = alp_scalar_compare(encoded.as_view(), 1.234444f32, CompareOperator::NotEq)
232            .unwrap()
233            .unwrap();
234        let expected = BoolArray::from_iter([true; 1025]);
235        assert_arrays_eq!(r_neq, expected);
236    }
237
238    #[test]
239    fn comparison_range() {
240        let mut ctx = LEGACY_SESSION.create_execution_ctx();
241        let array = PrimitiveArray::from_iter([0.0605_f32; 10]);
242        let encoded = alp_encode(array.as_view(), None, &mut ctx).unwrap();
243        assert!(encoded.patches().is_none());
244        let encoded_prim = encoded
245            .encoded()
246            .clone()
247            .execute::<PrimitiveArray>(&mut ctx)
248            .unwrap();
249        assert_eq!(encoded_prim.as_slice::<i32>(), vec![605; 10]);
250
251        // !(0.0605_f32 >= 0.06051_f32);
252        let r_gte = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Gte)
253            .unwrap()
254            .unwrap();
255        let expected = BoolArray::from_iter([false; 10]);
256        assert_arrays_eq!(r_gte, expected);
257
258        // (0.0605_f32 > 0.06051_f32);
259        let r_gt = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Gt)
260            .unwrap()
261            .unwrap();
262        let expected = BoolArray::from_iter([false; 10]);
263        assert_arrays_eq!(r_gt, expected);
264
265        // 0.0605_f32 <= 0.06051_f32;
266        let r_lte = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lte)
267            .unwrap()
268            .unwrap();
269        let expected = BoolArray::from_iter([true; 10]);
270        assert_arrays_eq!(r_lte, expected);
271
272        //0.0605_f32 < 0.06051_f32;
273        let r_lt = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lt)
274            .unwrap()
275            .unwrap();
276        let expected = BoolArray::from_iter([true; 10]);
277        assert_arrays_eq!(r_lt, expected);
278    }
279
280    #[test]
281    fn comparison_zeroes() {
282        let mut ctx = LEGACY_SESSION.create_execution_ctx();
283        let array = PrimitiveArray::from_iter([0.0_f32; 10]);
284        let encoded = alp_encode(array.as_view(), None, &mut ctx).unwrap();
285        assert!(encoded.patches().is_none());
286        let encoded_prim = encoded
287            .encoded()
288            .clone()
289            .execute::<PrimitiveArray>(&mut ctx)
290            .unwrap();
291        assert_eq!(encoded_prim.as_slice::<i32>(), vec![0; 10]);
292
293        let r_gte =
294            test_alp_compare(encoded.as_view(), -0.00000001_f32, CompareOperator::Gte).unwrap();
295        let expected = BoolArray::from_iter([true; 10]);
296        assert_arrays_eq!(r_gte, expected);
297
298        let r_gte = test_alp_compare(encoded.as_view(), -0.0_f32, CompareOperator::Gte).unwrap();
299        let expected = BoolArray::from_iter([true; 10]);
300        assert_arrays_eq!(r_gte, expected);
301
302        let r_gt =
303            test_alp_compare(encoded.as_view(), -0.0000000001f32, CompareOperator::Gt).unwrap();
304        let expected = BoolArray::from_iter([true; 10]);
305        assert_arrays_eq!(r_gt, expected);
306
307        let r_gte = test_alp_compare(encoded.as_view(), -0.0_f32, CompareOperator::Gt).unwrap();
308        let expected = BoolArray::from_iter([true; 10]);
309        assert_arrays_eq!(r_gte, expected);
310
311        let r_lte = test_alp_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lte).unwrap();
312        let expected = BoolArray::from_iter([true; 10]);
313        assert_arrays_eq!(r_lte, expected);
314
315        let r_lt = test_alp_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lt).unwrap();
316        let expected = BoolArray::from_iter([true; 10]);
317        assert_arrays_eq!(r_lt, expected);
318
319        let r_lt = test_alp_compare(encoded.as_view(), -0.00001_f32, CompareOperator::Lt).unwrap();
320        let expected = BoolArray::from_iter([false; 10]);
321        assert_arrays_eq!(r_lt, expected);
322    }
323
324    #[test]
325    fn compare_with_patches() {
326        let array = PrimitiveArray::from_iter([1.234f32, 1.5, 19.0, f32::consts::E, 1_000_000.9]);
327        let encoded = alp_encode(
328            array.as_view(),
329            None,
330            &mut LEGACY_SESSION.create_execution_ctx(),
331        )
332        .unwrap();
333        assert!(encoded.patches().is_some());
334
335        // Not supported!
336        assert!(
337            alp_scalar_compare(encoded.as_view(), 1_000_000.9_f32, CompareOperator::Eq)
338                .unwrap()
339                .is_none()
340        )
341    }
342
343    #[test]
344    fn compare_to_null() {
345        let array = PrimitiveArray::from_iter([1.234f32; 10]);
346        let encoded = alp_encode(
347            array.as_view(),
348            None,
349            &mut LEGACY_SESSION.create_execution_ctx(),
350        )
351        .unwrap();
352
353        let other = ConstantArray::new(
354            Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
355            array.len(),
356        );
357
358        let r = encoded
359            .into_array()
360            .binary(other.into_array(), Operator::Eq)
361            .unwrap();
362        // Comparing to null yields null results
363        let expected = BoolArray::from_iter([None::<bool>; 10]);
364        assert_arrays_eq!(r, expected);
365    }
366
367    #[rstest]
368    #[case(f32::NAN, false)]
369    #[case(-1.0f32 / 0.0f32, true)]
370    #[case(f32::INFINITY, false)]
371    #[case(f32::NEG_INFINITY, true)]
372    fn compare_to_non_finite_gt(#[case] value: f32, #[case] result: bool) {
373        let array = PrimitiveArray::from_iter([1.234f32; 10]);
374        let encoded = alp_encode(
375            array.as_view(),
376            None,
377            &mut LEGACY_SESSION.create_execution_ctx(),
378        )
379        .unwrap();
380
381        let r = test_alp_compare(encoded.as_view(), value, CompareOperator::Gt).unwrap();
382        let expected = BoolArray::from_iter([result; 10]);
383        assert_arrays_eq!(r, expected);
384    }
385
386    #[rstest]
387    #[case(f32::NAN, true)]
388    #[case(-1.0f32 / 0.0f32, false)]
389    #[case(f32::INFINITY, true)]
390    #[case(f32::NEG_INFINITY, false)]
391    fn compare_to_non_finite_lt(#[case] value: f32, #[case] result: bool) {
392        let array = PrimitiveArray::from_iter([1.234f32; 10]);
393        let encoded = alp_encode(
394            array.as_view(),
395            None,
396            &mut LEGACY_SESSION.create_execution_ctx(),
397        )
398        .unwrap();
399
400        let r = test_alp_compare(encoded.as_view(), value, CompareOperator::Lt).unwrap();
401        let expected = BoolArray::from_iter([result; 10]);
402        assert_arrays_eq!(r, expected);
403    }
404}