Skip to main content

vortex_array/scalar_fn/fns/between/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::fmt::Display;
7use std::fmt::Formatter;
8
9pub use kernel::*;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::Canonical;
18use crate::ExecutionCtx;
19use crate::IntoArray;
20use crate::arrays::ConstantArray;
21use crate::arrays::Decimal;
22use crate::arrays::Primitive;
23use crate::builtins::ArrayBuiltins;
24use crate::dtype::DType;
25use crate::dtype::DType::Bool;
26use crate::expr::StatsCatalog;
27use crate::expr::expression::Expression;
28use crate::scalar::Scalar;
29use crate::scalar_fn::Arity;
30use crate::scalar_fn::ChildName;
31use crate::scalar_fn::ExecutionArgs;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34use crate::scalar_fn::ScalarFnVTableExt;
35use crate::scalar_fn::fns::binary::Binary;
36use crate::scalar_fn::fns::binary::execute_boolean;
37use crate::scalar_fn::fns::operators::CompareOperator;
38use crate::scalar_fn::fns::operators::Operator;
39
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41pub struct BetweenOptions {
42    pub lower_strict: StrictComparison,
43    pub upper_strict: StrictComparison,
44}
45
46impl Display for BetweenOptions {
47    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48        let lower_op = if self.lower_strict.is_strict() {
49            "<"
50        } else {
51            "<="
52        };
53        let upper_op = if self.upper_strict.is_strict() {
54            "<"
55        } else {
56            "<="
57        };
58        write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
59    }
60}
61
62/// Strictness of the comparison.
63#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
64pub enum StrictComparison {
65    /// Strict bound (`<`)
66    Strict,
67    /// Non-strict bound (`<=`)
68    NonStrict,
69}
70
71impl StrictComparison {
72    pub const fn to_compare_operator(&self) -> CompareOperator {
73        match self {
74            StrictComparison::Strict => CompareOperator::Lt,
75            StrictComparison::NonStrict => CompareOperator::Lte,
76        }
77    }
78
79    pub const fn to_operator(&self) -> Operator {
80        match self {
81            StrictComparison::Strict => Operator::Lt,
82            StrictComparison::NonStrict => Operator::Lte,
83        }
84    }
85
86    pub const fn is_strict(&self) -> bool {
87        matches!(self, StrictComparison::Strict)
88    }
89}
90
91/// Common preconditions for between operations that apply to all arrays.
92///
93/// Returns `Some(result)` if the precondition short-circuits the between operation
94/// (empty array, null bounds), or `None` if between should proceed with the
95/// encoding-specific implementation.
96pub(super) fn precondition(
97    arr: &ArrayRef,
98    lower: &ArrayRef,
99    upper: &ArrayRef,
100) -> VortexResult<Option<ArrayRef>> {
101    let return_dtype =
102        Bool(arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability());
103
104    // Bail early if the array is empty.
105    if arr.is_empty() {
106        return Ok(Some(Canonical::empty(&return_dtype).into_array()));
107    }
108
109    // A quick check to see if either bound is a null constant array.
110    if (lower.is_invalid(0)? || upper.is_invalid(0)?)
111        && let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
112        && (c_lower.is_null() || c_upper.is_null())
113    {
114        return Ok(Some(
115            ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
116        ));
117    }
118
119    if lower.as_constant().is_some_and(|v| v.is_null())
120        || upper.as_constant().is_some_and(|v| v.is_null())
121    {
122        return Ok(Some(
123            ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
124        ));
125    }
126
127    Ok(None)
128}
129
130/// Between on a canonical array by directly dispatching to the appropriate kernel.
131///
132/// Falls back to compare + boolean and if no kernel handles the input.
133fn between_canonical(
134    arr: &ArrayRef,
135    lower: &ArrayRef,
136    upper: &ArrayRef,
137    options: &BetweenOptions,
138    ctx: &mut ExecutionCtx,
139) -> VortexResult<ArrayRef> {
140    if let Some(result) = precondition(arr, lower, upper)? {
141        return Ok(result);
142    }
143
144    // Try type-specific kernels
145    if let Some(prim) = arr.as_opt::<Primitive>()
146        && let Some(result) =
147            <Primitive as BetweenKernel>::between(prim, lower, upper, options, ctx)?
148    {
149        return Ok(result);
150    }
151    if let Some(dec) = arr.as_opt::<Decimal>()
152        && let Some(result) = <Decimal as BetweenKernel>::between(dec, lower, upper, options, ctx)?
153    {
154        return Ok(result);
155    }
156
157    // TODO(joe): return lazy compare once the executor supports this
158    // Fall back to compare + boolean and
159    let lower_cmp = lower.clone().binary(
160        arr.clone(),
161        Operator::from(options.lower_strict.to_compare_operator()),
162    )?;
163    let upper_cmp = arr.clone().binary(
164        upper.clone(),
165        Operator::from(options.upper_strict.to_compare_operator()),
166    )?;
167    execute_boolean(&lower_cmp, &upper_cmp, Operator::And)
168}
169
170/// An optimized scalar expression to compute whether values fall between two bounds.
171///
172/// This expression takes three children:
173/// 1. The array of values to check.
174/// 2. The lower bound.
175/// 3. The upper bound.
176///
177/// The comparison strictness is controlled by the metadata.
178///
179/// NOTE: this expression will shortly be removed in favor of pipelined computation of two
180/// separate comparisons combined with a logical AND.
181#[derive(Clone)]
182pub struct Between;
183
184impl ScalarFnVTable for Between {
185    type Options = BetweenOptions;
186
187    fn id(&self) -> ScalarFnId {
188        ScalarFnId::from("vortex.between")
189    }
190
191    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
192        Ok(Some(
193            pb::BetweenOpts {
194                lower_strict: instance.lower_strict.is_strict(),
195                upper_strict: instance.upper_strict.is_strict(),
196            }
197            .encode_to_vec(),
198        ))
199    }
200
201    fn deserialize(
202        &self,
203        _metadata: &[u8],
204        _session: &VortexSession,
205    ) -> VortexResult<Self::Options> {
206        let opts = pb::BetweenOpts::decode(_metadata)?;
207        Ok(BetweenOptions {
208            lower_strict: if opts.lower_strict {
209                StrictComparison::Strict
210            } else {
211                StrictComparison::NonStrict
212            },
213            upper_strict: if opts.upper_strict {
214                StrictComparison::Strict
215            } else {
216                StrictComparison::NonStrict
217            },
218        })
219    }
220
221    fn arity(&self, _options: &Self::Options) -> Arity {
222        Arity::Exact(3)
223    }
224
225    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
226        match child_idx {
227            0 => ChildName::from("array"),
228            1 => ChildName::from("lower"),
229            2 => ChildName::from("upper"),
230            _ => unreachable!("Invalid child index {} for Between expression", child_idx),
231        }
232    }
233
234    fn fmt_sql(
235        &self,
236        options: &Self::Options,
237        expr: &Expression,
238        f: &mut Formatter<'_>,
239    ) -> std::fmt::Result {
240        let lower_op = if options.lower_strict.is_strict() {
241            "<"
242        } else {
243            "<="
244        };
245        let upper_op = if options.upper_strict.is_strict() {
246            "<"
247        } else {
248            "<="
249        };
250        write!(
251            f,
252            "({} {} {} {} {})",
253            expr.child(1),
254            lower_op,
255            expr.child(0),
256            upper_op,
257            expr.child(2)
258        )
259    }
260
261    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
262        let arr_dt = &arg_dtypes[0];
263        let lower_dt = &arg_dtypes[1];
264        let upper_dt = &arg_dtypes[2];
265
266        if !arr_dt.eq_ignore_nullability(lower_dt) {
267            vortex_bail!(
268                "Array dtype {} does not match lower dtype {}",
269                arr_dt,
270                lower_dt
271            );
272        }
273        if !arr_dt.eq_ignore_nullability(upper_dt) {
274            vortex_bail!(
275                "Array dtype {} does not match upper dtype {}",
276                arr_dt,
277                upper_dt
278            );
279        }
280
281        Ok(Bool(
282            arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
283        ))
284    }
285
286    fn execute(
287        &self,
288        options: &Self::Options,
289        args: &dyn ExecutionArgs,
290        ctx: &mut ExecutionCtx,
291    ) -> VortexResult<ArrayRef> {
292        let arr = args.get(0)?;
293        let lower = args.get(1)?;
294        let upper = args.get(2)?;
295
296        // canonicalize the arr and we might be able to run a between kernels over that.
297        if !arr.is_canonical() {
298            return arr.execute::<Canonical>(ctx)?.into_array().between(
299                lower,
300                upper,
301                options.clone(),
302            );
303        }
304
305        between_canonical(&arr, &lower, &upper, options, ctx)
306    }
307
308    fn stat_falsification(
309        &self,
310        options: &Self::Options,
311        expr: &Expression,
312        catalog: &dyn StatsCatalog,
313    ) -> Option<Expression> {
314        let arr = expr.child(0).clone();
315        let lower = expr.child(1).clone();
316        let upper = expr.child(2).clone();
317
318        let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]);
319        let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
320
321        Binary
322            .new_expr(Operator::And, [lhs, rhs])
323            .stat_falsification(catalog)
324    }
325
326    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
327        false
328    }
329
330    fn is_fallible(&self, _options: &Self::Options) -> bool {
331        false
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use rstest::rstest;
338    use vortex_buffer::buffer;
339
340    use super::*;
341    use crate::IntoArray;
342    use crate::LEGACY_SESSION;
343    use crate::ToCanonical;
344    use crate::VortexSessionExecute;
345    use crate::arrays::BoolArray;
346    use crate::arrays::DecimalArray;
347    use crate::assert_arrays_eq;
348    use crate::dtype::DType;
349    use crate::dtype::DecimalDType;
350    use crate::dtype::Nullability;
351    use crate::dtype::PType;
352    use crate::expr::between;
353    use crate::expr::get_item;
354    use crate::expr::lit;
355    use crate::expr::root;
356    use crate::scalar::DecimalValue;
357    use crate::scalar::Scalar;
358    use crate::test_harness::to_int_indices;
359    use crate::validity::Validity;
360
361    #[test]
362    fn test_display() {
363        let expr = between(
364            get_item("score", root()),
365            lit(10),
366            lit(50),
367            BetweenOptions {
368                lower_strict: StrictComparison::NonStrict,
369                upper_strict: StrictComparison::Strict,
370            },
371        );
372        assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
373
374        let expr2 = between(
375            root(),
376            lit(0),
377            lit(100),
378            BetweenOptions {
379                lower_strict: StrictComparison::Strict,
380                upper_strict: StrictComparison::NonStrict,
381            },
382        );
383        assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
384    }
385
386    #[rstest]
387    #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
388    #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
389    #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
390    #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
391    fn test_bounds(
392        #[case] lower_strict: StrictComparison,
393        #[case] upper_strict: StrictComparison,
394        #[case] expected: Vec<u64>,
395    ) {
396        let lower = buffer![0, 0, 0, 0, 2].into_array();
397        let array = buffer![1, 0, 1, 0, 1].into_array();
398        let upper = buffer![2, 1, 1, 0, 0].into_array();
399
400        let matches = between_canonical(
401            &array,
402            &lower,
403            &upper,
404            &BetweenOptions {
405                lower_strict,
406                upper_strict,
407            },
408            &mut LEGACY_SESSION.create_execution_ctx(),
409        )
410        .unwrap()
411        .to_bool();
412
413        let indices = to_int_indices(matches).unwrap();
414        assert_eq!(indices, expected);
415    }
416
417    #[test]
418    fn test_constants() {
419        let lower = buffer![0, 0, 2, 0, 2].into_array();
420        let array = buffer![1, 0, 1, 0, 1].into_array();
421
422        // upper is null
423        let upper = ConstantArray::new(
424            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
425            5,
426        )
427        .into_array();
428
429        let matches = between_canonical(
430            &array,
431            &lower,
432            &upper,
433            &BetweenOptions {
434                lower_strict: StrictComparison::NonStrict,
435                upper_strict: StrictComparison::NonStrict,
436            },
437            &mut LEGACY_SESSION.create_execution_ctx(),
438        )
439        .unwrap()
440        .to_bool();
441
442        let indices = to_int_indices(matches).unwrap();
443        assert!(indices.is_empty());
444
445        // upper is a fixed constant
446        let upper = ConstantArray::new(Scalar::from(2), 5).into_array();
447        let matches = between_canonical(
448            &array,
449            &lower,
450            &upper,
451            &BetweenOptions {
452                lower_strict: StrictComparison::NonStrict,
453                upper_strict: StrictComparison::NonStrict,
454            },
455            &mut LEGACY_SESSION.create_execution_ctx(),
456        )
457        .unwrap()
458        .to_bool();
459        let indices = to_int_indices(matches).unwrap();
460        assert_eq!(indices, vec![0, 1, 3]);
461
462        // lower is also a constant
463        let lower = ConstantArray::new(Scalar::from(0), 5).into_array();
464
465        let matches = between_canonical(
466            &array,
467            &lower,
468            &upper,
469            &BetweenOptions {
470                lower_strict: StrictComparison::NonStrict,
471                upper_strict: StrictComparison::NonStrict,
472            },
473            &mut LEGACY_SESSION.create_execution_ctx(),
474        )
475        .unwrap()
476        .to_bool();
477        let indices = to_int_indices(matches).unwrap();
478        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
479    }
480
481    #[test]
482    fn test_between_decimal() {
483        let values = buffer![100i128, 200i128, 300i128, 400i128];
484        let decimal_type = DecimalDType::new(3, 2);
485        let array = DecimalArray::new(values, decimal_type, Validity::NonNullable).into_array();
486
487        let lower = ConstantArray::new(
488            Scalar::decimal(
489                DecimalValue::I128(100i128),
490                decimal_type,
491                Nullability::NonNullable,
492            ),
493            array.len(),
494        )
495        .into_array();
496        let upper = ConstantArray::new(
497            Scalar::decimal(
498                DecimalValue::I128(400i128),
499                decimal_type,
500                Nullability::NonNullable,
501            ),
502            array.len(),
503        )
504        .into_array();
505
506        // Strict lower bound, non-strict upper bound
507        let between_strict = between_canonical(
508            &array,
509            &lower,
510            &upper,
511            &BetweenOptions {
512                lower_strict: StrictComparison::Strict,
513                upper_strict: StrictComparison::NonStrict,
514            },
515            &mut LEGACY_SESSION.create_execution_ctx(),
516        )
517        .unwrap();
518        assert_arrays_eq!(
519            between_strict,
520            BoolArray::from_iter([false, true, true, true])
521        );
522
523        // Non-strict lower bound, strict upper bound
524        let between_strict = between_canonical(
525            &array,
526            &lower,
527            &upper,
528            &BetweenOptions {
529                lower_strict: StrictComparison::NonStrict,
530                upper_strict: StrictComparison::Strict,
531            },
532            &mut LEGACY_SESSION.create_execution_ctx(),
533        )
534        .unwrap();
535        assert_arrays_eq!(
536            between_strict,
537            BoolArray::from_iter([true, true, true, false])
538        );
539    }
540}