vortex_expr/exprs/
binary.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5
6use vortex_array::compute::{Operator as ArrayOperator, add, and_kleene, compare, or_kleene, sub};
7use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata};
8use vortex_dtype::DType;
9use vortex_error::{VortexResult, vortex_bail};
10use vortex_proto::expr as pb;
11
12use crate::display::{DisplayAs, DisplayFormat};
13use crate::{
14    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Operator, Scope, StatsCatalog,
15    VTable, lit, vtable,
16};
17
18vtable!(Binary);
19
20#[allow(clippy::derived_hash_with_manual_eq)]
21#[derive(Debug, Clone, Hash, Eq)]
22pub struct BinaryExpr {
23    lhs: ExprRef,
24    operator: Operator,
25    rhs: ExprRef,
26}
27
28impl PartialEq for BinaryExpr {
29    fn eq(&self, other: &Self) -> bool {
30        self.lhs.eq(&other.lhs) && self.operator == other.operator && self.rhs.eq(&other.rhs)
31    }
32}
33
34pub struct BinaryExprEncoding;
35
36impl VTable for BinaryVTable {
37    type Expr = BinaryExpr;
38    type Encoding = BinaryExprEncoding;
39    type Metadata = ProstMetadata<pb::BinaryOpts>;
40
41    fn id(_encoding: &Self::Encoding) -> ExprId {
42        ExprId::new_ref("binary")
43    }
44
45    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
46        ExprEncodingRef::new_ref(BinaryExprEncoding.as_ref())
47    }
48
49    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
50        Some(ProstMetadata(pb::BinaryOpts {
51            op: expr.operator.into(),
52        }))
53    }
54
55    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
56        vec![expr.lhs(), expr.rhs()]
57    }
58
59    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
60        Ok(BinaryExpr::new(
61            children[0].clone(),
62            expr.op(),
63            children[1].clone(),
64        ))
65    }
66
67    fn build(
68        _encoding: &Self::Encoding,
69        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
70        children: Vec<ExprRef>,
71    ) -> VortexResult<Self::Expr> {
72        Ok(BinaryExpr::new(
73            children[0].clone(),
74            metadata.op().into(),
75            children[1].clone(),
76        ))
77    }
78
79    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
80        let lhs = expr.lhs.unchecked_evaluate(scope)?;
81        let rhs = expr.rhs.unchecked_evaluate(scope)?;
82
83        match expr.operator {
84            Operator::Eq => compare(&lhs, &rhs, ArrayOperator::Eq),
85            Operator::NotEq => compare(&lhs, &rhs, ArrayOperator::NotEq),
86            Operator::Lt => compare(&lhs, &rhs, ArrayOperator::Lt),
87            Operator::Lte => compare(&lhs, &rhs, ArrayOperator::Lte),
88            Operator::Gt => compare(&lhs, &rhs, ArrayOperator::Gt),
89            Operator::Gte => compare(&lhs, &rhs, ArrayOperator::Gte),
90            Operator::And => and_kleene(&lhs, &rhs),
91            Operator::Or => or_kleene(&lhs, &rhs),
92            Operator::Add => add(&lhs, &rhs),
93            Operator::Sub => sub(&lhs, &rhs),
94        }
95    }
96
97    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
98        let lhs = expr.lhs.return_dtype(scope)?;
99        let rhs = expr.rhs.return_dtype(scope)?;
100
101        if expr.operator == Operator::Add {
102            if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
103                return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
104            }
105            vortex_bail!("incompatible types for checked add: {} {}", lhs, rhs);
106        }
107
108        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
109    }
110}
111
112impl BinaryExpr {
113    pub fn new(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> Self {
114        Self { lhs, operator, rhs }
115    }
116
117    pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
118        Self::new(lhs, operator, rhs).into_expr()
119    }
120
121    pub fn lhs(&self) -> &ExprRef {
122        &self.lhs
123    }
124
125    pub fn rhs(&self) -> &ExprRef {
126        &self.rhs
127    }
128
129    pub fn op(&self) -> Operator {
130        self.operator
131    }
132}
133
134impl DisplayAs for BinaryExpr {
135    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
136        match df {
137            DisplayFormat::Compact => {
138                write!(f, "({} {} {})", self.lhs, self.operator, self.rhs)
139            }
140            DisplayFormat::Tree => {
141                write!(f, "Binary({})", self.operator)
142            }
143        }
144    }
145
146    fn child_names(&self) -> Option<Vec<String>> {
147        Some(vec!["lhs".to_string(), "rhs".to_string()])
148    }
149}
150
151impl AnalysisExpr for BinaryExpr {
152    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
153        // Wrap another predicate with an optional NaNCount check, if the stat is available.
154        //
155        // For example, regular pruning conversion for `A >= B` would be
156        //
157        //      A.max < B.min
158        //
159        // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting
160        // in:
161        //
162        //      (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min
163        //
164        // Non-floating point column and literal expressions should be unaffected as they do not
165        // have a nan_count statistic defined.
166        #[inline]
167        fn with_nan_predicate(
168            lhs: &ExprRef,
169            rhs: &ExprRef,
170            value_predicate: ExprRef,
171            catalog: &mut dyn StatsCatalog,
172        ) -> ExprRef {
173            let nan_predicate = lhs
174                .nan_count(catalog)
175                .into_iter()
176                .chain(rhs.nan_count(catalog))
177                .map(|nans| eq(nans, lit(0u64)))
178                .reduce(and);
179
180            if let Some(nan_check) = nan_predicate {
181                and(nan_check, value_predicate)
182            } else {
183                value_predicate
184            }
185        }
186
187        match self.operator {
188            Operator::Eq => {
189                let min_lhs = self.lhs.min(catalog);
190                let max_lhs = self.lhs.max(catalog);
191
192                let min_rhs = self.rhs.min(catalog);
193                let max_rhs = self.rhs.max(catalog);
194
195                let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
196                let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
197
198                let min_max_check = left.into_iter().chain(right).reduce(or)?;
199
200                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
201                Some(with_nan_predicate(
202                    self.lhs(),
203                    self.rhs(),
204                    min_max_check,
205                    catalog,
206                ))
207            }
208            Operator::NotEq => {
209                let min_lhs = self.lhs.min(catalog)?;
210                let max_lhs = self.lhs.max(catalog)?;
211
212                let min_rhs = self.rhs.min(catalog)?;
213                let max_rhs = self.rhs.max(catalog)?;
214
215                let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
216
217                Some(with_nan_predicate(
218                    self.lhs(),
219                    self.rhs(),
220                    min_max_check,
221                    catalog,
222                ))
223            }
224            Operator::Gt => {
225                let min_max_check = lt_eq(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
226
227                Some(with_nan_predicate(
228                    self.lhs(),
229                    self.rhs(),
230                    min_max_check,
231                    catalog,
232                ))
233            }
234            Operator::Gte => {
235                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
236                let min_max_check = lt(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
237
238                Some(with_nan_predicate(
239                    self.lhs(),
240                    self.rhs(),
241                    min_max_check,
242                    catalog,
243                ))
244            }
245            Operator::Lt => {
246                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
247                let min_max_check = gt_eq(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
248
249                Some(with_nan_predicate(
250                    self.lhs(),
251                    self.rhs(),
252                    min_max_check,
253                    catalog,
254                ))
255            }
256            Operator::Lte => {
257                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
258                let min_max_check = gt(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
259
260                Some(with_nan_predicate(
261                    self.lhs(),
262                    self.rhs(),
263                    min_max_check,
264                    catalog,
265                ))
266            }
267            Operator::And => self
268                .lhs
269                .stat_falsification(catalog)
270                .into_iter()
271                .chain(self.rhs.stat_falsification(catalog))
272                .reduce(or),
273            Operator::Or => Some(and(
274                self.lhs.stat_falsification(catalog)?,
275                self.rhs.stat_falsification(catalog)?,
276            )),
277            Operator::Add | Operator::Sub => None,
278        }
279    }
280}
281
282/// Create a new [`BinaryExpr`] using the [`Eq`](crate::Operator::Eq) operator.
283///
284/// ## Example usage
285///
286/// ```
287/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
288/// # use vortex_array::{Array, IntoArray, ToCanonical};
289/// # use vortex_array::validity::Validity;
290/// # use vortex_buffer::buffer;
291/// # use vortex_expr::{eq, root, lit, Scope};
292/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
293/// let result = eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
294///
295/// assert_eq!(
296///     result.to_bool().unwrap().boolean_buffer(),
297///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
298/// );
299/// ```
300pub fn eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
301    BinaryExpr::new(lhs, Operator::Eq, rhs).into_expr()
302}
303
304/// Create a new [`BinaryExpr`] using the [`NotEq`](crate::Operator::NotEq) operator.
305///
306/// ## Example usage
307///
308/// ```
309/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
310/// # use vortex_array::{IntoArray, ToCanonical};
311/// # use vortex_array::validity::Validity;
312/// # use vortex_buffer::buffer;
313/// # use vortex_expr::{root, lit, not_eq, Scope};
314/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
315/// let result = not_eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
316///
317/// assert_eq!(
318///     result.to_bool().unwrap().boolean_buffer(),
319///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
320/// );
321/// ```
322pub fn not_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
323    BinaryExpr::new(lhs, Operator::NotEq, rhs).into_expr()
324}
325
326/// Create a new [`BinaryExpr`] using the [`Gte`](crate::Operator::Gte) operator.
327///
328/// ## Example usage
329///
330/// ```
331/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
332/// # use vortex_array::{IntoArray, ToCanonical};
333/// # use vortex_array::validity::Validity;
334/// # use vortex_buffer::buffer;
335/// # use vortex_expr::{gt_eq, root, lit, Scope};
336/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
337/// let result = gt_eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
338///
339/// assert_eq!(
340///     result.to_bool().unwrap().boolean_buffer(),
341///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
342/// );
343/// ```
344pub fn gt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
345    BinaryExpr::new(lhs, Operator::Gte, rhs).into_expr()
346}
347
348/// Create a new [`BinaryExpr`] using the [`Gt`](crate::Operator::Gt) operator.
349///
350/// ## Example usage
351///
352/// ```
353/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
354/// # use vortex_array::{IntoArray, ToCanonical};
355/// # use vortex_array::validity::Validity;
356/// # use vortex_buffer::buffer;
357/// # use vortex_expr::{gt, root, lit, Scope};
358/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
359/// let result = gt(root(), lit(2)).evaluate(&Scope::new(xs.to_array())).unwrap();
360///
361/// assert_eq!(
362///     result.to_bool().unwrap().boolean_buffer(),
363///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
364/// );
365/// ```
366pub fn gt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
367    BinaryExpr::new(lhs, Operator::Gt, rhs).into_expr()
368}
369
370/// Create a new [`BinaryExpr`] using the [`Lte`](crate::Operator::Lte) operator.
371///
372/// ## Example usage
373///
374/// ```
375/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
376/// # use vortex_array::{IntoArray, ToCanonical};
377/// # use vortex_array::validity::Validity;
378/// # use vortex_buffer::buffer;
379/// # use vortex_expr::{root, lit, lt_eq, Scope};
380/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
381/// let result = lt_eq(root(), lit(2)).evaluate(&Scope::new(xs.to_array())).unwrap();
382///
383/// assert_eq!(
384///     result.to_bool().unwrap().boolean_buffer(),
385///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
386/// );
387/// ```
388pub fn lt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
389    BinaryExpr::new(lhs, Operator::Lte, rhs).into_expr()
390}
391
392/// Create a new [`BinaryExpr`] using the [`Lt`](crate::Operator::Lt) operator.
393///
394/// ## Example usage
395///
396/// ```
397/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
398/// # use vortex_array::{IntoArray, ToCanonical};
399/// # use vortex_array::validity::Validity;
400/// # use vortex_buffer::buffer;
401/// # use vortex_expr::{root, lit, lt, Scope};
402/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
403/// let result = lt(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
404///
405/// assert_eq!(
406///     result.to_bool().unwrap().boolean_buffer(),
407///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
408/// );
409/// ```
410pub fn lt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
411    BinaryExpr::new(lhs, Operator::Lt, rhs).into_expr()
412}
413
414/// Create a new [`BinaryExpr`] using the [`Or`](crate::Operator::Or) operator.
415///
416/// ## Example usage
417///
418/// ```
419/// # use vortex_array::arrays::BoolArray;
420/// # use vortex_array::{IntoArray, ToCanonical};
421/// # use vortex_expr::{root, lit, or, Scope};
422/// let xs = BoolArray::from_iter(vec![true, false, true]);
423/// let result = or(root(), lit(false)).evaluate(&Scope::new(xs.to_array())).unwrap();
424///
425/// assert_eq!(
426///     result.to_bool().unwrap().boolean_buffer(),
427///     BoolArray::from_iter(vec![true, false, true]).boolean_buffer(),
428/// );
429/// ```
430pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
431    BinaryExpr::new(lhs, Operator::Or, rhs).into_expr()
432}
433
434/// Collects a list of `or`ed values into a single vortex, expr
435/// [x, y, z] => x or (y or z)
436pub fn or_collect<I>(iter: I) -> Option<ExprRef>
437where
438    I: IntoIterator<Item = ExprRef>,
439    I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
440{
441    let mut iter = iter.into_iter();
442    let first = iter.next_back()?;
443    Some(iter.rfold(first, |acc, elem| or(elem, acc)))
444}
445
446/// Create a new [`BinaryExpr`] using the [`And`](crate::Operator::And) operator.
447///
448/// ## Example usage
449///
450/// ```
451/// # use vortex_array::arrays::BoolArray;
452/// # use vortex_array::{IntoArray, ToCanonical};
453/// # use vortex_expr::{and, root, lit, Scope};
454/// let xs = BoolArray::from_iter(vec![true, false, true]);
455/// let result = and(root(), lit(true)).evaluate(&Scope::new(xs.to_array())).unwrap();
456///
457/// assert_eq!(
458///     result.to_bool().unwrap().boolean_buffer(),
459///     BoolArray::from_iter(vec![true, false, true]).boolean_buffer(),
460/// );
461/// ```
462pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
463    BinaryExpr::new(lhs, Operator::And, rhs).into_expr()
464}
465
466/// Collects a list of `and`ed values into a single vortex, expr
467/// [x, y, z] => x and (y and z)
468pub fn and_collect<I>(iter: I) -> Option<ExprRef>
469where
470    I: IntoIterator<Item = ExprRef>,
471    I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
472{
473    let mut iter = iter.into_iter();
474    let first = iter.next_back()?;
475    Some(iter.rfold(first, |acc, elem| and(elem, acc)))
476}
477
478/// Collects a list of `and`ed values into a single vortex, expr
479/// [x, y, z] => x and (y and z)
480pub fn and_collect_right<I>(iter: I) -> Option<ExprRef>
481where
482    I: IntoIterator<Item = ExprRef>,
483{
484    let iter = iter.into_iter();
485    iter.reduce(and)
486}
487
488/// Create a new [`BinaryExpr`] using the [`Add`](crate::Operator::Add) operator.
489///
490/// ## Example usage
491///
492/// ```
493/// # use vortex_array::IntoArray;
494/// # use vortex_array::arrow::IntoArrowArray as _;
495/// # use vortex_buffer::buffer;
496/// # use vortex_expr::{Scope, checked_add, lit, root};
497/// let xs = buffer![1, 2, 3].into_array();
498/// let result = checked_add(root(), lit(5))
499///     .evaluate(&Scope::new(xs.to_array()))
500///     .unwrap();
501///
502/// assert_eq!(
503///     &result.into_arrow_preferred().unwrap(),
504///     &buffer![6, 7, 8]
505///         .into_array()
506///         .into_arrow_preferred()
507///         .unwrap()
508/// );
509/// ```
510pub fn checked_add(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
511    BinaryExpr::new(lhs, Operator::Add, rhs).into_expr()
512}
513
514#[cfg(test)]
515mod tests {
516    use std::sync::Arc;
517
518    use vortex_dtype::{DType, Nullability};
519
520    use crate::{
521        VortexExpr, and, and_collect, and_collect_right, col, eq, gt, gt_eq, lit, lt, lt_eq,
522        not_eq, or, test_harness,
523    };
524
525    #[test]
526    fn and_collect_left_assoc() {
527        let values = vec![lit(1), lit(2), lit(3)];
528        assert_eq!(
529            Some(and(lit(1), and(lit(2), lit(3)))),
530            and_collect(values.into_iter())
531        );
532    }
533
534    #[test]
535    fn and_collect_right_assoc() {
536        let values = vec![lit(1), lit(2), lit(3)];
537        assert_eq!(
538            Some(and(and(lit(1), lit(2)), lit(3))),
539            and_collect_right(values.into_iter())
540        );
541    }
542
543    #[test]
544    fn dtype() {
545        let dtype = test_harness::struct_dtype();
546        let bool1: Arc<dyn VortexExpr> = col("bool1");
547        let bool2: Arc<dyn VortexExpr> = col("bool2");
548        assert_eq!(
549            and(bool1.clone(), bool2.clone())
550                .return_dtype(&dtype)
551                .unwrap(),
552            DType::Bool(Nullability::NonNullable)
553        );
554        assert_eq!(
555            or(bool1.clone(), bool2.clone())
556                .return_dtype(&dtype)
557                .unwrap(),
558            DType::Bool(Nullability::NonNullable)
559        );
560
561        let col1: Arc<dyn VortexExpr> = col("col1");
562        let col2: Arc<dyn VortexExpr> = col("col2");
563
564        assert_eq!(
565            eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
566            DType::Bool(Nullability::Nullable)
567        );
568        assert_eq!(
569            not_eq(col1.clone(), col2.clone())
570                .return_dtype(&dtype)
571                .unwrap(),
572            DType::Bool(Nullability::Nullable)
573        );
574        assert_eq!(
575            gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
576            DType::Bool(Nullability::Nullable)
577        );
578        assert_eq!(
579            gt_eq(col1.clone(), col2.clone())
580                .return_dtype(&dtype)
581                .unwrap(),
582            DType::Bool(Nullability::Nullable)
583        );
584        assert_eq!(
585            lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
586            DType::Bool(Nullability::Nullable)
587        );
588        assert_eq!(
589            lt_eq(col1.clone(), col2.clone())
590                .return_dtype(&dtype)
591                .unwrap(),
592            DType::Bool(Nullability::Nullable)
593        );
594
595        assert_eq!(
596            or(
597                lt(col1.clone(), col2.clone()),
598                not_eq(col1.clone(), col2.clone())
599            )
600            .return_dtype(&dtype)
601            .unwrap(),
602            DType::Bool(Nullability::Nullable)
603        );
604    }
605}