vortex_array/expr/exprs/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_dtype::DType;
8use vortex_dtype::DType::Bool;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_proto::expr as pb;
13
14use crate::ArrayRef;
15use crate::compute::BetweenOptions;
16use crate::compute::between as between_compute;
17use crate::expr::ChildName;
18use crate::expr::ExprId;
19use crate::expr::ExpressionView;
20use crate::expr::StatsCatalog;
21use crate::expr::VTable;
22use crate::expr::VTableExt;
23use crate::expr::expression::Expression;
24use crate::expr::exprs::binary::Binary;
25use crate::expr::exprs::operators::Operator;
26
27/// An optimized scalar expression to compute whether values fall between two bounds.
28///
29/// This expression takes three children:
30/// 1. The array of values to check.
31/// 2. The lower bound.
32/// 3. The upper bound.
33///
34/// The comparison strictness is controlled by the metadata.
35///
36/// NOTE: this expression will shortly be removed in favor of pipelined computation of two
37/// separate comparisons combined with a logical AND.
38pub struct Between;
39
40impl VTable for Between {
41    type Instance = BetweenOptions;
42
43    fn id(&self) -> ExprId {
44        ExprId::from("vortex.between")
45    }
46
47    fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
48        Ok(Some(
49            pb::BetweenOpts {
50                lower_strict: instance.lower_strict.is_strict(),
51                upper_strict: instance.upper_strict.is_strict(),
52            }
53            .encode_to_vec(),
54        ))
55    }
56
57    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
58        let opts = pb::BetweenOpts::decode(metadata)?;
59        Ok(Some(BetweenOptions {
60            lower_strict: if opts.lower_strict {
61                crate::compute::StrictComparison::Strict
62            } else {
63                crate::compute::StrictComparison::NonStrict
64            },
65            upper_strict: if opts.upper_strict {
66                crate::compute::StrictComparison::Strict
67            } else {
68                crate::compute::StrictComparison::NonStrict
69            },
70        }))
71    }
72
73    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
74        if expr.children().len() != 3 {
75            vortex_bail!(
76                "Between expression requires exactly 3 children, got {}",
77                expr.children().len()
78            );
79        }
80        Ok(())
81    }
82
83    fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
84        match child_idx {
85            0 => ChildName::from("array"),
86            1 => ChildName::from("lower"),
87            2 => ChildName::from("upper"),
88            _ => unreachable!("Invalid child index {} for Between expression", child_idx),
89        }
90    }
91
92    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
93        let options = expr.data();
94        let lower_op = if options.lower_strict.is_strict() {
95            "<"
96        } else {
97            "<="
98        };
99        let upper_op = if options.upper_strict.is_strict() {
100            "<"
101        } else {
102            "<="
103        };
104        write!(
105            f,
106            "({} {} {} {} {})",
107            expr.lower(),
108            lower_op,
109            expr.child(),
110            upper_op,
111            expr.upper()
112        )
113    }
114
115    fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
116        let arr_dt = expr.child().return_dtype(scope)?;
117        let lower_dt = expr.lower().return_dtype(scope)?;
118        let upper_dt = expr.upper().return_dtype(scope)?;
119
120        if !arr_dt.eq_ignore_nullability(&lower_dt) {
121            vortex_bail!(
122                "Array dtype {} does not match lower dtype {}",
123                arr_dt,
124                lower_dt
125            );
126        }
127        if !arr_dt.eq_ignore_nullability(&upper_dt) {
128            vortex_bail!(
129                "Array dtype {} does not match upper dtype {}",
130                arr_dt,
131                upper_dt
132            );
133        }
134
135        Ok(Bool(
136            arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
137        ))
138    }
139
140    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
141        let arr = expr.child().evaluate(scope)?;
142        let lower = expr.lower().evaluate(scope)?;
143        let upper = expr.upper().evaluate(scope)?;
144        between_compute(&arr, &lower, &upper, expr.data())
145    }
146
147    fn stat_falsification(
148        &self,
149        expr: &ExpressionView<Self>,
150        catalog: &dyn StatsCatalog,
151    ) -> Option<Expression> {
152        expr.to_binary_expr().stat_falsification(catalog)
153    }
154
155    fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
156        false
157    }
158}
159
160impl ExpressionView<'_, Between> {
161    pub fn child(&self) -> &Expression {
162        &self.children()[0]
163    }
164
165    pub fn lower(&self) -> &Expression {
166        &self.children()[1]
167    }
168
169    pub fn upper(&self) -> &Expression {
170        &self.children()[2]
171    }
172
173    pub fn to_binary_expr(&self) -> Expression {
174        let options = self.data();
175        let arr = self.children()[0].clone();
176        let lower = self.children()[1].clone();
177        let upper = self.children()[2].clone();
178
179        let lhs = Binary.new_expr(
180            options.lower_strict.to_operator().into(),
181            [lower, arr.clone()],
182        );
183        let rhs = Binary.new_expr(options.upper_strict.to_operator().into(), [arr, upper]);
184        Binary.new_expr(Operator::And, [lhs, rhs])
185    }
186}
187
188/// Creates an expression that checks if values are between two bounds.
189///
190/// Returns a boolean array indicating which values fall within the specified range.
191/// The comparison strictness is controlled by the options parameter.
192///
193/// ```rust
194/// # use vortex_array::compute::BetweenOptions;
195/// # use vortex_array::compute::StrictComparison;
196/// # use vortex_array::expr::{between, lit, root};
197/// let opts = BetweenOptions {
198///     lower_strict: StrictComparison::NonStrict,
199///     upper_strict: StrictComparison::NonStrict,
200/// };
201/// let expr = between(root(), lit(10), lit(20), opts);
202/// ```
203pub fn between(
204    arr: Expression,
205    lower: Expression,
206    upper: Expression,
207    options: BetweenOptions,
208) -> Expression {
209    Between
210        .try_new_expr(options, [arr, lower, upper])
211        .vortex_expect("Failed to create Between expression")
212}
213
214#[cfg(test)]
215mod tests {
216    use super::between;
217    use crate::compute::BetweenOptions;
218    use crate::compute::StrictComparison;
219    use crate::expr::exprs::get_item::get_item;
220    use crate::expr::exprs::literal::lit;
221    use crate::expr::exprs::root::root;
222
223    #[test]
224    fn test_display() {
225        let expr = between(
226            get_item("score", root()),
227            lit(10),
228            lit(50),
229            BetweenOptions {
230                lower_strict: StrictComparison::NonStrict,
231                upper_strict: StrictComparison::Strict,
232            },
233        );
234        assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
235
236        let expr2 = between(
237            root(),
238            lit(0),
239            lit(100),
240            BetweenOptions {
241                lower_strict: StrictComparison::Strict,
242                upper_strict: StrictComparison::NonStrict,
243            },
244        );
245        assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
246    }
247}