1mod kernel;
5
6use std::fmt::Display;
7use std::fmt::Formatter;
8
9pub use kernel::*;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::Canonical;
18use crate::ExecutionCtx;
19use crate::IntoArray;
20use crate::arrays::ConstantArray;
21use crate::arrays::Decimal;
22use crate::arrays::Primitive;
23use crate::builtins::ArrayBuiltins;
24use crate::dtype::DType;
25use crate::dtype::DType::Bool;
26use crate::expr::StatsCatalog;
27use crate::expr::expression::Expression;
28use crate::scalar::Scalar;
29use crate::scalar_fn::Arity;
30use crate::scalar_fn::ChildName;
31use crate::scalar_fn::ExecutionArgs;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34use crate::scalar_fn::ScalarFnVTableExt;
35use crate::scalar_fn::fns::binary::Binary;
36use crate::scalar_fn::fns::binary::execute_boolean;
37use crate::scalar_fn::fns::operators::CompareOperator;
38use crate::scalar_fn::fns::operators::Operator;
39
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41pub struct BetweenOptions {
42 pub lower_strict: StrictComparison,
43 pub upper_strict: StrictComparison,
44}
45
46impl Display for BetweenOptions {
47 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48 let lower_op = if self.lower_strict.is_strict() {
49 "<"
50 } else {
51 "<="
52 };
53 let upper_op = if self.upper_strict.is_strict() {
54 "<"
55 } else {
56 "<="
57 };
58 write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
59 }
60}
61
62#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
64pub enum StrictComparison {
65 Strict,
67 NonStrict,
69}
70
71impl StrictComparison {
72 pub const fn to_compare_operator(&self) -> CompareOperator {
73 match self {
74 StrictComparison::Strict => CompareOperator::Lt,
75 StrictComparison::NonStrict => CompareOperator::Lte,
76 }
77 }
78
79 pub const fn to_operator(&self) -> Operator {
80 match self {
81 StrictComparison::Strict => Operator::Lt,
82 StrictComparison::NonStrict => Operator::Lte,
83 }
84 }
85
86 pub const fn is_strict(&self) -> bool {
87 matches!(self, StrictComparison::Strict)
88 }
89}
90
91pub(super) fn precondition(
97 arr: &ArrayRef,
98 lower: &ArrayRef,
99 upper: &ArrayRef,
100) -> VortexResult<Option<ArrayRef>> {
101 let return_dtype =
102 Bool(arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability());
103
104 if arr.is_empty() {
106 return Ok(Some(Canonical::empty(&return_dtype).into_array()));
107 }
108
109 if lower.as_constant().is_some_and(|v| v.is_null())
110 || upper.as_constant().is_some_and(|v| v.is_null())
111 {
112 return Ok(Some(
113 ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
114 ));
115 }
116
117 Ok(None)
118}
119
120fn between_canonical(
124 arr: &ArrayRef,
125 lower: &ArrayRef,
126 upper: &ArrayRef,
127 options: &BetweenOptions,
128 ctx: &mut ExecutionCtx,
129) -> VortexResult<ArrayRef> {
130 if let Some(result) = precondition(arr, lower, upper)? {
131 return Ok(result);
132 }
133
134 if let Some(prim) = arr.as_opt::<Primitive>()
136 && let Some(result) =
137 <Primitive as BetweenKernel>::between(prim, lower, upper, options, ctx)?
138 {
139 return Ok(result);
140 }
141 if let Some(dec) = arr.as_opt::<Decimal>()
142 && let Some(result) = <Decimal as BetweenKernel>::between(dec, lower, upper, options, ctx)?
143 {
144 return Ok(result);
145 }
146
147 let lower_cmp = lower.clone().binary(
150 arr.clone(),
151 Operator::from(options.lower_strict.to_compare_operator()),
152 )?;
153 let upper_cmp = arr.clone().binary(
154 upper.clone(),
155 Operator::from(options.upper_strict.to_compare_operator()),
156 )?;
157 execute_boolean(&lower_cmp, &upper_cmp, Operator::And)
158}
159
160#[derive(Clone)]
172pub struct Between;
173
174impl ScalarFnVTable for Between {
175 type Options = BetweenOptions;
176
177 fn id(&self) -> ScalarFnId {
178 ScalarFnId::new("vortex.between")
179 }
180
181 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
182 Ok(Some(
183 pb::BetweenOpts {
184 lower_strict: instance.lower_strict.is_strict(),
185 upper_strict: instance.upper_strict.is_strict(),
186 }
187 .encode_to_vec(),
188 ))
189 }
190
191 fn deserialize(
192 &self,
193 _metadata: &[u8],
194 _session: &VortexSession,
195 ) -> VortexResult<Self::Options> {
196 let opts = pb::BetweenOpts::decode(_metadata)?;
197 Ok(BetweenOptions {
198 lower_strict: if opts.lower_strict {
199 StrictComparison::Strict
200 } else {
201 StrictComparison::NonStrict
202 },
203 upper_strict: if opts.upper_strict {
204 StrictComparison::Strict
205 } else {
206 StrictComparison::NonStrict
207 },
208 })
209 }
210
211 fn arity(&self, _options: &Self::Options) -> Arity {
212 Arity::Exact(3)
213 }
214
215 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
216 match child_idx {
217 0 => ChildName::from("array"),
218 1 => ChildName::from("lower"),
219 2 => ChildName::from("upper"),
220 _ => unreachable!("Invalid child index {} for Between expression", child_idx),
221 }
222 }
223
224 fn fmt_sql(
225 &self,
226 options: &Self::Options,
227 expr: &Expression,
228 f: &mut Formatter<'_>,
229 ) -> std::fmt::Result {
230 let lower_op = if options.lower_strict.is_strict() {
231 "<"
232 } else {
233 "<="
234 };
235 let upper_op = if options.upper_strict.is_strict() {
236 "<"
237 } else {
238 "<="
239 };
240 write!(
241 f,
242 "({} {} {} {} {})",
243 expr.child(1),
244 lower_op,
245 expr.child(0),
246 upper_op,
247 expr.child(2)
248 )
249 }
250
251 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
252 let arr_dt = &arg_dtypes[0];
253 let lower_dt = &arg_dtypes[1];
254 let upper_dt = &arg_dtypes[2];
255
256 if !arr_dt.eq_ignore_nullability(lower_dt) {
257 vortex_bail!(
258 "Array dtype {} does not match lower dtype {}",
259 arr_dt,
260 lower_dt
261 );
262 }
263 if !arr_dt.eq_ignore_nullability(upper_dt) {
264 vortex_bail!(
265 "Array dtype {} does not match upper dtype {}",
266 arr_dt,
267 upper_dt
268 );
269 }
270
271 Ok(Bool(
272 arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
273 ))
274 }
275
276 fn execute(
277 &self,
278 options: &Self::Options,
279 args: &dyn ExecutionArgs,
280 ctx: &mut ExecutionCtx,
281 ) -> VortexResult<ArrayRef> {
282 let arr = args.get(0)?;
283 let lower = args.get(1)?;
284 let upper = args.get(2)?;
285
286 if !arr.is_canonical() {
288 return arr.execute::<Canonical>(ctx)?.into_array().between(
289 lower,
290 upper,
291 options.clone(),
292 );
293 }
294
295 between_canonical(&arr, &lower, &upper, options, ctx)
296 }
297
298 fn stat_falsification(
299 &self,
300 options: &Self::Options,
301 expr: &Expression,
302 catalog: &dyn StatsCatalog,
303 ) -> Option<Expression> {
304 let arr = expr.child(0).clone();
305 let lower = expr.child(1).clone();
306 let upper = expr.child(2).clone();
307
308 let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]);
309 let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
310
311 Binary
312 .new_expr(Operator::And, [lhs, rhs])
313 .stat_falsification(catalog)
314 }
315
316 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
317 false
318 }
319
320 fn is_fallible(&self, _options: &Self::Options) -> bool {
321 false
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use rstest::rstest;
328 use vortex_buffer::buffer;
329
330 use super::*;
331 use crate::IntoArray;
332 use crate::LEGACY_SESSION;
333 use crate::ToCanonical;
334 use crate::VortexSessionExecute;
335 use crate::arrays::BoolArray;
336 use crate::arrays::DecimalArray;
337 use crate::assert_arrays_eq;
338 use crate::dtype::DType;
339 use crate::dtype::DecimalDType;
340 use crate::dtype::Nullability;
341 use crate::dtype::PType;
342 use crate::expr::between;
343 use crate::expr::get_item;
344 use crate::expr::lit;
345 use crate::expr::root;
346 use crate::scalar::DecimalValue;
347 use crate::scalar::Scalar;
348 use crate::test_harness::to_int_indices;
349 use crate::validity::Validity;
350
351 #[test]
352 fn test_display() {
353 let expr = between(
354 get_item("score", root()),
355 lit(10),
356 lit(50),
357 BetweenOptions {
358 lower_strict: StrictComparison::NonStrict,
359 upper_strict: StrictComparison::Strict,
360 },
361 );
362 assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
363
364 let expr2 = between(
365 root(),
366 lit(0),
367 lit(100),
368 BetweenOptions {
369 lower_strict: StrictComparison::Strict,
370 upper_strict: StrictComparison::NonStrict,
371 },
372 );
373 assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
374 }
375
376 #[rstest]
377 #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
378 #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
379 #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
380 #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
381 fn test_bounds(
382 #[case] lower_strict: StrictComparison,
383 #[case] upper_strict: StrictComparison,
384 #[case] expected: Vec<u64>,
385 ) {
386 let lower = buffer![0, 0, 0, 0, 2].into_array();
387 let array = buffer![1, 0, 1, 0, 1].into_array();
388 let upper = buffer![2, 1, 1, 0, 0].into_array();
389
390 let matches = between_canonical(
391 &array,
392 &lower,
393 &upper,
394 &BetweenOptions {
395 lower_strict,
396 upper_strict,
397 },
398 &mut LEGACY_SESSION.create_execution_ctx(),
399 )
400 .unwrap()
401 .to_bool();
402
403 let indices = to_int_indices(matches).unwrap();
404 assert_eq!(indices, expected);
405 }
406
407 #[test]
408 fn test_constants() {
409 let lower = buffer![0, 0, 2, 0, 2].into_array();
410 let array = buffer![1, 0, 1, 0, 1].into_array();
411
412 let upper = ConstantArray::new(
414 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
415 5,
416 )
417 .into_array();
418
419 let matches = between_canonical(
420 &array,
421 &lower,
422 &upper,
423 &BetweenOptions {
424 lower_strict: StrictComparison::NonStrict,
425 upper_strict: StrictComparison::NonStrict,
426 },
427 &mut LEGACY_SESSION.create_execution_ctx(),
428 )
429 .unwrap()
430 .to_bool();
431
432 let indices = to_int_indices(matches).unwrap();
433 assert!(indices.is_empty());
434
435 let upper = ConstantArray::new(Scalar::from(2), 5).into_array();
437 let matches = between_canonical(
438 &array,
439 &lower,
440 &upper,
441 &BetweenOptions {
442 lower_strict: StrictComparison::NonStrict,
443 upper_strict: StrictComparison::NonStrict,
444 },
445 &mut LEGACY_SESSION.create_execution_ctx(),
446 )
447 .unwrap()
448 .to_bool();
449 let indices = to_int_indices(matches).unwrap();
450 assert_eq!(indices, vec![0, 1, 3]);
451
452 let lower = ConstantArray::new(Scalar::from(0), 5).into_array();
454
455 let matches = between_canonical(
456 &array,
457 &lower,
458 &upper,
459 &BetweenOptions {
460 lower_strict: StrictComparison::NonStrict,
461 upper_strict: StrictComparison::NonStrict,
462 },
463 &mut LEGACY_SESSION.create_execution_ctx(),
464 )
465 .unwrap()
466 .to_bool();
467 let indices = to_int_indices(matches).unwrap();
468 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
469 }
470
471 #[test]
472 fn test_between_decimal() {
473 let values = buffer![100i128, 200i128, 300i128, 400i128];
474 let decimal_type = DecimalDType::new(3, 2);
475 let array = DecimalArray::new(values, decimal_type, Validity::NonNullable).into_array();
476
477 let lower = ConstantArray::new(
478 Scalar::decimal(
479 DecimalValue::I128(100i128),
480 decimal_type,
481 Nullability::NonNullable,
482 ),
483 array.len(),
484 )
485 .into_array();
486 let upper = ConstantArray::new(
487 Scalar::decimal(
488 DecimalValue::I128(400i128),
489 decimal_type,
490 Nullability::NonNullable,
491 ),
492 array.len(),
493 )
494 .into_array();
495
496 let between_strict = between_canonical(
498 &array,
499 &lower,
500 &upper,
501 &BetweenOptions {
502 lower_strict: StrictComparison::Strict,
503 upper_strict: StrictComparison::NonStrict,
504 },
505 &mut LEGACY_SESSION.create_execution_ctx(),
506 )
507 .unwrap();
508 assert_arrays_eq!(
509 between_strict,
510 BoolArray::from_iter([false, true, true, true])
511 );
512
513 let between_strict = between_canonical(
515 &array,
516 &lower,
517 &upper,
518 &BetweenOptions {
519 lower_strict: StrictComparison::NonStrict,
520 upper_strict: StrictComparison::Strict,
521 },
522 &mut LEGACY_SESSION.create_execution_ctx(),
523 )
524 .unwrap();
525 assert_arrays_eq!(
526 between_strict,
527 BoolArray::from_iter([true, true, true, false])
528 );
529 }
530}