vortex_expr/exprs/
binary.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5use std::sync::Arc;
6
7use vortex_array::compute::{add, and_kleene, compare, or_kleene, sub};
8use vortex_array::pipeline::OperatorRef;
9use vortex_array::pipeline::operators::CompareOperator;
10use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata, compute};
11use vortex_dtype::DType;
12use vortex_error::{VortexExpect, VortexResult, vortex_bail};
13use vortex_proto::expr as pb;
14
15use crate::display::{DisplayAs, DisplayFormat};
16use crate::{
17    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Operator, Scope, StatsCatalog,
18    VTable, lit, vtable,
19};
20
21vtable!(Binary);
22
23#[allow(clippy::derived_hash_with_manual_eq)]
24#[derive(Debug, Clone, Hash, Eq)]
25pub struct BinaryExpr {
26    lhs: ExprRef,
27    operator: Operator,
28    rhs: ExprRef,
29}
30
31impl PartialEq for BinaryExpr {
32    fn eq(&self, other: &Self) -> bool {
33        self.lhs.eq(&other.lhs) && self.operator == other.operator && self.rhs.eq(&other.rhs)
34    }
35}
36
37pub struct BinaryExprEncoding;
38
39impl VTable for BinaryVTable {
40    type Expr = BinaryExpr;
41    type Encoding = BinaryExprEncoding;
42    type Metadata = ProstMetadata<pb::BinaryOpts>;
43
44    fn id(_encoding: &Self::Encoding) -> ExprId {
45        ExprId::new_ref("binary")
46    }
47
48    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
49        ExprEncodingRef::new_ref(BinaryExprEncoding.as_ref())
50    }
51
52    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
53        Some(ProstMetadata(pb::BinaryOpts {
54            op: expr.operator.into(),
55        }))
56    }
57
58    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
59        vec![expr.lhs(), expr.rhs()]
60    }
61
62    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
63        Ok(BinaryExpr::new(
64            children[0].clone(),
65            expr.op(),
66            children[1].clone(),
67        ))
68    }
69
70    fn build(
71        _encoding: &Self::Encoding,
72        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
73        children: Vec<ExprRef>,
74    ) -> VortexResult<Self::Expr> {
75        Ok(BinaryExpr::new(
76            children[0].clone(),
77            metadata.op().into(),
78            children[1].clone(),
79        ))
80    }
81
82    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
83        let lhs = expr.lhs.unchecked_evaluate(scope)?;
84        let rhs = expr.rhs.unchecked_evaluate(scope)?;
85
86        match expr.operator {
87            Operator::Eq => compare(&lhs, &rhs, compute::Operator::Eq),
88            Operator::NotEq => compare(&lhs, &rhs, compute::Operator::NotEq),
89            Operator::Lt => compare(&lhs, &rhs, compute::Operator::Lt),
90            Operator::Lte => compare(&lhs, &rhs, compute::Operator::Lte),
91            Operator::Gt => compare(&lhs, &rhs, compute::Operator::Gt),
92            Operator::Gte => compare(&lhs, &rhs, compute::Operator::Gte),
93            Operator::And => and_kleene(&lhs, &rhs),
94            Operator::Or => or_kleene(&lhs, &rhs),
95            Operator::Add => add(&lhs, &rhs),
96            Operator::Sub => sub(&lhs, &rhs),
97        }
98    }
99
100    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
101        let lhs = expr.lhs.return_dtype(scope)?;
102        let rhs = expr.rhs.return_dtype(scope)?;
103
104        if expr.operator == Operator::Add {
105            if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
106                return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
107            }
108            vortex_bail!("incompatible types for checked add: {} {}", lhs, rhs);
109        }
110
111        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
112    }
113
114    fn operator(expr: &BinaryExpr, children: Vec<OperatorRef>) -> Option<OperatorRef> {
115        let [lhs, rhs] = children
116            .try_into()
117            .ok()
118            .vortex_expect("Expected 2 children");
119        let op = expr.operator.try_into().ok()?;
120
121        Some(Arc::new(CompareOperator::new(lhs, rhs, op)) as OperatorRef)
122    }
123}
124
125impl BinaryExpr {
126    pub fn new(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> Self {
127        Self { lhs, operator, rhs }
128    }
129
130    pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
131        Self::new(lhs, operator, rhs).into_expr()
132    }
133
134    pub fn lhs(&self) -> &ExprRef {
135        &self.lhs
136    }
137
138    pub fn rhs(&self) -> &ExprRef {
139        &self.rhs
140    }
141
142    pub fn op(&self) -> Operator {
143        self.operator
144    }
145}
146
147impl DisplayAs for BinaryExpr {
148    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
149        match df {
150            DisplayFormat::Compact => {
151                write!(f, "({} {} {})", self.lhs, self.operator, self.rhs)
152            }
153            DisplayFormat::Tree => {
154                write!(f, "Binary({})", self.operator)
155            }
156        }
157    }
158
159    fn child_names(&self) -> Option<Vec<String>> {
160        Some(vec!["lhs".to_string(), "rhs".to_string()])
161    }
162}
163
164impl AnalysisExpr for BinaryExpr {
165    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
166        // Wrap another predicate with an optional NaNCount check, if the stat is available.
167        //
168        // For example, regular pruning conversion for `A >= B` would be
169        //
170        //      A.max < B.min
171        //
172        // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting
173        // in:
174        //
175        //      (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min
176        //
177        // Non-floating point column and literal expressions should be unaffected as they do not
178        // have a nan_count statistic defined.
179        #[inline]
180        fn with_nan_predicate(
181            lhs: &ExprRef,
182            rhs: &ExprRef,
183            value_predicate: ExprRef,
184            catalog: &mut dyn StatsCatalog,
185        ) -> ExprRef {
186            let nan_predicate = lhs
187                .nan_count(catalog)
188                .into_iter()
189                .chain(rhs.nan_count(catalog))
190                .map(|nans| eq(nans, lit(0u64)))
191                .reduce(and);
192
193            if let Some(nan_check) = nan_predicate {
194                and(nan_check, value_predicate)
195            } else {
196                value_predicate
197            }
198        }
199
200        match self.operator {
201            Operator::Eq => {
202                let min_lhs = self.lhs.min(catalog);
203                let max_lhs = self.lhs.max(catalog);
204
205                let min_rhs = self.rhs.min(catalog);
206                let max_rhs = self.rhs.max(catalog);
207
208                let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
209                let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
210
211                let min_max_check = left.into_iter().chain(right).reduce(or)?;
212
213                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
214                Some(with_nan_predicate(
215                    self.lhs(),
216                    self.rhs(),
217                    min_max_check,
218                    catalog,
219                ))
220            }
221            Operator::NotEq => {
222                let min_lhs = self.lhs.min(catalog)?;
223                let max_lhs = self.lhs.max(catalog)?;
224
225                let min_rhs = self.rhs.min(catalog)?;
226                let max_rhs = self.rhs.max(catalog)?;
227
228                let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
229
230                Some(with_nan_predicate(
231                    self.lhs(),
232                    self.rhs(),
233                    min_max_check,
234                    catalog,
235                ))
236            }
237            Operator::Gt => {
238                let min_max_check = lt_eq(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
239
240                Some(with_nan_predicate(
241                    self.lhs(),
242                    self.rhs(),
243                    min_max_check,
244                    catalog,
245                ))
246            }
247            Operator::Gte => {
248                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
249                let min_max_check = lt(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
250
251                Some(with_nan_predicate(
252                    self.lhs(),
253                    self.rhs(),
254                    min_max_check,
255                    catalog,
256                ))
257            }
258            Operator::Lt => {
259                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
260                let min_max_check = gt_eq(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
261
262                Some(with_nan_predicate(
263                    self.lhs(),
264                    self.rhs(),
265                    min_max_check,
266                    catalog,
267                ))
268            }
269            Operator::Lte => {
270                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
271                let min_max_check = gt(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
272
273                Some(with_nan_predicate(
274                    self.lhs(),
275                    self.rhs(),
276                    min_max_check,
277                    catalog,
278                ))
279            }
280            Operator::And => self
281                .lhs
282                .stat_falsification(catalog)
283                .into_iter()
284                .chain(self.rhs.stat_falsification(catalog))
285                .reduce(or),
286            Operator::Or => Some(and(
287                self.lhs.stat_falsification(catalog)?,
288                self.rhs.stat_falsification(catalog)?,
289            )),
290            Operator::Add | Operator::Sub => None,
291        }
292    }
293}
294
295/// Create a new [`BinaryExpr`] using the [`Eq`](crate::Operator::Eq) operator.
296///
297/// ## Example usage
298///
299/// ```
300/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
301/// # use vortex_array::{Array, IntoArray, ToCanonical};
302/// # use vortex_array::validity::Validity;
303/// # use vortex_buffer::buffer;
304/// # use vortex_expr::{eq, root, lit, Scope};
305/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
306/// let result = eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
307///
308/// assert_eq!(
309///     result.to_bool().unwrap().boolean_buffer(),
310///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
311/// );
312/// ```
313pub fn eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
314    BinaryExpr::new(lhs, Operator::Eq, rhs).into_expr()
315}
316
317/// Create a new [`BinaryExpr`] using the [`NotEq`](crate::Operator::NotEq) operator.
318///
319/// ## Example usage
320///
321/// ```
322/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
323/// # use vortex_array::{IntoArray, ToCanonical};
324/// # use vortex_array::validity::Validity;
325/// # use vortex_buffer::buffer;
326/// # use vortex_expr::{root, lit, not_eq, Scope};
327/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
328/// let result = not_eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
329///
330/// assert_eq!(
331///     result.to_bool().unwrap().boolean_buffer(),
332///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
333/// );
334/// ```
335pub fn not_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
336    BinaryExpr::new(lhs, Operator::NotEq, rhs).into_expr()
337}
338
339/// Create a new [`BinaryExpr`] using the [`Gte`](crate::Operator::Gte) operator.
340///
341/// ## Example usage
342///
343/// ```
344/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
345/// # use vortex_array::{IntoArray, ToCanonical};
346/// # use vortex_array::validity::Validity;
347/// # use vortex_buffer::buffer;
348/// # use vortex_expr::{gt_eq, root, lit, Scope};
349/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
350/// let result = gt_eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
351///
352/// assert_eq!(
353///     result.to_bool().unwrap().boolean_buffer(),
354///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
355/// );
356/// ```
357pub fn gt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
358    BinaryExpr::new(lhs, Operator::Gte, rhs).into_expr()
359}
360
361/// Create a new [`BinaryExpr`] using the [`Gt`](crate::Operator::Gt) operator.
362///
363/// ## Example usage
364///
365/// ```
366/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
367/// # use vortex_array::{IntoArray, ToCanonical};
368/// # use vortex_array::validity::Validity;
369/// # use vortex_buffer::buffer;
370/// # use vortex_expr::{gt, root, lit, Scope};
371/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
372/// let result = gt(root(), lit(2)).evaluate(&Scope::new(xs.to_array())).unwrap();
373///
374/// assert_eq!(
375///     result.to_bool().unwrap().boolean_buffer(),
376///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
377/// );
378/// ```
379pub fn gt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
380    BinaryExpr::new(lhs, Operator::Gt, rhs).into_expr()
381}
382
383/// Create a new [`BinaryExpr`] using the [`Lte`](crate::Operator::Lte) operator.
384///
385/// ## Example usage
386///
387/// ```
388/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
389/// # use vortex_array::{IntoArray, ToCanonical};
390/// # use vortex_array::validity::Validity;
391/// # use vortex_buffer::buffer;
392/// # use vortex_expr::{root, lit, lt_eq, Scope};
393/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
394/// let result = lt_eq(root(), lit(2)).evaluate(&Scope::new(xs.to_array())).unwrap();
395///
396/// assert_eq!(
397///     result.to_bool().unwrap().boolean_buffer(),
398///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
399/// );
400/// ```
401pub fn lt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
402    BinaryExpr::new(lhs, Operator::Lte, rhs).into_expr()
403}
404
405/// Create a new [`BinaryExpr`] using the [`Lt`](crate::Operator::Lt) operator.
406///
407/// ## Example usage
408///
409/// ```
410/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
411/// # use vortex_array::{IntoArray, ToCanonical};
412/// # use vortex_array::validity::Validity;
413/// # use vortex_buffer::buffer;
414/// # use vortex_expr::{root, lit, lt, Scope};
415/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
416/// let result = lt(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
417///
418/// assert_eq!(
419///     result.to_bool().unwrap().boolean_buffer(),
420///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
421/// );
422/// ```
423pub fn lt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
424    BinaryExpr::new(lhs, Operator::Lt, rhs).into_expr()
425}
426
427/// Create a new [`BinaryExpr`] using the [`Or`](crate::Operator::Or) operator.
428///
429/// ## Example usage
430///
431/// ```
432/// # use vortex_array::arrays::BoolArray;
433/// # use vortex_array::{IntoArray, ToCanonical};
434/// # use vortex_expr::{root, lit, or, Scope};
435/// let xs = BoolArray::from_iter(vec![true, false, true]);
436/// let result = or(root(), lit(false)).evaluate(&Scope::new(xs.to_array())).unwrap();
437///
438/// assert_eq!(
439///     result.to_bool().unwrap().boolean_buffer(),
440///     BoolArray::from_iter(vec![true, false, true]).boolean_buffer(),
441/// );
442/// ```
443pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
444    BinaryExpr::new(lhs, Operator::Or, rhs).into_expr()
445}
446
447/// Collects a list of `or`ed values into a single vortex, expr
448/// [x, y, z] => x or (y or z)
449pub fn or_collect<I>(iter: I) -> Option<ExprRef>
450where
451    I: IntoIterator<Item = ExprRef>,
452    I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
453{
454    let mut iter = iter.into_iter();
455    let first = iter.next_back()?;
456    Some(iter.rfold(first, |acc, elem| or(elem, acc)))
457}
458
459/// Create a new [`BinaryExpr`] using the [`And`](crate::Operator::And) operator.
460///
461/// ## Example usage
462///
463/// ```
464/// # use vortex_array::arrays::BoolArray;
465/// # use vortex_array::{IntoArray, ToCanonical};
466/// # use vortex_expr::{and, root, lit, Scope};
467/// let xs = BoolArray::from_iter(vec![true, false, true]);
468/// let result = and(root(), lit(true)).evaluate(&Scope::new(xs.to_array())).unwrap();
469///
470/// assert_eq!(
471///     result.to_bool().unwrap().boolean_buffer(),
472///     BoolArray::from_iter(vec![true, false, true]).boolean_buffer(),
473/// );
474/// ```
475pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
476    BinaryExpr::new(lhs, Operator::And, rhs).into_expr()
477}
478
479/// Collects a list of `and`ed values into a single vortex, expr
480/// [x, y, z] => x and (y and z)
481pub fn and_collect<I>(iter: I) -> Option<ExprRef>
482where
483    I: IntoIterator<Item = ExprRef>,
484    I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
485{
486    let mut iter = iter.into_iter();
487    let first = iter.next_back()?;
488    Some(iter.rfold(first, |acc, elem| and(elem, acc)))
489}
490
491/// Collects a list of `and`ed values into a single vortex, expr
492/// [x, y, z] => x and (y and z)
493pub fn and_collect_right<I>(iter: I) -> Option<ExprRef>
494where
495    I: IntoIterator<Item = ExprRef>,
496{
497    let iter = iter.into_iter();
498    iter.reduce(and)
499}
500
501/// Create a new [`BinaryExpr`] using the [`Add`](crate::Operator::Add) operator.
502///
503/// ## Example usage
504///
505/// ```
506/// # use vortex_array::IntoArray;
507/// # use vortex_array::arrow::IntoArrowArray as _;
508/// # use vortex_buffer::buffer;
509/// # use vortex_expr::{Scope, checked_add, lit, root};
510/// let xs = buffer![1, 2, 3].into_array();
511/// let result = checked_add(root(), lit(5))
512///     .evaluate(&Scope::new(xs.to_array()))
513///     .unwrap();
514///
515/// assert_eq!(
516///     &result.into_arrow_preferred().unwrap(),
517///     &buffer![6, 7, 8]
518///         .into_array()
519///         .into_arrow_preferred()
520///         .unwrap()
521/// );
522/// ```
523pub fn checked_add(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
524    BinaryExpr::new(lhs, Operator::Add, rhs).into_expr()
525}
526
527#[cfg(test)]
528mod tests {
529    use std::sync::Arc;
530
531    use vortex_dtype::{DType, Nullability};
532
533    use crate::{
534        VortexExpr, and, and_collect, and_collect_right, col, eq, gt, gt_eq, lit, lt, lt_eq,
535        not_eq, or, test_harness,
536    };
537
538    #[test]
539    fn and_collect_left_assoc() {
540        let values = vec![lit(1), lit(2), lit(3)];
541        assert_eq!(
542            Some(and(lit(1), and(lit(2), lit(3)))),
543            and_collect(values.into_iter())
544        );
545    }
546
547    #[test]
548    fn and_collect_right_assoc() {
549        let values = vec![lit(1), lit(2), lit(3)];
550        assert_eq!(
551            Some(and(and(lit(1), lit(2)), lit(3))),
552            and_collect_right(values.into_iter())
553        );
554    }
555
556    #[test]
557    fn dtype() {
558        let dtype = test_harness::struct_dtype();
559        let bool1: Arc<dyn VortexExpr> = col("bool1");
560        let bool2: Arc<dyn VortexExpr> = col("bool2");
561        assert_eq!(
562            and(bool1.clone(), bool2.clone())
563                .return_dtype(&dtype)
564                .unwrap(),
565            DType::Bool(Nullability::NonNullable)
566        );
567        assert_eq!(
568            or(bool1.clone(), bool2.clone())
569                .return_dtype(&dtype)
570                .unwrap(),
571            DType::Bool(Nullability::NonNullable)
572        );
573
574        let col1: Arc<dyn VortexExpr> = col("col1");
575        let col2: Arc<dyn VortexExpr> = col("col2");
576
577        assert_eq!(
578            eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
579            DType::Bool(Nullability::Nullable)
580        );
581        assert_eq!(
582            not_eq(col1.clone(), col2.clone())
583                .return_dtype(&dtype)
584                .unwrap(),
585            DType::Bool(Nullability::Nullable)
586        );
587        assert_eq!(
588            gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
589            DType::Bool(Nullability::Nullable)
590        );
591        assert_eq!(
592            gt_eq(col1.clone(), col2.clone())
593                .return_dtype(&dtype)
594                .unwrap(),
595            DType::Bool(Nullability::Nullable)
596        );
597        assert_eq!(
598            lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
599            DType::Bool(Nullability::Nullable)
600        );
601        assert_eq!(
602            lt_eq(col1.clone(), col2.clone())
603                .return_dtype(&dtype)
604                .unwrap(),
605            DType::Bool(Nullability::Nullable)
606        );
607
608        assert_eq!(
609            or(
610                lt(col1.clone(), col2.clone()),
611                not_eq(col1.clone(), col2.clone())
612            )
613            .return_dtype(&dtype)
614            .unwrap(),
615            DType::Bool(Nullability::Nullable)
616        );
617    }
618}