1mod kernel;
5
6use std::fmt::Display;
7use std::fmt::Formatter;
8
9pub use kernel::*;
10use prost::Message;
11use vortex_array::expr::and;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_proto::expr as pb;
15use vortex_session::VortexSession;
16use vortex_session::registry::CachedId;
17
18use crate::ArrayRef;
19use crate::Canonical;
20use crate::ExecutionCtx;
21use crate::IntoArray;
22use crate::arrays::ConstantArray;
23use crate::arrays::Decimal;
24use crate::arrays::Primitive;
25use crate::builtins::ArrayBuiltins;
26use crate::dtype::DType;
27use crate::dtype::DType::Bool;
28use crate::expr::expression::Expression;
29use crate::scalar::Scalar;
30use crate::scalar_fn::Arity;
31use crate::scalar_fn::ChildName;
32use crate::scalar_fn::ExecutionArgs;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35use crate::scalar_fn::fns::binary::execute_boolean;
36use crate::scalar_fn::fns::operators::CompareOperator;
37use crate::scalar_fn::fns::operators::Operator;
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub struct BetweenOptions {
41 pub lower_strict: StrictComparison,
42 pub upper_strict: StrictComparison,
43}
44
45impl Display for BetweenOptions {
46 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
47 let lower_op = if self.lower_strict.is_strict() {
48 "<"
49 } else {
50 "<="
51 };
52 let upper_op = if self.upper_strict.is_strict() {
53 "<"
54 } else {
55 "<="
56 };
57 write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
58 }
59}
60
61#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
63pub enum StrictComparison {
64 Strict,
66 NonStrict,
68}
69
70impl StrictComparison {
71 pub const fn to_compare_operator(&self) -> CompareOperator {
72 match self {
73 StrictComparison::Strict => CompareOperator::Lt,
74 StrictComparison::NonStrict => CompareOperator::Lte,
75 }
76 }
77
78 pub const fn to_operator(&self) -> Operator {
79 match self {
80 StrictComparison::Strict => Operator::Lt,
81 StrictComparison::NonStrict => Operator::Lte,
82 }
83 }
84
85 pub const fn is_strict(&self) -> bool {
86 matches!(self, StrictComparison::Strict)
87 }
88}
89
90pub(super) fn precondition(
96 arr: &ArrayRef,
97 lower: &ArrayRef,
98 upper: &ArrayRef,
99) -> VortexResult<Option<ArrayRef>> {
100 let return_dtype =
101 Bool(arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability());
102
103 if arr.is_empty() {
105 return Ok(Some(Canonical::empty(&return_dtype).into_array()));
106 }
107
108 if lower.as_constant().is_some_and(|v| v.is_null())
109 || upper.as_constant().is_some_and(|v| v.is_null())
110 {
111 return Ok(Some(
112 ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
113 ));
114 }
115
116 Ok(None)
117}
118
119fn between_canonical(
123 arr: &ArrayRef,
124 lower: &ArrayRef,
125 upper: &ArrayRef,
126 options: &BetweenOptions,
127 ctx: &mut ExecutionCtx,
128) -> VortexResult<ArrayRef> {
129 if let Some(result) = precondition(arr, lower, upper)? {
130 return Ok(result);
131 }
132
133 if let Some(prim) = arr.as_opt::<Primitive>()
135 && let Some(result) =
136 <Primitive as BetweenKernel>::between(prim, lower, upper, options, ctx)?
137 {
138 return Ok(result);
139 }
140 if let Some(dec) = arr.as_opt::<Decimal>()
141 && let Some(result) = <Decimal as BetweenKernel>::between(dec, lower, upper, options, ctx)?
142 {
143 return Ok(result);
144 }
145
146 let lower_cmp = lower.clone().binary(
149 arr.clone(),
150 Operator::from(options.lower_strict.to_compare_operator()),
151 )?;
152 let upper_cmp = arr.clone().binary(
153 upper.clone(),
154 Operator::from(options.upper_strict.to_compare_operator()),
155 )?;
156 execute_boolean(&lower_cmp, &upper_cmp, Operator::And, ctx)
157}
158
159#[derive(Clone)]
171pub struct Between;
172
173impl ScalarFnVTable for Between {
174 type Options = BetweenOptions;
175
176 fn id(&self) -> ScalarFnId {
177 static ID: CachedId = CachedId::new("vortex.between");
178 *ID
179 }
180
181 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
182 Ok(Some(
183 pb::BetweenOpts {
184 lower_strict: instance.lower_strict.is_strict(),
185 upper_strict: instance.upper_strict.is_strict(),
186 }
187 .encode_to_vec(),
188 ))
189 }
190
191 fn deserialize(
192 &self,
193 _metadata: &[u8],
194 _session: &VortexSession,
195 ) -> VortexResult<Self::Options> {
196 let opts = pb::BetweenOpts::decode(_metadata)?;
197 Ok(BetweenOptions {
198 lower_strict: if opts.lower_strict {
199 StrictComparison::Strict
200 } else {
201 StrictComparison::NonStrict
202 },
203 upper_strict: if opts.upper_strict {
204 StrictComparison::Strict
205 } else {
206 StrictComparison::NonStrict
207 },
208 })
209 }
210
211 fn arity(&self, _options: &Self::Options) -> Arity {
212 Arity::Exact(3)
213 }
214
215 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
216 match child_idx {
217 0 => ChildName::from("array"),
218 1 => ChildName::from("lower"),
219 2 => ChildName::from("upper"),
220 _ => unreachable!("Invalid child index {} for Between expression", child_idx),
221 }
222 }
223
224 fn fmt_sql(
225 &self,
226 options: &Self::Options,
227 expr: &Expression,
228 f: &mut Formatter<'_>,
229 ) -> std::fmt::Result {
230 let lower_op = if options.lower_strict.is_strict() {
231 "<"
232 } else {
233 "<="
234 };
235 let upper_op = if options.upper_strict.is_strict() {
236 "<"
237 } else {
238 "<="
239 };
240 write!(
241 f,
242 "({} {} {} {} {})",
243 expr.child(1),
244 lower_op,
245 expr.child(0),
246 upper_op,
247 expr.child(2)
248 )
249 }
250
251 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
252 let arr_dt = &arg_dtypes[0];
253 let lower_dt = &arg_dtypes[1];
254 let upper_dt = &arg_dtypes[2];
255
256 if !arr_dt.eq_ignore_nullability(lower_dt) {
257 vortex_bail!(
258 "Array dtype {} does not match lower dtype {}",
259 arr_dt,
260 lower_dt
261 );
262 }
263 if !arr_dt.eq_ignore_nullability(upper_dt) {
264 vortex_bail!(
265 "Array dtype {} does not match upper dtype {}",
266 arr_dt,
267 upper_dt
268 );
269 }
270
271 Ok(Bool(
272 arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
273 ))
274 }
275
276 fn execute(
277 &self,
278 options: &Self::Options,
279 args: &dyn ExecutionArgs,
280 ctx: &mut ExecutionCtx,
281 ) -> VortexResult<ArrayRef> {
282 let arr = args.get(0)?;
283 let lower = args.get(1)?;
284 let upper = args.get(2)?;
285
286 if !arr.is_canonical() {
288 return arr.execute::<Canonical>(ctx)?.into_array().between(
289 lower,
290 upper,
291 options.clone(),
292 );
293 }
294
295 between_canonical(&arr, &lower, &upper, options, ctx)
296 }
297
298 fn validity(
299 &self,
300 _options: &Self::Options,
301 expression: &Expression,
302 ) -> VortexResult<Option<Expression>> {
303 let arr = expression.child(0).validity()?;
304 let lower = expression.child(1).validity()?;
305 let upper = expression.child(2).validity()?;
306 Ok(Some(and(and(arr, lower), upper)))
307 }
308
309 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
310 false
311 }
312
313 fn is_fallible(&self, _options: &Self::Options) -> bool {
314 false
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use std::sync::LazyLock;
321
322 use rstest::rstest;
323 use vortex_buffer::buffer;
324
325 use super::*;
326 use crate::IntoArray;
327 use crate::VortexSessionExecute;
328 use crate::arrays::BoolArray;
329 use crate::arrays::DecimalArray;
330 use crate::assert_arrays_eq;
331 use crate::dtype::DType;
332 use crate::dtype::DecimalDType;
333 use crate::dtype::Nullability;
334 use crate::dtype::PType;
335 use crate::expr::between;
336 use crate::expr::get_item;
337 use crate::expr::lit;
338 use crate::expr::root;
339 use crate::scalar::DecimalValue;
340 use crate::scalar::Scalar;
341 use crate::test_harness::to_int_indices;
342 use crate::validity::Validity;
343
344 static SESSION: LazyLock<VortexSession> = LazyLock::new(crate::array_session);
345
346 #[test]
347 fn test_display() {
348 let expr = between(
349 get_item("score", root()),
350 lit(10),
351 lit(50),
352 BetweenOptions {
353 lower_strict: StrictComparison::NonStrict,
354 upper_strict: StrictComparison::Strict,
355 },
356 );
357 assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
358
359 let expr2 = between(
360 root(),
361 lit(0),
362 lit(100),
363 BetweenOptions {
364 lower_strict: StrictComparison::Strict,
365 upper_strict: StrictComparison::NonStrict,
366 },
367 );
368 assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
369 }
370
371 #[rstest]
372 #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
373 #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
374 #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
375 #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
376 fn test_bounds(
377 #[case] lower_strict: StrictComparison,
378 #[case] upper_strict: StrictComparison,
379 #[case] expected: Vec<u64>,
380 ) {
381 let lower = buffer![0, 0, 0, 0, 2].into_array();
382 let array = buffer![1, 0, 1, 0, 1].into_array();
383 let upper = buffer![2, 1, 1, 0, 0].into_array();
384 let ctx = &mut SESSION.create_execution_ctx();
385
386 let matches = between_canonical(
387 &array,
388 &lower,
389 &upper,
390 &BetweenOptions {
391 lower_strict,
392 upper_strict,
393 },
394 ctx,
395 )
396 .unwrap()
397 .execute::<BoolArray>(ctx)
398 .unwrap();
399
400 let indices = to_int_indices(matches, ctx).unwrap();
401 assert_eq!(indices, expected);
402 }
403
404 #[test]
405 fn test_constants() {
406 let lower = buffer![0, 0, 2, 0, 2].into_array();
407 let array = buffer![1, 0, 1, 0, 1].into_array();
408 let ctx = &mut SESSION.create_execution_ctx();
409
410 let upper = ConstantArray::new(
412 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
413 5,
414 )
415 .into_array();
416
417 let matches = between_canonical(
418 &array,
419 &lower,
420 &upper,
421 &BetweenOptions {
422 lower_strict: StrictComparison::NonStrict,
423 upper_strict: StrictComparison::NonStrict,
424 },
425 ctx,
426 )
427 .unwrap()
428 .execute::<BoolArray>(ctx)
429 .unwrap();
430
431 let indices = to_int_indices(matches, ctx).unwrap();
432 assert!(indices.is_empty());
433
434 let upper = ConstantArray::new(Scalar::from(2), 5).into_array();
436 let matches = between_canonical(
437 &array,
438 &lower,
439 &upper,
440 &BetweenOptions {
441 lower_strict: StrictComparison::NonStrict,
442 upper_strict: StrictComparison::NonStrict,
443 },
444 ctx,
445 )
446 .unwrap()
447 .execute::<BoolArray>(ctx)
448 .unwrap();
449 let indices = to_int_indices(matches, ctx).unwrap();
450 assert_eq!(indices, vec![0, 1, 3]);
451
452 let lower = ConstantArray::new(Scalar::from(0), 5).into_array();
454
455 let matches = between_canonical(
456 &array,
457 &lower,
458 &upper,
459 &BetweenOptions {
460 lower_strict: StrictComparison::NonStrict,
461 upper_strict: StrictComparison::NonStrict,
462 },
463 ctx,
464 )
465 .unwrap()
466 .execute::<BoolArray>(ctx)
467 .unwrap();
468 let indices = to_int_indices(matches, ctx).unwrap();
469 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
470 }
471
472 #[test]
473 fn test_between_decimal() {
474 let ctx = &mut SESSION.create_execution_ctx();
475 let values = buffer![100i128, 200i128, 300i128, 400i128];
476 let decimal_type = DecimalDType::new(3, 2);
477 let array = DecimalArray::new(values, decimal_type, Validity::NonNullable).into_array();
478
479 let lower = ConstantArray::new(
480 Scalar::decimal(
481 DecimalValue::I128(100i128),
482 decimal_type,
483 Nullability::NonNullable,
484 ),
485 array.len(),
486 )
487 .into_array();
488 let upper = ConstantArray::new(
489 Scalar::decimal(
490 DecimalValue::I128(400i128),
491 decimal_type,
492 Nullability::NonNullable,
493 ),
494 array.len(),
495 )
496 .into_array();
497
498 let between_strict = between_canonical(
500 &array,
501 &lower,
502 &upper,
503 &BetweenOptions {
504 lower_strict: StrictComparison::Strict,
505 upper_strict: StrictComparison::NonStrict,
506 },
507 ctx,
508 )
509 .unwrap();
510 assert_arrays_eq!(
511 between_strict,
512 BoolArray::from_iter([false, true, true, true]),
513 ctx
514 );
515
516 let between_strict = between_canonical(
518 &array,
519 &lower,
520 &upper,
521 &BetweenOptions {
522 lower_strict: StrictComparison::NonStrict,
523 upper_strict: StrictComparison::Strict,
524 },
525 ctx,
526 )
527 .unwrap();
528 assert_arrays_eq!(
529 between_strict,
530 BoolArray::from_iter([true, true, true, false]),
531 ctx
532 );
533 }
534
535 #[rstest]
541 #[case(DecimalValue::I16(1), DecimalValue::I32(82246), vec![0, 1, 2, 3])]
543 #[case(DecimalValue::I32(82246), DecimalValue::I16(4), vec![])]
545 #[case(DecimalValue::I16(1), DecimalValue::I32(-82246), vec![])]
547 #[case(DecimalValue::I32(-82246), DecimalValue::I16(2), vec![0, 1])]
549 fn test_between_decimal_mismatched_storage_types(
550 #[case] lower_val: DecimalValue,
551 #[case] upper_val: DecimalValue,
552 #[case] expected_indices: Vec<u64>,
553 ) {
554 let ctx = &mut SESSION.create_execution_ctx();
555 let decimal_type = DecimalDType::new(5, -67);
558 let array = DecimalArray::new(
559 buffer![1i16, 2i16, 3i16, 4i16],
560 decimal_type,
561 Validity::NonNullable,
562 )
563 .into_array();
564
565 let lower = ConstantArray::new(
566 Scalar::decimal(lower_val, decimal_type, Nullability::NonNullable),
567 array.len(),
568 )
569 .into_array();
570 let upper = ConstantArray::new(
571 Scalar::decimal(upper_val, decimal_type, Nullability::NonNullable),
572 array.len(),
573 )
574 .into_array();
575
576 let result = between_canonical(
577 &array,
578 &lower,
579 &upper,
580 &BetweenOptions {
581 lower_strict: StrictComparison::NonStrict,
582 upper_strict: StrictComparison::NonStrict,
583 },
584 ctx,
585 )
586 .unwrap()
587 .execute::<BoolArray>(ctx)
588 .unwrap();
589
590 assert_eq!(to_int_indices(result, ctx).unwrap(), expected_indices);
591 }
592}