1use std::fmt::Display;
5use std::hash::Hash;
6
7use vortex_array::compute::{Operator as ArrayOperator, add, and_kleene, compare, or_kleene, sub};
8use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata};
9use vortex_dtype::DType;
10use vortex_error::{VortexResult, vortex_bail};
11use vortex_proto::expr as pb;
12
13use crate::{
14 AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Operator, Scope, StatsCatalog,
15 VTable, lit, vtable,
16};
17
18vtable!(Binary);
19
20#[allow(clippy::derived_hash_with_manual_eq)]
21#[derive(Debug, Clone, Hash, Eq)]
22pub struct BinaryExpr {
23 lhs: ExprRef,
24 operator: Operator,
25 rhs: ExprRef,
26}
27
28impl PartialEq for BinaryExpr {
29 fn eq(&self, other: &Self) -> bool {
30 self.lhs.eq(&other.lhs) && self.operator == other.operator && self.rhs.eq(&other.rhs)
31 }
32}
33
34pub struct BinaryExprEncoding;
35
36impl VTable for BinaryVTable {
37 type Expr = BinaryExpr;
38 type Encoding = BinaryExprEncoding;
39 type Metadata = ProstMetadata<pb::BinaryOpts>;
40
41 fn id(_encoding: &Self::Encoding) -> ExprId {
42 ExprId::new_ref("binary")
43 }
44
45 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
46 ExprEncodingRef::new_ref(BinaryExprEncoding.as_ref())
47 }
48
49 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
50 Some(ProstMetadata(pb::BinaryOpts {
51 op: expr.operator.into(),
52 }))
53 }
54
55 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
56 vec![expr.lhs(), expr.rhs()]
57 }
58
59 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
60 Ok(BinaryExpr::new(
61 children[0].clone(),
62 expr.op(),
63 children[1].clone(),
64 ))
65 }
66
67 fn build(
68 _encoding: &Self::Encoding,
69 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
70 children: Vec<ExprRef>,
71 ) -> VortexResult<Self::Expr> {
72 Ok(BinaryExpr::new(
73 children[0].clone(),
74 metadata.op().into(),
75 children[1].clone(),
76 ))
77 }
78
79 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
80 let lhs = expr.lhs.unchecked_evaluate(scope)?;
81 let rhs = expr.rhs.unchecked_evaluate(scope)?;
82
83 match expr.operator {
84 Operator::Eq => compare(&lhs, &rhs, ArrayOperator::Eq),
85 Operator::NotEq => compare(&lhs, &rhs, ArrayOperator::NotEq),
86 Operator::Lt => compare(&lhs, &rhs, ArrayOperator::Lt),
87 Operator::Lte => compare(&lhs, &rhs, ArrayOperator::Lte),
88 Operator::Gt => compare(&lhs, &rhs, ArrayOperator::Gt),
89 Operator::Gte => compare(&lhs, &rhs, ArrayOperator::Gte),
90 Operator::And => and_kleene(&lhs, &rhs),
91 Operator::Or => or_kleene(&lhs, &rhs),
92 Operator::Add => add(&lhs, &rhs),
93 Operator::Sub => sub(&lhs, &rhs),
94 }
95 }
96
97 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
98 let lhs = expr.lhs.return_dtype(scope)?;
99 let rhs = expr.rhs.return_dtype(scope)?;
100
101 if expr.operator == Operator::Add {
102 if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
103 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
104 }
105 vortex_bail!("incompatible types for checked add: {} {}", lhs, rhs);
106 }
107
108 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
109 }
110}
111
112impl BinaryExpr {
113 pub fn new(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> Self {
114 Self { lhs, operator, rhs }
115 }
116
117 pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
118 Self::new(lhs, operator, rhs).into_expr()
119 }
120
121 pub fn lhs(&self) -> &ExprRef {
122 &self.lhs
123 }
124
125 pub fn rhs(&self) -> &ExprRef {
126 &self.rhs
127 }
128
129 pub fn op(&self) -> Operator {
130 self.operator
131 }
132}
133
134impl Display for BinaryExpr {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 write!(f, "({} {} {})", self.lhs, self.operator, self.rhs)
137 }
138}
139
140impl AnalysisExpr for BinaryExpr {
141 fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
142 #[inline]
156 fn with_nan_predicate(
157 lhs: &ExprRef,
158 rhs: &ExprRef,
159 value_predicate: ExprRef,
160 catalog: &mut dyn StatsCatalog,
161 ) -> ExprRef {
162 let nan_predicate = lhs
163 .nan_count(catalog)
164 .into_iter()
165 .chain(rhs.nan_count(catalog))
166 .map(|nans| eq(nans, lit(0u64)))
167 .reduce(and);
168
169 if let Some(nan_check) = nan_predicate {
170 and(nan_check, value_predicate)
171 } else {
172 value_predicate
173 }
174 }
175
176 match self.operator {
177 Operator::Eq => {
178 let min_lhs = self.lhs.min(catalog);
179 let max_lhs = self.lhs.max(catalog);
180
181 let min_rhs = self.rhs.min(catalog);
182 let max_rhs = self.rhs.max(catalog);
183
184 let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
185 let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
186
187 let min_max_check = left.into_iter().chain(right).reduce(or)?;
188
189 Some(with_nan_predicate(
191 self.lhs(),
192 self.rhs(),
193 min_max_check,
194 catalog,
195 ))
196 }
197 Operator::NotEq => {
198 let min_lhs = self.lhs.min(catalog)?;
199 let max_lhs = self.lhs.max(catalog)?;
200
201 let min_rhs = self.rhs.min(catalog)?;
202 let max_rhs = self.rhs.max(catalog)?;
203
204 let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
205
206 Some(with_nan_predicate(
207 self.lhs(),
208 self.rhs(),
209 min_max_check,
210 catalog,
211 ))
212 }
213 Operator::Gt => {
214 let min_max_check = lt_eq(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
215
216 Some(with_nan_predicate(
217 self.lhs(),
218 self.rhs(),
219 min_max_check,
220 catalog,
221 ))
222 }
223 Operator::Gte => {
224 let min_max_check = lt(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
226
227 Some(with_nan_predicate(
228 self.lhs(),
229 self.rhs(),
230 min_max_check,
231 catalog,
232 ))
233 }
234 Operator::Lt => {
235 let min_max_check = gt_eq(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
237
238 Some(with_nan_predicate(
239 self.lhs(),
240 self.rhs(),
241 min_max_check,
242 catalog,
243 ))
244 }
245 Operator::Lte => {
246 let min_max_check = gt(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
248
249 Some(with_nan_predicate(
250 self.lhs(),
251 self.rhs(),
252 min_max_check,
253 catalog,
254 ))
255 }
256 Operator::And => self
257 .lhs
258 .stat_falsification(catalog)
259 .into_iter()
260 .chain(self.rhs.stat_falsification(catalog))
261 .reduce(or),
262 Operator::Or => Some(and(
263 self.lhs.stat_falsification(catalog)?,
264 self.rhs.stat_falsification(catalog)?,
265 )),
266 Operator::Add | Operator::Sub => None,
267 }
268 }
269}
270
271pub fn eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
290 BinaryExpr::new(lhs, Operator::Eq, rhs).into_expr()
291}
292
293pub fn not_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
312 BinaryExpr::new(lhs, Operator::NotEq, rhs).into_expr()
313}
314
315pub fn gt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
334 BinaryExpr::new(lhs, Operator::Gte, rhs).into_expr()
335}
336
337pub fn gt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
356 BinaryExpr::new(lhs, Operator::Gt, rhs).into_expr()
357}
358
359pub fn lt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
378 BinaryExpr::new(lhs, Operator::Lte, rhs).into_expr()
379}
380
381pub fn lt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
400 BinaryExpr::new(lhs, Operator::Lt, rhs).into_expr()
401}
402
403pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
420 BinaryExpr::new(lhs, Operator::Or, rhs).into_expr()
421}
422
423pub fn or_collect<I>(iter: I) -> Option<ExprRef>
426where
427 I: IntoIterator<Item = ExprRef>,
428 I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
429{
430 let mut iter = iter.into_iter();
431 let first = iter.next_back()?;
432 Some(iter.rfold(first, |acc, elem| or(elem, acc)))
433}
434
435pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
452 BinaryExpr::new(lhs, Operator::And, rhs).into_expr()
453}
454
455pub fn and_collect<I>(iter: I) -> Option<ExprRef>
458where
459 I: IntoIterator<Item = ExprRef>,
460 I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
461{
462 let mut iter = iter.into_iter();
463 let first = iter.next_back()?;
464 Some(iter.rfold(first, |acc, elem| and(elem, acc)))
465}
466
467pub fn and_collect_right<I>(iter: I) -> Option<ExprRef>
470where
471 I: IntoIterator<Item = ExprRef>,
472{
473 let iter = iter.into_iter();
474 iter.reduce(and)
475}
476
477pub fn checked_add(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
500 BinaryExpr::new(lhs, Operator::Add, rhs).into_expr()
501}
502
503#[cfg(test)]
504mod tests {
505 use std::sync::Arc;
506
507 use vortex_dtype::{DType, Nullability};
508
509 use crate::{
510 VortexExpr, and, and_collect, and_collect_right, col, eq, gt, gt_eq, lit, lt, lt_eq,
511 not_eq, or, test_harness,
512 };
513
514 #[test]
515 fn and_collect_left_assoc() {
516 let values = vec![lit(1), lit(2), lit(3)];
517 assert_eq!(
518 Some(and(lit(1), and(lit(2), lit(3)))),
519 and_collect(values.into_iter())
520 );
521 }
522
523 #[test]
524 fn and_collect_right_assoc() {
525 let values = vec![lit(1), lit(2), lit(3)];
526 assert_eq!(
527 Some(and(and(lit(1), lit(2)), lit(3))),
528 and_collect_right(values.into_iter())
529 );
530 }
531
532 #[test]
533 fn dtype() {
534 let dtype = test_harness::struct_dtype();
535 let bool1: Arc<dyn VortexExpr> = col("bool1");
536 let bool2: Arc<dyn VortexExpr> = col("bool2");
537 assert_eq!(
538 and(bool1.clone(), bool2.clone())
539 .return_dtype(&dtype)
540 .unwrap(),
541 DType::Bool(Nullability::NonNullable)
542 );
543 assert_eq!(
544 or(bool1.clone(), bool2.clone())
545 .return_dtype(&dtype)
546 .unwrap(),
547 DType::Bool(Nullability::NonNullable)
548 );
549
550 let col1: Arc<dyn VortexExpr> = col("col1");
551 let col2: Arc<dyn VortexExpr> = col("col2");
552
553 assert_eq!(
554 eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
555 DType::Bool(Nullability::Nullable)
556 );
557 assert_eq!(
558 not_eq(col1.clone(), col2.clone())
559 .return_dtype(&dtype)
560 .unwrap(),
561 DType::Bool(Nullability::Nullable)
562 );
563 assert_eq!(
564 gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
565 DType::Bool(Nullability::Nullable)
566 );
567 assert_eq!(
568 gt_eq(col1.clone(), col2.clone())
569 .return_dtype(&dtype)
570 .unwrap(),
571 DType::Bool(Nullability::Nullable)
572 );
573 assert_eq!(
574 lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
575 DType::Bool(Nullability::Nullable)
576 );
577 assert_eq!(
578 lt_eq(col1.clone(), col2.clone())
579 .return_dtype(&dtype)
580 .unwrap(),
581 DType::Bool(Nullability::Nullable)
582 );
583
584 assert_eq!(
585 or(
586 lt(col1.clone(), col2.clone()),
587 not_eq(col1.clone(), col2.clone())
588 )
589 .return_dtype(&dtype)
590 .unwrap(),
591 DType::Bool(Nullability::Nullable)
592 );
593 }
594}