1use std::hash::Hash;
5use std::sync::Arc;
6
7use vortex_array::compute::{add, and_kleene, compare, or_kleene, sub};
8use vortex_array::operator::OperatorRef;
9use vortex_array::operator::compare::CompareOperator;
10use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata, compute};
11use vortex_dtype::DType;
12use vortex_error::{VortexResult, vortex_bail};
13use vortex_proto::expr as pb;
14
15use crate::display::{DisplayAs, DisplayFormat};
16use crate::{
17 AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Operator, Scope, StatsCatalog,
18 VTable, lit, vtable,
19};
20
21vtable!(Binary);
22
23#[allow(clippy::derived_hash_with_manual_eq)]
24#[derive(Debug, Clone, Hash, Eq)]
25pub struct BinaryExpr {
26 lhs: ExprRef,
27 operator: Operator,
28 rhs: ExprRef,
29}
30
31impl PartialEq for BinaryExpr {
32 fn eq(&self, other: &Self) -> bool {
33 self.lhs.eq(&other.lhs) && self.operator == other.operator && self.rhs.eq(&other.rhs)
34 }
35}
36
37pub struct BinaryExprEncoding;
38
39impl VTable for BinaryVTable {
40 type Expr = BinaryExpr;
41 type Encoding = BinaryExprEncoding;
42 type Metadata = ProstMetadata<pb::BinaryOpts>;
43
44 fn id(_encoding: &Self::Encoding) -> ExprId {
45 ExprId::new_ref("binary")
46 }
47
48 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
49 ExprEncodingRef::new_ref(BinaryExprEncoding.as_ref())
50 }
51
52 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
53 Some(ProstMetadata(pb::BinaryOpts {
54 op: expr.operator.into(),
55 }))
56 }
57
58 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
59 vec![expr.lhs(), expr.rhs()]
60 }
61
62 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
63 Ok(BinaryExpr::new(
64 children[0].clone(),
65 expr.op(),
66 children[1].clone(),
67 ))
68 }
69
70 fn build(
71 _encoding: &Self::Encoding,
72 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
73 children: Vec<ExprRef>,
74 ) -> VortexResult<Self::Expr> {
75 Ok(BinaryExpr::new(
76 children[0].clone(),
77 metadata.op().into(),
78 children[1].clone(),
79 ))
80 }
81
82 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
83 let lhs = expr.lhs.unchecked_evaluate(scope)?;
84 let rhs = expr.rhs.unchecked_evaluate(scope)?;
85
86 match expr.operator {
87 Operator::Eq => compare(&lhs, &rhs, compute::Operator::Eq),
88 Operator::NotEq => compare(&lhs, &rhs, compute::Operator::NotEq),
89 Operator::Lt => compare(&lhs, &rhs, compute::Operator::Lt),
90 Operator::Lte => compare(&lhs, &rhs, compute::Operator::Lte),
91 Operator::Gt => compare(&lhs, &rhs, compute::Operator::Gt),
92 Operator::Gte => compare(&lhs, &rhs, compute::Operator::Gte),
93 Operator::And => and_kleene(&lhs, &rhs),
94 Operator::Or => or_kleene(&lhs, &rhs),
95 Operator::Add => add(&lhs, &rhs),
96 Operator::Sub => sub(&lhs, &rhs),
97 }
98 }
99
100 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
101 let lhs = expr.lhs.return_dtype(scope)?;
102 let rhs = expr.rhs.return_dtype(scope)?;
103
104 if expr.operator == Operator::Add {
105 if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
106 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
107 }
108 vortex_bail!("incompatible types for checked add: {} {}", lhs, rhs);
109 }
110
111 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
112 }
113
114 fn operator(expr: &BinaryExpr, scope: &OperatorRef) -> VortexResult<Option<OperatorRef>> {
115 let Some(lhs) = expr.lhs.operator(scope)? else {
116 return Ok(None);
117 };
118 let Some(rhs) = expr.rhs.operator(scope)? else {
119 return Ok(None);
120 };
121 let Ok(op): VortexResult<compute::Operator> = expr.operator.try_into() else {
122 return Ok(None);
123 };
124 Ok(Some(Arc::new(CompareOperator::try_new(lhs, rhs, op)?)))
125 }
126}
127
128impl BinaryExpr {
129 pub fn new(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> Self {
130 Self { lhs, operator, rhs }
131 }
132
133 pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
134 Self::new(lhs, operator, rhs).into_expr()
135 }
136
137 pub fn lhs(&self) -> &ExprRef {
138 &self.lhs
139 }
140
141 pub fn rhs(&self) -> &ExprRef {
142 &self.rhs
143 }
144
145 pub fn op(&self) -> Operator {
146 self.operator
147 }
148}
149
150impl DisplayAs for BinaryExpr {
151 fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
152 match df {
153 DisplayFormat::Compact => {
154 write!(f, "({} {} {})", self.lhs, self.operator, self.rhs)
155 }
156 DisplayFormat::Tree => {
157 write!(f, "Binary({})", self.operator)
158 }
159 }
160 }
161
162 fn child_names(&self) -> Option<Vec<String>> {
163 Some(vec!["lhs".to_string(), "rhs".to_string()])
164 }
165}
166
167impl AnalysisExpr for BinaryExpr {
168 fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
169 #[inline]
183 fn with_nan_predicate(
184 lhs: &ExprRef,
185 rhs: &ExprRef,
186 value_predicate: ExprRef,
187 catalog: &mut dyn StatsCatalog,
188 ) -> ExprRef {
189 let nan_predicate = lhs
190 .nan_count(catalog)
191 .into_iter()
192 .chain(rhs.nan_count(catalog))
193 .map(|nans| eq(nans, lit(0u64)))
194 .reduce(and);
195
196 if let Some(nan_check) = nan_predicate {
197 and(nan_check, value_predicate)
198 } else {
199 value_predicate
200 }
201 }
202
203 match self.operator {
204 Operator::Eq => {
205 let min_lhs = self.lhs.min(catalog);
206 let max_lhs = self.lhs.max(catalog);
207
208 let min_rhs = self.rhs.min(catalog);
209 let max_rhs = self.rhs.max(catalog);
210
211 let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
212 let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
213
214 let min_max_check = left.into_iter().chain(right).reduce(or)?;
215
216 Some(with_nan_predicate(
218 self.lhs(),
219 self.rhs(),
220 min_max_check,
221 catalog,
222 ))
223 }
224 Operator::NotEq => {
225 let min_lhs = self.lhs.min(catalog)?;
226 let max_lhs = self.lhs.max(catalog)?;
227
228 let min_rhs = self.rhs.min(catalog)?;
229 let max_rhs = self.rhs.max(catalog)?;
230
231 let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
232
233 Some(with_nan_predicate(
234 self.lhs(),
235 self.rhs(),
236 min_max_check,
237 catalog,
238 ))
239 }
240 Operator::Gt => {
241 let min_max_check = lt_eq(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
242
243 Some(with_nan_predicate(
244 self.lhs(),
245 self.rhs(),
246 min_max_check,
247 catalog,
248 ))
249 }
250 Operator::Gte => {
251 let min_max_check = lt(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
253
254 Some(with_nan_predicate(
255 self.lhs(),
256 self.rhs(),
257 min_max_check,
258 catalog,
259 ))
260 }
261 Operator::Lt => {
262 let min_max_check = gt_eq(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
264
265 Some(with_nan_predicate(
266 self.lhs(),
267 self.rhs(),
268 min_max_check,
269 catalog,
270 ))
271 }
272 Operator::Lte => {
273 let min_max_check = gt(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
275
276 Some(with_nan_predicate(
277 self.lhs(),
278 self.rhs(),
279 min_max_check,
280 catalog,
281 ))
282 }
283 Operator::And => self
284 .lhs
285 .stat_falsification(catalog)
286 .into_iter()
287 .chain(self.rhs.stat_falsification(catalog))
288 .reduce(or),
289 Operator::Or => Some(and(
290 self.lhs.stat_falsification(catalog)?,
291 self.rhs.stat_falsification(catalog)?,
292 )),
293 Operator::Add | Operator::Sub => None,
294 }
295 }
296}
297
298pub fn eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
317 BinaryExpr::new(lhs, Operator::Eq, rhs).into_expr()
318}
319
320pub fn not_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
339 BinaryExpr::new(lhs, Operator::NotEq, rhs).into_expr()
340}
341
342pub fn gt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
361 BinaryExpr::new(lhs, Operator::Gte, rhs).into_expr()
362}
363
364pub fn gt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
383 BinaryExpr::new(lhs, Operator::Gt, rhs).into_expr()
384}
385
386pub fn lt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
405 BinaryExpr::new(lhs, Operator::Lte, rhs).into_expr()
406}
407
408pub fn lt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
427 BinaryExpr::new(lhs, Operator::Lt, rhs).into_expr()
428}
429
430pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
447 BinaryExpr::new(lhs, Operator::Or, rhs).into_expr()
448}
449
450pub fn or_collect<I>(iter: I) -> Option<ExprRef>
453where
454 I: IntoIterator<Item = ExprRef>,
455 I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
456{
457 let mut iter = iter.into_iter();
458 let first = iter.next_back()?;
459 Some(iter.rfold(first, |acc, elem| or(elem, acc)))
460}
461
462pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
479 BinaryExpr::new(lhs, Operator::And, rhs).into_expr()
480}
481
482pub fn and_collect<I>(iter: I) -> Option<ExprRef>
485where
486 I: IntoIterator<Item = ExprRef>,
487 I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
488{
489 let mut iter = iter.into_iter();
490 let first = iter.next_back()?;
491 Some(iter.rfold(first, |acc, elem| and(elem, acc)))
492}
493
494pub fn and_collect_right<I>(iter: I) -> Option<ExprRef>
497where
498 I: IntoIterator<Item = ExprRef>,
499{
500 let iter = iter.into_iter();
501 iter.reduce(and)
502}
503
504pub fn checked_add(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
527 BinaryExpr::new(lhs, Operator::Add, rhs).into_expr()
528}
529
530#[cfg(test)]
531mod tests {
532 use std::sync::Arc;
533
534 use vortex_dtype::{DType, Nullability};
535
536 use crate::{
537 VortexExpr, and, and_collect, and_collect_right, col, eq, gt, gt_eq, lit, lt, lt_eq,
538 not_eq, or, test_harness,
539 };
540
541 #[test]
542 fn and_collect_left_assoc() {
543 let values = vec![lit(1), lit(2), lit(3)];
544 assert_eq!(
545 Some(and(lit(1), and(lit(2), lit(3)))),
546 and_collect(values.into_iter())
547 );
548 }
549
550 #[test]
551 fn and_collect_right_assoc() {
552 let values = vec![lit(1), lit(2), lit(3)];
553 assert_eq!(
554 Some(and(and(lit(1), lit(2)), lit(3))),
555 and_collect_right(values.into_iter())
556 );
557 }
558
559 #[test]
560 fn dtype() {
561 let dtype = test_harness::struct_dtype();
562 let bool1: Arc<dyn VortexExpr> = col("bool1");
563 let bool2: Arc<dyn VortexExpr> = col("bool2");
564 assert_eq!(
565 and(bool1.clone(), bool2.clone())
566 .return_dtype(&dtype)
567 .unwrap(),
568 DType::Bool(Nullability::NonNullable)
569 );
570 assert_eq!(
571 or(bool1.clone(), bool2.clone())
572 .return_dtype(&dtype)
573 .unwrap(),
574 DType::Bool(Nullability::NonNullable)
575 );
576
577 let col1: Arc<dyn VortexExpr> = col("col1");
578 let col2: Arc<dyn VortexExpr> = col("col2");
579
580 assert_eq!(
581 eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
582 DType::Bool(Nullability::Nullable)
583 );
584 assert_eq!(
585 not_eq(col1.clone(), col2.clone())
586 .return_dtype(&dtype)
587 .unwrap(),
588 DType::Bool(Nullability::Nullable)
589 );
590 assert_eq!(
591 gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
592 DType::Bool(Nullability::Nullable)
593 );
594 assert_eq!(
595 gt_eq(col1.clone(), col2.clone())
596 .return_dtype(&dtype)
597 .unwrap(),
598 DType::Bool(Nullability::Nullable)
599 );
600 assert_eq!(
601 lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
602 DType::Bool(Nullability::Nullable)
603 );
604 assert_eq!(
605 lt_eq(col1.clone(), col2.clone())
606 .return_dtype(&dtype)
607 .unwrap(),
608 DType::Bool(Nullability::Nullable)
609 );
610
611 assert_eq!(
612 or(
613 lt(col1.clone(), col2.clone()),
614 not_eq(col1.clone(), col2.clone())
615 )
616 .return_dtype(&dtype)
617 .unwrap(),
618 DType::Bool(Nullability::Nullable)
619 );
620 }
621}