vortex_expr/exprs/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt::Debug;
6use std::hash::{Hash, Hasher};
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use futures::try_join;
11use itertools::Itertools;
12use vortex_array::compute::{BetweenOptions, StrictComparison, between as between_compute};
13use vortex_array::operator::{
14    BatchBindCtx, BatchExecution, BatchExecutionRef, BatchOperator, LengthBounds, Operator,
15    OperatorEq, OperatorHash, OperatorId, OperatorRef,
16};
17use vortex_array::{Array, ArrayRef, Canonical, DeserializeMetadata, IntoArray, ProstMetadata};
18use vortex_dtype::DType;
19use vortex_dtype::DType::Bool;
20use vortex_error::{VortexExpect, VortexResult, vortex_bail};
21use vortex_proto::expr as pb;
22
23use crate::display::{DisplayAs, DisplayFormat};
24use crate::{
25    AnalysisExpr, BinaryExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, StatsCatalog,
26    VTable, vtable,
27};
28
29vtable!(Between);
30
31#[allow(clippy::derived_hash_with_manual_eq)]
32#[derive(Clone, Debug, Hash, Eq)]
33pub struct BetweenExpr {
34    arr: ExprRef,
35    lower: ExprRef,
36    upper: ExprRef,
37    options: BetweenOptions,
38}
39
40impl PartialEq for BetweenExpr {
41    fn eq(&self, other: &Self) -> bool {
42        self.arr.eq(&other.arr)
43            && self.lower.eq(&other.lower)
44            && self.upper.eq(&other.upper)
45            && self.options == other.options
46    }
47}
48
49pub struct BetweenExprEncoding;
50
51impl VTable for BetweenVTable {
52    type Expr = BetweenExpr;
53    type Encoding = BetweenExprEncoding;
54    type Metadata = ProstMetadata<pb::BetweenOpts>;
55
56    fn id(_encoding: &Self::Encoding) -> ExprId {
57        ExprId::new_ref("between")
58    }
59
60    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
61        ExprEncodingRef::new_ref(BetweenExprEncoding.as_ref())
62    }
63
64    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
65        Some(ProstMetadata(pb::BetweenOpts {
66            lower_strict: expr.options.lower_strict == StrictComparison::Strict,
67            upper_strict: expr.options.upper_strict == StrictComparison::Strict,
68        }))
69    }
70
71    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
72        vec![&expr.arr, &expr.lower, &expr.upper]
73    }
74
75    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
76        Ok(BetweenExpr::new(
77            children[0].clone(),
78            children[1].clone(),
79            children[2].clone(),
80            expr.options.clone(),
81        ))
82    }
83
84    fn build(
85        _encoding: &Self::Encoding,
86        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
87        children: Vec<ExprRef>,
88    ) -> VortexResult<Self::Expr> {
89        Ok(BetweenExpr::new(
90            children[0].clone(),
91            children[1].clone(),
92            children[2].clone(),
93            BetweenOptions {
94                lower_strict: if metadata.lower_strict {
95                    StrictComparison::Strict
96                } else {
97                    StrictComparison::NonStrict
98                },
99                upper_strict: if metadata.upper_strict {
100                    StrictComparison::Strict
101                } else {
102                    StrictComparison::NonStrict
103                },
104            },
105        ))
106    }
107
108    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
109        let arr_val = expr.arr.unchecked_evaluate(scope)?;
110        let lower_arr_val = expr.lower.unchecked_evaluate(scope)?;
111        let upper_arr_val = expr.upper.unchecked_evaluate(scope)?;
112
113        between_compute(&arr_val, &lower_arr_val, &upper_arr_val, &expr.options)
114    }
115
116    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
117        let arr_dt = expr.arr.return_dtype(scope)?;
118        let lower_dt = expr.lower.return_dtype(scope)?;
119        let upper_dt = expr.upper.return_dtype(scope)?;
120
121        if !arr_dt.eq_ignore_nullability(&lower_dt) {
122            vortex_bail!(
123                "Array dtype {} does not match lower dtype {}",
124                arr_dt,
125                lower_dt
126            );
127        }
128        if !arr_dt.eq_ignore_nullability(&upper_dt) {
129            vortex_bail!(
130                "Array dtype {} does not match upper dtype {}",
131                arr_dt,
132                upper_dt
133            );
134        }
135
136        Ok(Bool(
137            arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
138        ))
139    }
140
141    fn operator(expr: &Self::Expr, scope: &OperatorRef) -> VortexResult<Option<OperatorRef>> {
142        let Some(arr) = expr.arr.operator(scope)? else {
143            return Ok(None);
144        };
145        let Some(lower) = expr.lower.operator(scope)? else {
146            return Ok(None);
147        };
148        let Some(upper) = expr.upper.operator(scope)? else {
149            return Ok(None);
150        };
151        Ok(Some(Arc::new(BetweenOperator {
152            children: [arr, lower, upper],
153            dtype: expr.return_dtype(scope.dtype())?,
154            options: expr.options.clone(),
155        })))
156    }
157}
158
159impl BetweenExpr {
160    pub fn new(arr: ExprRef, lower: ExprRef, upper: ExprRef, options: BetweenOptions) -> Self {
161        Self {
162            arr,
163            lower,
164            upper,
165            options,
166        }
167    }
168
169    pub fn new_expr(
170        arr: ExprRef,
171        lower: ExprRef,
172        upper: ExprRef,
173        options: BetweenOptions,
174    ) -> ExprRef {
175        Self::new(arr, lower, upper, options).into_expr()
176    }
177
178    pub fn to_binary_expr(&self) -> ExprRef {
179        let lhs = BinaryExpr::new(
180            self.lower.clone(),
181            self.options.lower_strict.to_operator().into(),
182            self.arr.clone(),
183        );
184        let rhs = BinaryExpr::new(
185            self.arr.clone(),
186            self.options.upper_strict.to_operator().into(),
187            self.upper.clone(),
188        );
189        BinaryExpr::new(lhs.into_expr(), crate::Operator::And, rhs.into_expr()).into_expr()
190    }
191}
192
193impl DisplayAs for BetweenExpr {
194    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
195        match df {
196            DisplayFormat::Compact => {
197                write!(
198                    f,
199                    "({} {} {} {} {})",
200                    self.lower,
201                    self.options.lower_strict.to_operator(),
202                    self.arr,
203                    self.options.upper_strict.to_operator(),
204                    self.upper
205                )
206            }
207            DisplayFormat::Tree => {
208                write!(f, "Between")
209            }
210        }
211    }
212
213    fn child_names(&self) -> Option<Vec<String>> {
214        // Children are: arr, lower, upper (based on the order in the children() method)
215        Some(vec![
216            "array".to_string(),
217            format!("lower ({:?})", self.options.lower_strict),
218            format!("upper ({:?})", self.options.upper_strict),
219        ])
220    }
221}
222
223impl AnalysisExpr for BetweenExpr {
224    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
225        self.to_binary_expr().stat_falsification(catalog)
226    }
227}
228
229/// Creates an expression that checks if values are between two bounds.
230///
231/// Returns a boolean array indicating which values fall within the specified range.
232/// The comparison strictness is controlled by the options parameter.
233///
234/// ```rust
235/// # use vortex_array::compute::BetweenOptions;
236/// # use vortex_array::compute::StrictComparison;
237/// # use vortex_expr::{between, lit, root};
238/// let opts = BetweenOptions {
239///     lower_strict: StrictComparison::NonStrict,
240///     upper_strict: StrictComparison::NonStrict,
241/// };
242/// let expr = between(root(), lit(10), lit(20), opts);
243/// ```
244pub fn between(arr: ExprRef, lower: ExprRef, upper: ExprRef, options: BetweenOptions) -> ExprRef {
245    BetweenExpr::new(arr, lower, upper, options).into_expr()
246}
247
248#[derive(Debug)]
249pub struct BetweenOperator {
250    children: [OperatorRef; 3],
251    dtype: DType,
252    options: BetweenOptions,
253}
254
255impl OperatorHash for BetweenOperator {
256    fn operator_hash<H: Hasher>(&self, state: &mut H) {
257        for child in &self.children {
258            child.operator_hash(state);
259        }
260        self.dtype.hash(state);
261        self.options.hash(state);
262    }
263}
264
265impl OperatorEq for BetweenOperator {
266    fn operator_eq(&self, other: &Self) -> bool {
267        self.children.len() == other.children.len()
268            && self
269                .children
270                .iter()
271                .zip(other.children.iter())
272                .all(|(a, b)| a.operator_eq(b))
273            && self.dtype == other.dtype
274            && self.options == other.options
275    }
276}
277
278impl Operator for BetweenOperator {
279    fn id(&self) -> OperatorId {
280        OperatorId::from("vortex.between")
281    }
282
283    fn as_any(&self) -> &dyn Any {
284        self
285    }
286
287    fn dtype(&self) -> &DType {
288        &self.dtype
289    }
290
291    fn bounds(&self) -> LengthBounds {
292        self.children[0].bounds()
293    }
294
295    fn children(&self) -> &[OperatorRef] {
296        &self.children
297    }
298
299    fn with_children(self: Arc<Self>, children: Vec<OperatorRef>) -> VortexResult<OperatorRef> {
300        let (arr, lower, upper) = children
301            .into_iter()
302            .tuples()
303            .next()
304            .vortex_expect("expected 3 children");
305
306        Ok(Arc::new(BetweenOperator {
307            children: [arr, lower, upper],
308            dtype: self.dtype.clone(),
309            options: self.options.clone(),
310        }))
311    }
312
313    fn is_selection_target(&self, _child_idx: usize) -> Option<bool> {
314        // All children are position preserving.
315        Some(true)
316    }
317}
318
319impl BatchOperator for BetweenOperator {
320    fn bind(&self, ctx: &mut dyn BatchBindCtx) -> VortexResult<BatchExecutionRef> {
321        let arr = ctx.child(0)?;
322        let lower = ctx.child(1)?;
323        let upper = ctx.child(2)?;
324        Ok(Box::new(BetweenExecution {
325            arr,
326            lower,
327            upper,
328            options: self.options.clone(),
329        }))
330    }
331}
332
333struct BetweenExecution {
334    arr: BatchExecutionRef,
335    lower: BatchExecutionRef,
336    upper: BatchExecutionRef,
337    options: BetweenOptions,
338}
339
340#[async_trait]
341impl BatchExecution for BetweenExecution {
342    async fn execute(self: Box<Self>) -> VortexResult<Canonical> {
343        let (arr, lower, upper) = try_join!(
344            self.arr.execute(),
345            self.lower.execute(),
346            self.upper.execute()
347        )?;
348        let result = between_compute(
349            arr.into_array().as_ref(),
350            lower.into_array().as_ref(),
351            upper.into_array().as_ref(),
352            &self.options,
353        )?;
354        Ok(result.to_canonical())
355    }
356}
357
358// TODO(ngates): we need scalar variants for batch execution. Although really it should be
359//  pipelined?
360
361#[cfg(test)]
362mod tests {
363    use vortex_array::compute::{BetweenOptions, StrictComparison};
364
365    use crate::{between, get_item, lit, root};
366
367    #[test]
368    fn test_display() {
369        let expr = between(
370            get_item("score", root()),
371            lit(10),
372            lit(50),
373            BetweenOptions {
374                lower_strict: StrictComparison::NonStrict,
375                upper_strict: StrictComparison::Strict,
376            },
377        );
378        assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
379
380        let expr2 = between(
381            root(),
382            lit(0),
383            lit(100),
384            BetweenOptions {
385                lower_strict: StrictComparison::Strict,
386                upper_strict: StrictComparison::NonStrict,
387            },
388        );
389        assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
390    }
391}