vortex_expr/exprs/
binary.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5use std::sync::Arc;
6
7use vortex_array::compute::{add, and_kleene, compare, or_kleene, sub};
8use vortex_array::operator::OperatorRef;
9use vortex_array::operator::compare::CompareOperator;
10use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata, compute};
11use vortex_dtype::DType;
12use vortex_error::{VortexResult, vortex_bail};
13use vortex_proto::expr as pb;
14
15use crate::display::{DisplayAs, DisplayFormat};
16use crate::{
17    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Operator, Scope, StatsCatalog,
18    VTable, lit, vtable,
19};
20
21vtable!(Binary);
22
23#[allow(clippy::derived_hash_with_manual_eq)]
24#[derive(Debug, Clone, Hash, Eq)]
25pub struct BinaryExpr {
26    lhs: ExprRef,
27    operator: Operator,
28    rhs: ExprRef,
29}
30
31impl PartialEq for BinaryExpr {
32    fn eq(&self, other: &Self) -> bool {
33        self.lhs.eq(&other.lhs) && self.operator == other.operator && self.rhs.eq(&other.rhs)
34    }
35}
36
37pub struct BinaryExprEncoding;
38
39impl VTable for BinaryVTable {
40    type Expr = BinaryExpr;
41    type Encoding = BinaryExprEncoding;
42    type Metadata = ProstMetadata<pb::BinaryOpts>;
43
44    fn id(_encoding: &Self::Encoding) -> ExprId {
45        ExprId::new_ref("binary")
46    }
47
48    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
49        ExprEncodingRef::new_ref(BinaryExprEncoding.as_ref())
50    }
51
52    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
53        Some(ProstMetadata(pb::BinaryOpts {
54            op: expr.operator.into(),
55        }))
56    }
57
58    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
59        vec![expr.lhs(), expr.rhs()]
60    }
61
62    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
63        Ok(BinaryExpr::new(
64            children[0].clone(),
65            expr.op(),
66            children[1].clone(),
67        ))
68    }
69
70    fn build(
71        _encoding: &Self::Encoding,
72        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
73        children: Vec<ExprRef>,
74    ) -> VortexResult<Self::Expr> {
75        Ok(BinaryExpr::new(
76            children[0].clone(),
77            metadata.op().into(),
78            children[1].clone(),
79        ))
80    }
81
82    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
83        let lhs = expr.lhs.unchecked_evaluate(scope)?;
84        let rhs = expr.rhs.unchecked_evaluate(scope)?;
85
86        match expr.operator {
87            Operator::Eq => compare(&lhs, &rhs, compute::Operator::Eq),
88            Operator::NotEq => compare(&lhs, &rhs, compute::Operator::NotEq),
89            Operator::Lt => compare(&lhs, &rhs, compute::Operator::Lt),
90            Operator::Lte => compare(&lhs, &rhs, compute::Operator::Lte),
91            Operator::Gt => compare(&lhs, &rhs, compute::Operator::Gt),
92            Operator::Gte => compare(&lhs, &rhs, compute::Operator::Gte),
93            Operator::And => and_kleene(&lhs, &rhs),
94            Operator::Or => or_kleene(&lhs, &rhs),
95            Operator::Add => add(&lhs, &rhs),
96            Operator::Sub => sub(&lhs, &rhs),
97        }
98    }
99
100    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
101        let lhs = expr.lhs.return_dtype(scope)?;
102        let rhs = expr.rhs.return_dtype(scope)?;
103
104        if expr.operator == Operator::Add {
105            if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
106                return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
107            }
108            vortex_bail!("incompatible types for checked add: {} {}", lhs, rhs);
109        }
110
111        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
112    }
113
114    fn operator(expr: &BinaryExpr, scope: &OperatorRef) -> VortexResult<Option<OperatorRef>> {
115        let Some(lhs) = expr.lhs.operator(scope)? else {
116            return Ok(None);
117        };
118        let Some(rhs) = expr.rhs.operator(scope)? else {
119            return Ok(None);
120        };
121        let Ok(op): VortexResult<compute::Operator> = expr.operator.try_into() else {
122            return Ok(None);
123        };
124        Ok(Some(Arc::new(CompareOperator::try_new(lhs, rhs, op)?)))
125    }
126}
127
128impl BinaryExpr {
129    pub fn new(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> Self {
130        Self { lhs, operator, rhs }
131    }
132
133    pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
134        Self::new(lhs, operator, rhs).into_expr()
135    }
136
137    pub fn lhs(&self) -> &ExprRef {
138        &self.lhs
139    }
140
141    pub fn rhs(&self) -> &ExprRef {
142        &self.rhs
143    }
144
145    pub fn op(&self) -> Operator {
146        self.operator
147    }
148}
149
150impl DisplayAs for BinaryExpr {
151    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
152        match df {
153            DisplayFormat::Compact => {
154                write!(f, "({} {} {})", self.lhs, self.operator, self.rhs)
155            }
156            DisplayFormat::Tree => {
157                write!(f, "Binary({})", self.operator)
158            }
159        }
160    }
161
162    fn child_names(&self) -> Option<Vec<String>> {
163        Some(vec!["lhs".to_string(), "rhs".to_string()])
164    }
165}
166
167impl AnalysisExpr for BinaryExpr {
168    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
169        // Wrap another predicate with an optional NaNCount check, if the stat is available.
170        //
171        // For example, regular pruning conversion for `A >= B` would be
172        //
173        //      A.max < B.min
174        //
175        // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting
176        // in:
177        //
178        //      (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min
179        //
180        // Non-floating point column and literal expressions should be unaffected as they do not
181        // have a nan_count statistic defined.
182        #[inline]
183        fn with_nan_predicate(
184            lhs: &ExprRef,
185            rhs: &ExprRef,
186            value_predicate: ExprRef,
187            catalog: &mut dyn StatsCatalog,
188        ) -> ExprRef {
189            let nan_predicate = lhs
190                .nan_count(catalog)
191                .into_iter()
192                .chain(rhs.nan_count(catalog))
193                .map(|nans| eq(nans, lit(0u64)))
194                .reduce(and);
195
196            if let Some(nan_check) = nan_predicate {
197                and(nan_check, value_predicate)
198            } else {
199                value_predicate
200            }
201        }
202
203        match self.operator {
204            Operator::Eq => {
205                let min_lhs = self.lhs.min(catalog);
206                let max_lhs = self.lhs.max(catalog);
207
208                let min_rhs = self.rhs.min(catalog);
209                let max_rhs = self.rhs.max(catalog);
210
211                let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
212                let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
213
214                let min_max_check = left.into_iter().chain(right).reduce(or)?;
215
216                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
217                Some(with_nan_predicate(
218                    self.lhs(),
219                    self.rhs(),
220                    min_max_check,
221                    catalog,
222                ))
223            }
224            Operator::NotEq => {
225                let min_lhs = self.lhs.min(catalog)?;
226                let max_lhs = self.lhs.max(catalog)?;
227
228                let min_rhs = self.rhs.min(catalog)?;
229                let max_rhs = self.rhs.max(catalog)?;
230
231                let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
232
233                Some(with_nan_predicate(
234                    self.lhs(),
235                    self.rhs(),
236                    min_max_check,
237                    catalog,
238                ))
239            }
240            Operator::Gt => {
241                let min_max_check = lt_eq(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
242
243                Some(with_nan_predicate(
244                    self.lhs(),
245                    self.rhs(),
246                    min_max_check,
247                    catalog,
248                ))
249            }
250            Operator::Gte => {
251                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
252                let min_max_check = lt(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
253
254                Some(with_nan_predicate(
255                    self.lhs(),
256                    self.rhs(),
257                    min_max_check,
258                    catalog,
259                ))
260            }
261            Operator::Lt => {
262                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
263                let min_max_check = gt_eq(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
264
265                Some(with_nan_predicate(
266                    self.lhs(),
267                    self.rhs(),
268                    min_max_check,
269                    catalog,
270                ))
271            }
272            Operator::Lte => {
273                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
274                let min_max_check = gt(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
275
276                Some(with_nan_predicate(
277                    self.lhs(),
278                    self.rhs(),
279                    min_max_check,
280                    catalog,
281                ))
282            }
283            Operator::And => self
284                .lhs
285                .stat_falsification(catalog)
286                .into_iter()
287                .chain(self.rhs.stat_falsification(catalog))
288                .reduce(or),
289            Operator::Or => Some(and(
290                self.lhs.stat_falsification(catalog)?,
291                self.rhs.stat_falsification(catalog)?,
292            )),
293            Operator::Add | Operator::Sub => None,
294        }
295    }
296}
297
298/// Create a new [`BinaryExpr`] using the [`Eq`](crate::Operator::Eq) operator.
299///
300/// ## Example usage
301///
302/// ```
303/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
304/// # use vortex_array::{Array, IntoArray, ToCanonical};
305/// # use vortex_array::validity::Validity;
306/// # use vortex_buffer::buffer;
307/// # use vortex_expr::{eq, root, lit, Scope};
308/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
309/// let result = eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
310///
311/// assert_eq!(
312///     result.to_bool().boolean_buffer(),
313///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
314/// );
315/// ```
316pub fn eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
317    BinaryExpr::new(lhs, Operator::Eq, rhs).into_expr()
318}
319
320/// Create a new [`BinaryExpr`] using the [`NotEq`](crate::Operator::NotEq) operator.
321///
322/// ## Example usage
323///
324/// ```
325/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
326/// # use vortex_array::{IntoArray, ToCanonical};
327/// # use vortex_array::validity::Validity;
328/// # use vortex_buffer::buffer;
329/// # use vortex_expr::{root, lit, not_eq, Scope};
330/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
331/// let result = not_eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
332///
333/// assert_eq!(
334///     result.to_bool().boolean_buffer(),
335///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
336/// );
337/// ```
338pub fn not_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
339    BinaryExpr::new(lhs, Operator::NotEq, rhs).into_expr()
340}
341
342/// Create a new [`BinaryExpr`] using the [`Gte`](crate::Operator::Gte) operator.
343///
344/// ## Example usage
345///
346/// ```
347/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
348/// # use vortex_array::{IntoArray, ToCanonical};
349/// # use vortex_array::validity::Validity;
350/// # use vortex_buffer::buffer;
351/// # use vortex_expr::{gt_eq, root, lit, Scope};
352/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
353/// let result = gt_eq(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
354///
355/// assert_eq!(
356///     result.to_bool().boolean_buffer(),
357///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
358/// );
359/// ```
360pub fn gt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
361    BinaryExpr::new(lhs, Operator::Gte, rhs).into_expr()
362}
363
364/// Create a new [`BinaryExpr`] using the [`Gt`](crate::Operator::Gt) operator.
365///
366/// ## Example usage
367///
368/// ```
369/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
370/// # use vortex_array::{IntoArray, ToCanonical};
371/// # use vortex_array::validity::Validity;
372/// # use vortex_buffer::buffer;
373/// # use vortex_expr::{gt, root, lit, Scope};
374/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
375/// let result = gt(root(), lit(2)).evaluate(&Scope::new(xs.to_array())).unwrap();
376///
377/// assert_eq!(
378///     result.to_bool().boolean_buffer(),
379///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
380/// );
381/// ```
382pub fn gt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
383    BinaryExpr::new(lhs, Operator::Gt, rhs).into_expr()
384}
385
386/// Create a new [`BinaryExpr`] using the [`Lte`](crate::Operator::Lte) operator.
387///
388/// ## Example usage
389///
390/// ```
391/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
392/// # use vortex_array::{IntoArray, ToCanonical};
393/// # use vortex_array::validity::Validity;
394/// # use vortex_buffer::buffer;
395/// # use vortex_expr::{root, lit, lt_eq, Scope};
396/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
397/// let result = lt_eq(root(), lit(2)).evaluate(&Scope::new(xs.to_array())).unwrap();
398///
399/// assert_eq!(
400///     result.to_bool().boolean_buffer(),
401///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
402/// );
403/// ```
404pub fn lt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
405    BinaryExpr::new(lhs, Operator::Lte, rhs).into_expr()
406}
407
408/// Create a new [`BinaryExpr`] using the [`Lt`](crate::Operator::Lt) operator.
409///
410/// ## Example usage
411///
412/// ```
413/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
414/// # use vortex_array::{IntoArray, ToCanonical};
415/// # use vortex_array::validity::Validity;
416/// # use vortex_buffer::buffer;
417/// # use vortex_expr::{root, lit, lt, Scope};
418/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
419/// let result = lt(root(), lit(3)).evaluate(&Scope::new(xs.to_array())).unwrap();
420///
421/// assert_eq!(
422///     result.to_bool().boolean_buffer(),
423///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
424/// );
425/// ```
426pub fn lt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
427    BinaryExpr::new(lhs, Operator::Lt, rhs).into_expr()
428}
429
430/// Create a new [`BinaryExpr`] using the [`Or`](crate::Operator::Or) operator.
431///
432/// ## Example usage
433///
434/// ```
435/// # use vortex_array::arrays::BoolArray;
436/// # use vortex_array::{IntoArray, ToCanonical};
437/// # use vortex_expr::{root, lit, or, Scope};
438/// let xs = BoolArray::from_iter(vec![true, false, true]);
439/// let result = or(root(), lit(false)).evaluate(&Scope::new(xs.to_array())).unwrap();
440///
441/// assert_eq!(
442///     result.to_bool().boolean_buffer(),
443///     BoolArray::from_iter(vec![true, false, true]).boolean_buffer(),
444/// );
445/// ```
446pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
447    BinaryExpr::new(lhs, Operator::Or, rhs).into_expr()
448}
449
450/// Collects a list of `or`ed values into a single vortex, expr
451/// [x, y, z] => x or (y or z)
452pub fn or_collect<I>(iter: I) -> Option<ExprRef>
453where
454    I: IntoIterator<Item = ExprRef>,
455    I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
456{
457    let mut iter = iter.into_iter();
458    let first = iter.next_back()?;
459    Some(iter.rfold(first, |acc, elem| or(elem, acc)))
460}
461
462/// Create a new [`BinaryExpr`] using the [`And`](crate::Operator::And) operator.
463///
464/// ## Example usage
465///
466/// ```
467/// # use vortex_array::arrays::BoolArray;
468/// # use vortex_array::{IntoArray, ToCanonical};
469/// # use vortex_expr::{and, root, lit, Scope};
470/// let xs = BoolArray::from_iter(vec![true, false, true]);
471/// let result = and(root(), lit(true)).evaluate(&Scope::new(xs.to_array())).unwrap();
472///
473/// assert_eq!(
474///     result.to_bool().boolean_buffer(),
475///     BoolArray::from_iter(vec![true, false, true]).boolean_buffer(),
476/// );
477/// ```
478pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
479    BinaryExpr::new(lhs, Operator::And, rhs).into_expr()
480}
481
482/// Collects a list of `and`ed values into a single vortex, expr
483/// [x, y, z] => x and (y and z)
484pub fn and_collect<I>(iter: I) -> Option<ExprRef>
485where
486    I: IntoIterator<Item = ExprRef>,
487    I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
488{
489    let mut iter = iter.into_iter();
490    let first = iter.next_back()?;
491    Some(iter.rfold(first, |acc, elem| and(elem, acc)))
492}
493
494/// Collects a list of `and`ed values into a single vortex, expr
495/// [x, y, z] => x and (y and z)
496pub fn and_collect_right<I>(iter: I) -> Option<ExprRef>
497where
498    I: IntoIterator<Item = ExprRef>,
499{
500    let iter = iter.into_iter();
501    iter.reduce(and)
502}
503
504/// Create a new [`BinaryExpr`] using the [`Add`](crate::Operator::Add) operator.
505///
506/// ## Example usage
507///
508/// ```
509/// # use vortex_array::IntoArray;
510/// # use vortex_array::arrow::IntoArrowArray as _;
511/// # use vortex_buffer::buffer;
512/// # use vortex_expr::{Scope, checked_add, lit, root};
513/// let xs = buffer![1, 2, 3].into_array();
514/// let result = checked_add(root(), lit(5))
515///     .evaluate(&Scope::new(xs.to_array()))
516///     .unwrap();
517///
518/// assert_eq!(
519///     &result.into_arrow_preferred().unwrap(),
520///     &buffer![6, 7, 8]
521///         .into_array()
522///         .into_arrow_preferred()
523///         .unwrap()
524/// );
525/// ```
526pub fn checked_add(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
527    BinaryExpr::new(lhs, Operator::Add, rhs).into_expr()
528}
529
530#[cfg(test)]
531mod tests {
532    use std::sync::Arc;
533
534    use vortex_dtype::{DType, Nullability};
535
536    use crate::{
537        VortexExpr, and, and_collect, and_collect_right, col, eq, gt, gt_eq, lit, lt, lt_eq,
538        not_eq, or, test_harness,
539    };
540
541    #[test]
542    fn and_collect_left_assoc() {
543        let values = vec![lit(1), lit(2), lit(3)];
544        assert_eq!(
545            Some(and(lit(1), and(lit(2), lit(3)))),
546            and_collect(values.into_iter())
547        );
548    }
549
550    #[test]
551    fn and_collect_right_assoc() {
552        let values = vec![lit(1), lit(2), lit(3)];
553        assert_eq!(
554            Some(and(and(lit(1), lit(2)), lit(3))),
555            and_collect_right(values.into_iter())
556        );
557    }
558
559    #[test]
560    fn dtype() {
561        let dtype = test_harness::struct_dtype();
562        let bool1: Arc<dyn VortexExpr> = col("bool1");
563        let bool2: Arc<dyn VortexExpr> = col("bool2");
564        assert_eq!(
565            and(bool1.clone(), bool2.clone())
566                .return_dtype(&dtype)
567                .unwrap(),
568            DType::Bool(Nullability::NonNullable)
569        );
570        assert_eq!(
571            or(bool1.clone(), bool2.clone())
572                .return_dtype(&dtype)
573                .unwrap(),
574            DType::Bool(Nullability::NonNullable)
575        );
576
577        let col1: Arc<dyn VortexExpr> = col("col1");
578        let col2: Arc<dyn VortexExpr> = col("col2");
579
580        assert_eq!(
581            eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
582            DType::Bool(Nullability::Nullable)
583        );
584        assert_eq!(
585            not_eq(col1.clone(), col2.clone())
586                .return_dtype(&dtype)
587                .unwrap(),
588            DType::Bool(Nullability::Nullable)
589        );
590        assert_eq!(
591            gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
592            DType::Bool(Nullability::Nullable)
593        );
594        assert_eq!(
595            gt_eq(col1.clone(), col2.clone())
596                .return_dtype(&dtype)
597                .unwrap(),
598            DType::Bool(Nullability::Nullable)
599        );
600        assert_eq!(
601            lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
602            DType::Bool(Nullability::Nullable)
603        );
604        assert_eq!(
605            lt_eq(col1.clone(), col2.clone())
606                .return_dtype(&dtype)
607                .unwrap(),
608            DType::Bool(Nullability::Nullable)
609        );
610
611        assert_eq!(
612            or(
613                lt(col1.clone(), col2.clone()),
614                not_eq(col1.clone(), col2.clone())
615            )
616            .return_dtype(&dtype)
617            .unwrap(),
618            DType::Bool(Nullability::Nullable)
619        );
620    }
621}