1use std::hash::Hash;
5use std::sync::Arc;
6
7use vortex_array::compute::{add, and_kleene, compare, div, mul, or_kleene, sub};
8use vortex_array::operator::OperatorRef;
9use vortex_array::operator::compare::CompareOperator;
10use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata, compute};
11use vortex_dtype::DType;
12use vortex_error::{VortexResult, vortex_bail};
13use vortex_proto::expr as pb;
14
15use crate::display::{DisplayAs, DisplayFormat};
16use crate::{
17 AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Operator, Scope, StatsCatalog,
18 VTable, lit, vtable,
19};
20
21vtable!(Binary);
22
23#[allow(clippy::derived_hash_with_manual_eq)]
24#[derive(Debug, Clone, Hash, Eq)]
25pub struct BinaryExpr {
26 lhs: ExprRef,
27 operator: Operator,
28 rhs: ExprRef,
29}
30
31impl PartialEq for BinaryExpr {
32 fn eq(&self, other: &Self) -> bool {
33 self.lhs.eq(&other.lhs) && self.operator == other.operator && self.rhs.eq(&other.rhs)
34 }
35}
36
37pub struct BinaryExprEncoding;
38
39impl VTable for BinaryVTable {
40 type Expr = BinaryExpr;
41 type Encoding = BinaryExprEncoding;
42 type Metadata = ProstMetadata<pb::BinaryOpts>;
43
44 fn id(_encoding: &Self::Encoding) -> ExprId {
45 ExprId::new_ref("binary")
46 }
47
48 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
49 ExprEncodingRef::new_ref(BinaryExprEncoding.as_ref())
50 }
51
52 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
53 Some(ProstMetadata(pb::BinaryOpts {
54 op: expr.operator.into(),
55 }))
56 }
57
58 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
59 vec![expr.lhs(), expr.rhs()]
60 }
61
62 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
63 Ok(BinaryExpr::new(
64 children[0].clone(),
65 expr.op(),
66 children[1].clone(),
67 ))
68 }
69
70 fn build(
71 _encoding: &Self::Encoding,
72 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
73 children: Vec<ExprRef>,
74 ) -> VortexResult<Self::Expr> {
75 Ok(BinaryExpr::new(
76 children[0].clone(),
77 metadata.op().into(),
78 children[1].clone(),
79 ))
80 }
81
82 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
83 let lhs = expr.lhs.unchecked_evaluate(scope)?;
84 let rhs = expr.rhs.unchecked_evaluate(scope)?;
85
86 match expr.operator {
87 Operator::Eq => compare(&lhs, &rhs, compute::Operator::Eq),
88 Operator::NotEq => compare(&lhs, &rhs, compute::Operator::NotEq),
89 Operator::Lt => compare(&lhs, &rhs, compute::Operator::Lt),
90 Operator::Lte => compare(&lhs, &rhs, compute::Operator::Lte),
91 Operator::Gt => compare(&lhs, &rhs, compute::Operator::Gt),
92 Operator::Gte => compare(&lhs, &rhs, compute::Operator::Gte),
93 Operator::And => and_kleene(&lhs, &rhs),
94 Operator::Or => or_kleene(&lhs, &rhs),
95 Operator::Add => add(&lhs, &rhs),
96 Operator::Sub => sub(&lhs, &rhs),
97 Operator::Mul => mul(&lhs, &rhs),
98 Operator::Div => div(&lhs, &rhs),
99 }
100 }
101
102 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
103 let lhs = expr.lhs.return_dtype(scope)?;
104 let rhs = expr.rhs.return_dtype(scope)?;
105
106 if expr.operator.is_arithmetic() {
107 if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
108 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
109 }
110 vortex_bail!(
111 "incompatible types for arithmetic operation: {} {}",
112 lhs,
113 rhs
114 );
115 }
116
117 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
118 }
119
120 fn operator(expr: &BinaryExpr, scope: &OperatorRef) -> VortexResult<Option<OperatorRef>> {
121 let Some(lhs) = expr.lhs.operator(scope)? else {
122 return Ok(None);
123 };
124 let Some(rhs) = expr.rhs.operator(scope)? else {
125 return Ok(None);
126 };
127 let Ok(op): VortexResult<compute::Operator> = expr.operator.try_into() else {
128 return Ok(None);
129 };
130 Ok(Some(Arc::new(CompareOperator::try_new(lhs, rhs, op)?)))
131 }
132}
133
134impl BinaryExpr {
135 pub fn new(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> Self {
136 Self { lhs, operator, rhs }
137 }
138
139 pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
140 Self::new(lhs, operator, rhs).into_expr()
141 }
142
143 pub fn lhs(&self) -> &ExprRef {
144 &self.lhs
145 }
146
147 pub fn rhs(&self) -> &ExprRef {
148 &self.rhs
149 }
150
151 pub fn op(&self) -> Operator {
152 self.operator
153 }
154}
155
156impl DisplayAs for BinaryExpr {
157 fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
158 match df {
159 DisplayFormat::Compact => {
160 write!(f, "({} {} {})", self.lhs, self.operator, self.rhs)
161 }
162 DisplayFormat::Tree => {
163 write!(f, "Binary({})", self.operator)
164 }
165 }
166 }
167
168 fn child_names(&self) -> Option<Vec<String>> {
169 Some(vec!["lhs".to_string(), "rhs".to_string()])
170 }
171}
172
173impl AnalysisExpr for BinaryExpr {
174 fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
175 #[inline]
189 fn with_nan_predicate(
190 lhs: &ExprRef,
191 rhs: &ExprRef,
192 value_predicate: ExprRef,
193 catalog: &mut dyn StatsCatalog,
194 ) -> ExprRef {
195 let nan_predicate = lhs
196 .nan_count(catalog)
197 .into_iter()
198 .chain(rhs.nan_count(catalog))
199 .map(|nans| eq(nans, lit(0u64)))
200 .reduce(and);
201
202 if let Some(nan_check) = nan_predicate {
203 and(nan_check, value_predicate)
204 } else {
205 value_predicate
206 }
207 }
208
209 match self.operator {
210 Operator::Eq => {
211 let min_lhs = self.lhs.min(catalog);
212 let max_lhs = self.lhs.max(catalog);
213
214 let min_rhs = self.rhs.min(catalog);
215 let max_rhs = self.rhs.max(catalog);
216
217 let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
218 let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
219
220 let min_max_check = left.into_iter().chain(right).reduce(or)?;
221
222 Some(with_nan_predicate(
224 self.lhs(),
225 self.rhs(),
226 min_max_check,
227 catalog,
228 ))
229 }
230 Operator::NotEq => {
231 let min_lhs = self.lhs.min(catalog)?;
232 let max_lhs = self.lhs.max(catalog)?;
233
234 let min_rhs = self.rhs.min(catalog)?;
235 let max_rhs = self.rhs.max(catalog)?;
236
237 let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
238
239 Some(with_nan_predicate(
240 self.lhs(),
241 self.rhs(),
242 min_max_check,
243 catalog,
244 ))
245 }
246 Operator::Gt => {
247 let min_max_check = lt_eq(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
248
249 Some(with_nan_predicate(
250 self.lhs(),
251 self.rhs(),
252 min_max_check,
253 catalog,
254 ))
255 }
256 Operator::Gte => {
257 let min_max_check = lt(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
259
260 Some(with_nan_predicate(
261 self.lhs(),
262 self.rhs(),
263 min_max_check,
264 catalog,
265 ))
266 }
267 Operator::Lt => {
268 let min_max_check = gt_eq(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
270
271 Some(with_nan_predicate(
272 self.lhs(),
273 self.rhs(),
274 min_max_check,
275 catalog,
276 ))
277 }
278 Operator::Lte => {
279 let min_max_check = gt(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
281
282 Some(with_nan_predicate(
283 self.lhs(),
284 self.rhs(),
285 min_max_check,
286 catalog,
287 ))
288 }
289 Operator::And => self
290 .lhs
291 .stat_falsification(catalog)
292 .into_iter()
293 .chain(self.rhs.stat_falsification(catalog))
294 .reduce(or),
295 Operator::Or => Some(and(
296 self.lhs.stat_falsification(catalog)?,
297 self.rhs.stat_falsification(catalog)?,
298 )),
299 Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
300 }
301 }
302}
303
304pub fn eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
323 BinaryExpr::new(lhs, Operator::Eq, rhs).into_expr()
324}
325
326pub fn not_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
345 BinaryExpr::new(lhs, Operator::NotEq, rhs).into_expr()
346}
347
348pub fn gt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
367 BinaryExpr::new(lhs, Operator::Gte, rhs).into_expr()
368}
369
370pub fn gt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
389 BinaryExpr::new(lhs, Operator::Gt, rhs).into_expr()
390}
391
392pub fn lt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
411 BinaryExpr::new(lhs, Operator::Lte, rhs).into_expr()
412}
413
414pub fn lt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
433 BinaryExpr::new(lhs, Operator::Lt, rhs).into_expr()
434}
435
436pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
453 BinaryExpr::new(lhs, Operator::Or, rhs).into_expr()
454}
455
456pub fn or_collect<I>(iter: I) -> Option<ExprRef>
459where
460 I: IntoIterator<Item = ExprRef>,
461 I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
462{
463 let mut iter = iter.into_iter();
464 let first = iter.next_back()?;
465 Some(iter.rfold(first, |acc, elem| or(elem, acc)))
466}
467
468pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
485 BinaryExpr::new(lhs, Operator::And, rhs).into_expr()
486}
487
488pub fn and_collect<I>(iter: I) -> Option<ExprRef>
491where
492 I: IntoIterator<Item = ExprRef>,
493 I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
494{
495 let mut iter = iter.into_iter();
496 let first = iter.next_back()?;
497 Some(iter.rfold(first, |acc, elem| and(elem, acc)))
498}
499
500pub fn and_collect_right<I>(iter: I) -> Option<ExprRef>
503where
504 I: IntoIterator<Item = ExprRef>,
505{
506 let iter = iter.into_iter();
507 iter.reduce(and)
508}
509
510pub fn checked_add(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
533 BinaryExpr::new(lhs, Operator::Add, rhs).into_expr()
534}
535
536#[cfg(test)]
537mod tests {
538 use std::sync::Arc;
539
540 use vortex_dtype::{DType, Nullability};
541
542 use crate::{
543 VortexExpr, and, and_collect, and_collect_right, col, eq, gt, gt_eq, lit, lt, lt_eq,
544 not_eq, or, test_harness,
545 };
546
547 #[test]
548 fn and_collect_left_assoc() {
549 let values = vec![lit(1), lit(2), lit(3)];
550 assert_eq!(
551 Some(and(lit(1), and(lit(2), lit(3)))),
552 and_collect(values.into_iter())
553 );
554 }
555
556 #[test]
557 fn and_collect_right_assoc() {
558 let values = vec![lit(1), lit(2), lit(3)];
559 assert_eq!(
560 Some(and(and(lit(1), lit(2)), lit(3))),
561 and_collect_right(values.into_iter())
562 );
563 }
564
565 #[test]
566 fn dtype() {
567 let dtype = test_harness::struct_dtype();
568 let bool1: Arc<dyn VortexExpr> = col("bool1");
569 let bool2: Arc<dyn VortexExpr> = col("bool2");
570 assert_eq!(
571 and(bool1.clone(), bool2.clone())
572 .return_dtype(&dtype)
573 .unwrap(),
574 DType::Bool(Nullability::NonNullable)
575 );
576 assert_eq!(
577 or(bool1.clone(), bool2.clone())
578 .return_dtype(&dtype)
579 .unwrap(),
580 DType::Bool(Nullability::NonNullable)
581 );
582
583 let col1: Arc<dyn VortexExpr> = col("col1");
584 let col2: Arc<dyn VortexExpr> = col("col2");
585
586 assert_eq!(
587 eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
588 DType::Bool(Nullability::Nullable)
589 );
590 assert_eq!(
591 not_eq(col1.clone(), col2.clone())
592 .return_dtype(&dtype)
593 .unwrap(),
594 DType::Bool(Nullability::Nullable)
595 );
596 assert_eq!(
597 gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
598 DType::Bool(Nullability::Nullable)
599 );
600 assert_eq!(
601 gt_eq(col1.clone(), col2.clone())
602 .return_dtype(&dtype)
603 .unwrap(),
604 DType::Bool(Nullability::Nullable)
605 );
606 assert_eq!(
607 lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
608 DType::Bool(Nullability::Nullable)
609 );
610 assert_eq!(
611 lt_eq(col1.clone(), col2.clone())
612 .return_dtype(&dtype)
613 .unwrap(),
614 DType::Bool(Nullability::Nullable)
615 );
616
617 assert_eq!(
618 or(
619 lt(col1.clone(), col2.clone()),
620 not_eq(col1.clone(), col2.clone())
621 )
622 .return_dtype(&dtype)
623 .unwrap(),
624 DType::Bool(Nullability::Nullable)
625 );
626 }
627}