1use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_dtype::DType;
9use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::Scalar;
11
12use crate::arrays::ConstantArray;
13use crate::compute::{
14 BooleanOperator, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Options, Output,
15 boolean, compare,
16};
17use crate::vtable::VTable;
18use crate::{Array, ArrayRef, Canonical, IntoArray};
19
20static BETWEEN_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
21 let compute = ComputeFn::new("between".into(), ArcRef::new_ref(&Between));
22 for kernel in inventory::iter::<BetweenKernelRef> {
23 compute.register_kernel(kernel.0.clone());
24 }
25 compute
26});
27
28pub fn between(
35 arr: &dyn Array,
36 lower: &dyn Array,
37 upper: &dyn Array,
38 options: &BetweenOptions,
39) -> VortexResult<ArrayRef> {
40 BETWEEN_FN
41 .invoke(&InvocationArgs {
42 inputs: &[arr.into(), lower.into(), upper.into()],
43 options,
44 })?
45 .unwrap_array()
46}
47
48pub struct BetweenKernelRef(ArcRef<dyn Kernel>);
49inventory::collect!(BetweenKernelRef);
50
51pub trait BetweenKernel: VTable {
52 fn between(
53 &self,
54 arr: &Self::Array,
55 lower: &dyn Array,
56 upper: &dyn Array,
57 options: &BetweenOptions,
58 ) -> VortexResult<Option<ArrayRef>>;
59}
60
61#[derive(Debug)]
62pub struct BetweenKernelAdapter<V: VTable>(pub V);
63
64impl<V: VTable + BetweenKernel> BetweenKernelAdapter<V> {
65 pub const fn lift(&'static self) -> BetweenKernelRef {
66 BetweenKernelRef(ArcRef::new_ref(self))
67 }
68}
69
70impl<V: VTable + BetweenKernel> Kernel for BetweenKernelAdapter<V> {
71 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
72 let inputs = BetweenArgs::try_from(args)?;
73 let Some(array) = inputs.array.as_opt::<V>() else {
74 return Ok(None);
75 };
76 Ok(
77 V::between(&self.0, array, inputs.lower, inputs.upper, inputs.options)?
78 .map(|array| array.into()),
79 )
80 }
81}
82
83struct Between;
84
85impl ComputeFnVTable for Between {
86 fn invoke(
87 &self,
88 args: &InvocationArgs,
89 kernels: &[ArcRef<dyn Kernel>],
90 ) -> VortexResult<Output> {
91 let BetweenArgs {
92 array,
93 lower,
94 upper,
95 options,
96 } = BetweenArgs::try_from(args)?;
97
98 let return_dtype = self.return_dtype(args)?;
99
100 if array.is_empty() {
102 return Ok(Canonical::empty(&return_dtype).into_array().into());
103 }
104
105 if lower.is_invalid(0)? || upper.is_invalid(0)? {
108 if let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant()) {
109 if c_lower.is_null() || c_upper.is_null() {
110 return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
111 .into_array()
112 .into());
113 }
114 }
115 }
116
117 if lower.as_constant().is_some_and(|v| v.is_null())
118 || upper.as_constant().is_some_and(|v| v.is_null())
119 {
120 return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
121 .into_array()
122 .into());
123 }
124
125 for kernel in kernels {
127 if let Some(output) = kernel.invoke(args)? {
128 return Ok(output);
129 }
130 }
131 if let Some(output) = array.invoke(&BETWEEN_FN, args)? {
132 return Ok(output);
133 }
134
135 Ok(boolean(
138 &compare(lower, array, options.lower_strict.to_operator())?,
139 &compare(array, upper, options.upper_strict.to_operator())?,
140 BooleanOperator::And,
141 )?
142 .into())
143 }
144
145 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
146 let BetweenArgs {
147 array,
148 lower,
149 upper,
150 options: _,
151 } = BetweenArgs::try_from(args)?;
152
153 if !array.dtype().eq_ignore_nullability(lower.dtype()) {
154 vortex_bail!(
155 "Array and lower bound types do not match: {:?} != {:?}",
156 array.dtype(),
157 lower.dtype()
158 );
159 }
160 if !array.dtype().eq_ignore_nullability(upper.dtype()) {
161 vortex_bail!(
162 "Array and upper bound types do not match: {:?} != {:?}",
163 array.dtype(),
164 upper.dtype()
165 );
166 }
167
168 Ok(DType::Bool(
169 array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability(),
170 ))
171 }
172
173 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
174 let BetweenArgs {
175 array,
176 lower,
177 upper,
178 options: _,
179 } = BetweenArgs::try_from(args)?;
180 if array.len() != lower.len() || array.len() != upper.len() {
181 vortex_bail!(
182 "Array lengths do not match: array:{} lower:{} upper:{}",
183 array.len(),
184 lower.len(),
185 upper.len()
186 );
187 }
188 Ok(array.len())
189 }
190
191 fn is_elementwise(&self) -> bool {
192 true
193 }
194}
195
196struct BetweenArgs<'a> {
197 array: &'a dyn Array,
198 lower: &'a dyn Array,
199 upper: &'a dyn Array,
200 options: &'a BetweenOptions,
201}
202
203impl<'a> TryFrom<&InvocationArgs<'a>> for BetweenArgs<'a> {
204 type Error = VortexError;
205
206 fn try_from(value: &InvocationArgs<'a>) -> VortexResult<Self> {
207 if value.inputs.len() != 3 {
208 vortex_bail!("Expected 3 inputs, found {}", value.inputs.len());
209 }
210 let array = value.inputs[0]
211 .array()
212 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
213 let lower = value.inputs[1]
214 .array()
215 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
216 let upper = value.inputs[2]
217 .array()
218 .ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
219 let options = value
220 .options
221 .as_any()
222 .downcast_ref::<BetweenOptions>()
223 .vortex_expect("Expected options to be an operator");
224
225 Ok(BetweenArgs {
226 array,
227 lower,
228 upper,
229 options,
230 })
231 }
232}
233
234#[derive(Debug, Clone, PartialEq, Eq, Hash)]
235pub struct BetweenOptions {
236 pub lower_strict: StrictComparison,
237 pub upper_strict: StrictComparison,
238}
239
240impl Options for BetweenOptions {
241 fn as_any(&self) -> &dyn Any {
242 self
243 }
244}
245
246#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
248pub enum StrictComparison {
249 Strict,
251 NonStrict,
253}
254
255impl StrictComparison {
256 pub const fn to_operator(&self) -> Operator {
257 match self {
258 StrictComparison::Strict => Operator::Lt,
259 StrictComparison::NonStrict => Operator::Lte,
260 }
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use vortex_dtype::{Nullability, PType};
267
268 use super::*;
269 use crate::ToCanonical;
270 use crate::arrays::PrimitiveArray;
271 use crate::compute::conformance::search_sorted::rstest;
272 use crate::test_harness::to_int_indices;
273
274 #[rstest]
275 #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
276 #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
277 #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
278 #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
279 fn test_bounds(
280 #[case] lower_strict: StrictComparison,
281 #[case] upper_strict: StrictComparison,
282 #[case] expected: Vec<u64>,
283 ) {
284 let lower = PrimitiveArray::from_iter([0, 0, 0, 0, 2]);
285 let array = PrimitiveArray::from_iter([1, 0, 1, 0, 1]);
286 let upper = PrimitiveArray::from_iter([2, 1, 1, 0, 0]);
287
288 let matches = between(
289 array.as_ref(),
290 lower.as_ref(),
291 upper.as_ref(),
292 &BetweenOptions {
293 lower_strict,
294 upper_strict,
295 },
296 )
297 .unwrap()
298 .to_bool()
299 .unwrap();
300
301 let indices = to_int_indices(matches).unwrap();
302 assert_eq!(indices, expected);
303 }
304
305 #[test]
306 fn test_constants() {
307 let lower = PrimitiveArray::from_iter([0, 0, 2, 0, 2]);
308 let array = PrimitiveArray::from_iter([1, 0, 1, 0, 1]);
309
310 let upper = ConstantArray::new(
312 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
313 5,
314 );
315
316 let matches = between(
317 array.as_ref(),
318 lower.as_ref(),
319 upper.as_ref(),
320 &BetweenOptions {
321 lower_strict: StrictComparison::NonStrict,
322 upper_strict: StrictComparison::NonStrict,
323 },
324 )
325 .unwrap()
326 .to_bool()
327 .unwrap();
328
329 let indices = to_int_indices(matches).unwrap();
330 assert!(indices.is_empty());
331
332 let upper = ConstantArray::new(Scalar::from(2), 5);
334 let matches = between(
335 array.as_ref(),
336 lower.as_ref(),
337 upper.as_ref(),
338 &BetweenOptions {
339 lower_strict: StrictComparison::NonStrict,
340 upper_strict: StrictComparison::NonStrict,
341 },
342 )
343 .unwrap()
344 .to_bool()
345 .unwrap();
346 let indices = to_int_indices(matches).unwrap();
347 assert_eq!(indices, vec![0, 1, 3]);
348
349 let lower = ConstantArray::new(Scalar::from(0), 5);
351
352 let matches = between(
353 array.as_ref(),
354 lower.as_ref(),
355 upper.as_ref(),
356 &BetweenOptions {
357 lower_strict: StrictComparison::NonStrict,
358 upper_strict: StrictComparison::NonStrict,
359 },
360 )
361 .unwrap()
362 .to_bool()
363 .unwrap();
364 let indices = to_int_indices(matches).unwrap();
365 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
366 }
367}