1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use vortex_array::ArrayRef;
7use vortex_array::compute::{Operator as ArrayOperator, and_kleene, compare, or_kleene};
8use vortex_dtype::DType;
9use vortex_error::VortexResult;
10
11use crate::{AnalysisExpr, ExprRef, Operator, Scope, ScopeDType, StatsCatalog, VortexExpr};
12
13#[derive(Debug, Clone, Eq, Hash)]
14#[allow(clippy::derived_hash_with_manual_eq)]
15pub struct BinaryExpr {
16 lhs: ExprRef,
17 operator: Operator,
18 rhs: ExprRef,
19}
20
21impl BinaryExpr {
22 pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
23 Arc::new(Self { lhs, operator, rhs })
24 }
25
26 pub fn lhs(&self) -> &ExprRef {
27 &self.lhs
28 }
29
30 pub fn rhs(&self) -> &ExprRef {
31 &self.rhs
32 }
33
34 pub fn op(&self) -> Operator {
35 self.operator
36 }
37}
38
39impl Display for BinaryExpr {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "({} {} {})", self.lhs, self.operator, self.rhs)
42 }
43}
44
45#[cfg(feature = "proto")]
46pub(crate) mod proto {
47 use vortex_error::{VortexResult, vortex_bail};
48 use vortex_proto::expr::kind::Kind;
49
50 use crate::{BinaryExpr, ExprDeserialize, ExprRef, ExprSerializable, Id};
51
52 pub(crate) struct BinarySerde;
53
54 impl Id for BinarySerde {
55 fn id(&self) -> &'static str {
56 "binary"
57 }
58 }
59
60 impl ExprDeserialize for BinarySerde {
61 fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
62 let Kind::BinaryOp(op) = kind else {
63 vortex_bail!("wrong kind {:?}, binary", kind)
64 };
65
66 Ok(BinaryExpr::new_expr(
67 children[0].clone(),
68 (*op).try_into()?,
69 children[1].clone(),
70 ))
71 }
72 }
73
74 impl ExprSerializable for BinaryExpr {
75 fn id(&self) -> &'static str {
76 BinarySerde.id()
77 }
78
79 fn serialize_kind(&self) -> VortexResult<Kind> {
80 Ok(Kind::BinaryOp(self.operator.into()))
81 }
82 }
83}
84
85impl AnalysisExpr for BinaryExpr {
86 fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
87 match self.operator {
88 Operator::Eq => {
89 let min_lhs = self.lhs.min(catalog);
90 let max_lhs = self.lhs.max(catalog);
91
92 let min_rhs = self.rhs.min(catalog);
93 let max_rhs = self.rhs.max(catalog);
94
95 let left = min_lhs.zip_with(max_rhs, gt);
96 let right = min_rhs.zip_with(max_lhs, gt);
97 left.into_iter().chain(right).reduce(or)
98 }
99 Operator::NotEq => {
100 let min_lhs = self.lhs.min(catalog)?;
101 let max_lhs = self.lhs.max(catalog)?;
102
103 let min_rhs = self.rhs.min(catalog)?;
104 let max_rhs = self.rhs.max(catalog)?;
105
106 Some(and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs)))
107 }
108 Operator::Gt => Some(lt_eq(self.lhs.max(catalog)?, self.rhs.min(catalog)?)),
109 Operator::Gte => Some(lt(self.lhs.max(catalog)?, self.rhs.min(catalog)?)),
110 Operator::Lt => Some(gt_eq(self.lhs.min(catalog)?, self.rhs.max(catalog)?)),
111 Operator::Lte => Some(gt(self.lhs.min(catalog)?, self.rhs.max(catalog)?)),
112 Operator::And => self
114 .lhs
115 .stat_falsification(catalog)
116 .into_iter()
117 .chain(self.rhs.stat_falsification(catalog))
118 .reduce(or),
119 Operator::Or => Some(and(
120 self.lhs.stat_falsification(catalog)?,
121 self.rhs.stat_falsification(catalog)?,
122 )),
123 }
124 }
125}
126
127impl VortexExpr for BinaryExpr {
128 fn as_any(&self) -> &dyn Any {
129 self
130 }
131
132 fn unchecked_evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
133 let lhs = self.lhs.unchecked_evaluate(scope)?;
134 let rhs = self.rhs.unchecked_evaluate(scope)?;
135
136 match self.operator {
137 Operator::Eq => compare(&lhs, &rhs, ArrayOperator::Eq),
138 Operator::NotEq => compare(&lhs, &rhs, ArrayOperator::NotEq),
139 Operator::Lt => compare(&lhs, &rhs, ArrayOperator::Lt),
140 Operator::Lte => compare(&lhs, &rhs, ArrayOperator::Lte),
141 Operator::Gt => compare(&lhs, &rhs, ArrayOperator::Gt),
142 Operator::Gte => compare(&lhs, &rhs, ArrayOperator::Gte),
143 Operator::And => and_kleene(&lhs, &rhs),
144 Operator::Or => or_kleene(&lhs, &rhs),
145 }
146 }
147
148 fn children(&self) -> Vec<&ExprRef> {
149 vec![&self.lhs, &self.rhs]
150 }
151
152 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
153 assert_eq!(children.len(), 2);
154 BinaryExpr::new_expr(children[0].clone(), self.operator, children[1].clone())
155 }
156
157 fn return_dtype(&self, ctx: &ScopeDType) -> VortexResult<DType> {
158 let lhs = self.lhs.return_dtype(ctx)?;
159 let rhs = self.rhs.return_dtype(ctx)?;
160 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
161 }
162}
163
164impl PartialEq for BinaryExpr {
165 fn eq(&self, other: &BinaryExpr) -> bool {
166 other.operator == self.operator && other.lhs.eq(&self.lhs) && other.rhs.eq(&self.rhs)
167 }
168}
169
170pub fn eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
190 BinaryExpr::new_expr(lhs, Operator::Eq, rhs)
191}
192
193pub fn not_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
213 BinaryExpr::new_expr(lhs, Operator::NotEq, rhs)
214}
215
216pub fn gt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
236 BinaryExpr::new_expr(lhs, Operator::Gte, rhs)
237}
238
239pub fn gt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
259 BinaryExpr::new_expr(lhs, Operator::Gt, rhs)
260}
261
262pub fn lt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
282 BinaryExpr::new_expr(lhs, Operator::Lte, rhs)
283}
284
285pub fn lt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
305 BinaryExpr::new_expr(lhs, Operator::Lt, rhs)
306}
307
308pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
326 BinaryExpr::new_expr(lhs, Operator::Or, rhs)
327}
328
329pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
347 BinaryExpr::new_expr(lhs, Operator::And, rhs)
348}
349
350#[cfg(test)]
351mod tests {
352 use std::sync::Arc;
353
354 use vortex_dtype::{DType, Nullability};
355
356 use crate::{
357 ScopeDType, VortexExpr, and, col, eq, gt, gt_eq, lt, lt_eq, not_eq, or, test_harness,
358 };
359
360 #[test]
361 fn dtype() {
362 let dtype = test_harness::struct_dtype();
363 let bool1: Arc<dyn VortexExpr> = col("bool1");
364 let bool2: Arc<dyn VortexExpr> = col("bool2");
365 assert_eq!(
366 and(bool1.clone(), bool2.clone())
367 .return_dtype(&ScopeDType::new(dtype.clone()))
368 .unwrap(),
369 DType::Bool(Nullability::NonNullable)
370 );
371 assert_eq!(
372 or(bool1.clone(), bool2.clone())
373 .return_dtype(&ScopeDType::new(dtype.clone()))
374 .unwrap(),
375 DType::Bool(Nullability::NonNullable)
376 );
377
378 let col1: Arc<dyn VortexExpr> = col("col1");
379 let col2: Arc<dyn VortexExpr> = col("col2");
380
381 assert_eq!(
382 eq(col1.clone(), col2.clone())
383 .return_dtype(&ScopeDType::new(dtype.clone()))
384 .unwrap(),
385 DType::Bool(Nullability::Nullable)
386 );
387 assert_eq!(
388 not_eq(col1.clone(), col2.clone())
389 .return_dtype(&ScopeDType::new(dtype.clone()))
390 .unwrap(),
391 DType::Bool(Nullability::Nullable)
392 );
393 assert_eq!(
394 gt(col1.clone(), col2.clone())
395 .return_dtype(&ScopeDType::new(dtype.clone()))
396 .unwrap(),
397 DType::Bool(Nullability::Nullable)
398 );
399 assert_eq!(
400 gt_eq(col1.clone(), col2.clone())
401 .return_dtype(&ScopeDType::new(dtype.clone()))
402 .unwrap(),
403 DType::Bool(Nullability::Nullable)
404 );
405 assert_eq!(
406 lt(col1.clone(), col2.clone())
407 .return_dtype(&ScopeDType::new(dtype.clone()))
408 .unwrap(),
409 DType::Bool(Nullability::Nullable)
410 );
411 assert_eq!(
412 lt_eq(col1.clone(), col2.clone())
413 .return_dtype(&ScopeDType::new(dtype.clone()))
414 .unwrap(),
415 DType::Bool(Nullability::Nullable)
416 );
417
418 assert_eq!(
419 or(
420 lt(col1.clone(), col2.clone()),
421 not_eq(col1.clone(), col2.clone())
422 )
423 .return_dtype(&ScopeDType::new(dtype))
424 .unwrap(),
425 DType::Bool(Nullability::Nullable)
426 );
427 }
428}