Skip to main content

vortex_array/scalar_fn/fns/between/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::any::Any;
7use std::fmt::Display;
8use std::fmt::Formatter;
9
10pub use kernel::*;
11use prost::Message;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_proto::expr as pb;
15use vortex_session::VortexSession;
16
17use crate::ArrayRef;
18use crate::Canonical;
19use crate::DynArray;
20use crate::ExecutionCtx;
21use crate::IntoArray;
22use crate::arrays::ConstantArray;
23use crate::arrays::Decimal;
24use crate::arrays::Primitive;
25use crate::builtins::ArrayBuiltins;
26use crate::compute::Options;
27use crate::dtype::DType;
28use crate::dtype::DType::Bool;
29use crate::expr::StatsCatalog;
30use crate::expr::expression::Expression;
31use crate::scalar::Scalar;
32use crate::scalar_fn::Arity;
33use crate::scalar_fn::ChildName;
34use crate::scalar_fn::ExecutionArgs;
35use crate::scalar_fn::ScalarFnId;
36use crate::scalar_fn::ScalarFnVTable;
37use crate::scalar_fn::ScalarFnVTableExt;
38use crate::scalar_fn::fns::binary::Binary;
39use crate::scalar_fn::fns::binary::execute_boolean;
40use crate::scalar_fn::fns::operators::CompareOperator;
41use crate::scalar_fn::fns::operators::Operator;
42
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct BetweenOptions {
45    pub lower_strict: StrictComparison,
46    pub upper_strict: StrictComparison,
47}
48
49impl Display for BetweenOptions {
50    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
51        let lower_op = if self.lower_strict.is_strict() {
52            "<"
53        } else {
54            "<="
55        };
56        let upper_op = if self.upper_strict.is_strict() {
57            "<"
58        } else {
59            "<="
60        };
61        write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
62    }
63}
64
65impl Options for BetweenOptions {
66    fn as_any(&self) -> &dyn Any {
67        self
68    }
69}
70
71/// Strictness of the comparison.
72#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
73pub enum StrictComparison {
74    /// Strict bound (`<`)
75    Strict,
76    /// Non-strict bound (`<=`)
77    NonStrict,
78}
79
80impl StrictComparison {
81    pub const fn to_compare_operator(&self) -> CompareOperator {
82        match self {
83            StrictComparison::Strict => CompareOperator::Lt,
84            StrictComparison::NonStrict => CompareOperator::Lte,
85        }
86    }
87
88    pub const fn to_operator(&self) -> Operator {
89        match self {
90            StrictComparison::Strict => Operator::Lt,
91            StrictComparison::NonStrict => Operator::Lte,
92        }
93    }
94
95    pub const fn is_strict(&self) -> bool {
96        matches!(self, StrictComparison::Strict)
97    }
98}
99
100/// Common preconditions for between operations that apply to all arrays.
101///
102/// Returns `Some(result)` if the precondition short-circuits the between operation
103/// (empty array, null bounds), or `None` if between should proceed with the
104/// encoding-specific implementation.
105pub(super) fn precondition(
106    arr: &ArrayRef,
107    lower: &ArrayRef,
108    upper: &ArrayRef,
109) -> VortexResult<Option<ArrayRef>> {
110    let return_dtype =
111        Bool(arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability());
112
113    // Bail early if the array is empty.
114    if arr.is_empty() {
115        return Ok(Some(Canonical::empty(&return_dtype).into_array()));
116    }
117
118    // A quick check to see if either bound is a null constant array.
119    if (lower.is_invalid(0)? || upper.is_invalid(0)?)
120        && let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
121        && (c_lower.is_null() || c_upper.is_null())
122    {
123        return Ok(Some(
124            ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
125        ));
126    }
127
128    if lower.as_constant().is_some_and(|v| v.is_null())
129        || upper.as_constant().is_some_and(|v| v.is_null())
130    {
131        return Ok(Some(
132            ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
133        ));
134    }
135
136    Ok(None)
137}
138
139/// Between on a canonical array by directly dispatching to the appropriate kernel.
140///
141/// Falls back to compare + boolean and if no kernel handles the input.
142fn between_canonical(
143    arr: &ArrayRef,
144    lower: &ArrayRef,
145    upper: &ArrayRef,
146    options: &BetweenOptions,
147    ctx: &mut ExecutionCtx,
148) -> VortexResult<ArrayRef> {
149    if let Some(result) = precondition(arr, lower, upper)? {
150        return Ok(result);
151    }
152
153    // Try type-specific kernels
154    if let Some(prim) = arr.as_opt::<Primitive>()
155        && let Some(result) =
156            <Primitive as BetweenKernel>::between(prim, lower, upper, options, ctx)?
157    {
158        return Ok(result);
159    }
160    if let Some(dec) = arr.as_opt::<Decimal>()
161        && let Some(result) = <Decimal as BetweenKernel>::between(dec, lower, upper, options, ctx)?
162    {
163        return Ok(result);
164    }
165
166    // TODO(joe): return lazy compare once the executor supports this
167    // Fall back to compare + boolean and
168    let lower_cmp = lower.to_array().binary(
169        arr.to_array(),
170        Operator::from(options.lower_strict.to_compare_operator()),
171    )?;
172    let upper_cmp = arr.to_array().binary(
173        upper.to_array(),
174        Operator::from(options.upper_strict.to_compare_operator()),
175    )?;
176    execute_boolean(&lower_cmp, &upper_cmp, Operator::And)
177}
178
179/// An optimized scalar expression to compute whether values fall between two bounds.
180///
181/// This expression takes three children:
182/// 1. The array of values to check.
183/// 2. The lower bound.
184/// 3. The upper bound.
185///
186/// The comparison strictness is controlled by the metadata.
187///
188/// NOTE: this expression will shortly be removed in favor of pipelined computation of two
189/// separate comparisons combined with a logical AND.
190#[derive(Clone)]
191pub struct Between;
192
193impl ScalarFnVTable for Between {
194    type Options = BetweenOptions;
195
196    fn id(&self) -> ScalarFnId {
197        ScalarFnId::from("vortex.between")
198    }
199
200    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
201        Ok(Some(
202            pb::BetweenOpts {
203                lower_strict: instance.lower_strict.is_strict(),
204                upper_strict: instance.upper_strict.is_strict(),
205            }
206            .encode_to_vec(),
207        ))
208    }
209
210    fn deserialize(
211        &self,
212        _metadata: &[u8],
213        _session: &VortexSession,
214    ) -> VortexResult<Self::Options> {
215        let opts = pb::BetweenOpts::decode(_metadata)?;
216        Ok(BetweenOptions {
217            lower_strict: if opts.lower_strict {
218                StrictComparison::Strict
219            } else {
220                StrictComparison::NonStrict
221            },
222            upper_strict: if opts.upper_strict {
223                StrictComparison::Strict
224            } else {
225                StrictComparison::NonStrict
226            },
227        })
228    }
229
230    fn arity(&self, _options: &Self::Options) -> Arity {
231        Arity::Exact(3)
232    }
233
234    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
235        match child_idx {
236            0 => ChildName::from("array"),
237            1 => ChildName::from("lower"),
238            2 => ChildName::from("upper"),
239            _ => unreachable!("Invalid child index {} for Between expression", child_idx),
240        }
241    }
242
243    fn fmt_sql(
244        &self,
245        options: &Self::Options,
246        expr: &Expression,
247        f: &mut Formatter<'_>,
248    ) -> std::fmt::Result {
249        let lower_op = if options.lower_strict.is_strict() {
250            "<"
251        } else {
252            "<="
253        };
254        let upper_op = if options.upper_strict.is_strict() {
255            "<"
256        } else {
257            "<="
258        };
259        write!(
260            f,
261            "({} {} {} {} {})",
262            expr.child(1),
263            lower_op,
264            expr.child(0),
265            upper_op,
266            expr.child(2)
267        )
268    }
269
270    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
271        let arr_dt = &arg_dtypes[0];
272        let lower_dt = &arg_dtypes[1];
273        let upper_dt = &arg_dtypes[2];
274
275        if !arr_dt.eq_ignore_nullability(lower_dt) {
276            vortex_bail!(
277                "Array dtype {} does not match lower dtype {}",
278                arr_dt,
279                lower_dt
280            );
281        }
282        if !arr_dt.eq_ignore_nullability(upper_dt) {
283            vortex_bail!(
284                "Array dtype {} does not match upper dtype {}",
285                arr_dt,
286                upper_dt
287            );
288        }
289
290        Ok(Bool(
291            arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
292        ))
293    }
294
295    fn execute(
296        &self,
297        options: &Self::Options,
298        args: &dyn ExecutionArgs,
299        ctx: &mut ExecutionCtx,
300    ) -> VortexResult<ArrayRef> {
301        let arr = args.get(0)?;
302        let lower = args.get(1)?;
303        let upper = args.get(2)?;
304
305        // canonicalize the arr and we might be able to run a between kernels over that.
306        if !arr.is_canonical() {
307            return arr.execute::<Canonical>(ctx)?.into_array().between(
308                lower,
309                upper,
310                options.clone(),
311            );
312        }
313
314        between_canonical(&arr, &lower, &upper, options, ctx)
315    }
316
317    fn stat_falsification(
318        &self,
319        options: &Self::Options,
320        expr: &Expression,
321        catalog: &dyn StatsCatalog,
322    ) -> Option<Expression> {
323        let arr = expr.child(0).clone();
324        let lower = expr.child(1).clone();
325        let upper = expr.child(2).clone();
326
327        let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]);
328        let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
329
330        Binary
331            .new_expr(Operator::And, [lhs, rhs])
332            .stat_falsification(catalog)
333    }
334
335    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
336        false
337    }
338
339    fn is_fallible(&self, _options: &Self::Options) -> bool {
340        false
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use rstest::rstest;
347    use vortex_buffer::buffer;
348
349    use super::*;
350    use crate::IntoArray;
351    use crate::LEGACY_SESSION;
352    use crate::ToCanonical;
353    use crate::VortexSessionExecute;
354    use crate::arrays::BoolArray;
355    use crate::arrays::DecimalArray;
356    use crate::assert_arrays_eq;
357    use crate::dtype::DType;
358    use crate::dtype::DecimalDType;
359    use crate::dtype::Nullability;
360    use crate::dtype::PType;
361    use crate::expr::between;
362    use crate::expr::get_item;
363    use crate::expr::lit;
364    use crate::expr::root;
365    use crate::scalar::DecimalValue;
366    use crate::scalar::Scalar;
367    use crate::test_harness::to_int_indices;
368    use crate::validity::Validity;
369
370    #[test]
371    fn test_display() {
372        let expr = between(
373            get_item("score", root()),
374            lit(10),
375            lit(50),
376            BetweenOptions {
377                lower_strict: StrictComparison::NonStrict,
378                upper_strict: StrictComparison::Strict,
379            },
380        );
381        assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
382
383        let expr2 = between(
384            root(),
385            lit(0),
386            lit(100),
387            BetweenOptions {
388                lower_strict: StrictComparison::Strict,
389                upper_strict: StrictComparison::NonStrict,
390            },
391        );
392        assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
393    }
394
395    #[rstest]
396    #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
397    #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
398    #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
399    #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
400    fn test_bounds(
401        #[case] lower_strict: StrictComparison,
402        #[case] upper_strict: StrictComparison,
403        #[case] expected: Vec<u64>,
404    ) {
405        let lower = buffer![0, 0, 0, 0, 2].into_array();
406        let array = buffer![1, 0, 1, 0, 1].into_array();
407        let upper = buffer![2, 1, 1, 0, 0].into_array();
408
409        let matches = between_canonical(
410            &array,
411            &lower,
412            &upper,
413            &BetweenOptions {
414                lower_strict,
415                upper_strict,
416            },
417            &mut LEGACY_SESSION.create_execution_ctx(),
418        )
419        .unwrap()
420        .to_bool();
421
422        let indices = to_int_indices(matches).unwrap();
423        assert_eq!(indices, expected);
424    }
425
426    #[test]
427    fn test_constants() {
428        let lower = buffer![0, 0, 2, 0, 2].into_array();
429        let array = buffer![1, 0, 1, 0, 1].into_array();
430
431        // upper is null
432        let upper = ConstantArray::new(
433            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
434            5,
435        )
436        .into_array();
437
438        let matches = between_canonical(
439            &array,
440            &lower,
441            &upper,
442            &BetweenOptions {
443                lower_strict: StrictComparison::NonStrict,
444                upper_strict: StrictComparison::NonStrict,
445            },
446            &mut LEGACY_SESSION.create_execution_ctx(),
447        )
448        .unwrap()
449        .to_bool();
450
451        let indices = to_int_indices(matches).unwrap();
452        assert!(indices.is_empty());
453
454        // upper is a fixed constant
455        let upper = ConstantArray::new(Scalar::from(2), 5).into_array();
456        let matches = between_canonical(
457            &array,
458            &lower,
459            &upper,
460            &BetweenOptions {
461                lower_strict: StrictComparison::NonStrict,
462                upper_strict: StrictComparison::NonStrict,
463            },
464            &mut LEGACY_SESSION.create_execution_ctx(),
465        )
466        .unwrap()
467        .to_bool();
468        let indices = to_int_indices(matches).unwrap();
469        assert_eq!(indices, vec![0, 1, 3]);
470
471        // lower is also a constant
472        let lower = ConstantArray::new(Scalar::from(0), 5).into_array();
473
474        let matches = between_canonical(
475            &array,
476            &lower,
477            &upper,
478            &BetweenOptions {
479                lower_strict: StrictComparison::NonStrict,
480                upper_strict: StrictComparison::NonStrict,
481            },
482            &mut LEGACY_SESSION.create_execution_ctx(),
483        )
484        .unwrap()
485        .to_bool();
486        let indices = to_int_indices(matches).unwrap();
487        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
488    }
489
490    #[test]
491    fn test_between_decimal() {
492        let values = buffer![100i128, 200i128, 300i128, 400i128];
493        let decimal_type = DecimalDType::new(3, 2);
494        let array = DecimalArray::new(values, decimal_type, Validity::NonNullable).into_array();
495
496        let lower = ConstantArray::new(
497            Scalar::decimal(
498                DecimalValue::I128(100i128),
499                decimal_type,
500                Nullability::NonNullable,
501            ),
502            array.len(),
503        )
504        .into_array();
505        let upper = ConstantArray::new(
506            Scalar::decimal(
507                DecimalValue::I128(400i128),
508                decimal_type,
509                Nullability::NonNullable,
510            ),
511            array.len(),
512        )
513        .into_array();
514
515        // Strict lower bound, non-strict upper bound
516        let between_strict = between_canonical(
517            &array,
518            &lower,
519            &upper,
520            &BetweenOptions {
521                lower_strict: StrictComparison::Strict,
522                upper_strict: StrictComparison::NonStrict,
523            },
524            &mut LEGACY_SESSION.create_execution_ctx(),
525        )
526        .unwrap();
527        assert_arrays_eq!(
528            between_strict,
529            BoolArray::from_iter([false, true, true, true])
530        );
531
532        // Non-strict lower bound, strict upper bound
533        let between_strict = between_canonical(
534            &array,
535            &lower,
536            &upper,
537            &BetweenOptions {
538                lower_strict: StrictComparison::NonStrict,
539                upper_strict: StrictComparison::Strict,
540            },
541            &mut LEGACY_SESSION.create_execution_ctx(),
542        )
543        .unwrap();
544        assert_arrays_eq!(
545            between_strict,
546            BoolArray::from_iter([true, true, true, false])
547        );
548    }
549}