1use std::any::Any;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::sync::LazyLock;
8
9use arcref::ArcRef;
10use vortex_dtype::DType;
11use vortex_error::VortexError;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_err;
16use vortex_scalar::Scalar;
17
18use crate::Array;
19use crate::ArrayRef;
20use crate::Canonical;
21use crate::IntoArray;
22use crate::arrays::ConstantArray;
23use crate::compute::BooleanOperator;
24use crate::compute::ComputeFn;
25use crate::compute::ComputeFnVTable;
26use crate::compute::InvocationArgs;
27use crate::compute::Kernel;
28use crate::compute::Operator;
29use crate::compute::Options;
30use crate::compute::Output;
31use crate::compute::boolean;
32use crate::compute::compare;
33use crate::vtable::VTable;
34
35static BETWEEN_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
36 let compute = ComputeFn::new("between".into(), ArcRef::new_ref(&Between));
37 for kernel in inventory::iter::<BetweenKernelRef> {
38 compute.register_kernel(kernel.0.clone());
39 }
40 compute
41});
42
43pub(crate) fn warm_up_vtable() -> usize {
44 BETWEEN_FN.kernels().len()
45}
46
47pub fn between(
54 arr: &dyn Array,
55 lower: &dyn Array,
56 upper: &dyn Array,
57 options: &BetweenOptions,
58) -> VortexResult<ArrayRef> {
59 BETWEEN_FN
60 .invoke(&InvocationArgs {
61 inputs: &[arr.into(), lower.into(), upper.into()],
62 options,
63 })?
64 .unwrap_array()
65}
66
67pub struct BetweenKernelRef(ArcRef<dyn Kernel>);
68inventory::collect!(BetweenKernelRef);
69
70pub trait BetweenKernel: VTable {
71 fn between(
72 &self,
73 arr: &Self::Array,
74 lower: &dyn Array,
75 upper: &dyn Array,
76 options: &BetweenOptions,
77 ) -> VortexResult<Option<ArrayRef>>;
78}
79
80#[derive(Debug)]
81pub struct BetweenKernelAdapter<V: VTable>(pub V);
82
83impl<V: VTable + BetweenKernel> BetweenKernelAdapter<V> {
84 pub const fn lift(&'static self) -> BetweenKernelRef {
85 BetweenKernelRef(ArcRef::new_ref(self))
86 }
87}
88
89impl<V: VTable + BetweenKernel> Kernel for BetweenKernelAdapter<V> {
90 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
91 let inputs = BetweenArgs::try_from(args)?;
92 let Some(array) = inputs.array.as_opt::<V>() else {
93 return Ok(None);
94 };
95 Ok(
96 V::between(&self.0, array, inputs.lower, inputs.upper, inputs.options)?
97 .map(|array| array.into()),
98 )
99 }
100}
101
102struct Between;
103
104impl ComputeFnVTable for Between {
105 fn invoke(
106 &self,
107 args: &InvocationArgs,
108 kernels: &[ArcRef<dyn Kernel>],
109 ) -> VortexResult<Output> {
110 let BetweenArgs {
111 array,
112 lower,
113 upper,
114 options,
115 } = BetweenArgs::try_from(args)?;
116
117 let return_dtype = self.return_dtype(args)?;
118
119 if array.is_empty() {
121 return Ok(Canonical::empty(&return_dtype).into_array().into());
122 }
123
124 if (lower.is_invalid(0) || upper.is_invalid(0))
127 && let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
128 && (c_lower.is_null() || c_upper.is_null())
129 {
130 return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
131 .into_array()
132 .into());
133 }
134
135 if lower.as_constant().is_some_and(|v| v.is_null())
136 || upper.as_constant().is_some_and(|v| v.is_null())
137 {
138 return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
139 .into_array()
140 .into());
141 }
142
143 for kernel in kernels {
145 if let Some(output) = kernel.invoke(args)? {
146 return Ok(output);
147 }
148 }
149 if let Some(output) = array.invoke(&BETWEEN_FN, args)? {
150 return Ok(output);
151 }
152
153 Ok(boolean(
156 &compare(lower, array, options.lower_strict.to_operator())?,
157 &compare(array, upper, options.upper_strict.to_operator())?,
158 BooleanOperator::And,
159 )?
160 .into())
161 }
162
163 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
164 let BetweenArgs {
165 array,
166 lower,
167 upper,
168 options: _,
169 } = BetweenArgs::try_from(args)?;
170
171 if !array.dtype().eq_ignore_nullability(lower.dtype()) {
172 vortex_bail!(
173 "Array and lower bound types do not match: {:?} != {:?}",
174 array.dtype(),
175 lower.dtype()
176 );
177 }
178 if !array.dtype().eq_ignore_nullability(upper.dtype()) {
179 vortex_bail!(
180 "Array and upper bound types do not match: {:?} != {:?}",
181 array.dtype(),
182 upper.dtype()
183 );
184 }
185
186 Ok(DType::Bool(
187 array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability(),
188 ))
189 }
190
191 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
192 let BetweenArgs {
193 array,
194 lower,
195 upper,
196 options: _,
197 } = BetweenArgs::try_from(args)?;
198 if array.len() != lower.len() || array.len() != upper.len() {
199 vortex_bail!(
200 "Array lengths do not match: array:{} lower:{} upper:{}",
201 array.len(),
202 lower.len(),
203 upper.len()
204 );
205 }
206 Ok(array.len())
207 }
208
209 fn is_elementwise(&self) -> bool {
210 true
211 }
212}
213
214struct BetweenArgs<'a> {
215 array: &'a dyn Array,
216 lower: &'a dyn Array,
217 upper: &'a dyn Array,
218 options: &'a BetweenOptions,
219}
220
221impl<'a> TryFrom<&InvocationArgs<'a>> for BetweenArgs<'a> {
222 type Error = VortexError;
223
224 fn try_from(value: &InvocationArgs<'a>) -> VortexResult<Self> {
225 if value.inputs.len() != 3 {
226 vortex_bail!("Expected 3 inputs, found {}", value.inputs.len());
227 }
228 let array = value.inputs[0]
229 .array()
230 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
231 let lower = value.inputs[1]
232 .array()
233 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
234 let upper = value.inputs[2]
235 .array()
236 .ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
237 let options = value
238 .options
239 .as_any()
240 .downcast_ref::<BetweenOptions>()
241 .vortex_expect("Expected options to be an operator");
242
243 Ok(BetweenArgs {
244 array,
245 lower,
246 upper,
247 options,
248 })
249 }
250}
251
252#[derive(Debug, Clone, PartialEq, Eq, Hash)]
253pub struct BetweenOptions {
254 pub lower_strict: StrictComparison,
255 pub upper_strict: StrictComparison,
256}
257
258impl Display for BetweenOptions {
259 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
260 let lower_op = if self.lower_strict.is_strict() {
261 "<"
262 } else {
263 "<="
264 };
265 let upper_op = if self.upper_strict.is_strict() {
266 "<"
267 } else {
268 "<="
269 };
270 write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
271 }
272}
273
274impl Options for BetweenOptions {
275 fn as_any(&self) -> &dyn Any {
276 self
277 }
278}
279
280#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
282pub enum StrictComparison {
283 Strict,
285 NonStrict,
287}
288
289impl StrictComparison {
290 pub const fn to_operator(&self) -> Operator {
291 match self {
292 StrictComparison::Strict => Operator::Lt,
293 StrictComparison::NonStrict => Operator::Lte,
294 }
295 }
296
297 pub const fn is_strict(&self) -> bool {
298 matches!(self, StrictComparison::Strict)
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use vortex_buffer::buffer;
305 use vortex_dtype::Nullability;
306 use vortex_dtype::PType;
307
308 use super::*;
309 use crate::ToCanonical;
310 use crate::compute::conformance::search_sorted::rstest;
311 use crate::test_harness::to_int_indices;
312
313 #[rstest]
314 #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
315 #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
316 #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
317 #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
318 fn test_bounds(
319 #[case] lower_strict: StrictComparison,
320 #[case] upper_strict: StrictComparison,
321 #[case] expected: Vec<u64>,
322 ) {
323 let lower = buffer![0, 0, 0, 0, 2].into_array();
324 let array = buffer![1, 0, 1, 0, 1].into_array();
325 let upper = buffer![2, 1, 1, 0, 0].into_array();
326
327 let matches = between(
328 array.as_ref(),
329 lower.as_ref(),
330 upper.as_ref(),
331 &BetweenOptions {
332 lower_strict,
333 upper_strict,
334 },
335 )
336 .unwrap()
337 .to_bool();
338
339 let indices = to_int_indices(matches).unwrap();
340 assert_eq!(indices, expected);
341 }
342
343 #[test]
344 fn test_constants() {
345 let lower = buffer![0, 0, 2, 0, 2].into_array();
346 let array = buffer![1, 0, 1, 0, 1].into_array();
347
348 let upper = ConstantArray::new(
350 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
351 5,
352 );
353
354 let matches = between(
355 array.as_ref(),
356 lower.as_ref(),
357 upper.as_ref(),
358 &BetweenOptions {
359 lower_strict: StrictComparison::NonStrict,
360 upper_strict: StrictComparison::NonStrict,
361 },
362 )
363 .unwrap()
364 .to_bool();
365
366 let indices = to_int_indices(matches).unwrap();
367 assert!(indices.is_empty());
368
369 let upper = ConstantArray::new(Scalar::from(2), 5);
371 let matches = between(
372 array.as_ref(),
373 lower.as_ref(),
374 upper.as_ref(),
375 &BetweenOptions {
376 lower_strict: StrictComparison::NonStrict,
377 upper_strict: StrictComparison::NonStrict,
378 },
379 )
380 .unwrap()
381 .to_bool();
382 let indices = to_int_indices(matches).unwrap();
383 assert_eq!(indices, vec![0, 1, 3]);
384
385 let lower = ConstantArray::new(Scalar::from(0), 5);
387
388 let matches = between(
389 array.as_ref(),
390 lower.as_ref(),
391 upper.as_ref(),
392 &BetweenOptions {
393 lower_strict: StrictComparison::NonStrict,
394 upper_strict: StrictComparison::NonStrict,
395 },
396 )
397 .unwrap()
398 .to_bool();
399 let indices = to_int_indices(matches).unwrap();
400 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
401 }
402}