vortex_expr/exprs/
binary.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5use std::sync::Arc;
6
7use vortex_array::compute::{add, and_kleene, compare, div, mul, or_kleene, sub};
8use vortex_array::operator::OperatorRef;
9use vortex_array::operator::compare::CompareOperator;
10use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata, compute};
11use vortex_dtype::DType;
12use vortex_error::{VortexResult, vortex_bail};
13use vortex_proto::expr as pb;
14
15use crate::display::{DisplayAs, DisplayFormat};
16use crate::{
17    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Operator, Scope, StatsCatalog,
18    VTable, lit, vtable,
19};
20
21vtable!(Binary);
22
23#[allow(clippy::derived_hash_with_manual_eq)]
24#[derive(Debug, Clone, Hash, Eq)]
25pub struct BinaryExpr {
26    lhs: ExprRef,
27    operator: Operator,
28    rhs: ExprRef,
29}
30
31impl PartialEq for BinaryExpr {
32    fn eq(&self, other: &Self) -> bool {
33        self.lhs.eq(&other.lhs) && self.operator == other.operator && self.rhs.eq(&other.rhs)
34    }
35}
36
37pub struct BinaryExprEncoding;
38
39impl VTable for BinaryVTable {
40    type Expr = BinaryExpr;
41    type Encoding = BinaryExprEncoding;
42    type Metadata = ProstMetadata<pb::BinaryOpts>;
43
44    fn id(_encoding: &Self::Encoding) -> ExprId {
45        ExprId::new_ref("binary")
46    }
47
48    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
49        ExprEncodingRef::new_ref(BinaryExprEncoding.as_ref())
50    }
51
52    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
53        Some(ProstMetadata(pb::BinaryOpts {
54            op: expr.operator.into(),
55        }))
56    }
57
58    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
59        vec![expr.lhs(), expr.rhs()]
60    }
61
62    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
63        Ok(BinaryExpr::new(
64            children[0].clone(),
65            expr.op(),
66            children[1].clone(),
67        ))
68    }
69
70    fn build(
71        _encoding: &Self::Encoding,
72        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
73        children: Vec<ExprRef>,
74    ) -> VortexResult<Self::Expr> {
75        Ok(BinaryExpr::new(
76            children[0].clone(),
77            metadata.op().into(),
78            children[1].clone(),
79        ))
80    }
81
82    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
83        let lhs = expr.lhs.unchecked_evaluate(scope)?;
84        let rhs = expr.rhs.unchecked_evaluate(scope)?;
85
86        match expr.operator {
87            Operator::Eq => compare(&lhs, &rhs, compute::Operator::Eq),
88            Operator::NotEq => compare(&lhs, &rhs, compute::Operator::NotEq),
89            Operator::Lt => compare(&lhs, &rhs, compute::Operator::Lt),
90            Operator::Lte => compare(&lhs, &rhs, compute::Operator::Lte),
91            Operator::Gt => compare(&lhs, &rhs, compute::Operator::Gt),
92            Operator::Gte => compare(&lhs, &rhs, compute::Operator::Gte),
93            Operator::And => and_kleene(&lhs, &rhs),
94            Operator::Or => or_kleene(&lhs, &rhs),
95            Operator::Add => add(&lhs, &rhs),
96            Operator::Sub => sub(&lhs, &rhs),
97            Operator::Mul => mul(&lhs, &rhs),
98            Operator::Div => div(&lhs, &rhs),
99        }
100    }
101
102    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
103        let lhs = expr.lhs.return_dtype(scope)?;
104        let rhs = expr.rhs.return_dtype(scope)?;
105
106        if expr.operator.is_arithmetic() {
107            if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
108                return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
109            }
110            vortex_bail!(
111                "incompatible types for arithmetic operation: {} {}",
112                lhs,
113                rhs
114            );
115        }
116
117        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
118    }
119
120    fn operator(expr: &BinaryExpr, scope: &OperatorRef) -> VortexResult<Option<OperatorRef>> {
121        let Some(lhs) = expr.lhs.operator(scope)? else {
122            return Ok(None);
123        };
124        let Some(rhs) = expr.rhs.operator(scope)? else {
125            return Ok(None);
126        };
127        let Ok(op): VortexResult<compute::Operator> = expr.operator.try_into() else {
128            return Ok(None);
129        };
130        Ok(Some(Arc::new(CompareOperator::try_new(lhs, rhs, op)?)))
131    }
132}
133
134impl BinaryExpr {
135    pub fn new(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> Self {
136        Self { lhs, operator, rhs }
137    }
138
139    pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
140        Self::new(lhs, operator, rhs).into_expr()
141    }
142
143    pub fn lhs(&self) -> &ExprRef {
144        &self.lhs
145    }
146
147    pub fn rhs(&self) -> &ExprRef {
148        &self.rhs
149    }
150
151    pub fn op(&self) -> Operator {
152        self.operator
153    }
154}
155
156impl DisplayAs for BinaryExpr {
157    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
158        match df {
159            DisplayFormat::Compact => {
160                write!(f, "({} {} {})", self.lhs, self.operator, self.rhs)
161            }
162            DisplayFormat::Tree => {
163                write!(f, "Binary({})", self.operator)
164            }
165        }
166    }
167
168    fn child_names(&self) -> Option<Vec<String>> {
169        Some(vec!["lhs".to_string(), "rhs".to_string()])
170    }
171}
172
173impl AnalysisExpr for BinaryExpr {
174    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
175        // Wrap another predicate with an optional NaNCount check, if the stat is available.
176        //
177        // For example, regular pruning conversion for `A >= B` would be
178        //
179        //      A.max < B.min
180        //
181        // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting
182        // in:
183        //
184        //      (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min
185        //
186        // Non-floating point column and literal expressions should be unaffected as they do not
187        // have a nan_count statistic defined.
188        #[inline]
189        fn with_nan_predicate(
190            lhs: &ExprRef,
191            rhs: &ExprRef,
192            value_predicate: ExprRef,
193            catalog: &mut dyn StatsCatalog,
194        ) -> ExprRef {
195            let nan_predicate = lhs
196                .nan_count(catalog)
197                .into_iter()
198                .chain(rhs.nan_count(catalog))
199                .map(|nans| eq(nans, lit(0u64)))
200                .reduce(and);
201
202            if let Some(nan_check) = nan_predicate {
203                and(nan_check, value_predicate)
204            } else {
205                value_predicate
206            }
207        }
208
209        match self.operator {
210            Operator::Eq => {
211                let min_lhs = self.lhs.min(catalog);
212                let max_lhs = self.lhs.max(catalog);
213
214                let min_rhs = self.rhs.min(catalog);
215                let max_rhs = self.rhs.max(catalog);
216
217                let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
218                let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
219
220                let min_max_check = left.into_iter().chain(right).reduce(or)?;
221
222                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
223                Some(with_nan_predicate(
224                    self.lhs(),
225                    self.rhs(),
226                    min_max_check,
227                    catalog,
228                ))
229            }
230            Operator::NotEq => {
231                let min_lhs = self.lhs.min(catalog)?;
232                let max_lhs = self.lhs.max(catalog)?;
233
234                let min_rhs = self.rhs.min(catalog)?;
235                let max_rhs = self.rhs.max(catalog)?;
236
237                let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
238
239                Some(with_nan_predicate(
240                    self.lhs(),
241                    self.rhs(),
242                    min_max_check,
243                    catalog,
244                ))
245            }
246            Operator::Gt => {
247                let min_max_check = lt_eq(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
248
249                Some(with_nan_predicate(
250                    self.lhs(),
251                    self.rhs(),
252                    min_max_check,
253                    catalog,
254                ))
255            }
256            Operator::Gte => {
257                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
258                let min_max_check = lt(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
259
260                Some(with_nan_predicate(
261                    self.lhs(),
262                    self.rhs(),
263                    min_max_check,
264                    catalog,
265                ))
266            }
267            Operator::Lt => {
268                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
269                let min_max_check = gt_eq(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
270
271                Some(with_nan_predicate(
272                    self.lhs(),
273                    self.rhs(),
274                    min_max_check,
275                    catalog,
276                ))
277            }
278            Operator::Lte => {
279                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
280                let min_max_check = gt(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
281
282                Some(with_nan_predicate(
283                    self.lhs(),
284                    self.rhs(),
285                    min_max_check,
286                    catalog,
287                ))
288            }
289            Operator::And => self
290                .lhs
291                .stat_falsification(catalog)
292                .into_iter()
293                .chain(self.rhs.stat_falsification(catalog))
294                .reduce(or),
295            Operator::Or => Some(and(
296                self.lhs.stat_falsification(catalog)?,
297                self.rhs.stat_falsification(catalog)?,
298            )),
299            Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
300        }
301    }
302}
303
304/// Create a new [`BinaryExpr`] using the [`Eq`](crate::Operator::Eq) operator.
305///
306/// ## Example usage
307///
308/// ```
309/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
310/// # use vortex_array::{Array, IntoArray, ToCanonical};
311/// # use vortex_array::validity::Validity;
312/// # use vortex_buffer::buffer;
313/// # use vortex_expr::{eq, root, lit, Scope};
314/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
315/// let result = eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
316///
317/// assert_eq!(
318///     result.to_bool().boolean_buffer(),
319///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
320/// );
321/// ```
322pub fn eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
323    BinaryExpr::new(lhs, Operator::Eq, rhs).into_expr()
324}
325
326/// Create a new [`BinaryExpr`] using the [`NotEq`](crate::Operator::NotEq) operator.
327///
328/// ## Example usage
329///
330/// ```
331/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
332/// # use vortex_array::{IntoArray, ToCanonical};
333/// # use vortex_array::validity::Validity;
334/// # use vortex_buffer::buffer;
335/// # use vortex_expr::{root, lit, not_eq, Scope};
336/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
337/// let result = not_eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
338///
339/// assert_eq!(
340///     result.to_bool().boolean_buffer(),
341///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
342/// );
343/// ```
344pub fn not_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
345    BinaryExpr::new(lhs, Operator::NotEq, rhs).into_expr()
346}
347
348/// Create a new [`BinaryExpr`] using the [`Gte`](crate::Operator::Gte) operator.
349///
350/// ## Example usage
351///
352/// ```
353/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
354/// # use vortex_array::{IntoArray, ToCanonical};
355/// # use vortex_array::validity::Validity;
356/// # use vortex_buffer::buffer;
357/// # use vortex_expr::{gt_eq, root, lit, Scope};
358/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
359/// let result = gt_eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
360///
361/// assert_eq!(
362///     result.to_bool().boolean_buffer(),
363///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
364/// );
365/// ```
366pub fn gt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
367    BinaryExpr::new(lhs, Operator::Gte, rhs).into_expr()
368}
369
370/// Create a new [`BinaryExpr`] using the [`Gt`](crate::Operator::Gt) operator.
371///
372/// ## Example usage
373///
374/// ```
375/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
376/// # use vortex_array::{IntoArray, ToCanonical};
377/// # use vortex_array::validity::Validity;
378/// # use vortex_buffer::buffer;
379/// # use vortex_expr::{gt, root, lit, Scope};
380/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
381/// let result = gt(root(), lit(2)).evaluate(&Scope::new(xs.to_array())).unwrap();
382///
383/// assert_eq!(
384///     result.to_bool().boolean_buffer(),
385///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
386/// );
387/// ```
388pub fn gt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
389    BinaryExpr::new(lhs, Operator::Gt, rhs).into_expr()
390}
391
392/// Create a new [`BinaryExpr`] using the [`Lte`](crate::Operator::Lte) operator.
393///
394/// ## Example usage
395///
396/// ```
397/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
398/// # use vortex_array::{IntoArray, ToCanonical};
399/// # use vortex_array::validity::Validity;
400/// # use vortex_buffer::buffer;
401/// # use vortex_expr::{root, lit, lt_eq, Scope};
402/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
403/// let result = lt_eq(root(), lit(2)).evaluate(&Scope::new(xs.to_array())).unwrap();
404///
405/// assert_eq!(
406///     result.to_bool().boolean_buffer(),
407///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
408/// );
409/// ```
410pub fn lt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
411    BinaryExpr::new(lhs, Operator::Lte, rhs).into_expr()
412}
413
414/// Create a new [`BinaryExpr`] using the [`Lt`](crate::Operator::Lt) operator.
415///
416/// ## Example usage
417///
418/// ```
419/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
420/// # use vortex_array::{IntoArray, ToCanonical};
421/// # use vortex_array::validity::Validity;
422/// # use vortex_buffer::buffer;
423/// # use vortex_expr::{root, lit, lt, Scope};
424/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
425/// let result = lt(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
426///
427/// assert_eq!(
428///     result.to_bool().boolean_buffer(),
429///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
430/// );
431/// ```
432pub fn lt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
433    BinaryExpr::new(lhs, Operator::Lt, rhs).into_expr()
434}
435
436/// Create a new [`BinaryExpr`] using the [`Or`](crate::Operator::Or) operator.
437///
438/// ## Example usage
439///
440/// ```
441/// # use vortex_array::arrays::BoolArray;
442/// # use vortex_array::{IntoArray, ToCanonical};
443/// # use vortex_expr::{root, lit, or, Scope};
444/// let xs = BoolArray::from_iter(vec![true, false, true]);
445/// let result = or(root(), lit(false)).evaluate(&Scope::new(xs.to_array())).unwrap();
446///
447/// assert_eq!(
448///     result.to_bool().boolean_buffer(),
449///     BoolArray::from_iter(vec![true, false, true]).boolean_buffer(),
450/// );
451/// ```
452pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
453    BinaryExpr::new(lhs, Operator::Or, rhs).into_expr()
454}
455
456/// Collects a list of `or`ed values into a single vortex, expr
457/// [x, y, z] => x or (y or z)
458pub fn or_collect<I>(iter: I) -> Option<ExprRef>
459where
460    I: IntoIterator<Item = ExprRef>,
461    I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
462{
463    let mut iter = iter.into_iter();
464    let first = iter.next_back()?;
465    Some(iter.rfold(first, |acc, elem| or(elem, acc)))
466}
467
468/// Create a new [`BinaryExpr`] using the [`And`](crate::Operator::And) operator.
469///
470/// ## Example usage
471///
472/// ```
473/// # use vortex_array::arrays::BoolArray;
474/// # use vortex_array::{IntoArray, ToCanonical};
475/// # use vortex_expr::{and, root, lit, Scope};
476/// let xs = BoolArray::from_iter(vec![true, false, true]);
477/// let result = and(root(), lit(true)).evaluate(&Scope::new(xs.to_array())).unwrap();
478///
479/// assert_eq!(
480///     result.to_bool().boolean_buffer(),
481///     BoolArray::from_iter(vec![true, false, true]).boolean_buffer(),
482/// );
483/// ```
484pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
485    BinaryExpr::new(lhs, Operator::And, rhs).into_expr()
486}
487
488/// Collects a list of `and`ed values into a single vortex, expr
489/// [x, y, z] => x and (y and z)
490pub fn and_collect<I>(iter: I) -> Option<ExprRef>
491where
492    I: IntoIterator<Item = ExprRef>,
493    I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
494{
495    let mut iter = iter.into_iter();
496    let first = iter.next_back()?;
497    Some(iter.rfold(first, |acc, elem| and(elem, acc)))
498}
499
500/// Collects a list of `and`ed values into a single vortex, expr
501/// [x, y, z] => x and (y and z)
502pub fn and_collect_right<I>(iter: I) -> Option<ExprRef>
503where
504    I: IntoIterator<Item = ExprRef>,
505{
506    let iter = iter.into_iter();
507    iter.reduce(and)
508}
509
510/// Create a new [`BinaryExpr`] using the [`Add`](crate::Operator::Add) operator.
511///
512/// ## Example usage
513///
514/// ```
515/// # use vortex_array::IntoArray;
516/// # use vortex_array::arrow::IntoArrowArray as _;
517/// # use vortex_buffer::buffer;
518/// # use vortex_expr::{Scope, checked_add, lit, root};
519/// let xs = buffer![1, 2, 3].into_array();
520/// let result = checked_add(root(), lit(5))
521///     .evaluate(&Scope::new(xs.to_array()))
522///     .unwrap();
523///
524/// assert_eq!(
525///     &result.into_arrow_preferred().unwrap(),
526///     &buffer![6, 7, 8]
527///         .into_array()
528///         .into_arrow_preferred()
529///         .unwrap()
530/// );
531/// ```
532pub fn checked_add(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
533    BinaryExpr::new(lhs, Operator::Add, rhs).into_expr()
534}
535
536#[cfg(test)]
537mod tests {
538    use std::sync::Arc;
539
540    use vortex_dtype::{DType, Nullability};
541
542    use crate::{
543        VortexExpr, and, and_collect, and_collect_right, col, eq, gt, gt_eq, lit, lt, lt_eq,
544        not_eq, or, test_harness,
545    };
546
547    #[test]
548    fn and_collect_left_assoc() {
549        let values = vec![lit(1), lit(2), lit(3)];
550        assert_eq!(
551            Some(and(lit(1), and(lit(2), lit(3)))),
552            and_collect(values.into_iter())
553        );
554    }
555
556    #[test]
557    fn and_collect_right_assoc() {
558        let values = vec![lit(1), lit(2), lit(3)];
559        assert_eq!(
560            Some(and(and(lit(1), lit(2)), lit(3))),
561            and_collect_right(values.into_iter())
562        );
563    }
564
565    #[test]
566    fn dtype() {
567        let dtype = test_harness::struct_dtype();
568        let bool1: Arc<dyn VortexExpr> = col("bool1");
569        let bool2: Arc<dyn VortexExpr> = col("bool2");
570        assert_eq!(
571            and(bool1.clone(), bool2.clone())
572                .return_dtype(&dtype)
573                .unwrap(),
574            DType::Bool(Nullability::NonNullable)
575        );
576        assert_eq!(
577            or(bool1.clone(), bool2.clone())
578                .return_dtype(&dtype)
579                .unwrap(),
580            DType::Bool(Nullability::NonNullable)
581        );
582
583        let col1: Arc<dyn VortexExpr> = col("col1");
584        let col2: Arc<dyn VortexExpr> = col("col2");
585
586        assert_eq!(
587            eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
588            DType::Bool(Nullability::Nullable)
589        );
590        assert_eq!(
591            not_eq(col1.clone(), col2.clone())
592                .return_dtype(&dtype)
593                .unwrap(),
594            DType::Bool(Nullability::Nullable)
595        );
596        assert_eq!(
597            gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
598            DType::Bool(Nullability::Nullable)
599        );
600        assert_eq!(
601            gt_eq(col1.clone(), col2.clone())
602                .return_dtype(&dtype)
603                .unwrap(),
604            DType::Bool(Nullability::Nullable)
605        );
606        assert_eq!(
607            lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
608            DType::Bool(Nullability::Nullable)
609        );
610        assert_eq!(
611            lt_eq(col1.clone(), col2.clone())
612                .return_dtype(&dtype)
613                .unwrap(),
614            DType::Bool(Nullability::Nullable)
615        );
616
617        assert_eq!(
618            or(
619                lt(col1.clone(), col2.clone()),
620                not_eq(col1.clone(), col2.clone())
621            )
622            .return_dtype(&dtype)
623            .unwrap(),
624            DType::Bool(Nullability::Nullable)
625        );
626    }
627}