vortex_expr/exprs/
binary.rs

1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use vortex_array::ArrayRef;
7use vortex_array::compute::{Operator as ArrayOperator, add, and_kleene, compare, or_kleene};
8use vortex_dtype::DType;
9use vortex_error::{VortexResult, vortex_bail};
10
11use crate::{AnalysisExpr, ExprRef, Operator, Scope, ScopeDType, StatsCatalog, VortexExpr};
12
13#[derive(Debug, Clone, Eq, Hash)]
14#[allow(clippy::derived_hash_with_manual_eq)]
15pub struct BinaryExpr {
16    lhs: ExprRef,
17    operator: Operator,
18    rhs: ExprRef,
19}
20
21impl BinaryExpr {
22    pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
23        Arc::new(Self { lhs, operator, rhs })
24    }
25
26    pub fn lhs(&self) -> &ExprRef {
27        &self.lhs
28    }
29
30    pub fn rhs(&self) -> &ExprRef {
31        &self.rhs
32    }
33
34    pub fn op(&self) -> Operator {
35        self.operator
36    }
37}
38
39impl Display for BinaryExpr {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "({} {} {})", self.lhs, self.operator, self.rhs)
42    }
43}
44
45#[cfg(feature = "proto")]
46pub(crate) mod proto {
47    use vortex_error::{VortexResult, vortex_bail};
48    use vortex_proto::expr::kind::Kind;
49
50    use crate::{BinaryExpr, ExprDeserialize, ExprRef, ExprSerializable, Id};
51
52    pub(crate) struct BinarySerde;
53
54    impl Id for BinarySerde {
55        fn id(&self) -> &'static str {
56            "binary"
57        }
58    }
59
60    impl ExprDeserialize for BinarySerde {
61        fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
62            let Kind::BinaryOp(op) = kind else {
63                vortex_bail!("wrong kind {:?}, binary", kind)
64            };
65
66            Ok(BinaryExpr::new_expr(
67                children[0].clone(),
68                (*op).try_into()?,
69                children[1].clone(),
70            ))
71        }
72    }
73
74    impl ExprSerializable for BinaryExpr {
75        fn id(&self) -> &'static str {
76            BinarySerde.id()
77        }
78
79        fn serialize_kind(&self) -> VortexResult<Kind> {
80            Ok(Kind::BinaryOp(self.operator.into()))
81        }
82    }
83}
84
85impl AnalysisExpr for BinaryExpr {
86    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
87        match self.operator {
88            Operator::Eq => {
89                let min_lhs = self.lhs.min(catalog);
90                let max_lhs = self.lhs.max(catalog);
91
92                let min_rhs = self.rhs.min(catalog);
93                let max_rhs = self.rhs.max(catalog);
94
95                let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
96                let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
97                left.into_iter().chain(right).reduce(or)
98            }
99            Operator::NotEq => {
100                let min_lhs = self.lhs.min(catalog)?;
101                let max_lhs = self.lhs.max(catalog)?;
102
103                let min_rhs = self.rhs.min(catalog)?;
104                let max_rhs = self.rhs.max(catalog)?;
105
106                Some(and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs)))
107            }
108            Operator::Gt => Some(lt_eq(self.lhs.max(catalog)?, self.rhs.min(catalog)?)),
109            Operator::Gte => Some(lt(self.lhs.max(catalog)?, self.rhs.min(catalog)?)),
110            Operator::Lt => Some(gt_eq(self.lhs.min(catalog)?, self.rhs.max(catalog)?)),
111            Operator::Lte => Some(gt(self.lhs.min(catalog)?, self.rhs.max(catalog)?)),
112            Operator::And => self
113                .lhs
114                .stat_falsification(catalog)
115                .into_iter()
116                .chain(self.rhs.stat_falsification(catalog))
117                .reduce(or),
118            Operator::Or => Some(and(
119                self.lhs.stat_falsification(catalog)?,
120                self.rhs.stat_falsification(catalog)?,
121            )),
122            Operator::Add => None,
123        }
124    }
125}
126
127impl VortexExpr for BinaryExpr {
128    fn as_any(&self) -> &dyn Any {
129        self
130    }
131
132    fn unchecked_evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
133        let lhs = self.lhs.unchecked_evaluate(scope)?;
134        let rhs = self.rhs.unchecked_evaluate(scope)?;
135
136        match self.operator {
137            Operator::Eq => compare(&lhs, &rhs, ArrayOperator::Eq),
138            Operator::NotEq => compare(&lhs, &rhs, ArrayOperator::NotEq),
139            Operator::Lt => compare(&lhs, &rhs, ArrayOperator::Lt),
140            Operator::Lte => compare(&lhs, &rhs, ArrayOperator::Lte),
141            Operator::Gt => compare(&lhs, &rhs, ArrayOperator::Gt),
142            Operator::Gte => compare(&lhs, &rhs, ArrayOperator::Gte),
143            Operator::And => and_kleene(&lhs, &rhs),
144            Operator::Or => or_kleene(&lhs, &rhs),
145            Operator::Add => add(&lhs, &rhs),
146        }
147    }
148
149    fn children(&self) -> Vec<&ExprRef> {
150        vec![&self.lhs, &self.rhs]
151    }
152
153    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
154        assert_eq!(children.len(), 2);
155        BinaryExpr::new_expr(children[0].clone(), self.operator, children[1].clone())
156    }
157
158    fn return_dtype(&self, ctx: &ScopeDType) -> VortexResult<DType> {
159        let lhs = self.lhs.return_dtype(ctx)?;
160        let rhs = self.rhs.return_dtype(ctx)?;
161
162        if self.operator == Operator::Add {
163            if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
164                return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
165            }
166            vortex_bail!("incompatible types for checked add: {} {}", lhs, rhs);
167        }
168
169        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
170    }
171}
172
173impl PartialEq for BinaryExpr {
174    fn eq(&self, other: &BinaryExpr) -> bool {
175        other.operator == self.operator && other.lhs.eq(&self.lhs) && other.rhs.eq(&self.rhs)
176    }
177}
178
179/// Create a new `BinaryExpr` using the `Eq` operator.
180///
181/// ## Example usage
182///
183/// ```
184/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
185/// use vortex_array::{Array, IntoArray, ToCanonical};
186/// use vortex_array::validity::Validity;
187/// use vortex_buffer::buffer;
188/// use vortex_expr::{eq, root, lit, Scope};
189///
190/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
191/// let result = eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
192///
193/// assert_eq!(
194///     result.to_bool().unwrap().boolean_buffer(),
195///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
196/// );
197/// ```
198pub fn eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
199    BinaryExpr::new_expr(lhs, Operator::Eq, rhs)
200}
201
202/// Create a new `BinaryExpr` using the `NotEq` operator.
203///
204/// ## Example usage
205///
206/// ```
207/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
208/// use vortex_array::{IntoArray, ToCanonical};
209/// use vortex_array::validity::Validity;
210/// use vortex_buffer::buffer;
211/// use vortex_expr::{root, lit, not_eq, Scope};
212///
213/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
214/// let result = not_eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
215///
216/// assert_eq!(
217///     result.to_bool().unwrap().boolean_buffer(),
218///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
219/// );
220/// ```
221pub fn not_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
222    BinaryExpr::new_expr(lhs, Operator::NotEq, rhs)
223}
224
225/// Create a new `BinaryExpr` using the `Gte` operator.
226///
227/// ## Example usage
228///
229/// ```
230/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
231/// use vortex_array::{IntoArray, ToCanonical};
232/// use vortex_array::validity::Validity;
233/// use vortex_buffer::buffer;
234/// use vortex_expr::{gt_eq, root, lit, Scope};
235///
236/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
237/// let result = gt_eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
238///
239/// assert_eq!(
240///     result.to_bool().unwrap().boolean_buffer(),
241///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
242/// );
243/// ```
244pub fn gt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
245    BinaryExpr::new_expr(lhs, Operator::Gte, rhs)
246}
247
248/// Create a new `BinaryExpr` using the `Gt` operator.
249///
250/// ## Example usage
251///
252/// ```
253/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
254/// use vortex_array::{IntoArray, ToCanonical};
255/// use vortex_array::validity::Validity;
256/// use vortex_buffer::buffer;
257/// use vortex_expr::{gt, root, lit, Scope};
258///
259/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
260/// let result = gt(root(), lit(2)).evaluate(&Scope::new(xs.to_array())).unwrap();
261///
262/// assert_eq!(
263///     result.to_bool().unwrap().boolean_buffer(),
264///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
265/// );
266/// ```
267pub fn gt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
268    BinaryExpr::new_expr(lhs, Operator::Gt, rhs)
269}
270
271/// Create a new `BinaryExpr` using the `Lte` operator.
272///
273/// ## Example usage
274///
275/// ```
276/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
277/// use vortex_array::{IntoArray, ToCanonical};
278/// use vortex_array::validity::Validity;
279/// use vortex_buffer::buffer;
280/// use vortex_expr::{root, lit, lt_eq, Scope};
281///
282/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
283/// let result = lt_eq(root(), lit(2)).evaluate(&Scope::new(xs.to_array())).unwrap();
284///
285/// assert_eq!(
286///     result.to_bool().unwrap().boolean_buffer(),
287///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
288/// );
289/// ```
290pub fn lt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
291    BinaryExpr::new_expr(lhs, Operator::Lte, rhs)
292}
293
294/// Create a new `BinaryExpr` using the `Lt` operator.
295///
296/// ## Example usage
297///
298/// ```
299/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
300/// use vortex_array::{IntoArray, ToCanonical};
301/// use vortex_array::validity::Validity;
302/// use vortex_buffer::buffer;
303/// use vortex_expr::{root, lit, lt, Scope};
304///
305/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
306/// let result = lt(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
307///
308/// assert_eq!(
309///     result.to_bool().unwrap().boolean_buffer(),
310///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
311/// );
312/// ```
313pub fn lt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
314    BinaryExpr::new_expr(lhs, Operator::Lt, rhs)
315}
316
317/// Create a new `BinaryExpr` using the `Or` operator.
318///
319/// ## Example usage
320///
321/// ```
322/// use vortex_array::arrays::BoolArray;
323/// use vortex_array::{IntoArray, ToCanonical};
324/// use vortex_expr::{root, lit, or, Scope};
325///
326/// let xs = BoolArray::from_iter(vec![true, false, true]);
327/// let result = or(root(), lit(false)).evaluate(&Scope::new(xs.to_array())).unwrap();
328///
329/// assert_eq!(
330///     result.to_bool().unwrap().boolean_buffer(),
331///     BoolArray::from_iter(vec![true, false, true]).boolean_buffer(),
332/// );
333/// ```
334pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
335    BinaryExpr::new_expr(lhs, Operator::Or, rhs)
336}
337
338/// Collects a list of `or`ed values into a single vortex, expr
339/// [x, y, z] => x or (y or z)
340pub fn or_collect<I>(iter: I) -> Option<ExprRef>
341where
342    I: IntoIterator<Item = ExprRef>,
343    I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
344{
345    let mut iter = iter.into_iter();
346    let first = iter.next_back()?;
347    Some(iter.rfold(first, |acc, elem| or(elem, acc)))
348}
349
350/// Create a new `BinaryExpr` using the `And` operator.
351///
352/// ## Example usage
353///
354/// ```
355/// use vortex_array::arrays::BoolArray;
356/// use vortex_array::{IntoArray, ToCanonical};
357/// use vortex_expr::{and, root, lit, Scope};
358///
359/// let xs = BoolArray::from_iter(vec![true, false, true]);
360/// let result = and(root(), lit(true)).evaluate(&Scope::new(xs.to_array())).unwrap();
361///
362/// assert_eq!(
363///     result.to_bool().unwrap().boolean_buffer(),
364///     BoolArray::from_iter(vec![true, false, true]).boolean_buffer(),
365/// );
366/// ```
367pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
368    BinaryExpr::new_expr(lhs, Operator::And, rhs)
369}
370
371/// Collects a list of `and`ed values into a single vortex, expr
372/// [x, y, z] => x and (y and z)
373pub fn and_collect<I>(iter: I) -> Option<ExprRef>
374where
375    I: IntoIterator<Item = ExprRef>,
376    I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
377{
378    let mut iter = iter.into_iter();
379    let first = iter.next_back()?;
380    Some(iter.rfold(first, |acc, elem| and(elem, acc)))
381}
382
383/// Collects a list of `and`ed values into a single vortex, expr
384/// [x, y, z] => x and (y and z)
385pub fn and_collect_right<I>(iter: I) -> Option<ExprRef>
386where
387    I: IntoIterator<Item = ExprRef>,
388{
389    let iter = iter.into_iter();
390    iter.reduce(and)
391}
392
393/// Create a new `BinaryExpr` using the `CheckedAdd` operator.
394///
395/// ## Example usage
396///
397/// ```
398/// use vortex_array::IntoArray;
399/// use vortex_array::arrow::IntoArrowArray as _;
400/// use vortex_buffer::buffer;
401/// use vortex_expr::{Scope, checked_add, lit, root};
402///
403/// let xs = buffer![1, 2, 3].into_array();
404/// let result = checked_add(root(), lit(5))
405///     .evaluate(&Scope::new(xs.to_array()))
406///     .unwrap();
407///
408/// assert_eq!(
409///     &result.into_arrow_preferred().unwrap(),
410///     &buffer![6, 7, 8]
411///         .into_array()
412///         .into_arrow_preferred()
413///         .unwrap()
414/// );
415/// ```
416pub fn checked_add(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
417    BinaryExpr::new_expr(lhs, Operator::Add, rhs)
418}
419
420#[cfg(test)]
421mod tests {
422    use std::sync::Arc;
423
424    use vortex_dtype::{DType, Nullability};
425
426    use crate::{
427        ScopeDType, VortexExpr, and, and_collect, and_collect_right, col, eq, gt, gt_eq, lit, lt,
428        lt_eq, not_eq, or, test_harness,
429    };
430
431    #[test]
432    fn and_collect_left_assoc() {
433        let values = vec![lit(1), lit(2), lit(3)];
434        assert_eq!(
435            Some(and(lit(1), and(lit(2), lit(3)))),
436            and_collect(values.into_iter())
437        );
438    }
439
440    #[test]
441    fn and_collect_right_assoc() {
442        let values = vec![lit(1), lit(2), lit(3)];
443        assert_eq!(
444            Some(and(and(lit(1), lit(2)), lit(3))),
445            and_collect_right(values.into_iter())
446        );
447    }
448
449    #[test]
450    fn dtype() {
451        let dtype = test_harness::struct_dtype();
452        let bool1: Arc<dyn VortexExpr> = col("bool1");
453        let bool2: Arc<dyn VortexExpr> = col("bool2");
454        assert_eq!(
455            and(bool1.clone(), bool2.clone())
456                .return_dtype(&ScopeDType::new(dtype.clone()))
457                .unwrap(),
458            DType::Bool(Nullability::NonNullable)
459        );
460        assert_eq!(
461            or(bool1.clone(), bool2.clone())
462                .return_dtype(&ScopeDType::new(dtype.clone()))
463                .unwrap(),
464            DType::Bool(Nullability::NonNullable)
465        );
466
467        let col1: Arc<dyn VortexExpr> = col("col1");
468        let col2: Arc<dyn VortexExpr> = col("col2");
469
470        assert_eq!(
471            eq(col1.clone(), col2.clone())
472                .return_dtype(&ScopeDType::new(dtype.clone()))
473                .unwrap(),
474            DType::Bool(Nullability::Nullable)
475        );
476        assert_eq!(
477            not_eq(col1.clone(), col2.clone())
478                .return_dtype(&ScopeDType::new(dtype.clone()))
479                .unwrap(),
480            DType::Bool(Nullability::Nullable)
481        );
482        assert_eq!(
483            gt(col1.clone(), col2.clone())
484                .return_dtype(&ScopeDType::new(dtype.clone()))
485                .unwrap(),
486            DType::Bool(Nullability::Nullable)
487        );
488        assert_eq!(
489            gt_eq(col1.clone(), col2.clone())
490                .return_dtype(&ScopeDType::new(dtype.clone()))
491                .unwrap(),
492            DType::Bool(Nullability::Nullable)
493        );
494        assert_eq!(
495            lt(col1.clone(), col2.clone())
496                .return_dtype(&ScopeDType::new(dtype.clone()))
497                .unwrap(),
498            DType::Bool(Nullability::Nullable)
499        );
500        assert_eq!(
501            lt_eq(col1.clone(), col2.clone())
502                .return_dtype(&ScopeDType::new(dtype.clone()))
503                .unwrap(),
504            DType::Bool(Nullability::Nullable)
505        );
506
507        assert_eq!(
508            or(
509                lt(col1.clone(), col2.clone()),
510                not_eq(col1.clone(), col2.clone())
511            )
512            .return_dtype(&ScopeDType::new(dtype))
513            .unwrap(),
514            DType::Bool(Nullability::Nullable)
515        );
516    }
517}