vortex_array/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_dtype::DType;
9use vortex_error::VortexError;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_scalar::Scalar;
15
16use crate::Array;
17use crate::ArrayRef;
18use crate::Canonical;
19use crate::IntoArray;
20use crate::arrays::ConstantArray;
21use crate::compute::BooleanOperator;
22use crate::compute::ComputeFn;
23use crate::compute::ComputeFnVTable;
24use crate::compute::InvocationArgs;
25use crate::compute::Kernel;
26use crate::compute::Operator;
27use crate::compute::Options;
28use crate::compute::Output;
29use crate::compute::boolean;
30use crate::compute::compare;
31use crate::vtable::VTable;
32
33static BETWEEN_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
34    let compute = ComputeFn::new("between".into(), ArcRef::new_ref(&Between));
35    for kernel in inventory::iter::<BetweenKernelRef> {
36        compute.register_kernel(kernel.0.clone());
37    }
38    compute
39});
40
41pub(crate) fn warm_up_vtable() -> usize {
42    BETWEEN_FN.kernels().len()
43}
44
45/// Compute between (a <= x <= b).
46///
47/// This is an optimized implementation that is equivalent to `(a <= x) AND (x <= b)`.
48///
49/// The `BetweenOptions` defines if the lower or upper bounds are strict (exclusive) or non-strict
50/// (inclusive).
51pub fn between(
52    arr: &dyn Array,
53    lower: &dyn Array,
54    upper: &dyn Array,
55    options: &BetweenOptions,
56) -> VortexResult<ArrayRef> {
57    BETWEEN_FN
58        .invoke(&InvocationArgs {
59            inputs: &[arr.into(), lower.into(), upper.into()],
60            options,
61        })?
62        .unwrap_array()
63}
64
65pub struct BetweenKernelRef(ArcRef<dyn Kernel>);
66inventory::collect!(BetweenKernelRef);
67
68pub trait BetweenKernel: VTable {
69    fn between(
70        &self,
71        arr: &Self::Array,
72        lower: &dyn Array,
73        upper: &dyn Array,
74        options: &BetweenOptions,
75    ) -> VortexResult<Option<ArrayRef>>;
76}
77
78#[derive(Debug)]
79pub struct BetweenKernelAdapter<V: VTable>(pub V);
80
81impl<V: VTable + BetweenKernel> BetweenKernelAdapter<V> {
82    pub const fn lift(&'static self) -> BetweenKernelRef {
83        BetweenKernelRef(ArcRef::new_ref(self))
84    }
85}
86
87impl<V: VTable + BetweenKernel> Kernel for BetweenKernelAdapter<V> {
88    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
89        let inputs = BetweenArgs::try_from(args)?;
90        let Some(array) = inputs.array.as_opt::<V>() else {
91            return Ok(None);
92        };
93        Ok(
94            V::between(&self.0, array, inputs.lower, inputs.upper, inputs.options)?
95                .map(|array| array.into()),
96        )
97    }
98}
99
100struct Between;
101
102impl ComputeFnVTable for Between {
103    fn invoke(
104        &self,
105        args: &InvocationArgs,
106        kernels: &[ArcRef<dyn Kernel>],
107    ) -> VortexResult<Output> {
108        let BetweenArgs {
109            array,
110            lower,
111            upper,
112            options,
113        } = BetweenArgs::try_from(args)?;
114
115        let return_dtype = self.return_dtype(args)?;
116
117        // Bail early if the array is empty.
118        if array.is_empty() {
119            return Ok(Canonical::empty(&return_dtype).into_array().into());
120        }
121
122        // A quick check to see if either array might is a null constant array.
123        // Note: Depends on returning early if array is empty for is_invalid check.
124        if (lower.is_invalid(0) || upper.is_invalid(0))
125            && let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
126            && (c_lower.is_null() || c_upper.is_null())
127        {
128            return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
129                .into_array()
130                .into());
131        }
132
133        if lower.as_constant().is_some_and(|v| v.is_null())
134            || upper.as_constant().is_some_and(|v| v.is_null())
135        {
136            return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
137                .into_array()
138                .into());
139        }
140
141        // Try each kernel
142        for kernel in kernels {
143            if let Some(output) = kernel.invoke(args)? {
144                return Ok(output);
145            }
146        }
147        if let Some(output) = array.invoke(&BETWEEN_FN, args)? {
148            return Ok(output);
149        }
150
151        // Otherwise, fall back to the default Arrow implementation
152        // TODO(joe): should we try to canonicalize the array and try between
153        Ok(boolean(
154            &compare(lower, array, options.lower_strict.to_operator())?,
155            &compare(array, upper, options.upper_strict.to_operator())?,
156            BooleanOperator::And,
157        )?
158        .into())
159    }
160
161    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
162        let BetweenArgs {
163            array,
164            lower,
165            upper,
166            options: _,
167        } = BetweenArgs::try_from(args)?;
168
169        if !array.dtype().eq_ignore_nullability(lower.dtype()) {
170            vortex_bail!(
171                "Array and lower bound types do not match: {:?} != {:?}",
172                array.dtype(),
173                lower.dtype()
174            );
175        }
176        if !array.dtype().eq_ignore_nullability(upper.dtype()) {
177            vortex_bail!(
178                "Array and upper bound types do not match: {:?} != {:?}",
179                array.dtype(),
180                upper.dtype()
181            );
182        }
183
184        Ok(DType::Bool(
185            array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability(),
186        ))
187    }
188
189    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
190        let BetweenArgs {
191            array,
192            lower,
193            upper,
194            options: _,
195        } = BetweenArgs::try_from(args)?;
196        if array.len() != lower.len() || array.len() != upper.len() {
197            vortex_bail!(
198                "Array lengths do not match: array:{} lower:{} upper:{}",
199                array.len(),
200                lower.len(),
201                upper.len()
202            );
203        }
204        Ok(array.len())
205    }
206
207    fn is_elementwise(&self) -> bool {
208        true
209    }
210}
211
212struct BetweenArgs<'a> {
213    array: &'a dyn Array,
214    lower: &'a dyn Array,
215    upper: &'a dyn Array,
216    options: &'a BetweenOptions,
217}
218
219impl<'a> TryFrom<&InvocationArgs<'a>> for BetweenArgs<'a> {
220    type Error = VortexError;
221
222    fn try_from(value: &InvocationArgs<'a>) -> VortexResult<Self> {
223        if value.inputs.len() != 3 {
224            vortex_bail!("Expected 3 inputs, found {}", value.inputs.len());
225        }
226        let array = value.inputs[0]
227            .array()
228            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
229        let lower = value.inputs[1]
230            .array()
231            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
232        let upper = value.inputs[2]
233            .array()
234            .ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
235        let options = value
236            .options
237            .as_any()
238            .downcast_ref::<BetweenOptions>()
239            .vortex_expect("Expected options to be an operator");
240
241        Ok(BetweenArgs {
242            array,
243            lower,
244            upper,
245            options,
246        })
247    }
248}
249
250#[derive(Debug, Clone, PartialEq, Eq, Hash)]
251pub struct BetweenOptions {
252    pub lower_strict: StrictComparison,
253    pub upper_strict: StrictComparison,
254}
255
256impl Options for BetweenOptions {
257    fn as_any(&self) -> &dyn Any {
258        self
259    }
260}
261
262/// Strictness of the comparison.
263#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
264pub enum StrictComparison {
265    /// Strict bound (`<`)
266    Strict,
267    /// Non-strict bound (`<=`)
268    NonStrict,
269}
270
271impl StrictComparison {
272    pub const fn to_operator(&self) -> Operator {
273        match self {
274            StrictComparison::Strict => Operator::Lt,
275            StrictComparison::NonStrict => Operator::Lte,
276        }
277    }
278
279    pub const fn is_strict(&self) -> bool {
280        matches!(self, StrictComparison::Strict)
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use vortex_buffer::buffer;
287    use vortex_dtype::Nullability;
288    use vortex_dtype::PType;
289
290    use super::*;
291    use crate::ToCanonical;
292    use crate::compute::conformance::search_sorted::rstest;
293    use crate::test_harness::to_int_indices;
294
295    #[rstest]
296    #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
297    #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
298    #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
299    #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
300    fn test_bounds(
301        #[case] lower_strict: StrictComparison,
302        #[case] upper_strict: StrictComparison,
303        #[case] expected: Vec<u64>,
304    ) {
305        let lower = buffer![0, 0, 0, 0, 2].into_array();
306        let array = buffer![1, 0, 1, 0, 1].into_array();
307        let upper = buffer![2, 1, 1, 0, 0].into_array();
308
309        let matches = between(
310            array.as_ref(),
311            lower.as_ref(),
312            upper.as_ref(),
313            &BetweenOptions {
314                lower_strict,
315                upper_strict,
316            },
317        )
318        .unwrap()
319        .to_bool();
320
321        let indices = to_int_indices(matches).unwrap();
322        assert_eq!(indices, expected);
323    }
324
325    #[test]
326    fn test_constants() {
327        let lower = buffer![0, 0, 2, 0, 2].into_array();
328        let array = buffer![1, 0, 1, 0, 1].into_array();
329
330        // upper is null
331        let upper = ConstantArray::new(
332            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
333            5,
334        );
335
336        let matches = between(
337            array.as_ref(),
338            lower.as_ref(),
339            upper.as_ref(),
340            &BetweenOptions {
341                lower_strict: StrictComparison::NonStrict,
342                upper_strict: StrictComparison::NonStrict,
343            },
344        )
345        .unwrap()
346        .to_bool();
347
348        let indices = to_int_indices(matches).unwrap();
349        assert!(indices.is_empty());
350
351        // upper is a fixed constant
352        let upper = ConstantArray::new(Scalar::from(2), 5);
353        let matches = between(
354            array.as_ref(),
355            lower.as_ref(),
356            upper.as_ref(),
357            &BetweenOptions {
358                lower_strict: StrictComparison::NonStrict,
359                upper_strict: StrictComparison::NonStrict,
360            },
361        )
362        .unwrap()
363        .to_bool();
364        let indices = to_int_indices(matches).unwrap();
365        assert_eq!(indices, vec![0, 1, 3]);
366
367        // lower is also a constant
368        let lower = ConstantArray::new(Scalar::from(0), 5);
369
370        let matches = between(
371            array.as_ref(),
372            lower.as_ref(),
373            upper.as_ref(),
374            &BetweenOptions {
375                lower_strict: StrictComparison::NonStrict,
376                upper_strict: StrictComparison::NonStrict,
377            },
378        )
379        .unwrap()
380        .to_bool();
381        let indices = to_int_indices(matches).unwrap();
382        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
383    }
384}