1use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_dtype::DType;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_proto::expr as pb;
12
13use crate::ArrayRef;
14use crate::compute;
15use crate::compute::add;
16use crate::compute::and_kleene;
17use crate::compute::compare;
18use crate::compute::div;
19use crate::compute::mul;
20use crate::compute::or_kleene;
21use crate::compute::sub;
22use crate::expr::ChildName;
23use crate::expr::ExprId;
24use crate::expr::ExpressionView;
25use crate::expr::ScalarFnExprExt;
26use crate::expr::StatsCatalog;
27use crate::expr::VTable;
28use crate::expr::VTableExt;
29use crate::expr::expression::Expression;
30use crate::expr::exprs::literal::lit;
31use crate::expr::exprs::operators::Operator;
32use crate::expr::stats::Stat;
33use crate::scalar_fns::binary;
34
35pub struct Binary;
36
37impl VTable for Binary {
38 type Instance = Operator;
39
40 fn id(&self) -> ExprId {
41 ExprId::from("vortex.binary")
42 }
43
44 fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
45 Ok(Some(
46 pb::BinaryOpts {
47 op: (*instance).into(),
48 }
49 .encode_to_vec(),
50 ))
51 }
52
53 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
54 let opts = pb::BinaryOpts::decode(metadata)?;
55 Ok(Some(Operator::try_from(opts.op)?))
56 }
57
58 fn validate(&self, _expr: &ExpressionView<Self>) -> VortexResult<()> {
59 Ok(())
61 }
62
63 fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
64 match child_idx {
65 0 => ChildName::from("lhs"),
66 1 => ChildName::from("rhs"),
67 _ => unreachable!("Binary has only two children"),
68 }
69 }
70
71 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
72 write!(f, "(")?;
73 expr.lhs().fmt_sql(f)?;
74 write!(f, " {} ", expr.operator())?;
75 expr.rhs().fmt_sql(f)?;
76 write!(f, ")")
77 }
78
79 fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result {
80 write!(f, "{}", *instance)
81 }
82
83 fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
84 let lhs = expr.lhs().return_dtype(scope)?;
85 let rhs = expr.rhs().return_dtype(scope)?;
86
87 if expr.operator().is_arithmetic() {
88 if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
89 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
90 }
91 vortex_bail!(
92 "incompatible types for arithmetic operation: {} {}",
93 lhs,
94 rhs
95 );
96 }
97
98 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
99 }
100
101 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
102 let lhs = expr.lhs().evaluate(scope)?;
103 let rhs = expr.rhs().evaluate(scope)?;
104
105 match expr.operator() {
106 Operator::Eq => compare(&lhs, &rhs, compute::Operator::Eq),
107 Operator::NotEq => compare(&lhs, &rhs, compute::Operator::NotEq),
108 Operator::Lt => compare(&lhs, &rhs, compute::Operator::Lt),
109 Operator::Lte => compare(&lhs, &rhs, compute::Operator::Lte),
110 Operator::Gt => compare(&lhs, &rhs, compute::Operator::Gt),
111 Operator::Gte => compare(&lhs, &rhs, compute::Operator::Gte),
112 Operator::And => and_kleene(&lhs, &rhs),
113 Operator::Or => or_kleene(&lhs, &rhs),
114 Operator::Add => add(&lhs, &rhs),
115 Operator::Sub => sub(&lhs, &rhs),
116 Operator::Mul => mul(&lhs, &rhs),
117 Operator::Div => div(&lhs, &rhs),
118 }
119 }
120
121 fn stat_falsification(
122 &self,
123 expr: &ExpressionView<Self>,
124 catalog: &dyn StatsCatalog,
125 ) -> Option<Expression> {
126 #[inline]
140 fn with_nan_predicate(
141 lhs: &Expression,
142 rhs: &Expression,
143 value_predicate: Expression,
144 catalog: &dyn StatsCatalog,
145 ) -> Expression {
146 let nan_predicate = lhs
147 .stat_expression(Stat::NaNCount, catalog)
148 .into_iter()
149 .chain(rhs.stat_expression(Stat::NaNCount, catalog))
150 .map(|nans| eq(nans, lit(0u64)))
151 .reduce(and);
152
153 if let Some(nan_check) = nan_predicate {
154 and(nan_check, value_predicate)
155 } else {
156 value_predicate
157 }
158 }
159
160 match expr.operator() {
161 Operator::Eq => {
162 let min_lhs = expr.lhs().stat_min(catalog);
163 let max_lhs = expr.lhs().stat_max(catalog);
164
165 let min_rhs = expr.rhs().stat_min(catalog);
166 let max_rhs = expr.rhs().stat_max(catalog);
167
168 let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
169 let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
170
171 let min_max_check = left.into_iter().chain(right).reduce(or)?;
172
173 Some(with_nan_predicate(
175 expr.lhs(),
176 expr.rhs(),
177 min_max_check,
178 catalog,
179 ))
180 }
181 Operator::NotEq => {
182 let min_lhs = expr.lhs().stat_min(catalog)?;
183 let max_lhs = expr.lhs().stat_max(catalog)?;
184
185 let min_rhs = expr.rhs().stat_min(catalog)?;
186 let max_rhs = expr.rhs().stat_max(catalog)?;
187
188 let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
189
190 Some(with_nan_predicate(
191 expr.lhs(),
192 expr.rhs(),
193 min_max_check,
194 catalog,
195 ))
196 }
197 Operator::Gt => {
198 let min_max_check =
199 lt_eq(expr.lhs().stat_max(catalog)?, expr.rhs().stat_min(catalog)?);
200
201 Some(with_nan_predicate(
202 expr.lhs(),
203 expr.rhs(),
204 min_max_check,
205 catalog,
206 ))
207 }
208 Operator::Gte => {
209 let min_max_check =
211 lt(expr.lhs().stat_max(catalog)?, expr.rhs().stat_min(catalog)?);
212
213 Some(with_nan_predicate(
214 expr.lhs(),
215 expr.rhs(),
216 min_max_check,
217 catalog,
218 ))
219 }
220 Operator::Lt => {
221 let min_max_check =
223 gt_eq(expr.lhs().stat_min(catalog)?, expr.rhs().stat_max(catalog)?);
224
225 Some(with_nan_predicate(
226 expr.lhs(),
227 expr.rhs(),
228 min_max_check,
229 catalog,
230 ))
231 }
232 Operator::Lte => {
233 let min_max_check =
235 gt(expr.lhs().stat_min(catalog)?, expr.rhs().stat_max(catalog)?);
236
237 Some(with_nan_predicate(
238 expr.lhs(),
239 expr.rhs(),
240 min_max_check,
241 catalog,
242 ))
243 }
244 Operator::And => expr
245 .lhs()
246 .stat_falsification(catalog)
247 .into_iter()
248 .chain(expr.rhs().stat_falsification(catalog))
249 .reduce(or),
250 Operator::Or => Some(and(
251 expr.lhs().stat_falsification(catalog)?,
252 expr.rhs().stat_falsification(catalog)?,
253 )),
254 Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
255 }
256 }
257
258 fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
259 false
260 }
261
262 fn is_fallible(&self, instance: &Self::Instance) -> bool {
263 let infallible = matches!(
266 instance,
267 Operator::Eq
268 | Operator::NotEq
269 | Operator::Gt
270 | Operator::Gte
271 | Operator::Lt
272 | Operator::Lte
273 | Operator::And
274 | Operator::Or
275 );
276
277 !infallible
278 }
279
280 fn expr_v2(&self, view: &ExpressionView<Self>) -> VortexResult<Expression> {
281 ScalarFnExprExt::try_new_expr(&binary::BinaryFn, view.operator(), view.children().clone())
282 }
283}
284
285impl ExpressionView<'_, Binary> {
286 pub fn lhs(&self) -> &Expression {
287 &self.children()[0]
288 }
289
290 pub fn rhs(&self) -> &Expression {
291 &self.children()[1]
292 }
293
294 pub fn operator(&self) -> Operator {
295 *self.data()
296 }
297}
298
299pub fn eq(lhs: Expression, rhs: Expression) -> Expression {
318 Binary
319 .try_new_expr(Operator::Eq, [lhs, rhs])
320 .vortex_expect("Failed to create Eq binary expression")
321}
322
323pub fn not_eq(lhs: Expression, rhs: Expression) -> Expression {
342 Binary
343 .try_new_expr(Operator::NotEq, [lhs, rhs])
344 .vortex_expect("Failed to create NotEq binary expression")
345}
346
347pub fn gt_eq(lhs: Expression, rhs: Expression) -> Expression {
366 Binary
367 .try_new_expr(Operator::Gte, [lhs, rhs])
368 .vortex_expect("Failed to create Gte binary expression")
369}
370
371pub fn gt(lhs: Expression, rhs: Expression) -> Expression {
390 Binary
391 .try_new_expr(Operator::Gt, [lhs, rhs])
392 .vortex_expect("Failed to create Gt binary expression")
393}
394
395pub fn lt_eq(lhs: Expression, rhs: Expression) -> Expression {
414 Binary
415 .try_new_expr(Operator::Lte, [lhs, rhs])
416 .vortex_expect("Failed to create Lte binary expression")
417}
418
419pub fn lt(lhs: Expression, rhs: Expression) -> Expression {
438 Binary
439 .try_new_expr(Operator::Lt, [lhs, rhs])
440 .vortex_expect("Failed to create Lt binary expression")
441}
442
443pub fn or(lhs: Expression, rhs: Expression) -> Expression {
460 Binary
461 .try_new_expr(Operator::Or, [lhs, rhs])
462 .vortex_expect("Failed to create Or binary expression")
463}
464
465pub fn or_collect<I>(iter: I) -> Option<Expression>
468where
469 I: IntoIterator<Item = Expression>,
470 I::IntoIter: DoubleEndedIterator<Item = Expression>,
471{
472 let mut iter = iter.into_iter();
473 let first = iter.next_back()?;
474 Some(iter.rfold(first, |acc, elem| or(elem, acc)))
475}
476
477pub fn and(lhs: Expression, rhs: Expression) -> Expression {
494 Binary
495 .try_new_expr(Operator::And, [lhs, rhs])
496 .vortex_expect("Failed to create And binary expression")
497}
498
499pub fn and_collect<I>(iter: I) -> Option<Expression>
502where
503 I: IntoIterator<Item = Expression>,
504 I::IntoIter: DoubleEndedIterator<Item = Expression>,
505{
506 let mut iter = iter.into_iter();
507 let first = iter.next_back()?;
508 Some(iter.rfold(first, |acc, elem| and(elem, acc)))
509}
510
511pub fn and_collect_right<I>(iter: I) -> Option<Expression>
514where
515 I: IntoIterator<Item = Expression>,
516{
517 let iter = iter.into_iter();
518 iter.reduce(and)
519}
520
521pub fn checked_add(lhs: Expression, rhs: Expression) -> Expression {
544 Binary
545 .try_new_expr(Operator::Add, [lhs, rhs])
546 .vortex_expect("Failed to create Add binary expression")
547}
548
549#[cfg(test)]
550mod tests {
551 use vortex_dtype::DType;
552 use vortex_dtype::Nullability;
553
554 use super::and;
555 use super::and_collect;
556 use super::and_collect_right;
557 use super::eq;
558 use super::gt;
559 use super::gt_eq;
560 use super::lt;
561 use super::lt_eq;
562 use super::not_eq;
563 use super::or;
564 use crate::expr::Expression;
565 use crate::expr::exprs::get_item::col;
566 use crate::expr::exprs::literal::lit;
567 use crate::expr::test_harness;
568
569 #[test]
570 fn and_collect_left_assoc() {
571 let values = vec![lit(1), lit(2), lit(3)];
572 assert_eq!(
573 Some(and(lit(1), and(lit(2), lit(3)))),
574 and_collect(values.into_iter())
575 );
576 }
577
578 #[test]
579 fn and_collect_right_assoc() {
580 let values = vec![lit(1), lit(2), lit(3)];
581 assert_eq!(
582 Some(and(and(lit(1), lit(2)), lit(3))),
583 and_collect_right(values.into_iter())
584 );
585 }
586
587 #[test]
588 fn dtype() {
589 let dtype = test_harness::struct_dtype();
590 let bool1: Expression = col("bool1");
591 let bool2: Expression = col("bool2");
592 assert_eq!(
593 and(bool1.clone(), bool2.clone())
594 .return_dtype(&dtype)
595 .unwrap(),
596 DType::Bool(Nullability::NonNullable)
597 );
598 assert_eq!(
599 or(bool1, bool2).return_dtype(&dtype).unwrap(),
600 DType::Bool(Nullability::NonNullable)
601 );
602
603 let col1: Expression = col("col1");
604 let col2: Expression = col("col2");
605
606 assert_eq!(
607 eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
608 DType::Bool(Nullability::Nullable)
609 );
610 assert_eq!(
611 not_eq(col1.clone(), col2.clone())
612 .return_dtype(&dtype)
613 .unwrap(),
614 DType::Bool(Nullability::Nullable)
615 );
616 assert_eq!(
617 gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
618 DType::Bool(Nullability::Nullable)
619 );
620 assert_eq!(
621 gt_eq(col1.clone(), col2.clone())
622 .return_dtype(&dtype)
623 .unwrap(),
624 DType::Bool(Nullability::Nullable)
625 );
626 assert_eq!(
627 lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
628 DType::Bool(Nullability::Nullable)
629 );
630 assert_eq!(
631 lt_eq(col1.clone(), col2.clone())
632 .return_dtype(&dtype)
633 .unwrap(),
634 DType::Bool(Nullability::Nullable)
635 );
636
637 assert_eq!(
638 or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
639 .return_dtype(&dtype)
640 .unwrap(),
641 DType::Bool(Nullability::Nullable)
642 );
643 }
644
645 #[test]
646 fn test_debug_print() {
647 let expr = gt(lit(1), lit(2));
648 assert_eq!(
649 format!("{expr:?}"),
650 "Expression { vtable: vortex.binary, data: >, children: [Expression { vtable: vortex.literal, data: 1i32, children: [] }, Expression { vtable: vortex.literal, data: 2i32, children: [] }] }"
651 );
652 }
653
654 #[test]
655 fn test_display_print() {
656 let expr = gt(lit(1), lit(2));
657 assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
658 }
659}