Skip to main content

vortex_array/scalar_fn/fns/between/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::any::Any;
7use std::fmt::Display;
8use std::fmt::Formatter;
9
10pub use kernel::*;
11use prost::Message;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_error::vortex_err;
15use vortex_proto::expr as pb;
16use vortex_session::VortexSession;
17
18use crate::Array;
19use crate::ArrayRef;
20use crate::Canonical;
21use crate::ExecutionCtx;
22use crate::IntoArray;
23use crate::arrays::ConstantArray;
24use crate::arrays::DecimalVTable;
25use crate::arrays::PrimitiveVTable;
26use crate::builtins::ArrayBuiltins;
27use crate::compute::Options;
28use crate::dtype::DType;
29use crate::dtype::DType::Bool;
30use crate::expr::StatsCatalog;
31use crate::expr::expression::Expression;
32use crate::scalar::Scalar;
33use crate::scalar_fn::Arity;
34use crate::scalar_fn::ChildName;
35use crate::scalar_fn::ExecutionArgs;
36use crate::scalar_fn::ScalarFnId;
37use crate::scalar_fn::ScalarFnVTable;
38use crate::scalar_fn::ScalarFnVTableExt;
39use crate::scalar_fn::fns::binary::Binary;
40use crate::scalar_fn::fns::binary::execute_boolean;
41use crate::scalar_fn::fns::operators::CompareOperator;
42use crate::scalar_fn::fns::operators::Operator;
43
44#[derive(Debug, Clone, PartialEq, Eq, Hash)]
45pub struct BetweenOptions {
46    pub lower_strict: StrictComparison,
47    pub upper_strict: StrictComparison,
48}
49
50impl Display for BetweenOptions {
51    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
52        let lower_op = if self.lower_strict.is_strict() {
53            "<"
54        } else {
55            "<="
56        };
57        let upper_op = if self.upper_strict.is_strict() {
58            "<"
59        } else {
60            "<="
61        };
62        write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
63    }
64}
65
66impl Options for BetweenOptions {
67    fn as_any(&self) -> &dyn Any {
68        self
69    }
70}
71
72/// Strictness of the comparison.
73#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
74pub enum StrictComparison {
75    /// Strict bound (`<`)
76    Strict,
77    /// Non-strict bound (`<=`)
78    NonStrict,
79}
80
81impl StrictComparison {
82    pub const fn to_compare_operator(&self) -> CompareOperator {
83        match self {
84            StrictComparison::Strict => CompareOperator::Lt,
85            StrictComparison::NonStrict => CompareOperator::Lte,
86        }
87    }
88
89    pub const fn to_operator(&self) -> Operator {
90        match self {
91            StrictComparison::Strict => Operator::Lt,
92            StrictComparison::NonStrict => Operator::Lte,
93        }
94    }
95
96    pub const fn is_strict(&self) -> bool {
97        matches!(self, StrictComparison::Strict)
98    }
99}
100
101/// Common preconditions for between operations that apply to all arrays.
102///
103/// Returns `Some(result)` if the precondition short-circuits the between operation
104/// (empty array, null bounds), or `None` if between should proceed with the
105/// encoding-specific implementation.
106pub(super) fn precondition(
107    arr: &dyn Array,
108    lower: &dyn Array,
109    upper: &dyn Array,
110) -> VortexResult<Option<ArrayRef>> {
111    let return_dtype =
112        Bool(arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability());
113
114    // Bail early if the array is empty.
115    if arr.is_empty() {
116        return Ok(Some(Canonical::empty(&return_dtype).into_array()));
117    }
118
119    // A quick check to see if either bound is a null constant array.
120    if (lower.is_invalid(0)? || upper.is_invalid(0)?)
121        && let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
122        && (c_lower.is_null() || c_upper.is_null())
123    {
124        return Ok(Some(
125            ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
126        ));
127    }
128
129    if lower.as_constant().is_some_and(|v| v.is_null())
130        || upper.as_constant().is_some_and(|v| v.is_null())
131    {
132        return Ok(Some(
133            ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
134        ));
135    }
136
137    Ok(None)
138}
139
140/// Between on a canonical array by directly dispatching to the appropriate kernel.
141///
142/// Falls back to compare + boolean and if no kernel handles the input.
143fn between_canonical(
144    arr: &dyn Array,
145    lower: &dyn Array,
146    upper: &dyn Array,
147    options: &BetweenOptions,
148    ctx: &mut ExecutionCtx,
149) -> VortexResult<ArrayRef> {
150    if let Some(result) = precondition(arr, lower, upper)? {
151        return Ok(result);
152    }
153
154    // Try type-specific kernels
155    if let Some(prim) = arr.as_opt::<PrimitiveVTable>()
156        && let Some(result) =
157            <PrimitiveVTable as BetweenKernel>::between(prim, lower, upper, options, ctx)?
158    {
159        return Ok(result);
160    }
161    if let Some(dec) = arr.as_opt::<DecimalVTable>()
162        && let Some(result) =
163            <DecimalVTable as BetweenKernel>::between(dec, lower, upper, options, ctx)?
164    {
165        return Ok(result);
166    }
167
168    // TODO(joe): return lazy compare once the executor supports this
169    // Fall back to compare + boolean and
170    let lower_cmp = lower.to_array().binary(
171        arr.to_array(),
172        Operator::from(options.lower_strict.to_compare_operator()),
173    )?;
174    let upper_cmp = arr.to_array().binary(
175        upper.to_array(),
176        Operator::from(options.upper_strict.to_compare_operator()),
177    )?;
178    execute_boolean(&lower_cmp, &upper_cmp, Operator::And)
179}
180
181/// An optimized scalar expression to compute whether values fall between two bounds.
182///
183/// This expression takes three children:
184/// 1. The array of values to check.
185/// 2. The lower bound.
186/// 3. The upper bound.
187///
188/// The comparison strictness is controlled by the metadata.
189///
190/// NOTE: this expression will shortly be removed in favor of pipelined computation of two
191/// separate comparisons combined with a logical AND.
192#[derive(Clone)]
193pub struct Between;
194
195impl ScalarFnVTable for Between {
196    type Options = BetweenOptions;
197
198    fn id(&self) -> ScalarFnId {
199        ScalarFnId::from("vortex.between")
200    }
201
202    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
203        Ok(Some(
204            pb::BetweenOpts {
205                lower_strict: instance.lower_strict.is_strict(),
206                upper_strict: instance.upper_strict.is_strict(),
207            }
208            .encode_to_vec(),
209        ))
210    }
211
212    fn deserialize(
213        &self,
214        _metadata: &[u8],
215        _session: &VortexSession,
216    ) -> VortexResult<Self::Options> {
217        let opts = pb::BetweenOpts::decode(_metadata)?;
218        Ok(BetweenOptions {
219            lower_strict: if opts.lower_strict {
220                StrictComparison::Strict
221            } else {
222                StrictComparison::NonStrict
223            },
224            upper_strict: if opts.upper_strict {
225                StrictComparison::Strict
226            } else {
227                StrictComparison::NonStrict
228            },
229        })
230    }
231
232    fn arity(&self, _options: &Self::Options) -> Arity {
233        Arity::Exact(3)
234    }
235
236    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
237        match child_idx {
238            0 => ChildName::from("array"),
239            1 => ChildName::from("lower"),
240            2 => ChildName::from("upper"),
241            _ => unreachable!("Invalid child index {} for Between expression", child_idx),
242        }
243    }
244
245    fn fmt_sql(
246        &self,
247        options: &Self::Options,
248        expr: &Expression,
249        f: &mut Formatter<'_>,
250    ) -> std::fmt::Result {
251        let lower_op = if options.lower_strict.is_strict() {
252            "<"
253        } else {
254            "<="
255        };
256        let upper_op = if options.upper_strict.is_strict() {
257            "<"
258        } else {
259            "<="
260        };
261        write!(
262            f,
263            "({} {} {} {} {})",
264            expr.child(1),
265            lower_op,
266            expr.child(0),
267            upper_op,
268            expr.child(2)
269        )
270    }
271
272    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
273        let arr_dt = &arg_dtypes[0];
274        let lower_dt = &arg_dtypes[1];
275        let upper_dt = &arg_dtypes[2];
276
277        if !arr_dt.eq_ignore_nullability(lower_dt) {
278            vortex_bail!(
279                "Array dtype {} does not match lower dtype {}",
280                arr_dt,
281                lower_dt
282            );
283        }
284        if !arr_dt.eq_ignore_nullability(upper_dt) {
285            vortex_bail!(
286                "Array dtype {} does not match upper dtype {}",
287                arr_dt,
288                upper_dt
289            );
290        }
291
292        Ok(Bool(
293            arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
294        ))
295    }
296
297    fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
298        let [arr, lower, upper]: [ArrayRef; _] = args
299            .inputs
300            .try_into()
301            .map_err(|_| vortex_err!("Expected 3 arguments for Between expression",))?;
302
303        // canonicalize the arr and we might be able to run a between kernels over that.
304        if !arr.is_canonical() {
305            return arr.execute::<Canonical>(args.ctx)?.into_array().between(
306                lower,
307                upper,
308                options.clone(),
309            );
310        }
311
312        between_canonical(
313            arr.as_ref(),
314            lower.as_ref(),
315            upper.as_ref(),
316            options,
317            args.ctx,
318        )
319    }
320
321    fn stat_falsification(
322        &self,
323        options: &Self::Options,
324        expr: &Expression,
325        catalog: &dyn StatsCatalog,
326    ) -> Option<Expression> {
327        let arr = expr.child(0).clone();
328        let lower = expr.child(1).clone();
329        let upper = expr.child(2).clone();
330
331        let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]);
332        let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
333
334        Binary
335            .new_expr(Operator::And, [lhs, rhs])
336            .stat_falsification(catalog)
337    }
338
339    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
340        false
341    }
342
343    fn is_fallible(&self, _options: &Self::Options) -> bool {
344        false
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use rstest::rstest;
351    use vortex_buffer::buffer;
352
353    use super::*;
354    use crate::IntoArray;
355    use crate::LEGACY_SESSION;
356    use crate::ToCanonical;
357    use crate::VortexSessionExecute;
358    use crate::arrays::BoolArray;
359    use crate::arrays::DecimalArray;
360    use crate::assert_arrays_eq;
361    use crate::dtype::DType;
362    use crate::dtype::DecimalDType;
363    use crate::dtype::Nullability;
364    use crate::dtype::PType;
365    use crate::expr::between;
366    use crate::expr::get_item;
367    use crate::expr::lit;
368    use crate::expr::root;
369    use crate::scalar::DecimalValue;
370    use crate::scalar::Scalar;
371    use crate::test_harness::to_int_indices;
372    use crate::validity::Validity;
373
374    #[test]
375    fn test_display() {
376        let expr = between(
377            get_item("score", root()),
378            lit(10),
379            lit(50),
380            BetweenOptions {
381                lower_strict: StrictComparison::NonStrict,
382                upper_strict: StrictComparison::Strict,
383            },
384        );
385        assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
386
387        let expr2 = between(
388            root(),
389            lit(0),
390            lit(100),
391            BetweenOptions {
392                lower_strict: StrictComparison::Strict,
393                upper_strict: StrictComparison::NonStrict,
394            },
395        );
396        assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
397    }
398
399    #[rstest]
400    #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
401    #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
402    #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
403    #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
404    fn test_bounds(
405        #[case] lower_strict: StrictComparison,
406        #[case] upper_strict: StrictComparison,
407        #[case] expected: Vec<u64>,
408    ) {
409        let lower = buffer![0, 0, 0, 0, 2].into_array();
410        let array = buffer![1, 0, 1, 0, 1].into_array();
411        let upper = buffer![2, 1, 1, 0, 0].into_array();
412
413        let matches = between_canonical(
414            array.as_ref(),
415            lower.as_ref(),
416            upper.as_ref(),
417            &BetweenOptions {
418                lower_strict,
419                upper_strict,
420            },
421            &mut LEGACY_SESSION.create_execution_ctx(),
422        )
423        .unwrap()
424        .to_bool();
425
426        let indices = to_int_indices(matches).unwrap();
427        assert_eq!(indices, expected);
428    }
429
430    #[test]
431    fn test_constants() {
432        let lower = buffer![0, 0, 2, 0, 2].into_array();
433        let array = buffer![1, 0, 1, 0, 1].into_array();
434
435        // upper is null
436        let upper = ConstantArray::new(
437            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
438            5,
439        );
440
441        let matches = between_canonical(
442            array.as_ref(),
443            lower.as_ref(),
444            upper.as_ref(),
445            &BetweenOptions {
446                lower_strict: StrictComparison::NonStrict,
447                upper_strict: StrictComparison::NonStrict,
448            },
449            &mut LEGACY_SESSION.create_execution_ctx(),
450        )
451        .unwrap()
452        .to_bool();
453
454        let indices = to_int_indices(matches).unwrap();
455        assert!(indices.is_empty());
456
457        // upper is a fixed constant
458        let upper = ConstantArray::new(Scalar::from(2), 5);
459        let matches = between_canonical(
460            array.as_ref(),
461            lower.as_ref(),
462            upper.as_ref(),
463            &BetweenOptions {
464                lower_strict: StrictComparison::NonStrict,
465                upper_strict: StrictComparison::NonStrict,
466            },
467            &mut LEGACY_SESSION.create_execution_ctx(),
468        )
469        .unwrap()
470        .to_bool();
471        let indices = to_int_indices(matches).unwrap();
472        assert_eq!(indices, vec![0, 1, 3]);
473
474        // lower is also a constant
475        let lower = ConstantArray::new(Scalar::from(0), 5);
476
477        let matches = between_canonical(
478            array.as_ref(),
479            lower.as_ref(),
480            upper.as_ref(),
481            &BetweenOptions {
482                lower_strict: StrictComparison::NonStrict,
483                upper_strict: StrictComparison::NonStrict,
484            },
485            &mut LEGACY_SESSION.create_execution_ctx(),
486        )
487        .unwrap()
488        .to_bool();
489        let indices = to_int_indices(matches).unwrap();
490        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
491    }
492
493    #[test]
494    fn test_between_decimal() {
495        let values = buffer![100i128, 200i128, 300i128, 400i128];
496        let decimal_type = DecimalDType::new(3, 2);
497        let array = DecimalArray::new(values, decimal_type, Validity::NonNullable);
498
499        let lower = ConstantArray::new(
500            Scalar::decimal(
501                DecimalValue::I128(100i128),
502                decimal_type,
503                Nullability::NonNullable,
504            ),
505            array.len(),
506        );
507        let upper = ConstantArray::new(
508            Scalar::decimal(
509                DecimalValue::I128(400i128),
510                decimal_type,
511                Nullability::NonNullable,
512            ),
513            array.len(),
514        );
515
516        // Strict lower bound, non-strict upper bound
517        let between_strict = between_canonical(
518            array.as_ref(),
519            lower.as_ref(),
520            upper.as_ref(),
521            &BetweenOptions {
522                lower_strict: StrictComparison::Strict,
523                upper_strict: StrictComparison::NonStrict,
524            },
525            &mut LEGACY_SESSION.create_execution_ctx(),
526        )
527        .unwrap();
528        assert_arrays_eq!(
529            between_strict,
530            BoolArray::from_iter([false, true, true, true])
531        );
532
533        // Non-strict lower bound, strict upper bound
534        let between_strict = between_canonical(
535            array.as_ref(),
536            lower.as_ref(),
537            upper.as_ref(),
538            &BetweenOptions {
539                lower_strict: StrictComparison::NonStrict,
540                upper_strict: StrictComparison::Strict,
541            },
542            &mut LEGACY_SESSION.create_execution_ctx(),
543        )
544        .unwrap();
545        assert_arrays_eq!(
546            between_strict,
547            BoolArray::from_iter([true, true, true, false])
548        );
549    }
550}