1use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_dtype::DType;
9use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::Scalar;
11
12use crate::arrays::ConstantArray;
13use crate::compute::{
14 BooleanOperator, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Options, Output,
15 boolean, compare,
16};
17use crate::vtable::VTable;
18use crate::{Array, ArrayRef, Canonical, IntoArray};
19
20pub fn between(
47 arr: &dyn Array,
48 lower: &dyn Array,
49 upper: &dyn Array,
50 options: &BetweenOptions,
51) -> VortexResult<ArrayRef> {
52 BETWEEN_FN
53 .invoke(&InvocationArgs {
54 inputs: &[arr.into(), lower.into(), upper.into()],
55 options,
56 })?
57 .unwrap_array()
58}
59
60pub struct BetweenKernelRef(ArcRef<dyn Kernel>);
61inventory::collect!(BetweenKernelRef);
62
63pub trait BetweenKernel: VTable {
64 fn between(
65 &self,
66 arr: &Self::Array,
67 lower: &dyn Array,
68 upper: &dyn Array,
69 options: &BetweenOptions,
70 ) -> VortexResult<Option<ArrayRef>>;
71}
72
73#[derive(Debug)]
74pub struct BetweenKernelAdapter<V: VTable>(pub V);
75
76impl<V: VTable + BetweenKernel> BetweenKernelAdapter<V> {
77 pub const fn lift(&'static self) -> BetweenKernelRef {
78 BetweenKernelRef(ArcRef::new_ref(self))
79 }
80}
81
82impl<V: VTable + BetweenKernel> Kernel for BetweenKernelAdapter<V> {
83 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
84 let inputs = BetweenArgs::try_from(args)?;
85 let Some(array) = inputs.array.as_opt::<V>() else {
86 return Ok(None);
87 };
88 Ok(
89 V::between(&self.0, array, inputs.lower, inputs.upper, inputs.options)?
90 .map(|array| array.into()),
91 )
92 }
93}
94
95pub static BETWEEN_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
96 let compute = ComputeFn::new("between".into(), ArcRef::new_ref(&Between));
97 for kernel in inventory::iter::<BetweenKernelRef> {
98 compute.register_kernel(kernel.0.clone());
99 }
100 compute
101});
102
103struct Between;
104
105impl ComputeFnVTable for Between {
106 fn invoke(
107 &self,
108 args: &InvocationArgs,
109 kernels: &[ArcRef<dyn Kernel>],
110 ) -> VortexResult<Output> {
111 let BetweenArgs {
112 array,
113 lower,
114 upper,
115 options,
116 } = BetweenArgs::try_from(args)?;
117
118 let return_dtype = self.return_dtype(args)?;
119
120 if array.is_empty() {
122 return Ok(Canonical::empty(&return_dtype).into_array().into());
123 }
124
125 if lower.is_invalid(0)? || upper.is_invalid(0)? {
128 if let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant()) {
129 if c_lower.is_null() || c_upper.is_null() {
130 return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
131 .into_array()
132 .into());
133 }
134 }
135 }
136
137 if lower.as_constant().is_some_and(|v| v.is_null())
138 || upper.as_constant().is_some_and(|v| v.is_null())
139 {
140 return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
141 .into_array()
142 .into());
143 }
144
145 for kernel in kernels {
147 if let Some(output) = kernel.invoke(args)? {
148 return Ok(output);
149 }
150 }
151 if let Some(output) = array.invoke(&BETWEEN_FN, args)? {
152 return Ok(output);
153 }
154
155 Ok(boolean(
158 &compare(lower, array, options.lower_strict.to_operator())?,
159 &compare(array, upper, options.upper_strict.to_operator())?,
160 BooleanOperator::And,
161 )?
162 .into())
163 }
164
165 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
166 let BetweenArgs {
167 array,
168 lower,
169 upper,
170 options: _,
171 } = BetweenArgs::try_from(args)?;
172
173 if !array.dtype().eq_ignore_nullability(lower.dtype()) {
174 vortex_bail!(
175 "Array and lower bound types do not match: {:?} != {:?}",
176 array.dtype(),
177 lower.dtype()
178 );
179 }
180 if !array.dtype().eq_ignore_nullability(upper.dtype()) {
181 vortex_bail!(
182 "Array and upper bound types do not match: {:?} != {:?}",
183 array.dtype(),
184 upper.dtype()
185 );
186 }
187
188 Ok(DType::Bool(
189 array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability(),
190 ))
191 }
192
193 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
194 let BetweenArgs {
195 array,
196 lower,
197 upper,
198 options: _,
199 } = BetweenArgs::try_from(args)?;
200 if array.len() != lower.len() || array.len() != upper.len() {
201 vortex_bail!(
202 "Array lengths do not match: array:{} lower:{} upper:{}",
203 array.len(),
204 lower.len(),
205 upper.len()
206 );
207 }
208 Ok(array.len())
209 }
210
211 fn is_elementwise(&self) -> bool {
212 true
213 }
214}
215
216struct BetweenArgs<'a> {
217 array: &'a dyn Array,
218 lower: &'a dyn Array,
219 upper: &'a dyn Array,
220 options: &'a BetweenOptions,
221}
222
223impl<'a> TryFrom<&InvocationArgs<'a>> for BetweenArgs<'a> {
224 type Error = VortexError;
225
226 fn try_from(value: &InvocationArgs<'a>) -> VortexResult<Self> {
227 if value.inputs.len() != 3 {
228 vortex_bail!("Expected 3 inputs, found {}", value.inputs.len());
229 }
230 let array = value.inputs[0]
231 .array()
232 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
233 let lower = value.inputs[1]
234 .array()
235 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
236 let upper = value.inputs[2]
237 .array()
238 .ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
239 let options = value
240 .options
241 .as_any()
242 .downcast_ref::<BetweenOptions>()
243 .vortex_expect("Expected options to be an operator");
244
245 Ok(BetweenArgs {
246 array,
247 lower,
248 upper,
249 options,
250 })
251 }
252}
253
254#[derive(Debug, Clone, PartialEq, Eq, Hash)]
255pub struct BetweenOptions {
256 pub lower_strict: StrictComparison,
257 pub upper_strict: StrictComparison,
258}
259
260impl Options for BetweenOptions {
261 fn as_any(&self) -> &dyn Any {
262 self
263 }
264}
265
266#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
267pub enum StrictComparison {
268 Strict,
269 NonStrict,
270}
271
272impl StrictComparison {
273 pub const fn to_operator(&self) -> Operator {
274 match self {
275 StrictComparison::Strict => Operator::Lt,
276 StrictComparison::NonStrict => Operator::Lte,
277 }
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use vortex_dtype::{Nullability, PType};
284
285 use super::*;
286 use crate::ToCanonical;
287 use crate::arrays::PrimitiveArray;
288 use crate::compute::conformance::search_sorted::rstest;
289 use crate::test_harness::to_int_indices;
290
291 #[rstest]
292 #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
293 #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
294 #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
295 #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
296 fn test_bounds(
297 #[case] lower_strict: StrictComparison,
298 #[case] upper_strict: StrictComparison,
299 #[case] expected: Vec<u64>,
300 ) {
301 let lower = PrimitiveArray::from_iter([0, 0, 0, 0, 2]);
302 let array = PrimitiveArray::from_iter([1, 0, 1, 0, 1]);
303 let upper = PrimitiveArray::from_iter([2, 1, 1, 0, 0]);
304
305 let matches = between(
306 array.as_ref(),
307 lower.as_ref(),
308 upper.as_ref(),
309 &BetweenOptions {
310 lower_strict,
311 upper_strict,
312 },
313 )
314 .unwrap()
315 .to_bool()
316 .unwrap();
317
318 let indices = to_int_indices(matches).unwrap();
319 assert_eq!(indices, expected);
320 }
321
322 #[test]
323 fn test_constants() {
324 let lower = PrimitiveArray::from_iter([0, 0, 2, 0, 2]);
325 let array = PrimitiveArray::from_iter([1, 0, 1, 0, 1]);
326
327 let upper = ConstantArray::new(
329 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
330 5,
331 );
332
333 let matches = between(
334 array.as_ref(),
335 lower.as_ref(),
336 upper.as_ref(),
337 &BetweenOptions {
338 lower_strict: StrictComparison::NonStrict,
339 upper_strict: StrictComparison::NonStrict,
340 },
341 )
342 .unwrap()
343 .to_bool()
344 .unwrap();
345
346 let indices = to_int_indices(matches).unwrap();
347 assert!(indices.is_empty());
348
349 let upper = ConstantArray::new(Scalar::from(2), 5);
351 let matches = between(
352 array.as_ref(),
353 lower.as_ref(),
354 upper.as_ref(),
355 &BetweenOptions {
356 lower_strict: StrictComparison::NonStrict,
357 upper_strict: StrictComparison::NonStrict,
358 },
359 )
360 .unwrap()
361 .to_bool()
362 .unwrap();
363 let indices = to_int_indices(matches).unwrap();
364 assert_eq!(indices, vec![0, 1, 3]);
365
366 let lower = ConstantArray::new(Scalar::from(0), 5);
368
369 let matches = between(
370 array.as_ref(),
371 lower.as_ref(),
372 upper.as_ref(),
373 &BetweenOptions {
374 lower_strict: StrictComparison::NonStrict,
375 upper_strict: StrictComparison::NonStrict,
376 },
377 )
378 .unwrap()
379 .to_bool()
380 .unwrap();
381 let indices = to_int_indices(matches).unwrap();
382 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
383 }
384}