vortex_expr/exprs/
between.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Debug, Display};
5
6use vortex_array::compute::{BetweenOptions, StrictComparison, between as between_compute};
7use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata};
8use vortex_dtype::DType;
9use vortex_dtype::DType::Bool;
10use vortex_error::{VortexResult, vortex_bail};
11use vortex_proto::expr as pb;
12
13use crate::{
14    AnalysisExpr, BinaryExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable,
15};
16
17vtable!(Between);
18
19#[allow(clippy::derived_hash_with_manual_eq)]
20#[derive(Clone, Debug, Hash)]
21pub struct BetweenExpr {
22    arr: ExprRef,
23    lower: ExprRef,
24    upper: ExprRef,
25    options: BetweenOptions,
26}
27
28impl PartialEq for BetweenExpr {
29    fn eq(&self, other: &Self) -> bool {
30        self.arr.eq(&other.arr)
31            && self.lower.eq(&other.lower)
32            && self.upper.eq(&other.upper)
33            && self.options == other.options
34    }
35}
36
37pub struct BetweenExprEncoding;
38
39impl VTable for BetweenVTable {
40    type Expr = BetweenExpr;
41    type Encoding = BetweenExprEncoding;
42    type Metadata = ProstMetadata<pb::BetweenOpts>;
43
44    fn id(_encoding: &Self::Encoding) -> ExprId {
45        ExprId::new_ref("between")
46    }
47
48    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
49        ExprEncodingRef::new_ref(BetweenExprEncoding.as_ref())
50    }
51
52    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
53        Some(ProstMetadata(pb::BetweenOpts {
54            lower_strict: expr.options.lower_strict == StrictComparison::Strict,
55            upper_strict: expr.options.upper_strict == StrictComparison::Strict,
56        }))
57    }
58
59    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
60        vec![&expr.arr, &expr.lower, &expr.upper]
61    }
62
63    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
64        Ok(BetweenExpr::new(
65            children[0].clone(),
66            children[1].clone(),
67            children[2].clone(),
68            expr.options.clone(),
69        ))
70    }
71
72    fn build(
73        _encoding: &Self::Encoding,
74        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
75        children: Vec<ExprRef>,
76    ) -> VortexResult<Self::Expr> {
77        Ok(BetweenExpr::new(
78            children[0].clone(),
79            children[1].clone(),
80            children[2].clone(),
81            BetweenOptions {
82                lower_strict: if metadata.lower_strict {
83                    StrictComparison::Strict
84                } else {
85                    StrictComparison::NonStrict
86                },
87                upper_strict: if metadata.upper_strict {
88                    StrictComparison::Strict
89                } else {
90                    StrictComparison::NonStrict
91                },
92            },
93        ))
94    }
95
96    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
97        let arr_val = expr.arr.unchecked_evaluate(scope)?;
98        let lower_arr_val = expr.lower.unchecked_evaluate(scope)?;
99        let upper_arr_val = expr.upper.unchecked_evaluate(scope)?;
100
101        between_compute(&arr_val, &lower_arr_val, &upper_arr_val, &expr.options)
102    }
103
104    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
105        let arr_dt = expr.arr.return_dtype(scope)?;
106        let lower_dt = expr.lower.return_dtype(scope)?;
107        let upper_dt = expr.upper.return_dtype(scope)?;
108
109        if !arr_dt.eq_ignore_nullability(&lower_dt) {
110            vortex_bail!(
111                "Array dtype {} does not match lower dtype {}",
112                arr_dt,
113                lower_dt
114            );
115        }
116        if !arr_dt.eq_ignore_nullability(&upper_dt) {
117            vortex_bail!(
118                "Array dtype {} does not match upper dtype {}",
119                arr_dt,
120                upper_dt
121            );
122        }
123
124        Ok(Bool(
125            arr_dt.nullability() | lower_dt.nullability() | upper_dt.nullability(),
126        ))
127    }
128}
129
130impl BetweenExpr {
131    pub fn new(arr: ExprRef, lower: ExprRef, upper: ExprRef, options: BetweenOptions) -> Self {
132        Self {
133            arr,
134            lower,
135            upper,
136            options,
137        }
138    }
139
140    pub fn new_expr(
141        arr: ExprRef,
142        lower: ExprRef,
143        upper: ExprRef,
144        options: BetweenOptions,
145    ) -> ExprRef {
146        Self::new(arr, lower, upper, options).into_expr()
147    }
148
149    pub fn to_binary_expr(&self) -> ExprRef {
150        let lhs = BinaryExpr::new(
151            self.lower.clone(),
152            self.options.lower_strict.to_operator().into(),
153            self.arr.clone(),
154        );
155        let rhs = BinaryExpr::new(
156            self.arr.clone(),
157            self.options.upper_strict.to_operator().into(),
158            self.upper.clone(),
159        );
160        BinaryExpr::new(lhs.into_expr(), crate::Operator::And, rhs.into_expr()).into_expr()
161    }
162}
163
164impl Display for BetweenExpr {
165    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166        write!(
167            f,
168            "({} {} {} {} {})",
169            self.lower,
170            self.options.lower_strict.to_operator(),
171            self.arr,
172            self.options.upper_strict.to_operator(),
173            self.upper
174        )
175    }
176}
177
178impl AnalysisExpr for BetweenExpr {}
179
180pub fn between(arr: ExprRef, lower: ExprRef, upper: ExprRef, options: BetweenOptions) -> ExprRef {
181    BetweenExpr::new(arr, lower, upper, options).into_expr()
182}