vortex_array/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_dtype::DType;
9use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::Scalar;
11
12use crate::arrays::ConstantArray;
13use crate::compute::{
14    BooleanOperator, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Options, Output,
15    boolean, compare,
16};
17use crate::vtable::VTable;
18use crate::{Array, ArrayRef, Canonical, IntoArray};
19
20/// Compute between (a <= x <= b), this can be implemented using compare and boolean and but this
21/// will likely have a lower runtime.
22///
23/// This semantics is equivalent to:
24/// ```
25/// use vortex_array::{Array, ArrayRef};
26/// use vortex_array::compute::{boolean, compare, BetweenOptions, BooleanOperator, Operator};///
27/// use vortex_error::VortexResult;
28///
29/// fn between(
30///    arr: &dyn Array,
31///    lower: &dyn Array,
32///    upper: &dyn Array,
33///    options: &BetweenOptions
34/// ) -> VortexResult<ArrayRef> {
35///     boolean(
36///         &compare(lower, arr, options.lower_strict.to_operator())?,
37///         &compare(arr, upper,  options.upper_strict.to_operator())?,
38///         BooleanOperator::And
39///     )
40/// }
41///  ```
42///
43/// The BetweenOptions { lower: StrictComparison, upper: StrictComparison } defines if the
44/// value is < (strict) or <= (non-strict).
45///
46pub fn between(
47    arr: &dyn Array,
48    lower: &dyn Array,
49    upper: &dyn Array,
50    options: &BetweenOptions,
51) -> VortexResult<ArrayRef> {
52    BETWEEN_FN
53        .invoke(&InvocationArgs {
54            inputs: &[arr.into(), lower.into(), upper.into()],
55            options,
56        })?
57        .unwrap_array()
58}
59
60pub struct BetweenKernelRef(ArcRef<dyn Kernel>);
61inventory::collect!(BetweenKernelRef);
62
63pub trait BetweenKernel: VTable {
64    fn between(
65        &self,
66        arr: &Self::Array,
67        lower: &dyn Array,
68        upper: &dyn Array,
69        options: &BetweenOptions,
70    ) -> VortexResult<Option<ArrayRef>>;
71}
72
73#[derive(Debug)]
74pub struct BetweenKernelAdapter<V: VTable>(pub V);
75
76impl<V: VTable + BetweenKernel> BetweenKernelAdapter<V> {
77    pub const fn lift(&'static self) -> BetweenKernelRef {
78        BetweenKernelRef(ArcRef::new_ref(self))
79    }
80}
81
82impl<V: VTable + BetweenKernel> Kernel for BetweenKernelAdapter<V> {
83    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
84        let inputs = BetweenArgs::try_from(args)?;
85        let Some(array) = inputs.array.as_opt::<V>() else {
86            return Ok(None);
87        };
88        Ok(
89            V::between(&self.0, array, inputs.lower, inputs.upper, inputs.options)?
90                .map(|array| array.into()),
91        )
92    }
93}
94
95pub static BETWEEN_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
96    let compute = ComputeFn::new("between".into(), ArcRef::new_ref(&Between));
97    for kernel in inventory::iter::<BetweenKernelRef> {
98        compute.register_kernel(kernel.0.clone());
99    }
100    compute
101});
102
103struct Between;
104
105impl ComputeFnVTable for Between {
106    fn invoke(
107        &self,
108        args: &InvocationArgs,
109        kernels: &[ArcRef<dyn Kernel>],
110    ) -> VortexResult<Output> {
111        let BetweenArgs {
112            array,
113            lower,
114            upper,
115            options,
116        } = BetweenArgs::try_from(args)?;
117
118        let return_dtype = self.return_dtype(args)?;
119
120        // Bail early if the array is empty.
121        if array.is_empty() {
122            return Ok(Canonical::empty(&return_dtype).into_array().into());
123        }
124
125        // A quick check to see if either array might is a null constant array.
126        // Note: Depends on returning early if array is empty for is_invalid check.
127        if lower.is_invalid(0)? || upper.is_invalid(0)? {
128            if let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant()) {
129                if c_lower.is_null() || c_upper.is_null() {
130                    return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
131                        .into_array()
132                        .into());
133                }
134            }
135        }
136
137        if lower.as_constant().is_some_and(|v| v.is_null())
138            || upper.as_constant().is_some_and(|v| v.is_null())
139        {
140            return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
141                .into_array()
142                .into());
143        }
144
145        // Try each kernel
146        for kernel in kernels {
147            if let Some(output) = kernel.invoke(args)? {
148                return Ok(output);
149            }
150        }
151        if let Some(output) = array.invoke(&BETWEEN_FN, args)? {
152            return Ok(output);
153        }
154
155        // Otherwise, fall back to the default Arrow implementation
156        // TODO(joe): should we try to canonicalize the array and try between
157        Ok(boolean(
158            &compare(lower, array, options.lower_strict.to_operator())?,
159            &compare(array, upper, options.upper_strict.to_operator())?,
160            BooleanOperator::And,
161        )?
162        .into())
163    }
164
165    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
166        let BetweenArgs {
167            array,
168            lower,
169            upper,
170            options: _,
171        } = BetweenArgs::try_from(args)?;
172
173        if !array.dtype().eq_ignore_nullability(lower.dtype()) {
174            vortex_bail!(
175                "Array and lower bound types do not match: {:?} != {:?}",
176                array.dtype(),
177                lower.dtype()
178            );
179        }
180        if !array.dtype().eq_ignore_nullability(upper.dtype()) {
181            vortex_bail!(
182                "Array and upper bound types do not match: {:?} != {:?}",
183                array.dtype(),
184                upper.dtype()
185            );
186        }
187
188        Ok(DType::Bool(
189            array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability(),
190        ))
191    }
192
193    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
194        let BetweenArgs {
195            array,
196            lower,
197            upper,
198            options: _,
199        } = BetweenArgs::try_from(args)?;
200        if array.len() != lower.len() || array.len() != upper.len() {
201            vortex_bail!(
202                "Array lengths do not match: array:{} lower:{} upper:{}",
203                array.len(),
204                lower.len(),
205                upper.len()
206            );
207        }
208        Ok(array.len())
209    }
210
211    fn is_elementwise(&self) -> bool {
212        true
213    }
214}
215
216struct BetweenArgs<'a> {
217    array: &'a dyn Array,
218    lower: &'a dyn Array,
219    upper: &'a dyn Array,
220    options: &'a BetweenOptions,
221}
222
223impl<'a> TryFrom<&InvocationArgs<'a>> for BetweenArgs<'a> {
224    type Error = VortexError;
225
226    fn try_from(value: &InvocationArgs<'a>) -> VortexResult<Self> {
227        if value.inputs.len() != 3 {
228            vortex_bail!("Expected 3 inputs, found {}", value.inputs.len());
229        }
230        let array = value.inputs[0]
231            .array()
232            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
233        let lower = value.inputs[1]
234            .array()
235            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
236        let upper = value.inputs[2]
237            .array()
238            .ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
239        let options = value
240            .options
241            .as_any()
242            .downcast_ref::<BetweenOptions>()
243            .vortex_expect("Expected options to be an operator");
244
245        Ok(BetweenArgs {
246            array,
247            lower,
248            upper,
249            options,
250        })
251    }
252}
253
254#[derive(Debug, Clone, PartialEq, Eq, Hash)]
255pub struct BetweenOptions {
256    pub lower_strict: StrictComparison,
257    pub upper_strict: StrictComparison,
258}
259
260impl Options for BetweenOptions {
261    fn as_any(&self) -> &dyn Any {
262        self
263    }
264}
265
266#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
267pub enum StrictComparison {
268    Strict,
269    NonStrict,
270}
271
272impl StrictComparison {
273    pub const fn to_operator(&self) -> Operator {
274        match self {
275            StrictComparison::Strict => Operator::Lt,
276            StrictComparison::NonStrict => Operator::Lte,
277        }
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use vortex_dtype::{Nullability, PType};
284
285    use super::*;
286    use crate::ToCanonical;
287    use crate::arrays::PrimitiveArray;
288    use crate::compute::conformance::search_sorted::rstest;
289    use crate::test_harness::to_int_indices;
290
291    #[rstest]
292    #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
293    #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
294    #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
295    #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
296    fn test_bounds(
297        #[case] lower_strict: StrictComparison,
298        #[case] upper_strict: StrictComparison,
299        #[case] expected: Vec<u64>,
300    ) {
301        let lower = PrimitiveArray::from_iter([0, 0, 0, 0, 2]);
302        let array = PrimitiveArray::from_iter([1, 0, 1, 0, 1]);
303        let upper = PrimitiveArray::from_iter([2, 1, 1, 0, 0]);
304
305        let matches = between(
306            array.as_ref(),
307            lower.as_ref(),
308            upper.as_ref(),
309            &BetweenOptions {
310                lower_strict,
311                upper_strict,
312            },
313        )
314        .unwrap()
315        .to_bool()
316        .unwrap();
317
318        let indices = to_int_indices(matches).unwrap();
319        assert_eq!(indices, expected);
320    }
321
322    #[test]
323    fn test_constants() {
324        let lower = PrimitiveArray::from_iter([0, 0, 2, 0, 2]);
325        let array = PrimitiveArray::from_iter([1, 0, 1, 0, 1]);
326
327        // upper is null
328        let upper = ConstantArray::new(
329            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
330            5,
331        );
332
333        let matches = between(
334            array.as_ref(),
335            lower.as_ref(),
336            upper.as_ref(),
337            &BetweenOptions {
338                lower_strict: StrictComparison::NonStrict,
339                upper_strict: StrictComparison::NonStrict,
340            },
341        )
342        .unwrap()
343        .to_bool()
344        .unwrap();
345
346        let indices = to_int_indices(matches).unwrap();
347        assert!(indices.is_empty());
348
349        // upper is a fixed constant
350        let upper = ConstantArray::new(Scalar::from(2), 5);
351        let matches = between(
352            array.as_ref(),
353            lower.as_ref(),
354            upper.as_ref(),
355            &BetweenOptions {
356                lower_strict: StrictComparison::NonStrict,
357                upper_strict: StrictComparison::NonStrict,
358            },
359        )
360        .unwrap()
361        .to_bool()
362        .unwrap();
363        let indices = to_int_indices(matches).unwrap();
364        assert_eq!(indices, vec![0, 1, 3]);
365
366        // lower is also a constant
367        let lower = ConstantArray::new(Scalar::from(0), 5);
368
369        let matches = between(
370            array.as_ref(),
371            lower.as_ref(),
372            upper.as_ref(),
373            &BetweenOptions {
374                lower_strict: StrictComparison::NonStrict,
375                upper_strict: StrictComparison::NonStrict,
376            },
377        )
378        .unwrap()
379        .to_bool()
380        .unwrap();
381        let indices = to_int_indices(matches).unwrap();
382        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
383    }
384}