Skip to main content

vortex_alp/alp/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use num_traits::Bounded;
7use vortex_array::ArrayRef;
8use vortex_array::ArrayView;
9use vortex_array::IntoArray;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::builtins::ArrayBuiltins;
12use vortex_array::dtype::NativeDType;
13use vortex_array::dtype::NativePType;
14use vortex_array::dtype::Nullability;
15use vortex_array::scalar::Scalar;
16use vortex_array::scalar_fn::fns::between::BetweenOptions;
17use vortex_array::scalar_fn::fns::between::BetweenReduce;
18use vortex_array::scalar_fn::fns::between::StrictComparison;
19use vortex_error::VortexResult;
20
21use crate::ALP;
22use crate::ALPFloat;
23use crate::Exponents;
24use crate::alp::array::ALPArrayExt;
25use crate::alp::array::ALPArraySlotsExt;
26use crate::match_each_alp_float_ptype;
27
28impl BetweenReduce for ALP {
29    fn between(
30        array: ArrayView<'_, Self>,
31        lower: &ArrayRef,
32        upper: &ArrayRef,
33        options: &BetweenOptions,
34    ) -> VortexResult<Option<ArrayRef>> {
35        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
36            return Ok(None);
37        };
38
39        if array.patches().is_some() {
40            return Ok(None);
41        }
42
43        let nullability =
44            array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability();
45        match_each_alp_float_ptype!(array.dtype().as_ptype(), |F| {
46            between_impl::<F>(
47                array,
48                F::try_from(&lower)?,
49                F::try_from(&upper)?,
50                nullability,
51                options,
52            )
53        })
54        .map(Some)
55    }
56}
57
58fn between_impl<T: NativePType + ALPFloat>(
59    array: ArrayView<'_, ALP>,
60    lower: T,
61    upper: T,
62    nullability: Nullability,
63    options: &BetweenOptions,
64) -> VortexResult<ArrayRef>
65where
66    Scalar: From<T::ALPInt>,
67    <T as ALPFloat>::ALPInt: NativeDType + Debug,
68{
69    let exponents = array.exponents();
70
71    // There are always compared
72    // the below bound is `value {< | <=} x`, either value encodes into the ALPInt domain
73    // in which case we can leave the comparison unchanged `enc(value) {< | <=} x` or it doesn't
74    // and we encode into value below enc_below(value) < value < x, in which case the comparison
75    // becomes enc(value) < x. See `alp_scalar_compare` for more details.
76    // note that if the value doesn't encode than value != x, so must use strict comparison.
77    let Some((lower_enc, lower_strict)) =
78        encode_lower_bound::<T>(lower, exponents, options.lower_strict)
79    else {
80        return Ok(ConstantArray::new(Scalar::bool(false, nullability), array.len()).into_array());
81    };
82
83    // the upper value `x { < | <= } value` similarly encodes or `x < value < enc_above(value())`
84    let Some((upper_enc, upper_strict)) =
85        encode_upper_bound::<T>(upper, exponents, options.upper_strict)
86    else {
87        return Ok(ConstantArray::new(Scalar::bool(false, nullability), array.len()).into_array());
88    };
89
90    let options = BetweenOptions {
91        lower_strict,
92        upper_strict,
93    };
94
95    array.encoded().clone().between(
96        ConstantArray::new(Scalar::primitive(lower_enc, nullability), array.len()).into_array(),
97        ConstantArray::new(Scalar::primitive(upper_enc, nullability), array.len()).into_array(),
98        options,
99    )
100}
101
102fn encode_lower_bound<T: ALPFloat>(
103    lower: T,
104    exponents: Exponents,
105    strict: StrictComparison,
106) -> Option<(T::ALPInt, StrictComparison)> {
107    if NativePType::is_nan(lower) || NativePType::is_infinite(lower) {
108        return NativePType::is_lt(lower, T::zero())
109            .then_some((T::ALPInt::min_value(), StrictComparison::NonStrict));
110    }
111
112    Some(
113        T::encode_single(lower, exponents)
114            .map(|x| (x, strict))
115            .unwrap_or_else(|| (T::encode_below(lower, exponents), StrictComparison::Strict)),
116    )
117}
118
119fn encode_upper_bound<T: ALPFloat>(
120    upper: T,
121    exponents: Exponents,
122    strict: StrictComparison,
123) -> Option<(T::ALPInt, StrictComparison)> {
124    if NativePType::is_nan(upper) || NativePType::is_infinite(upper) {
125        return NativePType::is_gt(upper, T::zero())
126            .then_some((T::ALPInt::max_value(), StrictComparison::NonStrict));
127    }
128
129    Some(
130        T::encode_single(upper, exponents)
131            .map(|x| (x, strict))
132            .unwrap_or_else(|| (T::encode_above(upper, exponents), StrictComparison::Strict)),
133    )
134}
135
136#[cfg(test)]
137mod tests {
138    use vortex_array::IntoArray;
139    use vortex_array::LEGACY_SESSION;
140    use vortex_array::VortexSessionExecute;
141    use vortex_array::arrays::ConstantArray;
142    use vortex_array::arrays::PrimitiveArray;
143    use vortex_array::assert_arrays_eq;
144    use vortex_array::dtype::Nullability;
145    use vortex_array::scalar::Scalar;
146    use vortex_array::scalar_fn::fns::between::BetweenOptions;
147    use vortex_array::scalar_fn::fns::between::StrictComparison;
148
149    use crate::ALPArray;
150    use crate::alp::array::ALPArrayExt;
151    use crate::alp::compute::between::between_impl;
152    use crate::alp_encode;
153
154    fn assert_between(
155        arr: &ALPArray,
156        lower: f32,
157        upper: f32,
158        options: &BetweenOptions,
159        expected: bool,
160    ) {
161        let res =
162            between_impl(arr.as_view(), lower, upper, Nullability::Nullable, options).unwrap();
163        assert_arrays_eq!(
164            res,
165            ConstantArray::new(Scalar::bool(expected, res.dtype().nullability()), arr.len())
166                .into_array()
167        );
168    }
169
170    #[test]
171    fn comparison_range() {
172        let value = 0.0605_f32;
173        let array = PrimitiveArray::from_iter([value; 1]);
174        let encoded = alp_encode(
175            array.as_view(),
176            None,
177            &mut LEGACY_SESSION.create_execution_ctx(),
178        )
179        .unwrap();
180        assert!(encoded.patches().is_none());
181
182        assert_between(
183            &encoded,
184            0.0605_f32,
185            0.0605,
186            &BetweenOptions {
187                lower_strict: StrictComparison::NonStrict,
188                upper_strict: StrictComparison::NonStrict,
189            },
190            true,
191        );
192
193        assert_between(
194            &encoded,
195            0.0605_f32,
196            0.0605,
197            &BetweenOptions {
198                lower_strict: StrictComparison::Strict,
199                upper_strict: StrictComparison::NonStrict,
200            },
201            false,
202        );
203
204        assert_between(
205            &encoded,
206            0.0605_f32,
207            0.0605,
208            &BetweenOptions {
209                lower_strict: StrictComparison::NonStrict,
210                upper_strict: StrictComparison::Strict,
211            },
212            false,
213        );
214
215        assert_between(
216            &encoded,
217            0.060499_f32,
218            0.06051,
219            &BetweenOptions {
220                lower_strict: StrictComparison::NonStrict,
221                upper_strict: StrictComparison::NonStrict,
222            },
223            true,
224        );
225
226        assert_between(
227            &encoded,
228            0.06_f32,
229            0.06051,
230            &BetweenOptions {
231                lower_strict: StrictComparison::NonStrict,
232                upper_strict: StrictComparison::Strict,
233            },
234            true,
235        );
236    }
237
238    #[test]
239    fn non_finite_bounds_use_total_order() {
240        let mut ctx = LEGACY_SESSION.create_execution_ctx();
241        let array = PrimitiveArray::from_iter([1.234f32; 10]);
242        let encoded = alp_encode(array.as_view(), None, &mut ctx).unwrap();
243        assert!(encoded.patches().is_none());
244
245        let options = BetweenOptions {
246            lower_strict: StrictComparison::NonStrict,
247            upper_strict: StrictComparison::Strict,
248        };
249
250        assert_between(&encoded, f32::from_bits(0xffffff5e), 2.0, &options, true);
251        assert_between(&encoded, f32::NAN, 2.0, &options, false);
252        assert_between(&encoded, f32::NEG_INFINITY, 2.0, &options, true);
253        assert_between(&encoded, f32::INFINITY, 2.0, &options, false);
254
255        assert_between(&encoded, 0.0, f32::NAN, &options, true);
256        assert_between(&encoded, 0.0, f32::from_bits(0xffffff5e), &options, false);
257        assert_between(&encoded, 0.0, f32::INFINITY, &options, true);
258        assert_between(&encoded, 0.0, f32::NEG_INFINITY, &options, false);
259    }
260}