1use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_dtype::DType;
9use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::Scalar;
11
12use crate::arrays::ConstantArray;
13use crate::compute::{
14 BooleanOperator, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Options, Output,
15 boolean, compare,
16};
17use crate::vtable::VTable;
18use crate::{Array, ArrayRef, Canonical, IntoArray};
19
20static BETWEEN_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
21 let compute = ComputeFn::new("between".into(), ArcRef::new_ref(&Between));
22 for kernel in inventory::iter::<BetweenKernelRef> {
23 compute.register_kernel(kernel.0.clone());
24 }
25 compute
26});
27
28pub(crate) fn warm_up_vtable() -> usize {
29 BETWEEN_FN.kernels().len()
30}
31
32pub fn between(
39 arr: &dyn Array,
40 lower: &dyn Array,
41 upper: &dyn Array,
42 options: &BetweenOptions,
43) -> VortexResult<ArrayRef> {
44 BETWEEN_FN
45 .invoke(&InvocationArgs {
46 inputs: &[arr.into(), lower.into(), upper.into()],
47 options,
48 })?
49 .unwrap_array()
50}
51
52pub struct BetweenKernelRef(ArcRef<dyn Kernel>);
53inventory::collect!(BetweenKernelRef);
54
55pub trait BetweenKernel: VTable {
56 fn between(
57 &self,
58 arr: &Self::Array,
59 lower: &dyn Array,
60 upper: &dyn Array,
61 options: &BetweenOptions,
62 ) -> VortexResult<Option<ArrayRef>>;
63}
64
65#[derive(Debug)]
66pub struct BetweenKernelAdapter<V: VTable>(pub V);
67
68impl<V: VTable + BetweenKernel> BetweenKernelAdapter<V> {
69 pub const fn lift(&'static self) -> BetweenKernelRef {
70 BetweenKernelRef(ArcRef::new_ref(self))
71 }
72}
73
74impl<V: VTable + BetweenKernel> Kernel for BetweenKernelAdapter<V> {
75 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
76 let inputs = BetweenArgs::try_from(args)?;
77 let Some(array) = inputs.array.as_opt::<V>() else {
78 return Ok(None);
79 };
80 Ok(
81 V::between(&self.0, array, inputs.lower, inputs.upper, inputs.options)?
82 .map(|array| array.into()),
83 )
84 }
85}
86
87struct Between;
88
89impl ComputeFnVTable for Between {
90 fn invoke(
91 &self,
92 args: &InvocationArgs,
93 kernels: &[ArcRef<dyn Kernel>],
94 ) -> VortexResult<Output> {
95 let BetweenArgs {
96 array,
97 lower,
98 upper,
99 options,
100 } = BetweenArgs::try_from(args)?;
101
102 let return_dtype = self.return_dtype(args)?;
103
104 if array.is_empty() {
106 return Ok(Canonical::empty(&return_dtype).into_array().into());
107 }
108
109 if (lower.is_invalid(0) || upper.is_invalid(0))
112 && let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
113 && (c_lower.is_null() || c_upper.is_null())
114 {
115 return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
116 .into_array()
117 .into());
118 }
119
120 if lower.as_constant().is_some_and(|v| v.is_null())
121 || upper.as_constant().is_some_and(|v| v.is_null())
122 {
123 return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
124 .into_array()
125 .into());
126 }
127
128 for kernel in kernels {
130 if let Some(output) = kernel.invoke(args)? {
131 return Ok(output);
132 }
133 }
134 if let Some(output) = array.invoke(&BETWEEN_FN, args)? {
135 return Ok(output);
136 }
137
138 Ok(boolean(
141 &compare(lower, array, options.lower_strict.to_operator())?,
142 &compare(array, upper, options.upper_strict.to_operator())?,
143 BooleanOperator::And,
144 )?
145 .into())
146 }
147
148 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
149 let BetweenArgs {
150 array,
151 lower,
152 upper,
153 options: _,
154 } = BetweenArgs::try_from(args)?;
155
156 if !array.dtype().eq_ignore_nullability(lower.dtype()) {
157 vortex_bail!(
158 "Array and lower bound types do not match: {:?} != {:?}",
159 array.dtype(),
160 lower.dtype()
161 );
162 }
163 if !array.dtype().eq_ignore_nullability(upper.dtype()) {
164 vortex_bail!(
165 "Array and upper bound types do not match: {:?} != {:?}",
166 array.dtype(),
167 upper.dtype()
168 );
169 }
170
171 Ok(DType::Bool(
172 array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability(),
173 ))
174 }
175
176 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
177 let BetweenArgs {
178 array,
179 lower,
180 upper,
181 options: _,
182 } = BetweenArgs::try_from(args)?;
183 if array.len() != lower.len() || array.len() != upper.len() {
184 vortex_bail!(
185 "Array lengths do not match: array:{} lower:{} upper:{}",
186 array.len(),
187 lower.len(),
188 upper.len()
189 );
190 }
191 Ok(array.len())
192 }
193
194 fn is_elementwise(&self) -> bool {
195 true
196 }
197}
198
199struct BetweenArgs<'a> {
200 array: &'a dyn Array,
201 lower: &'a dyn Array,
202 upper: &'a dyn Array,
203 options: &'a BetweenOptions,
204}
205
206impl<'a> TryFrom<&InvocationArgs<'a>> for BetweenArgs<'a> {
207 type Error = VortexError;
208
209 fn try_from(value: &InvocationArgs<'a>) -> VortexResult<Self> {
210 if value.inputs.len() != 3 {
211 vortex_bail!("Expected 3 inputs, found {}", value.inputs.len());
212 }
213 let array = value.inputs[0]
214 .array()
215 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
216 let lower = value.inputs[1]
217 .array()
218 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
219 let upper = value.inputs[2]
220 .array()
221 .ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
222 let options = value
223 .options
224 .as_any()
225 .downcast_ref::<BetweenOptions>()
226 .vortex_expect("Expected options to be an operator");
227
228 Ok(BetweenArgs {
229 array,
230 lower,
231 upper,
232 options,
233 })
234 }
235}
236
237#[derive(Debug, Clone, PartialEq, Eq, Hash)]
238pub struct BetweenOptions {
239 pub lower_strict: StrictComparison,
240 pub upper_strict: StrictComparison,
241}
242
243impl Options for BetweenOptions {
244 fn as_any(&self) -> &dyn Any {
245 self
246 }
247}
248
249#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
251pub enum StrictComparison {
252 Strict,
254 NonStrict,
256}
257
258impl StrictComparison {
259 pub const fn to_operator(&self) -> Operator {
260 match self {
261 StrictComparison::Strict => Operator::Lt,
262 StrictComparison::NonStrict => Operator::Lte,
263 }
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use vortex_buffer::buffer;
270 use vortex_dtype::{Nullability, PType};
271
272 use super::*;
273 use crate::ToCanonical;
274 use crate::compute::conformance::search_sorted::rstest;
275 use crate::test_harness::to_int_indices;
276
277 #[rstest]
278 #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
279 #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
280 #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
281 #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
282 fn test_bounds(
283 #[case] lower_strict: StrictComparison,
284 #[case] upper_strict: StrictComparison,
285 #[case] expected: Vec<u64>,
286 ) {
287 let lower = buffer![0, 0, 0, 0, 2].into_array();
288 let array = buffer![1, 0, 1, 0, 1].into_array();
289 let upper = buffer![2, 1, 1, 0, 0].into_array();
290
291 let matches = between(
292 array.as_ref(),
293 lower.as_ref(),
294 upper.as_ref(),
295 &BetweenOptions {
296 lower_strict,
297 upper_strict,
298 },
299 )
300 .unwrap()
301 .to_bool();
302
303 let indices = to_int_indices(matches).unwrap();
304 assert_eq!(indices, expected);
305 }
306
307 #[test]
308 fn test_constants() {
309 let lower = buffer![0, 0, 2, 0, 2].into_array();
310 let array = buffer![1, 0, 1, 0, 1].into_array();
311
312 let upper = ConstantArray::new(
314 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
315 5,
316 );
317
318 let matches = between(
319 array.as_ref(),
320 lower.as_ref(),
321 upper.as_ref(),
322 &BetweenOptions {
323 lower_strict: StrictComparison::NonStrict,
324 upper_strict: StrictComparison::NonStrict,
325 },
326 )
327 .unwrap()
328 .to_bool();
329
330 let indices = to_int_indices(matches).unwrap();
331 assert!(indices.is_empty());
332
333 let upper = ConstantArray::new(Scalar::from(2), 5);
335 let matches = between(
336 array.as_ref(),
337 lower.as_ref(),
338 upper.as_ref(),
339 &BetweenOptions {
340 lower_strict: StrictComparison::NonStrict,
341 upper_strict: StrictComparison::NonStrict,
342 },
343 )
344 .unwrap()
345 .to_bool();
346 let indices = to_int_indices(matches).unwrap();
347 assert_eq!(indices, vec![0, 1, 3]);
348
349 let lower = ConstantArray::new(Scalar::from(0), 5);
351
352 let matches = between(
353 array.as_ref(),
354 lower.as_ref(),
355 upper.as_ref(),
356 &BetweenOptions {
357 lower_strict: StrictComparison::NonStrict,
358 upper_strict: StrictComparison::NonStrict,
359 },
360 )
361 .unwrap()
362 .to_bool();
363 let indices = to_int_indices(matches).unwrap();
364 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
365 }
366}