vortex_array/expr/exprs/
between.rs1use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_dtype::DType;
8use vortex_dtype::DType::Bool;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_proto::expr as pb;
13
14use crate::ArrayRef;
15use crate::compute::BetweenOptions;
16use crate::compute::between as between_compute;
17use crate::expr::ChildName;
18use crate::expr::ExprId;
19use crate::expr::ExpressionView;
20use crate::expr::StatsCatalog;
21use crate::expr::VTable;
22use crate::expr::VTableExt;
23use crate::expr::expression::Expression;
24use crate::expr::exprs::binary::Binary;
25use crate::expr::exprs::operators::Operator;
26
27pub struct Between;
39
40impl VTable for Between {
41 type Instance = BetweenOptions;
42
43 fn id(&self) -> ExprId {
44 ExprId::from("vortex.between")
45 }
46
47 fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
48 Ok(Some(
49 pb::BetweenOpts {
50 lower_strict: instance.lower_strict.is_strict(),
51 upper_strict: instance.upper_strict.is_strict(),
52 }
53 .encode_to_vec(),
54 ))
55 }
56
57 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
58 let opts = pb::BetweenOpts::decode(metadata)?;
59 Ok(Some(BetweenOptions {
60 lower_strict: if opts.lower_strict {
61 crate::compute::StrictComparison::Strict
62 } else {
63 crate::compute::StrictComparison::NonStrict
64 },
65 upper_strict: if opts.upper_strict {
66 crate::compute::StrictComparison::Strict
67 } else {
68 crate::compute::StrictComparison::NonStrict
69 },
70 }))
71 }
72
73 fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
74 if expr.children().len() != 3 {
75 vortex_bail!(
76 "Between expression requires exactly 3 children, got {}",
77 expr.children().len()
78 );
79 }
80 Ok(())
81 }
82
83 fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
84 match child_idx {
85 0 => ChildName::from("array"),
86 1 => ChildName::from("lower"),
87 2 => ChildName::from("upper"),
88 _ => unreachable!("Invalid child index {} for Between expression", child_idx),
89 }
90 }
91
92 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
93 let options = expr.data();
94 let lower_op = if options.lower_strict.is_strict() {
95 "<"
96 } else {
97 "<="
98 };
99 let upper_op = if options.upper_strict.is_strict() {
100 "<"
101 } else {
102 "<="
103 };
104 write!(
105 f,
106 "({} {} {} {} {})",
107 expr.lower(),
108 lower_op,
109 expr.child(),
110 upper_op,
111 expr.upper()
112 )
113 }
114
115 fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
116 let arr_dt = expr.child().return_dtype(scope)?;
117 let lower_dt = expr.lower().return_dtype(scope)?;
118 let upper_dt = expr.upper().return_dtype(scope)?;
119
120 if !arr_dt.eq_ignore_nullability(&lower_dt) {
121 vortex_bail!(
122 "Array dtype {} does not match lower dtype {}",
123 arr_dt,
124 lower_dt
125 );
126 }
127 if !arr_dt.eq_ignore_nullability(&upper_dt) {
128 vortex_bail!(
129 "Array dtype {} does not match upper dtype {}",
130 arr_dt,
131 upper_dt
132 );
133 }
134
135 Ok(Bool(
136 arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
137 ))
138 }
139
140 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
141 let arr = expr.child().evaluate(scope)?;
142 let lower = expr.lower().evaluate(scope)?;
143 let upper = expr.upper().evaluate(scope)?;
144 between_compute(&arr, &lower, &upper, expr.data())
145 }
146
147 fn stat_falsification(
148 &self,
149 expr: &ExpressionView<Self>,
150 catalog: &dyn StatsCatalog,
151 ) -> Option<Expression> {
152 expr.to_binary_expr().stat_falsification(catalog)
153 }
154
155 fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
156 false
157 }
158}
159
160impl ExpressionView<'_, Between> {
161 pub fn child(&self) -> &Expression {
162 &self.children()[0]
163 }
164
165 pub fn lower(&self) -> &Expression {
166 &self.children()[1]
167 }
168
169 pub fn upper(&self) -> &Expression {
170 &self.children()[2]
171 }
172
173 pub fn to_binary_expr(&self) -> Expression {
174 let options = self.data();
175 let arr = self.children()[0].clone();
176 let lower = self.children()[1].clone();
177 let upper = self.children()[2].clone();
178
179 let lhs = Binary.new_expr(
180 options.lower_strict.to_operator().into(),
181 [lower, arr.clone()],
182 );
183 let rhs = Binary.new_expr(options.upper_strict.to_operator().into(), [arr, upper]);
184 Binary.new_expr(Operator::And, [lhs, rhs])
185 }
186}
187
188pub fn between(
204 arr: Expression,
205 lower: Expression,
206 upper: Expression,
207 options: BetweenOptions,
208) -> Expression {
209 Between
210 .try_new_expr(options, [arr, lower, upper])
211 .vortex_expect("Failed to create Between expression")
212}
213
214#[cfg(test)]
215mod tests {
216 use super::between;
217 use crate::compute::BetweenOptions;
218 use crate::compute::StrictComparison;
219 use crate::expr::exprs::get_item::get_item;
220 use crate::expr::exprs::literal::lit;
221 use crate::expr::exprs::root::root;
222
223 #[test]
224 fn test_display() {
225 let expr = between(
226 get_item("score", root()),
227 lit(10),
228 lit(50),
229 BetweenOptions {
230 lower_strict: StrictComparison::NonStrict,
231 upper_strict: StrictComparison::Strict,
232 },
233 );
234 assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
235
236 let expr2 = between(
237 root(),
238 lit(0),
239 lit(100),
240 BetweenOptions {
241 lower_strict: StrictComparison::Strict,
242 upper_strict: StrictComparison::NonStrict,
243 },
244 );
245 assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
246 }
247}