vortex_expr/exprs/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt::Debug;
6use std::hash::{Hash, Hasher};
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use futures::try_join;
11use itertools::Itertools;
12use vortex_array::compute::{BetweenOptions, StrictComparison, between as between_compute};
13use vortex_array::operator::{
14    BatchBindCtx, BatchExecution, BatchExecutionRef, BatchOperator, Operator, OperatorEq,
15    OperatorHash, OperatorId, OperatorRef,
16};
17use vortex_array::{Array, ArrayRef, Canonical, DeserializeMetadata, IntoArray, ProstMetadata};
18use vortex_dtype::DType;
19use vortex_dtype::DType::Bool;
20use vortex_error::{VortexExpect, VortexResult, vortex_bail};
21use vortex_proto::expr as pb;
22
23use crate::display::{DisplayAs, DisplayFormat};
24use crate::{
25    AnalysisExpr, BinaryExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable,
26};
27
28vtable!(Between);
29
30#[allow(clippy::derived_hash_with_manual_eq)]
31#[derive(Clone, Debug, Hash, Eq)]
32pub struct BetweenExpr {
33    arr: ExprRef,
34    lower: ExprRef,
35    upper: ExprRef,
36    options: BetweenOptions,
37}
38
39impl PartialEq for BetweenExpr {
40    fn eq(&self, other: &Self) -> bool {
41        self.arr.eq(&other.arr)
42            && self.lower.eq(&other.lower)
43            && self.upper.eq(&other.upper)
44            && self.options == other.options
45    }
46}
47
48pub struct BetweenExprEncoding;
49
50impl VTable for BetweenVTable {
51    type Expr = BetweenExpr;
52    type Encoding = BetweenExprEncoding;
53    type Metadata = ProstMetadata<pb::BetweenOpts>;
54
55    fn id(_encoding: &Self::Encoding) -> ExprId {
56        ExprId::new_ref("between")
57    }
58
59    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
60        ExprEncodingRef::new_ref(BetweenExprEncoding.as_ref())
61    }
62
63    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
64        Some(ProstMetadata(pb::BetweenOpts {
65            lower_strict: expr.options.lower_strict == StrictComparison::Strict,
66            upper_strict: expr.options.upper_strict == StrictComparison::Strict,
67        }))
68    }
69
70    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
71        vec![&expr.arr, &expr.lower, &expr.upper]
72    }
73
74    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
75        Ok(BetweenExpr::new(
76            children[0].clone(),
77            children[1].clone(),
78            children[2].clone(),
79            expr.options.clone(),
80        ))
81    }
82
83    fn build(
84        _encoding: &Self::Encoding,
85        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
86        children: Vec<ExprRef>,
87    ) -> VortexResult<Self::Expr> {
88        Ok(BetweenExpr::new(
89            children[0].clone(),
90            children[1].clone(),
91            children[2].clone(),
92            BetweenOptions {
93                lower_strict: if metadata.lower_strict {
94                    StrictComparison::Strict
95                } else {
96                    StrictComparison::NonStrict
97                },
98                upper_strict: if metadata.upper_strict {
99                    StrictComparison::Strict
100                } else {
101                    StrictComparison::NonStrict
102                },
103            },
104        ))
105    }
106
107    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
108        let arr_val = expr.arr.unchecked_evaluate(scope)?;
109        let lower_arr_val = expr.lower.unchecked_evaluate(scope)?;
110        let upper_arr_val = expr.upper.unchecked_evaluate(scope)?;
111
112        between_compute(&arr_val, &lower_arr_val, &upper_arr_val, &expr.options)
113    }
114
115    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
116        let arr_dt = expr.arr.return_dtype(scope)?;
117        let lower_dt = expr.lower.return_dtype(scope)?;
118        let upper_dt = expr.upper.return_dtype(scope)?;
119
120        if !arr_dt.eq_ignore_nullability(&lower_dt) {
121            vortex_bail!(
122                "Array dtype {} does not match lower dtype {}",
123                arr_dt,
124                lower_dt
125            );
126        }
127        if !arr_dt.eq_ignore_nullability(&upper_dt) {
128            vortex_bail!(
129                "Array dtype {} does not match upper dtype {}",
130                arr_dt,
131                upper_dt
132            );
133        }
134
135        Ok(Bool(
136            arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
137        ))
138    }
139
140    fn operator(expr: &Self::Expr, scope: &OperatorRef) -> VortexResult<Option<OperatorRef>> {
141        let Some(arr) = expr.arr.operator(scope)? else {
142            return Ok(None);
143        };
144        let Some(lower) = expr.lower.operator(scope)? else {
145            return Ok(None);
146        };
147        let Some(upper) = expr.upper.operator(scope)? else {
148            return Ok(None);
149        };
150        Ok(Some(Arc::new(BetweenOperator {
151            children: [arr, lower, upper],
152            dtype: expr.return_dtype(scope.dtype())?,
153            options: expr.options.clone(),
154        })))
155    }
156}
157
158impl BetweenExpr {
159    pub fn new(arr: ExprRef, lower: ExprRef, upper: ExprRef, options: BetweenOptions) -> Self {
160        Self {
161            arr,
162            lower,
163            upper,
164            options,
165        }
166    }
167
168    pub fn new_expr(
169        arr: ExprRef,
170        lower: ExprRef,
171        upper: ExprRef,
172        options: BetweenOptions,
173    ) -> ExprRef {
174        Self::new(arr, lower, upper, options).into_expr()
175    }
176
177    pub fn to_binary_expr(&self) -> ExprRef {
178        let lhs = BinaryExpr::new(
179            self.lower.clone(),
180            self.options.lower_strict.to_operator().into(),
181            self.arr.clone(),
182        );
183        let rhs = BinaryExpr::new(
184            self.arr.clone(),
185            self.options.upper_strict.to_operator().into(),
186            self.upper.clone(),
187        );
188        BinaryExpr::new(lhs.into_expr(), crate::Operator::And, rhs.into_expr()).into_expr()
189    }
190}
191
192impl DisplayAs for BetweenExpr {
193    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
194        match df {
195            DisplayFormat::Compact => {
196                write!(
197                    f,
198                    "({} {} {} {} {})",
199                    self.lower,
200                    self.options.lower_strict.to_operator(),
201                    self.arr,
202                    self.options.upper_strict.to_operator(),
203                    self.upper
204                )
205            }
206            DisplayFormat::Tree => {
207                write!(f, "Between")
208            }
209        }
210    }
211
212    fn child_names(&self) -> Option<Vec<String>> {
213        // Children are: arr, lower, upper (based on the order in the children() method)
214        Some(vec![
215            "array".to_string(),
216            format!("lower ({:?})", self.options.lower_strict),
217            format!("upper ({:?})", self.options.upper_strict),
218        ])
219    }
220}
221
222impl AnalysisExpr for BetweenExpr {}
223
224/// Creates an expression that checks if values are between two bounds.
225///
226/// Returns a boolean array indicating which values fall within the specified range.
227/// The comparison strictness is controlled by the options parameter.
228///
229/// ```rust
230/// # use vortex_array::compute::BetweenOptions;
231/// # use vortex_array::compute::StrictComparison;
232/// # use vortex_expr::{between, lit, root};
233/// let opts = BetweenOptions {
234///     lower_strict: StrictComparison::NonStrict,
235///     upper_strict: StrictComparison::NonStrict,
236/// };
237/// let expr = between(root(), lit(10), lit(20), opts);
238/// ```
239pub fn between(arr: ExprRef, lower: ExprRef, upper: ExprRef, options: BetweenOptions) -> ExprRef {
240    BetweenExpr::new(arr, lower, upper, options).into_expr()
241}
242
243#[derive(Debug)]
244pub struct BetweenOperator {
245    children: [OperatorRef; 3],
246    dtype: DType,
247    options: BetweenOptions,
248}
249
250impl OperatorHash for BetweenOperator {
251    fn operator_hash<H: Hasher>(&self, state: &mut H) {
252        for child in &self.children {
253            child.operator_hash(state);
254        }
255        self.dtype.hash(state);
256        self.options.hash(state);
257    }
258}
259
260impl OperatorEq for BetweenOperator {
261    fn operator_eq(&self, other: &Self) -> bool {
262        self.children.len() == other.children.len()
263            && self
264                .children
265                .iter()
266                .zip(other.children.iter())
267                .all(|(a, b)| a.operator_eq(b))
268            && self.dtype == other.dtype
269            && self.options == other.options
270    }
271}
272
273impl Operator for BetweenOperator {
274    fn id(&self) -> OperatorId {
275        OperatorId::from("vortex.between")
276    }
277
278    fn as_any(&self) -> &dyn Any {
279        self
280    }
281
282    fn dtype(&self) -> &DType {
283        &self.dtype
284    }
285
286    fn len(&self) -> usize {
287        self.children[0].len()
288    }
289
290    fn children(&self) -> &[OperatorRef] {
291        &self.children
292    }
293
294    fn with_children(self: Arc<Self>, children: Vec<OperatorRef>) -> VortexResult<OperatorRef> {
295        let (arr, lower, upper) = children
296            .into_iter()
297            .tuples()
298            .next()
299            .vortex_expect("expected 3 children");
300
301        Ok(Arc::new(BetweenOperator {
302            children: [arr, lower, upper],
303            dtype: self.dtype.clone(),
304            options: self.options.clone(),
305        }))
306    }
307
308    fn is_selection_target(&self, _child_idx: usize) -> Option<bool> {
309        // All children are position preserving.
310        Some(true)
311    }
312}
313
314impl BatchOperator for BetweenOperator {
315    fn bind(&self, ctx: &mut dyn BatchBindCtx) -> VortexResult<BatchExecutionRef> {
316        let arr = ctx.child(0)?;
317        let lower = ctx.child(1)?;
318        let upper = ctx.child(2)?;
319        Ok(Box::new(BetweenExecution {
320            arr,
321            lower,
322            upper,
323            options: self.options.clone(),
324        }))
325    }
326}
327
328struct BetweenExecution {
329    arr: BatchExecutionRef,
330    lower: BatchExecutionRef,
331    upper: BatchExecutionRef,
332    options: BetweenOptions,
333}
334
335#[async_trait]
336impl BatchExecution for BetweenExecution {
337    async fn execute(self: Box<Self>) -> VortexResult<Canonical> {
338        let (arr, lower, upper) = try_join!(
339            self.arr.execute(),
340            self.lower.execute(),
341            self.upper.execute()
342        )?;
343        let result = between_compute(
344            arr.into_array().as_ref(),
345            lower.into_array().as_ref(),
346            upper.into_array().as_ref(),
347            &self.options,
348        )?;
349        Ok(result.to_canonical())
350    }
351}
352
353// TODO(ngates): we need scalar variants for batch execution. Although really it should be
354//  pipelined?
355
356#[cfg(test)]
357mod tests {
358    use vortex_array::compute::{BetweenOptions, StrictComparison};
359
360    use crate::{between, get_item, lit, root};
361
362    #[test]
363    fn test_display() {
364        let expr = between(
365            get_item("score", root()),
366            lit(10),
367            lit(50),
368            BetweenOptions {
369                lower_strict: StrictComparison::NonStrict,
370                upper_strict: StrictComparison::Strict,
371            },
372        );
373        assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
374
375        let expr2 = between(
376            root(),
377            lit(0),
378            lit(100),
379            BetweenOptions {
380                lower_strict: StrictComparison::Strict,
381                upper_strict: StrictComparison::NonStrict,
382            },
383        );
384        assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
385    }
386}