1use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_dtype::DType;
9use vortex_error::VortexError;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_scalar::Scalar;
15
16use crate::Array;
17use crate::ArrayRef;
18use crate::Canonical;
19use crate::IntoArray;
20use crate::arrays::ConstantArray;
21use crate::compute::BooleanOperator;
22use crate::compute::ComputeFn;
23use crate::compute::ComputeFnVTable;
24use crate::compute::InvocationArgs;
25use crate::compute::Kernel;
26use crate::compute::Operator;
27use crate::compute::Options;
28use crate::compute::Output;
29use crate::compute::boolean;
30use crate::compute::compare;
31use crate::vtable::VTable;
32
33static BETWEEN_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
34 let compute = ComputeFn::new("between".into(), ArcRef::new_ref(&Between));
35 for kernel in inventory::iter::<BetweenKernelRef> {
36 compute.register_kernel(kernel.0.clone());
37 }
38 compute
39});
40
41pub(crate) fn warm_up_vtable() -> usize {
42 BETWEEN_FN.kernels().len()
43}
44
45pub fn between(
52 arr: &dyn Array,
53 lower: &dyn Array,
54 upper: &dyn Array,
55 options: &BetweenOptions,
56) -> VortexResult<ArrayRef> {
57 BETWEEN_FN
58 .invoke(&InvocationArgs {
59 inputs: &[arr.into(), lower.into(), upper.into()],
60 options,
61 })?
62 .unwrap_array()
63}
64
65pub struct BetweenKernelRef(ArcRef<dyn Kernel>);
66inventory::collect!(BetweenKernelRef);
67
68pub trait BetweenKernel: VTable {
69 fn between(
70 &self,
71 arr: &Self::Array,
72 lower: &dyn Array,
73 upper: &dyn Array,
74 options: &BetweenOptions,
75 ) -> VortexResult<Option<ArrayRef>>;
76}
77
78#[derive(Debug)]
79pub struct BetweenKernelAdapter<V: VTable>(pub V);
80
81impl<V: VTable + BetweenKernel> BetweenKernelAdapter<V> {
82 pub const fn lift(&'static self) -> BetweenKernelRef {
83 BetweenKernelRef(ArcRef::new_ref(self))
84 }
85}
86
87impl<V: VTable + BetweenKernel> Kernel for BetweenKernelAdapter<V> {
88 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
89 let inputs = BetweenArgs::try_from(args)?;
90 let Some(array) = inputs.array.as_opt::<V>() else {
91 return Ok(None);
92 };
93 Ok(
94 V::between(&self.0, array, inputs.lower, inputs.upper, inputs.options)?
95 .map(|array| array.into()),
96 )
97 }
98}
99
100struct Between;
101
102impl ComputeFnVTable for Between {
103 fn invoke(
104 &self,
105 args: &InvocationArgs,
106 kernels: &[ArcRef<dyn Kernel>],
107 ) -> VortexResult<Output> {
108 let BetweenArgs {
109 array,
110 lower,
111 upper,
112 options,
113 } = BetweenArgs::try_from(args)?;
114
115 let return_dtype = self.return_dtype(args)?;
116
117 if array.is_empty() {
119 return Ok(Canonical::empty(&return_dtype).into_array().into());
120 }
121
122 if (lower.is_invalid(0) || upper.is_invalid(0))
125 && let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
126 && (c_lower.is_null() || c_upper.is_null())
127 {
128 return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
129 .into_array()
130 .into());
131 }
132
133 if lower.as_constant().is_some_and(|v| v.is_null())
134 || upper.as_constant().is_some_and(|v| v.is_null())
135 {
136 return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
137 .into_array()
138 .into());
139 }
140
141 for kernel in kernels {
143 if let Some(output) = kernel.invoke(args)? {
144 return Ok(output);
145 }
146 }
147 if let Some(output) = array.invoke(&BETWEEN_FN, args)? {
148 return Ok(output);
149 }
150
151 Ok(boolean(
154 &compare(lower, array, options.lower_strict.to_operator())?,
155 &compare(array, upper, options.upper_strict.to_operator())?,
156 BooleanOperator::And,
157 )?
158 .into())
159 }
160
161 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
162 let BetweenArgs {
163 array,
164 lower,
165 upper,
166 options: _,
167 } = BetweenArgs::try_from(args)?;
168
169 if !array.dtype().eq_ignore_nullability(lower.dtype()) {
170 vortex_bail!(
171 "Array and lower bound types do not match: {:?} != {:?}",
172 array.dtype(),
173 lower.dtype()
174 );
175 }
176 if !array.dtype().eq_ignore_nullability(upper.dtype()) {
177 vortex_bail!(
178 "Array and upper bound types do not match: {:?} != {:?}",
179 array.dtype(),
180 upper.dtype()
181 );
182 }
183
184 Ok(DType::Bool(
185 array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability(),
186 ))
187 }
188
189 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
190 let BetweenArgs {
191 array,
192 lower,
193 upper,
194 options: _,
195 } = BetweenArgs::try_from(args)?;
196 if array.len() != lower.len() || array.len() != upper.len() {
197 vortex_bail!(
198 "Array lengths do not match: array:{} lower:{} upper:{}",
199 array.len(),
200 lower.len(),
201 upper.len()
202 );
203 }
204 Ok(array.len())
205 }
206
207 fn is_elementwise(&self) -> bool {
208 true
209 }
210}
211
212struct BetweenArgs<'a> {
213 array: &'a dyn Array,
214 lower: &'a dyn Array,
215 upper: &'a dyn Array,
216 options: &'a BetweenOptions,
217}
218
219impl<'a> TryFrom<&InvocationArgs<'a>> for BetweenArgs<'a> {
220 type Error = VortexError;
221
222 fn try_from(value: &InvocationArgs<'a>) -> VortexResult<Self> {
223 if value.inputs.len() != 3 {
224 vortex_bail!("Expected 3 inputs, found {}", value.inputs.len());
225 }
226 let array = value.inputs[0]
227 .array()
228 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
229 let lower = value.inputs[1]
230 .array()
231 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
232 let upper = value.inputs[2]
233 .array()
234 .ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
235 let options = value
236 .options
237 .as_any()
238 .downcast_ref::<BetweenOptions>()
239 .vortex_expect("Expected options to be an operator");
240
241 Ok(BetweenArgs {
242 array,
243 lower,
244 upper,
245 options,
246 })
247 }
248}
249
250#[derive(Debug, Clone, PartialEq, Eq, Hash)]
251pub struct BetweenOptions {
252 pub lower_strict: StrictComparison,
253 pub upper_strict: StrictComparison,
254}
255
256impl Options for BetweenOptions {
257 fn as_any(&self) -> &dyn Any {
258 self
259 }
260}
261
262#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
264pub enum StrictComparison {
265 Strict,
267 NonStrict,
269}
270
271impl StrictComparison {
272 pub const fn to_operator(&self) -> Operator {
273 match self {
274 StrictComparison::Strict => Operator::Lt,
275 StrictComparison::NonStrict => Operator::Lte,
276 }
277 }
278
279 pub const fn is_strict(&self) -> bool {
280 matches!(self, StrictComparison::Strict)
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use vortex_buffer::buffer;
287 use vortex_dtype::Nullability;
288 use vortex_dtype::PType;
289
290 use super::*;
291 use crate::ToCanonical;
292 use crate::compute::conformance::search_sorted::rstest;
293 use crate::test_harness::to_int_indices;
294
295 #[rstest]
296 #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
297 #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
298 #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
299 #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
300 fn test_bounds(
301 #[case] lower_strict: StrictComparison,
302 #[case] upper_strict: StrictComparison,
303 #[case] expected: Vec<u64>,
304 ) {
305 let lower = buffer![0, 0, 0, 0, 2].into_array();
306 let array = buffer![1, 0, 1, 0, 1].into_array();
307 let upper = buffer![2, 1, 1, 0, 0].into_array();
308
309 let matches = between(
310 array.as_ref(),
311 lower.as_ref(),
312 upper.as_ref(),
313 &BetweenOptions {
314 lower_strict,
315 upper_strict,
316 },
317 )
318 .unwrap()
319 .to_bool();
320
321 let indices = to_int_indices(matches).unwrap();
322 assert_eq!(indices, expected);
323 }
324
325 #[test]
326 fn test_constants() {
327 let lower = buffer![0, 0, 2, 0, 2].into_array();
328 let array = buffer![1, 0, 1, 0, 1].into_array();
329
330 let upper = ConstantArray::new(
332 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
333 5,
334 );
335
336 let matches = between(
337 array.as_ref(),
338 lower.as_ref(),
339 upper.as_ref(),
340 &BetweenOptions {
341 lower_strict: StrictComparison::NonStrict,
342 upper_strict: StrictComparison::NonStrict,
343 },
344 )
345 .unwrap()
346 .to_bool();
347
348 let indices = to_int_indices(matches).unwrap();
349 assert!(indices.is_empty());
350
351 let upper = ConstantArray::new(Scalar::from(2), 5);
353 let matches = between(
354 array.as_ref(),
355 lower.as_ref(),
356 upper.as_ref(),
357 &BetweenOptions {
358 lower_strict: StrictComparison::NonStrict,
359 upper_strict: StrictComparison::NonStrict,
360 },
361 )
362 .unwrap()
363 .to_bool();
364 let indices = to_int_indices(matches).unwrap();
365 assert_eq!(indices, vec![0, 1, 3]);
366
367 let lower = ConstantArray::new(Scalar::from(0), 5);
369
370 let matches = between(
371 array.as_ref(),
372 lower.as_ref(),
373 upper.as_ref(),
374 &BetweenOptions {
375 lower_strict: StrictComparison::NonStrict,
376 upper_strict: StrictComparison::NonStrict,
377 },
378 )
379 .unwrap()
380 .to_bool();
381 let indices = to_int_indices(matches).unwrap();
382 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
383 }
384}