Skip to main content

vortex_alp/alp/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use vortex_array::Array;
7use vortex_array::ArrayRef;
8use vortex_array::ExecutionCtx;
9use vortex_array::IntoArray;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::compute::Operator;
12use vortex_array::compute::compare;
13use vortex_array::expr::CompareKernel;
14use vortex_array::scalar::Scalar;
15use vortex_dtype::NativePType;
16use vortex_error::VortexResult;
17use vortex_error::vortex_bail;
18use vortex_error::vortex_err;
19
20use crate::ALPArray;
21use crate::ALPFloat;
22use crate::ALPVTable;
23use crate::match_each_alp_float_ptype;
24
25// TODO(joe): add fuzzing.
26
27impl CompareKernel for ALPVTable {
28    fn compare(
29        lhs: &ALPArray,
30        rhs: &dyn Array,
31        operator: Operator,
32        _ctx: &mut ExecutionCtx,
33    ) -> VortexResult<Option<ArrayRef>> {
34        if lhs.patches().is_some() {
35            // TODO(joe): support patches
36            return Ok(None);
37        }
38        if lhs.dtype().is_nullable() || rhs.dtype().is_nullable() {
39            // TODO(joe): support nullability
40            return Ok(None);
41        }
42
43        if let Some(const_scalar) = rhs.as_constant() {
44            let pscalar = const_scalar.as_primitive_opt().ok_or_else(|| {
45                vortex_err!(
46                    "ALP Compare RHS had the wrong type {}, expected {}",
47                    const_scalar,
48                    const_scalar.dtype()
49                )
50            })?;
51
52            match_each_alp_float_ptype!(pscalar.ptype(), |T| {
53                match pscalar.typed_value::<T>() {
54                    Some(value) => return alp_scalar_compare(lhs, value, operator),
55                    None => vortex_bail!(
56                        "Failed to convert scalar {:?} to ALP type {:?}",
57                        pscalar,
58                        pscalar.ptype()
59                    ),
60                }
61            });
62        }
63
64        Ok(None)
65    }
66}
67
68/// We can compare a scalar to an ALPArray by encoding the scalar into the ALP domain and comparing
69/// the encoded value to the encoded values in the ALPArray. There are fixups when the value doesn't
70/// encode into the ALP domain.
71fn alp_scalar_compare<F: ALPFloat + Into<Scalar>>(
72    alp: &ALPArray,
73    value: F,
74    operator: Operator,
75) -> VortexResult<Option<ArrayRef>>
76where
77    F::ALPInt: Into<Scalar>,
78    <F as ALPFloat>::ALPInt: Debug,
79{
80    // TODO(joe): support patches, this is checked above.
81    if alp.patches().is_some() {
82        return Ok(None);
83    }
84
85    let exponents = alp.exponents();
86    // If the scalar doesn't fit into the ALP domain,
87    // it cannot be equal to any values in the encoded array.
88    let encoded = F::encode_single(value, alp.exponents());
89    match encoded {
90        Some(encoded) => {
91            let s = ConstantArray::new(encoded, alp.len());
92            Ok(Some(compare(alp.encoded(), s.as_ref(), operator)?))
93        }
94        None => match operator {
95            // Since this value is not encodable it cannot be equal to any value in the encoded
96            // array.
97            Operator::Eq => Ok(Some(ConstantArray::new(false, alp.len()).into_array())),
98            // Since this value is not encodable it cannot be equal to any value in the encoded
99            // array, hence != to all values in the encoded array.
100            Operator::NotEq => Ok(Some(ConstantArray::new(true, alp.len()).into_array())),
101            Operator::Gt | Operator::Gte => {
102                // Per IEEE 754 totalOrder semantics the ordering is -Nan < -Inf < Inf < Nan.
103                // All values in the encoded array are definitely finite
104                let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
105                if is_not_finite {
106                    Ok(Some(
107                        ConstantArray::new(value.is_sign_negative(), alp.len()).into_array(),
108                    ))
109                } else {
110                    Ok(Some(compare(
111                        alp.encoded(),
112                        ConstantArray::new(F::encode_above(value, exponents), alp.len()).as_ref(),
113                        // Since the encoded value is unencodable gte is equivalent to gt.
114                        // Consider a value v, between two encodable values v_l (just less) and
115                        // v_a (just above), then for all encodable values (u), v > u <=> v_g >= u
116                        Operator::Gte,
117                    )?))
118                }
119            }
120            Operator::Lt | Operator::Lte => {
121                // Per IEEE 754 totalOrder semantics the ordering is -Nan < -Inf < Inf < Nan.
122                // All values in the encoded array are definitely finite
123                let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value);
124                if is_not_finite {
125                    Ok(Some(
126                        ConstantArray::new(value.is_sign_positive(), alp.len()).into_array(),
127                    ))
128                } else {
129                    Ok(Some(compare(
130                        alp.encoded(),
131                        ConstantArray::new(F::encode_below(value, exponents), alp.len()).as_ref(),
132                        // Since the encoded values unencodable lt is equivalent to lte.
133                        // See Gt | Gte for further explanation.
134                        Operator::Lte,
135                    )?))
136                }
137            }
138        },
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use rstest::rstest;
145    use vortex_array::ArrayRef;
146    use vortex_array::ToCanonical;
147    use vortex_array::arrays::BoolArray;
148    use vortex_array::arrays::ConstantArray;
149    use vortex_array::arrays::PrimitiveArray;
150    use vortex_array::assert_arrays_eq;
151    use vortex_array::compute::Operator;
152    use vortex_array::compute::compare;
153    use vortex_array::scalar::Scalar;
154    use vortex_dtype::DType;
155    use vortex_dtype::Nullability;
156    use vortex_dtype::PType;
157
158    use super::*;
159    use crate::alp_encode;
160
161    fn test_alp_compare<F: ALPFloat + Into<Scalar>>(
162        alp: &ALPArray,
163        value: F,
164        operator: Operator,
165    ) -> Option<ArrayRef>
166    where
167        F::ALPInt: Into<Scalar>,
168        <F as ALPFloat>::ALPInt: Debug,
169    {
170        alp_scalar_compare(alp, value, operator).unwrap()
171    }
172
173    #[test]
174    fn basic_comparison_test() {
175        let array = PrimitiveArray::from_iter([1.234f32; 1025]);
176        let encoded = alp_encode(&array, None).unwrap();
177        assert!(encoded.patches().is_none());
178        assert_eq!(
179            encoded.encoded().to_primitive().as_slice::<i32>(),
180            vec![1234; 1025]
181        );
182
183        let r = alp_scalar_compare(&encoded, 1.3_f32, Operator::Eq)
184            .unwrap()
185            .unwrap();
186        let expected = BoolArray::from_iter([false; 1025]);
187        assert_arrays_eq!(r, expected);
188
189        let r = alp_scalar_compare(&encoded, 1.234f32, Operator::Eq)
190            .unwrap()
191            .unwrap();
192        let expected = BoolArray::from_iter([true; 1025]);
193        assert_arrays_eq!(r, expected);
194    }
195
196    #[test]
197    fn comparison_with_unencodable_value() {
198        let array = PrimitiveArray::from_iter([1.234f32; 1025]);
199        let encoded = alp_encode(&array, None).unwrap();
200        assert!(encoded.patches().is_none());
201        assert_eq!(
202            encoded.encoded().to_primitive().as_slice::<i32>(),
203            vec![1234; 1025]
204        );
205
206        #[allow(clippy::excessive_precision)]
207        let r_eq = alp_scalar_compare(&encoded, 1.234444_f32, Operator::Eq)
208            .unwrap()
209            .unwrap();
210        let expected = BoolArray::from_iter([false; 1025]);
211        assert_arrays_eq!(r_eq, expected);
212
213        #[allow(clippy::excessive_precision)]
214        let r_neq = alp_scalar_compare(&encoded, 1.234444f32, Operator::NotEq)
215            .unwrap()
216            .unwrap();
217        let expected = BoolArray::from_iter([true; 1025]);
218        assert_arrays_eq!(r_neq, expected);
219    }
220
221    #[test]
222    fn comparison_range() {
223        let array = PrimitiveArray::from_iter([0.0605_f32; 10]);
224        let encoded = alp_encode(&array, None).unwrap();
225        assert!(encoded.patches().is_none());
226        assert_eq!(
227            encoded.encoded().to_primitive().as_slice::<i32>(),
228            vec![605; 10]
229        );
230
231        // !(0.0605_f32 >= 0.06051_f32);
232        let r_gte = alp_scalar_compare(&encoded, 0.06051_f32, Operator::Gte)
233            .unwrap()
234            .unwrap();
235        let expected = BoolArray::from_iter([false; 10]);
236        assert_arrays_eq!(r_gte, expected);
237
238        // (0.0605_f32 > 0.06051_f32);
239        let r_gt = alp_scalar_compare(&encoded, 0.06051_f32, Operator::Gt)
240            .unwrap()
241            .unwrap();
242        let expected = BoolArray::from_iter([false; 10]);
243        assert_arrays_eq!(r_gt, expected);
244
245        // 0.0605_f32 <= 0.06051_f32;
246        let r_lte = alp_scalar_compare(&encoded, 0.06051_f32, Operator::Lte)
247            .unwrap()
248            .unwrap();
249        let expected = BoolArray::from_iter([true; 10]);
250        assert_arrays_eq!(r_lte, expected);
251
252        //0.0605_f32 < 0.06051_f32;
253        let r_lt = alp_scalar_compare(&encoded, 0.06051_f32, Operator::Lt)
254            .unwrap()
255            .unwrap();
256        let expected = BoolArray::from_iter([true; 10]);
257        assert_arrays_eq!(r_lt, expected);
258    }
259
260    #[test]
261    fn comparison_zeroes() {
262        let array = PrimitiveArray::from_iter([0.0_f32; 10]);
263        let encoded = alp_encode(&array, None).unwrap();
264        assert!(encoded.patches().is_none());
265        assert_eq!(
266            encoded.encoded().to_primitive().as_slice::<i32>(),
267            vec![0; 10]
268        );
269
270        let r_gte = test_alp_compare(&encoded, -0.00000001_f32, Operator::Gte).unwrap();
271        let expected = BoolArray::from_iter([true; 10]);
272        assert_arrays_eq!(r_gte, expected);
273
274        let r_gte = test_alp_compare(&encoded, -0.0_f32, Operator::Gte).unwrap();
275        let expected = BoolArray::from_iter([true; 10]);
276        assert_arrays_eq!(r_gte, expected);
277
278        let r_gt = test_alp_compare(&encoded, -0.0000000001f32, Operator::Gt).unwrap();
279        let expected = BoolArray::from_iter([true; 10]);
280        assert_arrays_eq!(r_gt, expected);
281
282        let r_gte = test_alp_compare(&encoded, -0.0_f32, Operator::Gt).unwrap();
283        let expected = BoolArray::from_iter([true; 10]);
284        assert_arrays_eq!(r_gte, expected);
285
286        let r_lte = test_alp_compare(&encoded, 0.06051_f32, Operator::Lte).unwrap();
287        let expected = BoolArray::from_iter([true; 10]);
288        assert_arrays_eq!(r_lte, expected);
289
290        let r_lt = test_alp_compare(&encoded, 0.06051_f32, Operator::Lt).unwrap();
291        let expected = BoolArray::from_iter([true; 10]);
292        assert_arrays_eq!(r_lt, expected);
293
294        let r_lt = test_alp_compare(&encoded, -0.00001_f32, Operator::Lt).unwrap();
295        let expected = BoolArray::from_iter([false; 10]);
296        assert_arrays_eq!(r_lt, expected);
297    }
298
299    #[test]
300    fn compare_with_patches() {
301        let array =
302            PrimitiveArray::from_iter([1.234f32, 1.5, 19.0, std::f32::consts::E, 1_000_000.9]);
303        let encoded = alp_encode(&array, None).unwrap();
304        assert!(encoded.patches().is_some());
305
306        // Not supported!
307        assert!(
308            alp_scalar_compare(&encoded, 1_000_000.9_f32, Operator::Eq)
309                .unwrap()
310                .is_none()
311        )
312    }
313
314    #[test]
315    fn compare_to_null() {
316        let array = PrimitiveArray::from_iter([1.234f32; 10]);
317        let encoded = alp_encode(&array, None).unwrap();
318
319        let other = ConstantArray::new(
320            Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
321            array.len(),
322        );
323
324        let r = compare(encoded.as_ref(), other.as_ref(), Operator::Eq).unwrap();
325        // Comparing to null yields null results
326        let expected = BoolArray::from_iter([None::<bool>; 10]);
327        assert_arrays_eq!(r, expected);
328    }
329
330    #[rstest]
331    #[case(f32::NAN, false)]
332    #[case(-1.0f32 / 0.0f32, true)]
333    #[case(f32::INFINITY, false)]
334    #[case(f32::NEG_INFINITY, true)]
335    fn compare_to_non_finite_gt(#[case] value: f32, #[case] result: bool) {
336        let array = PrimitiveArray::from_iter([1.234f32; 10]);
337        let encoded = alp_encode(&array, None).unwrap();
338
339        let r = test_alp_compare(&encoded, value, Operator::Gt).unwrap();
340        let expected = BoolArray::from_iter([result; 10]);
341        assert_arrays_eq!(r, expected);
342    }
343
344    #[rstest]
345    #[case(f32::NAN, true)]
346    #[case(-1.0f32 / 0.0f32, false)]
347    #[case(f32::INFINITY, true)]
348    #[case(f32::NEG_INFINITY, false)]
349    fn compare_to_non_finite_lt(#[case] value: f32, #[case] result: bool) {
350        let array = PrimitiveArray::from_iter([1.234f32; 10]);
351        let encoded = alp_encode(&array, None).unwrap();
352
353        let r = test_alp_compare(&encoded, value, Operator::Lt).unwrap();
354        let expected = BoolArray::from_iter([result; 10]);
355        assert_arrays_eq!(r, expected);
356    }
357}