Skip to main content

vortex_array/expr/exprs/between/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::any::Any;
7use std::fmt::Display;
8use std::fmt::Formatter;
9
10pub use kernel::*;
11use prost::Message;
12use vortex_dtype::DType;
13use vortex_dtype::DType::Bool;
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_bail;
17use vortex_error::vortex_err;
18use vortex_proto::expr as pb;
19use vortex_session::VortexSession;
20
21use crate::Array;
22use crate::ArrayRef;
23use crate::Canonical;
24use crate::ExecutionCtx;
25use crate::IntoArray;
26use crate::arrays::ConstantArray;
27use crate::arrays::DecimalVTable;
28use crate::arrays::PrimitiveVTable;
29use crate::builtins::ArrayBuiltins;
30use crate::compute::BooleanOperator;
31use crate::compute::Options;
32use crate::compute::compare;
33use crate::expr::Arity;
34use crate::expr::ChildName;
35use crate::expr::ExecutionArgs;
36use crate::expr::ExprId;
37use crate::expr::StatsCatalog;
38use crate::expr::VTable;
39use crate::expr::VTableExt;
40use crate::expr::execute_boolean;
41use crate::expr::expression::Expression;
42use crate::expr::exprs::binary::Binary;
43use crate::expr::exprs::operators::Operator;
44use crate::scalar::Scalar;
45
46#[derive(Debug, Clone, PartialEq, Eq, Hash)]
47pub struct BetweenOptions {
48    pub lower_strict: StrictComparison,
49    pub upper_strict: StrictComparison,
50}
51
52impl Display for BetweenOptions {
53    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54        let lower_op = if self.lower_strict.is_strict() {
55            "<"
56        } else {
57            "<="
58        };
59        let upper_op = if self.upper_strict.is_strict() {
60            "<"
61        } else {
62            "<="
63        };
64        write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
65    }
66}
67
68impl Options for BetweenOptions {
69    fn as_any(&self) -> &dyn Any {
70        self
71    }
72}
73
74/// Strictness of the comparison.
75#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
76pub enum StrictComparison {
77    /// Strict bound (`<`)
78    Strict,
79    /// Non-strict bound (`<=`)
80    NonStrict,
81}
82
83impl StrictComparison {
84    pub const fn to_operator(&self) -> crate::compute::Operator {
85        match self {
86            StrictComparison::Strict => crate::compute::Operator::Lt,
87            StrictComparison::NonStrict => crate::compute::Operator::Lte,
88        }
89    }
90
91    pub const fn is_strict(&self) -> bool {
92        matches!(self, StrictComparison::Strict)
93    }
94}
95
96/// Common preconditions for between operations that apply to all arrays.
97///
98/// Returns `Some(result)` if the precondition short-circuits the between operation
99/// (empty array, null bounds), or `None` if between should proceed with the
100/// encoding-specific implementation.
101pub(super) fn precondition(
102    arr: &dyn Array,
103    lower: &dyn Array,
104    upper: &dyn Array,
105) -> VortexResult<Option<ArrayRef>> {
106    let return_dtype =
107        Bool(arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability());
108
109    // Bail early if the array is empty.
110    if arr.is_empty() {
111        return Ok(Some(Canonical::empty(&return_dtype).into_array()));
112    }
113
114    // A quick check to see if either bound is a null constant array.
115    if (lower.is_invalid(0)? || upper.is_invalid(0)?)
116        && let (Some(c_lower), Some(c_upper)) = (lower.as_constant(), upper.as_constant())
117        && (c_lower.is_null() || c_upper.is_null())
118    {
119        return Ok(Some(
120            ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
121        ));
122    }
123
124    if lower.as_constant().is_some_and(|v| v.is_null())
125        || upper.as_constant().is_some_and(|v| v.is_null())
126    {
127        return Ok(Some(
128            ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
129        ));
130    }
131
132    Ok(None)
133}
134
135/// Between on a canonical array by directly dispatching to the appropriate kernel.
136///
137/// Falls back to compare + boolean and if no kernel handles the input.
138fn between_canonical(
139    arr: &dyn Array,
140    lower: &dyn Array,
141    upper: &dyn Array,
142    options: &BetweenOptions,
143    ctx: &mut ExecutionCtx,
144) -> VortexResult<ArrayRef> {
145    if let Some(result) = precondition(arr, lower, upper)? {
146        return Ok(result);
147    }
148
149    // Try type-specific kernels
150    if let Some(prim) = arr.as_opt::<PrimitiveVTable>()
151        && let Some(result) =
152            <PrimitiveVTable as BetweenKernel>::between(prim, lower, upper, options, ctx)?
153    {
154        return Ok(result);
155    }
156    if let Some(dec) = arr.as_opt::<DecimalVTable>()
157        && let Some(result) =
158            <DecimalVTable as BetweenKernel>::between(dec, lower, upper, options, ctx)?
159    {
160        return Ok(result);
161    }
162
163    // TODO(joe): return lazy compare once the executor supports this
164    // Fall back to compare + boolean and
165    execute_boolean(
166        &compare(lower, arr, options.lower_strict.to_operator())?,
167        &compare(arr, upper, options.upper_strict.to_operator())?,
168        BooleanOperator::AndKleene,
169    )
170}
171
172/// An optimized scalar expression to compute whether values fall between two bounds.
173///
174/// This expression takes three children:
175/// 1. The array of values to check.
176/// 2. The lower bound.
177/// 3. The upper bound.
178///
179/// The comparison strictness is controlled by the metadata.
180///
181/// NOTE: this expression will shortly be removed in favor of pipelined computation of two
182/// separate comparisons combined with a logical AND.
183pub struct Between;
184
185impl VTable for Between {
186    type Options = BetweenOptions;
187
188    fn id(&self) -> ExprId {
189        ExprId::from("vortex.between")
190    }
191
192    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
193        Ok(Some(
194            pb::BetweenOpts {
195                lower_strict: instance.lower_strict.is_strict(),
196                upper_strict: instance.upper_strict.is_strict(),
197            }
198            .encode_to_vec(),
199        ))
200    }
201
202    fn deserialize(
203        &self,
204        _metadata: &[u8],
205        _session: &VortexSession,
206    ) -> VortexResult<Self::Options> {
207        let opts = pb::BetweenOpts::decode(_metadata)?;
208        Ok(BetweenOptions {
209            lower_strict: if opts.lower_strict {
210                StrictComparison::Strict
211            } else {
212                StrictComparison::NonStrict
213            },
214            upper_strict: if opts.upper_strict {
215                StrictComparison::Strict
216            } else {
217                StrictComparison::NonStrict
218            },
219        })
220    }
221
222    fn arity(&self, _options: &Self::Options) -> Arity {
223        Arity::Exact(3)
224    }
225
226    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
227        match child_idx {
228            0 => ChildName::from("array"),
229            1 => ChildName::from("lower"),
230            2 => ChildName::from("upper"),
231            _ => unreachable!("Invalid child index {} for Between expression", child_idx),
232        }
233    }
234
235    fn fmt_sql(
236        &self,
237        options: &Self::Options,
238        expr: &Expression,
239        f: &mut Formatter<'_>,
240    ) -> std::fmt::Result {
241        let lower_op = if options.lower_strict.is_strict() {
242            "<"
243        } else {
244            "<="
245        };
246        let upper_op = if options.upper_strict.is_strict() {
247            "<"
248        } else {
249            "<="
250        };
251        write!(
252            f,
253            "({} {} {} {} {})",
254            expr.child(1),
255            lower_op,
256            expr.child(0),
257            upper_op,
258            expr.child(2)
259        )
260    }
261
262    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
263        let arr_dt = &arg_dtypes[0];
264        let lower_dt = &arg_dtypes[1];
265        let upper_dt = &arg_dtypes[2];
266
267        if !arr_dt.eq_ignore_nullability(lower_dt) {
268            vortex_bail!(
269                "Array dtype {} does not match lower dtype {}",
270                arr_dt,
271                lower_dt
272            );
273        }
274        if !arr_dt.eq_ignore_nullability(upper_dt) {
275            vortex_bail!(
276                "Array dtype {} does not match upper dtype {}",
277                arr_dt,
278                upper_dt
279            );
280        }
281
282        Ok(Bool(
283            arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
284        ))
285    }
286
287    fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
288        let [arr, lower, upper]: [ArrayRef; _] = args
289            .inputs
290            .try_into()
291            .map_err(|_| vortex_err!("Expected 3 arguments for Between expression",))?;
292
293        // canonicalize the arr and we might be able to run a between kernels over that.
294        if !arr.is_canonical() {
295            return arr.execute::<Canonical>(args.ctx)?.into_array().between(
296                lower,
297                upper,
298                options.clone(),
299            );
300        }
301
302        between_canonical(
303            arr.as_ref(),
304            lower.as_ref(),
305            upper.as_ref(),
306            options,
307            args.ctx,
308        )
309    }
310
311    fn stat_falsification(
312        &self,
313        options: &Self::Options,
314        expr: &Expression,
315        catalog: &dyn StatsCatalog,
316    ) -> Option<Expression> {
317        let arr = expr.child(0).clone();
318        let lower = expr.child(1).clone();
319        let upper = expr.child(2).clone();
320
321        let lhs = Binary.new_expr(
322            options.lower_strict.to_operator().into(),
323            [lower, arr.clone()],
324        );
325        let rhs = Binary.new_expr(options.upper_strict.to_operator().into(), [arr, upper]);
326
327        Binary
328            .new_expr(Operator::And, [lhs, rhs])
329            .stat_falsification(catalog)
330    }
331
332    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
333        false
334    }
335
336    fn is_fallible(&self, _options: &Self::Options) -> bool {
337        false
338    }
339}
340
341/// Creates an expression that checks if values are between two bounds.
342///
343/// Returns a boolean array indicating which values fall within the specified range.
344/// The comparison strictness is controlled by the options parameter.
345///
346/// ```rust
347/// # use vortex_array::expr::BetweenOptions;
348/// # use vortex_array::expr::StrictComparison;
349/// # use vortex_array::expr::{between, lit, root};
350/// let opts = BetweenOptions {
351///     lower_strict: StrictComparison::NonStrict,
352///     upper_strict: StrictComparison::NonStrict,
353/// };
354/// let expr = between(root(), lit(10), lit(20), opts);
355/// ```
356pub fn between(
357    arr: Expression,
358    lower: Expression,
359    upper: Expression,
360    options: BetweenOptions,
361) -> Expression {
362    Between
363        .try_new_expr(options, [arr, lower, upper])
364        .vortex_expect("Failed to create Between expression")
365}
366
367#[cfg(test)]
368mod tests {
369    use rstest::rstest;
370    use vortex_buffer::buffer;
371    use vortex_dtype::DType;
372    use vortex_dtype::DecimalDType;
373    use vortex_dtype::Nullability;
374    use vortex_dtype::PType;
375
376    use super::*;
377    use crate::IntoArray;
378    use crate::LEGACY_SESSION;
379    use crate::ToCanonical;
380    use crate::VortexSessionExecute;
381    use crate::arrays::BoolArray;
382    use crate::arrays::DecimalArray;
383    use crate::assert_arrays_eq;
384    use crate::expr::exprs::get_item::get_item;
385    use crate::expr::exprs::literal::lit;
386    use crate::expr::exprs::root::root;
387    use crate::scalar::DecimalValue;
388    use crate::scalar::Scalar;
389    use crate::test_harness::to_int_indices;
390    use crate::validity::Validity;
391
392    #[test]
393    fn test_display() {
394        let expr = between(
395            get_item("score", root()),
396            lit(10),
397            lit(50),
398            BetweenOptions {
399                lower_strict: StrictComparison::NonStrict,
400                upper_strict: StrictComparison::Strict,
401            },
402        );
403        assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
404
405        let expr2 = between(
406            root(),
407            lit(0),
408            lit(100),
409            BetweenOptions {
410                lower_strict: StrictComparison::Strict,
411                upper_strict: StrictComparison::NonStrict,
412            },
413        );
414        assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
415    }
416
417    #[rstest]
418    #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
419    #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
420    #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
421    #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
422    fn test_bounds(
423        #[case] lower_strict: StrictComparison,
424        #[case] upper_strict: StrictComparison,
425        #[case] expected: Vec<u64>,
426    ) {
427        let lower = buffer![0, 0, 0, 0, 2].into_array();
428        let array = buffer![1, 0, 1, 0, 1].into_array();
429        let upper = buffer![2, 1, 1, 0, 0].into_array();
430
431        let matches = between_canonical(
432            array.as_ref(),
433            lower.as_ref(),
434            upper.as_ref(),
435            &BetweenOptions {
436                lower_strict,
437                upper_strict,
438            },
439            &mut LEGACY_SESSION.create_execution_ctx(),
440        )
441        .unwrap()
442        .to_bool();
443
444        let indices = to_int_indices(matches).unwrap();
445        assert_eq!(indices, expected);
446    }
447
448    #[test]
449    fn test_constants() {
450        let lower = buffer![0, 0, 2, 0, 2].into_array();
451        let array = buffer![1, 0, 1, 0, 1].into_array();
452
453        // upper is null
454        let upper = ConstantArray::new(
455            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
456            5,
457        );
458
459        let matches = between_canonical(
460            array.as_ref(),
461            lower.as_ref(),
462            upper.as_ref(),
463            &BetweenOptions {
464                lower_strict: StrictComparison::NonStrict,
465                upper_strict: StrictComparison::NonStrict,
466            },
467            &mut LEGACY_SESSION.create_execution_ctx(),
468        )
469        .unwrap()
470        .to_bool();
471
472        let indices = to_int_indices(matches).unwrap();
473        assert!(indices.is_empty());
474
475        // upper is a fixed constant
476        let upper = ConstantArray::new(Scalar::from(2), 5);
477        let matches = between_canonical(
478            array.as_ref(),
479            lower.as_ref(),
480            upper.as_ref(),
481            &BetweenOptions {
482                lower_strict: StrictComparison::NonStrict,
483                upper_strict: StrictComparison::NonStrict,
484            },
485            &mut LEGACY_SESSION.create_execution_ctx(),
486        )
487        .unwrap()
488        .to_bool();
489        let indices = to_int_indices(matches).unwrap();
490        assert_eq!(indices, vec![0, 1, 3]);
491
492        // lower is also a constant
493        let lower = ConstantArray::new(Scalar::from(0), 5);
494
495        let matches = between_canonical(
496            array.as_ref(),
497            lower.as_ref(),
498            upper.as_ref(),
499            &BetweenOptions {
500                lower_strict: StrictComparison::NonStrict,
501                upper_strict: StrictComparison::NonStrict,
502            },
503            &mut LEGACY_SESSION.create_execution_ctx(),
504        )
505        .unwrap()
506        .to_bool();
507        let indices = to_int_indices(matches).unwrap();
508        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
509    }
510
511    #[test]
512    fn test_between_decimal() {
513        let values = buffer![100i128, 200i128, 300i128, 400i128];
514        let decimal_type = DecimalDType::new(3, 2);
515        let array = DecimalArray::new(values, decimal_type, Validity::NonNullable);
516
517        let lower = ConstantArray::new(
518            Scalar::decimal(
519                DecimalValue::I128(100i128),
520                decimal_type,
521                Nullability::NonNullable,
522            ),
523            array.len(),
524        );
525        let upper = ConstantArray::new(
526            Scalar::decimal(
527                DecimalValue::I128(400i128),
528                decimal_type,
529                Nullability::NonNullable,
530            ),
531            array.len(),
532        );
533
534        // Strict lower bound, non-strict upper bound
535        let between_strict = between_canonical(
536            array.as_ref(),
537            lower.as_ref(),
538            upper.as_ref(),
539            &BetweenOptions {
540                lower_strict: StrictComparison::Strict,
541                upper_strict: StrictComparison::NonStrict,
542            },
543            &mut LEGACY_SESSION.create_execution_ctx(),
544        )
545        .unwrap();
546        assert_arrays_eq!(
547            between_strict,
548            BoolArray::from_iter([false, true, true, true])
549        );
550
551        // Non-strict lower bound, strict upper bound
552        let between_strict = between_canonical(
553            array.as_ref(),
554            lower.as_ref(),
555            upper.as_ref(),
556            &BetweenOptions {
557                lower_strict: StrictComparison::NonStrict,
558                upper_strict: StrictComparison::Strict,
559            },
560            &mut LEGACY_SESSION.create_execution_ctx(),
561        )
562        .unwrap();
563        assert_arrays_eq!(
564            between_strict,
565            BoolArray::from_iter([true, true, true, false])
566        );
567    }
568}