1use std::any::Any;
2use std::sync::LazyLock;
3
4use vortex_dtype::DType;
5use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
6use vortex_scalar::Scalar;
7
8use crate::arcref::ArcRef;
9use crate::arrays::ConstantArray;
10use crate::compute::{
11 BooleanOperator, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Options, Output,
12 boolean, compare,
13};
14use crate::{Array, ArrayRef, Canonical, Encoding, IntoArray};
15
16pub fn between(
43 arr: &dyn Array,
44 lower: &dyn Array,
45 upper: &dyn Array,
46 options: &BetweenOptions,
47) -> VortexResult<ArrayRef> {
48 BETWEEN_FN
49 .invoke(&InvocationArgs {
50 inputs: &[arr.into(), lower.into(), upper.into()],
51 options,
52 })?
53 .unwrap_array()
54}
55
56pub struct BetweenKernelRef(ArcRef<dyn Kernel>);
57inventory::collect!(BetweenKernelRef);
58
59pub trait BetweenKernel: Encoding {
60 fn between(
61 &self,
62 arr: &Self::Array,
63 lower: &dyn Array,
64 upper: &dyn Array,
65 options: &BetweenOptions,
66 ) -> VortexResult<Option<ArrayRef>>;
67}
68
69#[derive(Debug)]
70pub struct BetweenKernelAdapter<E: Encoding>(pub E);
71
72impl<E: Encoding + BetweenKernel> BetweenKernelAdapter<E> {
73 pub const fn lift(&'static self) -> BetweenKernelRef {
74 BetweenKernelRef(ArcRef::new_ref(self))
75 }
76}
77
78impl<E: Encoding + BetweenKernel> Kernel for BetweenKernelAdapter<E> {
79 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
80 let inputs = BetweenArgs::try_from(args)?;
81 let Some(array) = inputs.array.as_any().downcast_ref::<E::Array>() else {
82 return Ok(None);
83 };
84 Ok(
85 E::between(&self.0, array, inputs.lower, inputs.upper, inputs.options)?
86 .map(|array| array.into()),
87 )
88 }
89}
90
91pub static BETWEEN_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
92 let compute = ComputeFn::new("between".into(), ArcRef::new_ref(&Between));
93 for kernel in inventory::iter::<BetweenKernelRef> {
94 compute.register_kernel(kernel.0.clone());
95 }
96 compute
97});
98
99struct Between;
100
101impl ComputeFnVTable for Between {
102 fn invoke(
103 &self,
104 args: &InvocationArgs,
105 kernels: &[ArcRef<dyn Kernel>],
106 ) -> VortexResult<Output> {
107 let BetweenArgs {
108 array,
109 lower,
110 upper,
111 options,
112 } = BetweenArgs::try_from(args)?;
113
114 let return_dtype = self.return_dtype(args)?;
115
116 if lower.is_invalid(0)? || upper.is_invalid(0)? {
118 if let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant()) {
119 if c_lower.is_null() || c_upper.is_null() {
120 return Ok(ConstantArray::new(Scalar::null(return_dtype), array.len())
121 .into_array()
122 .into());
123 }
124 }
125 }
126
127 if lower.as_constant().is_some_and(|v| v.is_null())
128 || upper.as_constant().is_some_and(|v| v.is_null())
129 {
130 return Ok(Canonical::empty(&return_dtype).into_array().into());
131 }
132
133 for kernel in kernels {
135 if let Some(output) = kernel.invoke(args)? {
136 return Ok(output);
137 }
138 }
139 if let Some(output) = array.invoke(&BETWEEN_FN, args)? {
140 return Ok(output);
141 }
142
143 Ok(boolean(
146 &compare(lower, array, options.lower_strict.to_operator())?,
147 &compare(array, upper, options.upper_strict.to_operator())?,
148 BooleanOperator::And,
149 )?
150 .into())
151 }
152
153 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
154 let BetweenArgs {
155 array,
156 lower,
157 upper,
158 options: _,
159 } = BetweenArgs::try_from(args)?;
160
161 if !array.dtype().eq_ignore_nullability(lower.dtype()) {
162 vortex_bail!(
163 "Array and lower bound types do not match: {:?} != {:?}",
164 array.dtype(),
165 lower.dtype()
166 );
167 }
168 if !array.dtype().eq_ignore_nullability(upper.dtype()) {
169 vortex_bail!(
170 "Array and upper bound types do not match: {:?} != {:?}",
171 array.dtype(),
172 upper.dtype()
173 );
174 }
175
176 Ok(DType::Bool(
177 array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability(),
178 ))
179 }
180
181 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
182 let BetweenArgs {
183 array,
184 lower,
185 upper,
186 options: _,
187 } = BetweenArgs::try_from(args)?;
188 if array.len() != lower.len() || array.len() != upper.len() {
189 vortex_bail!(
190 "Array lengths do not match: array:{} lower:{} upper:{}",
191 array.len(),
192 lower.len(),
193 upper.len()
194 );
195 }
196 Ok(array.len())
197 }
198
199 fn is_elementwise(&self) -> bool {
200 true
201 }
202}
203
204struct BetweenArgs<'a> {
205 array: &'a dyn Array,
206 lower: &'a dyn Array,
207 upper: &'a dyn Array,
208 options: &'a BetweenOptions,
209}
210
211impl<'a> TryFrom<&InvocationArgs<'a>> for BetweenArgs<'a> {
212 type Error = VortexError;
213
214 fn try_from(value: &InvocationArgs<'a>) -> VortexResult<Self> {
215 if value.inputs.len() != 3 {
216 vortex_bail!("Expected 3 inputs, found {}", value.inputs.len());
217 }
218 let array = value.inputs[0]
219 .array()
220 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
221 let lower = value.inputs[1]
222 .array()
223 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
224 let upper = value.inputs[2]
225 .array()
226 .ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
227 let options = value
228 .options
229 .as_any()
230 .downcast_ref::<BetweenOptions>()
231 .vortex_expect("Expected options to be an operator");
232
233 Ok(BetweenArgs {
234 array,
235 lower,
236 upper,
237 options,
238 })
239 }
240}
241
242#[derive(Debug, Clone, PartialEq, Eq, Hash)]
243pub struct BetweenOptions {
244 pub lower_strict: StrictComparison,
245 pub upper_strict: StrictComparison,
246}
247
248impl Options for BetweenOptions {
249 fn as_any(&self) -> &dyn Any {
250 self
251 }
252}
253
254#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
255pub enum StrictComparison {
256 Strict,
257 NonStrict,
258}
259
260impl StrictComparison {
261 pub const fn to_operator(&self) -> Operator {
262 match self {
263 StrictComparison::Strict => Operator::Lt,
264 StrictComparison::NonStrict => Operator::Lte,
265 }
266 }
267}