1mod kernel;
5
6use std::any::Any;
7use std::fmt::Display;
8use std::fmt::Formatter;
9
10pub use kernel::*;
11use prost::Message;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_proto::expr as pb;
15use vortex_session::VortexSession;
16
17use crate::ArrayRef;
18use crate::Canonical;
19use crate::DynArray;
20use crate::ExecutionCtx;
21use crate::IntoArray;
22use crate::arrays::ConstantArray;
23use crate::arrays::Decimal;
24use crate::arrays::Primitive;
25use crate::builtins::ArrayBuiltins;
26use crate::compute::Options;
27use crate::dtype::DType;
28use crate::dtype::DType::Bool;
29use crate::expr::StatsCatalog;
30use crate::expr::expression::Expression;
31use crate::scalar::Scalar;
32use crate::scalar_fn::Arity;
33use crate::scalar_fn::ChildName;
34use crate::scalar_fn::ExecutionArgs;
35use crate::scalar_fn::ScalarFnId;
36use crate::scalar_fn::ScalarFnVTable;
37use crate::scalar_fn::ScalarFnVTableExt;
38use crate::scalar_fn::fns::binary::Binary;
39use crate::scalar_fn::fns::binary::execute_boolean;
40use crate::scalar_fn::fns::operators::CompareOperator;
41use crate::scalar_fn::fns::operators::Operator;
42
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct BetweenOptions {
45 pub lower_strict: StrictComparison,
46 pub upper_strict: StrictComparison,
47}
48
49impl Display for BetweenOptions {
50 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
51 let lower_op = if self.lower_strict.is_strict() {
52 "<"
53 } else {
54 "<="
55 };
56 let upper_op = if self.upper_strict.is_strict() {
57 "<"
58 } else {
59 "<="
60 };
61 write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
62 }
63}
64
65impl Options for BetweenOptions {
66 fn as_any(&self) -> &dyn Any {
67 self
68 }
69}
70
71#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
73pub enum StrictComparison {
74 Strict,
76 NonStrict,
78}
79
80impl StrictComparison {
81 pub const fn to_compare_operator(&self) -> CompareOperator {
82 match self {
83 StrictComparison::Strict => CompareOperator::Lt,
84 StrictComparison::NonStrict => CompareOperator::Lte,
85 }
86 }
87
88 pub const fn to_operator(&self) -> Operator {
89 match self {
90 StrictComparison::Strict => Operator::Lt,
91 StrictComparison::NonStrict => Operator::Lte,
92 }
93 }
94
95 pub const fn is_strict(&self) -> bool {
96 matches!(self, StrictComparison::Strict)
97 }
98}
99
100pub(super) fn precondition(
106 arr: &ArrayRef,
107 lower: &ArrayRef,
108 upper: &ArrayRef,
109) -> VortexResult<Option<ArrayRef>> {
110 let return_dtype =
111 Bool(arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability());
112
113 if arr.is_empty() {
115 return Ok(Some(Canonical::empty(&return_dtype).into_array()));
116 }
117
118 if (lower.is_invalid(0)? || upper.is_invalid(0)?)
120 && let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
121 && (c_lower.is_null() || c_upper.is_null())
122 {
123 return Ok(Some(
124 ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
125 ));
126 }
127
128 if lower.as_constant().is_some_and(|v| v.is_null())
129 || upper.as_constant().is_some_and(|v| v.is_null())
130 {
131 return Ok(Some(
132 ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
133 ));
134 }
135
136 Ok(None)
137}
138
139fn between_canonical(
143 arr: &ArrayRef,
144 lower: &ArrayRef,
145 upper: &ArrayRef,
146 options: &BetweenOptions,
147 ctx: &mut ExecutionCtx,
148) -> VortexResult<ArrayRef> {
149 if let Some(result) = precondition(arr, lower, upper)? {
150 return Ok(result);
151 }
152
153 if let Some(prim) = arr.as_opt::<Primitive>()
155 && let Some(result) =
156 <Primitive as BetweenKernel>::between(prim, lower, upper, options, ctx)?
157 {
158 return Ok(result);
159 }
160 if let Some(dec) = arr.as_opt::<Decimal>()
161 && let Some(result) = <Decimal as BetweenKernel>::between(dec, lower, upper, options, ctx)?
162 {
163 return Ok(result);
164 }
165
166 let lower_cmp = lower.to_array().binary(
169 arr.to_array(),
170 Operator::from(options.lower_strict.to_compare_operator()),
171 )?;
172 let upper_cmp = arr.to_array().binary(
173 upper.to_array(),
174 Operator::from(options.upper_strict.to_compare_operator()),
175 )?;
176 execute_boolean(&lower_cmp, &upper_cmp, Operator::And)
177}
178
179#[derive(Clone)]
191pub struct Between;
192
193impl ScalarFnVTable for Between {
194 type Options = BetweenOptions;
195
196 fn id(&self) -> ScalarFnId {
197 ScalarFnId::from("vortex.between")
198 }
199
200 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
201 Ok(Some(
202 pb::BetweenOpts {
203 lower_strict: instance.lower_strict.is_strict(),
204 upper_strict: instance.upper_strict.is_strict(),
205 }
206 .encode_to_vec(),
207 ))
208 }
209
210 fn deserialize(
211 &self,
212 _metadata: &[u8],
213 _session: &VortexSession,
214 ) -> VortexResult<Self::Options> {
215 let opts = pb::BetweenOpts::decode(_metadata)?;
216 Ok(BetweenOptions {
217 lower_strict: if opts.lower_strict {
218 StrictComparison::Strict
219 } else {
220 StrictComparison::NonStrict
221 },
222 upper_strict: if opts.upper_strict {
223 StrictComparison::Strict
224 } else {
225 StrictComparison::NonStrict
226 },
227 })
228 }
229
230 fn arity(&self, _options: &Self::Options) -> Arity {
231 Arity::Exact(3)
232 }
233
234 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
235 match child_idx {
236 0 => ChildName::from("array"),
237 1 => ChildName::from("lower"),
238 2 => ChildName::from("upper"),
239 _ => unreachable!("Invalid child index {} for Between expression", child_idx),
240 }
241 }
242
243 fn fmt_sql(
244 &self,
245 options: &Self::Options,
246 expr: &Expression,
247 f: &mut Formatter<'_>,
248 ) -> std::fmt::Result {
249 let lower_op = if options.lower_strict.is_strict() {
250 "<"
251 } else {
252 "<="
253 };
254 let upper_op = if options.upper_strict.is_strict() {
255 "<"
256 } else {
257 "<="
258 };
259 write!(
260 f,
261 "({} {} {} {} {})",
262 expr.child(1),
263 lower_op,
264 expr.child(0),
265 upper_op,
266 expr.child(2)
267 )
268 }
269
270 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
271 let arr_dt = &arg_dtypes[0];
272 let lower_dt = &arg_dtypes[1];
273 let upper_dt = &arg_dtypes[2];
274
275 if !arr_dt.eq_ignore_nullability(lower_dt) {
276 vortex_bail!(
277 "Array dtype {} does not match lower dtype {}",
278 arr_dt,
279 lower_dt
280 );
281 }
282 if !arr_dt.eq_ignore_nullability(upper_dt) {
283 vortex_bail!(
284 "Array dtype {} does not match upper dtype {}",
285 arr_dt,
286 upper_dt
287 );
288 }
289
290 Ok(Bool(
291 arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
292 ))
293 }
294
295 fn execute(
296 &self,
297 options: &Self::Options,
298 args: &dyn ExecutionArgs,
299 ctx: &mut ExecutionCtx,
300 ) -> VortexResult<ArrayRef> {
301 let arr = args.get(0)?;
302 let lower = args.get(1)?;
303 let upper = args.get(2)?;
304
305 if !arr.is_canonical() {
307 return arr.execute::<Canonical>(ctx)?.into_array().between(
308 lower,
309 upper,
310 options.clone(),
311 );
312 }
313
314 between_canonical(&arr, &lower, &upper, options, ctx)
315 }
316
317 fn stat_falsification(
318 &self,
319 options: &Self::Options,
320 expr: &Expression,
321 catalog: &dyn StatsCatalog,
322 ) -> Option<Expression> {
323 let arr = expr.child(0).clone();
324 let lower = expr.child(1).clone();
325 let upper = expr.child(2).clone();
326
327 let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]);
328 let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
329
330 Binary
331 .new_expr(Operator::And, [lhs, rhs])
332 .stat_falsification(catalog)
333 }
334
335 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
336 false
337 }
338
339 fn is_fallible(&self, _options: &Self::Options) -> bool {
340 false
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use rstest::rstest;
347 use vortex_buffer::buffer;
348
349 use super::*;
350 use crate::IntoArray;
351 use crate::LEGACY_SESSION;
352 use crate::ToCanonical;
353 use crate::VortexSessionExecute;
354 use crate::arrays::BoolArray;
355 use crate::arrays::DecimalArray;
356 use crate::assert_arrays_eq;
357 use crate::dtype::DType;
358 use crate::dtype::DecimalDType;
359 use crate::dtype::Nullability;
360 use crate::dtype::PType;
361 use crate::expr::between;
362 use crate::expr::get_item;
363 use crate::expr::lit;
364 use crate::expr::root;
365 use crate::scalar::DecimalValue;
366 use crate::scalar::Scalar;
367 use crate::test_harness::to_int_indices;
368 use crate::validity::Validity;
369
370 #[test]
371 fn test_display() {
372 let expr = between(
373 get_item("score", root()),
374 lit(10),
375 lit(50),
376 BetweenOptions {
377 lower_strict: StrictComparison::NonStrict,
378 upper_strict: StrictComparison::Strict,
379 },
380 );
381 assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
382
383 let expr2 = between(
384 root(),
385 lit(0),
386 lit(100),
387 BetweenOptions {
388 lower_strict: StrictComparison::Strict,
389 upper_strict: StrictComparison::NonStrict,
390 },
391 );
392 assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
393 }
394
395 #[rstest]
396 #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
397 #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
398 #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
399 #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
400 fn test_bounds(
401 #[case] lower_strict: StrictComparison,
402 #[case] upper_strict: StrictComparison,
403 #[case] expected: Vec<u64>,
404 ) {
405 let lower = buffer![0, 0, 0, 0, 2].into_array();
406 let array = buffer![1, 0, 1, 0, 1].into_array();
407 let upper = buffer![2, 1, 1, 0, 0].into_array();
408
409 let matches = between_canonical(
410 &array,
411 &lower,
412 &upper,
413 &BetweenOptions {
414 lower_strict,
415 upper_strict,
416 },
417 &mut LEGACY_SESSION.create_execution_ctx(),
418 )
419 .unwrap()
420 .to_bool();
421
422 let indices = to_int_indices(matches).unwrap();
423 assert_eq!(indices, expected);
424 }
425
426 #[test]
427 fn test_constants() {
428 let lower = buffer![0, 0, 2, 0, 2].into_array();
429 let array = buffer![1, 0, 1, 0, 1].into_array();
430
431 let upper = ConstantArray::new(
433 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
434 5,
435 )
436 .into_array();
437
438 let matches = between_canonical(
439 &array,
440 &lower,
441 &upper,
442 &BetweenOptions {
443 lower_strict: StrictComparison::NonStrict,
444 upper_strict: StrictComparison::NonStrict,
445 },
446 &mut LEGACY_SESSION.create_execution_ctx(),
447 )
448 .unwrap()
449 .to_bool();
450
451 let indices = to_int_indices(matches).unwrap();
452 assert!(indices.is_empty());
453
454 let upper = ConstantArray::new(Scalar::from(2), 5).into_array();
456 let matches = between_canonical(
457 &array,
458 &lower,
459 &upper,
460 &BetweenOptions {
461 lower_strict: StrictComparison::NonStrict,
462 upper_strict: StrictComparison::NonStrict,
463 },
464 &mut LEGACY_SESSION.create_execution_ctx(),
465 )
466 .unwrap()
467 .to_bool();
468 let indices = to_int_indices(matches).unwrap();
469 assert_eq!(indices, vec![0, 1, 3]);
470
471 let lower = ConstantArray::new(Scalar::from(0), 5).into_array();
473
474 let matches = between_canonical(
475 &array,
476 &lower,
477 &upper,
478 &BetweenOptions {
479 lower_strict: StrictComparison::NonStrict,
480 upper_strict: StrictComparison::NonStrict,
481 },
482 &mut LEGACY_SESSION.create_execution_ctx(),
483 )
484 .unwrap()
485 .to_bool();
486 let indices = to_int_indices(matches).unwrap();
487 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
488 }
489
490 #[test]
491 fn test_between_decimal() {
492 let values = buffer![100i128, 200i128, 300i128, 400i128];
493 let decimal_type = DecimalDType::new(3, 2);
494 let array = DecimalArray::new(values, decimal_type, Validity::NonNullable).into_array();
495
496 let lower = ConstantArray::new(
497 Scalar::decimal(
498 DecimalValue::I128(100i128),
499 decimal_type,
500 Nullability::NonNullable,
501 ),
502 array.len(),
503 )
504 .into_array();
505 let upper = ConstantArray::new(
506 Scalar::decimal(
507 DecimalValue::I128(400i128),
508 decimal_type,
509 Nullability::NonNullable,
510 ),
511 array.len(),
512 )
513 .into_array();
514
515 let between_strict = between_canonical(
517 &array,
518 &lower,
519 &upper,
520 &BetweenOptions {
521 lower_strict: StrictComparison::Strict,
522 upper_strict: StrictComparison::NonStrict,
523 },
524 &mut LEGACY_SESSION.create_execution_ctx(),
525 )
526 .unwrap();
527 assert_arrays_eq!(
528 between_strict,
529 BoolArray::from_iter([false, true, true, true])
530 );
531
532 let between_strict = between_canonical(
534 &array,
535 &lower,
536 &upper,
537 &BetweenOptions {
538 lower_strict: StrictComparison::NonStrict,
539 upper_strict: StrictComparison::Strict,
540 },
541 &mut LEGACY_SESSION.create_execution_ctx(),
542 )
543 .unwrap();
544 assert_arrays_eq!(
545 between_strict,
546 BoolArray::from_iter([true, true, true, false])
547 );
548 }
549}