1mod kernel;
5
6use std::fmt::Display;
7use std::fmt::Formatter;
8
9pub use kernel::*;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::Canonical;
18use crate::ExecutionCtx;
19use crate::IntoArray;
20use crate::arrays::ConstantArray;
21use crate::arrays::Decimal;
22use crate::arrays::Primitive;
23use crate::builtins::ArrayBuiltins;
24use crate::dtype::DType;
25use crate::dtype::DType::Bool;
26use crate::expr::StatsCatalog;
27use crate::expr::expression::Expression;
28use crate::scalar::Scalar;
29use crate::scalar_fn::Arity;
30use crate::scalar_fn::ChildName;
31use crate::scalar_fn::ExecutionArgs;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34use crate::scalar_fn::ScalarFnVTableExt;
35use crate::scalar_fn::fns::binary::Binary;
36use crate::scalar_fn::fns::binary::execute_boolean;
37use crate::scalar_fn::fns::operators::CompareOperator;
38use crate::scalar_fn::fns::operators::Operator;
39
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41pub struct BetweenOptions {
42 pub lower_strict: StrictComparison,
43 pub upper_strict: StrictComparison,
44}
45
46impl Display for BetweenOptions {
47 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48 let lower_op = if self.lower_strict.is_strict() {
49 "<"
50 } else {
51 "<="
52 };
53 let upper_op = if self.upper_strict.is_strict() {
54 "<"
55 } else {
56 "<="
57 };
58 write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
59 }
60}
61
62#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
64pub enum StrictComparison {
65 Strict,
67 NonStrict,
69}
70
71impl StrictComparison {
72 pub const fn to_compare_operator(&self) -> CompareOperator {
73 match self {
74 StrictComparison::Strict => CompareOperator::Lt,
75 StrictComparison::NonStrict => CompareOperator::Lte,
76 }
77 }
78
79 pub const fn to_operator(&self) -> Operator {
80 match self {
81 StrictComparison::Strict => Operator::Lt,
82 StrictComparison::NonStrict => Operator::Lte,
83 }
84 }
85
86 pub const fn is_strict(&self) -> bool {
87 matches!(self, StrictComparison::Strict)
88 }
89}
90
91pub(super) fn precondition(
97 arr: &ArrayRef,
98 lower: &ArrayRef,
99 upper: &ArrayRef,
100) -> VortexResult<Option<ArrayRef>> {
101 let return_dtype =
102 Bool(arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability());
103
104 if arr.is_empty() {
106 return Ok(Some(Canonical::empty(&return_dtype).into_array()));
107 }
108
109 if (lower.is_invalid(0)? || upper.is_invalid(0)?)
111 && let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
112 && (c_lower.is_null() || c_upper.is_null())
113 {
114 return Ok(Some(
115 ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
116 ));
117 }
118
119 if lower.as_constant().is_some_and(|v| v.is_null())
120 || upper.as_constant().is_some_and(|v| v.is_null())
121 {
122 return Ok(Some(
123 ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
124 ));
125 }
126
127 Ok(None)
128}
129
130fn between_canonical(
134 arr: &ArrayRef,
135 lower: &ArrayRef,
136 upper: &ArrayRef,
137 options: &BetweenOptions,
138 ctx: &mut ExecutionCtx,
139) -> VortexResult<ArrayRef> {
140 if let Some(result) = precondition(arr, lower, upper)? {
141 return Ok(result);
142 }
143
144 if let Some(prim) = arr.as_opt::<Primitive>()
146 && let Some(result) =
147 <Primitive as BetweenKernel>::between(prim, lower, upper, options, ctx)?
148 {
149 return Ok(result);
150 }
151 if let Some(dec) = arr.as_opt::<Decimal>()
152 && let Some(result) = <Decimal as BetweenKernel>::between(dec, lower, upper, options, ctx)?
153 {
154 return Ok(result);
155 }
156
157 let lower_cmp = lower.clone().binary(
160 arr.clone(),
161 Operator::from(options.lower_strict.to_compare_operator()),
162 )?;
163 let upper_cmp = arr.clone().binary(
164 upper.clone(),
165 Operator::from(options.upper_strict.to_compare_operator()),
166 )?;
167 execute_boolean(&lower_cmp, &upper_cmp, Operator::And)
168}
169
170#[derive(Clone)]
182pub struct Between;
183
184impl ScalarFnVTable for Between {
185 type Options = BetweenOptions;
186
187 fn id(&self) -> ScalarFnId {
188 ScalarFnId::from("vortex.between")
189 }
190
191 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
192 Ok(Some(
193 pb::BetweenOpts {
194 lower_strict: instance.lower_strict.is_strict(),
195 upper_strict: instance.upper_strict.is_strict(),
196 }
197 .encode_to_vec(),
198 ))
199 }
200
201 fn deserialize(
202 &self,
203 _metadata: &[u8],
204 _session: &VortexSession,
205 ) -> VortexResult<Self::Options> {
206 let opts = pb::BetweenOpts::decode(_metadata)?;
207 Ok(BetweenOptions {
208 lower_strict: if opts.lower_strict {
209 StrictComparison::Strict
210 } else {
211 StrictComparison::NonStrict
212 },
213 upper_strict: if opts.upper_strict {
214 StrictComparison::Strict
215 } else {
216 StrictComparison::NonStrict
217 },
218 })
219 }
220
221 fn arity(&self, _options: &Self::Options) -> Arity {
222 Arity::Exact(3)
223 }
224
225 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
226 match child_idx {
227 0 => ChildName::from("array"),
228 1 => ChildName::from("lower"),
229 2 => ChildName::from("upper"),
230 _ => unreachable!("Invalid child index {} for Between expression", child_idx),
231 }
232 }
233
234 fn fmt_sql(
235 &self,
236 options: &Self::Options,
237 expr: &Expression,
238 f: &mut Formatter<'_>,
239 ) -> std::fmt::Result {
240 let lower_op = if options.lower_strict.is_strict() {
241 "<"
242 } else {
243 "<="
244 };
245 let upper_op = if options.upper_strict.is_strict() {
246 "<"
247 } else {
248 "<="
249 };
250 write!(
251 f,
252 "({} {} {} {} {})",
253 expr.child(1),
254 lower_op,
255 expr.child(0),
256 upper_op,
257 expr.child(2)
258 )
259 }
260
261 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
262 let arr_dt = &arg_dtypes[0];
263 let lower_dt = &arg_dtypes[1];
264 let upper_dt = &arg_dtypes[2];
265
266 if !arr_dt.eq_ignore_nullability(lower_dt) {
267 vortex_bail!(
268 "Array dtype {} does not match lower dtype {}",
269 arr_dt,
270 lower_dt
271 );
272 }
273 if !arr_dt.eq_ignore_nullability(upper_dt) {
274 vortex_bail!(
275 "Array dtype {} does not match upper dtype {}",
276 arr_dt,
277 upper_dt
278 );
279 }
280
281 Ok(Bool(
282 arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
283 ))
284 }
285
286 fn execute(
287 &self,
288 options: &Self::Options,
289 args: &dyn ExecutionArgs,
290 ctx: &mut ExecutionCtx,
291 ) -> VortexResult<ArrayRef> {
292 let arr = args.get(0)?;
293 let lower = args.get(1)?;
294 let upper = args.get(2)?;
295
296 if !arr.is_canonical() {
298 return arr.execute::<Canonical>(ctx)?.into_array().between(
299 lower,
300 upper,
301 options.clone(),
302 );
303 }
304
305 between_canonical(&arr, &lower, &upper, options, ctx)
306 }
307
308 fn stat_falsification(
309 &self,
310 options: &Self::Options,
311 expr: &Expression,
312 catalog: &dyn StatsCatalog,
313 ) -> Option<Expression> {
314 let arr = expr.child(0).clone();
315 let lower = expr.child(1).clone();
316 let upper = expr.child(2).clone();
317
318 let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]);
319 let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
320
321 Binary
322 .new_expr(Operator::And, [lhs, rhs])
323 .stat_falsification(catalog)
324 }
325
326 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
327 false
328 }
329
330 fn is_fallible(&self, _options: &Self::Options) -> bool {
331 false
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use rstest::rstest;
338 use vortex_buffer::buffer;
339
340 use super::*;
341 use crate::IntoArray;
342 use crate::LEGACY_SESSION;
343 use crate::ToCanonical;
344 use crate::VortexSessionExecute;
345 use crate::arrays::BoolArray;
346 use crate::arrays::DecimalArray;
347 use crate::assert_arrays_eq;
348 use crate::dtype::DType;
349 use crate::dtype::DecimalDType;
350 use crate::dtype::Nullability;
351 use crate::dtype::PType;
352 use crate::expr::between;
353 use crate::expr::get_item;
354 use crate::expr::lit;
355 use crate::expr::root;
356 use crate::scalar::DecimalValue;
357 use crate::scalar::Scalar;
358 use crate::test_harness::to_int_indices;
359 use crate::validity::Validity;
360
361 #[test]
362 fn test_display() {
363 let expr = between(
364 get_item("score", root()),
365 lit(10),
366 lit(50),
367 BetweenOptions {
368 lower_strict: StrictComparison::NonStrict,
369 upper_strict: StrictComparison::Strict,
370 },
371 );
372 assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
373
374 let expr2 = between(
375 root(),
376 lit(0),
377 lit(100),
378 BetweenOptions {
379 lower_strict: StrictComparison::Strict,
380 upper_strict: StrictComparison::NonStrict,
381 },
382 );
383 assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
384 }
385
386 #[rstest]
387 #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
388 #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
389 #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
390 #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
391 fn test_bounds(
392 #[case] lower_strict: StrictComparison,
393 #[case] upper_strict: StrictComparison,
394 #[case] expected: Vec<u64>,
395 ) {
396 let lower = buffer![0, 0, 0, 0, 2].into_array();
397 let array = buffer![1, 0, 1, 0, 1].into_array();
398 let upper = buffer![2, 1, 1, 0, 0].into_array();
399
400 let matches = between_canonical(
401 &array,
402 &lower,
403 &upper,
404 &BetweenOptions {
405 lower_strict,
406 upper_strict,
407 },
408 &mut LEGACY_SESSION.create_execution_ctx(),
409 )
410 .unwrap()
411 .to_bool();
412
413 let indices = to_int_indices(matches).unwrap();
414 assert_eq!(indices, expected);
415 }
416
417 #[test]
418 fn test_constants() {
419 let lower = buffer![0, 0, 2, 0, 2].into_array();
420 let array = buffer![1, 0, 1, 0, 1].into_array();
421
422 let upper = ConstantArray::new(
424 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
425 5,
426 )
427 .into_array();
428
429 let matches = between_canonical(
430 &array,
431 &lower,
432 &upper,
433 &BetweenOptions {
434 lower_strict: StrictComparison::NonStrict,
435 upper_strict: StrictComparison::NonStrict,
436 },
437 &mut LEGACY_SESSION.create_execution_ctx(),
438 )
439 .unwrap()
440 .to_bool();
441
442 let indices = to_int_indices(matches).unwrap();
443 assert!(indices.is_empty());
444
445 let upper = ConstantArray::new(Scalar::from(2), 5).into_array();
447 let matches = between_canonical(
448 &array,
449 &lower,
450 &upper,
451 &BetweenOptions {
452 lower_strict: StrictComparison::NonStrict,
453 upper_strict: StrictComparison::NonStrict,
454 },
455 &mut LEGACY_SESSION.create_execution_ctx(),
456 )
457 .unwrap()
458 .to_bool();
459 let indices = to_int_indices(matches).unwrap();
460 assert_eq!(indices, vec![0, 1, 3]);
461
462 let lower = ConstantArray::new(Scalar::from(0), 5).into_array();
464
465 let matches = between_canonical(
466 &array,
467 &lower,
468 &upper,
469 &BetweenOptions {
470 lower_strict: StrictComparison::NonStrict,
471 upper_strict: StrictComparison::NonStrict,
472 },
473 &mut LEGACY_SESSION.create_execution_ctx(),
474 )
475 .unwrap()
476 .to_bool();
477 let indices = to_int_indices(matches).unwrap();
478 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
479 }
480
481 #[test]
482 fn test_between_decimal() {
483 let values = buffer![100i128, 200i128, 300i128, 400i128];
484 let decimal_type = DecimalDType::new(3, 2);
485 let array = DecimalArray::new(values, decimal_type, Validity::NonNullable).into_array();
486
487 let lower = ConstantArray::new(
488 Scalar::decimal(
489 DecimalValue::I128(100i128),
490 decimal_type,
491 Nullability::NonNullable,
492 ),
493 array.len(),
494 )
495 .into_array();
496 let upper = ConstantArray::new(
497 Scalar::decimal(
498 DecimalValue::I128(400i128),
499 decimal_type,
500 Nullability::NonNullable,
501 ),
502 array.len(),
503 )
504 .into_array();
505
506 let between_strict = between_canonical(
508 &array,
509 &lower,
510 &upper,
511 &BetweenOptions {
512 lower_strict: StrictComparison::Strict,
513 upper_strict: StrictComparison::NonStrict,
514 },
515 &mut LEGACY_SESSION.create_execution_ctx(),
516 )
517 .unwrap();
518 assert_arrays_eq!(
519 between_strict,
520 BoolArray::from_iter([false, true, true, true])
521 );
522
523 let between_strict = between_canonical(
525 &array,
526 &lower,
527 &upper,
528 &BetweenOptions {
529 lower_strict: StrictComparison::NonStrict,
530 upper_strict: StrictComparison::Strict,
531 },
532 &mut LEGACY_SESSION.create_execution_ctx(),
533 )
534 .unwrap();
535 assert_arrays_eq!(
536 between_strict,
537 BoolArray::from_iter([true, true, true, false])
538 );
539 }
540}