Skip to main content

vortex_alp/alp/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use vortex_array::ArrayRef;
7use vortex_array::IntoArray;
8use vortex_array::arrays::ConstantArray;
9use vortex_array::builtins::ArrayBuiltins;
10use vortex_array::dtype::NativeDType;
11use vortex_array::dtype::NativePType;
12use vortex_array::dtype::Nullability;
13use vortex_array::scalar::Scalar;
14use vortex_array::scalar_fn::fns::between::BetweenOptions;
15use vortex_array::scalar_fn::fns::between::BetweenReduce;
16use vortex_array::scalar_fn::fns::between::StrictComparison;
17use vortex_error::VortexResult;
18
19use crate::ALPArray;
20use crate::ALPFloat;
21use crate::ALPVTable;
22use crate::match_each_alp_float_ptype;
23
24impl BetweenReduce for ALPVTable {
25    fn between(
26        array: &ALPArray,
27        lower: &ArrayRef,
28        upper: &ArrayRef,
29        options: &BetweenOptions,
30    ) -> VortexResult<Option<ArrayRef>> {
31        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
32            return Ok(None);
33        };
34
35        if array.patches().is_some() {
36            return Ok(None);
37        }
38
39        let nullability =
40            array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability();
41        match_each_alp_float_ptype!(array.ptype(), |F| {
42            between_impl::<F>(
43                array,
44                F::try_from(&lower)?,
45                F::try_from(&upper)?,
46                nullability,
47                options,
48            )
49        })
50        .map(Some)
51    }
52}
53
54fn between_impl<T: NativePType + ALPFloat>(
55    array: &ALPArray,
56    lower: T,
57    upper: T,
58    nullability: Nullability,
59    options: &BetweenOptions,
60) -> VortexResult<ArrayRef>
61where
62    Scalar: From<T::ALPInt>,
63    <T as ALPFloat>::ALPInt: NativeDType + Debug,
64{
65    let exponents = array.exponents();
66
67    // There are always compared
68    // the below bound is `value {< | <=} x`, either value encodes into the ALPInt domain
69    // in which case we can leave the comparison unchanged `enc(value) {< | <=} x` or it doesn't
70    // and we encode into value below enc_below(value) < value < x, in which case the comparison
71    // becomes enc(value) < x. See `alp_scalar_compare` for more details.
72    // note that if the value doesn't encode than value != x, so must use strict comparison.
73    let (lower_enc, lower_strict) = T::encode_single(lower, exponents)
74        .map(|x| (x, options.lower_strict))
75        .unwrap_or_else(|| (T::encode_below(lower, exponents), StrictComparison::Strict));
76
77    // the upper value `x { < | <= } value` similarly encodes or `x < value < enc_above(value())`
78    let (upper_enc, upper_strict) = T::encode_single(upper, exponents)
79        .map(|x| (x, options.upper_strict))
80        .unwrap_or_else(|| (T::encode_above(upper, exponents), StrictComparison::Strict));
81
82    let options = BetweenOptions {
83        lower_strict,
84        upper_strict,
85    };
86
87    array.encoded().clone().between(
88        ConstantArray::new(Scalar::primitive(lower_enc, nullability), array.len()).into_array(),
89        ConstantArray::new(Scalar::primitive(upper_enc, nullability), array.len()).into_array(),
90        options,
91    )
92}
93
94#[cfg(test)]
95mod tests {
96    use vortex_array::arrays::BoolArray;
97    use vortex_array::arrays::PrimitiveArray;
98    use vortex_array::assert_arrays_eq;
99    use vortex_array::dtype::Nullability;
100    use vortex_array::scalar_fn::fns::between::BetweenOptions;
101    use vortex_array::scalar_fn::fns::between::StrictComparison;
102
103    use crate::ALPArray;
104    use crate::alp::compute::between::between_impl;
105    use crate::alp_encode;
106
107    fn assert_between(
108        arr: &ALPArray,
109        lower: f32,
110        upper: f32,
111        options: &BetweenOptions,
112        expected: bool,
113    ) {
114        let res = between_impl(arr, lower, upper, Nullability::Nullable, options).unwrap();
115        assert_arrays_eq!(res, BoolArray::from_iter([Some(expected)]));
116    }
117
118    #[test]
119    fn comparison_range() {
120        let value = 0.0605_f32;
121        let array = PrimitiveArray::from_iter([value; 1]);
122        let encoded = alp_encode(&array, None).unwrap();
123        assert!(encoded.patches().is_none());
124
125        assert_between(
126            &encoded,
127            0.0605_f32,
128            0.0605,
129            &BetweenOptions {
130                lower_strict: StrictComparison::NonStrict,
131                upper_strict: StrictComparison::NonStrict,
132            },
133            true,
134        );
135
136        assert_between(
137            &encoded,
138            0.0605_f32,
139            0.0605,
140            &BetweenOptions {
141                lower_strict: StrictComparison::Strict,
142                upper_strict: StrictComparison::NonStrict,
143            },
144            false,
145        );
146
147        assert_between(
148            &encoded,
149            0.0605_f32,
150            0.0605,
151            &BetweenOptions {
152                lower_strict: StrictComparison::NonStrict,
153                upper_strict: StrictComparison::Strict,
154            },
155            false,
156        );
157
158        assert_between(
159            &encoded,
160            0.060499_f32,
161            0.06051,
162            &BetweenOptions {
163                lower_strict: StrictComparison::NonStrict,
164                upper_strict: StrictComparison::NonStrict,
165            },
166            true,
167        );
168
169        assert_between(
170            &encoded,
171            0.06_f32,
172            0.06051,
173            &BetweenOptions {
174                lower_strict: StrictComparison::NonStrict,
175                upper_strict: StrictComparison::Strict,
176            },
177            true,
178        );
179    }
180}