1use std::fmt::Formatter;
5
6#[expect(deprecated)]
7pub use boolean::and_kleene;
8#[expect(deprecated)]
9pub use boolean::or_kleene;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::dtype::DType;
19use crate::expr::StatsCatalog;
20use crate::expr::and;
21use crate::expr::and_collect;
22use crate::expr::eq;
23use crate::expr::expression::Expression;
24use crate::expr::gt;
25use crate::expr::gt_eq;
26use crate::expr::lit;
27use crate::expr::lt;
28use crate::expr::lt_eq;
29use crate::expr::or_collect;
30use crate::expr::stats::Stat;
31use crate::scalar_fn::Arity;
32use crate::scalar_fn::ChildName;
33use crate::scalar_fn::ExecutionArgs;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36use crate::scalar_fn::fns::operators::CompareOperator;
37use crate::scalar_fn::fns::operators::Operator;
38
39pub(crate) mod boolean;
40pub(crate) use boolean::*;
41mod compare;
42pub use compare::*;
43mod numeric;
44pub(crate) use numeric::*;
45
46use crate::scalar::NumericOperator;
47
48#[derive(Clone)]
49pub struct Binary;
50
51impl ScalarFnVTable for Binary {
52 type Options = Operator;
53
54 fn id(&self) -> ScalarFnId {
55 ScalarFnId::from("vortex.binary")
56 }
57
58 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
59 Ok(Some(
60 pb::BinaryOpts {
61 op: (*instance).into(),
62 }
63 .encode_to_vec(),
64 ))
65 }
66
67 fn deserialize(
68 &self,
69 _metadata: &[u8],
70 _session: &VortexSession,
71 ) -> VortexResult<Self::Options> {
72 let opts = pb::BinaryOpts::decode(_metadata)?;
73 Operator::try_from(opts.op)
74 }
75
76 fn arity(&self, _options: &Self::Options) -> Arity {
77 Arity::Exact(2)
78 }
79
80 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
81 match child_idx {
82 0 => ChildName::from("lhs"),
83 1 => ChildName::from("rhs"),
84 _ => unreachable!("Binary has only two children"),
85 }
86 }
87
88 fn fmt_sql(
89 &self,
90 operator: &Operator,
91 expr: &Expression,
92 f: &mut Formatter<'_>,
93 ) -> std::fmt::Result {
94 write!(f, "(")?;
95 expr.child(0).fmt_sql(f)?;
96 write!(f, " {} ", operator)?;
97 expr.child(1).fmt_sql(f)?;
98 write!(f, ")")
99 }
100
101 fn coerce_args(&self, operator: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
102 let lhs = &args[0];
103 let rhs = &args[1];
104 if operator.is_arithmetic() || operator.is_comparison() {
105 let supertype = lhs.least_supertype(rhs).ok_or_else(|| {
106 vortex_error::vortex_err!("No common supertype for {} and {}", lhs, rhs)
107 })?;
108 Ok(vec![supertype.clone(), supertype])
109 } else {
110 Ok(args.to_vec())
112 }
113 }
114
115 fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
116 let lhs = &arg_dtypes[0];
117 let rhs = &arg_dtypes[1];
118
119 if operator.is_arithmetic() {
120 if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
121 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
122 }
123 vortex_bail!(
124 "incompatible types for arithmetic operation: {} {}",
125 lhs,
126 rhs
127 );
128 }
129
130 if operator.is_comparison()
131 && !lhs.eq_ignore_nullability(rhs)
132 && !lhs.is_extension()
133 && !rhs.is_extension()
134 {
135 vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs);
136 }
137
138 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
139 }
140
141 fn execute(
142 &self,
143 op: &Operator,
144 args: &dyn ExecutionArgs,
145 _ctx: &mut ExecutionCtx,
146 ) -> VortexResult<ArrayRef> {
147 let lhs = args.get(0)?;
148 let rhs = args.get(1)?;
149
150 match op {
151 Operator::Eq => execute_compare(&lhs, &rhs, CompareOperator::Eq),
152 Operator::NotEq => execute_compare(&lhs, &rhs, CompareOperator::NotEq),
153 Operator::Lt => execute_compare(&lhs, &rhs, CompareOperator::Lt),
154 Operator::Lte => execute_compare(&lhs, &rhs, CompareOperator::Lte),
155 Operator::Gt => execute_compare(&lhs, &rhs, CompareOperator::Gt),
156 Operator::Gte => execute_compare(&lhs, &rhs, CompareOperator::Gte),
157 Operator::And => execute_boolean(&lhs, &rhs, Operator::And),
158 Operator::Or => execute_boolean(&lhs, &rhs, Operator::Or),
159 Operator::Add => execute_numeric(&lhs, &rhs, NumericOperator::Add),
160 Operator::Sub => execute_numeric(&lhs, &rhs, NumericOperator::Sub),
161 Operator::Mul => execute_numeric(&lhs, &rhs, NumericOperator::Mul),
162 Operator::Div => execute_numeric(&lhs, &rhs, NumericOperator::Div),
163 }
164 }
165
166 fn stat_falsification(
167 &self,
168 operator: &Operator,
169 expr: &Expression,
170 catalog: &dyn StatsCatalog,
171 ) -> Option<Expression> {
172 #[inline]
186 fn with_nan_predicate(
187 lhs: &Expression,
188 rhs: &Expression,
189 value_predicate: Expression,
190 catalog: &dyn StatsCatalog,
191 ) -> Expression {
192 let nan_predicate = and_collect(
193 lhs.stat_expression(Stat::NaNCount, catalog)
194 .into_iter()
195 .chain(rhs.stat_expression(Stat::NaNCount, catalog))
196 .map(|nans| eq(nans, lit(0u64))),
197 );
198
199 if let Some(nan_check) = nan_predicate {
200 and(nan_check, value_predicate)
201 } else {
202 value_predicate
203 }
204 }
205
206 let lhs = expr.child(0);
207 let rhs = expr.child(1);
208 match operator {
209 Operator::Eq => {
210 let min_lhs = lhs.stat_min(catalog);
211 let max_lhs = lhs.stat_max(catalog);
212
213 let min_rhs = rhs.stat_min(catalog);
214 let max_rhs = rhs.stat_max(catalog);
215
216 let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
217 let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
218
219 let min_max_check = or_collect(left.into_iter().chain(right))?;
220
221 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
223 }
224 Operator::NotEq => {
225 let min_lhs = lhs.stat_min(catalog)?;
226 let max_lhs = lhs.stat_max(catalog)?;
227
228 let min_rhs = rhs.stat_min(catalog)?;
229 let max_rhs = rhs.stat_max(catalog)?;
230
231 let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
232
233 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
234 }
235 Operator::Gt => {
236 let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
237
238 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
239 }
240 Operator::Gte => {
241 let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
243
244 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
245 }
246 Operator::Lt => {
247 let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
249
250 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
251 }
252 Operator::Lte => {
253 let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
255
256 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
257 }
258 Operator::And => or_collect(
259 lhs.stat_falsification(catalog)
260 .into_iter()
261 .chain(rhs.stat_falsification(catalog)),
262 ),
263 Operator::Or => Some(and(
264 lhs.stat_falsification(catalog)?,
265 rhs.stat_falsification(catalog)?,
266 )),
267 Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
268 }
269 }
270
271 fn validity(
272 &self,
273 operator: &Operator,
274 expression: &Expression,
275 ) -> VortexResult<Option<Expression>> {
276 let lhs = expression.child(0).validity()?;
277 let rhs = expression.child(1).validity()?;
278
279 Ok(match operator {
280 Operator::And => None,
282 Operator::Or => None,
283 _ => {
284 Some(and(lhs, rhs))
286 }
287 })
288 }
289
290 fn is_null_sensitive(&self, _operator: &Operator) -> bool {
291 false
292 }
293
294 fn is_fallible(&self, operator: &Operator) -> bool {
295 let infallible = matches!(
298 operator,
299 Operator::Eq
300 | Operator::NotEq
301 | Operator::Gt
302 | Operator::Gte
303 | Operator::Lt
304 | Operator::Lte
305 | Operator::And
306 | Operator::Or
307 );
308
309 !infallible
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use vortex_error::VortexExpect;
316
317 use super::*;
318 use crate::assert_arrays_eq;
319 use crate::builtins::ArrayBuiltins;
320 use crate::dtype::DType;
321 use crate::dtype::Nullability;
322 use crate::expr::Expression;
323 use crate::expr::and_collect;
324 use crate::expr::col;
325 use crate::expr::lit;
326 use crate::expr::lt;
327 use crate::expr::not_eq;
328 use crate::expr::or;
329 use crate::expr::or_collect;
330 use crate::expr::test_harness;
331 use crate::scalar::Scalar;
332 #[test]
333 fn and_collect_balanced() {
334 let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
335
336 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
337 vortex.binary(and)
338 ├── lhs: vortex.binary(and)
339 │ ├── lhs: vortex.literal(1i32)
340 │ └── rhs: vortex.literal(2i32)
341 └── rhs: vortex.binary(and)
342 ├── lhs: vortex.binary(and)
343 │ ├── lhs: vortex.literal(3i32)
344 │ └── rhs: vortex.literal(4i32)
345 └── rhs: vortex.literal(5i32)
346 ");
347
348 let values = vec![lit(1), lit(2), lit(3), lit(4)];
350 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
351 vortex.binary(and)
352 ├── lhs: vortex.binary(and)
353 │ ├── lhs: vortex.literal(1i32)
354 │ └── rhs: vortex.literal(2i32)
355 └── rhs: vortex.binary(and)
356 ├── lhs: vortex.literal(3i32)
357 └── rhs: vortex.literal(4i32)
358 ");
359
360 let values = vec![lit(1)];
362 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @"vortex.literal(1i32)");
363
364 let values: Vec<Expression> = vec![];
366 assert!(and_collect(values.into_iter()).is_none());
367 }
368
369 #[test]
370 fn or_collect_balanced() {
371 let values = vec![lit(1), lit(2), lit(3), lit(4)];
373 insta::assert_snapshot!(or_collect(values.into_iter()).unwrap().display_tree(), @r"
374 vortex.binary(or)
375 ├── lhs: vortex.binary(or)
376 │ ├── lhs: vortex.literal(1i32)
377 │ └── rhs: vortex.literal(2i32)
378 └── rhs: vortex.binary(or)
379 ├── lhs: vortex.literal(3i32)
380 └── rhs: vortex.literal(4i32)
381 ");
382 }
383
384 #[test]
385 fn dtype() {
386 let dtype = test_harness::struct_dtype();
387 let bool1: Expression = col("bool1");
388 let bool2: Expression = col("bool2");
389 assert_eq!(
390 and(bool1.clone(), bool2.clone())
391 .return_dtype(&dtype)
392 .unwrap(),
393 DType::Bool(Nullability::NonNullable)
394 );
395 assert_eq!(
396 or(bool1, bool2).return_dtype(&dtype).unwrap(),
397 DType::Bool(Nullability::NonNullable)
398 );
399
400 let col1: Expression = col("col1");
401 let col2: Expression = col("col2");
402
403 assert_eq!(
404 eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
405 DType::Bool(Nullability::Nullable)
406 );
407 assert_eq!(
408 not_eq(col1.clone(), col2.clone())
409 .return_dtype(&dtype)
410 .unwrap(),
411 DType::Bool(Nullability::Nullable)
412 );
413 assert_eq!(
414 gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
415 DType::Bool(Nullability::Nullable)
416 );
417 assert_eq!(
418 gt_eq(col1.clone(), col2.clone())
419 .return_dtype(&dtype)
420 .unwrap(),
421 DType::Bool(Nullability::Nullable)
422 );
423 assert_eq!(
424 lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
425 DType::Bool(Nullability::Nullable)
426 );
427 assert_eq!(
428 lt_eq(col1.clone(), col2.clone())
429 .return_dtype(&dtype)
430 .unwrap(),
431 DType::Bool(Nullability::Nullable)
432 );
433
434 assert_eq!(
435 or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
436 .return_dtype(&dtype)
437 .unwrap(),
438 DType::Bool(Nullability::Nullable)
439 );
440 }
441
442 #[test]
443 fn test_display_print() {
444 let expr = gt(lit(1), lit(2));
445 assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
446 }
447
448 #[test]
451 fn test_struct_comparison() {
452 use crate::IntoArray;
453 use crate::arrays::StructArray;
454
455 let lhs_struct = StructArray::from_fields(&[
457 (
458 "a",
459 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
460 ),
461 (
462 "b",
463 crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
464 ),
465 ])
466 .unwrap()
467 .into_array();
468
469 let rhs_struct_equal = StructArray::from_fields(&[
470 (
471 "a",
472 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
473 ),
474 (
475 "b",
476 crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
477 ),
478 ])
479 .unwrap()
480 .into_array();
481
482 let rhs_struct_different = StructArray::from_fields(&[
483 (
484 "a",
485 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
486 ),
487 (
488 "b",
489 crate::arrays::PrimitiveArray::from_iter([4i32]).into_array(),
490 ),
491 ])
492 .unwrap()
493 .into_array();
494
495 let result_equal = lhs_struct.binary(rhs_struct_equal, Operator::Eq).unwrap();
497 assert_eq!(
498 result_equal.scalar_at(0).vortex_expect("value"),
499 Scalar::bool(true, Nullability::NonNullable),
500 "Equal structs should be equal"
501 );
502
503 let result_different = lhs_struct
504 .binary(rhs_struct_different, Operator::Eq)
505 .unwrap();
506 assert_eq!(
507 result_different.scalar_at(0).vortex_expect("value"),
508 Scalar::bool(false, Nullability::NonNullable),
509 "Different structs should not be equal"
510 );
511 }
512
513 #[test]
514 fn test_or_kleene_validity() {
515 use crate::IntoArray;
516 use crate::arrays::BoolArray;
517 use crate::arrays::StructArray;
518 use crate::expr::col;
519
520 let struct_arr = StructArray::from_fields(&[
521 ("a", BoolArray::from_iter([Some(true)]).into_array()),
522 (
523 "b",
524 BoolArray::from_iter([Option::<bool>::None]).into_array(),
525 ),
526 ])
527 .unwrap()
528 .into_array();
529
530 let expr = or(col("a"), col("b"));
531 let result = struct_arr.apply(&expr).unwrap();
532
533 assert_arrays_eq!(result, BoolArray::from_iter([Some(true)]).into_array())
534 }
535
536 #[test]
537 fn test_scalar_subtract_unsigned() {
538 use vortex_buffer::buffer;
539
540 use crate::IntoArray;
541 use crate::arrays::ConstantArray;
542 use crate::arrays::PrimitiveArray;
543
544 let values = buffer![1u16, 2, 3].into_array();
545 let rhs = ConstantArray::new(Scalar::from(1u16), 3).into_array();
546 let result = values.binary(rhs, Operator::Sub).unwrap();
547 assert_arrays_eq!(result, PrimitiveArray::from_iter([0u16, 1, 2]));
548 }
549
550 #[test]
551 fn test_scalar_subtract_signed() {
552 use vortex_buffer::buffer;
553
554 use crate::IntoArray;
555 use crate::arrays::ConstantArray;
556 use crate::arrays::PrimitiveArray;
557
558 let values = buffer![1i64, 2, 3].into_array();
559 let rhs = ConstantArray::new(Scalar::from(-1i64), 3).into_array();
560 let result = values.binary(rhs, Operator::Sub).unwrap();
561 assert_arrays_eq!(result, PrimitiveArray::from_iter([2i64, 3, 4]));
562 }
563
564 #[test]
565 fn test_scalar_subtract_nullable() {
566 use crate::IntoArray;
567 use crate::arrays::ConstantArray;
568 use crate::arrays::PrimitiveArray;
569
570 let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
571 let rhs = ConstantArray::new(Scalar::from(Some(1u16)), 4).into_array();
572 let result = values.into_array().binary(rhs, Operator::Sub).unwrap();
573 assert_arrays_eq!(
574 result,
575 PrimitiveArray::from_option_iter([Some(0u16), Some(1), None, Some(2)])
576 );
577 }
578
579 #[test]
580 fn test_scalar_subtract_float() {
581 use vortex_buffer::buffer;
582
583 use crate::IntoArray;
584 use crate::arrays::ConstantArray;
585 use crate::arrays::PrimitiveArray;
586
587 let values = buffer![1.0f64, 2.0, 3.0].into_array();
588 let rhs = ConstantArray::new(Scalar::from(-1f64), 3).into_array();
589 let result = values.binary(rhs, Operator::Sub).unwrap();
590 assert_arrays_eq!(result, PrimitiveArray::from_iter([2.0f64, 3.0, 4.0]));
591 }
592
593 #[test]
594 fn test_scalar_subtract_float_underflow_is_ok() {
595 use vortex_buffer::buffer;
596
597 use crate::IntoArray;
598 use crate::arrays::ConstantArray;
599
600 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
601 let rhs1 = ConstantArray::new(Scalar::from(1.0f32), 3).into_array();
602 let _results = values.binary(rhs1, Operator::Sub).unwrap();
603 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
604 let rhs2 = ConstantArray::new(Scalar::from(f32::MAX), 3).into_array();
605 let _results = values.binary(rhs2, Operator::Sub).unwrap();
606 }
607}