1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use vortex_array::ArrayRef;
7use vortex_array::compute::{Operator as ArrayOperator, add, and_kleene, compare, or_kleene};
8use vortex_dtype::DType;
9use vortex_error::{VortexResult, vortex_bail};
10
11use crate::{AnalysisExpr, ExprRef, Operator, Scope, ScopeDType, StatsCatalog, VortexExpr};
12
13#[derive(Debug, Clone, Eq, Hash)]
14#[allow(clippy::derived_hash_with_manual_eq)]
15pub struct BinaryExpr {
16 lhs: ExprRef,
17 operator: Operator,
18 rhs: ExprRef,
19}
20
21impl BinaryExpr {
22 pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
23 Arc::new(Self { lhs, operator, rhs })
24 }
25
26 pub fn lhs(&self) -> &ExprRef {
27 &self.lhs
28 }
29
30 pub fn rhs(&self) -> &ExprRef {
31 &self.rhs
32 }
33
34 pub fn op(&self) -> Operator {
35 self.operator
36 }
37}
38
39impl Display for BinaryExpr {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "({} {} {})", self.lhs, self.operator, self.rhs)
42 }
43}
44
45#[cfg(feature = "proto")]
46pub(crate) mod proto {
47 use vortex_error::{VortexResult, vortex_bail};
48 use vortex_proto::expr::kind::Kind;
49
50 use crate::{BinaryExpr, ExprDeserialize, ExprRef, ExprSerializable, Id};
51
52 pub(crate) struct BinarySerde;
53
54 impl Id for BinarySerde {
55 fn id(&self) -> &'static str {
56 "binary"
57 }
58 }
59
60 impl ExprDeserialize for BinarySerde {
61 fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
62 let Kind::BinaryOp(op) = kind else {
63 vortex_bail!("wrong kind {:?}, binary", kind)
64 };
65
66 Ok(BinaryExpr::new_expr(
67 children[0].clone(),
68 (*op).try_into()?,
69 children[1].clone(),
70 ))
71 }
72 }
73
74 impl ExprSerializable for BinaryExpr {
75 fn id(&self) -> &'static str {
76 BinarySerde.id()
77 }
78
79 fn serialize_kind(&self) -> VortexResult<Kind> {
80 Ok(Kind::BinaryOp(self.operator.into()))
81 }
82 }
83}
84
85impl AnalysisExpr for BinaryExpr {
86 fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
87 match self.operator {
88 Operator::Eq => {
89 let min_lhs = self.lhs.min(catalog);
90 let max_lhs = self.lhs.max(catalog);
91
92 let min_rhs = self.rhs.min(catalog);
93 let max_rhs = self.rhs.max(catalog);
94
95 let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
96 let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
97 left.into_iter().chain(right).reduce(or)
98 }
99 Operator::NotEq => {
100 let min_lhs = self.lhs.min(catalog)?;
101 let max_lhs = self.lhs.max(catalog)?;
102
103 let min_rhs = self.rhs.min(catalog)?;
104 let max_rhs = self.rhs.max(catalog)?;
105
106 Some(and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs)))
107 }
108 Operator::Gt => Some(lt_eq(self.lhs.max(catalog)?, self.rhs.min(catalog)?)),
109 Operator::Gte => Some(lt(self.lhs.max(catalog)?, self.rhs.min(catalog)?)),
110 Operator::Lt => Some(gt_eq(self.lhs.min(catalog)?, self.rhs.max(catalog)?)),
111 Operator::Lte => Some(gt(self.lhs.min(catalog)?, self.rhs.max(catalog)?)),
112 Operator::And => self
113 .lhs
114 .stat_falsification(catalog)
115 .into_iter()
116 .chain(self.rhs.stat_falsification(catalog))
117 .reduce(or),
118 Operator::Or => Some(and(
119 self.lhs.stat_falsification(catalog)?,
120 self.rhs.stat_falsification(catalog)?,
121 )),
122 Operator::Add => None,
123 }
124 }
125}
126
127impl VortexExpr for BinaryExpr {
128 fn as_any(&self) -> &dyn Any {
129 self
130 }
131
132 fn unchecked_evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
133 let lhs = self.lhs.unchecked_evaluate(scope)?;
134 let rhs = self.rhs.unchecked_evaluate(scope)?;
135
136 match self.operator {
137 Operator::Eq => compare(&lhs, &rhs, ArrayOperator::Eq),
138 Operator::NotEq => compare(&lhs, &rhs, ArrayOperator::NotEq),
139 Operator::Lt => compare(&lhs, &rhs, ArrayOperator::Lt),
140 Operator::Lte => compare(&lhs, &rhs, ArrayOperator::Lte),
141 Operator::Gt => compare(&lhs, &rhs, ArrayOperator::Gt),
142 Operator::Gte => compare(&lhs, &rhs, ArrayOperator::Gte),
143 Operator::And => and_kleene(&lhs, &rhs),
144 Operator::Or => or_kleene(&lhs, &rhs),
145 Operator::Add => add(&lhs, &rhs),
146 }
147 }
148
149 fn children(&self) -> Vec<&ExprRef> {
150 vec![&self.lhs, &self.rhs]
151 }
152
153 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
154 assert_eq!(children.len(), 2);
155 BinaryExpr::new_expr(children[0].clone(), self.operator, children[1].clone())
156 }
157
158 fn return_dtype(&self, ctx: &ScopeDType) -> VortexResult<DType> {
159 let lhs = self.lhs.return_dtype(ctx)?;
160 let rhs = self.rhs.return_dtype(ctx)?;
161
162 if self.operator == Operator::Add {
163 if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
164 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
165 }
166 vortex_bail!("incompatible types for checked add: {} {}", lhs, rhs);
167 }
168
169 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
170 }
171}
172
173impl PartialEq for BinaryExpr {
174 fn eq(&self, other: &BinaryExpr) -> bool {
175 other.operator == self.operator && other.lhs.eq(&self.lhs) && other.rhs.eq(&self.rhs)
176 }
177}
178
179pub fn eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
199 BinaryExpr::new_expr(lhs, Operator::Eq, rhs)
200}
201
202pub fn not_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
222 BinaryExpr::new_expr(lhs, Operator::NotEq, rhs)
223}
224
225pub fn gt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
245 BinaryExpr::new_expr(lhs, Operator::Gte, rhs)
246}
247
248pub fn gt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
268 BinaryExpr::new_expr(lhs, Operator::Gt, rhs)
269}
270
271pub fn lt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
291 BinaryExpr::new_expr(lhs, Operator::Lte, rhs)
292}
293
294pub fn lt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
314 BinaryExpr::new_expr(lhs, Operator::Lt, rhs)
315}
316
317pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
335 BinaryExpr::new_expr(lhs, Operator::Or, rhs)
336}
337
338pub fn or_collect<I>(iter: I) -> Option<ExprRef>
341where
342 I: IntoIterator<Item = ExprRef>,
343 I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
344{
345 let mut iter = iter.into_iter();
346 let first = iter.next_back()?;
347 Some(iter.rfold(first, |acc, elem| or(elem, acc)))
348}
349
350pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
368 BinaryExpr::new_expr(lhs, Operator::And, rhs)
369}
370
371pub fn and_collect<I>(iter: I) -> Option<ExprRef>
374where
375 I: IntoIterator<Item = ExprRef>,
376 I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
377{
378 let mut iter = iter.into_iter();
379 let first = iter.next_back()?;
380 Some(iter.rfold(first, |acc, elem| and(elem, acc)))
381}
382
383pub fn and_collect_right<I>(iter: I) -> Option<ExprRef>
386where
387 I: IntoIterator<Item = ExprRef>,
388{
389 let iter = iter.into_iter();
390 iter.reduce(and)
391}
392
393pub fn checked_add(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
417 BinaryExpr::new_expr(lhs, Operator::Add, rhs)
418}
419
420#[cfg(test)]
421mod tests {
422 use std::sync::Arc;
423
424 use vortex_dtype::{DType, Nullability};
425
426 use crate::{
427 ScopeDType, VortexExpr, and, and_collect, and_collect_right, col, eq, gt, gt_eq, lit, lt,
428 lt_eq, not_eq, or, test_harness,
429 };
430
431 #[test]
432 fn and_collect_left_assoc() {
433 let values = vec![lit(1), lit(2), lit(3)];
434 assert_eq!(
435 Some(and(lit(1), and(lit(2), lit(3)))),
436 and_collect(values.into_iter())
437 );
438 }
439
440 #[test]
441 fn and_collect_right_assoc() {
442 let values = vec![lit(1), lit(2), lit(3)];
443 assert_eq!(
444 Some(and(and(lit(1), lit(2)), lit(3))),
445 and_collect_right(values.into_iter())
446 );
447 }
448
449 #[test]
450 fn dtype() {
451 let dtype = test_harness::struct_dtype();
452 let bool1: Arc<dyn VortexExpr> = col("bool1");
453 let bool2: Arc<dyn VortexExpr> = col("bool2");
454 assert_eq!(
455 and(bool1.clone(), bool2.clone())
456 .return_dtype(&ScopeDType::new(dtype.clone()))
457 .unwrap(),
458 DType::Bool(Nullability::NonNullable)
459 );
460 assert_eq!(
461 or(bool1.clone(), bool2.clone())
462 .return_dtype(&ScopeDType::new(dtype.clone()))
463 .unwrap(),
464 DType::Bool(Nullability::NonNullable)
465 );
466
467 let col1: Arc<dyn VortexExpr> = col("col1");
468 let col2: Arc<dyn VortexExpr> = col("col2");
469
470 assert_eq!(
471 eq(col1.clone(), col2.clone())
472 .return_dtype(&ScopeDType::new(dtype.clone()))
473 .unwrap(),
474 DType::Bool(Nullability::Nullable)
475 );
476 assert_eq!(
477 not_eq(col1.clone(), col2.clone())
478 .return_dtype(&ScopeDType::new(dtype.clone()))
479 .unwrap(),
480 DType::Bool(Nullability::Nullable)
481 );
482 assert_eq!(
483 gt(col1.clone(), col2.clone())
484 .return_dtype(&ScopeDType::new(dtype.clone()))
485 .unwrap(),
486 DType::Bool(Nullability::Nullable)
487 );
488 assert_eq!(
489 gt_eq(col1.clone(), col2.clone())
490 .return_dtype(&ScopeDType::new(dtype.clone()))
491 .unwrap(),
492 DType::Bool(Nullability::Nullable)
493 );
494 assert_eq!(
495 lt(col1.clone(), col2.clone())
496 .return_dtype(&ScopeDType::new(dtype.clone()))
497 .unwrap(),
498 DType::Bool(Nullability::Nullable)
499 );
500 assert_eq!(
501 lt_eq(col1.clone(), col2.clone())
502 .return_dtype(&ScopeDType::new(dtype.clone()))
503 .unwrap(),
504 DType::Bool(Nullability::Nullable)
505 );
506
507 assert_eq!(
508 or(
509 lt(col1.clone(), col2.clone()),
510 not_eq(col1.clone(), col2.clone())
511 )
512 .return_dtype(&ScopeDType::new(dtype))
513 .unwrap(),
514 DType::Bool(Nullability::Nullable)
515 );
516 }
517}