vortex_array/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_dtype::DType;
9use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::Scalar;
11
12use crate::arrays::ConstantArray;
13use crate::compute::{
14    BooleanOperator, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Options, Output,
15    boolean, compare,
16};
17use crate::vtable::VTable;
18use crate::{Array, ArrayRef, Canonical, IntoArray};
19
20static BETWEEN_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
21    let compute = ComputeFn::new("between".into(), ArcRef::new_ref(&Between));
22    for kernel in inventory::iter::<BetweenKernelRef> {
23        compute.register_kernel(kernel.0.clone());
24    }
25    compute
26});
27
28/// Compute between (a <= x <= b).
29///
30/// This is an optimized implementation that is equivalent to `(a <= x) AND (x <= b)`.
31///
32/// The `BetweenOptions` defines if the lower or upper bounds are strict (exclusive) or non-strict
33/// (inclusive).
34pub fn between(
35    arr: &dyn Array,
36    lower: &dyn Array,
37    upper: &dyn Array,
38    options: &BetweenOptions,
39) -> VortexResult<ArrayRef> {
40    BETWEEN_FN
41        .invoke(&InvocationArgs {
42            inputs: &[arr.into(), lower.into(), upper.into()],
43            options,
44        })?
45        .unwrap_array()
46}
47
48pub struct BetweenKernelRef(ArcRef<dyn Kernel>);
49inventory::collect!(BetweenKernelRef);
50
51pub trait BetweenKernel: VTable {
52    fn between(
53        &self,
54        arr: &Self::Array,
55        lower: &dyn Array,
56        upper: &dyn Array,
57        options: &BetweenOptions,
58    ) -> VortexResult<Option<ArrayRef>>;
59}
60
61#[derive(Debug)]
62pub struct BetweenKernelAdapter<V: VTable>(pub V);
63
64impl<V: VTable + BetweenKernel> BetweenKernelAdapter<V> {
65    pub const fn lift(&'static self) -> BetweenKernelRef {
66        BetweenKernelRef(ArcRef::new_ref(self))
67    }
68}
69
70impl<V: VTable + BetweenKernel> Kernel for BetweenKernelAdapter<V> {
71    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
72        let inputs = BetweenArgs::try_from(args)?;
73        let Some(array) = inputs.array.as_opt::<V>() else {
74            return Ok(None);
75        };
76        Ok(
77            V::between(&self.0, array, inputs.lower, inputs.upper, inputs.options)?
78                .map(|array| array.into()),
79        )
80    }
81}
82
83struct Between;
84
85impl ComputeFnVTable for Between {
86    fn invoke(
87        &self,
88        args: &InvocationArgs,
89        kernels: &[ArcRef<dyn Kernel>],
90    ) -> VortexResult<Output> {
91        let BetweenArgs {
92            array,
93            lower,
94            upper,
95            options,
96        } = BetweenArgs::try_from(args)?;
97
98        let return_dtype = self.return_dtype(args)?;
99
100        // Bail early if the array is empty.
101        if array.is_empty() {
102            return Ok(Canonical::empty(&return_dtype).into_array().into());
103        }
104
105        // A quick check to see if either array might is a null constant array.
106        // Note: Depends on returning early if array is empty for is_invalid check.
107        if lower.is_invalid(0)? || upper.is_invalid(0)? {
108            if let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant()) {
109                if c_lower.is_null() || c_upper.is_null() {
110                    return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
111                        .into_array()
112                        .into());
113                }
114            }
115        }
116
117        if lower.as_constant().is_some_and(|v| v.is_null())
118            || upper.as_constant().is_some_and(|v| v.is_null())
119        {
120            return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
121                .into_array()
122                .into());
123        }
124
125        // Try each kernel
126        for kernel in kernels {
127            if let Some(output) = kernel.invoke(args)? {
128                return Ok(output);
129            }
130        }
131        if let Some(output) = array.invoke(&BETWEEN_FN, args)? {
132            return Ok(output);
133        }
134
135        // Otherwise, fall back to the default Arrow implementation
136        // TODO(joe): should we try to canonicalize the array and try between
137        Ok(boolean(
138            &compare(lower, array, options.lower_strict.to_operator())?,
139            &compare(array, upper, options.upper_strict.to_operator())?,
140            BooleanOperator::And,
141        )?
142        .into())
143    }
144
145    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
146        let BetweenArgs {
147            array,
148            lower,
149            upper,
150            options: _,
151        } = BetweenArgs::try_from(args)?;
152
153        if !array.dtype().eq_ignore_nullability(lower.dtype()) {
154            vortex_bail!(
155                "Array and lower bound types do not match: {:?} != {:?}",
156                array.dtype(),
157                lower.dtype()
158            );
159        }
160        if !array.dtype().eq_ignore_nullability(upper.dtype()) {
161            vortex_bail!(
162                "Array and upper bound types do not match: {:?} != {:?}",
163                array.dtype(),
164                upper.dtype()
165            );
166        }
167
168        Ok(DType::Bool(
169            array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability(),
170        ))
171    }
172
173    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
174        let BetweenArgs {
175            array,
176            lower,
177            upper,
178            options: _,
179        } = BetweenArgs::try_from(args)?;
180        if array.len() != lower.len() || array.len() != upper.len() {
181            vortex_bail!(
182                "Array lengths do not match: array:{} lower:{} upper:{}",
183                array.len(),
184                lower.len(),
185                upper.len()
186            );
187        }
188        Ok(array.len())
189    }
190
191    fn is_elementwise(&self) -> bool {
192        true
193    }
194}
195
196struct BetweenArgs<'a> {
197    array: &'a dyn Array,
198    lower: &'a dyn Array,
199    upper: &'a dyn Array,
200    options: &'a BetweenOptions,
201}
202
203impl<'a> TryFrom<&InvocationArgs<'a>> for BetweenArgs<'a> {
204    type Error = VortexError;
205
206    fn try_from(value: &InvocationArgs<'a>) -> VortexResult<Self> {
207        if value.inputs.len() != 3 {
208            vortex_bail!("Expected 3 inputs, found {}", value.inputs.len());
209        }
210        let array = value.inputs[0]
211            .array()
212            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
213        let lower = value.inputs[1]
214            .array()
215            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
216        let upper = value.inputs[2]
217            .array()
218            .ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
219        let options = value
220            .options
221            .as_any()
222            .downcast_ref::<BetweenOptions>()
223            .vortex_expect("Expected options to be an operator");
224
225        Ok(BetweenArgs {
226            array,
227            lower,
228            upper,
229            options,
230        })
231    }
232}
233
234#[derive(Debug, Clone, PartialEq, Eq, Hash)]
235pub struct BetweenOptions {
236    pub lower_strict: StrictComparison,
237    pub upper_strict: StrictComparison,
238}
239
240impl Options for BetweenOptions {
241    fn as_any(&self) -> &dyn Any {
242        self
243    }
244}
245
246/// Strictness of the comparison.
247#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
248pub enum StrictComparison {
249    /// Strict bound (`<`)
250    Strict,
251    /// Non-strict bound (`<=`)
252    NonStrict,
253}
254
255impl StrictComparison {
256    pub const fn to_operator(&self) -> Operator {
257        match self {
258            StrictComparison::Strict => Operator::Lt,
259            StrictComparison::NonStrict => Operator::Lte,
260        }
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use vortex_dtype::{Nullability, PType};
267
268    use super::*;
269    use crate::ToCanonical;
270    use crate::arrays::PrimitiveArray;
271    use crate::compute::conformance::search_sorted::rstest;
272    use crate::test_harness::to_int_indices;
273
274    #[rstest]
275    #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
276    #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
277    #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
278    #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
279    fn test_bounds(
280        #[case] lower_strict: StrictComparison,
281        #[case] upper_strict: StrictComparison,
282        #[case] expected: Vec<u64>,
283    ) {
284        let lower = PrimitiveArray::from_iter([0, 0, 0, 0, 2]);
285        let array = PrimitiveArray::from_iter([1, 0, 1, 0, 1]);
286        let upper = PrimitiveArray::from_iter([2, 1, 1, 0, 0]);
287
288        let matches = between(
289            array.as_ref(),
290            lower.as_ref(),
291            upper.as_ref(),
292            &BetweenOptions {
293                lower_strict,
294                upper_strict,
295            },
296        )
297        .unwrap()
298        .to_bool()
299        .unwrap();
300
301        let indices = to_int_indices(matches).unwrap();
302        assert_eq!(indices, expected);
303    }
304
305    #[test]
306    fn test_constants() {
307        let lower = PrimitiveArray::from_iter([0, 0, 2, 0, 2]);
308        let array = PrimitiveArray::from_iter([1, 0, 1, 0, 1]);
309
310        // upper is null
311        let upper = ConstantArray::new(
312            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
313            5,
314        );
315
316        let matches = between(
317            array.as_ref(),
318            lower.as_ref(),
319            upper.as_ref(),
320            &BetweenOptions {
321                lower_strict: StrictComparison::NonStrict,
322                upper_strict: StrictComparison::NonStrict,
323            },
324        )
325        .unwrap()
326        .to_bool()
327        .unwrap();
328
329        let indices = to_int_indices(matches).unwrap();
330        assert!(indices.is_empty());
331
332        // upper is a fixed constant
333        let upper = ConstantArray::new(Scalar::from(2), 5);
334        let matches = between(
335            array.as_ref(),
336            lower.as_ref(),
337            upper.as_ref(),
338            &BetweenOptions {
339                lower_strict: StrictComparison::NonStrict,
340                upper_strict: StrictComparison::NonStrict,
341            },
342        )
343        .unwrap()
344        .to_bool()
345        .unwrap();
346        let indices = to_int_indices(matches).unwrap();
347        assert_eq!(indices, vec![0, 1, 3]);
348
349        // lower is also a constant
350        let lower = ConstantArray::new(Scalar::from(0), 5);
351
352        let matches = between(
353            array.as_ref(),
354            lower.as_ref(),
355            upper.as_ref(),
356            &BetweenOptions {
357                lower_strict: StrictComparison::NonStrict,
358                upper_strict: StrictComparison::NonStrict,
359            },
360        )
361        .unwrap()
362        .to_bool()
363        .unwrap();
364        let indices = to_int_indices(matches).unwrap();
365        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
366    }
367}