1mod kernel;
5
6use std::fmt::Display;
7use std::fmt::Formatter;
8
9pub use kernel::*;
10use prost::Message;
11use vortex_array::expr::and;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_proto::expr as pb;
15use vortex_session::VortexSession;
16use vortex_session::registry::CachedId;
17
18use crate::ArrayRef;
19use crate::Canonical;
20use crate::ExecutionCtx;
21use crate::IntoArray;
22use crate::arrays::ConstantArray;
23use crate::arrays::Decimal;
24use crate::arrays::Primitive;
25use crate::builtins::ArrayBuiltins;
26use crate::dtype::DType;
27use crate::dtype::DType::Bool;
28use crate::expr::StatsCatalog;
29use crate::expr::expression::Expression;
30use crate::scalar::Scalar;
31use crate::scalar_fn::Arity;
32use crate::scalar_fn::ChildName;
33use crate::scalar_fn::ExecutionArgs;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36use crate::scalar_fn::ScalarFnVTableExt;
37use crate::scalar_fn::fns::binary::Binary;
38use crate::scalar_fn::fns::binary::execute_boolean;
39use crate::scalar_fn::fns::operators::CompareOperator;
40use crate::scalar_fn::fns::operators::Operator;
41
42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
43pub struct BetweenOptions {
44 pub lower_strict: StrictComparison,
45 pub upper_strict: StrictComparison,
46}
47
48impl Display for BetweenOptions {
49 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
50 let lower_op = if self.lower_strict.is_strict() {
51 "<"
52 } else {
53 "<="
54 };
55 let upper_op = if self.upper_strict.is_strict() {
56 "<"
57 } else {
58 "<="
59 };
60 write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
61 }
62}
63
64#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
66pub enum StrictComparison {
67 Strict,
69 NonStrict,
71}
72
73impl StrictComparison {
74 pub const fn to_compare_operator(&self) -> CompareOperator {
75 match self {
76 StrictComparison::Strict => CompareOperator::Lt,
77 StrictComparison::NonStrict => CompareOperator::Lte,
78 }
79 }
80
81 pub const fn to_operator(&self) -> Operator {
82 match self {
83 StrictComparison::Strict => Operator::Lt,
84 StrictComparison::NonStrict => Operator::Lte,
85 }
86 }
87
88 pub const fn is_strict(&self) -> bool {
89 matches!(self, StrictComparison::Strict)
90 }
91}
92
93pub(super) fn precondition(
99 arr: &ArrayRef,
100 lower: &ArrayRef,
101 upper: &ArrayRef,
102) -> VortexResult<Option<ArrayRef>> {
103 let return_dtype =
104 Bool(arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability());
105
106 if arr.is_empty() {
108 return Ok(Some(Canonical::empty(&return_dtype).into_array()));
109 }
110
111 if lower.as_constant().is_some_and(|v| v.is_null())
112 || upper.as_constant().is_some_and(|v| v.is_null())
113 {
114 return Ok(Some(
115 ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
116 ));
117 }
118
119 Ok(None)
120}
121
122fn between_canonical(
126 arr: &ArrayRef,
127 lower: &ArrayRef,
128 upper: &ArrayRef,
129 options: &BetweenOptions,
130 ctx: &mut ExecutionCtx,
131) -> VortexResult<ArrayRef> {
132 if let Some(result) = precondition(arr, lower, upper)? {
133 return Ok(result);
134 }
135
136 if let Some(prim) = arr.as_opt::<Primitive>()
138 && let Some(result) =
139 <Primitive as BetweenKernel>::between(prim, lower, upper, options, ctx)?
140 {
141 return Ok(result);
142 }
143 if let Some(dec) = arr.as_opt::<Decimal>()
144 && let Some(result) = <Decimal as BetweenKernel>::between(dec, lower, upper, options, ctx)?
145 {
146 return Ok(result);
147 }
148
149 let lower_cmp = lower.clone().binary(
152 arr.clone(),
153 Operator::from(options.lower_strict.to_compare_operator()),
154 )?;
155 let upper_cmp = arr.clone().binary(
156 upper.clone(),
157 Operator::from(options.upper_strict.to_compare_operator()),
158 )?;
159 execute_boolean(&lower_cmp, &upper_cmp, Operator::And, ctx)
160}
161
162#[derive(Clone)]
174pub struct Between;
175
176impl ScalarFnVTable for Between {
177 type Options = BetweenOptions;
178
179 fn id(&self) -> ScalarFnId {
180 static ID: CachedId = CachedId::new("vortex.between");
181 *ID
182 }
183
184 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
185 Ok(Some(
186 pb::BetweenOpts {
187 lower_strict: instance.lower_strict.is_strict(),
188 upper_strict: instance.upper_strict.is_strict(),
189 }
190 .encode_to_vec(),
191 ))
192 }
193
194 fn deserialize(
195 &self,
196 _metadata: &[u8],
197 _session: &VortexSession,
198 ) -> VortexResult<Self::Options> {
199 let opts = pb::BetweenOpts::decode(_metadata)?;
200 Ok(BetweenOptions {
201 lower_strict: if opts.lower_strict {
202 StrictComparison::Strict
203 } else {
204 StrictComparison::NonStrict
205 },
206 upper_strict: if opts.upper_strict {
207 StrictComparison::Strict
208 } else {
209 StrictComparison::NonStrict
210 },
211 })
212 }
213
214 fn arity(&self, _options: &Self::Options) -> Arity {
215 Arity::Exact(3)
216 }
217
218 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
219 match child_idx {
220 0 => ChildName::from("array"),
221 1 => ChildName::from("lower"),
222 2 => ChildName::from("upper"),
223 _ => unreachable!("Invalid child index {} for Between expression", child_idx),
224 }
225 }
226
227 fn fmt_sql(
228 &self,
229 options: &Self::Options,
230 expr: &Expression,
231 f: &mut Formatter<'_>,
232 ) -> std::fmt::Result {
233 let lower_op = if options.lower_strict.is_strict() {
234 "<"
235 } else {
236 "<="
237 };
238 let upper_op = if options.upper_strict.is_strict() {
239 "<"
240 } else {
241 "<="
242 };
243 write!(
244 f,
245 "({} {} {} {} {})",
246 expr.child(1),
247 lower_op,
248 expr.child(0),
249 upper_op,
250 expr.child(2)
251 )
252 }
253
254 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
255 let arr_dt = &arg_dtypes[0];
256 let lower_dt = &arg_dtypes[1];
257 let upper_dt = &arg_dtypes[2];
258
259 if !arr_dt.eq_ignore_nullability(lower_dt) {
260 vortex_bail!(
261 "Array dtype {} does not match lower dtype {}",
262 arr_dt,
263 lower_dt
264 );
265 }
266 if !arr_dt.eq_ignore_nullability(upper_dt) {
267 vortex_bail!(
268 "Array dtype {} does not match upper dtype {}",
269 arr_dt,
270 upper_dt
271 );
272 }
273
274 Ok(Bool(
275 arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
276 ))
277 }
278
279 fn execute(
280 &self,
281 options: &Self::Options,
282 args: &dyn ExecutionArgs,
283 ctx: &mut ExecutionCtx,
284 ) -> VortexResult<ArrayRef> {
285 let arr = args.get(0)?;
286 let lower = args.get(1)?;
287 let upper = args.get(2)?;
288
289 if !arr.is_canonical() {
291 return arr.execute::<Canonical>(ctx)?.into_array().between(
292 lower,
293 upper,
294 options.clone(),
295 );
296 }
297
298 between_canonical(&arr, &lower, &upper, options, ctx)
299 }
300
301 fn stat_falsification(
302 &self,
303 options: &Self::Options,
304 expr: &Expression,
305 catalog: &dyn StatsCatalog,
306 ) -> Option<Expression> {
307 let arr = expr.child(0).clone();
308 let lower = expr.child(1).clone();
309 let upper = expr.child(2).clone();
310
311 let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]);
312 let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
313
314 and(lhs, rhs).stat_falsification(catalog)
315 }
316
317 fn validity(
318 &self,
319 _options: &Self::Options,
320 expression: &Expression,
321 ) -> VortexResult<Option<Expression>> {
322 let arr = expression.child(0).validity()?;
323 let lower = expression.child(1).validity()?;
324 let upper = expression.child(2).validity()?;
325 Ok(Some(and(and(arr, lower), upper)))
326 }
327
328 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
329 false
330 }
331
332 fn is_fallible(&self, _options: &Self::Options) -> bool {
333 false
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use std::sync::LazyLock;
340
341 use rstest::rstest;
342 use vortex_buffer::buffer;
343
344 use super::*;
345 use crate::IntoArray;
346 use crate::VortexSessionExecute;
347 use crate::arrays::BoolArray;
348 use crate::arrays::DecimalArray;
349 use crate::assert_arrays_eq;
350 use crate::dtype::DType;
351 use crate::dtype::DecimalDType;
352 use crate::dtype::Nullability;
353 use crate::dtype::PType;
354 use crate::expr::between;
355 use crate::expr::get_item;
356 use crate::expr::lit;
357 use crate::expr::root;
358 use crate::scalar::DecimalValue;
359 use crate::scalar::Scalar;
360 use crate::session::ArraySession;
361 use crate::test_harness::to_int_indices;
362 use crate::validity::Validity;
363
364 static SESSION: LazyLock<VortexSession> =
365 LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
366
367 #[test]
368 fn test_display() {
369 let expr = between(
370 get_item("score", root()),
371 lit(10),
372 lit(50),
373 BetweenOptions {
374 lower_strict: StrictComparison::NonStrict,
375 upper_strict: StrictComparison::Strict,
376 },
377 );
378 assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
379
380 let expr2 = between(
381 root(),
382 lit(0),
383 lit(100),
384 BetweenOptions {
385 lower_strict: StrictComparison::Strict,
386 upper_strict: StrictComparison::NonStrict,
387 },
388 );
389 assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
390 }
391
392 #[rstest]
393 #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
394 #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
395 #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
396 #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
397 fn test_bounds(
398 #[case] lower_strict: StrictComparison,
399 #[case] upper_strict: StrictComparison,
400 #[case] expected: Vec<u64>,
401 ) {
402 let lower = buffer![0, 0, 0, 0, 2].into_array();
403 let array = buffer![1, 0, 1, 0, 1].into_array();
404 let upper = buffer![2, 1, 1, 0, 0].into_array();
405
406 let matches = between_canonical(
407 &array,
408 &lower,
409 &upper,
410 &BetweenOptions {
411 lower_strict,
412 upper_strict,
413 },
414 &mut SESSION.create_execution_ctx(),
415 )
416 .unwrap()
417 .execute::<BoolArray>(&mut SESSION.create_execution_ctx())
418 .unwrap();
419
420 let indices = to_int_indices(matches).unwrap();
421 assert_eq!(indices, expected);
422 }
423
424 #[test]
425 fn test_constants() {
426 let lower = buffer![0, 0, 2, 0, 2].into_array();
427 let array = buffer![1, 0, 1, 0, 1].into_array();
428
429 let upper = ConstantArray::new(
431 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
432 5,
433 )
434 .into_array();
435
436 let matches = between_canonical(
437 &array,
438 &lower,
439 &upper,
440 &BetweenOptions {
441 lower_strict: StrictComparison::NonStrict,
442 upper_strict: StrictComparison::NonStrict,
443 },
444 &mut SESSION.create_execution_ctx(),
445 )
446 .unwrap()
447 .execute::<BoolArray>(&mut SESSION.create_execution_ctx())
448 .unwrap();
449
450 let indices = to_int_indices(matches).unwrap();
451 assert!(indices.is_empty());
452
453 let upper = ConstantArray::new(Scalar::from(2), 5).into_array();
455 let matches = between_canonical(
456 &array,
457 &lower,
458 &upper,
459 &BetweenOptions {
460 lower_strict: StrictComparison::NonStrict,
461 upper_strict: StrictComparison::NonStrict,
462 },
463 &mut SESSION.create_execution_ctx(),
464 )
465 .unwrap()
466 .execute::<BoolArray>(&mut SESSION.create_execution_ctx())
467 .unwrap();
468 let indices = to_int_indices(matches).unwrap();
469 assert_eq!(indices, vec![0, 1, 3]);
470
471 let lower = ConstantArray::new(Scalar::from(0), 5).into_array();
473
474 let matches = between_canonical(
475 &array,
476 &lower,
477 &upper,
478 &BetweenOptions {
479 lower_strict: StrictComparison::NonStrict,
480 upper_strict: StrictComparison::NonStrict,
481 },
482 &mut SESSION.create_execution_ctx(),
483 )
484 .unwrap()
485 .execute::<BoolArray>(&mut SESSION.create_execution_ctx())
486 .unwrap();
487 let indices = to_int_indices(matches).unwrap();
488 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
489 }
490
491 #[test]
492 fn test_between_decimal() {
493 let values = buffer![100i128, 200i128, 300i128, 400i128];
494 let decimal_type = DecimalDType::new(3, 2);
495 let array = DecimalArray::new(values, decimal_type, Validity::NonNullable).into_array();
496
497 let lower = ConstantArray::new(
498 Scalar::decimal(
499 DecimalValue::I128(100i128),
500 decimal_type,
501 Nullability::NonNullable,
502 ),
503 array.len(),
504 )
505 .into_array();
506 let upper = ConstantArray::new(
507 Scalar::decimal(
508 DecimalValue::I128(400i128),
509 decimal_type,
510 Nullability::NonNullable,
511 ),
512 array.len(),
513 )
514 .into_array();
515
516 let between_strict = between_canonical(
518 &array,
519 &lower,
520 &upper,
521 &BetweenOptions {
522 lower_strict: StrictComparison::Strict,
523 upper_strict: StrictComparison::NonStrict,
524 },
525 &mut SESSION.create_execution_ctx(),
526 )
527 .unwrap();
528 assert_arrays_eq!(
529 between_strict,
530 BoolArray::from_iter([false, true, true, true])
531 );
532
533 let between_strict = between_canonical(
535 &array,
536 &lower,
537 &upper,
538 &BetweenOptions {
539 lower_strict: StrictComparison::NonStrict,
540 upper_strict: StrictComparison::Strict,
541 },
542 &mut SESSION.create_execution_ctx(),
543 )
544 .unwrap();
545 assert_arrays_eq!(
546 between_strict,
547 BoolArray::from_iter([true, true, true, false])
548 );
549 }
550
551 #[rstest]
557 #[case(DecimalValue::I16(1), DecimalValue::I32(82246), vec![0, 1, 2, 3])]
559 #[case(DecimalValue::I32(82246), DecimalValue::I16(4), vec![])]
561 #[case(DecimalValue::I16(1), DecimalValue::I32(-82246), vec![])]
563 #[case(DecimalValue::I32(-82246), DecimalValue::I16(2), vec![0, 1])]
565 fn test_between_decimal_mismatched_storage_types(
566 #[case] lower_val: DecimalValue,
567 #[case] upper_val: DecimalValue,
568 #[case] expected_indices: Vec<u64>,
569 ) {
570 let decimal_type = DecimalDType::new(5, -67);
573 let array = DecimalArray::new(
574 buffer![1i16, 2i16, 3i16, 4i16],
575 decimal_type,
576 Validity::NonNullable,
577 )
578 .into_array();
579
580 let lower = ConstantArray::new(
581 Scalar::decimal(lower_val, decimal_type, Nullability::NonNullable),
582 array.len(),
583 )
584 .into_array();
585 let upper = ConstantArray::new(
586 Scalar::decimal(upper_val, decimal_type, Nullability::NonNullable),
587 array.len(),
588 )
589 .into_array();
590
591 let result = between_canonical(
592 &array,
593 &lower,
594 &upper,
595 &BetweenOptions {
596 lower_strict: StrictComparison::NonStrict,
597 upper_strict: StrictComparison::NonStrict,
598 },
599 &mut SESSION.create_execution_ctx(),
600 )
601 .unwrap()
602 .execute::<BoolArray>(&mut SESSION.create_execution_ctx())
603 .unwrap();
604
605 assert_eq!(to_int_indices(result).unwrap(), expected_indices);
606 }
607}