vortex_array/compute/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::sync::LazyLock;
8
9use arcref::ArcRef;
10use vortex_dtype::DType;
11use vortex_error::VortexError;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_err;
16use vortex_scalar::Scalar;
17
18use crate::Array;
19use crate::ArrayRef;
20use crate::Canonical;
21use crate::IntoArray;
22use crate::arrays::ConstantArray;
23use crate::compute::BooleanOperator;
24use crate::compute::ComputeFn;
25use crate::compute::ComputeFnVTable;
26use crate::compute::InvocationArgs;
27use crate::compute::Kernel;
28use crate::compute::Operator;
29use crate::compute::Options;
30use crate::compute::Output;
31use crate::compute::boolean;
32use crate::compute::compare;
33use crate::vtable::VTable;
34
35static BETWEEN_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
36    let compute = ComputeFn::new("between".into(), ArcRef::new_ref(&Between));
37    for kernel in inventory::iter::<BetweenKernelRef> {
38        compute.register_kernel(kernel.0.clone());
39    }
40    compute
41});
42
43pub(crate) fn warm_up_vtable() -> usize {
44    BETWEEN_FN.kernels().len()
45}
46
47/// Compute between (a <= x <= b).
48///
49/// This is an optimized implementation that is equivalent to `(a <= x) AND (x <= b)`.
50///
51/// The `BetweenOptions` defines if the lower or upper bounds are strict (exclusive) or non-strict
52/// (inclusive).
53pub fn between(
54    arr: &dyn Array,
55    lower: &dyn Array,
56    upper: &dyn Array,
57    options: &BetweenOptions,
58) -> VortexResult<ArrayRef> {
59    BETWEEN_FN
60        .invoke(&InvocationArgs {
61            inputs: &[arr.into(), lower.into(), upper.into()],
62            options,
63        })?
64        .unwrap_array()
65}
66
67pub struct BetweenKernelRef(ArcRef<dyn Kernel>);
68inventory::collect!(BetweenKernelRef);
69
70pub trait BetweenKernel: VTable {
71    fn between(
72        &self,
73        arr: &Self::Array,
74        lower: &dyn Array,
75        upper: &dyn Array,
76        options: &BetweenOptions,
77    ) -> VortexResult<Option<ArrayRef>>;
78}
79
80#[derive(Debug)]
81pub struct BetweenKernelAdapter<V: VTable>(pub V);
82
83impl<V: VTable + BetweenKernel> BetweenKernelAdapter<V> {
84    pub const fn lift(&'static self) -> BetweenKernelRef {
85        BetweenKernelRef(ArcRef::new_ref(self))
86    }
87}
88
89impl<V: VTable + BetweenKernel> Kernel for BetweenKernelAdapter<V> {
90    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
91        let inputs = BetweenArgs::try_from(args)?;
92        let Some(array) = inputs.array.as_opt::<V>() else {
93            return Ok(None);
94        };
95        Ok(
96            V::between(&self.0, array, inputs.lower, inputs.upper, inputs.options)?
97                .map(|array| array.into()),
98        )
99    }
100}
101
102struct Between;
103
104impl ComputeFnVTable for Between {
105    fn invoke(
106        &self,
107        args: &InvocationArgs,
108        kernels: &[ArcRef<dyn Kernel>],
109    ) -> VortexResult<Output> {
110        let BetweenArgs {
111            array,
112            lower,
113            upper,
114            options,
115        } = BetweenArgs::try_from(args)?;
116
117        let return_dtype = self.return_dtype(args)?;
118
119        // Bail early if the array is empty.
120        if array.is_empty() {
121            return Ok(Canonical::empty(&return_dtype).into_array().into());
122        }
123
124        // A quick check to see if either array might is a null constant array.
125        // Note: Depends on returning early if array is empty for is_invalid check.
126        if (lower.is_invalid(0) || upper.is_invalid(0))
127            && let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
128            && (c_lower.is_null() || c_upper.is_null())
129        {
130            return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
131                .into_array()
132                .into());
133        }
134
135        if lower.as_constant().is_some_and(|v| v.is_null())
136            || upper.as_constant().is_some_and(|v| v.is_null())
137        {
138            return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
139                .into_array()
140                .into());
141        }
142
143        // Try each kernel
144        for kernel in kernels {
145            if let Some(output) = kernel.invoke(args)? {
146                return Ok(output);
147            }
148        }
149        if let Some(output) = array.invoke(&BETWEEN_FN, args)? {
150            return Ok(output);
151        }
152
153        // Otherwise, fall back to the default Arrow implementation
154        // TODO(joe): should we try to canonicalize the array and try between
155        Ok(boolean(
156            &compare(lower, array, options.lower_strict.to_operator())?,
157            &compare(array, upper, options.upper_strict.to_operator())?,
158            BooleanOperator::And,
159        )?
160        .into())
161    }
162
163    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
164        let BetweenArgs {
165            array,
166            lower,
167            upper,
168            options: _,
169        } = BetweenArgs::try_from(args)?;
170
171        if !array.dtype().eq_ignore_nullability(lower.dtype()) {
172            vortex_bail!(
173                "Array and lower bound types do not match: {:?} != {:?}",
174                array.dtype(),
175                lower.dtype()
176            );
177        }
178        if !array.dtype().eq_ignore_nullability(upper.dtype()) {
179            vortex_bail!(
180                "Array and upper bound types do not match: {:?} != {:?}",
181                array.dtype(),
182                upper.dtype()
183            );
184        }
185
186        Ok(DType::Bool(
187            array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability(),
188        ))
189    }
190
191    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
192        let BetweenArgs {
193            array,
194            lower,
195            upper,
196            options: _,
197        } = BetweenArgs::try_from(args)?;
198        if array.len() != lower.len() || array.len() != upper.len() {
199            vortex_bail!(
200                "Array lengths do not match: array:{} lower:{} upper:{}",
201                array.len(),
202                lower.len(),
203                upper.len()
204            );
205        }
206        Ok(array.len())
207    }
208
209    fn is_elementwise(&self) -> bool {
210        true
211    }
212}
213
214struct BetweenArgs<'a> {
215    array: &'a dyn Array,
216    lower: &'a dyn Array,
217    upper: &'a dyn Array,
218    options: &'a BetweenOptions,
219}
220
221impl<'a> TryFrom<&InvocationArgs<'a>> for BetweenArgs<'a> {
222    type Error = VortexError;
223
224    fn try_from(value: &InvocationArgs<'a>) -> VortexResult<Self> {
225        if value.inputs.len() != 3 {
226            vortex_bail!("Expected 3 inputs, found {}", value.inputs.len());
227        }
228        let array = value.inputs[0]
229            .array()
230            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
231        let lower = value.inputs[1]
232            .array()
233            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
234        let upper = value.inputs[2]
235            .array()
236            .ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
237        let options = value
238            .options
239            .as_any()
240            .downcast_ref::<BetweenOptions>()
241            .vortex_expect("Expected options to be an operator");
242
243        Ok(BetweenArgs {
244            array,
245            lower,
246            upper,
247            options,
248        })
249    }
250}
251
252#[derive(Debug, Clone, PartialEq, Eq, Hash)]
253pub struct BetweenOptions {
254    pub lower_strict: StrictComparison,
255    pub upper_strict: StrictComparison,
256}
257
258impl Display for BetweenOptions {
259    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
260        let lower_op = if self.lower_strict.is_strict() {
261            "<"
262        } else {
263            "<="
264        };
265        let upper_op = if self.upper_strict.is_strict() {
266            "<"
267        } else {
268            "<="
269        };
270        write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
271    }
272}
273
274impl Options for BetweenOptions {
275    fn as_any(&self) -> &dyn Any {
276        self
277    }
278}
279
280/// Strictness of the comparison.
281#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
282pub enum StrictComparison {
283    /// Strict bound (`<`)
284    Strict,
285    /// Non-strict bound (`<=`)
286    NonStrict,
287}
288
289impl StrictComparison {
290    pub const fn to_operator(&self) -> Operator {
291        match self {
292            StrictComparison::Strict => Operator::Lt,
293            StrictComparison::NonStrict => Operator::Lte,
294        }
295    }
296
297    pub const fn is_strict(&self) -> bool {
298        matches!(self, StrictComparison::Strict)
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use vortex_buffer::buffer;
305    use vortex_dtype::Nullability;
306    use vortex_dtype::PType;
307
308    use super::*;
309    use crate::ToCanonical;
310    use crate::compute::conformance::search_sorted::rstest;
311    use crate::test_harness::to_int_indices;
312
313    #[rstest]
314    #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
315    #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
316    #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
317    #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
318    fn test_bounds(
319        #[case] lower_strict: StrictComparison,
320        #[case] upper_strict: StrictComparison,
321        #[case] expected: Vec<u64>,
322    ) {
323        let lower = buffer![0, 0, 0, 0, 2].into_array();
324        let array = buffer![1, 0, 1, 0, 1].into_array();
325        let upper = buffer![2, 1, 1, 0, 0].into_array();
326
327        let matches = between(
328            array.as_ref(),
329            lower.as_ref(),
330            upper.as_ref(),
331            &BetweenOptions {
332                lower_strict,
333                upper_strict,
334            },
335        )
336        .unwrap()
337        .to_bool();
338
339        let indices = to_int_indices(matches).unwrap();
340        assert_eq!(indices, expected);
341    }
342
343    #[test]
344    fn test_constants() {
345        let lower = buffer![0, 0, 2, 0, 2].into_array();
346        let array = buffer![1, 0, 1, 0, 1].into_array();
347
348        // upper is null
349        let upper = ConstantArray::new(
350            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
351            5,
352        );
353
354        let matches = between(
355            array.as_ref(),
356            lower.as_ref(),
357            upper.as_ref(),
358            &BetweenOptions {
359                lower_strict: StrictComparison::NonStrict,
360                upper_strict: StrictComparison::NonStrict,
361            },
362        )
363        .unwrap()
364        .to_bool();
365
366        let indices = to_int_indices(matches).unwrap();
367        assert!(indices.is_empty());
368
369        // upper is a fixed constant
370        let upper = ConstantArray::new(Scalar::from(2), 5);
371        let matches = between(
372            array.as_ref(),
373            lower.as_ref(),
374            upper.as_ref(),
375            &BetweenOptions {
376                lower_strict: StrictComparison::NonStrict,
377                upper_strict: StrictComparison::NonStrict,
378            },
379        )
380        .unwrap()
381        .to_bool();
382        let indices = to_int_indices(matches).unwrap();
383        assert_eq!(indices, vec![0, 1, 3]);
384
385        // lower is also a constant
386        let lower = ConstantArray::new(Scalar::from(0), 5);
387
388        let matches = between(
389            array.as_ref(),
390            lower.as_ref(),
391            upper.as_ref(),
392            &BetweenOptions {
393                lower_strict: StrictComparison::NonStrict,
394                upper_strict: StrictComparison::NonStrict,
395            },
396        )
397        .unwrap()
398        .to_bool();
399        let indices = to_int_indices(matches).unwrap();
400        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
401    }
402}