vortex_array/compute/
between.rs

1use std::any::Any;
2use std::sync::LazyLock;
3
4use arcref::ArcRef;
5use vortex_dtype::DType;
6use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
7use vortex_scalar::Scalar;
8
9use crate::arrays::ConstantArray;
10use crate::compute::{
11    BooleanOperator, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Options, Output,
12    boolean, compare,
13};
14use crate::vtable::VTable;
15use crate::{Array, ArrayRef, Canonical, IntoArray};
16
17/// Compute between (a <= x <= b), this can be implemented using compare and boolean and but this
18/// will likely have a lower runtime.
19///
20/// This semantics is equivalent to:
21/// ```
22/// use vortex_array::{Array, ArrayRef};
23/// use vortex_array::compute::{boolean, compare, BetweenOptions, BooleanOperator, Operator};///
24/// use vortex_error::VortexResult;
25///
26/// fn between(
27///    arr: &dyn Array,
28///    lower: &dyn Array,
29///    upper: &dyn Array,
30///    options: &BetweenOptions
31/// ) -> VortexResult<ArrayRef> {
32///     boolean(
33///         &compare(lower, arr, options.lower_strict.to_operator())?,
34///         &compare(arr, upper,  options.upper_strict.to_operator())?,
35///         BooleanOperator::And
36///     )
37/// }
38///  ```
39///
40/// The BetweenOptions { lower: StrictComparison, upper: StrictComparison } defines if the
41/// value is < (strict) or <= (non-strict).
42///
43pub fn between(
44    arr: &dyn Array,
45    lower: &dyn Array,
46    upper: &dyn Array,
47    options: &BetweenOptions,
48) -> VortexResult<ArrayRef> {
49    BETWEEN_FN
50        .invoke(&InvocationArgs {
51            inputs: &[arr.into(), lower.into(), upper.into()],
52            options,
53        })?
54        .unwrap_array()
55}
56
57pub struct BetweenKernelRef(ArcRef<dyn Kernel>);
58inventory::collect!(BetweenKernelRef);
59
60pub trait BetweenKernel: VTable {
61    fn between(
62        &self,
63        arr: &Self::Array,
64        lower: &dyn Array,
65        upper: &dyn Array,
66        options: &BetweenOptions,
67    ) -> VortexResult<Option<ArrayRef>>;
68}
69
70#[derive(Debug)]
71pub struct BetweenKernelAdapter<V: VTable>(pub V);
72
73impl<V: VTable + BetweenKernel> BetweenKernelAdapter<V> {
74    pub const fn lift(&'static self) -> BetweenKernelRef {
75        BetweenKernelRef(ArcRef::new_ref(self))
76    }
77}
78
79impl<V: VTable + BetweenKernel> Kernel for BetweenKernelAdapter<V> {
80    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
81        let inputs = BetweenArgs::try_from(args)?;
82        let Some(array) = inputs.array.as_opt::<V>() else {
83            return Ok(None);
84        };
85        Ok(
86            V::between(&self.0, array, inputs.lower, inputs.upper, inputs.options)?
87                .map(|array| array.into()),
88        )
89    }
90}
91
92pub static BETWEEN_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
93    let compute = ComputeFn::new("between".into(), ArcRef::new_ref(&Between));
94    for kernel in inventory::iter::<BetweenKernelRef> {
95        compute.register_kernel(kernel.0.clone());
96    }
97    compute
98});
99
100struct Between;
101
102impl ComputeFnVTable for Between {
103    fn invoke(
104        &self,
105        args: &InvocationArgs,
106        kernels: &[ArcRef<dyn Kernel>],
107    ) -> VortexResult<Output> {
108        let BetweenArgs {
109            array,
110            lower,
111            upper,
112            options,
113        } = BetweenArgs::try_from(args)?;
114
115        let return_dtype = self.return_dtype(args)?;
116
117        // Bail early if the array is empty.
118        if array.is_empty() {
119            return Ok(Canonical::empty(&return_dtype).into_array().into());
120        }
121
122        // A quick check to see if either array might is a null constant array.
123        // Note: Depends on returning early if array is empty for is_invalid check.
124        if lower.is_invalid(0)? || upper.is_invalid(0)? {
125            if let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant()) {
126                if c_lower.is_null() || c_upper.is_null() {
127                    return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
128                        .into_array()
129                        .into());
130                }
131            }
132        }
133
134        if lower.as_constant().is_some_and(|v| v.is_null())
135            || upper.as_constant().is_some_and(|v| v.is_null())
136        {
137            return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
138                .into_array()
139                .into());
140        }
141
142        // Try each kernel
143        for kernel in kernels {
144            if let Some(output) = kernel.invoke(args)? {
145                return Ok(output);
146            }
147        }
148        if let Some(output) = array.invoke(&BETWEEN_FN, args)? {
149            return Ok(output);
150        }
151
152        // Otherwise, fall back to the default Arrow implementation
153        // TODO(joe): should we try to canonicalize the array and try between
154        Ok(boolean(
155            &compare(lower, array, options.lower_strict.to_operator())?,
156            &compare(array, upper, options.upper_strict.to_operator())?,
157            BooleanOperator::And,
158        )?
159        .into())
160    }
161
162    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
163        let BetweenArgs {
164            array,
165            lower,
166            upper,
167            options: _,
168        } = BetweenArgs::try_from(args)?;
169
170        if !array.dtype().eq_ignore_nullability(lower.dtype()) {
171            vortex_bail!(
172                "Array and lower bound types do not match: {:?} != {:?}",
173                array.dtype(),
174                lower.dtype()
175            );
176        }
177        if !array.dtype().eq_ignore_nullability(upper.dtype()) {
178            vortex_bail!(
179                "Array and upper bound types do not match: {:?} != {:?}",
180                array.dtype(),
181                upper.dtype()
182            );
183        }
184
185        Ok(DType::Bool(
186            array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability(),
187        ))
188    }
189
190    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
191        let BetweenArgs {
192            array,
193            lower,
194            upper,
195            options: _,
196        } = BetweenArgs::try_from(args)?;
197        if array.len() != lower.len() || array.len() != upper.len() {
198            vortex_bail!(
199                "Array lengths do not match: array:{} lower:{} upper:{}",
200                array.len(),
201                lower.len(),
202                upper.len()
203            );
204        }
205        Ok(array.len())
206    }
207
208    fn is_elementwise(&self) -> bool {
209        true
210    }
211}
212
213struct BetweenArgs<'a> {
214    array: &'a dyn Array,
215    lower: &'a dyn Array,
216    upper: &'a dyn Array,
217    options: &'a BetweenOptions,
218}
219
220impl<'a> TryFrom<&InvocationArgs<'a>> for BetweenArgs<'a> {
221    type Error = VortexError;
222
223    fn try_from(value: &InvocationArgs<'a>) -> VortexResult<Self> {
224        if value.inputs.len() != 3 {
225            vortex_bail!("Expected 3 inputs, found {}", value.inputs.len());
226        }
227        let array = value.inputs[0]
228            .array()
229            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
230        let lower = value.inputs[1]
231            .array()
232            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
233        let upper = value.inputs[2]
234            .array()
235            .ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
236        let options = value
237            .options
238            .as_any()
239            .downcast_ref::<BetweenOptions>()
240            .vortex_expect("Expected options to be an operator");
241
242        Ok(BetweenArgs {
243            array,
244            lower,
245            upper,
246            options,
247        })
248    }
249}
250
251#[derive(Debug, Clone, PartialEq, Eq, Hash)]
252pub struct BetweenOptions {
253    pub lower_strict: StrictComparison,
254    pub upper_strict: StrictComparison,
255}
256
257impl Options for BetweenOptions {
258    fn as_any(&self) -> &dyn Any {
259        self
260    }
261}
262
263#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
264pub enum StrictComparison {
265    Strict,
266    NonStrict,
267}
268
269impl StrictComparison {
270    pub const fn to_operator(&self) -> Operator {
271        match self {
272            StrictComparison::Strict => Operator::Lt,
273            StrictComparison::NonStrict => Operator::Lte,
274        }
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use vortex_dtype::{Nullability, PType};
281
282    use super::*;
283    use crate::ToCanonical;
284    use crate::arrays::PrimitiveArray;
285    use crate::compute::conformance::search_sorted::rstest;
286    use crate::test_harness::to_int_indices;
287
288    #[rstest]
289    #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
290    #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
291    #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
292    #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
293    fn test_bounds(
294        #[case] lower_strict: StrictComparison,
295        #[case] upper_strict: StrictComparison,
296        #[case] expected: Vec<u64>,
297    ) {
298        let lower = PrimitiveArray::from_iter([0, 0, 0, 0, 2]);
299        let array = PrimitiveArray::from_iter([1, 0, 1, 0, 1]);
300        let upper = PrimitiveArray::from_iter([2, 1, 1, 0, 0]);
301
302        let matches = between(
303            array.as_ref(),
304            lower.as_ref(),
305            upper.as_ref(),
306            &BetweenOptions {
307                lower_strict,
308                upper_strict,
309            },
310        )
311        .unwrap()
312        .to_bool()
313        .unwrap();
314
315        let indices = to_int_indices(matches).unwrap();
316        assert_eq!(indices, expected);
317    }
318
319    #[test]
320    fn test_constants() {
321        let lower = PrimitiveArray::from_iter([0, 0, 2, 0, 2]);
322        let array = PrimitiveArray::from_iter([1, 0, 1, 0, 1]);
323
324        // upper is null
325        let upper = ConstantArray::new(
326            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
327            5,
328        );
329
330        let matches = between(
331            array.as_ref(),
332            lower.as_ref(),
333            upper.as_ref(),
334            &BetweenOptions {
335                lower_strict: StrictComparison::NonStrict,
336                upper_strict: StrictComparison::NonStrict,
337            },
338        )
339        .unwrap()
340        .to_bool()
341        .unwrap();
342
343        let indices = to_int_indices(matches).unwrap();
344        assert!(indices.is_empty());
345
346        // upper is a fixed constant
347        let upper = ConstantArray::new(Scalar::from(2), 5);
348        let matches = between(
349            array.as_ref(),
350            lower.as_ref(),
351            upper.as_ref(),
352            &BetweenOptions {
353                lower_strict: StrictComparison::NonStrict,
354                upper_strict: StrictComparison::NonStrict,
355            },
356        )
357        .unwrap()
358        .to_bool()
359        .unwrap();
360        let indices = to_int_indices(matches).unwrap();
361        assert_eq!(indices, vec![0, 1, 3]);
362
363        // lower is also a constant
364        let lower = ConstantArray::new(Scalar::from(0), 5);
365
366        let matches = between(
367            array.as_ref(),
368            lower.as_ref(),
369            upper.as_ref(),
370            &BetweenOptions {
371                lower_strict: StrictComparison::NonStrict,
372                upper_strict: StrictComparison::NonStrict,
373            },
374        )
375        .unwrap()
376        .to_bool()
377        .unwrap();
378        let indices = to_int_indices(matches).unwrap();
379        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
380    }
381}