vortex_array/compute/
between.rs

1use std::any::Any;
2use std::sync::LazyLock;
3
4use vortex_dtype::DType;
5use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
6use vortex_scalar::Scalar;
7
8use crate::arcref::ArcRef;
9use crate::arrays::ConstantArray;
10use crate::compute::{
11    BooleanOperator, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Options, Output,
12    boolean, compare,
13};
14use crate::{Array, ArrayRef, Canonical, Encoding, IntoArray};
15
16/// Compute between (a <= x <= b), this can be implemented using compare and boolean and but this
17/// will likely have a lower runtime.
18///
19/// This semantics is equivalent to:
20/// ```
21/// use vortex_array::{Array, ArrayRef};
22/// use vortex_array::compute::{boolean, compare, BetweenOptions, BooleanOperator, Operator};///
23/// use vortex_error::VortexResult;
24///
25/// fn between(
26///    arr: &dyn Array,
27///    lower: &dyn Array,
28///    upper: &dyn Array,
29///    options: &BetweenOptions
30/// ) -> VortexResult<ArrayRef> {
31///     boolean(
32///         &compare(lower, arr, options.lower_strict.to_operator())?,
33///         &compare(arr, upper,  options.upper_strict.to_operator())?,
34///         BooleanOperator::And
35///     )
36/// }
37///  ```
38///
39/// The BetweenOptions { lower: StrictComparison, upper: StrictComparison } defines if the
40/// value is < (strict) or <= (non-strict).
41///
42pub fn between(
43    arr: &dyn Array,
44    lower: &dyn Array,
45    upper: &dyn Array,
46    options: &BetweenOptions,
47) -> VortexResult<ArrayRef> {
48    BETWEEN_FN
49        .invoke(&InvocationArgs {
50            inputs: &[arr.into(), lower.into(), upper.into()],
51            options,
52        })?
53        .unwrap_array()
54}
55
56pub struct BetweenKernelRef(ArcRef<dyn Kernel>);
57inventory::collect!(BetweenKernelRef);
58
59pub trait BetweenKernel: Encoding {
60    fn between(
61        &self,
62        arr: &Self::Array,
63        lower: &dyn Array,
64        upper: &dyn Array,
65        options: &BetweenOptions,
66    ) -> VortexResult<Option<ArrayRef>>;
67}
68
69#[derive(Debug)]
70pub struct BetweenKernelAdapter<E: Encoding>(pub E);
71
72impl<E: Encoding + BetweenKernel> BetweenKernelAdapter<E> {
73    pub const fn lift(&'static self) -> BetweenKernelRef {
74        BetweenKernelRef(ArcRef::new_ref(self))
75    }
76}
77
78impl<E: Encoding + BetweenKernel> Kernel for BetweenKernelAdapter<E> {
79    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
80        let inputs = BetweenArgs::try_from(args)?;
81        let Some(array) = inputs.array.as_any().downcast_ref::<E::Array>() else {
82            return Ok(None);
83        };
84        Ok(
85            E::between(&self.0, array, inputs.lower, inputs.upper, inputs.options)?
86                .map(|array| array.into()),
87        )
88    }
89}
90
91pub static BETWEEN_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
92    let compute = ComputeFn::new("between".into(), ArcRef::new_ref(&Between));
93    for kernel in inventory::iter::<BetweenKernelRef> {
94        compute.register_kernel(kernel.0.clone());
95    }
96    compute
97});
98
99struct Between;
100
101impl ComputeFnVTable for Between {
102    fn invoke(
103        &self,
104        args: &InvocationArgs,
105        kernels: &[ArcRef<dyn Kernel>],
106    ) -> VortexResult<Output> {
107        let BetweenArgs {
108            array,
109            lower,
110            upper,
111            options,
112        } = BetweenArgs::try_from(args)?;
113
114        let return_dtype = self.return_dtype(args)?;
115
116        // A quick check to see if either array might is a null constant array.
117        if lower.is_invalid(0)? || upper.is_invalid(0)? {
118            if let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant()) {
119                if c_lower.is_null() || c_upper.is_null() {
120                    return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
121                        .into_array()
122                        .into());
123                }
124            }
125        }
126
127        if lower.as_constant().is_some_and(|v| v.is_null())
128            || upper.as_constant().is_some_and(|v| v.is_null())
129        {
130            return Ok(Canonical::empty(&return_dtype).into_array().into());
131        }
132
133        // Try each kernel
134        for kernel in kernels {
135            if let Some(output) = kernel.invoke(args)? {
136                return Ok(output);
137            }
138        }
139        if let Some(output) = array.invoke(&BETWEEN_FN, args)? {
140            return Ok(output);
141        }
142
143        // Otherwise, fall back to the default Arrow implementation
144        // TODO(joe): should we try to canonicalize the array and try between
145        Ok(boolean(
146            &compare(lower, array, options.lower_strict.to_operator())?,
147            &compare(array, upper, options.upper_strict.to_operator())?,
148            BooleanOperator::And,
149        )?
150        .into())
151    }
152
153    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
154        let BetweenArgs {
155            array,
156            lower,
157            upper,
158            options: _,
159        } = BetweenArgs::try_from(args)?;
160
161        if !array.dtype().eq_ignore_nullability(lower.dtype()) {
162            vortex_bail!(
163                "Array and lower bound types do not match: {:?} != {:?}",
164                array.dtype(),
165                lower.dtype()
166            );
167        }
168        if !array.dtype().eq_ignore_nullability(upper.dtype()) {
169            vortex_bail!(
170                "Array and upper bound types do not match: {:?} != {:?}",
171                array.dtype(),
172                upper.dtype()
173            );
174        }
175
176        Ok(DType::Bool(
177            array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability(),
178        ))
179    }
180
181    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
182        let BetweenArgs {
183            array,
184            lower,
185            upper,
186            options: _,
187        } = BetweenArgs::try_from(args)?;
188        if array.len() != lower.len() || array.len() != upper.len() {
189            vortex_bail!(
190                "Array lengths do not match: array:{} lower:{} upper:{}",
191                array.len(),
192                lower.len(),
193                upper.len()
194            );
195        }
196        Ok(array.len())
197    }
198
199    fn is_elementwise(&self) -> bool {
200        true
201    }
202}
203
204struct BetweenArgs<'a> {
205    array: &'a dyn Array,
206    lower: &'a dyn Array,
207    upper: &'a dyn Array,
208    options: &'a BetweenOptions,
209}
210
211impl<'a> TryFrom<&InvocationArgs<'a>> for BetweenArgs<'a> {
212    type Error = VortexError;
213
214    fn try_from(value: &InvocationArgs<'a>) -> VortexResult<Self> {
215        if value.inputs.len() != 3 {
216            vortex_bail!("Expected 3 inputs, found {}", value.inputs.len());
217        }
218        let array = value.inputs[0]
219            .array()
220            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
221        let lower = value.inputs[1]
222            .array()
223            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
224        let upper = value.inputs[2]
225            .array()
226            .ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
227        let options = value
228            .options
229            .as_any()
230            .downcast_ref::<BetweenOptions>()
231            .vortex_expect("Expected options to be an operator");
232
233        Ok(BetweenArgs {
234            array,
235            lower,
236            upper,
237            options,
238        })
239    }
240}
241
242#[derive(Debug, Clone, PartialEq, Eq, Hash)]
243pub struct BetweenOptions {
244    pub lower_strict: StrictComparison,
245    pub upper_strict: StrictComparison,
246}
247
248impl Options for BetweenOptions {
249    fn as_any(&self) -> &dyn Any {
250        self
251    }
252}
253
254#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
255pub enum StrictComparison {
256    Strict,
257    NonStrict,
258}
259
260impl StrictComparison {
261    pub const fn to_operator(&self) -> Operator {
262        match self {
263            StrictComparison::Strict => Operator::Lt,
264            StrictComparison::NonStrict => Operator::Lte,
265        }
266    }
267}