1use std::any::Any;
2use std::sync::LazyLock;
3
4use arcref::ArcRef;
5use vortex_dtype::DType;
6use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
7use vortex_scalar::Scalar;
8
9use crate::arrays::ConstantArray;
10use crate::compute::{
11 BooleanOperator, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Options, Output,
12 boolean, compare,
13};
14use crate::vtable::VTable;
15use crate::{Array, ArrayRef, Canonical, IntoArray};
16
17pub fn between(
44 arr: &dyn Array,
45 lower: &dyn Array,
46 upper: &dyn Array,
47 options: &BetweenOptions,
48) -> VortexResult<ArrayRef> {
49 BETWEEN_FN
50 .invoke(&InvocationArgs {
51 inputs: &[arr.into(), lower.into(), upper.into()],
52 options,
53 })?
54 .unwrap_array()
55}
56
57pub struct BetweenKernelRef(ArcRef<dyn Kernel>);
58inventory::collect!(BetweenKernelRef);
59
60pub trait BetweenKernel: VTable {
61 fn between(
62 &self,
63 arr: &Self::Array,
64 lower: &dyn Array,
65 upper: &dyn Array,
66 options: &BetweenOptions,
67 ) -> VortexResult<Option<ArrayRef>>;
68}
69
70#[derive(Debug)]
71pub struct BetweenKernelAdapter<V: VTable>(pub V);
72
73impl<V: VTable + BetweenKernel> BetweenKernelAdapter<V> {
74 pub const fn lift(&'static self) -> BetweenKernelRef {
75 BetweenKernelRef(ArcRef::new_ref(self))
76 }
77}
78
79impl<V: VTable + BetweenKernel> Kernel for BetweenKernelAdapter<V> {
80 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
81 let inputs = BetweenArgs::try_from(args)?;
82 let Some(array) = inputs.array.as_opt::<V>() else {
83 return Ok(None);
84 };
85 Ok(
86 V::between(&self.0, array, inputs.lower, inputs.upper, inputs.options)?
87 .map(|array| array.into()),
88 )
89 }
90}
91
92pub static BETWEEN_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
93 let compute = ComputeFn::new("between".into(), ArcRef::new_ref(&Between));
94 for kernel in inventory::iter::<BetweenKernelRef> {
95 compute.register_kernel(kernel.0.clone());
96 }
97 compute
98});
99
100struct Between;
101
102impl ComputeFnVTable for Between {
103 fn invoke(
104 &self,
105 args: &InvocationArgs,
106 kernels: &[ArcRef<dyn Kernel>],
107 ) -> VortexResult<Output> {
108 let BetweenArgs {
109 array,
110 lower,
111 upper,
112 options,
113 } = BetweenArgs::try_from(args)?;
114
115 let return_dtype = self.return_dtype(args)?;
116
117 if array.is_empty() {
119 return Ok(Canonical::empty(&return_dtype).into_array().into());
120 }
121
122 if lower.is_invalid(0)? || upper.is_invalid(0)? {
125 if let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant()) {
126 if c_lower.is_null() || c_upper.is_null() {
127 return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
128 .into_array()
129 .into());
130 }
131 }
132 }
133
134 if lower.as_constant().is_some_and(|v| v.is_null())
135 || upper.as_constant().is_some_and(|v| v.is_null())
136 {
137 return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
138 .into_array()
139 .into());
140 }
141
142 for kernel in kernels {
144 if let Some(output) = kernel.invoke(args)? {
145 return Ok(output);
146 }
147 }
148 if let Some(output) = array.invoke(&BETWEEN_FN, args)? {
149 return Ok(output);
150 }
151
152 Ok(boolean(
155 &compare(lower, array, options.lower_strict.to_operator())?,
156 &compare(array, upper, options.upper_strict.to_operator())?,
157 BooleanOperator::And,
158 )?
159 .into())
160 }
161
162 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
163 let BetweenArgs {
164 array,
165 lower,
166 upper,
167 options: _,
168 } = BetweenArgs::try_from(args)?;
169
170 if !array.dtype().eq_ignore_nullability(lower.dtype()) {
171 vortex_bail!(
172 "Array and lower bound types do not match: {:?} != {:?}",
173 array.dtype(),
174 lower.dtype()
175 );
176 }
177 if !array.dtype().eq_ignore_nullability(upper.dtype()) {
178 vortex_bail!(
179 "Array and upper bound types do not match: {:?} != {:?}",
180 array.dtype(),
181 upper.dtype()
182 );
183 }
184
185 Ok(DType::Bool(
186 array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability(),
187 ))
188 }
189
190 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
191 let BetweenArgs {
192 array,
193 lower,
194 upper,
195 options: _,
196 } = BetweenArgs::try_from(args)?;
197 if array.len() != lower.len() || array.len() != upper.len() {
198 vortex_bail!(
199 "Array lengths do not match: array:{} lower:{} upper:{}",
200 array.len(),
201 lower.len(),
202 upper.len()
203 );
204 }
205 Ok(array.len())
206 }
207
208 fn is_elementwise(&self) -> bool {
209 true
210 }
211}
212
213struct BetweenArgs<'a> {
214 array: &'a dyn Array,
215 lower: &'a dyn Array,
216 upper: &'a dyn Array,
217 options: &'a BetweenOptions,
218}
219
220impl<'a> TryFrom<&InvocationArgs<'a>> for BetweenArgs<'a> {
221 type Error = VortexError;
222
223 fn try_from(value: &InvocationArgs<'a>) -> VortexResult<Self> {
224 if value.inputs.len() != 3 {
225 vortex_bail!("Expected 3 inputs, found {}", value.inputs.len());
226 }
227 let array = value.inputs[0]
228 .array()
229 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
230 let lower = value.inputs[1]
231 .array()
232 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
233 let upper = value.inputs[2]
234 .array()
235 .ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
236 let options = value
237 .options
238 .as_any()
239 .downcast_ref::<BetweenOptions>()
240 .vortex_expect("Expected options to be an operator");
241
242 Ok(BetweenArgs {
243 array,
244 lower,
245 upper,
246 options,
247 })
248 }
249}
250
251#[derive(Debug, Clone, PartialEq, Eq, Hash)]
252pub struct BetweenOptions {
253 pub lower_strict: StrictComparison,
254 pub upper_strict: StrictComparison,
255}
256
257impl Options for BetweenOptions {
258 fn as_any(&self) -> &dyn Any {
259 self
260 }
261}
262
263#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
264pub enum StrictComparison {
265 Strict,
266 NonStrict,
267}
268
269impl StrictComparison {
270 pub const fn to_operator(&self) -> Operator {
271 match self {
272 StrictComparison::Strict => Operator::Lt,
273 StrictComparison::NonStrict => Operator::Lte,
274 }
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use vortex_dtype::{Nullability, PType};
281
282 use super::*;
283 use crate::ToCanonical;
284 use crate::arrays::PrimitiveArray;
285 use crate::compute::conformance::search_sorted::rstest;
286 use crate::test_harness::to_int_indices;
287
288 #[rstest]
289 #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
290 #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
291 #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
292 #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
293 fn test_bounds(
294 #[case] lower_strict: StrictComparison,
295 #[case] upper_strict: StrictComparison,
296 #[case] expected: Vec<u64>,
297 ) {
298 let lower = PrimitiveArray::from_iter([0, 0, 0, 0, 2]);
299 let array = PrimitiveArray::from_iter([1, 0, 1, 0, 1]);
300 let upper = PrimitiveArray::from_iter([2, 1, 1, 0, 0]);
301
302 let matches = between(
303 array.as_ref(),
304 lower.as_ref(),
305 upper.as_ref(),
306 &BetweenOptions {
307 lower_strict,
308 upper_strict,
309 },
310 )
311 .unwrap()
312 .to_bool()
313 .unwrap();
314
315 let indices = to_int_indices(matches).unwrap();
316 assert_eq!(indices, expected);
317 }
318
319 #[test]
320 fn test_constants() {
321 let lower = PrimitiveArray::from_iter([0, 0, 2, 0, 2]);
322 let array = PrimitiveArray::from_iter([1, 0, 1, 0, 1]);
323
324 let upper = ConstantArray::new(
326 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
327 5,
328 );
329
330 let matches = between(
331 array.as_ref(),
332 lower.as_ref(),
333 upper.as_ref(),
334 &BetweenOptions {
335 lower_strict: StrictComparison::NonStrict,
336 upper_strict: StrictComparison::NonStrict,
337 },
338 )
339 .unwrap()
340 .to_bool()
341 .unwrap();
342
343 let indices = to_int_indices(matches).unwrap();
344 assert!(indices.is_empty());
345
346 let upper = ConstantArray::new(Scalar::from(2), 5);
348 let matches = between(
349 array.as_ref(),
350 lower.as_ref(),
351 upper.as_ref(),
352 &BetweenOptions {
353 lower_strict: StrictComparison::NonStrict,
354 upper_strict: StrictComparison::NonStrict,
355 },
356 )
357 .unwrap()
358 .to_bool()
359 .unwrap();
360 let indices = to_int_indices(matches).unwrap();
361 assert_eq!(indices, vec![0, 1, 3]);
362
363 let lower = ConstantArray::new(Scalar::from(0), 5);
365
366 let matches = between(
367 array.as_ref(),
368 lower.as_ref(),
369 upper.as_ref(),
370 &BetweenOptions {
371 lower_strict: StrictComparison::NonStrict,
372 upper_strict: StrictComparison::NonStrict,
373 },
374 )
375 .unwrap()
376 .to_bool()
377 .unwrap();
378 let indices = to_int_indices(matches).unwrap();
379 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
380 }
381}