vortex_expr/exprs/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use vortex_array::compute::{BetweenOptions, StrictComparison, between as between_compute};
7use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata};
8use vortex_dtype::DType;
9use vortex_dtype::DType::Bool;
10use vortex_error::{VortexResult, vortex_bail};
11use vortex_proto::expr as pb;
12
13use crate::display::{DisplayAs, DisplayFormat};
14use crate::{
15    AnalysisExpr, BinaryExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable,
16};
17
18vtable!(Between);
19
20#[allow(clippy::derived_hash_with_manual_eq)]
21#[derive(Clone, Debug, Hash, Eq)]
22pub struct BetweenExpr {
23    arr: ExprRef,
24    lower: ExprRef,
25    upper: ExprRef,
26    options: BetweenOptions,
27}
28
29impl PartialEq for BetweenExpr {
30    fn eq(&self, other: &Self) -> bool {
31        self.arr.eq(&other.arr)
32            && self.lower.eq(&other.lower)
33            && self.upper.eq(&other.upper)
34            && self.options == other.options
35    }
36}
37
38pub struct BetweenExprEncoding;
39
40impl VTable for BetweenVTable {
41    type Expr = BetweenExpr;
42    type Encoding = BetweenExprEncoding;
43    type Metadata = ProstMetadata<pb::BetweenOpts>;
44
45    fn id(_encoding: &Self::Encoding) -> ExprId {
46        ExprId::new_ref("between")
47    }
48
49    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
50        ExprEncodingRef::new_ref(BetweenExprEncoding.as_ref())
51    }
52
53    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
54        Some(ProstMetadata(pb::BetweenOpts {
55            lower_strict: expr.options.lower_strict == StrictComparison::Strict,
56            upper_strict: expr.options.upper_strict == StrictComparison::Strict,
57        }))
58    }
59
60    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
61        vec![&expr.arr, &expr.lower, &expr.upper]
62    }
63
64    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
65        Ok(BetweenExpr::new(
66            children[0].clone(),
67            children[1].clone(),
68            children[2].clone(),
69            expr.options.clone(),
70        ))
71    }
72
73    fn build(
74        _encoding: &Self::Encoding,
75        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
76        children: Vec<ExprRef>,
77    ) -> VortexResult<Self::Expr> {
78        Ok(BetweenExpr::new(
79            children[0].clone(),
80            children[1].clone(),
81            children[2].clone(),
82            BetweenOptions {
83                lower_strict: if metadata.lower_strict {
84                    StrictComparison::Strict
85                } else {
86                    StrictComparison::NonStrict
87                },
88                upper_strict: if metadata.upper_strict {
89                    StrictComparison::Strict
90                } else {
91                    StrictComparison::NonStrict
92                },
93            },
94        ))
95    }
96
97    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
98        let arr_val = expr.arr.unchecked_evaluate(scope)?;
99        let lower_arr_val = expr.lower.unchecked_evaluate(scope)?;
100        let upper_arr_val = expr.upper.unchecked_evaluate(scope)?;
101
102        between_compute(&arr_val, &lower_arr_val, &upper_arr_val, &expr.options)
103    }
104
105    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
106        let arr_dt = expr.arr.return_dtype(scope)?;
107        let lower_dt = expr.lower.return_dtype(scope)?;
108        let upper_dt = expr.upper.return_dtype(scope)?;
109
110        if !arr_dt.eq_ignore_nullability(&lower_dt) {
111            vortex_bail!(
112                "Array dtype {} does not match lower dtype {}",
113                arr_dt,
114                lower_dt
115            );
116        }
117        if !arr_dt.eq_ignore_nullability(&upper_dt) {
118            vortex_bail!(
119                "Array dtype {} does not match upper dtype {}",
120                arr_dt,
121                upper_dt
122            );
123        }
124
125        Ok(Bool(
126            arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
127        ))
128    }
129}
130
131impl BetweenExpr {
132    pub fn new(arr: ExprRef, lower: ExprRef, upper: ExprRef, options: BetweenOptions) -> Self {
133        Self {
134            arr,
135            lower,
136            upper,
137            options,
138        }
139    }
140
141    pub fn new_expr(
142        arr: ExprRef,
143        lower: ExprRef,
144        upper: ExprRef,
145        options: BetweenOptions,
146    ) -> ExprRef {
147        Self::new(arr, lower, upper, options).into_expr()
148    }
149
150    pub fn to_binary_expr(&self) -> ExprRef {
151        let lhs = BinaryExpr::new(
152            self.lower.clone(),
153            self.options.lower_strict.to_operator().into(),
154            self.arr.clone(),
155        );
156        let rhs = BinaryExpr::new(
157            self.arr.clone(),
158            self.options.upper_strict.to_operator().into(),
159            self.upper.clone(),
160        );
161        BinaryExpr::new(lhs.into_expr(), crate::Operator::And, rhs.into_expr()).into_expr()
162    }
163}
164
165impl DisplayAs for BetweenExpr {
166    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
167        match df {
168            DisplayFormat::Compact => {
169                write!(
170                    f,
171                    "({} {} {} {} {})",
172                    self.lower,
173                    self.options.lower_strict.to_operator(),
174                    self.arr,
175                    self.options.upper_strict.to_operator(),
176                    self.upper
177                )
178            }
179            DisplayFormat::Tree => {
180                write!(f, "Between")
181            }
182        }
183    }
184
185    fn child_names(&self) -> Option<Vec<String>> {
186        // Children are: arr, lower, upper (based on the order in the children() method)
187        Some(vec![
188            "array".to_string(),
189            format!("lower ({:?})", self.options.lower_strict),
190            format!("upper ({:?})", self.options.upper_strict),
191        ])
192    }
193}
194
195impl AnalysisExpr for BetweenExpr {}
196
197/// Creates an expression that checks if values are between two bounds.
198///
199/// Returns a boolean array indicating which values fall within the specified range.
200/// The comparison strictness is controlled by the options parameter.
201///
202/// ```rust
203/// # use vortex_array::compute::BetweenOptions;
204/// # use vortex_array::compute::StrictComparison;
205/// # use vortex_expr::{between, lit, root};
206/// let opts = BetweenOptions {
207///     lower_strict: StrictComparison::NonStrict,
208///     upper_strict: StrictComparison::NonStrict,
209/// };
210/// let expr = between(root(), lit(10), lit(20), opts);
211/// ```
212pub fn between(arr: ExprRef, lower: ExprRef, upper: ExprRef, options: BetweenOptions) -> ExprRef {
213    BetweenExpr::new(arr, lower, upper, options).into_expr()
214}
215
216#[cfg(test)]
217mod tests {
218    use vortex_array::compute::{BetweenOptions, StrictComparison};
219
220    use crate::{between, get_item, lit, root};
221
222    #[test]
223    fn test_display() {
224        let expr = between(
225            get_item("score", root()),
226            lit(10),
227            lit(50),
228            BetweenOptions {
229                lower_strict: StrictComparison::NonStrict,
230                upper_strict: StrictComparison::Strict,
231            },
232        );
233        assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
234
235        let expr2 = between(
236            root(),
237            lit(0),
238            lit(100),
239            BetweenOptions {
240                lower_strict: StrictComparison::Strict,
241                upper_strict: StrictComparison::NonStrict,
242            },
243        );
244        assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
245    }
246}