1mod kernel;
5
6use std::any::Any;
7use std::fmt::Display;
8use std::fmt::Formatter;
9
10pub use kernel::*;
11use prost::Message;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_error::vortex_err;
15use vortex_proto::expr as pb;
16use vortex_session::VortexSession;
17
18use crate::Array;
19use crate::ArrayRef;
20use crate::Canonical;
21use crate::ExecutionCtx;
22use crate::IntoArray;
23use crate::arrays::ConstantArray;
24use crate::arrays::DecimalVTable;
25use crate::arrays::PrimitiveVTable;
26use crate::builtins::ArrayBuiltins;
27use crate::compute::Options;
28use crate::dtype::DType;
29use crate::dtype::DType::Bool;
30use crate::expr::StatsCatalog;
31use crate::expr::expression::Expression;
32use crate::scalar::Scalar;
33use crate::scalar_fn::Arity;
34use crate::scalar_fn::ChildName;
35use crate::scalar_fn::ExecutionArgs;
36use crate::scalar_fn::ScalarFnId;
37use crate::scalar_fn::ScalarFnVTable;
38use crate::scalar_fn::ScalarFnVTableExt;
39use crate::scalar_fn::fns::binary::Binary;
40use crate::scalar_fn::fns::binary::execute_boolean;
41use crate::scalar_fn::fns::operators::CompareOperator;
42use crate::scalar_fn::fns::operators::Operator;
43
44#[derive(Debug, Clone, PartialEq, Eq, Hash)]
45pub struct BetweenOptions {
46 pub lower_strict: StrictComparison,
47 pub upper_strict: StrictComparison,
48}
49
50impl Display for BetweenOptions {
51 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
52 let lower_op = if self.lower_strict.is_strict() {
53 "<"
54 } else {
55 "<="
56 };
57 let upper_op = if self.upper_strict.is_strict() {
58 "<"
59 } else {
60 "<="
61 };
62 write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
63 }
64}
65
66impl Options for BetweenOptions {
67 fn as_any(&self) -> &dyn Any {
68 self
69 }
70}
71
72#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
74pub enum StrictComparison {
75 Strict,
77 NonStrict,
79}
80
81impl StrictComparison {
82 pub const fn to_compare_operator(&self) -> CompareOperator {
83 match self {
84 StrictComparison::Strict => CompareOperator::Lt,
85 StrictComparison::NonStrict => CompareOperator::Lte,
86 }
87 }
88
89 pub const fn to_operator(&self) -> Operator {
90 match self {
91 StrictComparison::Strict => Operator::Lt,
92 StrictComparison::NonStrict => Operator::Lte,
93 }
94 }
95
96 pub const fn is_strict(&self) -> bool {
97 matches!(self, StrictComparison::Strict)
98 }
99}
100
101pub(super) fn precondition(
107 arr: &dyn Array,
108 lower: &dyn Array,
109 upper: &dyn Array,
110) -> VortexResult<Option<ArrayRef>> {
111 let return_dtype =
112 Bool(arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability());
113
114 if arr.is_empty() {
116 return Ok(Some(Canonical::empty(&return_dtype).into_array()));
117 }
118
119 if (lower.is_invalid(0)? || upper.is_invalid(0)?)
121 && let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
122 && (c_lower.is_null() || c_upper.is_null())
123 {
124 return Ok(Some(
125 ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
126 ));
127 }
128
129 if lower.as_constant().is_some_and(|v| v.is_null())
130 || upper.as_constant().is_some_and(|v| v.is_null())
131 {
132 return Ok(Some(
133 ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
134 ));
135 }
136
137 Ok(None)
138}
139
140fn between_canonical(
144 arr: &dyn Array,
145 lower: &dyn Array,
146 upper: &dyn Array,
147 options: &BetweenOptions,
148 ctx: &mut ExecutionCtx,
149) -> VortexResult<ArrayRef> {
150 if let Some(result) = precondition(arr, lower, upper)? {
151 return Ok(result);
152 }
153
154 if let Some(prim) = arr.as_opt::<PrimitiveVTable>()
156 && let Some(result) =
157 <PrimitiveVTable as BetweenKernel>::between(prim, lower, upper, options, ctx)?
158 {
159 return Ok(result);
160 }
161 if let Some(dec) = arr.as_opt::<DecimalVTable>()
162 && let Some(result) =
163 <DecimalVTable as BetweenKernel>::between(dec, lower, upper, options, ctx)?
164 {
165 return Ok(result);
166 }
167
168 let lower_cmp = lower.to_array().binary(
171 arr.to_array(),
172 Operator::from(options.lower_strict.to_compare_operator()),
173 )?;
174 let upper_cmp = arr.to_array().binary(
175 upper.to_array(),
176 Operator::from(options.upper_strict.to_compare_operator()),
177 )?;
178 execute_boolean(&lower_cmp, &upper_cmp, Operator::And)
179}
180
181#[derive(Clone)]
193pub struct Between;
194
195impl ScalarFnVTable for Between {
196 type Options = BetweenOptions;
197
198 fn id(&self) -> ScalarFnId {
199 ScalarFnId::from("vortex.between")
200 }
201
202 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
203 Ok(Some(
204 pb::BetweenOpts {
205 lower_strict: instance.lower_strict.is_strict(),
206 upper_strict: instance.upper_strict.is_strict(),
207 }
208 .encode_to_vec(),
209 ))
210 }
211
212 fn deserialize(
213 &self,
214 _metadata: &[u8],
215 _session: &VortexSession,
216 ) -> VortexResult<Self::Options> {
217 let opts = pb::BetweenOpts::decode(_metadata)?;
218 Ok(BetweenOptions {
219 lower_strict: if opts.lower_strict {
220 StrictComparison::Strict
221 } else {
222 StrictComparison::NonStrict
223 },
224 upper_strict: if opts.upper_strict {
225 StrictComparison::Strict
226 } else {
227 StrictComparison::NonStrict
228 },
229 })
230 }
231
232 fn arity(&self, _options: &Self::Options) -> Arity {
233 Arity::Exact(3)
234 }
235
236 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
237 match child_idx {
238 0 => ChildName::from("array"),
239 1 => ChildName::from("lower"),
240 2 => ChildName::from("upper"),
241 _ => unreachable!("Invalid child index {} for Between expression", child_idx),
242 }
243 }
244
245 fn fmt_sql(
246 &self,
247 options: &Self::Options,
248 expr: &Expression,
249 f: &mut Formatter<'_>,
250 ) -> std::fmt::Result {
251 let lower_op = if options.lower_strict.is_strict() {
252 "<"
253 } else {
254 "<="
255 };
256 let upper_op = if options.upper_strict.is_strict() {
257 "<"
258 } else {
259 "<="
260 };
261 write!(
262 f,
263 "({} {} {} {} {})",
264 expr.child(1),
265 lower_op,
266 expr.child(0),
267 upper_op,
268 expr.child(2)
269 )
270 }
271
272 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
273 let arr_dt = &arg_dtypes[0];
274 let lower_dt = &arg_dtypes[1];
275 let upper_dt = &arg_dtypes[2];
276
277 if !arr_dt.eq_ignore_nullability(lower_dt) {
278 vortex_bail!(
279 "Array dtype {} does not match lower dtype {}",
280 arr_dt,
281 lower_dt
282 );
283 }
284 if !arr_dt.eq_ignore_nullability(upper_dt) {
285 vortex_bail!(
286 "Array dtype {} does not match upper dtype {}",
287 arr_dt,
288 upper_dt
289 );
290 }
291
292 Ok(Bool(
293 arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
294 ))
295 }
296
297 fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
298 let [arr, lower, upper]: [ArrayRef; _] = args
299 .inputs
300 .try_into()
301 .map_err(|_| vortex_err!("Expected 3 arguments for Between expression",))?;
302
303 if !arr.is_canonical() {
305 return arr.execute::<Canonical>(args.ctx)?.into_array().between(
306 lower,
307 upper,
308 options.clone(),
309 );
310 }
311
312 between_canonical(
313 arr.as_ref(),
314 lower.as_ref(),
315 upper.as_ref(),
316 options,
317 args.ctx,
318 )
319 }
320
321 fn stat_falsification(
322 &self,
323 options: &Self::Options,
324 expr: &Expression,
325 catalog: &dyn StatsCatalog,
326 ) -> Option<Expression> {
327 let arr = expr.child(0).clone();
328 let lower = expr.child(1).clone();
329 let upper = expr.child(2).clone();
330
331 let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]);
332 let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
333
334 Binary
335 .new_expr(Operator::And, [lhs, rhs])
336 .stat_falsification(catalog)
337 }
338
339 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
340 false
341 }
342
343 fn is_fallible(&self, _options: &Self::Options) -> bool {
344 false
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use rstest::rstest;
351 use vortex_buffer::buffer;
352
353 use super::*;
354 use crate::IntoArray;
355 use crate::LEGACY_SESSION;
356 use crate::ToCanonical;
357 use crate::VortexSessionExecute;
358 use crate::arrays::BoolArray;
359 use crate::arrays::DecimalArray;
360 use crate::assert_arrays_eq;
361 use crate::dtype::DType;
362 use crate::dtype::DecimalDType;
363 use crate::dtype::Nullability;
364 use crate::dtype::PType;
365 use crate::expr::between;
366 use crate::expr::get_item;
367 use crate::expr::lit;
368 use crate::expr::root;
369 use crate::scalar::DecimalValue;
370 use crate::scalar::Scalar;
371 use crate::test_harness::to_int_indices;
372 use crate::validity::Validity;
373
374 #[test]
375 fn test_display() {
376 let expr = between(
377 get_item("score", root()),
378 lit(10),
379 lit(50),
380 BetweenOptions {
381 lower_strict: StrictComparison::NonStrict,
382 upper_strict: StrictComparison::Strict,
383 },
384 );
385 assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
386
387 let expr2 = between(
388 root(),
389 lit(0),
390 lit(100),
391 BetweenOptions {
392 lower_strict: StrictComparison::Strict,
393 upper_strict: StrictComparison::NonStrict,
394 },
395 );
396 assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
397 }
398
399 #[rstest]
400 #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
401 #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
402 #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
403 #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
404 fn test_bounds(
405 #[case] lower_strict: StrictComparison,
406 #[case] upper_strict: StrictComparison,
407 #[case] expected: Vec<u64>,
408 ) {
409 let lower = buffer![0, 0, 0, 0, 2].into_array();
410 let array = buffer![1, 0, 1, 0, 1].into_array();
411 let upper = buffer![2, 1, 1, 0, 0].into_array();
412
413 let matches = between_canonical(
414 array.as_ref(),
415 lower.as_ref(),
416 upper.as_ref(),
417 &BetweenOptions {
418 lower_strict,
419 upper_strict,
420 },
421 &mut LEGACY_SESSION.create_execution_ctx(),
422 )
423 .unwrap()
424 .to_bool();
425
426 let indices = to_int_indices(matches).unwrap();
427 assert_eq!(indices, expected);
428 }
429
430 #[test]
431 fn test_constants() {
432 let lower = buffer![0, 0, 2, 0, 2].into_array();
433 let array = buffer![1, 0, 1, 0, 1].into_array();
434
435 let upper = ConstantArray::new(
437 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
438 5,
439 );
440
441 let matches = between_canonical(
442 array.as_ref(),
443 lower.as_ref(),
444 upper.as_ref(),
445 &BetweenOptions {
446 lower_strict: StrictComparison::NonStrict,
447 upper_strict: StrictComparison::NonStrict,
448 },
449 &mut LEGACY_SESSION.create_execution_ctx(),
450 )
451 .unwrap()
452 .to_bool();
453
454 let indices = to_int_indices(matches).unwrap();
455 assert!(indices.is_empty());
456
457 let upper = ConstantArray::new(Scalar::from(2), 5);
459 let matches = between_canonical(
460 array.as_ref(),
461 lower.as_ref(),
462 upper.as_ref(),
463 &BetweenOptions {
464 lower_strict: StrictComparison::NonStrict,
465 upper_strict: StrictComparison::NonStrict,
466 },
467 &mut LEGACY_SESSION.create_execution_ctx(),
468 )
469 .unwrap()
470 .to_bool();
471 let indices = to_int_indices(matches).unwrap();
472 assert_eq!(indices, vec![0, 1, 3]);
473
474 let lower = ConstantArray::new(Scalar::from(0), 5);
476
477 let matches = between_canonical(
478 array.as_ref(),
479 lower.as_ref(),
480 upper.as_ref(),
481 &BetweenOptions {
482 lower_strict: StrictComparison::NonStrict,
483 upper_strict: StrictComparison::NonStrict,
484 },
485 &mut LEGACY_SESSION.create_execution_ctx(),
486 )
487 .unwrap()
488 .to_bool();
489 let indices = to_int_indices(matches).unwrap();
490 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
491 }
492
493 #[test]
494 fn test_between_decimal() {
495 let values = buffer![100i128, 200i128, 300i128, 400i128];
496 let decimal_type = DecimalDType::new(3, 2);
497 let array = DecimalArray::new(values, decimal_type, Validity::NonNullable);
498
499 let lower = ConstantArray::new(
500 Scalar::decimal(
501 DecimalValue::I128(100i128),
502 decimal_type,
503 Nullability::NonNullable,
504 ),
505 array.len(),
506 );
507 let upper = ConstantArray::new(
508 Scalar::decimal(
509 DecimalValue::I128(400i128),
510 decimal_type,
511 Nullability::NonNullable,
512 ),
513 array.len(),
514 );
515
516 let between_strict = between_canonical(
518 array.as_ref(),
519 lower.as_ref(),
520 upper.as_ref(),
521 &BetweenOptions {
522 lower_strict: StrictComparison::Strict,
523 upper_strict: StrictComparison::NonStrict,
524 },
525 &mut LEGACY_SESSION.create_execution_ctx(),
526 )
527 .unwrap();
528 assert_arrays_eq!(
529 between_strict,
530 BoolArray::from_iter([false, true, true, true])
531 );
532
533 let between_strict = between_canonical(
535 array.as_ref(),
536 lower.as_ref(),
537 upper.as_ref(),
538 &BetweenOptions {
539 lower_strict: StrictComparison::NonStrict,
540 upper_strict: StrictComparison::Strict,
541 },
542 &mut LEGACY_SESSION.create_execution_ctx(),
543 )
544 .unwrap();
545 assert_arrays_eq!(
546 between_strict,
547 BoolArray::from_iter([true, true, true, false])
548 );
549 }
550}