Skip to main content

vortex_alp/alp/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use vortex_array::ArrayRef;
7use vortex_array::ArrayView;
8use vortex_array::ExecutionCtx;
9use vortex_array::IntoArray;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::builtins::ArrayBuiltins;
12use vortex_array::dtype::NativePType;
13use vortex_array::scalar::Scalar;
14use vortex_array::scalar_fn::fns::binary::CompareKernel;
15use vortex_array::scalar_fn::fns::operators::CompareOperator;
16use vortex_array::scalar_fn::fns::operators::Operator;
17use vortex_error::VortexResult;
18use vortex_error::vortex_bail;
19use vortex_error::vortex_err;
20
21use crate::ALP;
22use crate::ALPArrayExt;
23use crate::ALPArraySlotsExt;
24use crate::ALPFloat;
25use crate::match_each_alp_float_ptype;
26
27// TODO(joe): add fuzzing.
28
29impl CompareKernel for ALP {
30    fn compare(
31        lhs: ArrayView<'_, Self>,
32        rhs: &ArrayRef,
33        operator: CompareOperator,
34        _ctx: &mut ExecutionCtx,
35    ) -> VortexResult<Option<ArrayRef>> {
36        if lhs.patches().is_some() {
37            // TODO(joe): support patches
38            return Ok(None);
39        }
40        if lhs.dtype().is_nullable() || rhs.dtype().is_nullable() {
41            // TODO(joe): support nullability
42            return Ok(None);
43        }
44
45        if let Some(const_scalar) = rhs.as_constant() {
46            let pscalar = const_scalar.as_primitive_opt().ok_or_else(|| {
47                vortex_err!(
48                    "ALP Compare RHS had the wrong type {}, expected {}",
49                    const_scalar,
50                    const_scalar.dtype()
51                )
52            })?;
53
54            match_each_alp_float_ptype!(pscalar.ptype(), |T| {
55                match pscalar.typed_value::<T>() {
56                    Some(value) => return alp_scalar_compare(lhs, value, operator),
57                    None => vortex_bail!(
58                        "Failed to convert scalar {:?} to ALP type {:?}",
59                        pscalar,
60                        pscalar.ptype()
61                    ),
62                }
63            });
64        }
65
66        Ok(None)
67    }
68}
69
70/// We can compare a scalar to an ALPArray by encoding the scalar into the ALP domain and comparing
71/// the encoded value to the encoded values in the ALPArray. There are fixups when the value doesn't
72/// encode into the ALP domain.
73fn alp_scalar_compare<F: ALPFloat + Into<Scalar>>(
74    alp: ArrayView<ALP>,
75    value: F,
76    operator: CompareOperator,
77) -> VortexResult<Option<ArrayRef>>
78where
79    F::ALPInt: Into<Scalar>,
80    <F as ALPFloat>::ALPInt: Debug,
81{
82    // TODO(joe): support patches, this is checked above.
83    if alp.patches().is_some() {
84        return Ok(None);
85    }
86
87    let exponents = alp.exponents();
88    // If the scalar doesn't fit into the ALP domain,
89    // it cannot be equal to any values in the encoded array.
90    let encoded = F::encode_single(value, alp.exponents());
91    match encoded {
92        Some(encoded) => {
93            let s = ConstantArray::new(encoded, alp.len());
94            Ok(Some(
95                alp.encoded()
96                    .binary(s.into_array(), Operator::from(operator))?,
97            ))
98        }
99        None => match operator {
100            // Since this value is not encodable it cannot be equal to any value in the encoded
101            // array.
102            CompareOperator::Eq => Ok(Some(ConstantArray::new(false, alp.len()).into_array())),
103            // Since this value is not encodable it cannot be equal to any value in the encoded
104            // array, hence != to all values in the encoded array.
105            CompareOperator::NotEq => Ok(Some(ConstantArray::new(true, alp.len()).into_array())),
106            CompareOperator::Gt | CompareOperator::Gte => {
107                // Per IEEE 754 totalOrder semantics the ordering is -Nan < -Inf < Inf < Nan.
108                // All values in the encoded array are definitely finite
109                let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
110                if is_not_finite {
111                    Ok(Some(
112                        ConstantArray::new(value.is_sign_negative(), alp.len()).into_array(),
113                    ))
114                } else {
115                    Ok(Some(
116                        alp.encoded().binary(
117                            ConstantArray::new(F::encode_above(value, exponents), alp.len())
118                                .into_array(),
119                            // Since the encoded value is unencodable gte is equivalent to gt.
120                            // Consider a value v, between two encodable values v_l (just less) and
121                            // v_a (just above), then for all encodable values (u), v > u <=> v_g >= u
122                            Operator::Gte,
123                        )?,
124                    ))
125                }
126            }
127            CompareOperator::Lt | CompareOperator::Lte => {
128                // Per IEEE 754 totalOrder semantics the ordering is -Nan < -Inf < Inf < Nan.
129                // All values in the encoded array are definitely finite
130                let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
131                if is_not_finite {
132                    Ok(Some(
133                        ConstantArray::new(value.is_sign_positive(), alp.len()).into_array(),
134                    ))
135                } else {
136                    Ok(Some(
137                        alp.encoded().binary(
138                            ConstantArray::new(F::encode_below(value, exponents), alp.len())
139                                .into_array(),
140                            // Since the encoded values unencodable lt is equivalent to lte.
141                            // See Gt | Gte for further explanation.
142                            Operator::Lte,
143                        )?,
144                    ))
145                }
146            }
147        },
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use rstest::rstest;
154    use vortex_array::ArrayRef;
155    use vortex_array::ToCanonical;
156    use vortex_array::arrays::BoolArray;
157    use vortex_array::arrays::ConstantArray;
158    use vortex_array::arrays::PrimitiveArray;
159    use vortex_array::assert_arrays_eq;
160    use vortex_array::builtins::ArrayBuiltins;
161    use vortex_array::dtype::DType;
162    use vortex_array::dtype::Nullability;
163    use vortex_array::dtype::PType;
164    use vortex_array::scalar::Scalar;
165    use vortex_array::scalar_fn::fns::operators::CompareOperator;
166    use vortex_array::scalar_fn::fns::operators::Operator;
167
168    use super::*;
169    use crate::alp_encode;
170
171    fn test_alp_compare<F: ALPFloat + Into<Scalar>>(
172        alp: ArrayView<ALP>,
173        value: F,
174        operator: CompareOperator,
175    ) -> Option<ArrayRef>
176    where
177        F::ALPInt: Into<Scalar>,
178        <F as ALPFloat>::ALPInt: Debug,
179    {
180        alp_scalar_compare(alp, value, operator).unwrap()
181    }
182
183    #[test]
184    fn basic_comparison_test() {
185        let array = PrimitiveArray::from_iter([1.234f32; 1025]);
186        let encoded = alp_encode(&array, None).unwrap();
187        assert!(encoded.patches().is_none());
188        assert_eq!(
189            encoded.encoded().to_primitive().as_slice::<i32>(),
190            vec![1234; 1025]
191        );
192
193        let r = alp_scalar_compare(encoded.as_view(), 1.3_f32, CompareOperator::Eq)
194            .unwrap()
195            .unwrap();
196        let expected = BoolArray::from_iter([false; 1025]);
197        assert_arrays_eq!(r, expected);
198
199        let r = alp_scalar_compare(encoded.as_view(), 1.234f32, CompareOperator::Eq)
200            .unwrap()
201            .unwrap();
202        let expected = BoolArray::from_iter([true; 1025]);
203        assert_arrays_eq!(r, expected);
204    }
205
206    #[test]
207    fn comparison_with_unencodable_value() {
208        let array = PrimitiveArray::from_iter([1.234f32; 1025]);
209        let encoded = alp_encode(&array, None).unwrap();
210        assert!(encoded.patches().is_none());
211        assert_eq!(
212            encoded.encoded().to_primitive().as_slice::<i32>(),
213            vec![1234; 1025]
214        );
215
216        #[allow(clippy::excessive_precision)]
217        let r_eq = alp_scalar_compare(encoded.as_view(), 1.234444_f32, CompareOperator::Eq)
218            .unwrap()
219            .unwrap();
220        let expected = BoolArray::from_iter([false; 1025]);
221        assert_arrays_eq!(r_eq, expected);
222
223        #[allow(clippy::excessive_precision)]
224        let r_neq = alp_scalar_compare(encoded.as_view(), 1.234444f32, CompareOperator::NotEq)
225            .unwrap()
226            .unwrap();
227        let expected = BoolArray::from_iter([true; 1025]);
228        assert_arrays_eq!(r_neq, expected);
229    }
230
231    #[test]
232    fn comparison_range() {
233        let array = PrimitiveArray::from_iter([0.0605_f32; 10]);
234        let encoded = alp_encode(&array, None).unwrap();
235        assert!(encoded.patches().is_none());
236        assert_eq!(
237            encoded.encoded().to_primitive().as_slice::<i32>(),
238            vec![605; 10]
239        );
240
241        // !(0.0605_f32 >= 0.06051_f32);
242        let r_gte = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Gte)
243            .unwrap()
244            .unwrap();
245        let expected = BoolArray::from_iter([false; 10]);
246        assert_arrays_eq!(r_gte, expected);
247
248        // (0.0605_f32 > 0.06051_f32);
249        let r_gt = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Gt)
250            .unwrap()
251            .unwrap();
252        let expected = BoolArray::from_iter([false; 10]);
253        assert_arrays_eq!(r_gt, expected);
254
255        // 0.0605_f32 <= 0.06051_f32;
256        let r_lte = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lte)
257            .unwrap()
258            .unwrap();
259        let expected = BoolArray::from_iter([true; 10]);
260        assert_arrays_eq!(r_lte, expected);
261
262        //0.0605_f32 < 0.06051_f32;
263        let r_lt = alp_scalar_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lt)
264            .unwrap()
265            .unwrap();
266        let expected = BoolArray::from_iter([true; 10]);
267        assert_arrays_eq!(r_lt, expected);
268    }
269
270    #[test]
271    fn comparison_zeroes() {
272        let array = PrimitiveArray::from_iter([0.0_f32; 10]);
273        let encoded = alp_encode(&array, None).unwrap();
274        assert!(encoded.patches().is_none());
275        assert_eq!(
276            encoded.encoded().to_primitive().as_slice::<i32>(),
277            vec![0; 10]
278        );
279
280        let r_gte =
281            test_alp_compare(encoded.as_view(), -0.00000001_f32, CompareOperator::Gte).unwrap();
282        let expected = BoolArray::from_iter([true; 10]);
283        assert_arrays_eq!(r_gte, expected);
284
285        let r_gte = test_alp_compare(encoded.as_view(), -0.0_f32, CompareOperator::Gte).unwrap();
286        let expected = BoolArray::from_iter([true; 10]);
287        assert_arrays_eq!(r_gte, expected);
288
289        let r_gt =
290            test_alp_compare(encoded.as_view(), -0.0000000001f32, CompareOperator::Gt).unwrap();
291        let expected = BoolArray::from_iter([true; 10]);
292        assert_arrays_eq!(r_gt, expected);
293
294        let r_gte = test_alp_compare(encoded.as_view(), -0.0_f32, CompareOperator::Gt).unwrap();
295        let expected = BoolArray::from_iter([true; 10]);
296        assert_arrays_eq!(r_gte, expected);
297
298        let r_lte = test_alp_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lte).unwrap();
299        let expected = BoolArray::from_iter([true; 10]);
300        assert_arrays_eq!(r_lte, expected);
301
302        let r_lt = test_alp_compare(encoded.as_view(), 0.06051_f32, CompareOperator::Lt).unwrap();
303        let expected = BoolArray::from_iter([true; 10]);
304        assert_arrays_eq!(r_lt, expected);
305
306        let r_lt = test_alp_compare(encoded.as_view(), -0.00001_f32, CompareOperator::Lt).unwrap();
307        let expected = BoolArray::from_iter([false; 10]);
308        assert_arrays_eq!(r_lt, expected);
309    }
310
311    #[test]
312    fn compare_with_patches() {
313        let array =
314            PrimitiveArray::from_iter([1.234f32, 1.5, 19.0, std::f32::consts::E, 1_000_000.9]);
315        let encoded = alp_encode(&array, None).unwrap();
316        assert!(encoded.patches().is_some());
317
318        // Not supported!
319        assert!(
320            alp_scalar_compare(encoded.as_view(), 1_000_000.9_f32, CompareOperator::Eq)
321                .unwrap()
322                .is_none()
323        )
324    }
325
326    #[test]
327    fn compare_to_null() {
328        let array = PrimitiveArray::from_iter([1.234f32; 10]);
329        let encoded = alp_encode(&array, None).unwrap();
330
331        let other = ConstantArray::new(
332            Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
333            array.len(),
334        );
335
336        let r = encoded
337            .into_array()
338            .binary(other.into_array(), Operator::Eq)
339            .unwrap();
340        // Comparing to null yields null results
341        let expected = BoolArray::from_iter([None::<bool>; 10]);
342        assert_arrays_eq!(r, expected);
343    }
344
345    #[rstest]
346    #[case(f32::NAN, false)]
347    #[case(-1.0f32 / 0.0f32, true)]
348    #[case(f32::INFINITY, false)]
349    #[case(f32::NEG_INFINITY, true)]
350    fn compare_to_non_finite_gt(#[case] value: f32, #[case] result: bool) {
351        let array = PrimitiveArray::from_iter([1.234f32; 10]);
352        let encoded = alp_encode(&array, None).unwrap();
353
354        let r = test_alp_compare(encoded.as_view(), value, CompareOperator::Gt).unwrap();
355        let expected = BoolArray::from_iter([result; 10]);
356        assert_arrays_eq!(r, expected);
357    }
358
359    #[rstest]
360    #[case(f32::NAN, true)]
361    #[case(-1.0f32 / 0.0f32, false)]
362    #[case(f32::INFINITY, true)]
363    #[case(f32::NEG_INFINITY, false)]
364    fn compare_to_non_finite_lt(#[case] value: f32, #[case] result: bool) {
365        let array = PrimitiveArray::from_iter([1.234f32; 10]);
366        let encoded = alp_encode(&array, None).unwrap();
367
368        let r = test_alp_compare(encoded.as_view(), value, CompareOperator::Lt).unwrap();
369        let expected = BoolArray::from_iter([result; 10]);
370        assert_arrays_eq!(r, expected);
371    }
372}