1mod kernel;
5
6use std::any::Any;
7use std::fmt::Display;
8use std::fmt::Formatter;
9
10pub use kernel::*;
11use prost::Message;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_proto::expr as pb;
15use vortex_session::VortexSession;
16
17use crate::Array;
18use crate::ArrayRef;
19use crate::Canonical;
20use crate::ExecutionCtx;
21use crate::IntoArray;
22use crate::arrays::ConstantArray;
23use crate::arrays::DecimalVTable;
24use crate::arrays::PrimitiveVTable;
25use crate::builtins::ArrayBuiltins;
26use crate::compute::Options;
27use crate::dtype::DType;
28use crate::dtype::DType::Bool;
29use crate::expr::StatsCatalog;
30use crate::expr::expression::Expression;
31use crate::scalar::Scalar;
32use crate::scalar_fn::Arity;
33use crate::scalar_fn::ChildName;
34use crate::scalar_fn::ExecutionArgs;
35use crate::scalar_fn::ScalarFnId;
36use crate::scalar_fn::ScalarFnVTable;
37use crate::scalar_fn::ScalarFnVTableExt;
38use crate::scalar_fn::fns::binary::Binary;
39use crate::scalar_fn::fns::binary::execute_boolean;
40use crate::scalar_fn::fns::operators::CompareOperator;
41use crate::scalar_fn::fns::operators::Operator;
42
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct BetweenOptions {
45 pub lower_strict: StrictComparison,
46 pub upper_strict: StrictComparison,
47}
48
49impl Display for BetweenOptions {
50 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
51 let lower_op = if self.lower_strict.is_strict() {
52 "<"
53 } else {
54 "<="
55 };
56 let upper_op = if self.upper_strict.is_strict() {
57 "<"
58 } else {
59 "<="
60 };
61 write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
62 }
63}
64
65impl Options for BetweenOptions {
66 fn as_any(&self) -> &dyn Any {
67 self
68 }
69}
70
71#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
73pub enum StrictComparison {
74 Strict,
76 NonStrict,
78}
79
80impl StrictComparison {
81 pub const fn to_compare_operator(&self) -> CompareOperator {
82 match self {
83 StrictComparison::Strict => CompareOperator::Lt,
84 StrictComparison::NonStrict => CompareOperator::Lte,
85 }
86 }
87
88 pub const fn to_operator(&self) -> Operator {
89 match self {
90 StrictComparison::Strict => Operator::Lt,
91 StrictComparison::NonStrict => Operator::Lte,
92 }
93 }
94
95 pub const fn is_strict(&self) -> bool {
96 matches!(self, StrictComparison::Strict)
97 }
98}
99
100pub(super) fn precondition(
106 arr: &ArrayRef,
107 lower: &ArrayRef,
108 upper: &ArrayRef,
109) -> VortexResult<Option<ArrayRef>> {
110 let return_dtype =
111 Bool(arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability());
112
113 if arr.is_empty() {
115 return Ok(Some(Canonical::empty(&return_dtype).into_array()));
116 }
117
118 if (lower.is_invalid(0)? || upper.is_invalid(0)?)
120 && let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
121 && (c_lower.is_null() || c_upper.is_null())
122 {
123 return Ok(Some(
124 ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
125 ));
126 }
127
128 if lower.as_constant().is_some_and(|v| v.is_null())
129 || upper.as_constant().is_some_and(|v| v.is_null())
130 {
131 return Ok(Some(
132 ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
133 ));
134 }
135
136 Ok(None)
137}
138
139fn between_canonical(
143 arr: &ArrayRef,
144 lower: &ArrayRef,
145 upper: &ArrayRef,
146 options: &BetweenOptions,
147 ctx: &mut ExecutionCtx,
148) -> VortexResult<ArrayRef> {
149 if let Some(result) = precondition(arr, lower, upper)? {
150 return Ok(result);
151 }
152
153 if let Some(prim) = arr.as_opt::<PrimitiveVTable>()
155 && let Some(result) =
156 <PrimitiveVTable as BetweenKernel>::between(prim, lower, upper, options, ctx)?
157 {
158 return Ok(result);
159 }
160 if let Some(dec) = arr.as_opt::<DecimalVTable>()
161 && let Some(result) =
162 <DecimalVTable as BetweenKernel>::between(dec, lower, upper, options, ctx)?
163 {
164 return Ok(result);
165 }
166
167 let lower_cmp = lower.to_array().binary(
170 arr.to_array(),
171 Operator::from(options.lower_strict.to_compare_operator()),
172 )?;
173 let upper_cmp = arr.to_array().binary(
174 upper.to_array(),
175 Operator::from(options.upper_strict.to_compare_operator()),
176 )?;
177 execute_boolean(&lower_cmp, &upper_cmp, Operator::And)
178}
179
180#[derive(Clone)]
192pub struct Between;
193
194impl ScalarFnVTable for Between {
195 type Options = BetweenOptions;
196
197 fn id(&self) -> ScalarFnId {
198 ScalarFnId::from("vortex.between")
199 }
200
201 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
202 Ok(Some(
203 pb::BetweenOpts {
204 lower_strict: instance.lower_strict.is_strict(),
205 upper_strict: instance.upper_strict.is_strict(),
206 }
207 .encode_to_vec(),
208 ))
209 }
210
211 fn deserialize(
212 &self,
213 _metadata: &[u8],
214 _session: &VortexSession,
215 ) -> VortexResult<Self::Options> {
216 let opts = pb::BetweenOpts::decode(_metadata)?;
217 Ok(BetweenOptions {
218 lower_strict: if opts.lower_strict {
219 StrictComparison::Strict
220 } else {
221 StrictComparison::NonStrict
222 },
223 upper_strict: if opts.upper_strict {
224 StrictComparison::Strict
225 } else {
226 StrictComparison::NonStrict
227 },
228 })
229 }
230
231 fn arity(&self, _options: &Self::Options) -> Arity {
232 Arity::Exact(3)
233 }
234
235 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
236 match child_idx {
237 0 => ChildName::from("array"),
238 1 => ChildName::from("lower"),
239 2 => ChildName::from("upper"),
240 _ => unreachable!("Invalid child index {} for Between expression", child_idx),
241 }
242 }
243
244 fn fmt_sql(
245 &self,
246 options: &Self::Options,
247 expr: &Expression,
248 f: &mut Formatter<'_>,
249 ) -> std::fmt::Result {
250 let lower_op = if options.lower_strict.is_strict() {
251 "<"
252 } else {
253 "<="
254 };
255 let upper_op = if options.upper_strict.is_strict() {
256 "<"
257 } else {
258 "<="
259 };
260 write!(
261 f,
262 "({} {} {} {} {})",
263 expr.child(1),
264 lower_op,
265 expr.child(0),
266 upper_op,
267 expr.child(2)
268 )
269 }
270
271 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
272 let arr_dt = &arg_dtypes[0];
273 let lower_dt = &arg_dtypes[1];
274 let upper_dt = &arg_dtypes[2];
275
276 if !arr_dt.eq_ignore_nullability(lower_dt) {
277 vortex_bail!(
278 "Array dtype {} does not match lower dtype {}",
279 arr_dt,
280 lower_dt
281 );
282 }
283 if !arr_dt.eq_ignore_nullability(upper_dt) {
284 vortex_bail!(
285 "Array dtype {} does not match upper dtype {}",
286 arr_dt,
287 upper_dt
288 );
289 }
290
291 Ok(Bool(
292 arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
293 ))
294 }
295
296 fn execute(
297 &self,
298 options: &Self::Options,
299 args: &dyn ExecutionArgs,
300 ctx: &mut ExecutionCtx,
301 ) -> VortexResult<ArrayRef> {
302 let arr = args.get(0)?;
303 let lower = args.get(1)?;
304 let upper = args.get(2)?;
305
306 if !arr.is_canonical() {
308 return arr.execute::<Canonical>(ctx)?.into_array().between(
309 lower,
310 upper,
311 options.clone(),
312 );
313 }
314
315 between_canonical(&arr, &lower, &upper, options, ctx)
316 }
317
318 fn stat_falsification(
319 &self,
320 options: &Self::Options,
321 expr: &Expression,
322 catalog: &dyn StatsCatalog,
323 ) -> Option<Expression> {
324 let arr = expr.child(0).clone();
325 let lower = expr.child(1).clone();
326 let upper = expr.child(2).clone();
327
328 let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]);
329 let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
330
331 Binary
332 .new_expr(Operator::And, [lhs, rhs])
333 .stat_falsification(catalog)
334 }
335
336 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
337 false
338 }
339
340 fn is_fallible(&self, _options: &Self::Options) -> bool {
341 false
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use rstest::rstest;
348 use vortex_buffer::buffer;
349
350 use super::*;
351 use crate::IntoArray;
352 use crate::LEGACY_SESSION;
353 use crate::ToCanonical;
354 use crate::VortexSessionExecute;
355 use crate::arrays::BoolArray;
356 use crate::arrays::DecimalArray;
357 use crate::assert_arrays_eq;
358 use crate::dtype::DType;
359 use crate::dtype::DecimalDType;
360 use crate::dtype::Nullability;
361 use crate::dtype::PType;
362 use crate::expr::between;
363 use crate::expr::get_item;
364 use crate::expr::lit;
365 use crate::expr::root;
366 use crate::scalar::DecimalValue;
367 use crate::scalar::Scalar;
368 use crate::test_harness::to_int_indices;
369 use crate::validity::Validity;
370
371 #[test]
372 fn test_display() {
373 let expr = between(
374 get_item("score", root()),
375 lit(10),
376 lit(50),
377 BetweenOptions {
378 lower_strict: StrictComparison::NonStrict,
379 upper_strict: StrictComparison::Strict,
380 },
381 );
382 assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
383
384 let expr2 = between(
385 root(),
386 lit(0),
387 lit(100),
388 BetweenOptions {
389 lower_strict: StrictComparison::Strict,
390 upper_strict: StrictComparison::NonStrict,
391 },
392 );
393 assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
394 }
395
396 #[rstest]
397 #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
398 #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
399 #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
400 #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
401 fn test_bounds(
402 #[case] lower_strict: StrictComparison,
403 #[case] upper_strict: StrictComparison,
404 #[case] expected: Vec<u64>,
405 ) {
406 let lower = buffer![0, 0, 0, 0, 2].into_array();
407 let array = buffer![1, 0, 1, 0, 1].into_array();
408 let upper = buffer![2, 1, 1, 0, 0].into_array();
409
410 let matches = between_canonical(
411 &array,
412 &lower,
413 &upper,
414 &BetweenOptions {
415 lower_strict,
416 upper_strict,
417 },
418 &mut LEGACY_SESSION.create_execution_ctx(),
419 )
420 .unwrap()
421 .to_bool();
422
423 let indices = to_int_indices(matches).unwrap();
424 assert_eq!(indices, expected);
425 }
426
427 #[test]
428 fn test_constants() {
429 let lower = buffer![0, 0, 2, 0, 2].into_array();
430 let array = buffer![1, 0, 1, 0, 1].into_array();
431
432 let upper = ConstantArray::new(
434 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
435 5,
436 )
437 .into_array();
438
439 let matches = between_canonical(
440 &array,
441 &lower,
442 &upper,
443 &BetweenOptions {
444 lower_strict: StrictComparison::NonStrict,
445 upper_strict: StrictComparison::NonStrict,
446 },
447 &mut LEGACY_SESSION.create_execution_ctx(),
448 )
449 .unwrap()
450 .to_bool();
451
452 let indices = to_int_indices(matches).unwrap();
453 assert!(indices.is_empty());
454
455 let upper = ConstantArray::new(Scalar::from(2), 5).into_array();
457 let matches = between_canonical(
458 &array,
459 &lower,
460 &upper,
461 &BetweenOptions {
462 lower_strict: StrictComparison::NonStrict,
463 upper_strict: StrictComparison::NonStrict,
464 },
465 &mut LEGACY_SESSION.create_execution_ctx(),
466 )
467 .unwrap()
468 .to_bool();
469 let indices = to_int_indices(matches).unwrap();
470 assert_eq!(indices, vec![0, 1, 3]);
471
472 let lower = ConstantArray::new(Scalar::from(0), 5).into_array();
474
475 let matches = between_canonical(
476 &array,
477 &lower,
478 &upper,
479 &BetweenOptions {
480 lower_strict: StrictComparison::NonStrict,
481 upper_strict: StrictComparison::NonStrict,
482 },
483 &mut LEGACY_SESSION.create_execution_ctx(),
484 )
485 .unwrap()
486 .to_bool();
487 let indices = to_int_indices(matches).unwrap();
488 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
489 }
490
491 #[test]
492 fn test_between_decimal() {
493 let values = buffer![100i128, 200i128, 300i128, 400i128];
494 let decimal_type = DecimalDType::new(3, 2);
495 let array = DecimalArray::new(values, decimal_type, Validity::NonNullable).into_array();
496
497 let lower = ConstantArray::new(
498 Scalar::decimal(
499 DecimalValue::I128(100i128),
500 decimal_type,
501 Nullability::NonNullable,
502 ),
503 array.len(),
504 )
505 .into_array();
506 let upper = ConstantArray::new(
507 Scalar::decimal(
508 DecimalValue::I128(400i128),
509 decimal_type,
510 Nullability::NonNullable,
511 ),
512 array.len(),
513 )
514 .into_array();
515
516 let between_strict = between_canonical(
518 &array,
519 &lower,
520 &upper,
521 &BetweenOptions {
522 lower_strict: StrictComparison::Strict,
523 upper_strict: StrictComparison::NonStrict,
524 },
525 &mut LEGACY_SESSION.create_execution_ctx(),
526 )
527 .unwrap();
528 assert_arrays_eq!(
529 between_strict,
530 BoolArray::from_iter([false, true, true, true])
531 );
532
533 let between_strict = between_canonical(
535 &array,
536 &lower,
537 &upper,
538 &BetweenOptions {
539 lower_strict: StrictComparison::NonStrict,
540 upper_strict: StrictComparison::Strict,
541 },
542 &mut LEGACY_SESSION.create_execution_ctx(),
543 )
544 .unwrap();
545 assert_arrays_eq!(
546 between_strict,
547 BoolArray::from_iter([true, true, true, false])
548 );
549 }
550}