vortex_array/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_dtype::DType;
9use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::Scalar;
11
12use crate::arrays::ConstantArray;
13use crate::compute::{
14    BooleanOperator, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Options, Output,
15    boolean, compare,
16};
17use crate::vtable::VTable;
18use crate::{Array, ArrayRef, Canonical, IntoArray};
19
20static BETWEEN_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
21    let compute = ComputeFn::new("between".into(), ArcRef::new_ref(&Between));
22    for kernel in inventory::iter::<BetweenKernelRef> {
23        compute.register_kernel(kernel.0.clone());
24    }
25    compute
26});
27
28pub(crate) fn warm_up_vtable() -> usize {
29    BETWEEN_FN.kernels().len()
30}
31
32/// Compute between (a <= x <= b).
33///
34/// This is an optimized implementation that is equivalent to `(a <= x) AND (x <= b)`.
35///
36/// The `BetweenOptions` defines if the lower or upper bounds are strict (exclusive) or non-strict
37/// (inclusive).
38pub fn between(
39    arr: &dyn Array,
40    lower: &dyn Array,
41    upper: &dyn Array,
42    options: &BetweenOptions,
43) -> VortexResult<ArrayRef> {
44    BETWEEN_FN
45        .invoke(&InvocationArgs {
46            inputs: &[arr.into(), lower.into(), upper.into()],
47            options,
48        })?
49        .unwrap_array()
50}
51
52pub struct BetweenKernelRef(ArcRef<dyn Kernel>);
53inventory::collect!(BetweenKernelRef);
54
55pub trait BetweenKernel: VTable {
56    fn between(
57        &self,
58        arr: &Self::Array,
59        lower: &dyn Array,
60        upper: &dyn Array,
61        options: &BetweenOptions,
62    ) -> VortexResult<Option<ArrayRef>>;
63}
64
65#[derive(Debug)]
66pub struct BetweenKernelAdapter<V: VTable>(pub V);
67
68impl<V: VTable + BetweenKernel> BetweenKernelAdapter<V> {
69    pub const fn lift(&'static self) -> BetweenKernelRef {
70        BetweenKernelRef(ArcRef::new_ref(self))
71    }
72}
73
74impl<V: VTable + BetweenKernel> Kernel for BetweenKernelAdapter<V> {
75    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
76        let inputs = BetweenArgs::try_from(args)?;
77        let Some(array) = inputs.array.as_opt::<V>() else {
78            return Ok(None);
79        };
80        Ok(
81            V::between(&self.0, array, inputs.lower, inputs.upper, inputs.options)?
82                .map(|array| array.into()),
83        )
84    }
85}
86
87struct Between;
88
89impl ComputeFnVTable for Between {
90    fn invoke(
91        &self,
92        args: &InvocationArgs,
93        kernels: &[ArcRef<dyn Kernel>],
94    ) -> VortexResult<Output> {
95        let BetweenArgs {
96            array,
97            lower,
98            upper,
99            options,
100        } = BetweenArgs::try_from(args)?;
101
102        let return_dtype = self.return_dtype(args)?;
103
104        // Bail early if the array is empty.
105        if array.is_empty() {
106            return Ok(Canonical::empty(&return_dtype).into_array().into());
107        }
108
109        // A quick check to see if either array might is a null constant array.
110        // Note: Depends on returning early if array is empty for is_invalid check.
111        if (lower.is_invalid(0) || upper.is_invalid(0))
112            && let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
113            && (c_lower.is_null() || c_upper.is_null())
114        {
115            return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
116                .into_array()
117                .into());
118        }
119
120        if lower.as_constant().is_some_and(|v| v.is_null())
121            || upper.as_constant().is_some_and(|v| v.is_null())
122        {
123            return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
124                .into_array()
125                .into());
126        }
127
128        // Try each kernel
129        for kernel in kernels {
130            if let Some(output) = kernel.invoke(args)? {
131                return Ok(output);
132            }
133        }
134        if let Some(output) = array.invoke(&BETWEEN_FN, args)? {
135            return Ok(output);
136        }
137
138        // Otherwise, fall back to the default Arrow implementation
139        // TODO(joe): should we try to canonicalize the array and try between
140        Ok(boolean(
141            &compare(lower, array, options.lower_strict.to_operator())?,
142            &compare(array, upper, options.upper_strict.to_operator())?,
143            BooleanOperator::And,
144        )?
145        .into())
146    }
147
148    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
149        let BetweenArgs {
150            array,
151            lower,
152            upper,
153            options: _,
154        } = BetweenArgs::try_from(args)?;
155
156        if !array.dtype().eq_ignore_nullability(lower.dtype()) {
157            vortex_bail!(
158                "Array and lower bound types do not match: {:?} != {:?}",
159                array.dtype(),
160                lower.dtype()
161            );
162        }
163        if !array.dtype().eq_ignore_nullability(upper.dtype()) {
164            vortex_bail!(
165                "Array and upper bound types do not match: {:?} != {:?}",
166                array.dtype(),
167                upper.dtype()
168            );
169        }
170
171        Ok(DType::Bool(
172            array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability(),
173        ))
174    }
175
176    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
177        let BetweenArgs {
178            array,
179            lower,
180            upper,
181            options: _,
182        } = BetweenArgs::try_from(args)?;
183        if array.len() != lower.len() || array.len() != upper.len() {
184            vortex_bail!(
185                "Array lengths do not match: array:{} lower:{} upper:{}",
186                array.len(),
187                lower.len(),
188                upper.len()
189            );
190        }
191        Ok(array.len())
192    }
193
194    fn is_elementwise(&self) -> bool {
195        true
196    }
197}
198
199struct BetweenArgs<'a> {
200    array: &'a dyn Array,
201    lower: &'a dyn Array,
202    upper: &'a dyn Array,
203    options: &'a BetweenOptions,
204}
205
206impl<'a> TryFrom<&InvocationArgs<'a>> for BetweenArgs<'a> {
207    type Error = VortexError;
208
209    fn try_from(value: &InvocationArgs<'a>) -> VortexResult<Self> {
210        if value.inputs.len() != 3 {
211            vortex_bail!("Expected 3 inputs, found {}", value.inputs.len());
212        }
213        let array = value.inputs[0]
214            .array()
215            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
216        let lower = value.inputs[1]
217            .array()
218            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
219        let upper = value.inputs[2]
220            .array()
221            .ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
222        let options = value
223            .options
224            .as_any()
225            .downcast_ref::<BetweenOptions>()
226            .vortex_expect("Expected options to be an operator");
227
228        Ok(BetweenArgs {
229            array,
230            lower,
231            upper,
232            options,
233        })
234    }
235}
236
237#[derive(Debug, Clone, PartialEq, Eq, Hash)]
238pub struct BetweenOptions {
239    pub lower_strict: StrictComparison,
240    pub upper_strict: StrictComparison,
241}
242
243impl Options for BetweenOptions {
244    fn as_any(&self) -> &dyn Any {
245        self
246    }
247}
248
249/// Strictness of the comparison.
250#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
251pub enum StrictComparison {
252    /// Strict bound (`<`)
253    Strict,
254    /// Non-strict bound (`<=`)
255    NonStrict,
256}
257
258impl StrictComparison {
259    pub const fn to_operator(&self) -> Operator {
260        match self {
261            StrictComparison::Strict => Operator::Lt,
262            StrictComparison::NonStrict => Operator::Lte,
263        }
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use vortex_buffer::buffer;
270    use vortex_dtype::{Nullability, PType};
271
272    use super::*;
273    use crate::ToCanonical;
274    use crate::compute::conformance::search_sorted::rstest;
275    use crate::test_harness::to_int_indices;
276
277    #[rstest]
278    #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
279    #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
280    #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
281    #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
282    fn test_bounds(
283        #[case] lower_strict: StrictComparison,
284        #[case] upper_strict: StrictComparison,
285        #[case] expected: Vec<u64>,
286    ) {
287        let lower = buffer![0, 0, 0, 0, 2].into_array();
288        let array = buffer![1, 0, 1, 0, 1].into_array();
289        let upper = buffer![2, 1, 1, 0, 0].into_array();
290
291        let matches = between(
292            array.as_ref(),
293            lower.as_ref(),
294            upper.as_ref(),
295            &BetweenOptions {
296                lower_strict,
297                upper_strict,
298            },
299        )
300        .unwrap()
301        .to_bool();
302
303        let indices = to_int_indices(matches).unwrap();
304        assert_eq!(indices, expected);
305    }
306
307    #[test]
308    fn test_constants() {
309        let lower = buffer![0, 0, 2, 0, 2].into_array();
310        let array = buffer![1, 0, 1, 0, 1].into_array();
311
312        // upper is null
313        let upper = ConstantArray::new(
314            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
315            5,
316        );
317
318        let matches = between(
319            array.as_ref(),
320            lower.as_ref(),
321            upper.as_ref(),
322            &BetweenOptions {
323                lower_strict: StrictComparison::NonStrict,
324                upper_strict: StrictComparison::NonStrict,
325            },
326        )
327        .unwrap()
328        .to_bool();
329
330        let indices = to_int_indices(matches).unwrap();
331        assert!(indices.is_empty());
332
333        // upper is a fixed constant
334        let upper = ConstantArray::new(Scalar::from(2), 5);
335        let matches = between(
336            array.as_ref(),
337            lower.as_ref(),
338            upper.as_ref(),
339            &BetweenOptions {
340                lower_strict: StrictComparison::NonStrict,
341                upper_strict: StrictComparison::NonStrict,
342            },
343        )
344        .unwrap()
345        .to_bool();
346        let indices = to_int_indices(matches).unwrap();
347        assert_eq!(indices, vec![0, 1, 3]);
348
349        // lower is also a constant
350        let lower = ConstantArray::new(Scalar::from(0), 5);
351
352        let matches = between(
353            array.as_ref(),
354            lower.as_ref(),
355            upper.as_ref(),
356            &BetweenOptions {
357                lower_strict: StrictComparison::NonStrict,
358                upper_strict: StrictComparison::NonStrict,
359            },
360        )
361        .unwrap()
362        .to_bool();
363        let indices = to_int_indices(matches).unwrap();
364        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
365    }
366}