1use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_dtype::DType;
9use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::Scalar;
11
12use crate::arrays::ConstantArray;
13use crate::compute::{
14 BooleanOperator, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Options, Output,
15 boolean, compare,
16};
17use crate::vtable::VTable;
18use crate::{Array, ArrayRef, Canonical, IntoArray};
19
20static BETWEEN_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
21 let compute = ComputeFn::new("between".into(), ArcRef::new_ref(&Between));
22 for kernel in inventory::iter::<BetweenKernelRef> {
23 compute.register_kernel(kernel.0.clone());
24 }
25 compute
26});
27
28pub fn between(
35 arr: &dyn Array,
36 lower: &dyn Array,
37 upper: &dyn Array,
38 options: &BetweenOptions,
39) -> VortexResult<ArrayRef> {
40 BETWEEN_FN
41 .invoke(&InvocationArgs {
42 inputs: &[arr.into(), lower.into(), upper.into()],
43 options,
44 })?
45 .unwrap_array()
46}
47
48pub struct BetweenKernelRef(ArcRef<dyn Kernel>);
49inventory::collect!(BetweenKernelRef);
50
51pub trait BetweenKernel: VTable {
52 fn between(
53 &self,
54 arr: &Self::Array,
55 lower: &dyn Array,
56 upper: &dyn Array,
57 options: &BetweenOptions,
58 ) -> VortexResult<Option<ArrayRef>>;
59}
60
61#[derive(Debug)]
62pub struct BetweenKernelAdapter<V: VTable>(pub V);
63
64impl<V: VTable + BetweenKernel> BetweenKernelAdapter<V> {
65 pub const fn lift(&'static self) -> BetweenKernelRef {
66 BetweenKernelRef(ArcRef::new_ref(self))
67 }
68}
69
70impl<V: VTable + BetweenKernel> Kernel for BetweenKernelAdapter<V> {
71 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
72 let inputs = BetweenArgs::try_from(args)?;
73 let Some(array) = inputs.array.as_opt::<V>() else {
74 return Ok(None);
75 };
76 Ok(
77 V::between(&self.0, array, inputs.lower, inputs.upper, inputs.options)?
78 .map(|array| array.into()),
79 )
80 }
81}
82
83struct Between;
84
85impl ComputeFnVTable for Between {
86 fn invoke(
87 &self,
88 args: &InvocationArgs,
89 kernels: &[ArcRef<dyn Kernel>],
90 ) -> VortexResult<Output> {
91 let BetweenArgs {
92 array,
93 lower,
94 upper,
95 options,
96 } = BetweenArgs::try_from(args)?;
97
98 let return_dtype = self.return_dtype(args)?;
99
100 if array.is_empty() {
102 return Ok(Canonical::empty(&return_dtype).into_array().into());
103 }
104
105 if (lower.is_invalid(0)? || upper.is_invalid(0)?)
108 && let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
109 && (c_lower.is_null() || c_upper.is_null())
110 {
111 return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
112 .into_array()
113 .into());
114 }
115
116 if lower.as_constant().is_some_and(|v| v.is_null())
117 || upper.as_constant().is_some_and(|v| v.is_null())
118 {
119 return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
120 .into_array()
121 .into());
122 }
123
124 for kernel in kernels {
126 if let Some(output) = kernel.invoke(args)? {
127 return Ok(output);
128 }
129 }
130 if let Some(output) = array.invoke(&BETWEEN_FN, args)? {
131 return Ok(output);
132 }
133
134 Ok(boolean(
137 &compare(lower, array, options.lower_strict.to_operator())?,
138 &compare(array, upper, options.upper_strict.to_operator())?,
139 BooleanOperator::And,
140 )?
141 .into())
142 }
143
144 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
145 let BetweenArgs {
146 array,
147 lower,
148 upper,
149 options: _,
150 } = BetweenArgs::try_from(args)?;
151
152 if !array.dtype().eq_ignore_nullability(lower.dtype()) {
153 vortex_bail!(
154 "Array and lower bound types do not match: {:?} != {:?}",
155 array.dtype(),
156 lower.dtype()
157 );
158 }
159 if !array.dtype().eq_ignore_nullability(upper.dtype()) {
160 vortex_bail!(
161 "Array and upper bound types do not match: {:?} != {:?}",
162 array.dtype(),
163 upper.dtype()
164 );
165 }
166
167 Ok(DType::Bool(
168 array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability(),
169 ))
170 }
171
172 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
173 let BetweenArgs {
174 array,
175 lower,
176 upper,
177 options: _,
178 } = BetweenArgs::try_from(args)?;
179 if array.len() != lower.len() || array.len() != upper.len() {
180 vortex_bail!(
181 "Array lengths do not match: array:{} lower:{} upper:{}",
182 array.len(),
183 lower.len(),
184 upper.len()
185 );
186 }
187 Ok(array.len())
188 }
189
190 fn is_elementwise(&self) -> bool {
191 true
192 }
193}
194
195struct BetweenArgs<'a> {
196 array: &'a dyn Array,
197 lower: &'a dyn Array,
198 upper: &'a dyn Array,
199 options: &'a BetweenOptions,
200}
201
202impl<'a> TryFrom<&InvocationArgs<'a>> for BetweenArgs<'a> {
203 type Error = VortexError;
204
205 fn try_from(value: &InvocationArgs<'a>) -> VortexResult<Self> {
206 if value.inputs.len() != 3 {
207 vortex_bail!("Expected 3 inputs, found {}", value.inputs.len());
208 }
209 let array = value.inputs[0]
210 .array()
211 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
212 let lower = value.inputs[1]
213 .array()
214 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
215 let upper = value.inputs[2]
216 .array()
217 .ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
218 let options = value
219 .options
220 .as_any()
221 .downcast_ref::<BetweenOptions>()
222 .vortex_expect("Expected options to be an operator");
223
224 Ok(BetweenArgs {
225 array,
226 lower,
227 upper,
228 options,
229 })
230 }
231}
232
233#[derive(Debug, Clone, PartialEq, Eq, Hash)]
234pub struct BetweenOptions {
235 pub lower_strict: StrictComparison,
236 pub upper_strict: StrictComparison,
237}
238
239impl Options for BetweenOptions {
240 fn as_any(&self) -> &dyn Any {
241 self
242 }
243}
244
245#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
247pub enum StrictComparison {
248 Strict,
250 NonStrict,
252}
253
254impl StrictComparison {
255 pub const fn to_operator(&self) -> Operator {
256 match self {
257 StrictComparison::Strict => Operator::Lt,
258 StrictComparison::NonStrict => Operator::Lte,
259 }
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use vortex_dtype::{Nullability, PType};
266
267 use super::*;
268 use crate::ToCanonical;
269 use crate::arrays::PrimitiveArray;
270 use crate::compute::conformance::search_sorted::rstest;
271 use crate::test_harness::to_int_indices;
272
273 #[rstest]
274 #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
275 #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
276 #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
277 #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
278 fn test_bounds(
279 #[case] lower_strict: StrictComparison,
280 #[case] upper_strict: StrictComparison,
281 #[case] expected: Vec<u64>,
282 ) {
283 let lower = PrimitiveArray::from_iter([0, 0, 0, 0, 2]);
284 let array = PrimitiveArray::from_iter([1, 0, 1, 0, 1]);
285 let upper = PrimitiveArray::from_iter([2, 1, 1, 0, 0]);
286
287 let matches = between(
288 array.as_ref(),
289 lower.as_ref(),
290 upper.as_ref(),
291 &BetweenOptions {
292 lower_strict,
293 upper_strict,
294 },
295 )
296 .unwrap()
297 .to_bool()
298 .unwrap();
299
300 let indices = to_int_indices(matches).unwrap();
301 assert_eq!(indices, expected);
302 }
303
304 #[test]
305 fn test_constants() {
306 let lower = PrimitiveArray::from_iter([0, 0, 2, 0, 2]);
307 let array = PrimitiveArray::from_iter([1, 0, 1, 0, 1]);
308
309 let upper = ConstantArray::new(
311 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
312 5,
313 );
314
315 let matches = between(
316 array.as_ref(),
317 lower.as_ref(),
318 upper.as_ref(),
319 &BetweenOptions {
320 lower_strict: StrictComparison::NonStrict,
321 upper_strict: StrictComparison::NonStrict,
322 },
323 )
324 .unwrap()
325 .to_bool()
326 .unwrap();
327
328 let indices = to_int_indices(matches).unwrap();
329 assert!(indices.is_empty());
330
331 let upper = ConstantArray::new(Scalar::from(2), 5);
333 let matches = between(
334 array.as_ref(),
335 lower.as_ref(),
336 upper.as_ref(),
337 &BetweenOptions {
338 lower_strict: StrictComparison::NonStrict,
339 upper_strict: StrictComparison::NonStrict,
340 },
341 )
342 .unwrap()
343 .to_bool()
344 .unwrap();
345 let indices = to_int_indices(matches).unwrap();
346 assert_eq!(indices, vec![0, 1, 3]);
347
348 let lower = ConstantArray::new(Scalar::from(0), 5);
350
351 let matches = between(
352 array.as_ref(),
353 lower.as_ref(),
354 upper.as_ref(),
355 &BetweenOptions {
356 lower_strict: StrictComparison::NonStrict,
357 upper_strict: StrictComparison::NonStrict,
358 },
359 )
360 .unwrap()
361 .to_bool()
362 .unwrap();
363 let indices = to_int_indices(matches).unwrap();
364 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
365 }
366}