Skip to main content

vortex_array/scalar_fn/fns/between/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::fmt::Display;
7use std::fmt::Formatter;
8
9pub use kernel::*;
10use prost::Message;
11use vortex_array::expr::and;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_proto::expr as pb;
15use vortex_session::VortexSession;
16
17use crate::ArrayRef;
18use crate::Canonical;
19use crate::ExecutionCtx;
20use crate::IntoArray;
21use crate::arrays::ConstantArray;
22use crate::arrays::Decimal;
23use crate::arrays::Primitive;
24use crate::builtins::ArrayBuiltins;
25use crate::dtype::DType;
26use crate::dtype::DType::Bool;
27use crate::expr::StatsCatalog;
28use crate::expr::expression::Expression;
29use crate::scalar::Scalar;
30use crate::scalar_fn::Arity;
31use crate::scalar_fn::ChildName;
32use crate::scalar_fn::ExecutionArgs;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35use crate::scalar_fn::ScalarFnVTableExt;
36use crate::scalar_fn::fns::binary::Binary;
37use crate::scalar_fn::fns::binary::execute_boolean;
38use crate::scalar_fn::fns::operators::CompareOperator;
39use crate::scalar_fn::fns::operators::Operator;
40
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42pub struct BetweenOptions {
43    pub lower_strict: StrictComparison,
44    pub upper_strict: StrictComparison,
45}
46
47impl Display for BetweenOptions {
48    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
49        let lower_op = if self.lower_strict.is_strict() {
50            "<"
51        } else {
52            "<="
53        };
54        let upper_op = if self.upper_strict.is_strict() {
55            "<"
56        } else {
57            "<="
58        };
59        write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
60    }
61}
62
63/// Strictness of the comparison.
64#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
65pub enum StrictComparison {
66    /// Strict bound (`<`)
67    Strict,
68    /// Non-strict bound (`<=`)
69    NonStrict,
70}
71
72impl StrictComparison {
73    pub const fn to_compare_operator(&self) -> CompareOperator {
74        match self {
75            StrictComparison::Strict => CompareOperator::Lt,
76            StrictComparison::NonStrict => CompareOperator::Lte,
77        }
78    }
79
80    pub const fn to_operator(&self) -> Operator {
81        match self {
82            StrictComparison::Strict => Operator::Lt,
83            StrictComparison::NonStrict => Operator::Lte,
84        }
85    }
86
87    pub const fn is_strict(&self) -> bool {
88        matches!(self, StrictComparison::Strict)
89    }
90}
91
92/// Common preconditions for between operations that apply to all arrays.
93///
94/// Returns `Some(result)` if the precondition short-circuits the between operation
95/// (empty array, null bounds), or `None` if between should proceed with the
96/// encoding-specific implementation.
97pub(super) fn precondition(
98    arr: &ArrayRef,
99    lower: &ArrayRef,
100    upper: &ArrayRef,
101) -> VortexResult<Option<ArrayRef>> {
102    let return_dtype =
103        Bool(arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability());
104
105    // Bail early if the array is empty.
106    if arr.is_empty() {
107        return Ok(Some(Canonical::empty(&return_dtype).into_array()));
108    }
109
110    if lower.as_constant().is_some_and(|v| v.is_null())
111        || upper.as_constant().is_some_and(|v| v.is_null())
112    {
113        return Ok(Some(
114            ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
115        ));
116    }
117
118    Ok(None)
119}
120
121/// Between on a canonical array by directly dispatching to the appropriate kernel.
122///
123/// Falls back to compare + boolean and if no kernel handles the input.
124fn between_canonical(
125    arr: &ArrayRef,
126    lower: &ArrayRef,
127    upper: &ArrayRef,
128    options: &BetweenOptions,
129    ctx: &mut ExecutionCtx,
130) -> VortexResult<ArrayRef> {
131    if let Some(result) = precondition(arr, lower, upper)? {
132        return Ok(result);
133    }
134
135    // Try type-specific kernels
136    if let Some(prim) = arr.as_opt::<Primitive>()
137        && let Some(result) =
138            <Primitive as BetweenKernel>::between(prim, lower, upper, options, ctx)?
139    {
140        return Ok(result);
141    }
142    if let Some(dec) = arr.as_opt::<Decimal>()
143        && let Some(result) = <Decimal as BetweenKernel>::between(dec, lower, upper, options, ctx)?
144    {
145        return Ok(result);
146    }
147
148    // TODO(joe): return lazy compare once the executor supports this
149    // Fall back to compare + boolean and
150    let lower_cmp = lower.clone().binary(
151        arr.clone(),
152        Operator::from(options.lower_strict.to_compare_operator()),
153    )?;
154    let upper_cmp = arr.clone().binary(
155        upper.clone(),
156        Operator::from(options.upper_strict.to_compare_operator()),
157    )?;
158    execute_boolean(&lower_cmp, &upper_cmp, Operator::And, ctx)
159}
160
161/// An optimized scalar expression to compute whether values fall between two bounds.
162///
163/// This expression takes three children:
164/// 1. The array of values to check.
165/// 2. The lower bound.
166/// 3. The upper bound.
167///
168/// The comparison strictness is controlled by the metadata.
169///
170/// NOTE: this expression will shortly be removed in favor of pipelined computation of two
171/// separate comparisons combined with a logical AND.
172#[derive(Clone)]
173pub struct Between;
174
175impl ScalarFnVTable for Between {
176    type Options = BetweenOptions;
177
178    fn id(&self) -> ScalarFnId {
179        ScalarFnId::new("vortex.between")
180    }
181
182    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
183        Ok(Some(
184            pb::BetweenOpts {
185                lower_strict: instance.lower_strict.is_strict(),
186                upper_strict: instance.upper_strict.is_strict(),
187            }
188            .encode_to_vec(),
189        ))
190    }
191
192    fn deserialize(
193        &self,
194        _metadata: &[u8],
195        _session: &VortexSession,
196    ) -> VortexResult<Self::Options> {
197        let opts = pb::BetweenOpts::decode(_metadata)?;
198        Ok(BetweenOptions {
199            lower_strict: if opts.lower_strict {
200                StrictComparison::Strict
201            } else {
202                StrictComparison::NonStrict
203            },
204            upper_strict: if opts.upper_strict {
205                StrictComparison::Strict
206            } else {
207                StrictComparison::NonStrict
208            },
209        })
210    }
211
212    fn arity(&self, _options: &Self::Options) -> Arity {
213        Arity::Exact(3)
214    }
215
216    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
217        match child_idx {
218            0 => ChildName::from("array"),
219            1 => ChildName::from("lower"),
220            2 => ChildName::from("upper"),
221            _ => unreachable!("Invalid child index {} for Between expression", child_idx),
222        }
223    }
224
225    fn fmt_sql(
226        &self,
227        options: &Self::Options,
228        expr: &Expression,
229        f: &mut Formatter<'_>,
230    ) -> std::fmt::Result {
231        let lower_op = if options.lower_strict.is_strict() {
232            "<"
233        } else {
234            "<="
235        };
236        let upper_op = if options.upper_strict.is_strict() {
237            "<"
238        } else {
239            "<="
240        };
241        write!(
242            f,
243            "({} {} {} {} {})",
244            expr.child(1),
245            lower_op,
246            expr.child(0),
247            upper_op,
248            expr.child(2)
249        )
250    }
251
252    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
253        let arr_dt = &arg_dtypes[0];
254        let lower_dt = &arg_dtypes[1];
255        let upper_dt = &arg_dtypes[2];
256
257        if !arr_dt.eq_ignore_nullability(lower_dt) {
258            vortex_bail!(
259                "Array dtype {} does not match lower dtype {}",
260                arr_dt,
261                lower_dt
262            );
263        }
264        if !arr_dt.eq_ignore_nullability(upper_dt) {
265            vortex_bail!(
266                "Array dtype {} does not match upper dtype {}",
267                arr_dt,
268                upper_dt
269            );
270        }
271
272        Ok(Bool(
273            arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
274        ))
275    }
276
277    fn execute(
278        &self,
279        options: &Self::Options,
280        args: &dyn ExecutionArgs,
281        ctx: &mut ExecutionCtx,
282    ) -> VortexResult<ArrayRef> {
283        let arr = args.get(0)?;
284        let lower = args.get(1)?;
285        let upper = args.get(2)?;
286
287        // canonicalize the arr and we might be able to run a between kernels over that.
288        if !arr.is_canonical() {
289            return arr.execute::<Canonical>(ctx)?.into_array().between(
290                lower,
291                upper,
292                options.clone(),
293            );
294        }
295
296        between_canonical(&arr, &lower, &upper, options, ctx)
297    }
298
299    fn stat_falsification(
300        &self,
301        options: &Self::Options,
302        expr: &Expression,
303        catalog: &dyn StatsCatalog,
304    ) -> Option<Expression> {
305        let arr = expr.child(0).clone();
306        let lower = expr.child(1).clone();
307        let upper = expr.child(2).clone();
308
309        let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]);
310        let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
311
312        and(lhs, rhs).stat_falsification(catalog)
313    }
314
315    fn validity(
316        &self,
317        _options: &Self::Options,
318        expression: &Expression,
319    ) -> VortexResult<Option<Expression>> {
320        let arr = expression.child(0).validity()?;
321        let lower = expression.child(1).validity()?;
322        let upper = expression.child(2).validity()?;
323        Ok(Some(and(and(arr, lower), upper)))
324    }
325
326    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
327        false
328    }
329
330    fn is_fallible(&self, _options: &Self::Options) -> bool {
331        false
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use std::sync::LazyLock;
338
339    use rstest::rstest;
340    use vortex_buffer::buffer;
341
342    use super::*;
343    use crate::IntoArray;
344    use crate::VortexSessionExecute;
345    use crate::arrays::BoolArray;
346    use crate::arrays::DecimalArray;
347    use crate::assert_arrays_eq;
348    use crate::dtype::DType;
349    use crate::dtype::DecimalDType;
350    use crate::dtype::Nullability;
351    use crate::dtype::PType;
352    use crate::expr::between;
353    use crate::expr::get_item;
354    use crate::expr::lit;
355    use crate::expr::root;
356    use crate::scalar::DecimalValue;
357    use crate::scalar::Scalar;
358    use crate::session::ArraySession;
359    use crate::test_harness::to_int_indices;
360    use crate::validity::Validity;
361
362    static SESSION: LazyLock<VortexSession> =
363        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
364
365    #[test]
366    fn test_display() {
367        let expr = between(
368            get_item("score", root()),
369            lit(10),
370            lit(50),
371            BetweenOptions {
372                lower_strict: StrictComparison::NonStrict,
373                upper_strict: StrictComparison::Strict,
374            },
375        );
376        assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
377
378        let expr2 = between(
379            root(),
380            lit(0),
381            lit(100),
382            BetweenOptions {
383                lower_strict: StrictComparison::Strict,
384                upper_strict: StrictComparison::NonStrict,
385            },
386        );
387        assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
388    }
389
390    #[rstest]
391    #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
392    #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
393    #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
394    #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
395    fn test_bounds(
396        #[case] lower_strict: StrictComparison,
397        #[case] upper_strict: StrictComparison,
398        #[case] expected: Vec<u64>,
399    ) {
400        let lower = buffer![0, 0, 0, 0, 2].into_array();
401        let array = buffer![1, 0, 1, 0, 1].into_array();
402        let upper = buffer![2, 1, 1, 0, 0].into_array();
403
404        let matches = between_canonical(
405            &array,
406            &lower,
407            &upper,
408            &BetweenOptions {
409                lower_strict,
410                upper_strict,
411            },
412            &mut SESSION.create_execution_ctx(),
413        )
414        .unwrap()
415        .execute::<BoolArray>(&mut SESSION.create_execution_ctx())
416        .unwrap();
417
418        let indices = to_int_indices(matches).unwrap();
419        assert_eq!(indices, expected);
420    }
421
422    #[test]
423    fn test_constants() {
424        let lower = buffer![0, 0, 2, 0, 2].into_array();
425        let array = buffer![1, 0, 1, 0, 1].into_array();
426
427        // upper is null
428        let upper = ConstantArray::new(
429            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
430            5,
431        )
432        .into_array();
433
434        let matches = between_canonical(
435            &array,
436            &lower,
437            &upper,
438            &BetweenOptions {
439                lower_strict: StrictComparison::NonStrict,
440                upper_strict: StrictComparison::NonStrict,
441            },
442            &mut SESSION.create_execution_ctx(),
443        )
444        .unwrap()
445        .execute::<BoolArray>(&mut SESSION.create_execution_ctx())
446        .unwrap();
447
448        let indices = to_int_indices(matches).unwrap();
449        assert!(indices.is_empty());
450
451        // upper is a fixed constant
452        let upper = ConstantArray::new(Scalar::from(2), 5).into_array();
453        let matches = between_canonical(
454            &array,
455            &lower,
456            &upper,
457            &BetweenOptions {
458                lower_strict: StrictComparison::NonStrict,
459                upper_strict: StrictComparison::NonStrict,
460            },
461            &mut SESSION.create_execution_ctx(),
462        )
463        .unwrap()
464        .execute::<BoolArray>(&mut SESSION.create_execution_ctx())
465        .unwrap();
466        let indices = to_int_indices(matches).unwrap();
467        assert_eq!(indices, vec![0, 1, 3]);
468
469        // lower is also a constant
470        let lower = ConstantArray::new(Scalar::from(0), 5).into_array();
471
472        let matches = between_canonical(
473            &array,
474            &lower,
475            &upper,
476            &BetweenOptions {
477                lower_strict: StrictComparison::NonStrict,
478                upper_strict: StrictComparison::NonStrict,
479            },
480            &mut SESSION.create_execution_ctx(),
481        )
482        .unwrap()
483        .execute::<BoolArray>(&mut SESSION.create_execution_ctx())
484        .unwrap();
485        let indices = to_int_indices(matches).unwrap();
486        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
487    }
488
489    #[test]
490    fn test_between_decimal() {
491        let values = buffer![100i128, 200i128, 300i128, 400i128];
492        let decimal_type = DecimalDType::new(3, 2);
493        let array = DecimalArray::new(values, decimal_type, Validity::NonNullable).into_array();
494
495        let lower = ConstantArray::new(
496            Scalar::decimal(
497                DecimalValue::I128(100i128),
498                decimal_type,
499                Nullability::NonNullable,
500            ),
501            array.len(),
502        )
503        .into_array();
504        let upper = ConstantArray::new(
505            Scalar::decimal(
506                DecimalValue::I128(400i128),
507                decimal_type,
508                Nullability::NonNullable,
509            ),
510            array.len(),
511        )
512        .into_array();
513
514        // Strict lower bound, non-strict upper bound
515        let between_strict = between_canonical(
516            &array,
517            &lower,
518            &upper,
519            &BetweenOptions {
520                lower_strict: StrictComparison::Strict,
521                upper_strict: StrictComparison::NonStrict,
522            },
523            &mut SESSION.create_execution_ctx(),
524        )
525        .unwrap();
526        assert_arrays_eq!(
527            between_strict,
528            BoolArray::from_iter([false, true, true, true])
529        );
530
531        // Non-strict lower bound, strict upper bound
532        let between_strict = between_canonical(
533            &array,
534            &lower,
535            &upper,
536            &BetweenOptions {
537                lower_strict: StrictComparison::NonStrict,
538                upper_strict: StrictComparison::Strict,
539            },
540            &mut SESSION.create_execution_ctx(),
541        )
542        .unwrap();
543        assert_arrays_eq!(
544            between_strict,
545            BoolArray::from_iter([true, true, true, false])
546        );
547    }
548}