vortex_expr/exprs/
binary.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::hash::Hash;
6
7use vortex_array::compute::{Operator as ArrayOperator, add, and_kleene, compare, or_kleene, sub};
8use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata};
9use vortex_dtype::DType;
10use vortex_error::{VortexResult, vortex_bail};
11use vortex_proto::expr as pb;
12
13use crate::{
14    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Operator, Scope, StatsCatalog,
15    VTable, lit, vtable,
16};
17
18vtable!(Binary);
19
20#[allow(clippy::derived_hash_with_manual_eq)]
21#[derive(Debug, Clone, Hash, Eq)]
22pub struct BinaryExpr {
23    lhs: ExprRef,
24    operator: Operator,
25    rhs: ExprRef,
26}
27
28impl PartialEq for BinaryExpr {
29    fn eq(&self, other: &Self) -> bool {
30        self.lhs.eq(&other.lhs) && self.operator == other.operator && self.rhs.eq(&other.rhs)
31    }
32}
33
34pub struct BinaryExprEncoding;
35
36impl VTable for BinaryVTable {
37    type Expr = BinaryExpr;
38    type Encoding = BinaryExprEncoding;
39    type Metadata = ProstMetadata<pb::BinaryOpts>;
40
41    fn id(_encoding: &Self::Encoding) -> ExprId {
42        ExprId::new_ref("binary")
43    }
44
45    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
46        ExprEncodingRef::new_ref(BinaryExprEncoding.as_ref())
47    }
48
49    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
50        Some(ProstMetadata(pb::BinaryOpts {
51            op: expr.operator.into(),
52        }))
53    }
54
55    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
56        vec![expr.lhs(), expr.rhs()]
57    }
58
59    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
60        Ok(BinaryExpr::new(
61            children[0].clone(),
62            expr.op(),
63            children[1].clone(),
64        ))
65    }
66
67    fn build(
68        _encoding: &Self::Encoding,
69        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
70        children: Vec<ExprRef>,
71    ) -> VortexResult<Self::Expr> {
72        Ok(BinaryExpr::new(
73            children[0].clone(),
74            metadata.op().into(),
75            children[1].clone(),
76        ))
77    }
78
79    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
80        let lhs = expr.lhs.unchecked_evaluate(scope)?;
81        let rhs = expr.rhs.unchecked_evaluate(scope)?;
82
83        match expr.operator {
84            Operator::Eq => compare(&lhs, &rhs, ArrayOperator::Eq),
85            Operator::NotEq => compare(&lhs, &rhs, ArrayOperator::NotEq),
86            Operator::Lt => compare(&lhs, &rhs, ArrayOperator::Lt),
87            Operator::Lte => compare(&lhs, &rhs, ArrayOperator::Lte),
88            Operator::Gt => compare(&lhs, &rhs, ArrayOperator::Gt),
89            Operator::Gte => compare(&lhs, &rhs, ArrayOperator::Gte),
90            Operator::And => and_kleene(&lhs, &rhs),
91            Operator::Or => or_kleene(&lhs, &rhs),
92            Operator::Add => add(&lhs, &rhs),
93            Operator::Sub => sub(&lhs, &rhs),
94        }
95    }
96
97    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
98        let lhs = expr.lhs.return_dtype(scope)?;
99        let rhs = expr.rhs.return_dtype(scope)?;
100
101        if expr.operator == Operator::Add {
102            if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
103                return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
104            }
105            vortex_bail!("incompatible types for checked add: {} {}", lhs, rhs);
106        }
107
108        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
109    }
110}
111
112impl BinaryExpr {
113    pub fn new(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> Self {
114        Self { lhs, operator, rhs }
115    }
116
117    pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
118        Self::new(lhs, operator, rhs).into_expr()
119    }
120
121    pub fn lhs(&self) -> &ExprRef {
122        &self.lhs
123    }
124
125    pub fn rhs(&self) -> &ExprRef {
126        &self.rhs
127    }
128
129    pub fn op(&self) -> Operator {
130        self.operator
131    }
132}
133
134impl Display for BinaryExpr {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        write!(f, "({} {} {})", self.lhs, self.operator, self.rhs)
137    }
138}
139
140impl AnalysisExpr for BinaryExpr {
141    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
142        // Wrap another predicate with an optional NaNCount check, if the stat is available.
143        //
144        // For example, regular pruning conversion for `A >= B` would be
145        //
146        //      A.max < B.min
147        //
148        // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting
149        // in:
150        //
151        //      (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min
152        //
153        // Non-floating point column and literal expressions should be unaffected as they do not
154        // have a nan_count statistic defined.
155        #[inline]
156        fn with_nan_predicate(
157            lhs: &ExprRef,
158            rhs: &ExprRef,
159            value_predicate: ExprRef,
160            catalog: &mut dyn StatsCatalog,
161        ) -> ExprRef {
162            let nan_predicate = lhs
163                .nan_count(catalog)
164                .into_iter()
165                .chain(rhs.nan_count(catalog))
166                .map(|nans| eq(nans, lit(0u64)))
167                .reduce(and);
168
169            if let Some(nan_check) = nan_predicate {
170                and(nan_check, value_predicate)
171            } else {
172                value_predicate
173            }
174        }
175
176        match self.operator {
177            Operator::Eq => {
178                let min_lhs = self.lhs.min(catalog);
179                let max_lhs = self.lhs.max(catalog);
180
181                let min_rhs = self.rhs.min(catalog);
182                let max_rhs = self.rhs.max(catalog);
183
184                let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
185                let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
186
187                let min_max_check = left.into_iter().chain(right).reduce(or)?;
188
189                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
190                Some(with_nan_predicate(
191                    self.lhs(),
192                    self.rhs(),
193                    min_max_check,
194                    catalog,
195                ))
196            }
197            Operator::NotEq => {
198                let min_lhs = self.lhs.min(catalog)?;
199                let max_lhs = self.lhs.max(catalog)?;
200
201                let min_rhs = self.rhs.min(catalog)?;
202                let max_rhs = self.rhs.max(catalog)?;
203
204                let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
205
206                Some(with_nan_predicate(
207                    self.lhs(),
208                    self.rhs(),
209                    min_max_check,
210                    catalog,
211                ))
212            }
213            Operator::Gt => {
214                let min_max_check = lt_eq(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
215
216                Some(with_nan_predicate(
217                    self.lhs(),
218                    self.rhs(),
219                    min_max_check,
220                    catalog,
221                ))
222            }
223            Operator::Gte => {
224                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
225                let min_max_check = lt(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
226
227                Some(with_nan_predicate(
228                    self.lhs(),
229                    self.rhs(),
230                    min_max_check,
231                    catalog,
232                ))
233            }
234            Operator::Lt => {
235                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
236                let min_max_check = gt_eq(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
237
238                Some(with_nan_predicate(
239                    self.lhs(),
240                    self.rhs(),
241                    min_max_check,
242                    catalog,
243                ))
244            }
245            Operator::Lte => {
246                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
247                let min_max_check = gt(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
248
249                Some(with_nan_predicate(
250                    self.lhs(),
251                    self.rhs(),
252                    min_max_check,
253                    catalog,
254                ))
255            }
256            Operator::And => self
257                .lhs
258                .stat_falsification(catalog)
259                .into_iter()
260                .chain(self.rhs.stat_falsification(catalog))
261                .reduce(or),
262            Operator::Or => Some(and(
263                self.lhs.stat_falsification(catalog)?,
264                self.rhs.stat_falsification(catalog)?,
265            )),
266            Operator::Add | Operator::Sub => None,
267        }
268    }
269}
270
271/// Create a new [`BinaryExpr`] using the [`Eq`](crate::Operator::Eq) operator.
272///
273/// ## Example usage
274///
275/// ```
276/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
277/// # use vortex_array::{Array, IntoArray, ToCanonical};
278/// # use vortex_array::validity::Validity;
279/// # use vortex_buffer::buffer;
280/// # use vortex_expr::{eq, root, lit, Scope};
281/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
282/// let result = eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
283///
284/// assert_eq!(
285///     result.to_bool().unwrap().boolean_buffer(),
286///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
287/// );
288/// ```
289pub fn eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
290    BinaryExpr::new(lhs, Operator::Eq, rhs).into_expr()
291}
292
293/// Create a new [`BinaryExpr`] using the [`NotEq`](crate::Operator::NotEq) operator.
294///
295/// ## Example usage
296///
297/// ```
298/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
299/// # use vortex_array::{IntoArray, ToCanonical};
300/// # use vortex_array::validity::Validity;
301/// # use vortex_buffer::buffer;
302/// # use vortex_expr::{root, lit, not_eq, Scope};
303/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
304/// let result = not_eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
305///
306/// assert_eq!(
307///     result.to_bool().unwrap().boolean_buffer(),
308///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
309/// );
310/// ```
311pub fn not_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
312    BinaryExpr::new(lhs, Operator::NotEq, rhs).into_expr()
313}
314
315/// Create a new [`BinaryExpr`] using the [`Gte`](crate::Operator::Gte) operator.
316///
317/// ## Example usage
318///
319/// ```
320/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
321/// # use vortex_array::{IntoArray, ToCanonical};
322/// # use vortex_array::validity::Validity;
323/// # use vortex_buffer::buffer;
324/// # use vortex_expr::{gt_eq, root, lit, Scope};
325/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
326/// let result = gt_eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
327///
328/// assert_eq!(
329///     result.to_bool().unwrap().boolean_buffer(),
330///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
331/// );
332/// ```
333pub fn gt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
334    BinaryExpr::new(lhs, Operator::Gte, rhs).into_expr()
335}
336
337/// Create a new [`BinaryExpr`] using the [`Gt`](crate::Operator::Gt) operator.
338///
339/// ## Example usage
340///
341/// ```
342/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
343/// # use vortex_array::{IntoArray, ToCanonical};
344/// # use vortex_array::validity::Validity;
345/// # use vortex_buffer::buffer;
346/// # use vortex_expr::{gt, root, lit, Scope};
347/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
348/// let result = gt(root(), lit(2)).evaluate(&Scope::new(xs.to_array())).unwrap();
349///
350/// assert_eq!(
351///     result.to_bool().unwrap().boolean_buffer(),
352///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
353/// );
354/// ```
355pub fn gt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
356    BinaryExpr::new(lhs, Operator::Gt, rhs).into_expr()
357}
358
359/// Create a new [`BinaryExpr`] using the [`Lte`](crate::Operator::Lte) operator.
360///
361/// ## Example usage
362///
363/// ```
364/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
365/// # use vortex_array::{IntoArray, ToCanonical};
366/// # use vortex_array::validity::Validity;
367/// # use vortex_buffer::buffer;
368/// # use vortex_expr::{root, lit, lt_eq, Scope};
369/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
370/// let result = lt_eq(root(), lit(2)).evaluate(&Scope::new(xs.to_array())).unwrap();
371///
372/// assert_eq!(
373///     result.to_bool().unwrap().boolean_buffer(),
374///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
375/// );
376/// ```
377pub fn lt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
378    BinaryExpr::new(lhs, Operator::Lte, rhs).into_expr()
379}
380
381/// Create a new [`BinaryExpr`] using the [`Lt`](crate::Operator::Lt) operator.
382///
383/// ## Example usage
384///
385/// ```
386/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
387/// # use vortex_array::{IntoArray, ToCanonical};
388/// # use vortex_array::validity::Validity;
389/// # use vortex_buffer::buffer;
390/// # use vortex_expr::{root, lit, lt, Scope};
391/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
392/// let result = lt(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
393///
394/// assert_eq!(
395///     result.to_bool().unwrap().boolean_buffer(),
396///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
397/// );
398/// ```
399pub fn lt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
400    BinaryExpr::new(lhs, Operator::Lt, rhs).into_expr()
401}
402
403/// Create a new [`BinaryExpr`] using the [`Or`](crate::Operator::Or) operator.
404///
405/// ## Example usage
406///
407/// ```
408/// # use vortex_array::arrays::BoolArray;
409/// # use vortex_array::{IntoArray, ToCanonical};
410/// # use vortex_expr::{root, lit, or, Scope};
411/// let xs = BoolArray::from_iter(vec![true, false, true]);
412/// let result = or(root(), lit(false)).evaluate(&Scope::new(xs.to_array())).unwrap();
413///
414/// assert_eq!(
415///     result.to_bool().unwrap().boolean_buffer(),
416///     BoolArray::from_iter(vec![true, false, true]).boolean_buffer(),
417/// );
418/// ```
419pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
420    BinaryExpr::new(lhs, Operator::Or, rhs).into_expr()
421}
422
423/// Collects a list of `or`ed values into a single vortex, expr
424/// [x, y, z] => x or (y or z)
425pub fn or_collect<I>(iter: I) -> Option<ExprRef>
426where
427    I: IntoIterator<Item = ExprRef>,
428    I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
429{
430    let mut iter = iter.into_iter();
431    let first = iter.next_back()?;
432    Some(iter.rfold(first, |acc, elem| or(elem, acc)))
433}
434
435/// Create a new [`BinaryExpr`] using the [`And`](crate::Operator::And) operator.
436///
437/// ## Example usage
438///
439/// ```
440/// # use vortex_array::arrays::BoolArray;
441/// # use vortex_array::{IntoArray, ToCanonical};
442/// # use vortex_expr::{and, root, lit, Scope};
443/// let xs = BoolArray::from_iter(vec![true, false, true]);
444/// let result = and(root(), lit(true)).evaluate(&Scope::new(xs.to_array())).unwrap();
445///
446/// assert_eq!(
447///     result.to_bool().unwrap().boolean_buffer(),
448///     BoolArray::from_iter(vec![true, false, true]).boolean_buffer(),
449/// );
450/// ```
451pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
452    BinaryExpr::new(lhs, Operator::And, rhs).into_expr()
453}
454
455/// Collects a list of `and`ed values into a single vortex, expr
456/// [x, y, z] => x and (y and z)
457pub fn and_collect<I>(iter: I) -> Option<ExprRef>
458where
459    I: IntoIterator<Item = ExprRef>,
460    I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
461{
462    let mut iter = iter.into_iter();
463    let first = iter.next_back()?;
464    Some(iter.rfold(first, |acc, elem| and(elem, acc)))
465}
466
467/// Collects a list of `and`ed values into a single vortex, expr
468/// [x, y, z] => x and (y and z)
469pub fn and_collect_right<I>(iter: I) -> Option<ExprRef>
470where
471    I: IntoIterator<Item = ExprRef>,
472{
473    let iter = iter.into_iter();
474    iter.reduce(and)
475}
476
477/// Create a new [`BinaryExpr`] using the [`Add`](crate::Operator::Add) operator.
478///
479/// ## Example usage
480///
481/// ```
482/// # use vortex_array::IntoArray;
483/// # use vortex_array::arrow::IntoArrowArray as _;
484/// # use vortex_buffer::buffer;
485/// # use vortex_expr::{Scope, checked_add, lit, root};
486/// let xs = buffer![1, 2, 3].into_array();
487/// let result = checked_add(root(), lit(5))
488///     .evaluate(&Scope::new(xs.to_array()))
489///     .unwrap();
490///
491/// assert_eq!(
492///     &result.into_arrow_preferred().unwrap(),
493///     &buffer![6, 7, 8]
494///         .into_array()
495///         .into_arrow_preferred()
496///         .unwrap()
497/// );
498/// ```
499pub fn checked_add(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
500    BinaryExpr::new(lhs, Operator::Add, rhs).into_expr()
501}
502
503#[cfg(test)]
504mod tests {
505    use std::sync::Arc;
506
507    use vortex_dtype::{DType, Nullability};
508
509    use crate::{
510        VortexExpr, and, and_collect, and_collect_right, col, eq, gt, gt_eq, lit, lt, lt_eq,
511        not_eq, or, test_harness,
512    };
513
514    #[test]
515    fn and_collect_left_assoc() {
516        let values = vec![lit(1), lit(2), lit(3)];
517        assert_eq!(
518            Some(and(lit(1), and(lit(2), lit(3)))),
519            and_collect(values.into_iter())
520        );
521    }
522
523    #[test]
524    fn and_collect_right_assoc() {
525        let values = vec![lit(1), lit(2), lit(3)];
526        assert_eq!(
527            Some(and(and(lit(1), lit(2)), lit(3))),
528            and_collect_right(values.into_iter())
529        );
530    }
531
532    #[test]
533    fn dtype() {
534        let dtype = test_harness::struct_dtype();
535        let bool1: Arc<dyn VortexExpr> = col("bool1");
536        let bool2: Arc<dyn VortexExpr> = col("bool2");
537        assert_eq!(
538            and(bool1.clone(), bool2.clone())
539                .return_dtype(&dtype)
540                .unwrap(),
541            DType::Bool(Nullability::NonNullable)
542        );
543        assert_eq!(
544            or(bool1.clone(), bool2.clone())
545                .return_dtype(&dtype)
546                .unwrap(),
547            DType::Bool(Nullability::NonNullable)
548        );
549
550        let col1: Arc<dyn VortexExpr> = col("col1");
551        let col2: Arc<dyn VortexExpr> = col("col2");
552
553        assert_eq!(
554            eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
555            DType::Bool(Nullability::Nullable)
556        );
557        assert_eq!(
558            not_eq(col1.clone(), col2.clone())
559                .return_dtype(&dtype)
560                .unwrap(),
561            DType::Bool(Nullability::Nullable)
562        );
563        assert_eq!(
564            gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
565            DType::Bool(Nullability::Nullable)
566        );
567        assert_eq!(
568            gt_eq(col1.clone(), col2.clone())
569                .return_dtype(&dtype)
570                .unwrap(),
571            DType::Bool(Nullability::Nullable)
572        );
573        assert_eq!(
574            lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
575            DType::Bool(Nullability::Nullable)
576        );
577        assert_eq!(
578            lt_eq(col1.clone(), col2.clone())
579                .return_dtype(&dtype)
580                .unwrap(),
581            DType::Bool(Nullability::Nullable)
582        );
583
584        assert_eq!(
585            or(
586                lt(col1.clone(), col2.clone()),
587                not_eq(col1.clone(), col2.clone())
588            )
589            .return_dtype(&dtype)
590            .unwrap(),
591            DType::Bool(Nullability::Nullable)
592        );
593    }
594}