vortex_array/expr/exprs/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_dtype::DType;
8use vortex_dtype::DType::Bool;
9use vortex_error::{VortexExpect, VortexResult, vortex_bail};
10use vortex_proto::expr as pb;
11
12use crate::ArrayRef;
13use crate::compute::{BetweenOptions, between as between_compute};
14use crate::expr::expression::Expression;
15use crate::expr::exprs::binary::Binary;
16use crate::expr::exprs::operators::Operator;
17use crate::expr::{ChildName, ExprId, ExpressionView, StatsCatalog, VTable, VTableExt};
18
19/// An optimized scalar expression to compute whether values fall between two bounds.
20///
21/// This expression takes three children:
22/// 1. The array of values to check.
23/// 2. The lower bound.
24/// 3. The upper bound.
25///
26/// The comparison strictness is controlled by the metadata.
27///
28/// NOTE: this expression will shortly be removed in favor of pipelined computation of two
29/// separate comparisons combined with a logical AND.
30pub struct Between;
31
32impl VTable for Between {
33    type Instance = BetweenOptions;
34
35    fn id(&self) -> ExprId {
36        ExprId::from("vortex.between")
37    }
38
39    fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
40        Ok(Some(
41            pb::BetweenOpts {
42                lower_strict: instance.lower_strict.is_strict(),
43                upper_strict: instance.upper_strict.is_strict(),
44            }
45            .encode_to_vec(),
46        ))
47    }
48
49    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
50        let opts = pb::BetweenOpts::decode(metadata)?;
51        Ok(Some(BetweenOptions {
52            lower_strict: if opts.lower_strict {
53                crate::compute::StrictComparison::Strict
54            } else {
55                crate::compute::StrictComparison::NonStrict
56            },
57            upper_strict: if opts.upper_strict {
58                crate::compute::StrictComparison::Strict
59            } else {
60                crate::compute::StrictComparison::NonStrict
61            },
62        }))
63    }
64
65    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
66        if expr.children().len() != 3 {
67            vortex_bail!(
68                "Between expression requires exactly 3 children, got {}",
69                expr.children().len()
70            );
71        }
72        Ok(())
73    }
74
75    fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
76        match child_idx {
77            0 => ChildName::from("array"),
78            1 => ChildName::from("lower"),
79            2 => ChildName::from("upper"),
80            _ => unreachable!("Invalid child index {} for Between expression", child_idx),
81        }
82    }
83
84    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
85        let options = expr.data();
86        let lower_op = if options.lower_strict.is_strict() {
87            "<"
88        } else {
89            "<="
90        };
91        let upper_op = if options.upper_strict.is_strict() {
92            "<"
93        } else {
94            "<="
95        };
96        write!(
97            f,
98            "({} {} {} {} {})",
99            expr.lower(),
100            lower_op,
101            expr.child(),
102            upper_op,
103            expr.upper()
104        )
105    }
106
107    fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
108        let arr_dt = expr.child().return_dtype(scope)?;
109        let lower_dt = expr.lower().return_dtype(scope)?;
110        let upper_dt = expr.upper().return_dtype(scope)?;
111
112        if !arr_dt.eq_ignore_nullability(&lower_dt) {
113            vortex_bail!(
114                "Array dtype {} does not match lower dtype {}",
115                arr_dt,
116                lower_dt
117            );
118        }
119        if !arr_dt.eq_ignore_nullability(&upper_dt) {
120            vortex_bail!(
121                "Array dtype {} does not match upper dtype {}",
122                arr_dt,
123                upper_dt
124            );
125        }
126
127        Ok(Bool(
128            arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
129        ))
130    }
131
132    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
133        let arr = expr.child().evaluate(scope)?;
134        let lower = expr.lower().evaluate(scope)?;
135        let upper = expr.upper().evaluate(scope)?;
136        between_compute(&arr, &lower, &upper, expr.data())
137    }
138
139    fn stat_falsification(
140        &self,
141        expr: &ExpressionView<Self>,
142        catalog: &mut dyn StatsCatalog,
143    ) -> Option<Expression> {
144        expr.to_binary_expr().stat_falsification(catalog)
145    }
146}
147
148impl ExpressionView<'_, Between> {
149    pub fn child(&self) -> &Expression {
150        &self.children()[0]
151    }
152
153    pub fn lower(&self) -> &Expression {
154        &self.children()[1]
155    }
156
157    pub fn upper(&self) -> &Expression {
158        &self.children()[2]
159    }
160
161    pub fn to_binary_expr(&self) -> Expression {
162        let options = self.data();
163        let arr = self.children()[0].clone();
164        let lower = self.children()[1].clone();
165        let upper = self.children()[2].clone();
166
167        let lhs = Binary.new_expr(
168            options.lower_strict.to_operator().into(),
169            [lower, arr.clone()],
170        );
171        let rhs = Binary.new_expr(options.upper_strict.to_operator().into(), [arr, upper]);
172        Binary.new_expr(Operator::And, [lhs, rhs])
173    }
174}
175
176/// Creates an expression that checks if values are between two bounds.
177///
178/// Returns a boolean array indicating which values fall within the specified range.
179/// The comparison strictness is controlled by the options parameter.
180///
181/// ```rust
182/// # use vortex_array::compute::BetweenOptions;
183/// # use vortex_array::compute::StrictComparison;
184/// # use vortex_array::expr::{between, lit, root};
185/// let opts = BetweenOptions {
186///     lower_strict: StrictComparison::NonStrict,
187///     upper_strict: StrictComparison::NonStrict,
188/// };
189/// let expr = between(root(), lit(10), lit(20), opts);
190/// ```
191pub fn between(
192    arr: Expression,
193    lower: Expression,
194    upper: Expression,
195    options: BetweenOptions,
196) -> Expression {
197    Between
198        .try_new_expr(options, [arr, lower, upper])
199        .vortex_expect("Failed to create Between expression")
200}
201
202#[cfg(test)]
203mod tests {
204    use super::between;
205    use crate::compute::{BetweenOptions, StrictComparison};
206    use crate::expr::exprs::get_item::get_item;
207    use crate::expr::exprs::literal::lit;
208    use crate::expr::exprs::root::root;
209
210    #[test]
211    fn test_display() {
212        let expr = between(
213            get_item("score", root()),
214            lit(10),
215            lit(50),
216            BetweenOptions {
217                lower_strict: StrictComparison::NonStrict,
218                upper_strict: StrictComparison::Strict,
219            },
220        );
221        assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
222
223        let expr2 = between(
224            root(),
225            lit(0),
226            lit(100),
227            BetweenOptions {
228                lower_strict: StrictComparison::Strict,
229                upper_strict: StrictComparison::NonStrict,
230            },
231        );
232        assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
233    }
234}