1mod kernel;
5
6use std::any::Any;
7use std::fmt::Display;
8use std::fmt::Formatter;
9
10pub use kernel::*;
11use prost::Message;
12use vortex_dtype::DType;
13use vortex_dtype::DType::Bool;
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_bail;
17use vortex_error::vortex_err;
18use vortex_proto::expr as pb;
19use vortex_session::VortexSession;
20
21use crate::Array;
22use crate::ArrayRef;
23use crate::Canonical;
24use crate::ExecutionCtx;
25use crate::IntoArray;
26use crate::arrays::ConstantArray;
27use crate::arrays::DecimalVTable;
28use crate::arrays::PrimitiveVTable;
29use crate::builtins::ArrayBuiltins;
30use crate::compute::BooleanOperator;
31use crate::compute::Options;
32use crate::compute::compare;
33use crate::expr::Arity;
34use crate::expr::ChildName;
35use crate::expr::ExecutionArgs;
36use crate::expr::ExprId;
37use crate::expr::StatsCatalog;
38use crate::expr::VTable;
39use crate::expr::VTableExt;
40use crate::expr::execute_boolean;
41use crate::expr::expression::Expression;
42use crate::expr::exprs::binary::Binary;
43use crate::expr::exprs::operators::Operator;
44use crate::scalar::Scalar;
45
46#[derive(Debug, Clone, PartialEq, Eq, Hash)]
47pub struct BetweenOptions {
48 pub lower_strict: StrictComparison,
49 pub upper_strict: StrictComparison,
50}
51
52impl Display for BetweenOptions {
53 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54 let lower_op = if self.lower_strict.is_strict() {
55 "<"
56 } else {
57 "<="
58 };
59 let upper_op = if self.upper_strict.is_strict() {
60 "<"
61 } else {
62 "<="
63 };
64 write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
65 }
66}
67
68impl Options for BetweenOptions {
69 fn as_any(&self) -> &dyn Any {
70 self
71 }
72}
73
74#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
76pub enum StrictComparison {
77 Strict,
79 NonStrict,
81}
82
83impl StrictComparison {
84 pub const fn to_operator(&self) -> crate::compute::Operator {
85 match self {
86 StrictComparison::Strict => crate::compute::Operator::Lt,
87 StrictComparison::NonStrict => crate::compute::Operator::Lte,
88 }
89 }
90
91 pub const fn is_strict(&self) -> bool {
92 matches!(self, StrictComparison::Strict)
93 }
94}
95
96pub(super) fn precondition(
102 arr: &dyn Array,
103 lower: &dyn Array,
104 upper: &dyn Array,
105) -> VortexResult<Option<ArrayRef>> {
106 let return_dtype =
107 Bool(arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability());
108
109 if arr.is_empty() {
111 return Ok(Some(Canonical::empty(&return_dtype).into_array()));
112 }
113
114 if (lower.is_invalid(0)? || upper.is_invalid(0)?)
116 && let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
117 && (c_lower.is_null() || c_upper.is_null())
118 {
119 return Ok(Some(
120 ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
121 ));
122 }
123
124 if lower.as_constant().is_some_and(|v| v.is_null())
125 || upper.as_constant().is_some_and(|v| v.is_null())
126 {
127 return Ok(Some(
128 ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
129 ));
130 }
131
132 Ok(None)
133}
134
135fn between_canonical(
139 arr: &dyn Array,
140 lower: &dyn Array,
141 upper: &dyn Array,
142 options: &BetweenOptions,
143 ctx: &mut ExecutionCtx,
144) -> VortexResult<ArrayRef> {
145 if let Some(result) = precondition(arr, lower, upper)? {
146 return Ok(result);
147 }
148
149 if let Some(prim) = arr.as_opt::<PrimitiveVTable>()
151 && let Some(result) =
152 <PrimitiveVTable as BetweenKernel>::between(prim, lower, upper, options, ctx)?
153 {
154 return Ok(result);
155 }
156 if let Some(dec) = arr.as_opt::<DecimalVTable>()
157 && let Some(result) =
158 <DecimalVTable as BetweenKernel>::between(dec, lower, upper, options, ctx)?
159 {
160 return Ok(result);
161 }
162
163 execute_boolean(
166 &compare(lower, arr, options.lower_strict.to_operator())?,
167 &compare(arr, upper, options.upper_strict.to_operator())?,
168 BooleanOperator::AndKleene,
169 )
170}
171
172pub struct Between;
184
185impl VTable for Between {
186 type Options = BetweenOptions;
187
188 fn id(&self) -> ExprId {
189 ExprId::from("vortex.between")
190 }
191
192 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
193 Ok(Some(
194 pb::BetweenOpts {
195 lower_strict: instance.lower_strict.is_strict(),
196 upper_strict: instance.upper_strict.is_strict(),
197 }
198 .encode_to_vec(),
199 ))
200 }
201
202 fn deserialize(
203 &self,
204 _metadata: &[u8],
205 _session: &VortexSession,
206 ) -> VortexResult<Self::Options> {
207 let opts = pb::BetweenOpts::decode(_metadata)?;
208 Ok(BetweenOptions {
209 lower_strict: if opts.lower_strict {
210 StrictComparison::Strict
211 } else {
212 StrictComparison::NonStrict
213 },
214 upper_strict: if opts.upper_strict {
215 StrictComparison::Strict
216 } else {
217 StrictComparison::NonStrict
218 },
219 })
220 }
221
222 fn arity(&self, _options: &Self::Options) -> Arity {
223 Arity::Exact(3)
224 }
225
226 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
227 match child_idx {
228 0 => ChildName::from("array"),
229 1 => ChildName::from("lower"),
230 2 => ChildName::from("upper"),
231 _ => unreachable!("Invalid child index {} for Between expression", child_idx),
232 }
233 }
234
235 fn fmt_sql(
236 &self,
237 options: &Self::Options,
238 expr: &Expression,
239 f: &mut Formatter<'_>,
240 ) -> std::fmt::Result {
241 let lower_op = if options.lower_strict.is_strict() {
242 "<"
243 } else {
244 "<="
245 };
246 let upper_op = if options.upper_strict.is_strict() {
247 "<"
248 } else {
249 "<="
250 };
251 write!(
252 f,
253 "({} {} {} {} {})",
254 expr.child(1),
255 lower_op,
256 expr.child(0),
257 upper_op,
258 expr.child(2)
259 )
260 }
261
262 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
263 let arr_dt = &arg_dtypes[0];
264 let lower_dt = &arg_dtypes[1];
265 let upper_dt = &arg_dtypes[2];
266
267 if !arr_dt.eq_ignore_nullability(lower_dt) {
268 vortex_bail!(
269 "Array dtype {} does not match lower dtype {}",
270 arr_dt,
271 lower_dt
272 );
273 }
274 if !arr_dt.eq_ignore_nullability(upper_dt) {
275 vortex_bail!(
276 "Array dtype {} does not match upper dtype {}",
277 arr_dt,
278 upper_dt
279 );
280 }
281
282 Ok(Bool(
283 arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
284 ))
285 }
286
287 fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
288 let [arr, lower, upper]: [ArrayRef; _] = args
289 .inputs
290 .try_into()
291 .map_err(|_| vortex_err!("Expected 3 arguments for Between expression",))?;
292
293 if !arr.is_canonical() {
295 return arr.execute::<Canonical>(args.ctx)?.into_array().between(
296 lower,
297 upper,
298 options.clone(),
299 );
300 }
301
302 between_canonical(
303 arr.as_ref(),
304 lower.as_ref(),
305 upper.as_ref(),
306 options,
307 args.ctx,
308 )
309 }
310
311 fn stat_falsification(
312 &self,
313 options: &Self::Options,
314 expr: &Expression,
315 catalog: &dyn StatsCatalog,
316 ) -> Option<Expression> {
317 let arr = expr.child(0).clone();
318 let lower = expr.child(1).clone();
319 let upper = expr.child(2).clone();
320
321 let lhs = Binary.new_expr(
322 options.lower_strict.to_operator().into(),
323 [lower, arr.clone()],
324 );
325 let rhs = Binary.new_expr(options.upper_strict.to_operator().into(), [arr, upper]);
326
327 Binary
328 .new_expr(Operator::And, [lhs, rhs])
329 .stat_falsification(catalog)
330 }
331
332 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
333 false
334 }
335
336 fn is_fallible(&self, _options: &Self::Options) -> bool {
337 false
338 }
339}
340
341pub fn between(
357 arr: Expression,
358 lower: Expression,
359 upper: Expression,
360 options: BetweenOptions,
361) -> Expression {
362 Between
363 .try_new_expr(options, [arr, lower, upper])
364 .vortex_expect("Failed to create Between expression")
365}
366
367#[cfg(test)]
368mod tests {
369 use rstest::rstest;
370 use vortex_buffer::buffer;
371 use vortex_dtype::DType;
372 use vortex_dtype::DecimalDType;
373 use vortex_dtype::Nullability;
374 use vortex_dtype::PType;
375
376 use super::*;
377 use crate::IntoArray;
378 use crate::LEGACY_SESSION;
379 use crate::ToCanonical;
380 use crate::VortexSessionExecute;
381 use crate::arrays::BoolArray;
382 use crate::arrays::DecimalArray;
383 use crate::assert_arrays_eq;
384 use crate::expr::exprs::get_item::get_item;
385 use crate::expr::exprs::literal::lit;
386 use crate::expr::exprs::root::root;
387 use crate::scalar::DecimalValue;
388 use crate::scalar::Scalar;
389 use crate::test_harness::to_int_indices;
390 use crate::validity::Validity;
391
392 #[test]
393 fn test_display() {
394 let expr = between(
395 get_item("score", root()),
396 lit(10),
397 lit(50),
398 BetweenOptions {
399 lower_strict: StrictComparison::NonStrict,
400 upper_strict: StrictComparison::Strict,
401 },
402 );
403 assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
404
405 let expr2 = between(
406 root(),
407 lit(0),
408 lit(100),
409 BetweenOptions {
410 lower_strict: StrictComparison::Strict,
411 upper_strict: StrictComparison::NonStrict,
412 },
413 );
414 assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
415 }
416
417 #[rstest]
418 #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
419 #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
420 #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
421 #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
422 fn test_bounds(
423 #[case] lower_strict: StrictComparison,
424 #[case] upper_strict: StrictComparison,
425 #[case] expected: Vec<u64>,
426 ) {
427 let lower = buffer![0, 0, 0, 0, 2].into_array();
428 let array = buffer![1, 0, 1, 0, 1].into_array();
429 let upper = buffer![2, 1, 1, 0, 0].into_array();
430
431 let matches = between_canonical(
432 array.as_ref(),
433 lower.as_ref(),
434 upper.as_ref(),
435 &BetweenOptions {
436 lower_strict,
437 upper_strict,
438 },
439 &mut LEGACY_SESSION.create_execution_ctx(),
440 )
441 .unwrap()
442 .to_bool();
443
444 let indices = to_int_indices(matches).unwrap();
445 assert_eq!(indices, expected);
446 }
447
448 #[test]
449 fn test_constants() {
450 let lower = buffer![0, 0, 2, 0, 2].into_array();
451 let array = buffer![1, 0, 1, 0, 1].into_array();
452
453 let upper = ConstantArray::new(
455 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
456 5,
457 );
458
459 let matches = between_canonical(
460 array.as_ref(),
461 lower.as_ref(),
462 upper.as_ref(),
463 &BetweenOptions {
464 lower_strict: StrictComparison::NonStrict,
465 upper_strict: StrictComparison::NonStrict,
466 },
467 &mut LEGACY_SESSION.create_execution_ctx(),
468 )
469 .unwrap()
470 .to_bool();
471
472 let indices = to_int_indices(matches).unwrap();
473 assert!(indices.is_empty());
474
475 let upper = ConstantArray::new(Scalar::from(2), 5);
477 let matches = between_canonical(
478 array.as_ref(),
479 lower.as_ref(),
480 upper.as_ref(),
481 &BetweenOptions {
482 lower_strict: StrictComparison::NonStrict,
483 upper_strict: StrictComparison::NonStrict,
484 },
485 &mut LEGACY_SESSION.create_execution_ctx(),
486 )
487 .unwrap()
488 .to_bool();
489 let indices = to_int_indices(matches).unwrap();
490 assert_eq!(indices, vec![0, 1, 3]);
491
492 let lower = ConstantArray::new(Scalar::from(0), 5);
494
495 let matches = between_canonical(
496 array.as_ref(),
497 lower.as_ref(),
498 upper.as_ref(),
499 &BetweenOptions {
500 lower_strict: StrictComparison::NonStrict,
501 upper_strict: StrictComparison::NonStrict,
502 },
503 &mut LEGACY_SESSION.create_execution_ctx(),
504 )
505 .unwrap()
506 .to_bool();
507 let indices = to_int_indices(matches).unwrap();
508 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
509 }
510
511 #[test]
512 fn test_between_decimal() {
513 let values = buffer![100i128, 200i128, 300i128, 400i128];
514 let decimal_type = DecimalDType::new(3, 2);
515 let array = DecimalArray::new(values, decimal_type, Validity::NonNullable);
516
517 let lower = ConstantArray::new(
518 Scalar::decimal(
519 DecimalValue::I128(100i128),
520 decimal_type,
521 Nullability::NonNullable,
522 ),
523 array.len(),
524 );
525 let upper = ConstantArray::new(
526 Scalar::decimal(
527 DecimalValue::I128(400i128),
528 decimal_type,
529 Nullability::NonNullable,
530 ),
531 array.len(),
532 );
533
534 let between_strict = between_canonical(
536 array.as_ref(),
537 lower.as_ref(),
538 upper.as_ref(),
539 &BetweenOptions {
540 lower_strict: StrictComparison::Strict,
541 upper_strict: StrictComparison::NonStrict,
542 },
543 &mut LEGACY_SESSION.create_execution_ctx(),
544 )
545 .unwrap();
546 assert_arrays_eq!(
547 between_strict,
548 BoolArray::from_iter([false, true, true, true])
549 );
550
551 let between_strict = between_canonical(
553 array.as_ref(),
554 lower.as_ref(),
555 upper.as_ref(),
556 &BetweenOptions {
557 lower_strict: StrictComparison::NonStrict,
558 upper_strict: StrictComparison::Strict,
559 },
560 &mut LEGACY_SESSION.create_execution_ctx(),
561 )
562 .unwrap();
563 assert_arrays_eq!(
564 between_strict,
565 BoolArray::from_iter([true, true, true, false])
566 );
567 }
568}