1use std::hash::Hash;
5use std::sync::Arc;
6
7use vortex_array::compute::{add, and_kleene, compare, or_kleene, sub};
8use vortex_array::pipeline::OperatorRef;
9use vortex_array::pipeline::operators::CompareOperator;
10use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata, compute};
11use vortex_dtype::DType;
12use vortex_error::{VortexExpect, VortexResult, vortex_bail};
13use vortex_proto::expr as pb;
14
15use crate::display::{DisplayAs, DisplayFormat};
16use crate::{
17 AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Operator, Scope, StatsCatalog,
18 VTable, lit, vtable,
19};
20
21vtable!(Binary);
22
23#[allow(clippy::derived_hash_with_manual_eq)]
24#[derive(Debug, Clone, Hash, Eq)]
25pub struct BinaryExpr {
26 lhs: ExprRef,
27 operator: Operator,
28 rhs: ExprRef,
29}
30
31impl PartialEq for BinaryExpr {
32 fn eq(&self, other: &Self) -> bool {
33 self.lhs.eq(&other.lhs) && self.operator == other.operator && self.rhs.eq(&other.rhs)
34 }
35}
36
37pub struct BinaryExprEncoding;
38
39impl VTable for BinaryVTable {
40 type Expr = BinaryExpr;
41 type Encoding = BinaryExprEncoding;
42 type Metadata = ProstMetadata<pb::BinaryOpts>;
43
44 fn id(_encoding: &Self::Encoding) -> ExprId {
45 ExprId::new_ref("binary")
46 }
47
48 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
49 ExprEncodingRef::new_ref(BinaryExprEncoding.as_ref())
50 }
51
52 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
53 Some(ProstMetadata(pb::BinaryOpts {
54 op: expr.operator.into(),
55 }))
56 }
57
58 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
59 vec![expr.lhs(), expr.rhs()]
60 }
61
62 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
63 Ok(BinaryExpr::new(
64 children[0].clone(),
65 expr.op(),
66 children[1].clone(),
67 ))
68 }
69
70 fn build(
71 _encoding: &Self::Encoding,
72 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
73 children: Vec<ExprRef>,
74 ) -> VortexResult<Self::Expr> {
75 Ok(BinaryExpr::new(
76 children[0].clone(),
77 metadata.op().into(),
78 children[1].clone(),
79 ))
80 }
81
82 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
83 let lhs = expr.lhs.unchecked_evaluate(scope)?;
84 let rhs = expr.rhs.unchecked_evaluate(scope)?;
85
86 match expr.operator {
87 Operator::Eq => compare(&lhs, &rhs, compute::Operator::Eq),
88 Operator::NotEq => compare(&lhs, &rhs, compute::Operator::NotEq),
89 Operator::Lt => compare(&lhs, &rhs, compute::Operator::Lt),
90 Operator::Lte => compare(&lhs, &rhs, compute::Operator::Lte),
91 Operator::Gt => compare(&lhs, &rhs, compute::Operator::Gt),
92 Operator::Gte => compare(&lhs, &rhs, compute::Operator::Gte),
93 Operator::And => and_kleene(&lhs, &rhs),
94 Operator::Or => or_kleene(&lhs, &rhs),
95 Operator::Add => add(&lhs, &rhs),
96 Operator::Sub => sub(&lhs, &rhs),
97 }
98 }
99
100 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
101 let lhs = expr.lhs.return_dtype(scope)?;
102 let rhs = expr.rhs.return_dtype(scope)?;
103
104 if expr.operator == Operator::Add {
105 if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
106 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
107 }
108 vortex_bail!("incompatible types for checked add: {} {}", lhs, rhs);
109 }
110
111 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
112 }
113
114 fn operator(expr: &BinaryExpr, children: Vec<OperatorRef>) -> Option<OperatorRef> {
115 let [lhs, rhs] = children
116 .try_into()
117 .ok()
118 .vortex_expect("Expected 2 children");
119 let op = expr.operator.try_into().ok()?;
120
121 Some(Arc::new(CompareOperator::new(lhs, rhs, op)) as OperatorRef)
122 }
123}
124
125impl BinaryExpr {
126 pub fn new(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> Self {
127 Self { lhs, operator, rhs }
128 }
129
130 pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
131 Self::new(lhs, operator, rhs).into_expr()
132 }
133
134 pub fn lhs(&self) -> &ExprRef {
135 &self.lhs
136 }
137
138 pub fn rhs(&self) -> &ExprRef {
139 &self.rhs
140 }
141
142 pub fn op(&self) -> Operator {
143 self.operator
144 }
145}
146
147impl DisplayAs for BinaryExpr {
148 fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
149 match df {
150 DisplayFormat::Compact => {
151 write!(f, "({} {} {})", self.lhs, self.operator, self.rhs)
152 }
153 DisplayFormat::Tree => {
154 write!(f, "Binary({})", self.operator)
155 }
156 }
157 }
158
159 fn child_names(&self) -> Option<Vec<String>> {
160 Some(vec!["lhs".to_string(), "rhs".to_string()])
161 }
162}
163
164impl AnalysisExpr for BinaryExpr {
165 fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
166 #[inline]
180 fn with_nan_predicate(
181 lhs: &ExprRef,
182 rhs: &ExprRef,
183 value_predicate: ExprRef,
184 catalog: &mut dyn StatsCatalog,
185 ) -> ExprRef {
186 let nan_predicate = lhs
187 .nan_count(catalog)
188 .into_iter()
189 .chain(rhs.nan_count(catalog))
190 .map(|nans| eq(nans, lit(0u64)))
191 .reduce(and);
192
193 if let Some(nan_check) = nan_predicate {
194 and(nan_check, value_predicate)
195 } else {
196 value_predicate
197 }
198 }
199
200 match self.operator {
201 Operator::Eq => {
202 let min_lhs = self.lhs.min(catalog);
203 let max_lhs = self.lhs.max(catalog);
204
205 let min_rhs = self.rhs.min(catalog);
206 let max_rhs = self.rhs.max(catalog);
207
208 let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
209 let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
210
211 let min_max_check = left.into_iter().chain(right).reduce(or)?;
212
213 Some(with_nan_predicate(
215 self.lhs(),
216 self.rhs(),
217 min_max_check,
218 catalog,
219 ))
220 }
221 Operator::NotEq => {
222 let min_lhs = self.lhs.min(catalog)?;
223 let max_lhs = self.lhs.max(catalog)?;
224
225 let min_rhs = self.rhs.min(catalog)?;
226 let max_rhs = self.rhs.max(catalog)?;
227
228 let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
229
230 Some(with_nan_predicate(
231 self.lhs(),
232 self.rhs(),
233 min_max_check,
234 catalog,
235 ))
236 }
237 Operator::Gt => {
238 let min_max_check = lt_eq(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
239
240 Some(with_nan_predicate(
241 self.lhs(),
242 self.rhs(),
243 min_max_check,
244 catalog,
245 ))
246 }
247 Operator::Gte => {
248 let min_max_check = lt(self.lhs.max(catalog)?, self.rhs.min(catalog)?);
250
251 Some(with_nan_predicate(
252 self.lhs(),
253 self.rhs(),
254 min_max_check,
255 catalog,
256 ))
257 }
258 Operator::Lt => {
259 let min_max_check = gt_eq(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
261
262 Some(with_nan_predicate(
263 self.lhs(),
264 self.rhs(),
265 min_max_check,
266 catalog,
267 ))
268 }
269 Operator::Lte => {
270 let min_max_check = gt(self.lhs.min(catalog)?, self.rhs.max(catalog)?);
272
273 Some(with_nan_predicate(
274 self.lhs(),
275 self.rhs(),
276 min_max_check,
277 catalog,
278 ))
279 }
280 Operator::And => self
281 .lhs
282 .stat_falsification(catalog)
283 .into_iter()
284 .chain(self.rhs.stat_falsification(catalog))
285 .reduce(or),
286 Operator::Or => Some(and(
287 self.lhs.stat_falsification(catalog)?,
288 self.rhs.stat_falsification(catalog)?,
289 )),
290 Operator::Add | Operator::Sub => None,
291 }
292 }
293}
294
295pub fn eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
314 BinaryExpr::new(lhs, Operator::Eq, rhs).into_expr()
315}
316
317pub fn not_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
336 BinaryExpr::new(lhs, Operator::NotEq, rhs).into_expr()
337}
338
339pub fn gt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
358 BinaryExpr::new(lhs, Operator::Gte, rhs).into_expr()
359}
360
361pub fn gt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
380 BinaryExpr::new(lhs, Operator::Gt, rhs).into_expr()
381}
382
383pub fn lt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
402 BinaryExpr::new(lhs, Operator::Lte, rhs).into_expr()
403}
404
405pub fn lt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
424 BinaryExpr::new(lhs, Operator::Lt, rhs).into_expr()
425}
426
427pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
444 BinaryExpr::new(lhs, Operator::Or, rhs).into_expr()
445}
446
447pub fn or_collect<I>(iter: I) -> Option<ExprRef>
450where
451 I: IntoIterator<Item = ExprRef>,
452 I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
453{
454 let mut iter = iter.into_iter();
455 let first = iter.next_back()?;
456 Some(iter.rfold(first, |acc, elem| or(elem, acc)))
457}
458
459pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
476 BinaryExpr::new(lhs, Operator::And, rhs).into_expr()
477}
478
479pub fn and_collect<I>(iter: I) -> Option<ExprRef>
482where
483 I: IntoIterator<Item = ExprRef>,
484 I::IntoIter: DoubleEndedIterator<Item = ExprRef>,
485{
486 let mut iter = iter.into_iter();
487 let first = iter.next_back()?;
488 Some(iter.rfold(first, |acc, elem| and(elem, acc)))
489}
490
491pub fn and_collect_right<I>(iter: I) -> Option<ExprRef>
494where
495 I: IntoIterator<Item = ExprRef>,
496{
497 let iter = iter.into_iter();
498 iter.reduce(and)
499}
500
501pub fn checked_add(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
524 BinaryExpr::new(lhs, Operator::Add, rhs).into_expr()
525}
526
527#[cfg(test)]
528mod tests {
529 use std::sync::Arc;
530
531 use vortex_dtype::{DType, Nullability};
532
533 use crate::{
534 VortexExpr, and, and_collect, and_collect_right, col, eq, gt, gt_eq, lit, lt, lt_eq,
535 not_eq, or, test_harness,
536 };
537
538 #[test]
539 fn and_collect_left_assoc() {
540 let values = vec![lit(1), lit(2), lit(3)];
541 assert_eq!(
542 Some(and(lit(1), and(lit(2), lit(3)))),
543 and_collect(values.into_iter())
544 );
545 }
546
547 #[test]
548 fn and_collect_right_assoc() {
549 let values = vec![lit(1), lit(2), lit(3)];
550 assert_eq!(
551 Some(and(and(lit(1), lit(2)), lit(3))),
552 and_collect_right(values.into_iter())
553 );
554 }
555
556 #[test]
557 fn dtype() {
558 let dtype = test_harness::struct_dtype();
559 let bool1: Arc<dyn VortexExpr> = col("bool1");
560 let bool2: Arc<dyn VortexExpr> = col("bool2");
561 assert_eq!(
562 and(bool1.clone(), bool2.clone())
563 .return_dtype(&dtype)
564 .unwrap(),
565 DType::Bool(Nullability::NonNullable)
566 );
567 assert_eq!(
568 or(bool1.clone(), bool2.clone())
569 .return_dtype(&dtype)
570 .unwrap(),
571 DType::Bool(Nullability::NonNullable)
572 );
573
574 let col1: Arc<dyn VortexExpr> = col("col1");
575 let col2: Arc<dyn VortexExpr> = col("col2");
576
577 assert_eq!(
578 eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
579 DType::Bool(Nullability::Nullable)
580 );
581 assert_eq!(
582 not_eq(col1.clone(), col2.clone())
583 .return_dtype(&dtype)
584 .unwrap(),
585 DType::Bool(Nullability::Nullable)
586 );
587 assert_eq!(
588 gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
589 DType::Bool(Nullability::Nullable)
590 );
591 assert_eq!(
592 gt_eq(col1.clone(), col2.clone())
593 .return_dtype(&dtype)
594 .unwrap(),
595 DType::Bool(Nullability::Nullable)
596 );
597 assert_eq!(
598 lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
599 DType::Bool(Nullability::Nullable)
600 );
601 assert_eq!(
602 lt_eq(col1.clone(), col2.clone())
603 .return_dtype(&dtype)
604 .unwrap(),
605 DType::Bool(Nullability::Nullable)
606 );
607
608 assert_eq!(
609 or(
610 lt(col1.clone(), col2.clone()),
611 not_eq(col1.clone(), col2.clone())
612 )
613 .return_dtype(&dtype)
614 .unwrap(),
615 DType::Bool(Nullability::Nullable)
616 );
617 }
618}