Skip to main content

vortex_array/scalar_fn/fns/between/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::fmt::Display;
7use std::fmt::Formatter;
8
9pub use kernel::*;
10use prost::Message;
11use vortex_array::expr::and;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_proto::expr as pb;
15use vortex_session::VortexSession;
16use vortex_session::registry::CachedId;
17
18use crate::ArrayRef;
19use crate::Canonical;
20use crate::ExecutionCtx;
21use crate::IntoArray;
22use crate::arrays::ConstantArray;
23use crate::arrays::Decimal;
24use crate::arrays::Primitive;
25use crate::builtins::ArrayBuiltins;
26use crate::dtype::DType;
27use crate::dtype::DType::Bool;
28use crate::expr::expression::Expression;
29use crate::scalar::Scalar;
30use crate::scalar_fn::Arity;
31use crate::scalar_fn::ChildName;
32use crate::scalar_fn::ExecutionArgs;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35use crate::scalar_fn::fns::binary::execute_boolean;
36use crate::scalar_fn::fns::operators::CompareOperator;
37use crate::scalar_fn::fns::operators::Operator;
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub struct BetweenOptions {
41    pub lower_strict: StrictComparison,
42    pub upper_strict: StrictComparison,
43}
44
45impl Display for BetweenOptions {
46    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
47        let lower_op = if self.lower_strict.is_strict() {
48            "<"
49        } else {
50            "<="
51        };
52        let upper_op = if self.upper_strict.is_strict() {
53            "<"
54        } else {
55            "<="
56        };
57        write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
58    }
59}
60
61/// Strictness of the comparison.
62#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
63pub enum StrictComparison {
64    /// Strict bound (`<`)
65    Strict,
66    /// Non-strict bound (`<=`)
67    NonStrict,
68}
69
70impl StrictComparison {
71    pub const fn to_compare_operator(&self) -> CompareOperator {
72        match self {
73            StrictComparison::Strict => CompareOperator::Lt,
74            StrictComparison::NonStrict => CompareOperator::Lte,
75        }
76    }
77
78    pub const fn to_operator(&self) -> Operator {
79        match self {
80            StrictComparison::Strict => Operator::Lt,
81            StrictComparison::NonStrict => Operator::Lte,
82        }
83    }
84
85    pub const fn is_strict(&self) -> bool {
86        matches!(self, StrictComparison::Strict)
87    }
88}
89
90/// Common preconditions for between operations that apply to all arrays.
91///
92/// Returns `Some(result)` if the precondition short-circuits the between operation
93/// (empty array, null bounds), or `None` if between should proceed with the
94/// encoding-specific implementation.
95pub(super) fn precondition(
96    arr: &ArrayRef,
97    lower: &ArrayRef,
98    upper: &ArrayRef,
99) -> VortexResult<Option<ArrayRef>> {
100    let return_dtype =
101        Bool(arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability());
102
103    // Bail early if the array is empty.
104    if arr.is_empty() {
105        return Ok(Some(Canonical::empty(&return_dtype).into_array()));
106    }
107
108    if lower.as_constant().is_some_and(|v| v.is_null())
109        || upper.as_constant().is_some_and(|v| v.is_null())
110    {
111        return Ok(Some(
112            ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
113        ));
114    }
115
116    Ok(None)
117}
118
119/// Between on a canonical array by directly dispatching to the appropriate kernel.
120///
121/// Falls back to compare + boolean and if no kernel handles the input.
122fn between_canonical(
123    arr: &ArrayRef,
124    lower: &ArrayRef,
125    upper: &ArrayRef,
126    options: &BetweenOptions,
127    ctx: &mut ExecutionCtx,
128) -> VortexResult<ArrayRef> {
129    if let Some(result) = precondition(arr, lower, upper)? {
130        return Ok(result);
131    }
132
133    // Try type-specific kernels
134    if let Some(prim) = arr.as_opt::<Primitive>()
135        && let Some(result) =
136            <Primitive as BetweenKernel>::between(prim, lower, upper, options, ctx)?
137    {
138        return Ok(result);
139    }
140    if let Some(dec) = arr.as_opt::<Decimal>()
141        && let Some(result) = <Decimal as BetweenKernel>::between(dec, lower, upper, options, ctx)?
142    {
143        return Ok(result);
144    }
145
146    // TODO(joe): return lazy compare once the executor supports this
147    // Fall back to compare + boolean and
148    let lower_cmp = lower.clone().binary(
149        arr.clone(),
150        Operator::from(options.lower_strict.to_compare_operator()),
151    )?;
152    let upper_cmp = arr.clone().binary(
153        upper.clone(),
154        Operator::from(options.upper_strict.to_compare_operator()),
155    )?;
156    execute_boolean(&lower_cmp, &upper_cmp, Operator::And, ctx)
157}
158
159/// An optimized scalar expression to compute whether values fall between two bounds.
160///
161/// This expression takes three children:
162/// 1. The array of values to check.
163/// 2. The lower bound.
164/// 3. The upper bound.
165///
166/// The comparison strictness is controlled by the metadata.
167///
168/// NOTE: this expression will shortly be removed in favor of pipelined computation of two
169/// separate comparisons combined with a logical AND.
170#[derive(Clone)]
171pub struct Between;
172
173impl ScalarFnVTable for Between {
174    type Options = BetweenOptions;
175
176    fn id(&self) -> ScalarFnId {
177        static ID: CachedId = CachedId::new("vortex.between");
178        *ID
179    }
180
181    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
182        Ok(Some(
183            pb::BetweenOpts {
184                lower_strict: instance.lower_strict.is_strict(),
185                upper_strict: instance.upper_strict.is_strict(),
186            }
187            .encode_to_vec(),
188        ))
189    }
190
191    fn deserialize(
192        &self,
193        _metadata: &[u8],
194        _session: &VortexSession,
195    ) -> VortexResult<Self::Options> {
196        let opts = pb::BetweenOpts::decode(_metadata)?;
197        Ok(BetweenOptions {
198            lower_strict: if opts.lower_strict {
199                StrictComparison::Strict
200            } else {
201                StrictComparison::NonStrict
202            },
203            upper_strict: if opts.upper_strict {
204                StrictComparison::Strict
205            } else {
206                StrictComparison::NonStrict
207            },
208        })
209    }
210
211    fn arity(&self, _options: &Self::Options) -> Arity {
212        Arity::Exact(3)
213    }
214
215    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
216        match child_idx {
217            0 => ChildName::from("array"),
218            1 => ChildName::from("lower"),
219            2 => ChildName::from("upper"),
220            _ => unreachable!("Invalid child index {} for Between expression", child_idx),
221        }
222    }
223
224    fn fmt_sql(
225        &self,
226        options: &Self::Options,
227        expr: &Expression,
228        f: &mut Formatter<'_>,
229    ) -> std::fmt::Result {
230        let lower_op = if options.lower_strict.is_strict() {
231            "<"
232        } else {
233            "<="
234        };
235        let upper_op = if options.upper_strict.is_strict() {
236            "<"
237        } else {
238            "<="
239        };
240        write!(
241            f,
242            "({} {} {} {} {})",
243            expr.child(1),
244            lower_op,
245            expr.child(0),
246            upper_op,
247            expr.child(2)
248        )
249    }
250
251    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
252        let arr_dt = &arg_dtypes[0];
253        let lower_dt = &arg_dtypes[1];
254        let upper_dt = &arg_dtypes[2];
255
256        if !arr_dt.eq_ignore_nullability(lower_dt) {
257            vortex_bail!(
258                "Array dtype {} does not match lower dtype {}",
259                arr_dt,
260                lower_dt
261            );
262        }
263        if !arr_dt.eq_ignore_nullability(upper_dt) {
264            vortex_bail!(
265                "Array dtype {} does not match upper dtype {}",
266                arr_dt,
267                upper_dt
268            );
269        }
270
271        Ok(Bool(
272            arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
273        ))
274    }
275
276    fn execute(
277        &self,
278        options: &Self::Options,
279        args: &dyn ExecutionArgs,
280        ctx: &mut ExecutionCtx,
281    ) -> VortexResult<ArrayRef> {
282        let arr = args.get(0)?;
283        let lower = args.get(1)?;
284        let upper = args.get(2)?;
285
286        // canonicalize the arr and we might be able to run a between kernels over that.
287        if !arr.is_canonical() {
288            return arr.execute::<Canonical>(ctx)?.into_array().between(
289                lower,
290                upper,
291                options.clone(),
292            );
293        }
294
295        between_canonical(&arr, &lower, &upper, options, ctx)
296    }
297
298    fn validity(
299        &self,
300        _options: &Self::Options,
301        expression: &Expression,
302    ) -> VortexResult<Option<Expression>> {
303        let arr = expression.child(0).validity()?;
304        let lower = expression.child(1).validity()?;
305        let upper = expression.child(2).validity()?;
306        Ok(Some(and(and(arr, lower), upper)))
307    }
308
309    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
310        false
311    }
312
313    fn is_fallible(&self, _options: &Self::Options) -> bool {
314        false
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use std::sync::LazyLock;
321
322    use rstest::rstest;
323    use vortex_buffer::buffer;
324
325    use super::*;
326    use crate::IntoArray;
327    use crate::VortexSessionExecute;
328    use crate::arrays::BoolArray;
329    use crate::arrays::DecimalArray;
330    use crate::assert_arrays_eq;
331    use crate::dtype::DType;
332    use crate::dtype::DecimalDType;
333    use crate::dtype::Nullability;
334    use crate::dtype::PType;
335    use crate::expr::between;
336    use crate::expr::get_item;
337    use crate::expr::lit;
338    use crate::expr::root;
339    use crate::scalar::DecimalValue;
340    use crate::scalar::Scalar;
341    use crate::test_harness::to_int_indices;
342    use crate::validity::Validity;
343
344    static SESSION: LazyLock<VortexSession> = LazyLock::new(crate::array_session);
345
346    #[test]
347    fn test_display() {
348        let expr = between(
349            get_item("score", root()),
350            lit(10),
351            lit(50),
352            BetweenOptions {
353                lower_strict: StrictComparison::NonStrict,
354                upper_strict: StrictComparison::Strict,
355            },
356        );
357        assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
358
359        let expr2 = between(
360            root(),
361            lit(0),
362            lit(100),
363            BetweenOptions {
364                lower_strict: StrictComparison::Strict,
365                upper_strict: StrictComparison::NonStrict,
366            },
367        );
368        assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
369    }
370
371    #[rstest]
372    #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
373    #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
374    #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
375    #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
376    fn test_bounds(
377        #[case] lower_strict: StrictComparison,
378        #[case] upper_strict: StrictComparison,
379        #[case] expected: Vec<u64>,
380    ) {
381        let lower = buffer![0, 0, 0, 0, 2].into_array();
382        let array = buffer![1, 0, 1, 0, 1].into_array();
383        let upper = buffer![2, 1, 1, 0, 0].into_array();
384        let ctx = &mut SESSION.create_execution_ctx();
385
386        let matches = between_canonical(
387            &array,
388            &lower,
389            &upper,
390            &BetweenOptions {
391                lower_strict,
392                upper_strict,
393            },
394            ctx,
395        )
396        .unwrap()
397        .execute::<BoolArray>(ctx)
398        .unwrap();
399
400        let indices = to_int_indices(matches, ctx).unwrap();
401        assert_eq!(indices, expected);
402    }
403
404    #[test]
405    fn test_constants() {
406        let lower = buffer![0, 0, 2, 0, 2].into_array();
407        let array = buffer![1, 0, 1, 0, 1].into_array();
408        let ctx = &mut SESSION.create_execution_ctx();
409
410        // upper is null
411        let upper = ConstantArray::new(
412            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
413            5,
414        )
415        .into_array();
416
417        let matches = between_canonical(
418            &array,
419            &lower,
420            &upper,
421            &BetweenOptions {
422                lower_strict: StrictComparison::NonStrict,
423                upper_strict: StrictComparison::NonStrict,
424            },
425            ctx,
426        )
427        .unwrap()
428        .execute::<BoolArray>(ctx)
429        .unwrap();
430
431        let indices = to_int_indices(matches, ctx).unwrap();
432        assert!(indices.is_empty());
433
434        // upper is a fixed constant
435        let upper = ConstantArray::new(Scalar::from(2), 5).into_array();
436        let matches = between_canonical(
437            &array,
438            &lower,
439            &upper,
440            &BetweenOptions {
441                lower_strict: StrictComparison::NonStrict,
442                upper_strict: StrictComparison::NonStrict,
443            },
444            ctx,
445        )
446        .unwrap()
447        .execute::<BoolArray>(ctx)
448        .unwrap();
449        let indices = to_int_indices(matches, ctx).unwrap();
450        assert_eq!(indices, vec![0, 1, 3]);
451
452        // lower is also a constant
453        let lower = ConstantArray::new(Scalar::from(0), 5).into_array();
454
455        let matches = between_canonical(
456            &array,
457            &lower,
458            &upper,
459            &BetweenOptions {
460                lower_strict: StrictComparison::NonStrict,
461                upper_strict: StrictComparison::NonStrict,
462            },
463            ctx,
464        )
465        .unwrap()
466        .execute::<BoolArray>(ctx)
467        .unwrap();
468        let indices = to_int_indices(matches, ctx).unwrap();
469        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
470    }
471
472    #[test]
473    fn test_between_decimal() {
474        let ctx = &mut SESSION.create_execution_ctx();
475        let values = buffer![100i128, 200i128, 300i128, 400i128];
476        let decimal_type = DecimalDType::new(3, 2);
477        let array = DecimalArray::new(values, decimal_type, Validity::NonNullable).into_array();
478
479        let lower = ConstantArray::new(
480            Scalar::decimal(
481                DecimalValue::I128(100i128),
482                decimal_type,
483                Nullability::NonNullable,
484            ),
485            array.len(),
486        )
487        .into_array();
488        let upper = ConstantArray::new(
489            Scalar::decimal(
490                DecimalValue::I128(400i128),
491                decimal_type,
492                Nullability::NonNullable,
493            ),
494            array.len(),
495        )
496        .into_array();
497
498        // Strict lower bound, non-strict upper bound
499        let between_strict = between_canonical(
500            &array,
501            &lower,
502            &upper,
503            &BetweenOptions {
504                lower_strict: StrictComparison::Strict,
505                upper_strict: StrictComparison::NonStrict,
506            },
507            ctx,
508        )
509        .unwrap();
510        assert_arrays_eq!(
511            between_strict,
512            BoolArray::from_iter([false, true, true, true]),
513            ctx
514        );
515
516        // Non-strict lower bound, strict upper bound
517        let between_strict = between_canonical(
518            &array,
519            &lower,
520            &upper,
521            &BetweenOptions {
522                lower_strict: StrictComparison::NonStrict,
523                upper_strict: StrictComparison::Strict,
524            },
525            ctx,
526        )
527        .unwrap();
528        assert_arrays_eq!(
529            between_strict,
530            BoolArray::from_iter([true, true, true, false]),
531            ctx
532        );
533    }
534
535    /// Regression test for a fuzzer crash where a bound scalar used a wider storage type (I32)
536    /// than the array's storage type (I16), causing the cast in `between_unpack` to fail.
537    ///
538    /// The fix casts the bound to the array's storage type and, when the cast fails, uses the
539    /// overflow direction to determine the result without falling back to Arrow.
540    #[rstest]
541    // Upper bound too large (I32 > i16::MAX): upper constraint always satisfied → result from lower only.
542    #[case(DecimalValue::I16(1), DecimalValue::I32(82246), vec![0, 1, 2, 3])]
543    // Lower bound too large (I32 > i16::MAX): lower constraint never satisfied → all false.
544    #[case(DecimalValue::I32(82246), DecimalValue::I16(4), vec![])]
545    // Upper bound too small (negative I32 < i16::MIN): upper constraint never satisfied → all false.
546    #[case(DecimalValue::I16(1), DecimalValue::I32(-82246), vec![])]
547    // Lower bound too small (negative I32 < i16::MIN): lower constraint always satisfied → result from upper only.
548    #[case(DecimalValue::I32(-82246), DecimalValue::I16(2), vec![0, 1])]
549    fn test_between_decimal_mismatched_storage_types(
550        #[case] lower_val: DecimalValue,
551        #[case] upper_val: DecimalValue,
552        #[case] expected_indices: Vec<u64>,
553    ) {
554        let ctx = &mut SESSION.create_execution_ctx();
555        // Array uses I16 storage with precision=5 (values fit in i16 even though precision=5
556        // nominally maps to I32 as the smallest storage type).
557        let decimal_type = DecimalDType::new(5, -67);
558        let array = DecimalArray::new(
559            buffer![1i16, 2i16, 3i16, 4i16],
560            decimal_type,
561            Validity::NonNullable,
562        )
563        .into_array();
564
565        let lower = ConstantArray::new(
566            Scalar::decimal(lower_val, decimal_type, Nullability::NonNullable),
567            array.len(),
568        )
569        .into_array();
570        let upper = ConstantArray::new(
571            Scalar::decimal(upper_val, decimal_type, Nullability::NonNullable),
572            array.len(),
573        )
574        .into_array();
575
576        let result = between_canonical(
577            &array,
578            &lower,
579            &upper,
580            &BetweenOptions {
581                lower_strict: StrictComparison::NonStrict,
582                upper_strict: StrictComparison::NonStrict,
583            },
584            ctx,
585        )
586        .unwrap()
587        .execute::<BoolArray>(ctx)
588        .unwrap();
589
590        assert_eq!(to_int_indices(result, ctx).unwrap(), expected_indices);
591    }
592}