1use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_dtype::DType;
8use vortex_error::{VortexExpect, VortexResult, vortex_bail};
9use vortex_proto::expr as pb;
10
11use crate::compute::{add, and_kleene, compare, div, mul, or_kleene, sub};
12use crate::expr::expression::Expression;
13use crate::expr::exprs::literal::lit;
14use crate::expr::exprs::operators::Operator;
15use crate::expr::{ChildName, ExprId, ExpressionView, StatsCatalog, VTable, VTableExt};
16use crate::{ArrayRef, compute};
17
18pub struct Binary;
19
20impl VTable for Binary {
21 type Instance = Operator;
22
23 fn id(&self) -> ExprId {
24 ExprId::from("vortex.binary")
25 }
26
27 fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
28 Ok(Some(
29 pb::BinaryOpts {
30 op: (*instance).into(),
31 }
32 .encode_to_vec(),
33 ))
34 }
35
36 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
37 let opts = pb::BinaryOpts::decode(metadata)?;
38 Ok(Some(Operator::try_from(opts.op)?))
39 }
40
41 fn validate(&self, _expr: &ExpressionView<Self>) -> VortexResult<()> {
42 Ok(())
44 }
45
46 fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
47 match child_idx {
48 0 => ChildName::from("lhs"),
49 1 => ChildName::from("rhs"),
50 _ => unreachable!("Binary has only two children"),
51 }
52 }
53
54 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
55 write!(f, "(")?;
56 expr.lhs().fmt_sql(f)?;
57 write!(f, " {} ", expr.operator())?;
58 expr.rhs().fmt_sql(f)?;
59 write!(f, ")")
60 }
61
62 fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", *instance)
64 }
65
66 fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
67 let lhs = expr.lhs().return_dtype(scope)?;
68 let rhs = expr.rhs().return_dtype(scope)?;
69
70 if expr.operator().is_arithmetic() {
71 if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
72 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
73 }
74 vortex_bail!(
75 "incompatible types for arithmetic operation: {} {}",
76 lhs,
77 rhs
78 );
79 }
80
81 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
82 }
83
84 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
85 let lhs = expr.lhs().evaluate(scope)?;
86 let rhs = expr.rhs().evaluate(scope)?;
87
88 match expr.operator() {
89 Operator::Eq => compare(&lhs, &rhs, compute::Operator::Eq),
90 Operator::NotEq => compare(&lhs, &rhs, compute::Operator::NotEq),
91 Operator::Lt => compare(&lhs, &rhs, compute::Operator::Lt),
92 Operator::Lte => compare(&lhs, &rhs, compute::Operator::Lte),
93 Operator::Gt => compare(&lhs, &rhs, compute::Operator::Gt),
94 Operator::Gte => compare(&lhs, &rhs, compute::Operator::Gte),
95 Operator::And => and_kleene(&lhs, &rhs),
96 Operator::Or => or_kleene(&lhs, &rhs),
97 Operator::Add => add(&lhs, &rhs),
98 Operator::Sub => sub(&lhs, &rhs),
99 Operator::Mul => mul(&lhs, &rhs),
100 Operator::Div => div(&lhs, &rhs),
101 }
102 }
103
104 fn stat_falsification(
105 &self,
106 expr: &ExpressionView<Self>,
107 catalog: &mut dyn StatsCatalog,
108 ) -> Option<Expression> {
109 #[inline]
123 fn with_nan_predicate(
124 lhs: &Expression,
125 rhs: &Expression,
126 value_predicate: Expression,
127 catalog: &mut dyn StatsCatalog,
128 ) -> Expression {
129 let nan_predicate = lhs
130 .stat_nan_count(catalog)
131 .into_iter()
132 .chain(rhs.stat_nan_count(catalog))
133 .map(|nans| eq(nans, lit(0u64)))
134 .reduce(and);
135
136 if let Some(nan_check) = nan_predicate {
137 and(nan_check, value_predicate)
138 } else {
139 value_predicate
140 }
141 }
142
143 match expr.operator() {
144 Operator::Eq => {
145 let min_lhs = expr.lhs().stat_min(catalog);
146 let max_lhs = expr.lhs().stat_max(catalog);
147
148 let min_rhs = expr.rhs().stat_min(catalog);
149 let max_rhs = expr.rhs().stat_max(catalog);
150
151 let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
152 let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
153
154 let min_max_check = left.into_iter().chain(right).reduce(or)?;
155
156 Some(with_nan_predicate(
158 expr.lhs(),
159 expr.rhs(),
160 min_max_check,
161 catalog,
162 ))
163 }
164 Operator::NotEq => {
165 let min_lhs = expr.lhs().stat_min(catalog)?;
166 let max_lhs = expr.lhs().stat_max(catalog)?;
167
168 let min_rhs = expr.rhs().stat_min(catalog)?;
169 let max_rhs = expr.rhs().stat_max(catalog)?;
170
171 let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
172
173 Some(with_nan_predicate(
174 expr.lhs(),
175 expr.rhs(),
176 min_max_check,
177 catalog,
178 ))
179 }
180 Operator::Gt => {
181 let min_max_check =
182 lt_eq(expr.lhs().stat_max(catalog)?, expr.rhs().stat_min(catalog)?);
183
184 Some(with_nan_predicate(
185 expr.lhs(),
186 expr.rhs(),
187 min_max_check,
188 catalog,
189 ))
190 }
191 Operator::Gte => {
192 let min_max_check =
194 lt(expr.lhs().stat_max(catalog)?, expr.rhs().stat_min(catalog)?);
195
196 Some(with_nan_predicate(
197 expr.lhs(),
198 expr.rhs(),
199 min_max_check,
200 catalog,
201 ))
202 }
203 Operator::Lt => {
204 let min_max_check =
206 gt_eq(expr.lhs().stat_min(catalog)?, expr.rhs().stat_max(catalog)?);
207
208 Some(with_nan_predicate(
209 expr.lhs(),
210 expr.rhs(),
211 min_max_check,
212 catalog,
213 ))
214 }
215 Operator::Lte => {
216 let min_max_check =
218 gt(expr.lhs().stat_min(catalog)?, expr.rhs().stat_max(catalog)?);
219
220 Some(with_nan_predicate(
221 expr.lhs(),
222 expr.rhs(),
223 min_max_check,
224 catalog,
225 ))
226 }
227 Operator::And => expr
228 .lhs()
229 .stat_falsification(catalog)
230 .into_iter()
231 .chain(expr.rhs().stat_falsification(catalog))
232 .reduce(or),
233 Operator::Or => Some(and(
234 expr.lhs().stat_falsification(catalog)?,
235 expr.rhs().stat_falsification(catalog)?,
236 )),
237 Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
238 }
239 }
240}
241
242impl ExpressionView<'_, Binary> {
243 pub fn lhs(&self) -> &Expression {
244 &self.children()[0]
245 }
246
247 pub fn rhs(&self) -> &Expression {
248 &self.children()[1]
249 }
250
251 pub fn operator(&self) -> Operator {
252 *self.data()
253 }
254}
255
256pub fn eq(lhs: Expression, rhs: Expression) -> Expression {
275 Binary
276 .try_new_expr(Operator::Eq, [lhs, rhs])
277 .vortex_expect("Failed to create Eq binary expression")
278}
279
280pub fn not_eq(lhs: Expression, rhs: Expression) -> Expression {
299 Binary
300 .try_new_expr(Operator::NotEq, [lhs, rhs])
301 .vortex_expect("Failed to create NotEq binary expression")
302}
303
304pub fn gt_eq(lhs: Expression, rhs: Expression) -> Expression {
323 Binary
324 .try_new_expr(Operator::Gte, [lhs, rhs])
325 .vortex_expect("Failed to create Gte binary expression")
326}
327
328pub fn gt(lhs: Expression, rhs: Expression) -> Expression {
347 Binary
348 .try_new_expr(Operator::Gt, [lhs, rhs])
349 .vortex_expect("Failed to create Gt binary expression")
350}
351
352pub fn lt_eq(lhs: Expression, rhs: Expression) -> Expression {
371 Binary
372 .try_new_expr(Operator::Lte, [lhs, rhs])
373 .vortex_expect("Failed to create Lte binary expression")
374}
375
376pub fn lt(lhs: Expression, rhs: Expression) -> Expression {
395 Binary
396 .try_new_expr(Operator::Lt, [lhs, rhs])
397 .vortex_expect("Failed to create Lt binary expression")
398}
399
400pub fn or(lhs: Expression, rhs: Expression) -> Expression {
417 Binary
418 .try_new_expr(Operator::Or, [lhs, rhs])
419 .vortex_expect("Failed to create Or binary expression")
420}
421
422pub fn or_collect<I>(iter: I) -> Option<Expression>
425where
426 I: IntoIterator<Item = Expression>,
427 I::IntoIter: DoubleEndedIterator<Item = Expression>,
428{
429 let mut iter = iter.into_iter();
430 let first = iter.next_back()?;
431 Some(iter.rfold(first, |acc, elem| or(elem, acc)))
432}
433
434pub fn and(lhs: Expression, rhs: Expression) -> Expression {
451 Binary
452 .try_new_expr(Operator::And, [lhs, rhs])
453 .vortex_expect("Failed to create And binary expression")
454}
455
456pub fn and_collect<I>(iter: I) -> Option<Expression>
459where
460 I: IntoIterator<Item = Expression>,
461 I::IntoIter: DoubleEndedIterator<Item = Expression>,
462{
463 let mut iter = iter.into_iter();
464 let first = iter.next_back()?;
465 Some(iter.rfold(first, |acc, elem| and(elem, acc)))
466}
467
468pub fn and_collect_right<I>(iter: I) -> Option<Expression>
471where
472 I: IntoIterator<Item = Expression>,
473{
474 let iter = iter.into_iter();
475 iter.reduce(and)
476}
477
478pub fn checked_add(lhs: Expression, rhs: Expression) -> Expression {
501 Binary
502 .try_new_expr(Operator::Add, [lhs, rhs])
503 .vortex_expect("Failed to create Add binary expression")
504}
505
506#[cfg(test)]
507mod tests {
508 use vortex_dtype::{DType, Nullability};
509
510 use super::{and, and_collect, and_collect_right, eq, gt, gt_eq, lt, lt_eq, not_eq, or};
511 use crate::expr::exprs::get_item::col;
512 use crate::expr::exprs::literal::lit;
513 use crate::expr::{Expression, test_harness};
514
515 #[test]
516 fn and_collect_left_assoc() {
517 let values = vec![lit(1), lit(2), lit(3)];
518 assert_eq!(
519 Some(and(lit(1), and(lit(2), lit(3)))),
520 and_collect(values.into_iter())
521 );
522 }
523
524 #[test]
525 fn and_collect_right_assoc() {
526 let values = vec![lit(1), lit(2), lit(3)];
527 assert_eq!(
528 Some(and(and(lit(1), lit(2)), lit(3))),
529 and_collect_right(values.into_iter())
530 );
531 }
532
533 #[test]
534 fn dtype() {
535 let dtype = test_harness::struct_dtype();
536 let bool1: Expression = col("bool1");
537 let bool2: Expression = col("bool2");
538 assert_eq!(
539 and(bool1.clone(), bool2.clone())
540 .return_dtype(&dtype)
541 .unwrap(),
542 DType::Bool(Nullability::NonNullable)
543 );
544 assert_eq!(
545 or(bool1, bool2).return_dtype(&dtype).unwrap(),
546 DType::Bool(Nullability::NonNullable)
547 );
548
549 let col1: Expression = col("col1");
550 let col2: Expression = col("col2");
551
552 assert_eq!(
553 eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
554 DType::Bool(Nullability::Nullable)
555 );
556 assert_eq!(
557 not_eq(col1.clone(), col2.clone())
558 .return_dtype(&dtype)
559 .unwrap(),
560 DType::Bool(Nullability::Nullable)
561 );
562 assert_eq!(
563 gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
564 DType::Bool(Nullability::Nullable)
565 );
566 assert_eq!(
567 gt_eq(col1.clone(), col2.clone())
568 .return_dtype(&dtype)
569 .unwrap(),
570 DType::Bool(Nullability::Nullable)
571 );
572 assert_eq!(
573 lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
574 DType::Bool(Nullability::Nullable)
575 );
576 assert_eq!(
577 lt_eq(col1.clone(), col2.clone())
578 .return_dtype(&dtype)
579 .unwrap(),
580 DType::Bool(Nullability::Nullable)
581 );
582
583 assert_eq!(
584 or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
585 .return_dtype(&dtype)
586 .unwrap(),
587 DType::Bool(Nullability::Nullable)
588 );
589 }
590}