vortex_array/expr/exprs/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_dtype::DType;
8use vortex_dtype::DType::Bool;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_proto::expr as pb;
14use vortex_vector::Datum;
15
16use crate::ArrayRef;
17use crate::compute::BetweenOptions;
18use crate::compute::between as between_compute;
19use crate::expr::Arity;
20use crate::expr::ChildName;
21use crate::expr::ExecutionArgs;
22use crate::expr::ExprId;
23use crate::expr::StatsCatalog;
24use crate::expr::VTable;
25use crate::expr::VTableExt;
26use crate::expr::expression::Expression;
27use crate::expr::exprs::binary::Binary;
28use crate::expr::exprs::operators::Operator;
29
30/// An optimized scalar expression to compute whether values fall between two bounds.
31///
32/// This expression takes three children:
33/// 1. The array of values to check.
34/// 2. The lower bound.
35/// 3. The upper bound.
36///
37/// The comparison strictness is controlled by the metadata.
38///
39/// NOTE: this expression will shortly be removed in favor of pipelined computation of two
40/// separate comparisons combined with a logical AND.
41pub struct Between;
42
43impl VTable for Between {
44    type Options = BetweenOptions;
45
46    fn id(&self) -> ExprId {
47        ExprId::from("vortex.between")
48    }
49
50    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
51        Ok(Some(
52            pb::BetweenOpts {
53                lower_strict: instance.lower_strict.is_strict(),
54                upper_strict: instance.upper_strict.is_strict(),
55            }
56            .encode_to_vec(),
57        ))
58    }
59
60    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Self::Options> {
61        let opts = pb::BetweenOpts::decode(metadata)?;
62        Ok(BetweenOptions {
63            lower_strict: if opts.lower_strict {
64                crate::compute::StrictComparison::Strict
65            } else {
66                crate::compute::StrictComparison::NonStrict
67            },
68            upper_strict: if opts.upper_strict {
69                crate::compute::StrictComparison::Strict
70            } else {
71                crate::compute::StrictComparison::NonStrict
72            },
73        })
74    }
75
76    fn arity(&self, _options: &Self::Options) -> Arity {
77        Arity::Exact(3)
78    }
79
80    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
81        match child_idx {
82            0 => ChildName::from("array"),
83            1 => ChildName::from("lower"),
84            2 => ChildName::from("upper"),
85            _ => unreachable!("Invalid child index {} for Between expression", child_idx),
86        }
87    }
88
89    fn fmt_sql(
90        &self,
91        options: &Self::Options,
92        expr: &Expression,
93        f: &mut Formatter<'_>,
94    ) -> std::fmt::Result {
95        let lower_op = if options.lower_strict.is_strict() {
96            "<"
97        } else {
98            "<="
99        };
100        let upper_op = if options.upper_strict.is_strict() {
101            "<"
102        } else {
103            "<="
104        };
105        write!(
106            f,
107            "({} {} {} {} {})",
108            expr.child(1),
109            lower_op,
110            expr.child(0),
111            upper_op,
112            expr.child(2)
113        )
114    }
115
116    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
117        let arr_dt = &arg_dtypes[0];
118        let lower_dt = &arg_dtypes[1];
119        let upper_dt = &arg_dtypes[2];
120
121        if !arr_dt.eq_ignore_nullability(lower_dt) {
122            vortex_bail!(
123                "Array dtype {} does not match lower dtype {}",
124                arr_dt,
125                lower_dt
126            );
127        }
128        if !arr_dt.eq_ignore_nullability(upper_dt) {
129            vortex_bail!(
130                "Array dtype {} does not match upper dtype {}",
131                arr_dt,
132                upper_dt
133            );
134        }
135
136        Ok(Bool(
137            arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
138        ))
139    }
140
141    fn evaluate(
142        &self,
143        options: &Self::Options,
144        expr: &Expression,
145        scope: &ArrayRef,
146    ) -> VortexResult<ArrayRef> {
147        let arr = expr.child(0).evaluate(scope)?;
148        let lower = expr.child(1).evaluate(scope)?;
149        let upper = expr.child(2).evaluate(scope)?;
150        between_compute(&arr, &lower, &upper, options)
151    }
152
153    fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<Datum> {
154        let [arr, lower, upper]: [Datum; _] = args
155            .datums
156            .try_into()
157            .map_err(|_| vortex_err!("Expected 3 arguments for Between expression",))?;
158        let [arr_dt, lower_dt, upper_dt]: [DType; _] = args
159            .dtypes
160            .try_into()
161            .map_err(|_| vortex_err!("Expected 3 dtypes for Between expression",))?;
162
163        let lower_bound = Binary
164            .bind(options.lower_strict.to_operator().into())
165            .execute(ExecutionArgs {
166                datums: vec![lower, arr.clone()],
167                dtypes: vec![lower_dt, arr_dt.clone()],
168                row_count: args.row_count,
169                return_dtype: args.return_dtype.clone(),
170            })?;
171        let upper_bound = Binary
172            .bind(options.upper_strict.to_operator().into())
173            .execute(ExecutionArgs {
174                datums: vec![arr, upper],
175                dtypes: vec![arr_dt, upper_dt],
176                row_count: args.row_count,
177                return_dtype: args.return_dtype.clone(),
178            })?;
179
180        Binary.bind(Operator::And).execute(ExecutionArgs {
181            datums: vec![lower_bound, upper_bound],
182            dtypes: vec![args.return_dtype.clone(), args.return_dtype.clone()],
183            row_count: args.row_count,
184            return_dtype: args.return_dtype,
185        })
186    }
187
188    fn stat_falsification(
189        &self,
190        options: &Self::Options,
191        expr: &Expression,
192        catalog: &dyn StatsCatalog,
193    ) -> Option<Expression> {
194        let arr = expr.child(0).clone();
195        let lower = expr.child(1).clone();
196        let upper = expr.child(2).clone();
197
198        let lhs = Binary.new_expr(
199            options.lower_strict.to_operator().into(),
200            [lower, arr.clone()],
201        );
202        let rhs = Binary.new_expr(options.upper_strict.to_operator().into(), [arr, upper]);
203
204        Binary
205            .new_expr(Operator::And, [lhs, rhs])
206            .stat_falsification(catalog)
207    }
208
209    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
210        false
211    }
212
213    fn is_fallible(&self, _options: &Self::Options) -> bool {
214        false
215    }
216}
217
218/// Creates an expression that checks if values are between two bounds.
219///
220/// Returns a boolean array indicating which values fall within the specified range.
221/// The comparison strictness is controlled by the options parameter.
222///
223/// ```rust
224/// # use vortex_array::compute::BetweenOptions;
225/// # use vortex_array::compute::StrictComparison;
226/// # use vortex_array::expr::{between, lit, root};
227/// let opts = BetweenOptions {
228///     lower_strict: StrictComparison::NonStrict,
229///     upper_strict: StrictComparison::NonStrict,
230/// };
231/// let expr = between(root(), lit(10), lit(20), opts);
232/// ```
233pub fn between(
234    arr: Expression,
235    lower: Expression,
236    upper: Expression,
237    options: BetweenOptions,
238) -> Expression {
239    Between
240        .try_new_expr(options, [arr, lower, upper])
241        .vortex_expect("Failed to create Between expression")
242}
243
244#[cfg(test)]
245mod tests {
246    use super::between;
247    use crate::compute::BetweenOptions;
248    use crate::compute::StrictComparison;
249    use crate::expr::exprs::get_item::get_item;
250    use crate::expr::exprs::literal::lit;
251    use crate::expr::exprs::root::root;
252
253    #[test]
254    fn test_display() {
255        let expr = between(
256            get_item("score", root()),
257            lit(10),
258            lit(50),
259            BetweenOptions {
260                lower_strict: StrictComparison::NonStrict,
261                upper_strict: StrictComparison::Strict,
262            },
263        );
264        assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
265
266        let expr2 = between(
267            root(),
268            lit(0),
269            lit(100),
270            BetweenOptions {
271                lower_strict: StrictComparison::Strict,
272                upper_strict: StrictComparison::NonStrict,
273            },
274        );
275        assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
276    }
277}