vortex_expr/exprs/
binary.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::hash::Hash;
6
7use vortex_array::compute::{Operator as ArrayOperator, add, and_kleene, compare, or_kleene};
8use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata};
9use vortex_dtype::DType;
10use vortex_error::{VortexResult, vortex_bail};
11use vortex_proto::expr as pb;
12
13use crate::{
14    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Operator, Scope, StatsCatalog,
15    VTable, lit, vtable,
16};
17
18vtable!(Binary);
19
20#[allow(clippy::derived_hash_with_manual_eq)]
21#[derive(Debug, Clone, Hash)]
22pub struct BinaryExpr {
23    lhs: ExprRef,
24    operator: Operator,
25    rhs: ExprRef,
26}
27
28impl PartialEq for BinaryExpr {
29    fn eq(&self, other: &Self) -> bool {
30        self.lhs.eq(&other.lhs) && self.operator == other.operator && self.rhs.eq(&other.rhs)
31    }
32}
33
34pub struct BinaryExprEncoding;
35
36impl VTable for BinaryVTable {
37    type Expr = BinaryExpr;
38    type Encoding = BinaryExprEncoding;
39    type Metadata = ProstMetadata<pb::BinaryOpts>;
40
41    fn id(_encoding: &Self::Encoding) -> ExprId {
42        ExprId::new_ref("binary")
43    }
44
45    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
46        ExprEncodingRef::new_ref(BinaryExprEncoding.as_ref())
47    }
48
49    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
50        Some(ProstMetadata(pb::BinaryOpts {
51            op: expr.operator.into(),
52        }))
53    }
54
55    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
56        vec![expr.lhs(), expr.rhs()]
57    }
58
59    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
60        Ok(BinaryExpr::new(
61            children[0].clone(),
62            expr.op(),
63            children[1].clone(),
64        ))
65    }
66
67    fn build(
68        _encoding: &Self::Encoding,
69        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
70        children: Vec<ExprRef>,
71    ) -> VortexResult<Self::Expr> {
72        Ok(BinaryExpr::new(
73            children[0].clone(),
74            metadata.op().into(),
75            children[1].clone(),
76        ))
77    }
78
79    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
80        let lhs = expr.lhs.unchecked_evaluate(scope)?;
81        let rhs = expr.rhs.unchecked_evaluate(scope)?;
82
83        match expr.operator {
84            Operator::Eq => compare(&lhs, &rhs, ArrayOperator::Eq),
85            Operator::NotEq => compare(&lhs, &rhs, ArrayOperator::NotEq),
86            Operator::Lt => compare(&lhs, &rhs, ArrayOperator::Lt),
87            Operator::Lte => compare(&lhs, &rhs, ArrayOperator::Lte),
88            Operator::Gt => compare(&lhs, &rhs, ArrayOperator::Gt),
89            Operator::Gte => compare(&lhs, &rhs, ArrayOperator::Gte),
90            Operator::And => and_kleene(&lhs, &rhs),
91            Operator::Or => or_kleene(&lhs, &rhs),
92            Operator::Add => add(&lhs, &rhs),
93        }
94    }
95
96    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
97        let lhs = expr.lhs.return_dtype(scope)?;
98        let rhs = expr.rhs.return_dtype(scope)?;
99
100        if expr.operator == Operator::Add {
101            if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
102                return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
103            }
104            vortex_bail!("incompatible types for checked add: {} {}", lhs, rhs);
105        }
106
107        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
108    }
109}
110
111impl BinaryExpr {
112    pub fn new(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> Self {
113        Self { lhs, operator, rhs }
114    }
115
116    pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
117        Self::new(lhs, operator, rhs).into_expr()
118    }
119
120    pub fn lhs(&self) -> &ExprRef {
121        &self.lhs
122    }
123
124    pub fn rhs(&self) -> &ExprRef {
125        &self.rhs
126    }
127
128    pub fn op(&self) -> Operator {
129        self.operator
130    }
131}
132
133impl Display for BinaryExpr {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        write!(f, "({} {} {})", self.lhs, self.operator, self.rhs)
136    }
137}
138
139impl AnalysisExpr for BinaryExpr {
140    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
141        // Wrap another predicate with an optional NaNCount check, if the stat is available.
142        //
143        // For example, regular pruning conversion for `A >= B` would be
144        //
145        //      A.max < B.min
146        //
147        // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting
148        // in:
149        //
150        //      (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min
151        //
152        // Non-floating point column and literal expressions should be unaffected as they do not
153        // have a nan_count statistic defined.
154        #[inline]
155        fn with_nan_predicate(
156            lhs: &ExprRef,
157            rhs: &ExprRef,
158            value_predicate: ExprRef,
159            catalog: &mut dyn StatsCatalog,
160        ) -> ExprRef {
161            let nan_predicate = lhs
162                .nan_count(catalog)
163                .into_iter()
164                .chain(rhs.nan_count(catalog))
165                .map(|nans| eq(nans, lit(0u64)))
166                .reduce(and);
167
168            if let Some(nan_check) = nan_predicate {
169                and(nan_check, value_predicate)
170            } else {
171                value_predicate
172            }
173        }
174
175        match self.operator {
176            Operator::Eq => {
177                let min_lhs = self.lhs.min(catalog);
178                let max_lhs = self.lhs.max(catalog);
179
180                let min_rhs = self.rhs.min(catalog);
181                let max_rhs = self.rhs.max(catalog);
182
183                let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
184                let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
185
186                let min_max_check = left.into_iter().chain(right).reduce(or)?;
187
188                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
189                Some(with_nan_predicate(
190                    self.lhs(),
191                    self.rhs(),
192                    min_max_check,
193                    catalog,
194                ))
195            }
196            Operator::NotEq => {
197                let min_lhs = self.lhs.min(catalog)?;
198                let max_lhs = self.lhs.max(catalog)?;
199
200                let min_rhs = self.rhs.min(catalog)?;
201                let max_rhs = self.rhs.max(catalog)?;
202
203                let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
204
205                Some(with_nan_predicate(
206                    self.lhs(),
207                    self.rhs(),
208                    min_max_check,
209                    catalog,
210                ))
211            }
212            Operator::Gt => {
213                let min_max_check = lt_eq(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
214
215                Some(with_nan_predicate(
216                    self.lhs(),
217                    self.rhs(),
218                    min_max_check,
219                    catalog,
220                ))
221            }
222            Operator::Gte => {
223                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
224                let min_max_check = lt(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
225
226                Some(with_nan_predicate(
227                    self.lhs(),
228                    self.rhs(),
229                    min_max_check,
230                    catalog,
231                ))
232            }
233            Operator::Lt => {
234                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
235                let min_max_check = gt_eq(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
236
237                Some(with_nan_predicate(
238                    self.lhs(),
239                    self.rhs(),
240                    min_max_check,
241                    catalog,
242                ))
243            }
244            Operator::Lte => {
245                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
246                let min_max_check = gt(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
247
248                Some(with_nan_predicate(
249                    self.lhs(),
250                    self.rhs(),
251                    min_max_check,
252                    catalog,
253                ))
254            }
255            Operator::And => self
256                .lhs
257                .stat_falsification(catalog)
258                .into_iter()
259                .chain(self.rhs.stat_falsification(catalog))
260                .reduce(or),
261            Operator::Or => Some(and(
262                self.lhs.stat_falsification(catalog)?,
263                self.rhs.stat_falsification(catalog)?,
264            )),
265            Operator::Add => None,
266        }
267    }
268}
269
270/// Create a new `BinaryExpr` using the `Eq` operator.
271///
272/// ## Example usage
273///
274/// ```
275/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
276/// use vortex_array::{Array, IntoArray, ToCanonical};
277/// use vortex_array::validity::Validity;
278/// use vortex_buffer::buffer;
279/// use vortex_expr::{eq, root, lit, Scope};
280///
281/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
282/// let result = eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
283///
284/// assert_eq!(
285///     result.to_bool().unwrap().boolean_buffer(),
286///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
287/// );
288/// ```
289pub fn eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
290    BinaryExpr::new(lhs, Operator::Eq, rhs).into_expr()
291}
292
293/// Create a new `BinaryExpr` using the `NotEq` operator.
294///
295/// ## Example usage
296///
297/// ```
298/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
299/// use vortex_array::{IntoArray, ToCanonical};
300/// use vortex_array::validity::Validity;
301/// use vortex_buffer::buffer;
302/// use vortex_expr::{root, lit, not_eq, Scope};
303///
304/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
305/// let result = not_eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
306///
307/// assert_eq!(
308///     result.to_bool().unwrap().boolean_buffer(),
309///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
310/// );
311/// ```
312pub fn not_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
313    BinaryExpr::new(lhs, Operator::NotEq, rhs).into_expr()
314}
315
316/// Create a new `BinaryExpr` using the `Gte` operator.
317///
318/// ## Example usage
319///
320/// ```
321/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
322/// use vortex_array::{IntoArray, ToCanonical};
323/// use vortex_array::validity::Validity;
324/// use vortex_buffer::buffer;
325/// use vortex_expr::{gt_eq, root, lit, Scope};
326///
327/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
328/// let result = gt_eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
329///
330/// assert_eq!(
331///     result.to_bool().unwrap().boolean_buffer(),
332///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
333/// );
334/// ```
335pub fn gt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
336    BinaryExpr::new(lhs, Operator::Gte, rhs).into_expr()
337}
338
339/// Create a new `BinaryExpr` using the `Gt` operator.
340///
341/// ## Example usage
342///
343/// ```
344/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
345/// use vortex_array::{IntoArray, ToCanonical};
346/// use vortex_array::validity::Validity;
347/// use vortex_buffer::buffer;
348/// use vortex_expr::{gt, root, lit, Scope};
349///
350/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
351/// let result = gt(root(), lit(2)).evaluate(&Scope::new(xs.to_array())).unwrap();
352///
353/// assert_eq!(
354///     result.to_bool().unwrap().boolean_buffer(),
355///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
356/// );
357/// ```
358pub fn gt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
359    BinaryExpr::new(lhs, Operator::Gt, rhs).into_expr()
360}
361
362/// Create a new `BinaryExpr` using the `Lte` operator.
363///
364/// ## Example usage
365///
366/// ```
367/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
368/// use vortex_array::{IntoArray, ToCanonical};
369/// use vortex_array::validity::Validity;
370/// use vortex_buffer::buffer;
371/// use vortex_expr::{root, lit, lt_eq, Scope};
372///
373/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
374/// let result = lt_eq(root(), lit(2)).evaluate(&Scope::new(xs.to_array())).unwrap();
375///
376/// assert_eq!(
377///     result.to_bool().unwrap().boolean_buffer(),
378///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
379/// );
380/// ```
381pub fn lt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
382    BinaryExpr::new(lhs, Operator::Lte, rhs).into_expr()
383}
384
385/// Create a new `BinaryExpr` using the `Lt` operator.
386///
387/// ## Example usage
388///
389/// ```
390/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
391/// use vortex_array::{IntoArray, ToCanonical};
392/// use vortex_array::validity::Validity;
393/// use vortex_buffer::buffer;
394/// use vortex_expr::{root, lit, lt, Scope};
395///
396/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
397/// let result = lt(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
398///
399/// assert_eq!(
400///     result.to_bool().unwrap().boolean_buffer(),
401///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
402/// );
403/// ```
404pub fn lt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
405    BinaryExpr::new(lhs, Operator::Lt, rhs).into_expr()
406}
407
408/// Create a new `BinaryExpr` using the `Or` operator.
409///
410/// ## Example usage
411///
412/// ```
413/// use vortex_array::arrays::BoolArray;
414/// use vortex_array::{IntoArray, ToCanonical};
415/// use vortex_expr::{root, lit, or, Scope};
416///
417/// let xs = BoolArray::from_iter(vec![true, false, true]);
418/// let result = or(root(), lit(false)).evaluate(&Scope::new(xs.to_array())).unwrap();
419///
420/// assert_eq!(
421///     result.to_bool().unwrap().boolean_buffer(),
422///     BoolArray::from_iter(vec![true, false, true]).boolean_buffer(),
423/// );
424/// ```
425pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
426    BinaryExpr::new(lhs, Operator::Or, rhs).into_expr()
427}
428
429/// Collects a list of `or`ed values into a single vortex, expr
430/// [x, y, z] => x or (y or z)
431pub fn or_collect<I>(iter: I) -> Option<ExprRef>
432where
433    I: IntoIterator<Item = ExprRef>,
434    I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
435{
436    let mut iter = iter.into_iter();
437    let first = iter.next_back()?;
438    Some(iter.rfold(first, |acc, elem| or(elem, acc)))
439}
440
441/// Create a new `BinaryExpr` using the `And` operator.
442///
443/// ## Example usage
444///
445/// ```
446/// use vortex_array::arrays::BoolArray;
447/// use vortex_array::{IntoArray, ToCanonical};
448/// use vortex_expr::{and, root, lit, Scope};
449///
450/// let xs = BoolArray::from_iter(vec![true, false, true]);
451/// let result = and(root(), lit(true)).evaluate(&Scope::new(xs.to_array())).unwrap();
452///
453/// assert_eq!(
454///     result.to_bool().unwrap().boolean_buffer(),
455///     BoolArray::from_iter(vec![true, false, true]).boolean_buffer(),
456/// );
457/// ```
458pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
459    BinaryExpr::new(lhs, Operator::And, rhs).into_expr()
460}
461
462/// Collects a list of `and`ed values into a single vortex, expr
463/// [x, y, z] => x and (y and z)
464pub fn and_collect<I>(iter: I) -> Option<ExprRef>
465where
466    I: IntoIterator<Item = ExprRef>,
467    I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
468{
469    let mut iter = iter.into_iter();
470    let first = iter.next_back()?;
471    Some(iter.rfold(first, |acc, elem| and(elem, acc)))
472}
473
474/// Collects a list of `and`ed values into a single vortex, expr
475/// [x, y, z] => x and (y and z)
476pub fn and_collect_right<I>(iter: I) -> Option<ExprRef>
477where
478    I: IntoIterator<Item = ExprRef>,
479{
480    let iter = iter.into_iter();
481    iter.reduce(and)
482}
483
484/// Create a new `BinaryExpr` using the `CheckedAdd` operator.
485///
486/// ## Example usage
487///
488/// ```
489/// use vortex_array::IntoArray;
490/// use vortex_array::arrow::IntoArrowArray as _;
491/// use vortex_buffer::buffer;
492/// use vortex_expr::{Scope, checked_add, lit, root};
493///
494/// let xs = buffer![1, 2, 3].into_array();
495/// let result = checked_add(root(), lit(5))
496///     .evaluate(&Scope::new(xs.to_array()))
497///     .unwrap();
498///
499/// assert_eq!(
500///     &result.into_arrow_preferred().unwrap(),
501///     &buffer![6, 7, 8]
502///         .into_array()
503///         .into_arrow_preferred()
504///         .unwrap()
505/// );
506/// ```
507pub fn checked_add(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
508    BinaryExpr::new(lhs, Operator::Add, rhs).into_expr()
509}
510
511#[cfg(test)]
512mod tests {
513    use std::sync::Arc;
514
515    use vortex_dtype::{DType, Nullability};
516
517    use crate::{
518        VortexExpr, and, and_collect, and_collect_right, col, eq, gt, gt_eq, lit, lt, lt_eq,
519        not_eq, or, test_harness,
520    };
521
522    #[test]
523    fn and_collect_left_assoc() {
524        let values = vec![lit(1), lit(2), lit(3)];
525        assert_eq!(
526            Some(and(lit(1), and(lit(2), lit(3)))),
527            and_collect(values.into_iter())
528        );
529    }
530
531    #[test]
532    fn and_collect_right_assoc() {
533        let values = vec![lit(1), lit(2), lit(3)];
534        assert_eq!(
535            Some(and(and(lit(1), lit(2)), lit(3))),
536            and_collect_right(values.into_iter())
537        );
538    }
539
540    #[test]
541    fn dtype() {
542        let dtype = test_harness::struct_dtype();
543        let bool1: Arc<dyn VortexExpr> = col("bool1");
544        let bool2: Arc<dyn VortexExpr> = col("bool2");
545        assert_eq!(
546            and(bool1.clone(), bool2.clone())
547                .return_dtype(&dtype)
548                .unwrap(),
549            DType::Bool(Nullability::NonNullable)
550        );
551        assert_eq!(
552            or(bool1.clone(), bool2.clone())
553                .return_dtype(&dtype)
554                .unwrap(),
555            DType::Bool(Nullability::NonNullable)
556        );
557
558        let col1: Arc<dyn VortexExpr> = col("col1");
559        let col2: Arc<dyn VortexExpr> = col("col2");
560
561        assert_eq!(
562            eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
563            DType::Bool(Nullability::Nullable)
564        );
565        assert_eq!(
566            not_eq(col1.clone(), col2.clone())
567                .return_dtype(&dtype)
568                .unwrap(),
569            DType::Bool(Nullability::Nullable)
570        );
571        assert_eq!(
572            gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
573            DType::Bool(Nullability::Nullable)
574        );
575        assert_eq!(
576            gt_eq(col1.clone(), col2.clone())
577                .return_dtype(&dtype)
578                .unwrap(),
579            DType::Bool(Nullability::Nullable)
580        );
581        assert_eq!(
582            lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
583            DType::Bool(Nullability::Nullable)
584        );
585        assert_eq!(
586            lt_eq(col1.clone(), col2.clone())
587                .return_dtype(&dtype)
588                .unwrap(),
589            DType::Bool(Nullability::Nullable)
590        );
591
592        assert_eq!(
593            or(
594                lt(col1.clone(), col2.clone()),
595                not_eq(col1.clone(), col2.clone())
596            )
597            .return_dtype(&dtype)
598            .unwrap(),
599            DType::Bool(Nullability::Nullable)
600        );
601    }
602}