Skip to main content

vortex_array/scalar_fn/fns/between/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::fmt::Display;
7use std::fmt::Formatter;
8
9pub use kernel::*;
10use prost::Message;
11use vortex_array::expr::and;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_proto::expr as pb;
15use vortex_session::VortexSession;
16use vortex_session::registry::CachedId;
17
18use crate::ArrayRef;
19use crate::Canonical;
20use crate::ExecutionCtx;
21use crate::IntoArray;
22use crate::arrays::ConstantArray;
23use crate::arrays::Decimal;
24use crate::arrays::Primitive;
25use crate::builtins::ArrayBuiltins;
26use crate::dtype::DType;
27use crate::dtype::DType::Bool;
28use crate::expr::StatsCatalog;
29use crate::expr::expression::Expression;
30use crate::scalar::Scalar;
31use crate::scalar_fn::Arity;
32use crate::scalar_fn::ChildName;
33use crate::scalar_fn::ExecutionArgs;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36use crate::scalar_fn::ScalarFnVTableExt;
37use crate::scalar_fn::fns::binary::Binary;
38use crate::scalar_fn::fns::binary::execute_boolean;
39use crate::scalar_fn::fns::operators::CompareOperator;
40use crate::scalar_fn::fns::operators::Operator;
41
42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
43pub struct BetweenOptions {
44    pub lower_strict: StrictComparison,
45    pub upper_strict: StrictComparison,
46}
47
48impl Display for BetweenOptions {
49    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
50        let lower_op = if self.lower_strict.is_strict() {
51            "<"
52        } else {
53            "<="
54        };
55        let upper_op = if self.upper_strict.is_strict() {
56            "<"
57        } else {
58            "<="
59        };
60        write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
61    }
62}
63
64/// Strictness of the comparison.
65#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
66pub enum StrictComparison {
67    /// Strict bound (`<`)
68    Strict,
69    /// Non-strict bound (`<=`)
70    NonStrict,
71}
72
73impl StrictComparison {
74    pub const fn to_compare_operator(&self) -> CompareOperator {
75        match self {
76            StrictComparison::Strict => CompareOperator::Lt,
77            StrictComparison::NonStrict => CompareOperator::Lte,
78        }
79    }
80
81    pub const fn to_operator(&self) -> Operator {
82        match self {
83            StrictComparison::Strict => Operator::Lt,
84            StrictComparison::NonStrict => Operator::Lte,
85        }
86    }
87
88    pub const fn is_strict(&self) -> bool {
89        matches!(self, StrictComparison::Strict)
90    }
91}
92
93/// Common preconditions for between operations that apply to all arrays.
94///
95/// Returns `Some(result)` if the precondition short-circuits the between operation
96/// (empty array, null bounds), or `None` if between should proceed with the
97/// encoding-specific implementation.
98pub(super) fn precondition(
99    arr: &ArrayRef,
100    lower: &ArrayRef,
101    upper: &ArrayRef,
102) -> VortexResult<Option<ArrayRef>> {
103    let return_dtype =
104        Bool(arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability());
105
106    // Bail early if the array is empty.
107    if arr.is_empty() {
108        return Ok(Some(Canonical::empty(&return_dtype).into_array()));
109    }
110
111    if lower.as_constant().is_some_and(|v| v.is_null())
112        || upper.as_constant().is_some_and(|v| v.is_null())
113    {
114        return Ok(Some(
115            ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
116        ));
117    }
118
119    Ok(None)
120}
121
122/// Between on a canonical array by directly dispatching to the appropriate kernel.
123///
124/// Falls back to compare + boolean and if no kernel handles the input.
125fn between_canonical(
126    arr: &ArrayRef,
127    lower: &ArrayRef,
128    upper: &ArrayRef,
129    options: &BetweenOptions,
130    ctx: &mut ExecutionCtx,
131) -> VortexResult<ArrayRef> {
132    if let Some(result) = precondition(arr, lower, upper)? {
133        return Ok(result);
134    }
135
136    // Try type-specific kernels
137    if let Some(prim) = arr.as_opt::<Primitive>()
138        && let Some(result) =
139            <Primitive as BetweenKernel>::between(prim, lower, upper, options, ctx)?
140    {
141        return Ok(result);
142    }
143    if let Some(dec) = arr.as_opt::<Decimal>()
144        && let Some(result) = <Decimal as BetweenKernel>::between(dec, lower, upper, options, ctx)?
145    {
146        return Ok(result);
147    }
148
149    // TODO(joe): return lazy compare once the executor supports this
150    // Fall back to compare + boolean and
151    let lower_cmp = lower.clone().binary(
152        arr.clone(),
153        Operator::from(options.lower_strict.to_compare_operator()),
154    )?;
155    let upper_cmp = arr.clone().binary(
156        upper.clone(),
157        Operator::from(options.upper_strict.to_compare_operator()),
158    )?;
159    execute_boolean(&lower_cmp, &upper_cmp, Operator::And, ctx)
160}
161
162/// An optimized scalar expression to compute whether values fall between two bounds.
163///
164/// This expression takes three children:
165/// 1. The array of values to check.
166/// 2. The lower bound.
167/// 3. The upper bound.
168///
169/// The comparison strictness is controlled by the metadata.
170///
171/// NOTE: this expression will shortly be removed in favor of pipelined computation of two
172/// separate comparisons combined with a logical AND.
173#[derive(Clone)]
174pub struct Between;
175
176impl ScalarFnVTable for Between {
177    type Options = BetweenOptions;
178
179    fn id(&self) -> ScalarFnId {
180        static ID: CachedId = CachedId::new("vortex.between");
181        *ID
182    }
183
184    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
185        Ok(Some(
186            pb::BetweenOpts {
187                lower_strict: instance.lower_strict.is_strict(),
188                upper_strict: instance.upper_strict.is_strict(),
189            }
190            .encode_to_vec(),
191        ))
192    }
193
194    fn deserialize(
195        &self,
196        _metadata: &[u8],
197        _session: &VortexSession,
198    ) -> VortexResult<Self::Options> {
199        let opts = pb::BetweenOpts::decode(_metadata)?;
200        Ok(BetweenOptions {
201            lower_strict: if opts.lower_strict {
202                StrictComparison::Strict
203            } else {
204                StrictComparison::NonStrict
205            },
206            upper_strict: if opts.upper_strict {
207                StrictComparison::Strict
208            } else {
209                StrictComparison::NonStrict
210            },
211        })
212    }
213
214    fn arity(&self, _options: &Self::Options) -> Arity {
215        Arity::Exact(3)
216    }
217
218    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
219        match child_idx {
220            0 => ChildName::from("array"),
221            1 => ChildName::from("lower"),
222            2 => ChildName::from("upper"),
223            _ => unreachable!("Invalid child index {} for Between expression", child_idx),
224        }
225    }
226
227    fn fmt_sql(
228        &self,
229        options: &Self::Options,
230        expr: &Expression,
231        f: &mut Formatter<'_>,
232    ) -> std::fmt::Result {
233        let lower_op = if options.lower_strict.is_strict() {
234            "<"
235        } else {
236            "<="
237        };
238        let upper_op = if options.upper_strict.is_strict() {
239            "<"
240        } else {
241            "<="
242        };
243        write!(
244            f,
245            "({} {} {} {} {})",
246            expr.child(1),
247            lower_op,
248            expr.child(0),
249            upper_op,
250            expr.child(2)
251        )
252    }
253
254    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
255        let arr_dt = &arg_dtypes[0];
256        let lower_dt = &arg_dtypes[1];
257        let upper_dt = &arg_dtypes[2];
258
259        if !arr_dt.eq_ignore_nullability(lower_dt) {
260            vortex_bail!(
261                "Array dtype {} does not match lower dtype {}",
262                arr_dt,
263                lower_dt
264            );
265        }
266        if !arr_dt.eq_ignore_nullability(upper_dt) {
267            vortex_bail!(
268                "Array dtype {} does not match upper dtype {}",
269                arr_dt,
270                upper_dt
271            );
272        }
273
274        Ok(Bool(
275            arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
276        ))
277    }
278
279    fn execute(
280        &self,
281        options: &Self::Options,
282        args: &dyn ExecutionArgs,
283        ctx: &mut ExecutionCtx,
284    ) -> VortexResult<ArrayRef> {
285        let arr = args.get(0)?;
286        let lower = args.get(1)?;
287        let upper = args.get(2)?;
288
289        // canonicalize the arr and we might be able to run a between kernels over that.
290        if !arr.is_canonical() {
291            return arr.execute::<Canonical>(ctx)?.into_array().between(
292                lower,
293                upper,
294                options.clone(),
295            );
296        }
297
298        between_canonical(&arr, &lower, &upper, options, ctx)
299    }
300
301    fn stat_falsification(
302        &self,
303        options: &Self::Options,
304        expr: &Expression,
305        catalog: &dyn StatsCatalog,
306    ) -> Option<Expression> {
307        let arr = expr.child(0).clone();
308        let lower = expr.child(1).clone();
309        let upper = expr.child(2).clone();
310
311        let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]);
312        let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
313
314        and(lhs, rhs).stat_falsification(catalog)
315    }
316
317    fn validity(
318        &self,
319        _options: &Self::Options,
320        expression: &Expression,
321    ) -> VortexResult<Option<Expression>> {
322        let arr = expression.child(0).validity()?;
323        let lower = expression.child(1).validity()?;
324        let upper = expression.child(2).validity()?;
325        Ok(Some(and(and(arr, lower), upper)))
326    }
327
328    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
329        false
330    }
331
332    fn is_fallible(&self, _options: &Self::Options) -> bool {
333        false
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use std::sync::LazyLock;
340
341    use rstest::rstest;
342    use vortex_buffer::buffer;
343
344    use super::*;
345    use crate::IntoArray;
346    use crate::VortexSessionExecute;
347    use crate::arrays::BoolArray;
348    use crate::arrays::DecimalArray;
349    use crate::assert_arrays_eq;
350    use crate::dtype::DType;
351    use crate::dtype::DecimalDType;
352    use crate::dtype::Nullability;
353    use crate::dtype::PType;
354    use crate::expr::between;
355    use crate::expr::get_item;
356    use crate::expr::lit;
357    use crate::expr::root;
358    use crate::scalar::DecimalValue;
359    use crate::scalar::Scalar;
360    use crate::session::ArraySession;
361    use crate::test_harness::to_int_indices;
362    use crate::validity::Validity;
363
364    static SESSION: LazyLock<VortexSession> =
365        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
366
367    #[test]
368    fn test_display() {
369        let expr = between(
370            get_item("score", root()),
371            lit(10),
372            lit(50),
373            BetweenOptions {
374                lower_strict: StrictComparison::NonStrict,
375                upper_strict: StrictComparison::Strict,
376            },
377        );
378        assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
379
380        let expr2 = between(
381            root(),
382            lit(0),
383            lit(100),
384            BetweenOptions {
385                lower_strict: StrictComparison::Strict,
386                upper_strict: StrictComparison::NonStrict,
387            },
388        );
389        assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
390    }
391
392    #[rstest]
393    #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
394    #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
395    #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
396    #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
397    fn test_bounds(
398        #[case] lower_strict: StrictComparison,
399        #[case] upper_strict: StrictComparison,
400        #[case] expected: Vec<u64>,
401    ) {
402        let lower = buffer![0, 0, 0, 0, 2].into_array();
403        let array = buffer![1, 0, 1, 0, 1].into_array();
404        let upper = buffer![2, 1, 1, 0, 0].into_array();
405
406        let matches = between_canonical(
407            &array,
408            &lower,
409            &upper,
410            &BetweenOptions {
411                lower_strict,
412                upper_strict,
413            },
414            &mut SESSION.create_execution_ctx(),
415        )
416        .unwrap()
417        .execute::<BoolArray>(&mut SESSION.create_execution_ctx())
418        .unwrap();
419
420        let indices = to_int_indices(matches).unwrap();
421        assert_eq!(indices, expected);
422    }
423
424    #[test]
425    fn test_constants() {
426        let lower = buffer![0, 0, 2, 0, 2].into_array();
427        let array = buffer![1, 0, 1, 0, 1].into_array();
428
429        // upper is null
430        let upper = ConstantArray::new(
431            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
432            5,
433        )
434        .into_array();
435
436        let matches = between_canonical(
437            &array,
438            &lower,
439            &upper,
440            &BetweenOptions {
441                lower_strict: StrictComparison::NonStrict,
442                upper_strict: StrictComparison::NonStrict,
443            },
444            &mut SESSION.create_execution_ctx(),
445        )
446        .unwrap()
447        .execute::<BoolArray>(&mut SESSION.create_execution_ctx())
448        .unwrap();
449
450        let indices = to_int_indices(matches).unwrap();
451        assert!(indices.is_empty());
452
453        // upper is a fixed constant
454        let upper = ConstantArray::new(Scalar::from(2), 5).into_array();
455        let matches = between_canonical(
456            &array,
457            &lower,
458            &upper,
459            &BetweenOptions {
460                lower_strict: StrictComparison::NonStrict,
461                upper_strict: StrictComparison::NonStrict,
462            },
463            &mut SESSION.create_execution_ctx(),
464        )
465        .unwrap()
466        .execute::<BoolArray>(&mut SESSION.create_execution_ctx())
467        .unwrap();
468        let indices = to_int_indices(matches).unwrap();
469        assert_eq!(indices, vec![0, 1, 3]);
470
471        // lower is also a constant
472        let lower = ConstantArray::new(Scalar::from(0), 5).into_array();
473
474        let matches = between_canonical(
475            &array,
476            &lower,
477            &upper,
478            &BetweenOptions {
479                lower_strict: StrictComparison::NonStrict,
480                upper_strict: StrictComparison::NonStrict,
481            },
482            &mut SESSION.create_execution_ctx(),
483        )
484        .unwrap()
485        .execute::<BoolArray>(&mut SESSION.create_execution_ctx())
486        .unwrap();
487        let indices = to_int_indices(matches).unwrap();
488        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
489    }
490
491    #[test]
492    fn test_between_decimal() {
493        let values = buffer![100i128, 200i128, 300i128, 400i128];
494        let decimal_type = DecimalDType::new(3, 2);
495        let array = DecimalArray::new(values, decimal_type, Validity::NonNullable).into_array();
496
497        let lower = ConstantArray::new(
498            Scalar::decimal(
499                DecimalValue::I128(100i128),
500                decimal_type,
501                Nullability::NonNullable,
502            ),
503            array.len(),
504        )
505        .into_array();
506        let upper = ConstantArray::new(
507            Scalar::decimal(
508                DecimalValue::I128(400i128),
509                decimal_type,
510                Nullability::NonNullable,
511            ),
512            array.len(),
513        )
514        .into_array();
515
516        // Strict lower bound, non-strict upper bound
517        let between_strict = between_canonical(
518            &array,
519            &lower,
520            &upper,
521            &BetweenOptions {
522                lower_strict: StrictComparison::Strict,
523                upper_strict: StrictComparison::NonStrict,
524            },
525            &mut SESSION.create_execution_ctx(),
526        )
527        .unwrap();
528        assert_arrays_eq!(
529            between_strict,
530            BoolArray::from_iter([false, true, true, true])
531        );
532
533        // Non-strict lower bound, strict upper bound
534        let between_strict = between_canonical(
535            &array,
536            &lower,
537            &upper,
538            &BetweenOptions {
539                lower_strict: StrictComparison::NonStrict,
540                upper_strict: StrictComparison::Strict,
541            },
542            &mut SESSION.create_execution_ctx(),
543        )
544        .unwrap();
545        assert_arrays_eq!(
546            between_strict,
547            BoolArray::from_iter([true, true, true, false])
548        );
549    }
550
551    /// Regression test for a fuzzer crash where a bound scalar used a wider storage type (I32)
552    /// than the array's storage type (I16), causing the cast in `between_unpack` to fail.
553    ///
554    /// The fix casts the bound to the array's storage type and, when the cast fails, uses the
555    /// overflow direction to determine the result without falling back to Arrow.
556    #[rstest]
557    // Upper bound too large (I32 > i16::MAX): upper constraint always satisfied → result from lower only.
558    #[case(DecimalValue::I16(1), DecimalValue::I32(82246), vec![0, 1, 2, 3])]
559    // Lower bound too large (I32 > i16::MAX): lower constraint never satisfied → all false.
560    #[case(DecimalValue::I32(82246), DecimalValue::I16(4), vec![])]
561    // Upper bound too small (negative I32 < i16::MIN): upper constraint never satisfied → all false.
562    #[case(DecimalValue::I16(1), DecimalValue::I32(-82246), vec![])]
563    // Lower bound too small (negative I32 < i16::MIN): lower constraint always satisfied → result from upper only.
564    #[case(DecimalValue::I32(-82246), DecimalValue::I16(2), vec![0, 1])]
565    fn test_between_decimal_mismatched_storage_types(
566        #[case] lower_val: DecimalValue,
567        #[case] upper_val: DecimalValue,
568        #[case] expected_indices: Vec<u64>,
569    ) {
570        // Array uses I16 storage with precision=5 (values fit in i16 even though precision=5
571        // nominally maps to I32 as the smallest storage type).
572        let decimal_type = DecimalDType::new(5, -67);
573        let array = DecimalArray::new(
574            buffer![1i16, 2i16, 3i16, 4i16],
575            decimal_type,
576            Validity::NonNullable,
577        )
578        .into_array();
579
580        let lower = ConstantArray::new(
581            Scalar::decimal(lower_val, decimal_type, Nullability::NonNullable),
582            array.len(),
583        )
584        .into_array();
585        let upper = ConstantArray::new(
586            Scalar::decimal(upper_val, decimal_type, Nullability::NonNullable),
587            array.len(),
588        )
589        .into_array();
590
591        let result = between_canonical(
592            &array,
593            &lower,
594            &upper,
595            &BetweenOptions {
596                lower_strict: StrictComparison::NonStrict,
597                upper_strict: StrictComparison::NonStrict,
598            },
599            &mut SESSION.create_execution_ctx(),
600        )
601        .unwrap()
602        .execute::<BoolArray>(&mut SESSION.create_execution_ctx())
603        .unwrap();
604
605        assert_eq!(to_int_indices(result).unwrap(), expected_indices);
606    }
607}