Skip to main content

vortex_array/scalar_fn/fns/between/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::any::Any;
7use std::fmt::Display;
8use std::fmt::Formatter;
9
10pub use kernel::*;
11use prost::Message;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_proto::expr as pb;
15use vortex_session::VortexSession;
16
17use crate::Array;
18use crate::ArrayRef;
19use crate::Canonical;
20use crate::ExecutionCtx;
21use crate::IntoArray;
22use crate::arrays::ConstantArray;
23use crate::arrays::DecimalVTable;
24use crate::arrays::PrimitiveVTable;
25use crate::builtins::ArrayBuiltins;
26use crate::compute::Options;
27use crate::dtype::DType;
28use crate::dtype::DType::Bool;
29use crate::expr::StatsCatalog;
30use crate::expr::expression::Expression;
31use crate::scalar::Scalar;
32use crate::scalar_fn::Arity;
33use crate::scalar_fn::ChildName;
34use crate::scalar_fn::ExecutionArgs;
35use crate::scalar_fn::ScalarFnId;
36use crate::scalar_fn::ScalarFnVTable;
37use crate::scalar_fn::ScalarFnVTableExt;
38use crate::scalar_fn::fns::binary::Binary;
39use crate::scalar_fn::fns::binary::execute_boolean;
40use crate::scalar_fn::fns::operators::CompareOperator;
41use crate::scalar_fn::fns::operators::Operator;
42
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct BetweenOptions {
45    pub lower_strict: StrictComparison,
46    pub upper_strict: StrictComparison,
47}
48
49impl Display for BetweenOptions {
50    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
51        let lower_op = if self.lower_strict.is_strict() {
52            "<"
53        } else {
54            "<="
55        };
56        let upper_op = if self.upper_strict.is_strict() {
57            "<"
58        } else {
59            "<="
60        };
61        write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
62    }
63}
64
65impl Options for BetweenOptions {
66    fn as_any(&self) -> &dyn Any {
67        self
68    }
69}
70
71/// Strictness of the comparison.
72#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
73pub enum StrictComparison {
74    /// Strict bound (`<`)
75    Strict,
76    /// Non-strict bound (`<=`)
77    NonStrict,
78}
79
80impl StrictComparison {
81    pub const fn to_compare_operator(&self) -> CompareOperator {
82        match self {
83            StrictComparison::Strict => CompareOperator::Lt,
84            StrictComparison::NonStrict => CompareOperator::Lte,
85        }
86    }
87
88    pub const fn to_operator(&self) -> Operator {
89        match self {
90            StrictComparison::Strict => Operator::Lt,
91            StrictComparison::NonStrict => Operator::Lte,
92        }
93    }
94
95    pub const fn is_strict(&self) -> bool {
96        matches!(self, StrictComparison::Strict)
97    }
98}
99
100/// Common preconditions for between operations that apply to all arrays.
101///
102/// Returns `Some(result)` if the precondition short-circuits the between operation
103/// (empty array, null bounds), or `None` if between should proceed with the
104/// encoding-specific implementation.
105pub(super) fn precondition(
106    arr: &ArrayRef,
107    lower: &ArrayRef,
108    upper: &ArrayRef,
109) -> VortexResult<Option<ArrayRef>> {
110    let return_dtype =
111        Bool(arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability());
112
113    // Bail early if the array is empty.
114    if arr.is_empty() {
115        return Ok(Some(Canonical::empty(&return_dtype).into_array()));
116    }
117
118    // A quick check to see if either bound is a null constant array.
119    if (lower.is_invalid(0)? || upper.is_invalid(0)?)
120        && let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
121        && (c_lower.is_null() || c_upper.is_null())
122    {
123        return Ok(Some(
124            ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
125        ));
126    }
127
128    if lower.as_constant().is_some_and(|v| v.is_null())
129        || upper.as_constant().is_some_and(|v| v.is_null())
130    {
131        return Ok(Some(
132            ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
133        ));
134    }
135
136    Ok(None)
137}
138
139/// Between on a canonical array by directly dispatching to the appropriate kernel.
140///
141/// Falls back to compare + boolean and if no kernel handles the input.
142fn between_canonical(
143    arr: &ArrayRef,
144    lower: &ArrayRef,
145    upper: &ArrayRef,
146    options: &BetweenOptions,
147    ctx: &mut ExecutionCtx,
148) -> VortexResult<ArrayRef> {
149    if let Some(result) = precondition(arr, lower, upper)? {
150        return Ok(result);
151    }
152
153    // Try type-specific kernels
154    if let Some(prim) = arr.as_opt::<PrimitiveVTable>()
155        && let Some(result) =
156            <PrimitiveVTable as BetweenKernel>::between(prim, lower, upper, options, ctx)?
157    {
158        return Ok(result);
159    }
160    if let Some(dec) = arr.as_opt::<DecimalVTable>()
161        && let Some(result) =
162            <DecimalVTable as BetweenKernel>::between(dec, lower, upper, options, ctx)?
163    {
164        return Ok(result);
165    }
166
167    // TODO(joe): return lazy compare once the executor supports this
168    // Fall back to compare + boolean and
169    let lower_cmp = lower.to_array().binary(
170        arr.to_array(),
171        Operator::from(options.lower_strict.to_compare_operator()),
172    )?;
173    let upper_cmp = arr.to_array().binary(
174        upper.to_array(),
175        Operator::from(options.upper_strict.to_compare_operator()),
176    )?;
177    execute_boolean(&lower_cmp, &upper_cmp, Operator::And)
178}
179
180/// An optimized scalar expression to compute whether values fall between two bounds.
181///
182/// This expression takes three children:
183/// 1. The array of values to check.
184/// 2. The lower bound.
185/// 3. The upper bound.
186///
187/// The comparison strictness is controlled by the metadata.
188///
189/// NOTE: this expression will shortly be removed in favor of pipelined computation of two
190/// separate comparisons combined with a logical AND.
191#[derive(Clone)]
192pub struct Between;
193
194impl ScalarFnVTable for Between {
195    type Options = BetweenOptions;
196
197    fn id(&self) -> ScalarFnId {
198        ScalarFnId::from("vortex.between")
199    }
200
201    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
202        Ok(Some(
203            pb::BetweenOpts {
204                lower_strict: instance.lower_strict.is_strict(),
205                upper_strict: instance.upper_strict.is_strict(),
206            }
207            .encode_to_vec(),
208        ))
209    }
210
211    fn deserialize(
212        &self,
213        _metadata: &[u8],
214        _session: &VortexSession,
215    ) -> VortexResult<Self::Options> {
216        let opts = pb::BetweenOpts::decode(_metadata)?;
217        Ok(BetweenOptions {
218            lower_strict: if opts.lower_strict {
219                StrictComparison::Strict
220            } else {
221                StrictComparison::NonStrict
222            },
223            upper_strict: if opts.upper_strict {
224                StrictComparison::Strict
225            } else {
226                StrictComparison::NonStrict
227            },
228        })
229    }
230
231    fn arity(&self, _options: &Self::Options) -> Arity {
232        Arity::Exact(3)
233    }
234
235    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
236        match child_idx {
237            0 => ChildName::from("array"),
238            1 => ChildName::from("lower"),
239            2 => ChildName::from("upper"),
240            _ => unreachable!("Invalid child index {} for Between expression", child_idx),
241        }
242    }
243
244    fn fmt_sql(
245        &self,
246        options: &Self::Options,
247        expr: &Expression,
248        f: &mut Formatter<'_>,
249    ) -> std::fmt::Result {
250        let lower_op = if options.lower_strict.is_strict() {
251            "<"
252        } else {
253            "<="
254        };
255        let upper_op = if options.upper_strict.is_strict() {
256            "<"
257        } else {
258            "<="
259        };
260        write!(
261            f,
262            "({} {} {} {} {})",
263            expr.child(1),
264            lower_op,
265            expr.child(0),
266            upper_op,
267            expr.child(2)
268        )
269    }
270
271    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
272        let arr_dt = &arg_dtypes[0];
273        let lower_dt = &arg_dtypes[1];
274        let upper_dt = &arg_dtypes[2];
275
276        if !arr_dt.eq_ignore_nullability(lower_dt) {
277            vortex_bail!(
278                "Array dtype {} does not match lower dtype {}",
279                arr_dt,
280                lower_dt
281            );
282        }
283        if !arr_dt.eq_ignore_nullability(upper_dt) {
284            vortex_bail!(
285                "Array dtype {} does not match upper dtype {}",
286                arr_dt,
287                upper_dt
288            );
289        }
290
291        Ok(Bool(
292            arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
293        ))
294    }
295
296    fn execute(
297        &self,
298        options: &Self::Options,
299        args: &dyn ExecutionArgs,
300        ctx: &mut ExecutionCtx,
301    ) -> VortexResult<ArrayRef> {
302        let arr = args.get(0)?;
303        let lower = args.get(1)?;
304        let upper = args.get(2)?;
305
306        // canonicalize the arr and we might be able to run a between kernels over that.
307        if !arr.is_canonical() {
308            return arr.execute::<Canonical>(ctx)?.into_array().between(
309                lower,
310                upper,
311                options.clone(),
312            );
313        }
314
315        between_canonical(&arr, &lower, &upper, options, ctx)
316    }
317
318    fn stat_falsification(
319        &self,
320        options: &Self::Options,
321        expr: &Expression,
322        catalog: &dyn StatsCatalog,
323    ) -> Option<Expression> {
324        let arr = expr.child(0).clone();
325        let lower = expr.child(1).clone();
326        let upper = expr.child(2).clone();
327
328        let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]);
329        let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
330
331        Binary
332            .new_expr(Operator::And, [lhs, rhs])
333            .stat_falsification(catalog)
334    }
335
336    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
337        false
338    }
339
340    fn is_fallible(&self, _options: &Self::Options) -> bool {
341        false
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use rstest::rstest;
348    use vortex_buffer::buffer;
349
350    use super::*;
351    use crate::IntoArray;
352    use crate::LEGACY_SESSION;
353    use crate::ToCanonical;
354    use crate::VortexSessionExecute;
355    use crate::arrays::BoolArray;
356    use crate::arrays::DecimalArray;
357    use crate::assert_arrays_eq;
358    use crate::dtype::DType;
359    use crate::dtype::DecimalDType;
360    use crate::dtype::Nullability;
361    use crate::dtype::PType;
362    use crate::expr::between;
363    use crate::expr::get_item;
364    use crate::expr::lit;
365    use crate::expr::root;
366    use crate::scalar::DecimalValue;
367    use crate::scalar::Scalar;
368    use crate::test_harness::to_int_indices;
369    use crate::validity::Validity;
370
371    #[test]
372    fn test_display() {
373        let expr = between(
374            get_item("score", root()),
375            lit(10),
376            lit(50),
377            BetweenOptions {
378                lower_strict: StrictComparison::NonStrict,
379                upper_strict: StrictComparison::Strict,
380            },
381        );
382        assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
383
384        let expr2 = between(
385            root(),
386            lit(0),
387            lit(100),
388            BetweenOptions {
389                lower_strict: StrictComparison::Strict,
390                upper_strict: StrictComparison::NonStrict,
391            },
392        );
393        assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
394    }
395
396    #[rstest]
397    #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
398    #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
399    #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
400    #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
401    fn test_bounds(
402        #[case] lower_strict: StrictComparison,
403        #[case] upper_strict: StrictComparison,
404        #[case] expected: Vec<u64>,
405    ) {
406        let lower = buffer![0, 0, 0, 0, 2].into_array();
407        let array = buffer![1, 0, 1, 0, 1].into_array();
408        let upper = buffer![2, 1, 1, 0, 0].into_array();
409
410        let matches = between_canonical(
411            &array,
412            &lower,
413            &upper,
414            &BetweenOptions {
415                lower_strict,
416                upper_strict,
417            },
418            &mut LEGACY_SESSION.create_execution_ctx(),
419        )
420        .unwrap()
421        .to_bool();
422
423        let indices = to_int_indices(matches).unwrap();
424        assert_eq!(indices, expected);
425    }
426
427    #[test]
428    fn test_constants() {
429        let lower = buffer![0, 0, 2, 0, 2].into_array();
430        let array = buffer![1, 0, 1, 0, 1].into_array();
431
432        // upper is null
433        let upper = ConstantArray::new(
434            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
435            5,
436        )
437        .into_array();
438
439        let matches = between_canonical(
440            &array,
441            &lower,
442            &upper,
443            &BetweenOptions {
444                lower_strict: StrictComparison::NonStrict,
445                upper_strict: StrictComparison::NonStrict,
446            },
447            &mut LEGACY_SESSION.create_execution_ctx(),
448        )
449        .unwrap()
450        .to_bool();
451
452        let indices = to_int_indices(matches).unwrap();
453        assert!(indices.is_empty());
454
455        // upper is a fixed constant
456        let upper = ConstantArray::new(Scalar::from(2), 5).into_array();
457        let matches = between_canonical(
458            &array,
459            &lower,
460            &upper,
461            &BetweenOptions {
462                lower_strict: StrictComparison::NonStrict,
463                upper_strict: StrictComparison::NonStrict,
464            },
465            &mut LEGACY_SESSION.create_execution_ctx(),
466        )
467        .unwrap()
468        .to_bool();
469        let indices = to_int_indices(matches).unwrap();
470        assert_eq!(indices, vec![0, 1, 3]);
471
472        // lower is also a constant
473        let lower = ConstantArray::new(Scalar::from(0), 5).into_array();
474
475        let matches = between_canonical(
476            &array,
477            &lower,
478            &upper,
479            &BetweenOptions {
480                lower_strict: StrictComparison::NonStrict,
481                upper_strict: StrictComparison::NonStrict,
482            },
483            &mut LEGACY_SESSION.create_execution_ctx(),
484        )
485        .unwrap()
486        .to_bool();
487        let indices = to_int_indices(matches).unwrap();
488        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
489    }
490
491    #[test]
492    fn test_between_decimal() {
493        let values = buffer![100i128, 200i128, 300i128, 400i128];
494        let decimal_type = DecimalDType::new(3, 2);
495        let array = DecimalArray::new(values, decimal_type, Validity::NonNullable).into_array();
496
497        let lower = ConstantArray::new(
498            Scalar::decimal(
499                DecimalValue::I128(100i128),
500                decimal_type,
501                Nullability::NonNullable,
502            ),
503            array.len(),
504        )
505        .into_array();
506        let upper = ConstantArray::new(
507            Scalar::decimal(
508                DecimalValue::I128(400i128),
509                decimal_type,
510                Nullability::NonNullable,
511            ),
512            array.len(),
513        )
514        .into_array();
515
516        // Strict lower bound, non-strict upper bound
517        let between_strict = between_canonical(
518            &array,
519            &lower,
520            &upper,
521            &BetweenOptions {
522                lower_strict: StrictComparison::Strict,
523                upper_strict: StrictComparison::NonStrict,
524            },
525            &mut LEGACY_SESSION.create_execution_ctx(),
526        )
527        .unwrap();
528        assert_arrays_eq!(
529            between_strict,
530            BoolArray::from_iter([false, true, true, true])
531        );
532
533        // Non-strict lower bound, strict upper bound
534        let between_strict = between_canonical(
535            &array,
536            &lower,
537            &upper,
538            &BetweenOptions {
539                lower_strict: StrictComparison::NonStrict,
540                upper_strict: StrictComparison::Strict,
541            },
542            &mut LEGACY_SESSION.create_execution_ctx(),
543        )
544        .unwrap();
545        assert_arrays_eq!(
546            between_strict,
547            BoolArray::from_iter([true, true, true, false])
548        );
549    }
550}