Skip to main content

vortex_alp/alp/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use vortex_array::ArrayRef;
7use vortex_array::ArrayView;
8use vortex_array::IntoArray;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::builtins::ArrayBuiltins;
11use vortex_array::dtype::NativeDType;
12use vortex_array::dtype::NativePType;
13use vortex_array::dtype::Nullability;
14use vortex_array::scalar::Scalar;
15use vortex_array::scalar_fn::fns::between::BetweenOptions;
16use vortex_array::scalar_fn::fns::between::BetweenReduce;
17use vortex_array::scalar_fn::fns::between::StrictComparison;
18use vortex_error::VortexResult;
19
20use crate::ALP;
21use crate::ALPFloat;
22use crate::alp::array::ALPArrayExt;
23use crate::alp::array::ALPArraySlotsExt;
24use crate::match_each_alp_float_ptype;
25
26impl BetweenReduce for ALP {
27    fn between(
28        array: ArrayView<'_, Self>,
29        lower: &ArrayRef,
30        upper: &ArrayRef,
31        options: &BetweenOptions,
32    ) -> VortexResult<Option<ArrayRef>> {
33        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
34            return Ok(None);
35        };
36
37        if array.patches().is_some() {
38            return Ok(None);
39        }
40
41        let nullability =
42            array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability();
43        match_each_alp_float_ptype!(array.dtype().as_ptype(), |F| {
44            between_impl::<F>(
45                array,
46                F::try_from(&lower)?,
47                F::try_from(&upper)?,
48                nullability,
49                options,
50            )
51        })
52        .map(Some)
53    }
54}
55
56fn between_impl<T: NativePType + ALPFloat>(
57    array: ArrayView<'_, ALP>,
58    lower: T,
59    upper: T,
60    nullability: Nullability,
61    options: &BetweenOptions,
62) -> VortexResult<ArrayRef>
63where
64    Scalar: From<T::ALPInt>,
65    <T as ALPFloat>::ALPInt: NativeDType + Debug,
66{
67    let exponents = array.exponents();
68
69    // There are always compared
70    // the below bound is `value {< | <=} x`, either value encodes into the ALPInt domain
71    // in which case we can leave the comparison unchanged `enc(value) {< | <=} x` or it doesn't
72    // and we encode into value below enc_below(value) < value < x, in which case the comparison
73    // becomes enc(value) < x. See `alp_scalar_compare` for more details.
74    // note that if the value doesn't encode than value != x, so must use strict comparison.
75    let (lower_enc, lower_strict) = T::encode_single(lower, exponents)
76        .map(|x| (x, options.lower_strict))
77        .unwrap_or_else(|| (T::encode_below(lower, exponents), StrictComparison::Strict));
78
79    // the upper value `x { < | <= } value` similarly encodes or `x < value < enc_above(value())`
80    let (upper_enc, upper_strict) = T::encode_single(upper, exponents)
81        .map(|x| (x, options.upper_strict))
82        .unwrap_or_else(|| (T::encode_above(upper, exponents), StrictComparison::Strict));
83
84    let options = BetweenOptions {
85        lower_strict,
86        upper_strict,
87    };
88
89    array.encoded().clone().between(
90        ConstantArray::new(Scalar::primitive(lower_enc, nullability), array.len()).into_array(),
91        ConstantArray::new(Scalar::primitive(upper_enc, nullability), array.len()).into_array(),
92        options,
93    )
94}
95
96#[cfg(test)]
97mod tests {
98    use vortex_array::arrays::BoolArray;
99    use vortex_array::arrays::PrimitiveArray;
100    use vortex_array::assert_arrays_eq;
101    use vortex_array::dtype::Nullability;
102    use vortex_array::scalar_fn::fns::between::BetweenOptions;
103    use vortex_array::scalar_fn::fns::between::StrictComparison;
104
105    use crate::ALPArray;
106    use crate::alp::array::ALPArrayExt;
107    use crate::alp::compute::between::between_impl;
108    use crate::alp_encode;
109
110    fn assert_between(
111        arr: &ALPArray,
112        lower: f32,
113        upper: f32,
114        options: &BetweenOptions,
115        expected: bool,
116    ) {
117        let res =
118            between_impl(arr.as_view(), lower, upper, Nullability::Nullable, options).unwrap();
119        assert_arrays_eq!(res, BoolArray::from_iter([Some(expected)]));
120    }
121
122    #[test]
123    fn comparison_range() {
124        let value = 0.0605_f32;
125        let array = PrimitiveArray::from_iter([value; 1]);
126        let encoded = alp_encode(&array, None).unwrap();
127        assert!(encoded.patches().is_none());
128
129        assert_between(
130            &encoded,
131            0.0605_f32,
132            0.0605,
133            &BetweenOptions {
134                lower_strict: StrictComparison::NonStrict,
135                upper_strict: StrictComparison::NonStrict,
136            },
137            true,
138        );
139
140        assert_between(
141            &encoded,
142            0.0605_f32,
143            0.0605,
144            &BetweenOptions {
145                lower_strict: StrictComparison::Strict,
146                upper_strict: StrictComparison::NonStrict,
147            },
148            false,
149        );
150
151        assert_between(
152            &encoded,
153            0.0605_f32,
154            0.0605,
155            &BetweenOptions {
156                lower_strict: StrictComparison::NonStrict,
157                upper_strict: StrictComparison::Strict,
158            },
159            false,
160        );
161
162        assert_between(
163            &encoded,
164            0.060499_f32,
165            0.06051,
166            &BetweenOptions {
167                lower_strict: StrictComparison::NonStrict,
168                upper_strict: StrictComparison::NonStrict,
169            },
170            true,
171        );
172
173        assert_between(
174            &encoded,
175            0.06_f32,
176            0.06051,
177            &BetweenOptions {
178                lower_strict: StrictComparison::NonStrict,
179                upper_strict: StrictComparison::Strict,
180            },
181            true,
182        );
183    }
184}