vortex_array/expr/exprs/
between.rs1use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_dtype::DType;
8use vortex_dtype::DType::Bool;
9use vortex_error::{VortexExpect, VortexResult, vortex_bail};
10use vortex_proto::expr as pb;
11
12use crate::ArrayRef;
13use crate::compute::{BetweenOptions, between as between_compute};
14use crate::expr::expression::Expression;
15use crate::expr::exprs::binary::Binary;
16use crate::expr::exprs::operators::Operator;
17use crate::expr::{ChildName, ExprId, ExpressionView, StatsCatalog, VTable, VTableExt};
18
19pub struct Between;
31
32impl VTable for Between {
33 type Instance = BetweenOptions;
34
35 fn id(&self) -> ExprId {
36 ExprId::from("vortex.between")
37 }
38
39 fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
40 Ok(Some(
41 pb::BetweenOpts {
42 lower_strict: instance.lower_strict.is_strict(),
43 upper_strict: instance.upper_strict.is_strict(),
44 }
45 .encode_to_vec(),
46 ))
47 }
48
49 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
50 let opts = pb::BetweenOpts::decode(metadata)?;
51 Ok(Some(BetweenOptions {
52 lower_strict: if opts.lower_strict {
53 crate::compute::StrictComparison::Strict
54 } else {
55 crate::compute::StrictComparison::NonStrict
56 },
57 upper_strict: if opts.upper_strict {
58 crate::compute::StrictComparison::Strict
59 } else {
60 crate::compute::StrictComparison::NonStrict
61 },
62 }))
63 }
64
65 fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
66 if expr.children().len() != 3 {
67 vortex_bail!(
68 "Between expression requires exactly 3 children, got {}",
69 expr.children().len()
70 );
71 }
72 Ok(())
73 }
74
75 fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
76 match child_idx {
77 0 => ChildName::from("array"),
78 1 => ChildName::from("lower"),
79 2 => ChildName::from("upper"),
80 _ => unreachable!("Invalid child index {} for Between expression", child_idx),
81 }
82 }
83
84 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
85 let options = expr.data();
86 let lower_op = if options.lower_strict.is_strict() {
87 "<"
88 } else {
89 "<="
90 };
91 let upper_op = if options.upper_strict.is_strict() {
92 "<"
93 } else {
94 "<="
95 };
96 write!(
97 f,
98 "({} {} {} {} {})",
99 expr.lower(),
100 lower_op,
101 expr.child(),
102 upper_op,
103 expr.upper()
104 )
105 }
106
107 fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
108 let arr_dt = expr.child().return_dtype(scope)?;
109 let lower_dt = expr.lower().return_dtype(scope)?;
110 let upper_dt = expr.upper().return_dtype(scope)?;
111
112 if !arr_dt.eq_ignore_nullability(&lower_dt) {
113 vortex_bail!(
114 "Array dtype {} does not match lower dtype {}",
115 arr_dt,
116 lower_dt
117 );
118 }
119 if !arr_dt.eq_ignore_nullability(&upper_dt) {
120 vortex_bail!(
121 "Array dtype {} does not match upper dtype {}",
122 arr_dt,
123 upper_dt
124 );
125 }
126
127 Ok(Bool(
128 arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
129 ))
130 }
131
132 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
133 let arr = expr.child().evaluate(scope)?;
134 let lower = expr.lower().evaluate(scope)?;
135 let upper = expr.upper().evaluate(scope)?;
136 between_compute(&arr, &lower, &upper, expr.data())
137 }
138
139 fn stat_falsification(
140 &self,
141 expr: &ExpressionView<Self>,
142 catalog: &mut dyn StatsCatalog,
143 ) -> Option<Expression> {
144 expr.to_binary_expr().stat_falsification(catalog)
145 }
146}
147
148impl ExpressionView<'_, Between> {
149 pub fn child(&self) -> &Expression {
150 &self.children()[0]
151 }
152
153 pub fn lower(&self) -> &Expression {
154 &self.children()[1]
155 }
156
157 pub fn upper(&self) -> &Expression {
158 &self.children()[2]
159 }
160
161 pub fn to_binary_expr(&self) -> Expression {
162 let options = self.data();
163 let arr = self.children()[0].clone();
164 let lower = self.children()[1].clone();
165 let upper = self.children()[2].clone();
166
167 let lhs = Binary.new_expr(
168 options.lower_strict.to_operator().into(),
169 [lower, arr.clone()],
170 );
171 let rhs = Binary.new_expr(options.upper_strict.to_operator().into(), [arr, upper]);
172 Binary.new_expr(Operator::And, [lhs, rhs])
173 }
174}
175
176pub fn between(
192 arr: Expression,
193 lower: Expression,
194 upper: Expression,
195 options: BetweenOptions,
196) -> Expression {
197 Between
198 .try_new_expr(options, [arr, lower, upper])
199 .vortex_expect("Failed to create Between expression")
200}
201
202#[cfg(test)]
203mod tests {
204 use super::between;
205 use crate::compute::{BetweenOptions, StrictComparison};
206 use crate::expr::exprs::get_item::get_item;
207 use crate::expr::exprs::literal::lit;
208 use crate::expr::exprs::root::root;
209
210 #[test]
211 fn test_display() {
212 let expr = between(
213 get_item("score", root()),
214 lit(10),
215 lit(50),
216 BetweenOptions {
217 lower_strict: StrictComparison::NonStrict,
218 upper_strict: StrictComparison::Strict,
219 },
220 );
221 assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
222
223 let expr2 = between(
224 root(),
225 lit(0),
226 lit(100),
227 BetweenOptions {
228 lower_strict: StrictComparison::Strict,
229 upper_strict: StrictComparison::NonStrict,
230 },
231 );
232 assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
233 }
234}