Skip to main content

vortex_alp/alp/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use vortex_array::Array;
7use vortex_array::ArrayRef;
8use vortex_array::IntoArray;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::builtins::ArrayBuiltins;
11use vortex_array::dtype::NativeDType;
12use vortex_array::dtype::NativePType;
13use vortex_array::dtype::Nullability;
14use vortex_array::scalar::Scalar;
15use vortex_array::scalar_fn::fns::between::BetweenOptions;
16use vortex_array::scalar_fn::fns::between::BetweenReduce;
17use vortex_array::scalar_fn::fns::between::StrictComparison;
18use vortex_error::VortexResult;
19
20use crate::ALPArray;
21use crate::ALPFloat;
22use crate::ALPVTable;
23use crate::match_each_alp_float_ptype;
24
25impl BetweenReduce for ALPVTable {
26    fn between(
27        array: &ALPArray,
28        lower: &dyn Array,
29        upper: &dyn Array,
30        options: &BetweenOptions,
31    ) -> VortexResult<Option<ArrayRef>> {
32        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
33            return Ok(None);
34        };
35
36        if array.patches().is_some() {
37            return Ok(None);
38        }
39
40        let nullability =
41            array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability();
42        match_each_alp_float_ptype!(array.ptype(), |F| {
43            between_impl::<F>(
44                array,
45                F::try_from(&lower)?,
46                F::try_from(&upper)?,
47                nullability,
48                options,
49            )
50        })
51        .map(Some)
52    }
53}
54
55fn between_impl<T: NativePType + ALPFloat>(
56    array: &ALPArray,
57    lower: T,
58    upper: T,
59    nullability: Nullability,
60    options: &BetweenOptions,
61) -> VortexResult<ArrayRef>
62where
63    Scalar: From<T::ALPInt>,
64    <T as ALPFloat>::ALPInt: NativeDType + Debug,
65{
66    let exponents = array.exponents();
67
68    // There are always compared
69    // the below bound is `value {< | <=} x`, either value encodes into the ALPInt domain
70    // in which case we can leave the comparison unchanged `enc(value) {< | <=} x` or it doesn't
71    // and we encode into value below enc_below(value) < value < x, in which case the comparison
72    // becomes enc(value) < x. See `alp_scalar_compare` for more details.
73    // note that if the value doesn't encode than value != x, so must use strict comparison.
74    let (lower_enc, lower_strict) = T::encode_single(lower, exponents)
75        .map(|x| (x, options.lower_strict))
76        .unwrap_or_else(|| (T::encode_below(lower, exponents), StrictComparison::Strict));
77
78    // the upper value `x { < | <= } value` similarly encodes or `x < value < enc_above(value())`
79    let (upper_enc, upper_strict) = T::encode_single(upper, exponents)
80        .map(|x| (x, options.upper_strict))
81        .unwrap_or_else(|| (T::encode_above(upper, exponents), StrictComparison::Strict));
82
83    let options = BetweenOptions {
84        lower_strict,
85        upper_strict,
86    };
87
88    array.encoded().clone().between(
89        ConstantArray::new(Scalar::primitive(lower_enc, nullability), array.len()).into_array(),
90        ConstantArray::new(Scalar::primitive(upper_enc, nullability), array.len()).into_array(),
91        options,
92    )
93}
94
95#[cfg(test)]
96mod tests {
97    use vortex_array::arrays::BoolArray;
98    use vortex_array::arrays::PrimitiveArray;
99    use vortex_array::assert_arrays_eq;
100    use vortex_array::dtype::Nullability;
101    use vortex_array::scalar_fn::fns::between::BetweenOptions;
102    use vortex_array::scalar_fn::fns::between::StrictComparison;
103
104    use crate::ALPArray;
105    use crate::alp::compute::between::between_impl;
106    use crate::alp_encode;
107
108    fn assert_between(
109        arr: &ALPArray,
110        lower: f32,
111        upper: f32,
112        options: &BetweenOptions,
113        expected: bool,
114    ) {
115        let res = between_impl(arr, lower, upper, Nullability::Nullable, options).unwrap();
116        assert_arrays_eq!(res, BoolArray::from_iter([Some(expected)]));
117    }
118
119    #[test]
120    fn comparison_range() {
121        let value = 0.0605_f32;
122        let array = PrimitiveArray::from_iter([value; 1]);
123        let encoded = alp_encode(&array, None).unwrap();
124        assert!(encoded.patches().is_none());
125
126        assert_between(
127            &encoded,
128            0.0605_f32,
129            0.0605,
130            &BetweenOptions {
131                lower_strict: StrictComparison::NonStrict,
132                upper_strict: StrictComparison::NonStrict,
133            },
134            true,
135        );
136
137        assert_between(
138            &encoded,
139            0.0605_f32,
140            0.0605,
141            &BetweenOptions {
142                lower_strict: StrictComparison::Strict,
143                upper_strict: StrictComparison::NonStrict,
144            },
145            false,
146        );
147
148        assert_between(
149            &encoded,
150            0.0605_f32,
151            0.0605,
152            &BetweenOptions {
153                lower_strict: StrictComparison::NonStrict,
154                upper_strict: StrictComparison::Strict,
155            },
156            false,
157        );
158
159        assert_between(
160            &encoded,
161            0.060499_f32,
162            0.06051,
163            &BetweenOptions {
164                lower_strict: StrictComparison::NonStrict,
165                upper_strict: StrictComparison::NonStrict,
166            },
167            true,
168        );
169
170        assert_between(
171            &encoded,
172            0.06_f32,
173            0.06051,
174            &BetweenOptions {
175                lower_strict: StrictComparison::NonStrict,
176                upper_strict: StrictComparison::Strict,
177            },
178            true,
179        );
180    }
181}