1use std::hash::Hash;
5
6use vortex_array::compute::{Operator as ArrayOperator, add, and_kleene, compare, or_kleene, sub};
7use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata};
8use vortex_dtype::DType;
9use vortex_error::{VortexResult, vortex_bail};
10use vortex_proto::expr as pb;
11
12use crate::display::{DisplayAs, DisplayFormat};
13use crate::{
14 AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Operator, Scope, StatsCatalog,
15 VTable, lit, vtable,
16};
17
18vtable!(Binary);
19
20#[allow(clippy::derived_hash_with_manual_eq)]
21#[derive(Debug, Clone, Hash, Eq)]
22pub struct BinaryExpr {
23 lhs: ExprRef,
24 operator: Operator,
25 rhs: ExprRef,
26}
27
28impl PartialEq for BinaryExpr {
29 fn eq(&self, other: &Self) -> bool {
30 self.lhs.eq(&other.lhs) && self.operator == other.operator && self.rhs.eq(&other.rhs)
31 }
32}
33
34pub struct BinaryExprEncoding;
35
36impl VTable for BinaryVTable {
37 type Expr = BinaryExpr;
38 type Encoding = BinaryExprEncoding;
39 type Metadata = ProstMetadata<pb::BinaryOpts>;
40
41 fn id(_encoding: &Self::Encoding) -> ExprId {
42 ExprId::new_ref("binary")
43 }
44
45 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
46 ExprEncodingRef::new_ref(BinaryExprEncoding.as_ref())
47 }
48
49 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
50 Some(ProstMetadata(pb::BinaryOpts {
51 op: expr.operator.into(),
52 }))
53 }
54
55 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
56 vec![expr.lhs(), expr.rhs()]
57 }
58
59 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
60 Ok(BinaryExpr::new(
61 children[0].clone(),
62 expr.op(),
63 children[1].clone(),
64 ))
65 }
66
67 fn build(
68 _encoding: &Self::Encoding,
69 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
70 children: Vec<ExprRef>,
71 ) -> VortexResult<Self::Expr> {
72 Ok(BinaryExpr::new(
73 children[0].clone(),
74 metadata.op().into(),
75 children[1].clone(),
76 ))
77 }
78
79 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
80 let lhs = expr.lhs.unchecked_evaluate(scope)?;
81 let rhs = expr.rhs.unchecked_evaluate(scope)?;
82
83 match expr.operator {
84 Operator::Eq => compare(&lhs, &rhs, ArrayOperator::Eq),
85 Operator::NotEq => compare(&lhs, &rhs, ArrayOperator::NotEq),
86 Operator::Lt => compare(&lhs, &rhs, ArrayOperator::Lt),
87 Operator::Lte => compare(&lhs, &rhs, ArrayOperator::Lte),
88 Operator::Gt => compare(&lhs, &rhs, ArrayOperator::Gt),
89 Operator::Gte => compare(&lhs, &rhs, ArrayOperator::Gte),
90 Operator::And => and_kleene(&lhs, &rhs),
91 Operator::Or => or_kleene(&lhs, &rhs),
92 Operator::Add => add(&lhs, &rhs),
93 Operator::Sub => sub(&lhs, &rhs),
94 }
95 }
96
97 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
98 let lhs = expr.lhs.return_dtype(scope)?;
99 let rhs = expr.rhs.return_dtype(scope)?;
100
101 if expr.operator == Operator::Add {
102 if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
103 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
104 }
105 vortex_bail!("incompatible types for checked add: {} {}", lhs, rhs);
106 }
107
108 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
109 }
110}
111
112impl BinaryExpr {
113 pub fn new(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> Self {
114 Self { lhs, operator, rhs }
115 }
116
117 pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
118 Self::new(lhs, operator, rhs).into_expr()
119 }
120
121 pub fn lhs(&self) -> &ExprRef {
122 &self.lhs
123 }
124
125 pub fn rhs(&self) -> &ExprRef {
126 &self.rhs
127 }
128
129 pub fn op(&self) -> Operator {
130 self.operator
131 }
132}
133
134impl DisplayAs for BinaryExpr {
135 fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
136 match df {
137 DisplayFormat::Compact => {
138 write!(f, "({} {} {})", self.lhs, self.operator, self.rhs)
139 }
140 DisplayFormat::Tree => {
141 write!(f, "Binary({})", self.operator)
142 }
143 }
144 }
145
146 fn child_names(&self) -> Option<Vec<String>> {
147 Some(vec!["lhs".to_string(), "rhs".to_string()])
148 }
149}
150
151impl AnalysisExpr for BinaryExpr {
152 fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
153 #[inline]
167 fn with_nan_predicate(
168 lhs: &ExprRef,
169 rhs: &ExprRef,
170 value_predicate: ExprRef,
171 catalog: &mut dyn StatsCatalog,
172 ) -> ExprRef {
173 let nan_predicate = lhs
174 .nan_count(catalog)
175 .into_iter()
176 .chain(rhs.nan_count(catalog))
177 .map(|nans| eq(nans, lit(0u64)))
178 .reduce(and);
179
180 if let Some(nan_check) = nan_predicate {
181 and(nan_check, value_predicate)
182 } else {
183 value_predicate
184 }
185 }
186
187 match self.operator {
188 Operator::Eq => {
189 let min_lhs = self.lhs.min(catalog);
190 let max_lhs = self.lhs.max(catalog);
191
192 let min_rhs = self.rhs.min(catalog);
193 let max_rhs = self.rhs.max(catalog);
194
195 let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
196 let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
197
198 let min_max_check = left.into_iter().chain(right).reduce(or)?;
199
200 Some(with_nan_predicate(
202 self.lhs(),
203 self.rhs(),
204 min_max_check,
205 catalog,
206 ))
207 }
208 Operator::NotEq => {
209 let min_lhs = self.lhs.min(catalog)?;
210 let max_lhs = self.lhs.max(catalog)?;
211
212 let min_rhs = self.rhs.min(catalog)?;
213 let max_rhs = self.rhs.max(catalog)?;
214
215 let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
216
217 Some(with_nan_predicate(
218 self.lhs(),
219 self.rhs(),
220 min_max_check,
221 catalog,
222 ))
223 }
224 Operator::Gt => {
225 let min_max_check = lt_eq(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
226
227 Some(with_nan_predicate(
228 self.lhs(),
229 self.rhs(),
230 min_max_check,
231 catalog,
232 ))
233 }
234 Operator::Gte => {
235 let min_max_check = lt(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
237
238 Some(with_nan_predicate(
239 self.lhs(),
240 self.rhs(),
241 min_max_check,
242 catalog,
243 ))
244 }
245 Operator::Lt => {
246 let min_max_check = gt_eq(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
248
249 Some(with_nan_predicate(
250 self.lhs(),
251 self.rhs(),
252 min_max_check,
253 catalog,
254 ))
255 }
256 Operator::Lte => {
257 let min_max_check = gt(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
259
260 Some(with_nan_predicate(
261 self.lhs(),
262 self.rhs(),
263 min_max_check,
264 catalog,
265 ))
266 }
267 Operator::And => self
268 .lhs
269 .stat_falsification(catalog)
270 .into_iter()
271 .chain(self.rhs.stat_falsification(catalog))
272 .reduce(or),
273 Operator::Or => Some(and(
274 self.lhs.stat_falsification(catalog)?,
275 self.rhs.stat_falsification(catalog)?,
276 )),
277 Operator::Add | Operator::Sub => None,
278 }
279 }
280}
281
282pub fn eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
301 BinaryExpr::new(lhs, Operator::Eq, rhs).into_expr()
302}
303
304pub fn not_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
323 BinaryExpr::new(lhs, Operator::NotEq, rhs).into_expr()
324}
325
326pub fn gt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
345 BinaryExpr::new(lhs, Operator::Gte, rhs).into_expr()
346}
347
348pub fn gt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
367 BinaryExpr::new(lhs, Operator::Gt, rhs).into_expr()
368}
369
370pub fn lt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
389 BinaryExpr::new(lhs, Operator::Lte, rhs).into_expr()
390}
391
392pub fn lt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
411 BinaryExpr::new(lhs, Operator::Lt, rhs).into_expr()
412}
413
414pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
431 BinaryExpr::new(lhs, Operator::Or, rhs).into_expr()
432}
433
434pub fn or_collect<I>(iter: I) -> Option<ExprRef>
437where
438 I: IntoIterator<Item = ExprRef>,
439 I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
440{
441 let mut iter = iter.into_iter();
442 let first = iter.next_back()?;
443 Some(iter.rfold(first, |acc, elem| or(elem, acc)))
444}
445
446pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
463 BinaryExpr::new(lhs, Operator::And, rhs).into_expr()
464}
465
466pub fn and_collect<I>(iter: I) -> Option<ExprRef>
469where
470 I: IntoIterator<Item = ExprRef>,
471 I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
472{
473 let mut iter = iter.into_iter();
474 let first = iter.next_back()?;
475 Some(iter.rfold(first, |acc, elem| and(elem, acc)))
476}
477
478pub fn and_collect_right<I>(iter: I) -> Option<ExprRef>
481where
482 I: IntoIterator<Item = ExprRef>,
483{
484 let iter = iter.into_iter();
485 iter.reduce(and)
486}
487
488pub fn checked_add(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
511 BinaryExpr::new(lhs, Operator::Add, rhs).into_expr()
512}
513
514#[cfg(test)]
515mod tests {
516 use std::sync::Arc;
517
518 use vortex_dtype::{DType, Nullability};
519
520 use crate::{
521 VortexExpr, and, and_collect, and_collect_right, col, eq, gt, gt_eq, lit, lt, lt_eq,
522 not_eq, or, test_harness,
523 };
524
525 #[test]
526 fn and_collect_left_assoc() {
527 let values = vec![lit(1), lit(2), lit(3)];
528 assert_eq!(
529 Some(and(lit(1), and(lit(2), lit(3)))),
530 and_collect(values.into_iter())
531 );
532 }
533
534 #[test]
535 fn and_collect_right_assoc() {
536 let values = vec![lit(1), lit(2), lit(3)];
537 assert_eq!(
538 Some(and(and(lit(1), lit(2)), lit(3))),
539 and_collect_right(values.into_iter())
540 );
541 }
542
543 #[test]
544 fn dtype() {
545 let dtype = test_harness::struct_dtype();
546 let bool1: Arc<dyn VortexExpr> = col("bool1");
547 let bool2: Arc<dyn VortexExpr> = col("bool2");
548 assert_eq!(
549 and(bool1.clone(), bool2.clone())
550 .return_dtype(&dtype)
551 .unwrap(),
552 DType::Bool(Nullability::NonNullable)
553 );
554 assert_eq!(
555 or(bool1.clone(), bool2.clone())
556 .return_dtype(&dtype)
557 .unwrap(),
558 DType::Bool(Nullability::NonNullable)
559 );
560
561 let col1: Arc<dyn VortexExpr> = col("col1");
562 let col2: Arc<dyn VortexExpr> = col("col2");
563
564 assert_eq!(
565 eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
566 DType::Bool(Nullability::Nullable)
567 );
568 assert_eq!(
569 not_eq(col1.clone(), col2.clone())
570 .return_dtype(&dtype)
571 .unwrap(),
572 DType::Bool(Nullability::Nullable)
573 );
574 assert_eq!(
575 gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
576 DType::Bool(Nullability::Nullable)
577 );
578 assert_eq!(
579 gt_eq(col1.clone(), col2.clone())
580 .return_dtype(&dtype)
581 .unwrap(),
582 DType::Bool(Nullability::Nullable)
583 );
584 assert_eq!(
585 lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
586 DType::Bool(Nullability::Nullable)
587 );
588 assert_eq!(
589 lt_eq(col1.clone(), col2.clone())
590 .return_dtype(&dtype)
591 .unwrap(),
592 DType::Bool(Nullability::Nullable)
593 );
594
595 assert_eq!(
596 or(
597 lt(col1.clone(), col2.clone()),
598 not_eq(col1.clone(), col2.clone())
599 )
600 .return_dtype(&dtype)
601 .unwrap(),
602 DType::Bool(Nullability::Nullable)
603 );
604 }
605}