1use std::fmt::Formatter;
5
6use arrow_ord::cmp;
7use prost::Message;
8use vortex_compute::arrow::IntoArrow;
9use vortex_compute::arrow::IntoVector;
10use vortex_compute::logical::LogicalAndKleene;
11use vortex_compute::logical::LogicalOrKleene;
12use vortex_dtype::DType;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_err;
17use vortex_proto::expr as pb;
18use vortex_vector::Datum;
19use vortex_vector::VectorOps;
20
21use crate::ArrayRef;
22use crate::compute;
23use crate::compute::add;
24use crate::compute::and_kleene;
25use crate::compute::compare;
26use crate::compute::div;
27use crate::compute::mul;
28use crate::compute::or_kleene;
29use crate::compute::sub;
30use crate::expr::Arity;
31use crate::expr::ChildName;
32use crate::expr::ExecutionArgs;
33use crate::expr::ExprId;
34use crate::expr::StatsCatalog;
35use crate::expr::VTable;
36use crate::expr::VTableExt;
37use crate::expr::expression::Expression;
38use crate::expr::exprs::literal::lit;
39use crate::expr::exprs::operators::Operator;
40use crate::expr::stats::Stat;
41
42pub struct Binary;
43
44impl VTable for Binary {
45 type Options = Operator;
46
47 fn id(&self) -> ExprId {
48 ExprId::from("vortex.binary")
49 }
50
51 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
52 Ok(Some(
53 pb::BinaryOpts {
54 op: (*instance).into(),
55 }
56 .encode_to_vec(),
57 ))
58 }
59
60 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Self::Options> {
61 let opts = pb::BinaryOpts::decode(metadata)?;
62 Operator::try_from(opts.op)
63 }
64
65 fn arity(&self, _options: &Self::Options) -> Arity {
66 Arity::Exact(2)
67 }
68
69 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
70 match child_idx {
71 0 => ChildName::from("lhs"),
72 1 => ChildName::from("rhs"),
73 _ => unreachable!("Binary has only two children"),
74 }
75 }
76
77 fn fmt_sql(
78 &self,
79 operator: &Operator,
80 expr: &Expression,
81 f: &mut Formatter<'_>,
82 ) -> std::fmt::Result {
83 write!(f, "(")?;
84 expr.child(0).fmt_sql(f)?;
85 write!(f, " {} ", operator)?;
86 expr.child(1).fmt_sql(f)?;
87 write!(f, ")")
88 }
89
90 fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
91 let lhs = &arg_dtypes[0];
92 let rhs = &arg_dtypes[1];
93
94 if operator.is_arithmetic() {
95 if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
96 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
97 }
98 vortex_bail!(
99 "incompatible types for arithmetic operation: {} {}",
100 lhs,
101 rhs
102 );
103 }
104
105 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
106 }
107
108 fn evaluate(
109 &self,
110 operator: &Operator,
111 expr: &Expression,
112 scope: &ArrayRef,
113 ) -> VortexResult<ArrayRef> {
114 let lhs = expr.child(0).evaluate(scope)?;
115 let rhs = expr.child(1).evaluate(scope)?;
116
117 match operator {
118 Operator::Eq => compare(&lhs, &rhs, compute::Operator::Eq),
119 Operator::NotEq => compare(&lhs, &rhs, compute::Operator::NotEq),
120 Operator::Lt => compare(&lhs, &rhs, compute::Operator::Lt),
121 Operator::Lte => compare(&lhs, &rhs, compute::Operator::Lte),
122 Operator::Gt => compare(&lhs, &rhs, compute::Operator::Gt),
123 Operator::Gte => compare(&lhs, &rhs, compute::Operator::Gte),
124 Operator::And => and_kleene(&lhs, &rhs),
125 Operator::Or => or_kleene(&lhs, &rhs),
126 Operator::Add => add(&lhs, &rhs),
127 Operator::Sub => sub(&lhs, &rhs),
128 Operator::Mul => mul(&lhs, &rhs),
129 Operator::Div => div(&lhs, &rhs),
130 }
131 }
132
133 fn execute(&self, op: &Operator, args: ExecutionArgs) -> VortexResult<Datum> {
134 let [lhs, rhs]: [Datum; _] = args
135 .datums
136 .try_into()
137 .map_err(|_| vortex_err!("Wrong arg count"))?;
138
139 match op {
140 Operator::And => {
141 return Ok(LogicalAndKleene::and_kleene(&lhs.into_bool(), &rhs.into_bool()).into());
142 }
143 Operator::Or => {
144 return Ok(LogicalOrKleene::or_kleene(&lhs.into_bool(), &rhs.into_bool()).into());
145 }
146 _ => {}
147 }
148
149 let lhs = lhs.into_arrow()?;
150 let rhs = rhs.into_arrow()?;
151
152 let vector = match op {
153 Operator::Eq => cmp::eq(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(),
154 Operator::NotEq => cmp::neq(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(),
155 Operator::Gt => cmp::gt(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(),
156 Operator::Gte => cmp::gt_eq(lhs.as_ref(), rhs.as_ref())?
157 .into_vector()?
158 .into(),
159 Operator::Lt => cmp::lt(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(),
160 Operator::Lte => cmp::lt_eq(lhs.as_ref(), rhs.as_ref())?
161 .into_vector()?
162 .into(),
163
164 Operator::Add => {
165 arrow_arith::numeric::add(lhs.as_ref(), rhs.as_ref())?.into_vector()?
166 }
167 Operator::Sub => {
168 arrow_arith::numeric::sub(lhs.as_ref(), rhs.as_ref())?.into_vector()?
169 }
170 Operator::Mul => {
171 arrow_arith::numeric::mul(lhs.as_ref(), rhs.as_ref())?.into_vector()?
172 }
173 Operator::Div => {
174 arrow_arith::numeric::div(lhs.as_ref(), rhs.as_ref())?.into_vector()?
175 }
176 Operator::And | Operator::Or => {
177 unreachable!("Already dealt with above")
178 }
179 };
180
181 if lhs.get().1 && rhs.get().1 {
183 return Ok(Datum::Scalar(vector.scalar_at(0)));
184 }
185
186 Ok(Datum::Vector(vector))
187 }
188
189 fn stat_falsification(
190 &self,
191 operator: &Operator,
192 expr: &Expression,
193 catalog: &dyn StatsCatalog,
194 ) -> Option<Expression> {
195 #[inline]
209 fn with_nan_predicate(
210 lhs: &Expression,
211 rhs: &Expression,
212 value_predicate: Expression,
213 catalog: &dyn StatsCatalog,
214 ) -> Expression {
215 let nan_predicate = lhs
216 .stat_expression(Stat::NaNCount, catalog)
217 .into_iter()
218 .chain(rhs.stat_expression(Stat::NaNCount, catalog))
219 .map(|nans| eq(nans, lit(0u64)))
220 .reduce(and);
221
222 if let Some(nan_check) = nan_predicate {
223 and(nan_check, value_predicate)
224 } else {
225 value_predicate
226 }
227 }
228
229 let lhs = expr.child(0);
230 let rhs = expr.child(1);
231 match operator {
232 Operator::Eq => {
233 let min_lhs = lhs.stat_min(catalog);
234 let max_lhs = lhs.stat_max(catalog);
235
236 let min_rhs = rhs.stat_min(catalog);
237 let max_rhs = rhs.stat_max(catalog);
238
239 let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
240 let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
241
242 let min_max_check = left.into_iter().chain(right).reduce(or)?;
243
244 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
246 }
247 Operator::NotEq => {
248 let min_lhs = lhs.stat_min(catalog)?;
249 let max_lhs = lhs.stat_max(catalog)?;
250
251 let min_rhs = rhs.stat_min(catalog)?;
252 let max_rhs = rhs.stat_max(catalog)?;
253
254 let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
255
256 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
257 }
258 Operator::Gt => {
259 let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
260
261 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
262 }
263 Operator::Gte => {
264 let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
266
267 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
268 }
269 Operator::Lt => {
270 let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
272
273 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
274 }
275 Operator::Lte => {
276 let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
278
279 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
280 }
281 Operator::And => lhs
282 .stat_falsification(catalog)
283 .into_iter()
284 .chain(rhs.stat_falsification(catalog))
285 .reduce(or),
286 Operator::Or => Some(and(
287 lhs.stat_falsification(catalog)?,
288 rhs.stat_falsification(catalog)?,
289 )),
290 Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
291 }
292 }
293
294 fn is_null_sensitive(&self, _operator: &Operator) -> bool {
295 false
296 }
297
298 fn is_fallible(&self, operator: &Operator) -> bool {
299 let infallible = matches!(
302 operator,
303 Operator::Eq
304 | Operator::NotEq
305 | Operator::Gt
306 | Operator::Gte
307 | Operator::Lt
308 | Operator::Lte
309 | Operator::And
310 | Operator::Or
311 );
312
313 !infallible
314 }
315}
316
317pub fn eq(lhs: Expression, rhs: Expression) -> Expression {
336 Binary
337 .try_new_expr(Operator::Eq, [lhs, rhs])
338 .vortex_expect("Failed to create Eq binary expression")
339}
340
341pub fn not_eq(lhs: Expression, rhs: Expression) -> Expression {
360 Binary
361 .try_new_expr(Operator::NotEq, [lhs, rhs])
362 .vortex_expect("Failed to create NotEq binary expression")
363}
364
365pub fn gt_eq(lhs: Expression, rhs: Expression) -> Expression {
384 Binary
385 .try_new_expr(Operator::Gte, [lhs, rhs])
386 .vortex_expect("Failed to create Gte binary expression")
387}
388
389pub fn gt(lhs: Expression, rhs: Expression) -> Expression {
408 Binary
409 .try_new_expr(Operator::Gt, [lhs, rhs])
410 .vortex_expect("Failed to create Gt binary expression")
411}
412
413pub fn lt_eq(lhs: Expression, rhs: Expression) -> Expression {
432 Binary
433 .try_new_expr(Operator::Lte, [lhs, rhs])
434 .vortex_expect("Failed to create Lte binary expression")
435}
436
437pub fn lt(lhs: Expression, rhs: Expression) -> Expression {
456 Binary
457 .try_new_expr(Operator::Lt, [lhs, rhs])
458 .vortex_expect("Failed to create Lt binary expression")
459}
460
461pub fn or(lhs: Expression, rhs: Expression) -> Expression {
478 Binary
479 .try_new_expr(Operator::Or, [lhs, rhs])
480 .vortex_expect("Failed to create Or binary expression")
481}
482
483pub fn or_collect<I>(iter: I) -> Option<Expression>
486where
487 I: IntoIterator<Item = Expression>,
488 I::IntoIter: DoubleEndedIterator<Item = Expression>,
489{
490 let mut iter = iter.into_iter();
491 let first = iter.next_back()?;
492 Some(iter.rfold(first, |acc, elem| or(elem, acc)))
493}
494
495pub fn and(lhs: Expression, rhs: Expression) -> Expression {
512 Binary
513 .try_new_expr(Operator::And, [lhs, rhs])
514 .vortex_expect("Failed to create And binary expression")
515}
516
517pub fn and_collect<I>(iter: I) -> Option<Expression>
520where
521 I: IntoIterator<Item = Expression>,
522 I::IntoIter: DoubleEndedIterator<Item = Expression>,
523{
524 let mut iter = iter.into_iter();
525 let first = iter.next_back()?;
526 Some(iter.rfold(first, |acc, elem| and(elem, acc)))
527}
528
529pub fn and_collect_right<I>(iter: I) -> Option<Expression>
532where
533 I: IntoIterator<Item = Expression>,
534{
535 let iter = iter.into_iter();
536 iter.reduce(and)
537}
538
539pub fn checked_add(lhs: Expression, rhs: Expression) -> Expression {
562 Binary
563 .try_new_expr(Operator::Add, [lhs, rhs])
564 .vortex_expect("Failed to create Add binary expression")
565}
566
567#[cfg(test)]
568mod tests {
569 use vortex_dtype::DType;
570 use vortex_dtype::Nullability;
571
572 use super::and;
573 use super::and_collect;
574 use super::and_collect_right;
575 use super::eq;
576 use super::gt;
577 use super::gt_eq;
578 use super::lt;
579 use super::lt_eq;
580 use super::not_eq;
581 use super::or;
582 use crate::expr::Expression;
583 use crate::expr::exprs::get_item::col;
584 use crate::expr::exprs::literal::lit;
585 use crate::expr::test_harness;
586
587 #[test]
588 fn and_collect_left_assoc() {
589 let values = vec![lit(1), lit(2), lit(3)];
590 assert_eq!(
591 Some(and(lit(1), and(lit(2), lit(3)))),
592 and_collect(values.into_iter())
593 );
594 }
595
596 #[test]
597 fn and_collect_right_assoc() {
598 let values = vec![lit(1), lit(2), lit(3)];
599 assert_eq!(
600 Some(and(and(lit(1), lit(2)), lit(3))),
601 and_collect_right(values.into_iter())
602 );
603 }
604
605 #[test]
606 fn dtype() {
607 let dtype = test_harness::struct_dtype();
608 let bool1: Expression = col("bool1");
609 let bool2: Expression = col("bool2");
610 assert_eq!(
611 and(bool1.clone(), bool2.clone())
612 .return_dtype(&dtype)
613 .unwrap(),
614 DType::Bool(Nullability::NonNullable)
615 );
616 assert_eq!(
617 or(bool1, bool2).return_dtype(&dtype).unwrap(),
618 DType::Bool(Nullability::NonNullable)
619 );
620
621 let col1: Expression = col("col1");
622 let col2: Expression = col("col2");
623
624 assert_eq!(
625 eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
626 DType::Bool(Nullability::Nullable)
627 );
628 assert_eq!(
629 not_eq(col1.clone(), col2.clone())
630 .return_dtype(&dtype)
631 .unwrap(),
632 DType::Bool(Nullability::Nullable)
633 );
634 assert_eq!(
635 gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
636 DType::Bool(Nullability::Nullable)
637 );
638 assert_eq!(
639 gt_eq(col1.clone(), col2.clone())
640 .return_dtype(&dtype)
641 .unwrap(),
642 DType::Bool(Nullability::Nullable)
643 );
644 assert_eq!(
645 lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
646 DType::Bool(Nullability::Nullable)
647 );
648 assert_eq!(
649 lt_eq(col1.clone(), col2.clone())
650 .return_dtype(&dtype)
651 .unwrap(),
652 DType::Bool(Nullability::Nullable)
653 );
654
655 assert_eq!(
656 or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
657 .return_dtype(&dtype)
658 .unwrap(),
659 DType::Bool(Nullability::Nullable)
660 );
661 }
662
663 #[test]
664 fn test_display_print() {
665 let expr = gt(lit(1), lit(2));
666 assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
667 }
668}