1mod kernel;
5
6use std::fmt::Display;
7use std::fmt::Formatter;
8
9pub use kernel::*;
10use prost::Message;
11use vortex_array::expr::and;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_proto::expr as pb;
15use vortex_session::VortexSession;
16
17use crate::ArrayRef;
18use crate::Canonical;
19use crate::ExecutionCtx;
20use crate::IntoArray;
21use crate::arrays::ConstantArray;
22use crate::arrays::Decimal;
23use crate::arrays::Primitive;
24use crate::builtins::ArrayBuiltins;
25use crate::dtype::DType;
26use crate::dtype::DType::Bool;
27use crate::expr::StatsCatalog;
28use crate::expr::expression::Expression;
29use crate::scalar::Scalar;
30use crate::scalar_fn::Arity;
31use crate::scalar_fn::ChildName;
32use crate::scalar_fn::ExecutionArgs;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35use crate::scalar_fn::ScalarFnVTableExt;
36use crate::scalar_fn::fns::binary::Binary;
37use crate::scalar_fn::fns::binary::execute_boolean;
38use crate::scalar_fn::fns::operators::CompareOperator;
39use crate::scalar_fn::fns::operators::Operator;
40
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42pub struct BetweenOptions {
43 pub lower_strict: StrictComparison,
44 pub upper_strict: StrictComparison,
45}
46
47impl Display for BetweenOptions {
48 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
49 let lower_op = if self.lower_strict.is_strict() {
50 "<"
51 } else {
52 "<="
53 };
54 let upper_op = if self.upper_strict.is_strict() {
55 "<"
56 } else {
57 "<="
58 };
59 write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
60 }
61}
62
63#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
65pub enum StrictComparison {
66 Strict,
68 NonStrict,
70}
71
72impl StrictComparison {
73 pub const fn to_compare_operator(&self) -> CompareOperator {
74 match self {
75 StrictComparison::Strict => CompareOperator::Lt,
76 StrictComparison::NonStrict => CompareOperator::Lte,
77 }
78 }
79
80 pub const fn to_operator(&self) -> Operator {
81 match self {
82 StrictComparison::Strict => Operator::Lt,
83 StrictComparison::NonStrict => Operator::Lte,
84 }
85 }
86
87 pub const fn is_strict(&self) -> bool {
88 matches!(self, StrictComparison::Strict)
89 }
90}
91
92pub(super) fn precondition(
98 arr: &ArrayRef,
99 lower: &ArrayRef,
100 upper: &ArrayRef,
101) -> VortexResult<Option<ArrayRef>> {
102 let return_dtype =
103 Bool(arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability());
104
105 if arr.is_empty() {
107 return Ok(Some(Canonical::empty(&return_dtype).into_array()));
108 }
109
110 if lower.as_constant().is_some_and(|v| v.is_null())
111 || upper.as_constant().is_some_and(|v| v.is_null())
112 {
113 return Ok(Some(
114 ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
115 ));
116 }
117
118 Ok(None)
119}
120
121fn between_canonical(
125 arr: &ArrayRef,
126 lower: &ArrayRef,
127 upper: &ArrayRef,
128 options: &BetweenOptions,
129 ctx: &mut ExecutionCtx,
130) -> VortexResult<ArrayRef> {
131 if let Some(result) = precondition(arr, lower, upper)? {
132 return Ok(result);
133 }
134
135 if let Some(prim) = arr.as_opt::<Primitive>()
137 && let Some(result) =
138 <Primitive as BetweenKernel>::between(prim, lower, upper, options, ctx)?
139 {
140 return Ok(result);
141 }
142 if let Some(dec) = arr.as_opt::<Decimal>()
143 && let Some(result) = <Decimal as BetweenKernel>::between(dec, lower, upper, options, ctx)?
144 {
145 return Ok(result);
146 }
147
148 let lower_cmp = lower.clone().binary(
151 arr.clone(),
152 Operator::from(options.lower_strict.to_compare_operator()),
153 )?;
154 let upper_cmp = arr.clone().binary(
155 upper.clone(),
156 Operator::from(options.upper_strict.to_compare_operator()),
157 )?;
158 execute_boolean(&lower_cmp, &upper_cmp, Operator::And, ctx)
159}
160
161#[derive(Clone)]
173pub struct Between;
174
175impl ScalarFnVTable for Between {
176 type Options = BetweenOptions;
177
178 fn id(&self) -> ScalarFnId {
179 ScalarFnId::new("vortex.between")
180 }
181
182 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
183 Ok(Some(
184 pb::BetweenOpts {
185 lower_strict: instance.lower_strict.is_strict(),
186 upper_strict: instance.upper_strict.is_strict(),
187 }
188 .encode_to_vec(),
189 ))
190 }
191
192 fn deserialize(
193 &self,
194 _metadata: &[u8],
195 _session: &VortexSession,
196 ) -> VortexResult<Self::Options> {
197 let opts = pb::BetweenOpts::decode(_metadata)?;
198 Ok(BetweenOptions {
199 lower_strict: if opts.lower_strict {
200 StrictComparison::Strict
201 } else {
202 StrictComparison::NonStrict
203 },
204 upper_strict: if opts.upper_strict {
205 StrictComparison::Strict
206 } else {
207 StrictComparison::NonStrict
208 },
209 })
210 }
211
212 fn arity(&self, _options: &Self::Options) -> Arity {
213 Arity::Exact(3)
214 }
215
216 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
217 match child_idx {
218 0 => ChildName::from("array"),
219 1 => ChildName::from("lower"),
220 2 => ChildName::from("upper"),
221 _ => unreachable!("Invalid child index {} for Between expression", child_idx),
222 }
223 }
224
225 fn fmt_sql(
226 &self,
227 options: &Self::Options,
228 expr: &Expression,
229 f: &mut Formatter<'_>,
230 ) -> std::fmt::Result {
231 let lower_op = if options.lower_strict.is_strict() {
232 "<"
233 } else {
234 "<="
235 };
236 let upper_op = if options.upper_strict.is_strict() {
237 "<"
238 } else {
239 "<="
240 };
241 write!(
242 f,
243 "({} {} {} {} {})",
244 expr.child(1),
245 lower_op,
246 expr.child(0),
247 upper_op,
248 expr.child(2)
249 )
250 }
251
252 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
253 let arr_dt = &arg_dtypes[0];
254 let lower_dt = &arg_dtypes[1];
255 let upper_dt = &arg_dtypes[2];
256
257 if !arr_dt.eq_ignore_nullability(lower_dt) {
258 vortex_bail!(
259 "Array dtype {} does not match lower dtype {}",
260 arr_dt,
261 lower_dt
262 );
263 }
264 if !arr_dt.eq_ignore_nullability(upper_dt) {
265 vortex_bail!(
266 "Array dtype {} does not match upper dtype {}",
267 arr_dt,
268 upper_dt
269 );
270 }
271
272 Ok(Bool(
273 arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
274 ))
275 }
276
277 fn execute(
278 &self,
279 options: &Self::Options,
280 args: &dyn ExecutionArgs,
281 ctx: &mut ExecutionCtx,
282 ) -> VortexResult<ArrayRef> {
283 let arr = args.get(0)?;
284 let lower = args.get(1)?;
285 let upper = args.get(2)?;
286
287 if !arr.is_canonical() {
289 return arr.execute::<Canonical>(ctx)?.into_array().between(
290 lower,
291 upper,
292 options.clone(),
293 );
294 }
295
296 between_canonical(&arr, &lower, &upper, options, ctx)
297 }
298
299 fn stat_falsification(
300 &self,
301 options: &Self::Options,
302 expr: &Expression,
303 catalog: &dyn StatsCatalog,
304 ) -> Option<Expression> {
305 let arr = expr.child(0).clone();
306 let lower = expr.child(1).clone();
307 let upper = expr.child(2).clone();
308
309 let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]);
310 let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
311
312 and(lhs, rhs).stat_falsification(catalog)
313 }
314
315 fn validity(
316 &self,
317 _options: &Self::Options,
318 expression: &Expression,
319 ) -> VortexResult<Option<Expression>> {
320 let arr = expression.child(0).validity()?;
321 let lower = expression.child(1).validity()?;
322 let upper = expression.child(2).validity()?;
323 Ok(Some(and(and(arr, lower), upper)))
324 }
325
326 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
327 false
328 }
329
330 fn is_fallible(&self, _options: &Self::Options) -> bool {
331 false
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use std::sync::LazyLock;
338
339 use rstest::rstest;
340 use vortex_buffer::buffer;
341
342 use super::*;
343 use crate::IntoArray;
344 use crate::VortexSessionExecute;
345 use crate::arrays::BoolArray;
346 use crate::arrays::DecimalArray;
347 use crate::assert_arrays_eq;
348 use crate::dtype::DType;
349 use crate::dtype::DecimalDType;
350 use crate::dtype::Nullability;
351 use crate::dtype::PType;
352 use crate::expr::between;
353 use crate::expr::get_item;
354 use crate::expr::lit;
355 use crate::expr::root;
356 use crate::scalar::DecimalValue;
357 use crate::scalar::Scalar;
358 use crate::session::ArraySession;
359 use crate::test_harness::to_int_indices;
360 use crate::validity::Validity;
361
362 static SESSION: LazyLock<VortexSession> =
363 LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
364
365 #[test]
366 fn test_display() {
367 let expr = between(
368 get_item("score", root()),
369 lit(10),
370 lit(50),
371 BetweenOptions {
372 lower_strict: StrictComparison::NonStrict,
373 upper_strict: StrictComparison::Strict,
374 },
375 );
376 assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
377
378 let expr2 = between(
379 root(),
380 lit(0),
381 lit(100),
382 BetweenOptions {
383 lower_strict: StrictComparison::Strict,
384 upper_strict: StrictComparison::NonStrict,
385 },
386 );
387 assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
388 }
389
390 #[rstest]
391 #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
392 #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
393 #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
394 #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
395 fn test_bounds(
396 #[case] lower_strict: StrictComparison,
397 #[case] upper_strict: StrictComparison,
398 #[case] expected: Vec<u64>,
399 ) {
400 let lower = buffer![0, 0, 0, 0, 2].into_array();
401 let array = buffer![1, 0, 1, 0, 1].into_array();
402 let upper = buffer![2, 1, 1, 0, 0].into_array();
403
404 let matches = between_canonical(
405 &array,
406 &lower,
407 &upper,
408 &BetweenOptions {
409 lower_strict,
410 upper_strict,
411 },
412 &mut SESSION.create_execution_ctx(),
413 )
414 .unwrap()
415 .execute::<BoolArray>(&mut SESSION.create_execution_ctx())
416 .unwrap();
417
418 let indices = to_int_indices(matches).unwrap();
419 assert_eq!(indices, expected);
420 }
421
422 #[test]
423 fn test_constants() {
424 let lower = buffer![0, 0, 2, 0, 2].into_array();
425 let array = buffer![1, 0, 1, 0, 1].into_array();
426
427 let upper = ConstantArray::new(
429 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
430 5,
431 )
432 .into_array();
433
434 let matches = between_canonical(
435 &array,
436 &lower,
437 &upper,
438 &BetweenOptions {
439 lower_strict: StrictComparison::NonStrict,
440 upper_strict: StrictComparison::NonStrict,
441 },
442 &mut SESSION.create_execution_ctx(),
443 )
444 .unwrap()
445 .execute::<BoolArray>(&mut SESSION.create_execution_ctx())
446 .unwrap();
447
448 let indices = to_int_indices(matches).unwrap();
449 assert!(indices.is_empty());
450
451 let upper = ConstantArray::new(Scalar::from(2), 5).into_array();
453 let matches = between_canonical(
454 &array,
455 &lower,
456 &upper,
457 &BetweenOptions {
458 lower_strict: StrictComparison::NonStrict,
459 upper_strict: StrictComparison::NonStrict,
460 },
461 &mut SESSION.create_execution_ctx(),
462 )
463 .unwrap()
464 .execute::<BoolArray>(&mut SESSION.create_execution_ctx())
465 .unwrap();
466 let indices = to_int_indices(matches).unwrap();
467 assert_eq!(indices, vec![0, 1, 3]);
468
469 let lower = ConstantArray::new(Scalar::from(0), 5).into_array();
471
472 let matches = between_canonical(
473 &array,
474 &lower,
475 &upper,
476 &BetweenOptions {
477 lower_strict: StrictComparison::NonStrict,
478 upper_strict: StrictComparison::NonStrict,
479 },
480 &mut SESSION.create_execution_ctx(),
481 )
482 .unwrap()
483 .execute::<BoolArray>(&mut SESSION.create_execution_ctx())
484 .unwrap();
485 let indices = to_int_indices(matches).unwrap();
486 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
487 }
488
489 #[test]
490 fn test_between_decimal() {
491 let values = buffer![100i128, 200i128, 300i128, 400i128];
492 let decimal_type = DecimalDType::new(3, 2);
493 let array = DecimalArray::new(values, decimal_type, Validity::NonNullable).into_array();
494
495 let lower = ConstantArray::new(
496 Scalar::decimal(
497 DecimalValue::I128(100i128),
498 decimal_type,
499 Nullability::NonNullable,
500 ),
501 array.len(),
502 )
503 .into_array();
504 let upper = ConstantArray::new(
505 Scalar::decimal(
506 DecimalValue::I128(400i128),
507 decimal_type,
508 Nullability::NonNullable,
509 ),
510 array.len(),
511 )
512 .into_array();
513
514 let between_strict = between_canonical(
516 &array,
517 &lower,
518 &upper,
519 &BetweenOptions {
520 lower_strict: StrictComparison::Strict,
521 upper_strict: StrictComparison::NonStrict,
522 },
523 &mut SESSION.create_execution_ctx(),
524 )
525 .unwrap();
526 assert_arrays_eq!(
527 between_strict,
528 BoolArray::from_iter([false, true, true, true])
529 );
530
531 let between_strict = between_canonical(
533 &array,
534 &lower,
535 &upper,
536 &BetweenOptions {
537 lower_strict: StrictComparison::NonStrict,
538 upper_strict: StrictComparison::Strict,
539 },
540 &mut SESSION.create_execution_ctx(),
541 )
542 .unwrap();
543 assert_arrays_eq!(
544 between_strict,
545 BoolArray::from_iter([true, true, true, false])
546 );
547 }
548}