vortex_array/expr/exprs/
binary.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_dtype::DType;
8use vortex_error::{VortexExpect, VortexResult, vortex_bail};
9use vortex_proto::expr as pb;
10
11use crate::compute::{add, and_kleene, compare, div, mul, or_kleene, sub};
12use crate::expr::expression::Expression;
13use crate::expr::exprs::literal::lit;
14use crate::expr::exprs::operators::Operator;
15use crate::expr::{ChildName, ExprId, ExpressionView, StatsCatalog, VTable, VTableExt};
16use crate::{ArrayRef, compute};
17
18pub struct Binary;
19
20impl VTable for Binary {
21    type Instance = Operator;
22
23    fn id(&self) -> ExprId {
24        ExprId::from("vortex.binary")
25    }
26
27    fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
28        Ok(Some(
29            pb::BinaryOpts {
30                op: (*instance).into(),
31            }
32            .encode_to_vec(),
33        ))
34    }
35
36    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
37        let opts = pb::BinaryOpts::decode(metadata)?;
38        Ok(Some(Operator::try_from(opts.op)?))
39    }
40
41    fn validate(&self, _expr: &ExpressionView<Self>) -> VortexResult<()> {
42        // TODO(ngates): check the dtypes.
43        Ok(())
44    }
45
46    fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
47        match child_idx {
48            0 => ChildName::from("lhs"),
49            1 => ChildName::from("rhs"),
50            _ => unreachable!("Binary has only two children"),
51        }
52    }
53
54    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
55        write!(f, "(")?;
56        expr.lhs().fmt_sql(f)?;
57        write!(f, " {} ", expr.operator())?;
58        expr.rhs().fmt_sql(f)?;
59        write!(f, ")")
60    }
61
62    fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result {
63        write!(f, "{}", *instance)
64    }
65
66    fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
67        let lhs = expr.lhs().return_dtype(scope)?;
68        let rhs = expr.rhs().return_dtype(scope)?;
69
70        if expr.operator().is_arithmetic() {
71            if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
72                return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
73            }
74            vortex_bail!(
75                "incompatible types for arithmetic operation: {} {}",
76                lhs,
77                rhs
78            );
79        }
80
81        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
82    }
83
84    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
85        let lhs = expr.lhs().evaluate(scope)?;
86        let rhs = expr.rhs().evaluate(scope)?;
87
88        match expr.operator() {
89            Operator::Eq => compare(&lhs, &rhs, compute::Operator::Eq),
90            Operator::NotEq => compare(&lhs, &rhs, compute::Operator::NotEq),
91            Operator::Lt => compare(&lhs, &rhs, compute::Operator::Lt),
92            Operator::Lte => compare(&lhs, &rhs, compute::Operator::Lte),
93            Operator::Gt => compare(&lhs, &rhs, compute::Operator::Gt),
94            Operator::Gte => compare(&lhs, &rhs, compute::Operator::Gte),
95            Operator::And => and_kleene(&lhs, &rhs),
96            Operator::Or => or_kleene(&lhs, &rhs),
97            Operator::Add => add(&lhs, &rhs),
98            Operator::Sub => sub(&lhs, &rhs),
99            Operator::Mul => mul(&lhs, &rhs),
100            Operator::Div => div(&lhs, &rhs),
101        }
102    }
103
104    fn stat_falsification(
105        &self,
106        expr: &ExpressionView<Self>,
107        catalog: &mut dyn StatsCatalog,
108    ) -> Option<Expression> {
109        // Wrap another predicate with an optional NaNCount check, if the stat is available.
110        //
111        // For example, regular pruning conversion for `A >= B` would be
112        //
113        //      A.max < B.min
114        //
115        // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting
116        // in:
117        //
118        //      (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min
119        //
120        // Non-floating point column and literal expressions should be unaffected as they do not
121        // have a nan_count statistic defined.
122        #[inline]
123        fn with_nan_predicate(
124            lhs: &Expression,
125            rhs: &Expression,
126            value_predicate: Expression,
127            catalog: &mut dyn StatsCatalog,
128        ) -> Expression {
129            let nan_predicate = lhs
130                .stat_nan_count(catalog)
131                .into_iter()
132                .chain(rhs.stat_nan_count(catalog))
133                .map(|nans| eq(nans, lit(0u64)))
134                .reduce(and);
135
136            if let Some(nan_check) = nan_predicate {
137                and(nan_check, value_predicate)
138            } else {
139                value_predicate
140            }
141        }
142
143        match expr.operator() {
144            Operator::Eq => {
145                let min_lhs = expr.lhs().stat_min(catalog);
146                let max_lhs = expr.lhs().stat_max(catalog);
147
148                let min_rhs = expr.rhs().stat_min(catalog);
149                let max_rhs = expr.rhs().stat_max(catalog);
150
151                let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
152                let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
153
154                let min_max_check = left.into_iter().chain(right).reduce(or)?;
155
156                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
157                Some(with_nan_predicate(
158                    expr.lhs(),
159                    expr.rhs(),
160                    min_max_check,
161                    catalog,
162                ))
163            }
164            Operator::NotEq => {
165                let min_lhs = expr.lhs().stat_min(catalog)?;
166                let max_lhs = expr.lhs().stat_max(catalog)?;
167
168                let min_rhs = expr.rhs().stat_min(catalog)?;
169                let max_rhs = expr.rhs().stat_max(catalog)?;
170
171                let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
172
173                Some(with_nan_predicate(
174                    expr.lhs(),
175                    expr.rhs(),
176                    min_max_check,
177                    catalog,
178                ))
179            }
180            Operator::Gt => {
181                let min_max_check =
182                    lt_eq(expr.lhs().stat_max(catalog)?, expr.rhs().stat_min(catalog)?);
183
184                Some(with_nan_predicate(
185                    expr.lhs(),
186                    expr.rhs(),
187                    min_max_check,
188                    catalog,
189                ))
190            }
191            Operator::Gte => {
192                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
193                let min_max_check =
194                    lt(expr.lhs().stat_max(catalog)?, expr.rhs().stat_min(catalog)?);
195
196                Some(with_nan_predicate(
197                    expr.lhs(),
198                    expr.rhs(),
199                    min_max_check,
200                    catalog,
201                ))
202            }
203            Operator::Lt => {
204                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
205                let min_max_check =
206                    gt_eq(expr.lhs().stat_min(catalog)?, expr.rhs().stat_max(catalog)?);
207
208                Some(with_nan_predicate(
209                    expr.lhs(),
210                    expr.rhs(),
211                    min_max_check,
212                    catalog,
213                ))
214            }
215            Operator::Lte => {
216                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
217                let min_max_check =
218                    gt(expr.lhs().stat_min(catalog)?, expr.rhs().stat_max(catalog)?);
219
220                Some(with_nan_predicate(
221                    expr.lhs(),
222                    expr.rhs(),
223                    min_max_check,
224                    catalog,
225                ))
226            }
227            Operator::And => expr
228                .lhs()
229                .stat_falsification(catalog)
230                .into_iter()
231                .chain(expr.rhs().stat_falsification(catalog))
232                .reduce(or),
233            Operator::Or => Some(and(
234                expr.lhs().stat_falsification(catalog)?,
235                expr.rhs().stat_falsification(catalog)?,
236            )),
237            Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
238        }
239    }
240}
241
242impl ExpressionView<'_, Binary> {
243    pub fn lhs(&self) -> &Expression {
244        &self.children()[0]
245    }
246
247    pub fn rhs(&self) -> &Expression {
248        &self.children()[1]
249    }
250
251    pub fn operator(&self) -> Operator {
252        *self.data()
253    }
254}
255
256/// Create a new [`Binary`] using the [`Eq`](crate::expr::exprs::operators::Operator::Eq) operator.
257///
258/// ## Example usage
259///
260/// ```
261/// # use vortex_array::arrays::{BoolArray, PrimitiveArray};
262/// # use vortex_array::{Array, IntoArray, ToCanonical};
263/// # use vortex_array::validity::Validity;
264/// # use vortex_buffer::buffer;
265/// # use vortex_array::expr::{eq, root, lit};
266/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
267/// let result = eq(root(), lit(3)).evaluate(&xs.to_array()).unwrap();
268///
269/// assert_eq!(
270///     result.to_bool().bit_buffer(),
271///     BoolArray::from_iter(vec![false, false, true]).bit_buffer(),
272/// );
273/// ```
274pub fn eq(lhs: Expression, rhs: Expression) -> Expression {
275    Binary
276        .try_new_expr(Operator::Eq, [lhs, rhs])
277        .vortex_expect("Failed to create Eq binary expression")
278}
279
280/// Create a new [`Binary`] using the [`NotEq`](crate::expr::exprs::operators::Operator::NotEq) operator.
281///
282/// ## Example usage
283///
284/// ```
285/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
286/// # use vortex_array::{IntoArray, ToCanonical};
287/// # use vortex_array::validity::Validity;
288/// # use vortex_buffer::buffer;
289/// # use vortex_array::expr::{root, lit, not_eq};
290/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
291/// let result = not_eq(root(), lit(3)).evaluate(&xs.to_array()).unwrap();
292///
293/// assert_eq!(
294///     result.to_bool().bit_buffer(),
295///     BoolArray::from_iter(vec![true, true, false]).bit_buffer(),
296/// );
297/// ```
298pub fn not_eq(lhs: Expression, rhs: Expression) -> Expression {
299    Binary
300        .try_new_expr(Operator::NotEq, [lhs, rhs])
301        .vortex_expect("Failed to create NotEq binary expression")
302}
303
304/// Create a new [`Binary`] using the [`Gte`](crate::expr::exprs::operators::Operator::Gte) operator.
305///
306/// ## Example usage
307///
308/// ```
309/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
310/// # use vortex_array::{IntoArray, ToCanonical};
311/// # use vortex_array::validity::Validity;
312/// # use vortex_buffer::buffer;
313/// # use vortex_array::expr::{gt_eq, root, lit};
314/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
315/// let result = gt_eq(root(), lit(3)).evaluate(&xs.to_array()).unwrap();
316///
317/// assert_eq!(
318///     result.to_bool().bit_buffer(),
319///     BoolArray::from_iter(vec![false, false, true]).bit_buffer(),
320/// );
321/// ```
322pub fn gt_eq(lhs: Expression, rhs: Expression) -> Expression {
323    Binary
324        .try_new_expr(Operator::Gte, [lhs, rhs])
325        .vortex_expect("Failed to create Gte binary expression")
326}
327
328/// Create a new [`Binary`] using the [`Gt`](crate::expr::exprs::operators::Operator::Gt) operator.
329///
330/// ## Example usage
331///
332/// ```
333/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
334/// # use vortex_array::{IntoArray, ToCanonical};
335/// # use vortex_array::validity::Validity;
336/// # use vortex_buffer::buffer;
337/// # use vortex_array::expr::{gt, root, lit};
338/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
339/// let result = gt(root(), lit(2)).evaluate(&xs.to_array()).unwrap();
340///
341/// assert_eq!(
342///     result.to_bool().bit_buffer(),
343///     BoolArray::from_iter(vec![false, false, true]).bit_buffer(),
344/// );
345/// ```
346pub fn gt(lhs: Expression, rhs: Expression) -> Expression {
347    Binary
348        .try_new_expr(Operator::Gt, [lhs, rhs])
349        .vortex_expect("Failed to create Gt binary expression")
350}
351
352/// Create a new [`Binary`] using the [`Lte`](crate::expr::exprs::operators::Operator::Lte) operator.
353///
354/// ## Example usage
355///
356/// ```
357/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
358/// # use vortex_array::{IntoArray, ToCanonical};
359/// # use vortex_array::validity::Validity;
360/// # use vortex_buffer::buffer;
361/// # use vortex_array::expr::{root, lit, lt_eq};
362/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
363/// let result = lt_eq(root(), lit(2)).evaluate(&xs.to_array()).unwrap();
364///
365/// assert_eq!(
366///     result.to_bool().bit_buffer(),
367///     BoolArray::from_iter(vec![true, true, false]).bit_buffer(),
368/// );
369/// ```
370pub fn lt_eq(lhs: Expression, rhs: Expression) -> Expression {
371    Binary
372        .try_new_expr(Operator::Lte, [lhs, rhs])
373        .vortex_expect("Failed to create Lte binary expression")
374}
375
376/// Create a new [`Binary`] using the [`Lt`](crate::expr::exprs::operators::Operator::Lt) operator.
377///
378/// ## Example usage
379///
380/// ```
381/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
382/// # use vortex_array::{IntoArray, ToCanonical};
383/// # use vortex_array::validity::Validity;
384/// # use vortex_buffer::buffer;
385/// # use vortex_array::expr::{root, lit, lt};
386/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
387/// let result = lt(root(), lit(3)).evaluate(&xs.to_array()).unwrap();
388///
389/// assert_eq!(
390///     result.to_bool().bit_buffer(),
391///     BoolArray::from_iter(vec![true, true, false]).bit_buffer(),
392/// );
393/// ```
394pub fn lt(lhs: Expression, rhs: Expression) -> Expression {
395    Binary
396        .try_new_expr(Operator::Lt, [lhs, rhs])
397        .vortex_expect("Failed to create Lt binary expression")
398}
399
400/// Create a new [`Binary`] using the [`Or`](crate::expr::exprs::operators::Operator::Or) operator.
401///
402/// ## Example usage
403///
404/// ```
405/// # use vortex_array::arrays::BoolArray;
406/// # use vortex_array::{IntoArray, ToCanonical};
407/// # use vortex_array::expr::{root, lit, or};
408/// let xs = BoolArray::from_iter(vec![true, false, true]);
409/// let result = or(root(), lit(false)).evaluate(&xs.to_array()).unwrap();
410///
411/// assert_eq!(
412///     result.to_bool().bit_buffer(),
413///     BoolArray::from_iter(vec![true, false, true]).bit_buffer(),
414/// );
415/// ```
416pub fn or(lhs: Expression, rhs: Expression) -> Expression {
417    Binary
418        .try_new_expr(Operator::Or, [lhs, rhs])
419        .vortex_expect("Failed to create Or binary expression")
420}
421
422/// Collects a list of `or`ed values into a single vortex, expr
423/// [x, y, z] => x or (y or z)
424pub fn or_collect<I>(iter: I) -> Option<Expression>
425where
426    I: IntoIterator<Item = Expression>,
427    I::IntoIter: DoubleEndedIterator<Item = Expression>,
428{
429    let mut iter = iter.into_iter();
430    let first = iter.next_back()?;
431    Some(iter.rfold(first, |acc, elem| or(elem, acc)))
432}
433
434/// Create a new [`Binary`] using the [`And`](crate::expr::exprs::operators::Operator::And) operator.
435///
436/// ## Example usage
437///
438/// ```
439/// # use vortex_array::arrays::BoolArray;
440/// # use vortex_array::{IntoArray, ToCanonical};
441/// # use vortex_array::expr::{and, root, lit};
442/// let xs = BoolArray::from_iter(vec![true, false, true]);
443/// let result = and(root(), lit(true)).evaluate(&xs.to_array()).unwrap();
444///
445/// assert_eq!(
446///     result.to_bool().bit_buffer(),
447///     BoolArray::from_iter(vec![true, false, true]).bit_buffer(),
448/// );
449/// ```
450pub fn and(lhs: Expression, rhs: Expression) -> Expression {
451    Binary
452        .try_new_expr(Operator::And, [lhs, rhs])
453        .vortex_expect("Failed to create And binary expression")
454}
455
456/// Collects a list of `and`ed values into a single vortex, expr
457/// [x, y, z] => x and (y and z)
458pub fn and_collect<I>(iter: I) -> Option<Expression>
459where
460    I: IntoIterator<Item = Expression>,
461    I::IntoIter: DoubleEndedIterator<Item = Expression>,
462{
463    let mut iter = iter.into_iter();
464    let first = iter.next_back()?;
465    Some(iter.rfold(first, |acc, elem| and(elem, acc)))
466}
467
468/// Collects a list of `and`ed values into a single vortex, expr
469/// [x, y, z] => x and (y and z)
470pub fn and_collect_right<I>(iter: I) -> Option<Expression>
471where
472    I: IntoIterator<Item = Expression>,
473{
474    let iter = iter.into_iter();
475    iter.reduce(and)
476}
477
478/// Create a new [`Binary`] using the [`Add`](crate::expr::exprs::operators::Operator::Add) operator.
479///
480/// ## Example usage
481///
482/// ```
483/// # use vortex_array::IntoArray;
484/// # use vortex_array::arrow::IntoArrowArray as _;
485/// # use vortex_buffer::buffer;
486/// # use vortex_array::expr::{checked_add, lit, root};
487/// let xs = buffer![1, 2, 3].into_array();
488/// let result = checked_add(root(), lit(5))
489///     .evaluate(&xs.to_array())
490///     .unwrap();
491///
492/// assert_eq!(
493///     &result.into_arrow_preferred().unwrap(),
494///     &buffer![6, 7, 8]
495///         .into_array()
496///         .into_arrow_preferred()
497///         .unwrap()
498/// );
499/// ```
500pub fn checked_add(lhs: Expression, rhs: Expression) -> Expression {
501    Binary
502        .try_new_expr(Operator::Add, [lhs, rhs])
503        .vortex_expect("Failed to create Add binary expression")
504}
505
506#[cfg(test)]
507mod tests {
508    use vortex_dtype::{DType, Nullability};
509
510    use super::{and, and_collect, and_collect_right, eq, gt, gt_eq, lt, lt_eq, not_eq, or};
511    use crate::expr::exprs::get_item::col;
512    use crate::expr::exprs::literal::lit;
513    use crate::expr::{Expression, test_harness};
514
515    #[test]
516    fn and_collect_left_assoc() {
517        let values = vec![lit(1), lit(2), lit(3)];
518        assert_eq!(
519            Some(and(lit(1), and(lit(2), lit(3)))),
520            and_collect(values.into_iter())
521        );
522    }
523
524    #[test]
525    fn and_collect_right_assoc() {
526        let values = vec![lit(1), lit(2), lit(3)];
527        assert_eq!(
528            Some(and(and(lit(1), lit(2)), lit(3))),
529            and_collect_right(values.into_iter())
530        );
531    }
532
533    #[test]
534    fn dtype() {
535        let dtype = test_harness::struct_dtype();
536        let bool1: Expression = col("bool1");
537        let bool2: Expression = col("bool2");
538        assert_eq!(
539            and(bool1.clone(), bool2.clone())
540                .return_dtype(&dtype)
541                .unwrap(),
542            DType::Bool(Nullability::NonNullable)
543        );
544        assert_eq!(
545            or(bool1, bool2).return_dtype(&dtype).unwrap(),
546            DType::Bool(Nullability::NonNullable)
547        );
548
549        let col1: Expression = col("col1");
550        let col2: Expression = col("col2");
551
552        assert_eq!(
553            eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
554            DType::Bool(Nullability::Nullable)
555        );
556        assert_eq!(
557            not_eq(col1.clone(), col2.clone())
558                .return_dtype(&dtype)
559                .unwrap(),
560            DType::Bool(Nullability::Nullable)
561        );
562        assert_eq!(
563            gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
564            DType::Bool(Nullability::Nullable)
565        );
566        assert_eq!(
567            gt_eq(col1.clone(), col2.clone())
568                .return_dtype(&dtype)
569                .unwrap(),
570            DType::Bool(Nullability::Nullable)
571        );
572        assert_eq!(
573            lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
574            DType::Bool(Nullability::Nullable)
575        );
576        assert_eq!(
577            lt_eq(col1.clone(), col2.clone())
578                .return_dtype(&dtype)
579                .unwrap(),
580            DType::Bool(Nullability::Nullable)
581        );
582
583        assert_eq!(
584            or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
585                .return_dtype(&dtype)
586                .unwrap(),
587            DType::Bool(Nullability::Nullable)
588        );
589    }
590}