vortex_alp/alp/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use vortex_array::Array;
7use vortex_array::ArrayRef;
8use vortex_array::arrays::ConstantArray;
9use vortex_array::compute::BetweenKernel;
10use vortex_array::compute::BetweenKernelAdapter;
11use vortex_array::compute::BetweenOptions;
12use vortex_array::compute::StrictComparison;
13use vortex_array::compute::between;
14use vortex_array::register_kernel;
15use vortex_dtype::NativeDType;
16use vortex_dtype::NativePType;
17use vortex_dtype::Nullability;
18use vortex_error::VortexResult;
19use vortex_scalar::Scalar;
20
21use crate::ALPArray;
22use crate::ALPFloat;
23use crate::ALPVTable;
24use crate::match_each_alp_float_ptype;
25
26impl BetweenKernel for ALPVTable {
27    fn between(
28        &self,
29        array: &ALPArray,
30        lower: &dyn Array,
31        upper: &dyn Array,
32        options: &BetweenOptions,
33    ) -> VortexResult<Option<ArrayRef>> {
34        let (Some(lower), Some(upper)) = (lower.as_constant(), upper.as_constant()) else {
35            return Ok(None);
36        };
37
38        if array.patches().is_some() {
39            return Ok(None);
40        }
41
42        let nullability =
43            array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability();
44
45        match_each_alp_float_ptype!(array.ptype(), |F| {
46            between_impl::<F>(
47                array,
48                F::try_from(lower)?,
49                F::try_from(upper)?,
50                nullability,
51                options,
52            )
53        })
54        .map(Some)
55    }
56}
57
58register_kernel!(BetweenKernelAdapter(ALPVTable).lift());
59
60fn between_impl<T: NativePType + ALPFloat>(
61    array: &ALPArray,
62    lower: T,
63    upper: T,
64    nullability: Nullability,
65    options: &BetweenOptions,
66) -> VortexResult<ArrayRef>
67where
68    Scalar: From<T::ALPInt>,
69    <T as ALPFloat>::ALPInt: NativeDType + Debug,
70{
71    let exponents = array.exponents();
72
73    // There are always compared
74    // the below bound is `value {< | <=} x`, either value encodes into the ALPInt domain
75    // in which case we can leave the comparison unchanged `enc(value) {< | <=} x` or it doesn't
76    // and we encode into value below enc_below(value) < value < x, in which case the comparison
77    // becomes enc(value) < x. See `alp_scalar_compare` for more details.
78    // note that if the value doesn't encode than value != x, so must use strict comparison.
79    let (lower_enc, lower_strict) = T::encode_single(lower, exponents)
80        .map(|x| (x, options.lower_strict))
81        .unwrap_or_else(|| (T::encode_below(lower, exponents), StrictComparison::Strict));
82
83    // the upper value `x { < | <= } value` similarly encodes or `x < value < enc_above(value())`
84    let (upper_enc, upper_strict) = T::encode_single(upper, exponents)
85        .map(|x| (x, options.upper_strict))
86        .unwrap_or_else(|| (T::encode_above(upper, exponents), StrictComparison::Strict));
87
88    let options = BetweenOptions {
89        lower_strict,
90        upper_strict,
91    };
92
93    between(
94        array.encoded(),
95        ConstantArray::new(Scalar::primitive(lower_enc, nullability), array.len()).as_ref(),
96        ConstantArray::new(Scalar::primitive(upper_enc, nullability), array.len()).as_ref(),
97        &options,
98    )
99}
100
101#[cfg(test)]
102mod tests {
103    use itertools::Itertools;
104    use vortex_array::ToCanonical;
105    use vortex_array::arrays::PrimitiveArray;
106    use vortex_array::compute::BetweenOptions;
107    use vortex_array::compute::StrictComparison;
108    use vortex_dtype::Nullability;
109
110    use crate::ALPArray;
111    use crate::alp::compute::between::between_impl;
112    use crate::alp_encode;
113
114    fn between_test(arr: &ALPArray, lower: f32, upper: f32, options: &BetweenOptions) -> bool {
115        let res = between_impl(arr, lower, upper, Nullability::Nullable, options)
116            .unwrap()
117            .to_bool()
118            .bit_buffer()
119            .iter()
120            .collect_vec();
121        assert_eq!(res.len(), 1);
122
123        res[0]
124    }
125
126    #[test]
127    fn comparison_range() {
128        let value = 0.0605_f32;
129        let array = PrimitiveArray::from_iter([value; 1]);
130        let encoded = alp_encode(&array, None).unwrap();
131        assert!(encoded.patches().is_none());
132        assert_eq!(
133            encoded.encoded().to_primitive().as_slice::<i32>(),
134            vec![605; 1]
135        );
136
137        assert!(between_test(
138            &encoded,
139            0.0605_f32,
140            0.0605,
141            &BetweenOptions {
142                lower_strict: StrictComparison::NonStrict,
143                upper_strict: StrictComparison::NonStrict,
144            },
145        ));
146
147        assert!(!between_test(
148            &encoded,
149            0.0605_f32,
150            0.0605,
151            &BetweenOptions {
152                lower_strict: StrictComparison::Strict,
153                upper_strict: StrictComparison::NonStrict,
154            },
155        ));
156
157        assert!(!between_test(
158            &encoded,
159            0.0605_f32,
160            0.0605,
161            &BetweenOptions {
162                lower_strict: StrictComparison::NonStrict,
163                upper_strict: StrictComparison::Strict,
164            },
165        ));
166
167        assert!(between_test(
168            &encoded,
169            0.060499_f32,
170            0.06051,
171            &BetweenOptions {
172                lower_strict: StrictComparison::NonStrict,
173                upper_strict: StrictComparison::NonStrict,
174            },
175        ));
176
177        assert!(between_test(
178            &encoded,
179            0.06_f32,
180            0.06051,
181            &BetweenOptions {
182                lower_strict: StrictComparison::NonStrict,
183                upper_strict: StrictComparison::Strict,
184            },
185        ))
186    }
187}