vortex_expr/
binary.rs

1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use vortex_array::ArrayRef;
7use vortex_array::compute::{Operator as ArrayOperator, and_kleene, compare, or_kleene};
8use vortex_dtype::DType;
9use vortex_error::VortexResult;
10
11use crate::{AnalysisExpr, ExprRef, Operator, Scope, ScopeDType, StatsCatalog, VortexExpr};
12
13#[derive(Debug, Clone, Eq, Hash)]
14#[allow(clippy::derived_hash_with_manual_eq)]
15pub struct BinaryExpr {
16    lhs: ExprRef,
17    operator: Operator,
18    rhs: ExprRef,
19}
20
21impl BinaryExpr {
22    pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
23        Arc::new(Self { lhs, operator, rhs })
24    }
25
26    pub fn lhs(&self) -> &ExprRef {
27        &self.lhs
28    }
29
30    pub fn rhs(&self) -> &ExprRef {
31        &self.rhs
32    }
33
34    pub fn op(&self) -> Operator {
35        self.operator
36    }
37}
38
39impl Display for BinaryExpr {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "({} {} {})", self.lhs, self.operator, self.rhs)
42    }
43}
44
45#[cfg(feature = "proto")]
46pub(crate) mod proto {
47    use vortex_error::{VortexResult, vortex_bail};
48    use vortex_proto::expr::kind::Kind;
49
50    use crate::{BinaryExpr, ExprDeserialize, ExprRef, ExprSerializable, Id};
51
52    pub(crate) struct BinarySerde;
53
54    impl Id for BinarySerde {
55        fn id(&self) -> &'static str {
56            "binary"
57        }
58    }
59
60    impl ExprDeserialize for BinarySerde {
61        fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
62            let Kind::BinaryOp(op) = kind else {
63                vortex_bail!("wrong kind {:?}, binary", kind)
64            };
65
66            Ok(BinaryExpr::new_expr(
67                children[0].clone(),
68                (*op).try_into()?,
69                children[1].clone(),
70            ))
71        }
72    }
73
74    impl ExprSerializable for BinaryExpr {
75        fn id(&self) -> &'static str {
76            BinarySerde.id()
77        }
78
79        fn serialize_kind(&self) -> VortexResult<Kind> {
80            Ok(Kind::BinaryOp(self.operator.into()))
81        }
82    }
83}
84
85impl AnalysisExpr for BinaryExpr {
86    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
87        match self.operator {
88            Operator::Eq => {
89                let min_lhs = self.lhs.min(catalog);
90                let max_lhs = self.lhs.max(catalog);
91
92                let min_rhs = self.rhs.min(catalog);
93                let max_rhs = self.rhs.max(catalog);
94
95                let left = min_lhs.zip_with(max_rhs, gt);
96                let right = min_rhs.zip_with(max_lhs, gt);
97                left.into_iter().chain(right).reduce(or)
98            }
99            Operator::NotEq => {
100                let min_lhs = self.lhs.min(catalog)?;
101                let max_lhs = self.lhs.max(catalog)?;
102
103                let min_rhs = self.rhs.min(catalog)?;
104                let max_rhs = self.rhs.max(catalog)?;
105
106                Some(and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs)))
107            }
108            Operator::Gt => Some(lt_eq(self.lhs.max(catalog)?, self.rhs.min(catalog)?)),
109            Operator::Gte => Some(lt(self.lhs.max(catalog)?, self.rhs.min(catalog)?)),
110            Operator::Lt => Some(gt_eq(self.lhs.min(catalog)?, self.rhs.max(catalog)?)),
111            Operator::Lte => Some(gt(self.lhs.min(catalog)?, self.rhs.max(catalog)?)),
112            // We can short circuit pruning expr for and
113            Operator::And => self
114                .lhs
115                .stat_falsification(catalog)
116                .into_iter()
117                .chain(self.rhs.stat_falsification(catalog))
118                .reduce(or),
119            Operator::Or => Some(and(
120                self.lhs.stat_falsification(catalog)?,
121                self.rhs.stat_falsification(catalog)?,
122            )),
123        }
124    }
125}
126
127impl VortexExpr for BinaryExpr {
128    fn as_any(&self) -> &dyn Any {
129        self
130    }
131
132    fn unchecked_evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
133        let lhs = self.lhs.unchecked_evaluate(scope)?;
134        let rhs = self.rhs.unchecked_evaluate(scope)?;
135
136        match self.operator {
137            Operator::Eq => compare(&lhs, &rhs, ArrayOperator::Eq),
138            Operator::NotEq => compare(&lhs, &rhs, ArrayOperator::NotEq),
139            Operator::Lt => compare(&lhs, &rhs, ArrayOperator::Lt),
140            Operator::Lte => compare(&lhs, &rhs, ArrayOperator::Lte),
141            Operator::Gt => compare(&lhs, &rhs, ArrayOperator::Gt),
142            Operator::Gte => compare(&lhs, &rhs, ArrayOperator::Gte),
143            Operator::And => and_kleene(&lhs, &rhs),
144            Operator::Or => or_kleene(&lhs, &rhs),
145        }
146    }
147
148    fn children(&self) -> Vec<&ExprRef> {
149        vec![&self.lhs, &self.rhs]
150    }
151
152    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
153        assert_eq!(children.len(), 2);
154        BinaryExpr::new_expr(children[0].clone(), self.operator, children[1].clone())
155    }
156
157    fn return_dtype(&self, ctx: &ScopeDType) -> VortexResult<DType> {
158        let lhs = self.lhs.return_dtype(ctx)?;
159        let rhs = self.rhs.return_dtype(ctx)?;
160        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
161    }
162}
163
164impl PartialEq for BinaryExpr {
165    fn eq(&self, other: &BinaryExpr) -> bool {
166        other.operator == self.operator && other.lhs.eq(&self.lhs) && other.rhs.eq(&self.rhs)
167    }
168}
169
170/// Create a new `BinaryExpr` using the `Eq` operator.
171///
172/// ## Example usage
173///
174/// ```
175/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
176/// use vortex_array::{Array, IntoArray, ToCanonical};
177/// use vortex_array::validity::Validity;
178/// use vortex_buffer::buffer;
179/// use vortex_expr::{eq, root, lit, Scope};
180///
181/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
182/// let result = eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
183///
184/// assert_eq!(
185///     result.to_bool().unwrap().boolean_buffer(),
186///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
187/// );
188/// ```
189pub fn eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
190    BinaryExpr::new_expr(lhs, Operator::Eq, rhs)
191}
192
193/// Create a new `BinaryExpr` using the `NotEq` operator.
194///
195/// ## Example usage
196///
197/// ```
198/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
199/// use vortex_array::{IntoArray, ToCanonical};
200/// use vortex_array::validity::Validity;
201/// use vortex_buffer::buffer;
202/// use vortex_expr::{root, lit, not_eq, Scope};
203///
204/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
205/// let result = not_eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
206///
207/// assert_eq!(
208///     result.to_bool().unwrap().boolean_buffer(),
209///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
210/// );
211/// ```
212pub fn not_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
213    BinaryExpr::new_expr(lhs, Operator::NotEq, rhs)
214}
215
216/// Create a new `BinaryExpr` using the `Gte` operator.
217///
218/// ## Example usage
219///
220/// ```
221/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
222/// use vortex_array::{IntoArray, ToCanonical};
223/// use vortex_array::validity::Validity;
224/// use vortex_buffer::buffer;
225/// use vortex_expr::{gt_eq, root, lit, Scope};
226///
227/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
228/// let result = gt_eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
229///
230/// assert_eq!(
231///     result.to_bool().unwrap().boolean_buffer(),
232///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
233/// );
234/// ```
235pub fn gt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
236    BinaryExpr::new_expr(lhs, Operator::Gte, rhs)
237}
238
239/// Create a new `BinaryExpr` using the `Gt` operator.
240///
241/// ## Example usage
242///
243/// ```
244/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
245/// use vortex_array::{IntoArray, ToCanonical};
246/// use vortex_array::validity::Validity;
247/// use vortex_buffer::buffer;
248/// use vortex_expr::{gt, root, lit, Scope};
249///
250/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
251/// let result = gt(root(), lit(2)).evaluate(&Scope::new(xs.to_array())).unwrap();
252///
253/// assert_eq!(
254///     result.to_bool().unwrap().boolean_buffer(),
255///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
256/// );
257/// ```
258pub fn gt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
259    BinaryExpr::new_expr(lhs, Operator::Gt, rhs)
260}
261
262/// Create a new `BinaryExpr` using the `Lte` operator.
263///
264/// ## Example usage
265///
266/// ```
267/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
268/// use vortex_array::{IntoArray, ToCanonical};
269/// use vortex_array::validity::Validity;
270/// use vortex_buffer::buffer;
271/// use vortex_expr::{root, lit, lt_eq, Scope};
272///
273/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
274/// let result = lt_eq(root(), lit(2)).evaluate(&Scope::new(xs.to_array())).unwrap();
275///
276/// assert_eq!(
277///     result.to_bool().unwrap().boolean_buffer(),
278///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
279/// );
280/// ```
281pub fn lt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
282    BinaryExpr::new_expr(lhs, Operator::Lte, rhs)
283}
284
285/// Create a new `BinaryExpr` using the `Lt` operator.
286///
287/// ## Example usage
288///
289/// ```
290/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
291/// use vortex_array::{IntoArray, ToCanonical};
292/// use vortex_array::validity::Validity;
293/// use vortex_buffer::buffer;
294/// use vortex_expr::{root, lit, lt, Scope};
295///
296/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
297/// let result = lt(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
298///
299/// assert_eq!(
300///     result.to_bool().unwrap().boolean_buffer(),
301///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
302/// );
303/// ```
304pub fn lt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
305    BinaryExpr::new_expr(lhs, Operator::Lt, rhs)
306}
307
308/// Create a new `BinaryExpr` using the `Or` operator.
309///
310/// ## Example usage
311///
312/// ```
313/// use vortex_array::arrays::BoolArray;
314/// use vortex_array::{IntoArray, ToCanonical};
315/// use vortex_expr::{root, lit, or, Scope};
316///
317/// let xs = BoolArray::from_iter(vec![true, false, true]);
318/// let result = or(root(), lit(false)).evaluate(&Scope::new(xs.to_array())).unwrap();
319///
320/// assert_eq!(
321///     result.to_bool().unwrap().boolean_buffer(),
322///     BoolArray::from_iter(vec![true, false, true]).boolean_buffer(),
323/// );
324/// ```
325pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
326    BinaryExpr::new_expr(lhs, Operator::Or, rhs)
327}
328
329/// Create a new `BinaryExpr` using the `And` operator.
330///
331/// ## Example usage
332///
333/// ```
334/// use vortex_array::arrays::BoolArray;
335/// use vortex_array::{IntoArray, ToCanonical};
336/// use vortex_expr::{and, root, lit, Scope};
337///
338/// let xs = BoolArray::from_iter(vec![true, false, true]);
339/// let result = and(root(), lit(true)).evaluate(&Scope::new(xs.to_array())).unwrap();
340///
341/// assert_eq!(
342///     result.to_bool().unwrap().boolean_buffer(),
343///     BoolArray::from_iter(vec![true, false, true]).boolean_buffer(),
344/// );
345/// ```
346pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
347    BinaryExpr::new_expr(lhs, Operator::And, rhs)
348}
349
350#[cfg(test)]
351mod tests {
352    use std::sync::Arc;
353
354    use vortex_dtype::{DType, Nullability};
355
356    use crate::{
357        ScopeDType, VortexExpr, and, col, eq, gt, gt_eq, lt, lt_eq, not_eq, or, test_harness,
358    };
359
360    #[test]
361    fn dtype() {
362        let dtype = test_harness::struct_dtype();
363        let bool1: Arc<dyn VortexExpr> = col("bool1");
364        let bool2: Arc<dyn VortexExpr> = col("bool2");
365        assert_eq!(
366            and(bool1.clone(), bool2.clone())
367                .return_dtype(&ScopeDType::new(dtype.clone()))
368                .unwrap(),
369            DType::Bool(Nullability::NonNullable)
370        );
371        assert_eq!(
372            or(bool1.clone(), bool2.clone())
373                .return_dtype(&ScopeDType::new(dtype.clone()))
374                .unwrap(),
375            DType::Bool(Nullability::NonNullable)
376        );
377
378        let col1: Arc<dyn VortexExpr> = col("col1");
379        let col2: Arc<dyn VortexExpr> = col("col2");
380
381        assert_eq!(
382            eq(col1.clone(), col2.clone())
383                .return_dtype(&ScopeDType::new(dtype.clone()))
384                .unwrap(),
385            DType::Bool(Nullability::Nullable)
386        );
387        assert_eq!(
388            not_eq(col1.clone(), col2.clone())
389                .return_dtype(&ScopeDType::new(dtype.clone()))
390                .unwrap(),
391            DType::Bool(Nullability::Nullable)
392        );
393        assert_eq!(
394            gt(col1.clone(), col2.clone())
395                .return_dtype(&ScopeDType::new(dtype.clone()))
396                .unwrap(),
397            DType::Bool(Nullability::Nullable)
398        );
399        assert_eq!(
400            gt_eq(col1.clone(), col2.clone())
401                .return_dtype(&ScopeDType::new(dtype.clone()))
402                .unwrap(),
403            DType::Bool(Nullability::Nullable)
404        );
405        assert_eq!(
406            lt(col1.clone(), col2.clone())
407                .return_dtype(&ScopeDType::new(dtype.clone()))
408                .unwrap(),
409            DType::Bool(Nullability::Nullable)
410        );
411        assert_eq!(
412            lt_eq(col1.clone(), col2.clone())
413                .return_dtype(&ScopeDType::new(dtype.clone()))
414                .unwrap(),
415            DType::Bool(Nullability::Nullable)
416        );
417
418        assert_eq!(
419            or(
420                lt(col1.clone(), col2.clone()),
421                not_eq(col1.clone(), col2.clone())
422            )
423            .return_dtype(&ScopeDType::new(dtype))
424            .unwrap(),
425            DType::Bool(Nullability::Nullable)
426        );
427    }
428}