1use std::fmt::Display;
5use std::hash::Hash;
6
7use vortex_array::compute::{Operator as ArrayOperator, add, and_kleene, compare, or_kleene};
8use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata};
9use vortex_dtype::DType;
10use vortex_error::{VortexResult, vortex_bail};
11use vortex_proto::expr as pb;
12
13use crate::{
14 AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Operator, Scope, StatsCatalog,
15 VTable, lit, vtable,
16};
17
18vtable!(Binary);
19
20#[allow(clippy::derived_hash_with_manual_eq)]
21#[derive(Debug, Clone, Hash)]
22pub struct BinaryExpr {
23 lhs: ExprRef,
24 operator: Operator,
25 rhs: ExprRef,
26}
27
28impl PartialEq for BinaryExpr {
29 fn eq(&self, other: &Self) -> bool {
30 self.lhs.eq(&other.lhs) && self.operator == other.operator && self.rhs.eq(&other.rhs)
31 }
32}
33
34pub struct BinaryExprEncoding;
35
36impl VTable for BinaryVTable {
37 type Expr = BinaryExpr;
38 type Encoding = BinaryExprEncoding;
39 type Metadata = ProstMetadata<pb::BinaryOpts>;
40
41 fn id(_encoding: &Self::Encoding) -> ExprId {
42 ExprId::new_ref("binary")
43 }
44
45 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
46 ExprEncodingRef::new_ref(BinaryExprEncoding.as_ref())
47 }
48
49 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
50 Some(ProstMetadata(pb::BinaryOpts {
51 op: expr.operator.into(),
52 }))
53 }
54
55 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
56 vec![expr.lhs(), expr.rhs()]
57 }
58
59 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
60 Ok(BinaryExpr::new(
61 children[0].clone(),
62 expr.op(),
63 children[1].clone(),
64 ))
65 }
66
67 fn build(
68 _encoding: &Self::Encoding,
69 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
70 children: Vec<ExprRef>,
71 ) -> VortexResult<Self::Expr> {
72 Ok(BinaryExpr::new(
73 children[0].clone(),
74 metadata.op().into(),
75 children[1].clone(),
76 ))
77 }
78
79 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
80 let lhs = expr.lhs.unchecked_evaluate(scope)?;
81 let rhs = expr.rhs.unchecked_evaluate(scope)?;
82
83 match expr.operator {
84 Operator::Eq => compare(&lhs, &rhs, ArrayOperator::Eq),
85 Operator::NotEq => compare(&lhs, &rhs, ArrayOperator::NotEq),
86 Operator::Lt => compare(&lhs, &rhs, ArrayOperator::Lt),
87 Operator::Lte => compare(&lhs, &rhs, ArrayOperator::Lte),
88 Operator::Gt => compare(&lhs, &rhs, ArrayOperator::Gt),
89 Operator::Gte => compare(&lhs, &rhs, ArrayOperator::Gte),
90 Operator::And => and_kleene(&lhs, &rhs),
91 Operator::Or => or_kleene(&lhs, &rhs),
92 Operator::Add => add(&lhs, &rhs),
93 }
94 }
95
96 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
97 let lhs = expr.lhs.return_dtype(scope)?;
98 let rhs = expr.rhs.return_dtype(scope)?;
99
100 if expr.operator == Operator::Add {
101 if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
102 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
103 }
104 vortex_bail!("incompatible types for checked add: {} {}", lhs, rhs);
105 }
106
107 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
108 }
109}
110
111impl BinaryExpr {
112 pub fn new(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> Self {
113 Self { lhs, operator, rhs }
114 }
115
116 pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
117 Self::new(lhs, operator, rhs).into_expr()
118 }
119
120 pub fn lhs(&self) -> &ExprRef {
121 &self.lhs
122 }
123
124 pub fn rhs(&self) -> &ExprRef {
125 &self.rhs
126 }
127
128 pub fn op(&self) -> Operator {
129 self.operator
130 }
131}
132
133impl Display for BinaryExpr {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 write!(f, "({} {} {})", self.lhs, self.operator, self.rhs)
136 }
137}
138
139impl AnalysisExpr for BinaryExpr {
140 fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
141 #[inline]
155 fn with_nan_predicate(
156 lhs: &ExprRef,
157 rhs: &ExprRef,
158 value_predicate: ExprRef,
159 catalog: &mut dyn StatsCatalog,
160 ) -> ExprRef {
161 let nan_predicate = lhs
162 .nan_count(catalog)
163 .into_iter()
164 .chain(rhs.nan_count(catalog))
165 .map(|nans| eq(nans, lit(0u64)))
166 .reduce(and);
167
168 if let Some(nan_check) = nan_predicate {
169 and(nan_check, value_predicate)
170 } else {
171 value_predicate
172 }
173 }
174
175 match self.operator {
176 Operator::Eq => {
177 let min_lhs = self.lhs.min(catalog);
178 let max_lhs = self.lhs.max(catalog);
179
180 let min_rhs = self.rhs.min(catalog);
181 let max_rhs = self.rhs.max(catalog);
182
183 let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
184 let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
185
186 let min_max_check = left.into_iter().chain(right).reduce(or)?;
187
188 Some(with_nan_predicate(
190 self.lhs(),
191 self.rhs(),
192 min_max_check,
193 catalog,
194 ))
195 }
196 Operator::NotEq => {
197 let min_lhs = self.lhs.min(catalog)?;
198 let max_lhs = self.lhs.max(catalog)?;
199
200 let min_rhs = self.rhs.min(catalog)?;
201 let max_rhs = self.rhs.max(catalog)?;
202
203 let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
204
205 Some(with_nan_predicate(
206 self.lhs(),
207 self.rhs(),
208 min_max_check,
209 catalog,
210 ))
211 }
212 Operator::Gt => {
213 let min_max_check = lt_eq(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
214
215 Some(with_nan_predicate(
216 self.lhs(),
217 self.rhs(),
218 min_max_check,
219 catalog,
220 ))
221 }
222 Operator::Gte => {
223 let min_max_check = lt(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
225
226 Some(with_nan_predicate(
227 self.lhs(),
228 self.rhs(),
229 min_max_check,
230 catalog,
231 ))
232 }
233 Operator::Lt => {
234 let min_max_check = gt_eq(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
236
237 Some(with_nan_predicate(
238 self.lhs(),
239 self.rhs(),
240 min_max_check,
241 catalog,
242 ))
243 }
244 Operator::Lte => {
245 let min_max_check = gt(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
247
248 Some(with_nan_predicate(
249 self.lhs(),
250 self.rhs(),
251 min_max_check,
252 catalog,
253 ))
254 }
255 Operator::And => self
256 .lhs
257 .stat_falsification(catalog)
258 .into_iter()
259 .chain(self.rhs.stat_falsification(catalog))
260 .reduce(or),
261 Operator::Or => Some(and(
262 self.lhs.stat_falsification(catalog)?,
263 self.rhs.stat_falsification(catalog)?,
264 )),
265 Operator::Add => None,
266 }
267 }
268}
269
270pub fn eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
290 BinaryExpr::new(lhs, Operator::Eq, rhs).into_expr()
291}
292
293pub fn not_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
313 BinaryExpr::new(lhs, Operator::NotEq, rhs).into_expr()
314}
315
316pub fn gt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
336 BinaryExpr::new(lhs, Operator::Gte, rhs).into_expr()
337}
338
339pub fn gt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
359 BinaryExpr::new(lhs, Operator::Gt, rhs).into_expr()
360}
361
362pub fn lt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
382 BinaryExpr::new(lhs, Operator::Lte, rhs).into_expr()
383}
384
385pub fn lt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
405 BinaryExpr::new(lhs, Operator::Lt, rhs).into_expr()
406}
407
408pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
426 BinaryExpr::new(lhs, Operator::Or, rhs).into_expr()
427}
428
429pub fn or_collect<I>(iter: I) -> Option<ExprRef>
432where
433 I: IntoIterator<Item = ExprRef>,
434 I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
435{
436 let mut iter = iter.into_iter();
437 let first = iter.next_back()?;
438 Some(iter.rfold(first, |acc, elem| or(elem, acc)))
439}
440
441pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
459 BinaryExpr::new(lhs, Operator::And, rhs).into_expr()
460}
461
462pub fn and_collect<I>(iter: I) -> Option<ExprRef>
465where
466 I: IntoIterator<Item = ExprRef>,
467 I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
468{
469 let mut iter = iter.into_iter();
470 let first = iter.next_back()?;
471 Some(iter.rfold(first, |acc, elem| and(elem, acc)))
472}
473
474pub fn and_collect_right<I>(iter: I) -> Option<ExprRef>
477where
478 I: IntoIterator<Item = ExprRef>,
479{
480 let iter = iter.into_iter();
481 iter.reduce(and)
482}
483
484pub fn checked_add(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
508 BinaryExpr::new(lhs, Operator::Add, rhs).into_expr()
509}
510
511#[cfg(test)]
512mod tests {
513 use std::sync::Arc;
514
515 use vortex_dtype::{DType, Nullability};
516
517 use crate::{
518 VortexExpr, and, and_collect, and_collect_right, col, eq, gt, gt_eq, lit, lt, lt_eq,
519 not_eq, or, test_harness,
520 };
521
522 #[test]
523 fn and_collect_left_assoc() {
524 let values = vec![lit(1), lit(2), lit(3)];
525 assert_eq!(
526 Some(and(lit(1), and(lit(2), lit(3)))),
527 and_collect(values.into_iter())
528 );
529 }
530
531 #[test]
532 fn and_collect_right_assoc() {
533 let values = vec![lit(1), lit(2), lit(3)];
534 assert_eq!(
535 Some(and(and(lit(1), lit(2)), lit(3))),
536 and_collect_right(values.into_iter())
537 );
538 }
539
540 #[test]
541 fn dtype() {
542 let dtype = test_harness::struct_dtype();
543 let bool1: Arc<dyn VortexExpr> = col("bool1");
544 let bool2: Arc<dyn VortexExpr> = col("bool2");
545 assert_eq!(
546 and(bool1.clone(), bool2.clone())
547 .return_dtype(&dtype)
548 .unwrap(),
549 DType::Bool(Nullability::NonNullable)
550 );
551 assert_eq!(
552 or(bool1.clone(), bool2.clone())
553 .return_dtype(&dtype)
554 .unwrap(),
555 DType::Bool(Nullability::NonNullable)
556 );
557
558 let col1: Arc<dyn VortexExpr> = col("col1");
559 let col2: Arc<dyn VortexExpr> = col("col2");
560
561 assert_eq!(
562 eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
563 DType::Bool(Nullability::Nullable)
564 );
565 assert_eq!(
566 not_eq(col1.clone(), col2.clone())
567 .return_dtype(&dtype)
568 .unwrap(),
569 DType::Bool(Nullability::Nullable)
570 );
571 assert_eq!(
572 gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
573 DType::Bool(Nullability::Nullable)
574 );
575 assert_eq!(
576 gt_eq(col1.clone(), col2.clone())
577 .return_dtype(&dtype)
578 .unwrap(),
579 DType::Bool(Nullability::Nullable)
580 );
581 assert_eq!(
582 lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
583 DType::Bool(Nullability::Nullable)
584 );
585 assert_eq!(
586 lt_eq(col1.clone(), col2.clone())
587 .return_dtype(&dtype)
588 .unwrap(),
589 DType::Bool(Nullability::Nullable)
590 );
591
592 assert_eq!(
593 or(
594 lt(col1.clone(), col2.clone()),
595 not_eq(col1.clone(), col2.clone())
596 )
597 .return_dtype(&dtype)
598 .unwrap(),
599 DType::Bool(Nullability::Nullable)
600 );
601 }
602}