1use std::any::Any;
5use std::fmt::Debug;
6use std::hash::{Hash, Hasher};
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use futures::try_join;
11use itertools::Itertools;
12use vortex_array::compute::{BetweenOptions, StrictComparison, between as between_compute};
13use vortex_array::operator::{
14 BatchBindCtx, BatchExecution, BatchExecutionRef, BatchOperator, Operator, OperatorEq,
15 OperatorHash, OperatorId, OperatorRef,
16};
17use vortex_array::{Array, ArrayRef, Canonical, DeserializeMetadata, IntoArray, ProstMetadata};
18use vortex_dtype::DType;
19use vortex_dtype::DType::Bool;
20use vortex_error::{VortexExpect, VortexResult, vortex_bail};
21use vortex_proto::expr as pb;
22
23use crate::display::{DisplayAs, DisplayFormat};
24use crate::{
25 AnalysisExpr, BinaryExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable,
26};
27
28vtable!(Between);
29
30#[allow(clippy::derived_hash_with_manual_eq)]
31#[derive(Clone, Debug, Hash, Eq)]
32pub struct BetweenExpr {
33 arr: ExprRef,
34 lower: ExprRef,
35 upper: ExprRef,
36 options: BetweenOptions,
37}
38
39impl PartialEq for BetweenExpr {
40 fn eq(&self, other: &Self) -> bool {
41 self.arr.eq(&other.arr)
42 && self.lower.eq(&other.lower)
43 && self.upper.eq(&other.upper)
44 && self.options == other.options
45 }
46}
47
48pub struct BetweenExprEncoding;
49
50impl VTable for BetweenVTable {
51 type Expr = BetweenExpr;
52 type Encoding = BetweenExprEncoding;
53 type Metadata = ProstMetadata<pb::BetweenOpts>;
54
55 fn id(_encoding: &Self::Encoding) -> ExprId {
56 ExprId::new_ref("between")
57 }
58
59 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
60 ExprEncodingRef::new_ref(BetweenExprEncoding.as_ref())
61 }
62
63 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
64 Some(ProstMetadata(pb::BetweenOpts {
65 lower_strict: expr.options.lower_strict == StrictComparison::Strict,
66 upper_strict: expr.options.upper_strict == StrictComparison::Strict,
67 }))
68 }
69
70 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
71 vec![&expr.arr, &expr.lower, &expr.upper]
72 }
73
74 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
75 Ok(BetweenExpr::new(
76 children[0].clone(),
77 children[1].clone(),
78 children[2].clone(),
79 expr.options.clone(),
80 ))
81 }
82
83 fn build(
84 _encoding: &Self::Encoding,
85 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
86 children: Vec<ExprRef>,
87 ) -> VortexResult<Self::Expr> {
88 Ok(BetweenExpr::new(
89 children[0].clone(),
90 children[1].clone(),
91 children[2].clone(),
92 BetweenOptions {
93 lower_strict: if metadata.lower_strict {
94 StrictComparison::Strict
95 } else {
96 StrictComparison::NonStrict
97 },
98 upper_strict: if metadata.upper_strict {
99 StrictComparison::Strict
100 } else {
101 StrictComparison::NonStrict
102 },
103 },
104 ))
105 }
106
107 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
108 let arr_val = expr.arr.unchecked_evaluate(scope)?;
109 let lower_arr_val = expr.lower.unchecked_evaluate(scope)?;
110 let upper_arr_val = expr.upper.unchecked_evaluate(scope)?;
111
112 between_compute(&arr_val, &lower_arr_val, &upper_arr_val, &expr.options)
113 }
114
115 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
116 let arr_dt = expr.arr.return_dtype(scope)?;
117 let lower_dt = expr.lower.return_dtype(scope)?;
118 let upper_dt = expr.upper.return_dtype(scope)?;
119
120 if !arr_dt.eq_ignore_nullability(&lower_dt) {
121 vortex_bail!(
122 "Array dtype {} does not match lower dtype {}",
123 arr_dt,
124 lower_dt
125 );
126 }
127 if !arr_dt.eq_ignore_nullability(&upper_dt) {
128 vortex_bail!(
129 "Array dtype {} does not match upper dtype {}",
130 arr_dt,
131 upper_dt
132 );
133 }
134
135 Ok(Bool(
136 arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
137 ))
138 }
139
140 fn operator(expr: &Self::Expr, scope: &OperatorRef) -> VortexResult<Option<OperatorRef>> {
141 let Some(arr) = expr.arr.operator(scope)? else {
142 return Ok(None);
143 };
144 let Some(lower) = expr.lower.operator(scope)? else {
145 return Ok(None);
146 };
147 let Some(upper) = expr.upper.operator(scope)? else {
148 return Ok(None);
149 };
150 Ok(Some(Arc::new(BetweenOperator {
151 children: [arr, lower, upper],
152 dtype: expr.return_dtype(scope.dtype())?,
153 options: expr.options.clone(),
154 })))
155 }
156}
157
158impl BetweenExpr {
159 pub fn new(arr: ExprRef, lower: ExprRef, upper: ExprRef, options: BetweenOptions) -> Self {
160 Self {
161 arr,
162 lower,
163 upper,
164 options,
165 }
166 }
167
168 pub fn new_expr(
169 arr: ExprRef,
170 lower: ExprRef,
171 upper: ExprRef,
172 options: BetweenOptions,
173 ) -> ExprRef {
174 Self::new(arr, lower, upper, options).into_expr()
175 }
176
177 pub fn to_binary_expr(&self) -> ExprRef {
178 let lhs = BinaryExpr::new(
179 self.lower.clone(),
180 self.options.lower_strict.to_operator().into(),
181 self.arr.clone(),
182 );
183 let rhs = BinaryExpr::new(
184 self.arr.clone(),
185 self.options.upper_strict.to_operator().into(),
186 self.upper.clone(),
187 );
188 BinaryExpr::new(lhs.into_expr(), crate::Operator::And, rhs.into_expr()).into_expr()
189 }
190}
191
192impl DisplayAs for BetweenExpr {
193 fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
194 match df {
195 DisplayFormat::Compact => {
196 write!(
197 f,
198 "({} {} {} {} {})",
199 self.lower,
200 self.options.lower_strict.to_operator(),
201 self.arr,
202 self.options.upper_strict.to_operator(),
203 self.upper
204 )
205 }
206 DisplayFormat::Tree => {
207 write!(f, "Between")
208 }
209 }
210 }
211
212 fn child_names(&self) -> Option<Vec<String>> {
213 Some(vec![
215 "array".to_string(),
216 format!("lower ({:?})", self.options.lower_strict),
217 format!("upper ({:?})", self.options.upper_strict),
218 ])
219 }
220}
221
222impl AnalysisExpr for BetweenExpr {}
223
224pub fn between(arr: ExprRef, lower: ExprRef, upper: ExprRef, options: BetweenOptions) -> ExprRef {
240 BetweenExpr::new(arr, lower, upper, options).into_expr()
241}
242
243#[derive(Debug)]
244pub struct BetweenOperator {
245 children: [OperatorRef; 3],
246 dtype: DType,
247 options: BetweenOptions,
248}
249
250impl OperatorHash for BetweenOperator {
251 fn operator_hash<H: Hasher>(&self, state: &mut H) {
252 for child in &self.children {
253 child.operator_hash(state);
254 }
255 self.dtype.hash(state);
256 self.options.hash(state);
257 }
258}
259
260impl OperatorEq for BetweenOperator {
261 fn operator_eq(&self, other: &Self) -> bool {
262 self.children.len() == other.children.len()
263 && self
264 .children
265 .iter()
266 .zip(other.children.iter())
267 .all(|(a, b)| a.operator_eq(b))
268 && self.dtype == other.dtype
269 && self.options == other.options
270 }
271}
272
273impl Operator for BetweenOperator {
274 fn id(&self) -> OperatorId {
275 OperatorId::from("vortex.between")
276 }
277
278 fn as_any(&self) -> &dyn Any {
279 self
280 }
281
282 fn dtype(&self) -> &DType {
283 &self.dtype
284 }
285
286 fn len(&self) -> usize {
287 self.children[0].len()
288 }
289
290 fn children(&self) -> &[OperatorRef] {
291 &self.children
292 }
293
294 fn with_children(self: Arc<Self>, children: Vec<OperatorRef>) -> VortexResult<OperatorRef> {
295 let (arr, lower, upper) = children
296 .into_iter()
297 .tuples()
298 .next()
299 .vortex_expect("expected 3 children");
300
301 Ok(Arc::new(BetweenOperator {
302 children: [arr, lower, upper],
303 dtype: self.dtype.clone(),
304 options: self.options.clone(),
305 }))
306 }
307
308 fn is_selection_target(&self, _child_idx: usize) -> Option<bool> {
309 Some(true)
311 }
312}
313
314impl BatchOperator for BetweenOperator {
315 fn bind(&self, ctx: &mut dyn BatchBindCtx) -> VortexResult<BatchExecutionRef> {
316 let arr = ctx.child(0)?;
317 let lower = ctx.child(1)?;
318 let upper = ctx.child(2)?;
319 Ok(Box::new(BetweenExecution {
320 arr,
321 lower,
322 upper,
323 options: self.options.clone(),
324 }))
325 }
326}
327
328struct BetweenExecution {
329 arr: BatchExecutionRef,
330 lower: BatchExecutionRef,
331 upper: BatchExecutionRef,
332 options: BetweenOptions,
333}
334
335#[async_trait]
336impl BatchExecution for BetweenExecution {
337 async fn execute(self: Box<Self>) -> VortexResult<Canonical> {
338 let (arr, lower, upper) = try_join!(
339 self.arr.execute(),
340 self.lower.execute(),
341 self.upper.execute()
342 )?;
343 let result = between_compute(
344 arr.into_array().as_ref(),
345 lower.into_array().as_ref(),
346 upper.into_array().as_ref(),
347 &self.options,
348 )?;
349 Ok(result.to_canonical())
350 }
351}
352
353#[cfg(test)]
357mod tests {
358 use vortex_array::compute::{BetweenOptions, StrictComparison};
359
360 use crate::{between, get_item, lit, root};
361
362 #[test]
363 fn test_display() {
364 let expr = between(
365 get_item("score", root()),
366 lit(10),
367 lit(50),
368 BetweenOptions {
369 lower_strict: StrictComparison::NonStrict,
370 upper_strict: StrictComparison::Strict,
371 },
372 );
373 assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
374
375 let expr2 = between(
376 root(),
377 lit(0),
378 lit(100),
379 BetweenOptions {
380 lower_strict: StrictComparison::Strict,
381 upper_strict: StrictComparison::NonStrict,
382 },
383 );
384 assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
385 }
386}