Skip to main content

vortex_array/scalar_fn/fns/between/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::fmt::Display;
7use std::fmt::Formatter;
8
9pub use kernel::*;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::Canonical;
18use crate::ExecutionCtx;
19use crate::IntoArray;
20use crate::arrays::ConstantArray;
21use crate::arrays::Decimal;
22use crate::arrays::Primitive;
23use crate::builtins::ArrayBuiltins;
24use crate::dtype::DType;
25use crate::dtype::DType::Bool;
26use crate::expr::StatsCatalog;
27use crate::expr::expression::Expression;
28use crate::scalar::Scalar;
29use crate::scalar_fn::Arity;
30use crate::scalar_fn::ChildName;
31use crate::scalar_fn::ExecutionArgs;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34use crate::scalar_fn::ScalarFnVTableExt;
35use crate::scalar_fn::fns::binary::Binary;
36use crate::scalar_fn::fns::binary::execute_boolean;
37use crate::scalar_fn::fns::operators::CompareOperator;
38use crate::scalar_fn::fns::operators::Operator;
39
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41pub struct BetweenOptions {
42    pub lower_strict: StrictComparison,
43    pub upper_strict: StrictComparison,
44}
45
46impl Display for BetweenOptions {
47    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48        let lower_op = if self.lower_strict.is_strict() {
49            "<"
50        } else {
51            "<="
52        };
53        let upper_op = if self.upper_strict.is_strict() {
54            "<"
55        } else {
56            "<="
57        };
58        write!(f, "lower_strict: {}, upper_strict: {}", lower_op, upper_op)
59    }
60}
61
62/// Strictness of the comparison.
63#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
64pub enum StrictComparison {
65    /// Strict bound (`<`)
66    Strict,
67    /// Non-strict bound (`<=`)
68    NonStrict,
69}
70
71impl StrictComparison {
72    pub const fn to_compare_operator(&self) -> CompareOperator {
73        match self {
74            StrictComparison::Strict => CompareOperator::Lt,
75            StrictComparison::NonStrict => CompareOperator::Lte,
76        }
77    }
78
79    pub const fn to_operator(&self) -> Operator {
80        match self {
81            StrictComparison::Strict => Operator::Lt,
82            StrictComparison::NonStrict => Operator::Lte,
83        }
84    }
85
86    pub const fn is_strict(&self) -> bool {
87        matches!(self, StrictComparison::Strict)
88    }
89}
90
91/// Common preconditions for between operations that apply to all arrays.
92///
93/// Returns `Some(result)` if the precondition short-circuits the between operation
94/// (empty array, null bounds), or `None` if between should proceed with the
95/// encoding-specific implementation.
96pub(super) fn precondition(
97    arr: &ArrayRef,
98    lower: &ArrayRef,
99    upper: &ArrayRef,
100) -> VortexResult<Option<ArrayRef>> {
101    let return_dtype =
102        Bool(arr.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability());
103
104    // Bail early if the array is empty.
105    if arr.is_empty() {
106        return Ok(Some(Canonical::empty(&return_dtype).into_array()));
107    }
108
109    if lower.as_constant().is_some_and(|v| v.is_null())
110        || upper.as_constant().is_some_and(|v| v.is_null())
111    {
112        return Ok(Some(
113            ConstantArray::new(Scalar::null(return_dtype), arr.len()).into_array(),
114        ));
115    }
116
117    Ok(None)
118}
119
120/// Between on a canonical array by directly dispatching to the appropriate kernel.
121///
122/// Falls back to compare + boolean and if no kernel handles the input.
123fn between_canonical(
124    arr: &ArrayRef,
125    lower: &ArrayRef,
126    upper: &ArrayRef,
127    options: &BetweenOptions,
128    ctx: &mut ExecutionCtx,
129) -> VortexResult<ArrayRef> {
130    if let Some(result) = precondition(arr, lower, upper)? {
131        return Ok(result);
132    }
133
134    // Try type-specific kernels
135    if let Some(prim) = arr.as_opt::<Primitive>()
136        && let Some(result) =
137            <Primitive as BetweenKernel>::between(prim, lower, upper, options, ctx)?
138    {
139        return Ok(result);
140    }
141    if let Some(dec) = arr.as_opt::<Decimal>()
142        && let Some(result) = <Decimal as BetweenKernel>::between(dec, lower, upper, options, ctx)?
143    {
144        return Ok(result);
145    }
146
147    // TODO(joe): return lazy compare once the executor supports this
148    // Fall back to compare + boolean and
149    let lower_cmp = lower.clone().binary(
150        arr.clone(),
151        Operator::from(options.lower_strict.to_compare_operator()),
152    )?;
153    let upper_cmp = arr.clone().binary(
154        upper.clone(),
155        Operator::from(options.upper_strict.to_compare_operator()),
156    )?;
157    execute_boolean(&lower_cmp, &upper_cmp, Operator::And)
158}
159
160/// An optimized scalar expression to compute whether values fall between two bounds.
161///
162/// This expression takes three children:
163/// 1. The array of values to check.
164/// 2. The lower bound.
165/// 3. The upper bound.
166///
167/// The comparison strictness is controlled by the metadata.
168///
169/// NOTE: this expression will shortly be removed in favor of pipelined computation of two
170/// separate comparisons combined with a logical AND.
171#[derive(Clone)]
172pub struct Between;
173
174impl ScalarFnVTable for Between {
175    type Options = BetweenOptions;
176
177    fn id(&self) -> ScalarFnId {
178        ScalarFnId::new("vortex.between")
179    }
180
181    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
182        Ok(Some(
183            pb::BetweenOpts {
184                lower_strict: instance.lower_strict.is_strict(),
185                upper_strict: instance.upper_strict.is_strict(),
186            }
187            .encode_to_vec(),
188        ))
189    }
190
191    fn deserialize(
192        &self,
193        _metadata: &[u8],
194        _session: &VortexSession,
195    ) -> VortexResult<Self::Options> {
196        let opts = pb::BetweenOpts::decode(_metadata)?;
197        Ok(BetweenOptions {
198            lower_strict: if opts.lower_strict {
199                StrictComparison::Strict
200            } else {
201                StrictComparison::NonStrict
202            },
203            upper_strict: if opts.upper_strict {
204                StrictComparison::Strict
205            } else {
206                StrictComparison::NonStrict
207            },
208        })
209    }
210
211    fn arity(&self, _options: &Self::Options) -> Arity {
212        Arity::Exact(3)
213    }
214
215    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
216        match child_idx {
217            0 => ChildName::from("array"),
218            1 => ChildName::from("lower"),
219            2 => ChildName::from("upper"),
220            _ => unreachable!("Invalid child index {} for Between expression", child_idx),
221        }
222    }
223
224    fn fmt_sql(
225        &self,
226        options: &Self::Options,
227        expr: &Expression,
228        f: &mut Formatter<'_>,
229    ) -> std::fmt::Result {
230        let lower_op = if options.lower_strict.is_strict() {
231            "<"
232        } else {
233            "<="
234        };
235        let upper_op = if options.upper_strict.is_strict() {
236            "<"
237        } else {
238            "<="
239        };
240        write!(
241            f,
242            "({} {} {} {} {})",
243            expr.child(1),
244            lower_op,
245            expr.child(0),
246            upper_op,
247            expr.child(2)
248        )
249    }
250
251    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
252        let arr_dt = &arg_dtypes[0];
253        let lower_dt = &arg_dtypes[1];
254        let upper_dt = &arg_dtypes[2];
255
256        if !arr_dt.eq_ignore_nullability(lower_dt) {
257            vortex_bail!(
258                "Array dtype {} does not match lower dtype {}",
259                arr_dt,
260                lower_dt
261            );
262        }
263        if !arr_dt.eq_ignore_nullability(upper_dt) {
264            vortex_bail!(
265                "Array dtype {} does not match upper dtype {}",
266                arr_dt,
267                upper_dt
268            );
269        }
270
271        Ok(Bool(
272            arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
273        ))
274    }
275
276    fn execute(
277        &self,
278        options: &Self::Options,
279        args: &dyn ExecutionArgs,
280        ctx: &mut ExecutionCtx,
281    ) -> VortexResult<ArrayRef> {
282        let arr = args.get(0)?;
283        let lower = args.get(1)?;
284        let upper = args.get(2)?;
285
286        // canonicalize the arr and we might be able to run a between kernels over that.
287        if !arr.is_canonical() {
288            return arr.execute::<Canonical>(ctx)?.into_array().between(
289                lower,
290                upper,
291                options.clone(),
292            );
293        }
294
295        between_canonical(&arr, &lower, &upper, options, ctx)
296    }
297
298    fn stat_falsification(
299        &self,
300        options: &Self::Options,
301        expr: &Expression,
302        catalog: &dyn StatsCatalog,
303    ) -> Option<Expression> {
304        let arr = expr.child(0).clone();
305        let lower = expr.child(1).clone();
306        let upper = expr.child(2).clone();
307
308        let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]);
309        let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
310
311        Binary
312            .new_expr(Operator::And, [lhs, rhs])
313            .stat_falsification(catalog)
314    }
315
316    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
317        false
318    }
319
320    fn is_fallible(&self, _options: &Self::Options) -> bool {
321        false
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use rstest::rstest;
328    use vortex_buffer::buffer;
329
330    use super::*;
331    use crate::IntoArray;
332    use crate::LEGACY_SESSION;
333    use crate::ToCanonical;
334    use crate::VortexSessionExecute;
335    use crate::arrays::BoolArray;
336    use crate::arrays::DecimalArray;
337    use crate::assert_arrays_eq;
338    use crate::dtype::DType;
339    use crate::dtype::DecimalDType;
340    use crate::dtype::Nullability;
341    use crate::dtype::PType;
342    use crate::expr::between;
343    use crate::expr::get_item;
344    use crate::expr::lit;
345    use crate::expr::root;
346    use crate::scalar::DecimalValue;
347    use crate::scalar::Scalar;
348    use crate::test_harness::to_int_indices;
349    use crate::validity::Validity;
350
351    #[test]
352    fn test_display() {
353        let expr = between(
354            get_item("score", root()),
355            lit(10),
356            lit(50),
357            BetweenOptions {
358                lower_strict: StrictComparison::NonStrict,
359                upper_strict: StrictComparison::Strict,
360            },
361        );
362        assert_eq!(expr.to_string(), "(10i32 <= $.score < 50i32)");
363
364        let expr2 = between(
365            root(),
366            lit(0),
367            lit(100),
368            BetweenOptions {
369                lower_strict: StrictComparison::Strict,
370                upper_strict: StrictComparison::NonStrict,
371            },
372        );
373        assert_eq!(expr2.to_string(), "(0i32 < $ <= 100i32)");
374    }
375
376    #[rstest]
377    #[case(StrictComparison::NonStrict, StrictComparison::NonStrict, vec![0, 1, 2, 3])]
378    #[case(StrictComparison::NonStrict, StrictComparison::Strict, vec![0, 1])]
379    #[case(StrictComparison::Strict, StrictComparison::NonStrict, vec![0, 2])]
380    #[case(StrictComparison::Strict, StrictComparison::Strict, vec![0])]
381    fn test_bounds(
382        #[case] lower_strict: StrictComparison,
383        #[case] upper_strict: StrictComparison,
384        #[case] expected: Vec<u64>,
385    ) {
386        let lower = buffer![0, 0, 0, 0, 2].into_array();
387        let array = buffer![1, 0, 1, 0, 1].into_array();
388        let upper = buffer![2, 1, 1, 0, 0].into_array();
389
390        let matches = between_canonical(
391            &array,
392            &lower,
393            &upper,
394            &BetweenOptions {
395                lower_strict,
396                upper_strict,
397            },
398            &mut LEGACY_SESSION.create_execution_ctx(),
399        )
400        .unwrap()
401        .to_bool();
402
403        let indices = to_int_indices(matches).unwrap();
404        assert_eq!(indices, expected);
405    }
406
407    #[test]
408    fn test_constants() {
409        let lower = buffer![0, 0, 2, 0, 2].into_array();
410        let array = buffer![1, 0, 1, 0, 1].into_array();
411
412        // upper is null
413        let upper = ConstantArray::new(
414            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
415            5,
416        )
417        .into_array();
418
419        let matches = between_canonical(
420            &array,
421            &lower,
422            &upper,
423            &BetweenOptions {
424                lower_strict: StrictComparison::NonStrict,
425                upper_strict: StrictComparison::NonStrict,
426            },
427            &mut LEGACY_SESSION.create_execution_ctx(),
428        )
429        .unwrap()
430        .to_bool();
431
432        let indices = to_int_indices(matches).unwrap();
433        assert!(indices.is_empty());
434
435        // upper is a fixed constant
436        let upper = ConstantArray::new(Scalar::from(2), 5).into_array();
437        let matches = between_canonical(
438            &array,
439            &lower,
440            &upper,
441            &BetweenOptions {
442                lower_strict: StrictComparison::NonStrict,
443                upper_strict: StrictComparison::NonStrict,
444            },
445            &mut LEGACY_SESSION.create_execution_ctx(),
446        )
447        .unwrap()
448        .to_bool();
449        let indices = to_int_indices(matches).unwrap();
450        assert_eq!(indices, vec![0, 1, 3]);
451
452        // lower is also a constant
453        let lower = ConstantArray::new(Scalar::from(0), 5).into_array();
454
455        let matches = between_canonical(
456            &array,
457            &lower,
458            &upper,
459            &BetweenOptions {
460                lower_strict: StrictComparison::NonStrict,
461                upper_strict: StrictComparison::NonStrict,
462            },
463            &mut LEGACY_SESSION.create_execution_ctx(),
464        )
465        .unwrap()
466        .to_bool();
467        let indices = to_int_indices(matches).unwrap();
468        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
469    }
470
471    #[test]
472    fn test_between_decimal() {
473        let values = buffer![100i128, 200i128, 300i128, 400i128];
474        let decimal_type = DecimalDType::new(3, 2);
475        let array = DecimalArray::new(values, decimal_type, Validity::NonNullable).into_array();
476
477        let lower = ConstantArray::new(
478            Scalar::decimal(
479                DecimalValue::I128(100i128),
480                decimal_type,
481                Nullability::NonNullable,
482            ),
483            array.len(),
484        )
485        .into_array();
486        let upper = ConstantArray::new(
487            Scalar::decimal(
488                DecimalValue::I128(400i128),
489                decimal_type,
490                Nullability::NonNullable,
491            ),
492            array.len(),
493        )
494        .into_array();
495
496        // Strict lower bound, non-strict upper bound
497        let between_strict = between_canonical(
498            &array,
499            &lower,
500            &upper,
501            &BetweenOptions {
502                lower_strict: StrictComparison::Strict,
503                upper_strict: StrictComparison::NonStrict,
504            },
505            &mut LEGACY_SESSION.create_execution_ctx(),
506        )
507        .unwrap();
508        assert_arrays_eq!(
509            between_strict,
510            BoolArray::from_iter([false, true, true, true])
511        );
512
513        // Non-strict lower bound, strict upper bound
514        let between_strict = between_canonical(
515            &array,
516            &lower,
517            &upper,
518            &BetweenOptions {
519                lower_strict: StrictComparison::NonStrict,
520                upper_strict: StrictComparison::Strict,
521            },
522            &mut LEGACY_SESSION.create_execution_ctx(),
523        )
524        .unwrap();
525        assert_arrays_eq!(
526            between_strict,
527            BoolArray::from_iter([true, true, true, false])
528        );
529    }
530}