1use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_dtype::DType;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_proto::expr as pb;
12use vortex_session::VortexSession;
13
14use crate::ArrayRef;
15use crate::compute;
16use crate::compute::BooleanOperator;
17use crate::expr::Arity;
18use crate::expr::ChildName;
19use crate::expr::ExecutionArgs;
20use crate::expr::ExprId;
21use crate::expr::StatsCatalog;
22use crate::expr::VTable;
23use crate::expr::VTableExt;
24use crate::expr::expression::Expression;
25use crate::expr::exprs::literal::lit;
26use crate::expr::exprs::operators::Operator;
27use crate::expr::stats::Stat;
28
29mod boolean;
30pub(crate) use boolean::*;
31mod compare;
32pub use compare::*;
33mod numeric;
34pub(crate) use numeric::*;
35
36pub struct Binary;
37
38impl VTable for Binary {
39 type Options = Operator;
40
41 fn id(&self) -> ExprId {
42 ExprId::from("vortex.binary")
43 }
44
45 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
46 Ok(Some(
47 pb::BinaryOpts {
48 op: (*instance).into(),
49 }
50 .encode_to_vec(),
51 ))
52 }
53
54 fn deserialize(
55 &self,
56 _metadata: &[u8],
57 _session: &VortexSession,
58 ) -> VortexResult<Self::Options> {
59 let opts = pb::BinaryOpts::decode(_metadata)?;
60 Operator::try_from(opts.op)
61 }
62
63 fn arity(&self, _options: &Self::Options) -> Arity {
64 Arity::Exact(2)
65 }
66
67 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
68 match child_idx {
69 0 => ChildName::from("lhs"),
70 1 => ChildName::from("rhs"),
71 _ => unreachable!("Binary has only two children"),
72 }
73 }
74
75 fn fmt_sql(
76 &self,
77 operator: &Operator,
78 expr: &Expression,
79 f: &mut Formatter<'_>,
80 ) -> std::fmt::Result {
81 write!(f, "(")?;
82 expr.child(0).fmt_sql(f)?;
83 write!(f, " {} ", operator)?;
84 expr.child(1).fmt_sql(f)?;
85 write!(f, ")")
86 }
87
88 fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
89 let lhs = &arg_dtypes[0];
90 let rhs = &arg_dtypes[1];
91
92 if operator.is_arithmetic() {
93 if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
94 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
95 }
96 vortex_bail!(
97 "incompatible types for arithmetic operation: {} {}",
98 lhs,
99 rhs
100 );
101 }
102
103 if operator.is_comparison()
104 && !lhs.eq_ignore_nullability(rhs)
105 && !lhs.is_extension()
106 && !rhs.is_extension()
107 {
108 vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs);
109 }
110
111 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
112 }
113
114 fn execute(&self, op: &Operator, args: ExecutionArgs) -> VortexResult<ArrayRef> {
115 let [lhs, rhs] = &args.inputs[..] else {
116 vortex_bail!("Wrong arg count")
117 };
118
119 match op {
120 Operator::Eq => execute_compare(lhs, rhs, compute::Operator::Eq),
121 Operator::NotEq => execute_compare(lhs, rhs, compute::Operator::NotEq),
122 Operator::Lt => execute_compare(lhs, rhs, compute::Operator::Lt),
123 Operator::Lte => execute_compare(lhs, rhs, compute::Operator::Lte),
124 Operator::Gt => execute_compare(lhs, rhs, compute::Operator::Gt),
125 Operator::Gte => execute_compare(lhs, rhs, compute::Operator::Gte),
126 Operator::And => execute_boolean(lhs, rhs, BooleanOperator::AndKleene),
127 Operator::Or => execute_boolean(lhs, rhs, BooleanOperator::OrKleene),
128 Operator::Add => execute_numeric(lhs, rhs, crate::scalar::NumericOperator::Add),
129 Operator::Sub => execute_numeric(lhs, rhs, crate::scalar::NumericOperator::Sub),
130 Operator::Mul => execute_numeric(lhs, rhs, crate::scalar::NumericOperator::Mul),
131 Operator::Div => execute_numeric(lhs, rhs, crate::scalar::NumericOperator::Div),
132 }
133 }
134
135 fn stat_falsification(
136 &self,
137 operator: &Operator,
138 expr: &Expression,
139 catalog: &dyn StatsCatalog,
140 ) -> Option<Expression> {
141 #[inline]
155 fn with_nan_predicate(
156 lhs: &Expression,
157 rhs: &Expression,
158 value_predicate: Expression,
159 catalog: &dyn StatsCatalog,
160 ) -> Expression {
161 let nan_predicate = lhs
162 .stat_expression(Stat::NaNCount, catalog)
163 .into_iter()
164 .chain(rhs.stat_expression(Stat::NaNCount, catalog))
165 .map(|nans| eq(nans, lit(0u64)))
166 .reduce(and);
167
168 if let Some(nan_check) = nan_predicate {
169 and(nan_check, value_predicate)
170 } else {
171 value_predicate
172 }
173 }
174
175 let lhs = expr.child(0);
176 let rhs = expr.child(1);
177 match operator {
178 Operator::Eq => {
179 let min_lhs = lhs.stat_min(catalog);
180 let max_lhs = lhs.stat_max(catalog);
181
182 let min_rhs = rhs.stat_min(catalog);
183 let max_rhs = rhs.stat_max(catalog);
184
185 let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
186 let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
187
188 let min_max_check = left.into_iter().chain(right).reduce(or)?;
189
190 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
192 }
193 Operator::NotEq => {
194 let min_lhs = lhs.stat_min(catalog)?;
195 let max_lhs = lhs.stat_max(catalog)?;
196
197 let min_rhs = rhs.stat_min(catalog)?;
198 let max_rhs = rhs.stat_max(catalog)?;
199
200 let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
201
202 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
203 }
204 Operator::Gt => {
205 let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
206
207 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
208 }
209 Operator::Gte => {
210 let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
212
213 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
214 }
215 Operator::Lt => {
216 let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
218
219 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
220 }
221 Operator::Lte => {
222 let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
224
225 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
226 }
227 Operator::And => lhs
228 .stat_falsification(catalog)
229 .into_iter()
230 .chain(rhs.stat_falsification(catalog))
231 .reduce(or),
232 Operator::Or => Some(and(
233 lhs.stat_falsification(catalog)?,
234 rhs.stat_falsification(catalog)?,
235 )),
236 Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
237 }
238 }
239
240 fn validity(
241 &self,
242 operator: &Operator,
243 expression: &Expression,
244 ) -> VortexResult<Option<Expression>> {
245 let lhs = expression.child(0).validity()?;
246 let rhs = expression.child(1).validity()?;
247
248 Ok(match operator {
249 Operator::And => None,
251 Operator::Or => None,
252 _ => {
253 Some(and(lhs, rhs))
255 }
256 })
257 }
258
259 fn is_null_sensitive(&self, _operator: &Operator) -> bool {
260 false
261 }
262
263 fn is_fallible(&self, operator: &Operator) -> bool {
264 let infallible = matches!(
267 operator,
268 Operator::Eq
269 | Operator::NotEq
270 | Operator::Gt
271 | Operator::Gte
272 | Operator::Lt
273 | Operator::Lte
274 | Operator::And
275 | Operator::Or
276 );
277
278 !infallible
279 }
280}
281
282pub fn eq(lhs: Expression, rhs: Expression) -> Expression {
301 Binary
302 .try_new_expr(Operator::Eq, [lhs, rhs])
303 .vortex_expect("Failed to create Eq binary expression")
304}
305
306pub fn not_eq(lhs: Expression, rhs: Expression) -> Expression {
325 Binary
326 .try_new_expr(Operator::NotEq, [lhs, rhs])
327 .vortex_expect("Failed to create NotEq binary expression")
328}
329
330pub fn gt_eq(lhs: Expression, rhs: Expression) -> Expression {
349 Binary
350 .try_new_expr(Operator::Gte, [lhs, rhs])
351 .vortex_expect("Failed to create Gte binary expression")
352}
353
354pub fn gt(lhs: Expression, rhs: Expression) -> Expression {
373 Binary
374 .try_new_expr(Operator::Gt, [lhs, rhs])
375 .vortex_expect("Failed to create Gt binary expression")
376}
377
378pub fn lt_eq(lhs: Expression, rhs: Expression) -> Expression {
397 Binary
398 .try_new_expr(Operator::Lte, [lhs, rhs])
399 .vortex_expect("Failed to create Lte binary expression")
400}
401
402pub fn lt(lhs: Expression, rhs: Expression) -> Expression {
421 Binary
422 .try_new_expr(Operator::Lt, [lhs, rhs])
423 .vortex_expect("Failed to create Lt binary expression")
424}
425
426pub fn or(lhs: Expression, rhs: Expression) -> Expression {
443 Binary
444 .try_new_expr(Operator::Or, [lhs, rhs])
445 .vortex_expect("Failed to create Or binary expression")
446}
447
448pub fn or_collect<I>(iter: I) -> Option<Expression>
455where
456 I: IntoIterator<Item = Expression>,
457{
458 let exprs: Vec<_> = iter.into_iter().collect();
459 balanced_reduce(exprs, or)
460}
461
462pub fn and(lhs: Expression, rhs: Expression) -> Expression {
479 Binary
480 .try_new_expr(Operator::And, [lhs, rhs])
481 .vortex_expect("Failed to create And binary expression")
482}
483
484pub fn and_collect<I>(iter: I) -> Option<Expression>
491where
492 I: IntoIterator<Item = Expression>,
493{
494 let exprs: Vec<_> = iter.into_iter().collect();
495 balanced_reduce(exprs, and)
496}
497
498fn balanced_reduce<F>(mut exprs: Vec<Expression>, combine: F) -> Option<Expression>
500where
501 F: Fn(Expression, Expression) -> Expression + Copy,
502{
503 if exprs.is_empty() {
504 return None;
505 }
506 if exprs.len() == 1 {
507 return exprs.pop();
508 }
509
510 while exprs.len() > 1 {
511 let exprs_len = exprs.len();
512
513 for target_idx in 0..(exprs.len() / 2) {
514 let item_idx = target_idx * 2;
515 let new = combine(exprs[item_idx].clone(), exprs[item_idx + 1].clone());
516 exprs[target_idx] = new;
517 }
518
519 if !exprs.len().is_multiple_of(2) {
520 let lhs = exprs[(exprs.len() / 2) - 1].clone();
522 let rhs = exprs[exprs.len() - 1].clone();
523 exprs[exprs_len / 2 - 1] = combine(lhs, rhs);
524 }
525
526 exprs.truncate(exprs_len / 2);
527 }
528
529 exprs.pop()
530}
531
532pub fn checked_add(lhs: Expression, rhs: Expression) -> Expression {
553 Binary
554 .try_new_expr(Operator::Add, [lhs, rhs])
555 .vortex_expect("Failed to create Add binary expression")
556}
557
558#[cfg(test)]
559mod tests {
560 use vortex_dtype::DType;
561 use vortex_dtype::Nullability;
562
563 use super::*;
564 use crate::assert_arrays_eq;
565 use crate::compute::compare;
566 use crate::expr::Expression;
567 use crate::expr::exprs::get_item::col;
568 use crate::expr::exprs::literal::lit;
569 use crate::expr::test_harness;
570 use crate::scalar::Scalar;
571
572 #[test]
573 fn and_collect_balanced() {
574 let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
575
576 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
577 vortex.binary(and)
578 ├── lhs: vortex.binary(and)
579 │ ├── lhs: vortex.literal(1i32)
580 │ └── rhs: vortex.literal(2i32)
581 └── rhs: vortex.binary(and)
582 ├── lhs: vortex.binary(and)
583 │ ├── lhs: vortex.literal(3i32)
584 │ └── rhs: vortex.literal(4i32)
585 └── rhs: vortex.literal(5i32)
586 ");
587
588 let values = vec![lit(1), lit(2), lit(3), lit(4)];
590 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
591 vortex.binary(and)
592 ├── lhs: vortex.binary(and)
593 │ ├── lhs: vortex.literal(1i32)
594 │ └── rhs: vortex.literal(2i32)
595 └── rhs: vortex.binary(and)
596 ├── lhs: vortex.literal(3i32)
597 └── rhs: vortex.literal(4i32)
598 ");
599
600 let values = vec![lit(1)];
602 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @"vortex.literal(1i32)");
603
604 let values: Vec<Expression> = vec![];
606 assert!(and_collect(values.into_iter()).is_none());
607 }
608
609 #[test]
610 fn or_collect_balanced() {
611 let values = vec![lit(1), lit(2), lit(3), lit(4)];
613 insta::assert_snapshot!(or_collect(values.into_iter()).unwrap().display_tree(), @r"
614 vortex.binary(or)
615 ├── lhs: vortex.binary(or)
616 │ ├── lhs: vortex.literal(1i32)
617 │ └── rhs: vortex.literal(2i32)
618 └── rhs: vortex.binary(or)
619 ├── lhs: vortex.literal(3i32)
620 └── rhs: vortex.literal(4i32)
621 ");
622 }
623
624 #[test]
625 fn dtype() {
626 let dtype = test_harness::struct_dtype();
627 let bool1: Expression = col("bool1");
628 let bool2: Expression = col("bool2");
629 assert_eq!(
630 and(bool1.clone(), bool2.clone())
631 .return_dtype(&dtype)
632 .unwrap(),
633 DType::Bool(Nullability::NonNullable)
634 );
635 assert_eq!(
636 or(bool1, bool2).return_dtype(&dtype).unwrap(),
637 DType::Bool(Nullability::NonNullable)
638 );
639
640 let col1: Expression = col("col1");
641 let col2: Expression = col("col2");
642
643 assert_eq!(
644 eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
645 DType::Bool(Nullability::Nullable)
646 );
647 assert_eq!(
648 not_eq(col1.clone(), col2.clone())
649 .return_dtype(&dtype)
650 .unwrap(),
651 DType::Bool(Nullability::Nullable)
652 );
653 assert_eq!(
654 gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
655 DType::Bool(Nullability::Nullable)
656 );
657 assert_eq!(
658 gt_eq(col1.clone(), col2.clone())
659 .return_dtype(&dtype)
660 .unwrap(),
661 DType::Bool(Nullability::Nullable)
662 );
663 assert_eq!(
664 lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
665 DType::Bool(Nullability::Nullable)
666 );
667 assert_eq!(
668 lt_eq(col1.clone(), col2.clone())
669 .return_dtype(&dtype)
670 .unwrap(),
671 DType::Bool(Nullability::Nullable)
672 );
673
674 assert_eq!(
675 or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
676 .return_dtype(&dtype)
677 .unwrap(),
678 DType::Bool(Nullability::Nullable)
679 );
680 }
681
682 #[test]
683 fn test_display_print() {
684 let expr = gt(lit(1), lit(2));
685 assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
686 }
687
688 #[test]
691 fn test_struct_comparison() {
692 use crate::IntoArray;
693 use crate::arrays::StructArray;
694
695 let lhs_struct = StructArray::from_fields(&[
697 (
698 "a",
699 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
700 ),
701 (
702 "b",
703 crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
704 ),
705 ])
706 .unwrap()
707 .into_array();
708
709 let rhs_struct_equal = StructArray::from_fields(&[
710 (
711 "a",
712 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
713 ),
714 (
715 "b",
716 crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
717 ),
718 ])
719 .unwrap()
720 .into_array();
721
722 let rhs_struct_different = StructArray::from_fields(&[
723 (
724 "a",
725 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
726 ),
727 (
728 "b",
729 crate::arrays::PrimitiveArray::from_iter([4i32]).into_array(),
730 ),
731 ])
732 .unwrap()
733 .into_array();
734
735 let result_equal = compare(&lhs_struct, &rhs_struct_equal, compute::Operator::Eq).unwrap();
737 assert_eq!(
738 result_equal.scalar_at(0).vortex_expect("value"),
739 Scalar::bool(true, Nullability::NonNullable),
740 "Equal structs should be equal"
741 );
742
743 let result_different =
744 compare(&lhs_struct, &rhs_struct_different, compute::Operator::Eq).unwrap();
745 assert_eq!(
746 result_different.scalar_at(0).vortex_expect("value"),
747 Scalar::bool(false, Nullability::NonNullable),
748 "Different structs should not be equal"
749 );
750 }
751
752 #[test]
753 fn test_or_kleene_validity() {
754 use crate::IntoArray;
755 use crate::arrays::BoolArray;
756 use crate::arrays::StructArray;
757 use crate::expr::exprs::get_item::col;
758
759 let struct_arr = StructArray::from_fields(&[
760 ("a", BoolArray::from_iter([Some(true)]).into_array()),
761 (
762 "b",
763 BoolArray::from_iter([Option::<bool>::None]).into_array(),
764 ),
765 ])
766 .unwrap()
767 .into_array();
768
769 let expr = or(col("a"), col("b"));
770 let result = struct_arr.apply(&expr).unwrap();
771
772 assert_arrays_eq!(result, BoolArray::from_iter([Some(true)]).into_array())
773 }
774}