1use std::any::Any;
2use std::sync::LazyLock;
3
4use arcref::ArcRef;
5use vortex_dtype::DType;
6use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
7use vortex_scalar::Scalar;
8
9use crate::arrays::ConstantArray;
10use crate::compute::{
11 BooleanOperator, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Options, Output,
12 boolean, compare,
13};
14use crate::vtable::VTable;
15use crate::{Array, ArrayRef, Canonical, IntoArray};
16
17pub fn between(
44 arr: &dyn Array,
45 lower: &dyn Array,
46 upper: &dyn Array,
47 options: &BetweenOptions,
48) -> VortexResult<ArrayRef> {
49 BETWEEN_FN
50 .invoke(&InvocationArgs {
51 inputs: &[arr.into(), lower.into(), upper.into()],
52 options,
53 })?
54 .unwrap_array()
55}
56
57pub struct BetweenKernelRef(ArcRef<dyn Kernel>);
58inventory::collect!(BetweenKernelRef);
59
60pub trait BetweenKernel: VTable {
61 fn between(
62 &self,
63 arr: &Self::Array,
64 lower: &dyn Array,
65 upper: &dyn Array,
66 options: &BetweenOptions,
67 ) -> VortexResult<Option<ArrayRef>>;
68}
69
70#[derive(Debug)]
71pub struct BetweenKernelAdapter<V: VTable>(pub V);
72
73impl<V: VTable + BetweenKernel> BetweenKernelAdapter<V> {
74 pub const fn lift(&'static self) -> BetweenKernelRef {
75 BetweenKernelRef(ArcRef::new_ref(self))
76 }
77}
78
79impl<V: VTable + BetweenKernel> Kernel for BetweenKernelAdapter<V> {
80 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
81 let inputs = BetweenArgs::try_from(args)?;
82 let Some(array) = inputs.array.as_opt::<V>() else {
83 return Ok(None);
84 };
85 Ok(
86 V::between(&self.0, array, inputs.lower, inputs.upper, inputs.options)?
87 .map(|array| array.into()),
88 )
89 }
90}
91
92pub static BETWEEN_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
93 let compute = ComputeFn::new("between".into(), ArcRef::new_ref(&Between));
94 for kernel in inventory::iter::<BetweenKernelRef> {
95 compute.register_kernel(kernel.0.clone());
96 }
97 compute
98});
99
100struct Between;
101
102impl ComputeFnVTable for Between {
103 fn invoke(
104 &self,
105 args: &InvocationArgs,
106 kernels: &[ArcRef<dyn Kernel>],
107 ) -> VortexResult<Output> {
108 let BetweenArgs {
109 array,
110 lower,
111 upper,
112 options,
113 } = BetweenArgs::try_from(args)?;
114
115 let return_dtype = self.return_dtype(args)?;
116
117 if array.is_empty() {
119 return Ok(Canonical::empty(&return_dtype).into_array().into());
120 }
121
122 if lower.is_invalid(0)? || upper.is_invalid(0)? {
125 if let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant()) {
126 if c_lower.is_null() || c_upper.is_null() {
127 return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
128 .into_array()
129 .into());
130 }
131 }
132 }
133
134 if lower.as_constant().is_some_and(|v| v.is_null())
135 || upper.as_constant().is_some_and(|v| v.is_null())
136 {
137 return Ok(Canonical::empty(&return_dtype).into_array().into());
138 }
139
140 for kernel in kernels {
142 if let Some(output) = kernel.invoke(args)? {
143 return Ok(output);
144 }
145 }
146 if let Some(output) = array.invoke(&BETWEEN_FN, args)? {
147 return Ok(output);
148 }
149
150 Ok(boolean(
153 &compare(lower, array, options.lower_strict.to_operator())?,
154 &compare(array, upper, options.upper_strict.to_operator())?,
155 BooleanOperator::And,
156 )?
157 .into())
158 }
159
160 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
161 let BetweenArgs {
162 array,
163 lower,
164 upper,
165 options: _,
166 } = BetweenArgs::try_from(args)?;
167
168 if !array.dtype().eq_ignore_nullability(lower.dtype()) {
169 vortex_bail!(
170 "Array and lower bound types do not match: {:?} != {:?}",
171 array.dtype(),
172 lower.dtype()
173 );
174 }
175 if !array.dtype().eq_ignore_nullability(upper.dtype()) {
176 vortex_bail!(
177 "Array and upper bound types do not match: {:?} != {:?}",
178 array.dtype(),
179 upper.dtype()
180 );
181 }
182
183 Ok(DType::Bool(
184 array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability(),
185 ))
186 }
187
188 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
189 let BetweenArgs {
190 array,
191 lower,
192 upper,
193 options: _,
194 } = BetweenArgs::try_from(args)?;
195 if array.len() != lower.len() || array.len() != upper.len() {
196 vortex_bail!(
197 "Array lengths do not match: array:{} lower:{} upper:{}",
198 array.len(),
199 lower.len(),
200 upper.len()
201 );
202 }
203 Ok(array.len())
204 }
205
206 fn is_elementwise(&self) -> bool {
207 true
208 }
209}
210
211struct BetweenArgs<'a> {
212 array: &'a dyn Array,
213 lower: &'a dyn Array,
214 upper: &'a dyn Array,
215 options: &'a BetweenOptions,
216}
217
218impl<'a> TryFrom<&InvocationArgs<'a>> for BetweenArgs<'a> {
219 type Error = VortexError;
220
221 fn try_from(value: &InvocationArgs<'a>) -> VortexResult<Self> {
222 if value.inputs.len() != 3 {
223 vortex_bail!("Expected 3 inputs, found {}", value.inputs.len());
224 }
225 let array = value.inputs[0]
226 .array()
227 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
228 let lower = value.inputs[1]
229 .array()
230 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
231 let upper = value.inputs[2]
232 .array()
233 .ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
234 let options = value
235 .options
236 .as_any()
237 .downcast_ref::<BetweenOptions>()
238 .vortex_expect("Expected options to be an operator");
239
240 Ok(BetweenArgs {
241 array,
242 lower,
243 upper,
244 options,
245 })
246 }
247}
248
249#[derive(Debug, Clone, PartialEq, Eq, Hash)]
250pub struct BetweenOptions {
251 pub lower_strict: StrictComparison,
252 pub upper_strict: StrictComparison,
253}
254
255impl Options for BetweenOptions {
256 fn as_any(&self) -> &dyn Any {
257 self
258 }
259}
260
261#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
262pub enum StrictComparison {
263 Strict,
264 NonStrict,
265}
266
267impl StrictComparison {
268 pub const fn to_operator(&self) -> Operator {
269 match self {
270 StrictComparison::Strict => Operator::Lt,
271 StrictComparison::NonStrict => Operator::Lte,
272 }
273 }
274}