vortex_array/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_dtype::DType;
9use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::Scalar;
11
12use crate::arrays::ConstantArray;
13use crate::compute::{
14    BooleanOperator, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Options, Output,
15    boolean, compare,
16};
17use crate::vtable::VTable;
18use crate::{Array, ArrayRef, Canonical, IntoArray};
19
20static BETWEEN_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
21    let compute = ComputeFn::new("between".into(), ArcRef::new_ref(&Between));
22    for kernel in inventory::iter::<BetweenKernelRef> {
23        compute.register_kernel(kernel.0.clone());
24    }
25    compute
26});
27
28/// Compute between (a <= x <= b).
29///
30/// This is an optimized implementation that is equivalent to `(a <= x) AND (x <= b)`.
31///
32/// The `BetweenOptions` defines if the lower or upper bounds are strict (exclusive) or non-strict
33/// (inclusive).
34pub fn between(
35    arr: &dyn Array,
36    lower: &dyn Array,
37    upper: &dyn Array,
38    options: &BetweenOptions,
39) -> VortexResult<ArrayRef> {
40    BETWEEN_FN
41        .invoke(&InvocationArgs {
42            inputs: &[arr.into(), lower.into(), upper.into()],
43            options,
44        })?
45        .unwrap_array()
46}
47
48pub struct BetweenKernelRef(ArcRef<dyn Kernel>);
49inventory::collect!(BetweenKernelRef);
50
51pub trait BetweenKernel: VTable {
52    fn between(
53        &self,
54        arr: &Self::Array,
55        lower: &dyn Array,
56        upper: &dyn Array,
57        options: &BetweenOptions,
58    ) -> VortexResult<Option<ArrayRef>>;
59}
60
61#[derive(Debug)]
62pub struct BetweenKernelAdapter<V: VTable>(pub V);
63
64impl<V: VTable + BetweenKernel> BetweenKernelAdapter<V> {
65    pub const fn lift(&'static self) -> BetweenKernelRef {
66        BetweenKernelRef(ArcRef::new_ref(self))
67    }
68}
69
70impl<V: VTable + BetweenKernel> Kernel for BetweenKernelAdapter<V> {
71    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
72        let inputs = BetweenArgs::try_from(args)?;
73        let Some(array) = inputs.array.as_opt::<V>() else {
74            return Ok(None);
75        };
76        Ok(
77            V::between(&self.0, array, inputs.lower, inputs.upper, inputs.options)?
78                .map(|array| array.into()),
79        )
80    }
81}
82
83struct Between;
84
85impl ComputeFnVTable for Between {
86    fn invoke(
87        &self,
88        args: &InvocationArgs,
89        kernels: &[ArcRef<dyn Kernel>],
90    ) -> VortexResult<Output> {
91        let BetweenArgs {
92            array,
93            lower,
94            upper,
95            options,
96        } = BetweenArgs::try_from(args)?;
97
98        let return_dtype = self.return_dtype(args)?;
99
100        // Bail early if the array is empty.
101        if array.is_empty() {
102            return Ok(Canonical::empty(&return_dtype).into_array().into());
103        }
104
105        // A quick check to see if either array might is a null constant array.
106        // Note: Depends on returning early if array is empty for is_invalid check.
107        if (lower.is_invalid(0)? || upper.is_invalid(0)?)
108            && let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
109            && (c_lower.is_null() || c_upper.is_null())
110        {
111            return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
112                .into_array()
113                .into());
114        }
115
116        if lower.as_constant().is_some_and(|v| v.is_null())
117            || upper.as_constant().is_some_and(|v| v.is_null())
118        {
119            return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
120                .into_array()
121                .into());
122        }
123
124        // Try each kernel
125        for kernel in kernels {
126            if let Some(output) = kernel.invoke(args)? {
127                return Ok(output);
128            }
129        }
130        if let Some(output) = array.invoke(&BETWEEN_FN, args)? {
131            return Ok(output);
132        }
133
134        // Otherwise, fall back to the default Arrow implementation
135        // TODO(joe): should we try to canonicalize the array and try between
136        Ok(boolean(
137            &compare(lower, array, options.lower_strict.to_operator())?,
138            &compare(array, upper, options.upper_strict.to_operator())?,
139            BooleanOperator::And,
140        )?
141        .into())
142    }
143
144    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
145        let BetweenArgs {
146            array,
147            lower,
148            upper,
149            options: _,
150        } = BetweenArgs::try_from(args)?;
151
152        if !array.dtype().eq_ignore_nullability(lower.dtype()) {
153            vortex_bail!(
154                "Array and lower bound types do not match: {:?} != {:?}",
155                array.dtype(),
156                lower.dtype()
157            );
158        }
159        if !array.dtype().eq_ignore_nullability(upper.dtype()) {
160            vortex_bail!(
161                "Array and upper bound types do not match: {:?} != {:?}",
162                array.dtype(),
163                upper.dtype()
164            );
165        }
166
167        Ok(DType::Bool(
168            array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability(),
169        ))
170    }
171
172    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
173        let BetweenArgs {
174            array,
175            lower,
176            upper,
177            options: _,
178        } = BetweenArgs::try_from(args)?;
179        if array.len() != lower.len() || array.len() != upper.len() {
180            vortex_bail!(
181                "Array lengths do not match: array:{} lower:{} upper:{}",
182                array.len(),
183                lower.len(),
184                upper.len()
185            );
186        }
187        Ok(array.len())
188    }
189
190    fn is_elementwise(&self) -> bool {
191        true
192    }
193}
194
195struct BetweenArgs<'a> {
196    array: &'a dyn Array,
197    lower: &'a dyn Array,
198    upper: &'a dyn Array,
199    options: &'a BetweenOptions,
200}
201
202impl<'a> TryFrom<&InvocationArgs<'a>> for BetweenArgs<'a> {
203    type Error = VortexError;
204
205    fn try_from(value: &InvocationArgs<'a>) -> VortexResult<Self> {
206        if value.inputs.len() != 3 {
207            vortex_bail!("Expected 3 inputs, found {}", value.inputs.len());
208        }
209        let array = value.inputs[0]
210            .array()
211            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
212        let lower = value.inputs[1]
213            .array()
214            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
215        let upper = value.inputs[2]
216            .array()
217            .ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
218        let options = value
219            .options
220            .as_any()
221            .downcast_ref::<BetweenOptions>()
222            .vortex_expect("Expected options to be an operator");
223
224        Ok(BetweenArgs {
225            array,
226            lower,
227            upper,
228            options,
229        })
230    }
231}
232
233#[derive(Debug, Clone, PartialEq, Eq, Hash)]
234pub struct BetweenOptions {
235    pub lower_strict: StrictComparison,
236    pub upper_strict: StrictComparison,
237}
238
239impl Options for BetweenOptions {
240    fn as_any(&self) -> &dyn Any {
241        self
242    }
243}
244
245/// Strictness of the comparison.
246#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
247pub enum StrictComparison {
248    /// Strict bound (`<`)
249    Strict,
250    /// Non-strict bound (`<=`)
251    NonStrict,
252}
253
254impl StrictComparison {
255    pub const fn to_operator(&self) -> Operator {
256        match self {
257            StrictComparison::Strict => Operator::Lt,
258            StrictComparison::NonStrict => Operator::Lte,
259        }
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use vortex_dtype::{Nullability, PType};
266
267    use super::*;
268    use crate::ToCanonical;
269    use crate::arrays::PrimitiveArray;
270    use crate::compute::conformance::search_sorted::rstest;
271    use crate::test_harness::to_int_indices;
272
273    #[rstest]
274    #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
275    #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
276    #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
277    #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
278    fn test_bounds(
279        #[case] lower_strict: StrictComparison,
280        #[case] upper_strict: StrictComparison,
281        #[case] expected: Vec<u64>,
282    ) {
283        let lower = PrimitiveArray::from_iter([0, 0, 0, 0, 2]);
284        let array = PrimitiveArray::from_iter([1, 0, 1, 0, 1]);
285        let upper = PrimitiveArray::from_iter([2, 1, 1, 0, 0]);
286
287        let matches = between(
288            array.as_ref(),
289            lower.as_ref(),
290            upper.as_ref(),
291            &BetweenOptions {
292                lower_strict,
293                upper_strict,
294            },
295        )
296        .unwrap()
297        .to_bool()
298        .unwrap();
299
300        let indices = to_int_indices(matches).unwrap();
301        assert_eq!(indices, expected);
302    }
303
304    #[test]
305    fn test_constants() {
306        let lower = PrimitiveArray::from_iter([0, 0, 2, 0, 2]);
307        let array = PrimitiveArray::from_iter([1, 0, 1, 0, 1]);
308
309        // upper is null
310        let upper = ConstantArray::new(
311            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
312            5,
313        );
314
315        let matches = between(
316            array.as_ref(),
317            lower.as_ref(),
318            upper.as_ref(),
319            &BetweenOptions {
320                lower_strict: StrictComparison::NonStrict,
321                upper_strict: StrictComparison::NonStrict,
322            },
323        )
324        .unwrap()
325        .to_bool()
326        .unwrap();
327
328        let indices = to_int_indices(matches).unwrap();
329        assert!(indices.is_empty());
330
331        // upper is a fixed constant
332        let upper = ConstantArray::new(Scalar::from(2), 5);
333        let matches = between(
334            array.as_ref(),
335            lower.as_ref(),
336            upper.as_ref(),
337            &BetweenOptions {
338                lower_strict: StrictComparison::NonStrict,
339                upper_strict: StrictComparison::NonStrict,
340            },
341        )
342        .unwrap()
343        .to_bool()
344        .unwrap();
345        let indices = to_int_indices(matches).unwrap();
346        assert_eq!(indices, vec![0, 1, 3]);
347
348        // lower is also a constant
349        let lower = ConstantArray::new(Scalar::from(0), 5);
350
351        let matches = between(
352            array.as_ref(),
353            lower.as_ref(),
354            upper.as_ref(),
355            &BetweenOptions {
356                lower_strict: StrictComparison::NonStrict,
357                upper_strict: StrictComparison::NonStrict,
358            },
359        )
360        .unwrap()
361        .to_bool()
362        .unwrap();
363        let indices = to_int_indices(matches).unwrap();
364        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
365    }
366}