1use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_dtype::DType;
8use vortex_dtype::DType::Bool;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_proto::expr as pb;
14use vortex_vector::Datum;
15
16use crate::ArrayRef;
17use crate::compute::BetweenOptions;
18use crate::compute::between as between_compute;
19use crate::expr::Arity;
20use crate::expr::ChildName;
21use crate::expr::ExecutionArgs;
22use crate::expr::ExprId;
23use crate::expr::StatsCatalog;
24use crate::expr::VTable;
25use crate::expr::VTableExt;
26use crate::expr::expression::Expression;
27use crate::expr::exprs::binary::Binary;
28use crate::expr::exprs::operators::Operator;
29
30pub struct Between;
42
43impl VTable for Between {
44 type Options = BetweenOptions;
45
46 fn id(&self) -> ExprId {
47 ExprId::from("vortex.between")
48 }
49
50 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
51 Ok(Some(
52 pb::BetweenOpts {
53 lower_strict: instance.lower_strict.is_strict(),
54 upper_strict: instance.upper_strict.is_strict(),
55 }
56 .encode_to_vec(),
57 ))
58 }
59
60 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Self::Options> {
61 let opts = pb::BetweenOpts::decode(metadata)?;
62 Ok(BetweenOptions {
63 lower_strict: if opts.lower_strict {
64 crate::compute::StrictComparison::Strict
65 } else {
66 crate::compute::StrictComparison::NonStrict
67 },
68 upper_strict: if opts.upper_strict {
69 crate::compute::StrictComparison::Strict
70 } else {
71 crate::compute::StrictComparison::NonStrict
72 },
73 })
74 }
75
76 fn arity(&self, _options: &Self::Options) -> Arity {
77 Arity::Exact(3)
78 }
79
80 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
81 match child_idx {
82 0 => ChildName::from("array"),
83 1 => ChildName::from("lower"),
84 2 => ChildName::from("upper"),
85 _ => unreachable!("Invalid child index {} for Between expression", child_idx),
86 }
87 }
88
89 fn fmt_sql(
90 &self,
91 options: &Self::Options,
92 expr: &Expression,
93 f: &mut Formatter<'_>,
94 ) -> std::fmt::Result {
95 let lower_op = if options.lower_strict.is_strict() {
96 "<"
97 } else {
98 "<="
99 };
100 let upper_op = if options.upper_strict.is_strict() {
101 "<"
102 } else {
103 "<="
104 };
105 write!(
106 f,
107 "({} {} {} {} {})",
108 expr.child(1),
109 lower_op,
110 expr.child(0),
111 upper_op,
112 expr.child(2)
113 )
114 }
115
116 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
117 let arr_dt = &arg_dtypes[0];
118 let lower_dt = &arg_dtypes[1];
119 let upper_dt = &arg_dtypes[2];
120
121 if !arr_dt.eq_ignore_nullability(lower_dt) {
122 vortex_bail!(
123 "Array dtype {} does not match lower dtype {}",
124 arr_dt,
125 lower_dt
126 );
127 }
128 if !arr_dt.eq_ignore_nullability(upper_dt) {
129 vortex_bail!(
130 "Array dtype {} does not match upper dtype {}",
131 arr_dt,
132 upper_dt
133 );
134 }
135
136 Ok(Bool(
137 arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
138 ))
139 }
140
141 fn evaluate(
142 &self,
143 options: &Self::Options,
144 expr: &Expression,
145 scope: &ArrayRef,
146 ) -> VortexResult<ArrayRef> {
147 let arr = expr.child(0).evaluate(scope)?;
148 let lower = expr.child(1).evaluate(scope)?;
149 let upper = expr.child(2).evaluate(scope)?;
150 between_compute(&arr, &lower, &upper, options)
151 }
152
153 fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<Datum> {
154 let [arr, lower, upper]: [Datum; _] = args
155 .datums
156 .try_into()
157 .map_err(|_| vortex_err!("Expected 3 arguments for Between expression",))?;
158 let [arr_dt, lower_dt, upper_dt]: [DType; _] = args
159 .dtypes
160 .try_into()
161 .map_err(|_| vortex_err!("Expected 3 dtypes for Between expression",))?;
162
163 let lower_bound = Binary
164 .bind(options.lower_strict.to_operator().into())
165 .execute(ExecutionArgs {
166 datums: vec![lower, arr.clone()],
167 dtypes: vec![lower_dt, arr_dt.clone()],
168 row_count: args.row_count,
169 return_dtype: args.return_dtype.clone(),
170 })?;
171 let upper_bound = Binary
172 .bind(options.upper_strict.to_operator().into())
173 .execute(ExecutionArgs {
174 datums: vec![arr, upper],
175 dtypes: vec![arr_dt, upper_dt],
176 row_count: args.row_count,
177 return_dtype: args.return_dtype.clone(),
178 })?;
179
180 Binary.bind(Operator::And).execute(ExecutionArgs {
181 datums: vec![lower_bound, upper_bound],
182 dtypes: vec![args.return_dtype.clone(), args.return_dtype.clone()],
183 row_count: args.row_count,
184 return_dtype: args.return_dtype,
185 })
186 }
187
188 fn stat_falsification(
189 &self,
190 options: &Self::Options,
191 expr: &Expression,
192 catalog: &dyn StatsCatalog,
193 ) -> Option<Expression> {
194 let arr = expr.child(0).clone();
195 let lower = expr.child(1).clone();
196 let upper = expr.child(2).clone();
197
198 let lhs = Binary.new_expr(
199 options.lower_strict.to_operator().into(),
200 [lower, arr.clone()],
201 );
202 let rhs = Binary.new_expr(options.upper_strict.to_operator().into(), [arr, upper]);
203
204 Binary
205 .new_expr(Operator::And, [lhs, rhs])
206 .stat_falsification(catalog)
207 }
208
209 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
210 false
211 }
212
213 fn is_fallible(&self, _options: &Self::Options) -> bool {
214 false
215 }
216}
217
218pub fn between(
234 arr: Expression,
235 lower: Expression,
236 upper: Expression,
237 options: BetweenOptions,
238) -> Expression {
239 Between
240 .try_new_expr(options, [arr, lower, upper])
241 .vortex_expect("Failed to create Between expression")
242}
243
244#[cfg(test)]
245mod tests {
246 use super::between;
247 use crate::compute::BetweenOptions;
248 use crate::compute::StrictComparison;
249 use crate::expr::exprs::get_item::get_item;
250 use crate::expr::exprs::literal::lit;
251 use crate::expr::exprs::root::root;
252
253 #[test]
254 fn test_display() {
255 let expr = between(
256 get_item("score", root()),
257 lit(10),
258 lit(50),
259 BetweenOptions {
260 lower_strict: StrictComparison::NonStrict,
261 upper_strict: StrictComparison::Strict,
262 },
263 );
264 assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
265
266 let expr2 = between(
267 root(),
268 lit(0),
269 lit(100),
270 BetweenOptions {
271 lower_strict: StrictComparison::Strict,
272 upper_strict: StrictComparison::NonStrict,
273 },
274 );
275 assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
276 }
277}