1use std::fmt::Formatter;
5
6#[expect(deprecated)]
7pub use boolean::and_kleene;
8#[expect(deprecated)]
9pub use boolean::or_kleene;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::dtype::DType;
19use crate::expr::StatsCatalog;
20use crate::expr::and;
21use crate::expr::and_collect;
22use crate::expr::eq;
23use crate::expr::expression::Expression;
24use crate::expr::gt;
25use crate::expr::gt_eq;
26use crate::expr::lit;
27use crate::expr::lt;
28use crate::expr::lt_eq;
29use crate::expr::or_collect;
30use crate::expr::stats::Stat;
31use crate::scalar_fn::Arity;
32use crate::scalar_fn::ChildName;
33use crate::scalar_fn::ExecutionArgs;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36use crate::scalar_fn::fns::operators::CompareOperator;
37use crate::scalar_fn::fns::operators::Operator;
38
39pub(crate) mod boolean;
40pub(crate) use boolean::*;
41mod compare;
42pub use compare::*;
43mod numeric;
44pub(crate) use numeric::*;
45
46use crate::scalar::NumericOperator;
47
48#[derive(Clone)]
49pub struct Binary;
50
51impl ScalarFnVTable for Binary {
52 type Options = Operator;
53
54 fn id(&self) -> ScalarFnId {
55 ScalarFnId::from("vortex.binary")
56 }
57
58 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
59 Ok(Some(
60 pb::BinaryOpts {
61 op: (*instance).into(),
62 }
63 .encode_to_vec(),
64 ))
65 }
66
67 fn deserialize(
68 &self,
69 _metadata: &[u8],
70 _session: &VortexSession,
71 ) -> VortexResult<Self::Options> {
72 let opts = pb::BinaryOpts::decode(_metadata)?;
73 Operator::try_from(opts.op)
74 }
75
76 fn arity(&self, _options: &Self::Options) -> Arity {
77 Arity::Exact(2)
78 }
79
80 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
81 match child_idx {
82 0 => ChildName::from("lhs"),
83 1 => ChildName::from("rhs"),
84 _ => unreachable!("Binary has only two children"),
85 }
86 }
87
88 fn fmt_sql(
89 &self,
90 operator: &Operator,
91 expr: &Expression,
92 f: &mut Formatter<'_>,
93 ) -> std::fmt::Result {
94 write!(f, "(")?;
95 expr.child(0).fmt_sql(f)?;
96 write!(f, " {} ", operator)?;
97 expr.child(1).fmt_sql(f)?;
98 write!(f, ")")
99 }
100
101 fn coerce_args(&self, operator: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
102 let lhs = &args[0];
103 let rhs = &args[1];
104 if operator.is_arithmetic() || operator.is_comparison() {
105 let supertype = lhs.least_supertype(rhs).ok_or_else(|| {
106 vortex_error::vortex_err!("No common supertype for {} and {}", lhs, rhs)
107 })?;
108 Ok(vec![supertype.clone(), supertype])
109 } else {
110 Ok(args.to_vec())
112 }
113 }
114
115 fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
116 let lhs = &arg_dtypes[0];
117 let rhs = &arg_dtypes[1];
118
119 if operator.is_arithmetic() {
120 if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
121 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
122 }
123 vortex_bail!(
124 "incompatible types for arithmetic operation: {} {}",
125 lhs,
126 rhs
127 );
128 }
129
130 if operator.is_comparison()
131 && !lhs.eq_ignore_nullability(rhs)
132 && !lhs.is_extension()
133 && !rhs.is_extension()
134 {
135 vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs);
136 }
137
138 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
139 }
140
141 fn execute(
142 &self,
143 op: &Operator,
144 args: &dyn ExecutionArgs,
145 _ctx: &mut ExecutionCtx,
146 ) -> VortexResult<ArrayRef> {
147 let lhs = args.get(0)?;
148 let rhs = args.get(1)?;
149
150 match op {
151 Operator::Eq => execute_compare(&lhs, &rhs, CompareOperator::Eq),
152 Operator::NotEq => execute_compare(&lhs, &rhs, CompareOperator::NotEq),
153 Operator::Lt => execute_compare(&lhs, &rhs, CompareOperator::Lt),
154 Operator::Lte => execute_compare(&lhs, &rhs, CompareOperator::Lte),
155 Operator::Gt => execute_compare(&lhs, &rhs, CompareOperator::Gt),
156 Operator::Gte => execute_compare(&lhs, &rhs, CompareOperator::Gte),
157 Operator::And => execute_boolean(&lhs, &rhs, Operator::And),
158 Operator::Or => execute_boolean(&lhs, &rhs, Operator::Or),
159 Operator::Add => execute_numeric(&lhs, &rhs, NumericOperator::Add),
160 Operator::Sub => execute_numeric(&lhs, &rhs, NumericOperator::Sub),
161 Operator::Mul => execute_numeric(&lhs, &rhs, NumericOperator::Mul),
162 Operator::Div => execute_numeric(&lhs, &rhs, NumericOperator::Div),
163 }
164 }
165
166 fn stat_falsification(
167 &self,
168 operator: &Operator,
169 expr: &Expression,
170 catalog: &dyn StatsCatalog,
171 ) -> Option<Expression> {
172 fn with_nan_predicate(
186 lhs: &Expression,
187 rhs: &Expression,
188 value_predicate: Expression,
189 catalog: &dyn StatsCatalog,
190 ) -> Expression {
191 let nan_predicate = and_collect(
192 lhs.stat_expression(Stat::NaNCount, catalog)
193 .into_iter()
194 .chain(rhs.stat_expression(Stat::NaNCount, catalog))
195 .map(|nans| eq(nans, lit(0u64))),
196 );
197
198 if let Some(nan_check) = nan_predicate {
199 and(nan_check, value_predicate)
200 } else {
201 value_predicate
202 }
203 }
204
205 let lhs = expr.child(0);
206 let rhs = expr.child(1);
207 match operator {
208 Operator::Eq => {
209 let min_lhs = lhs.stat_min(catalog);
210 let max_lhs = lhs.stat_max(catalog);
211
212 let min_rhs = rhs.stat_min(catalog);
213 let max_rhs = rhs.stat_max(catalog);
214
215 let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
216 let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
217
218 let min_max_check = or_collect(left.into_iter().chain(right))?;
219
220 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
222 }
223 Operator::NotEq => {
224 let min_lhs = lhs.stat_min(catalog)?;
225 let max_lhs = lhs.stat_max(catalog)?;
226
227 let min_rhs = rhs.stat_min(catalog)?;
228 let max_rhs = rhs.stat_max(catalog)?;
229
230 let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
231
232 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
233 }
234 Operator::Gt => {
235 let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
236
237 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
238 }
239 Operator::Gte => {
240 let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
242
243 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
244 }
245 Operator::Lt => {
246 let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
248
249 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
250 }
251 Operator::Lte => {
252 let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
254
255 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
256 }
257 Operator::And => or_collect(
258 lhs.stat_falsification(catalog)
259 .into_iter()
260 .chain(rhs.stat_falsification(catalog)),
261 ),
262 Operator::Or => Some(and(
263 lhs.stat_falsification(catalog)?,
264 rhs.stat_falsification(catalog)?,
265 )),
266 Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
267 }
268 }
269
270 fn validity(
271 &self,
272 operator: &Operator,
273 expression: &Expression,
274 ) -> VortexResult<Option<Expression>> {
275 let lhs = expression.child(0).validity()?;
276 let rhs = expression.child(1).validity()?;
277
278 Ok(match operator {
279 Operator::And => None,
281 Operator::Or => None,
282 _ => {
283 Some(and(lhs, rhs))
285 }
286 })
287 }
288
289 fn is_null_sensitive(&self, _operator: &Operator) -> bool {
290 false
291 }
292
293 fn is_fallible(&self, operator: &Operator) -> bool {
294 let infallible = matches!(
297 operator,
298 Operator::Eq
299 | Operator::NotEq
300 | Operator::Gt
301 | Operator::Gte
302 | Operator::Lt
303 | Operator::Lte
304 | Operator::And
305 | Operator::Or
306 );
307
308 !infallible
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use vortex_error::VortexExpect;
315
316 use super::*;
317 use crate::assert_arrays_eq;
318 use crate::builtins::ArrayBuiltins;
319 use crate::dtype::DType;
320 use crate::dtype::Nullability;
321 use crate::expr::Expression;
322 use crate::expr::and_collect;
323 use crate::expr::col;
324 use crate::expr::lit;
325 use crate::expr::lt;
326 use crate::expr::not_eq;
327 use crate::expr::or;
328 use crate::expr::or_collect;
329 use crate::expr::test_harness;
330 use crate::scalar::Scalar;
331 #[test]
332 fn and_collect_balanced() {
333 let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
334
335 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
336 vortex.binary(and)
337 ├── lhs: vortex.binary(and)
338 │ ├── lhs: vortex.literal(1i32)
339 │ └── rhs: vortex.literal(2i32)
340 └── rhs: vortex.binary(and)
341 ├── lhs: vortex.binary(and)
342 │ ├── lhs: vortex.literal(3i32)
343 │ └── rhs: vortex.literal(4i32)
344 └── rhs: vortex.literal(5i32)
345 ");
346
347 let values = vec![lit(1), lit(2), lit(3), lit(4)];
349 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
350 vortex.binary(and)
351 ├── lhs: vortex.binary(and)
352 │ ├── lhs: vortex.literal(1i32)
353 │ └── rhs: vortex.literal(2i32)
354 └── rhs: vortex.binary(and)
355 ├── lhs: vortex.literal(3i32)
356 └── rhs: vortex.literal(4i32)
357 ");
358
359 let values = vec![lit(1)];
361 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @"vortex.literal(1i32)");
362
363 let values: Vec<Expression> = vec![];
365 assert!(and_collect(values.into_iter()).is_none());
366 }
367
368 #[test]
369 fn or_collect_balanced() {
370 let values = vec![lit(1), lit(2), lit(3), lit(4)];
372 insta::assert_snapshot!(or_collect(values.into_iter()).unwrap().display_tree(), @r"
373 vortex.binary(or)
374 ├── lhs: vortex.binary(or)
375 │ ├── lhs: vortex.literal(1i32)
376 │ └── rhs: vortex.literal(2i32)
377 └── rhs: vortex.binary(or)
378 ├── lhs: vortex.literal(3i32)
379 └── rhs: vortex.literal(4i32)
380 ");
381 }
382
383 #[test]
384 fn dtype() {
385 let dtype = test_harness::struct_dtype();
386 let bool1: Expression = col("bool1");
387 let bool2: Expression = col("bool2");
388 assert_eq!(
389 and(bool1.clone(), bool2.clone())
390 .return_dtype(&dtype)
391 .unwrap(),
392 DType::Bool(Nullability::NonNullable)
393 );
394 assert_eq!(
395 or(bool1, bool2).return_dtype(&dtype).unwrap(),
396 DType::Bool(Nullability::NonNullable)
397 );
398
399 let col1: Expression = col("col1");
400 let col2: Expression = col("col2");
401
402 assert_eq!(
403 eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
404 DType::Bool(Nullability::Nullable)
405 );
406 assert_eq!(
407 not_eq(col1.clone(), col2.clone())
408 .return_dtype(&dtype)
409 .unwrap(),
410 DType::Bool(Nullability::Nullable)
411 );
412 assert_eq!(
413 gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
414 DType::Bool(Nullability::Nullable)
415 );
416 assert_eq!(
417 gt_eq(col1.clone(), col2.clone())
418 .return_dtype(&dtype)
419 .unwrap(),
420 DType::Bool(Nullability::Nullable)
421 );
422 assert_eq!(
423 lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
424 DType::Bool(Nullability::Nullable)
425 );
426 assert_eq!(
427 lt_eq(col1.clone(), col2.clone())
428 .return_dtype(&dtype)
429 .unwrap(),
430 DType::Bool(Nullability::Nullable)
431 );
432
433 assert_eq!(
434 or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
435 .return_dtype(&dtype)
436 .unwrap(),
437 DType::Bool(Nullability::Nullable)
438 );
439 }
440
441 #[test]
442 fn test_display_print() {
443 let expr = gt(lit(1), lit(2));
444 assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
445 }
446
447 #[test]
450 fn test_struct_comparison() {
451 use crate::IntoArray;
452 use crate::arrays::StructArray;
453
454 let lhs_struct = StructArray::from_fields(&[
456 (
457 "a",
458 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
459 ),
460 (
461 "b",
462 crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
463 ),
464 ])
465 .unwrap()
466 .into_array();
467
468 let rhs_struct_equal = StructArray::from_fields(&[
469 (
470 "a",
471 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
472 ),
473 (
474 "b",
475 crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
476 ),
477 ])
478 .unwrap()
479 .into_array();
480
481 let rhs_struct_different = StructArray::from_fields(&[
482 (
483 "a",
484 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
485 ),
486 (
487 "b",
488 crate::arrays::PrimitiveArray::from_iter([4i32]).into_array(),
489 ),
490 ])
491 .unwrap()
492 .into_array();
493
494 let result_equal = lhs_struct.binary(rhs_struct_equal, Operator::Eq).unwrap();
496 assert_eq!(
497 result_equal.scalar_at(0).vortex_expect("value"),
498 Scalar::bool(true, Nullability::NonNullable),
499 "Equal structs should be equal"
500 );
501
502 let result_different = lhs_struct
503 .binary(rhs_struct_different, Operator::Eq)
504 .unwrap();
505 assert_eq!(
506 result_different.scalar_at(0).vortex_expect("value"),
507 Scalar::bool(false, Nullability::NonNullable),
508 "Different structs should not be equal"
509 );
510 }
511
512 #[test]
513 fn test_or_kleene_validity() {
514 use crate::IntoArray;
515 use crate::arrays::BoolArray;
516 use crate::arrays::StructArray;
517 use crate::expr::col;
518
519 let struct_arr = StructArray::from_fields(&[
520 ("a", BoolArray::from_iter([Some(true)]).into_array()),
521 (
522 "b",
523 BoolArray::from_iter([Option::<bool>::None]).into_array(),
524 ),
525 ])
526 .unwrap()
527 .into_array();
528
529 let expr = or(col("a"), col("b"));
530 let result = struct_arr.apply(&expr).unwrap();
531
532 assert_arrays_eq!(result, BoolArray::from_iter([Some(true)]).into_array())
533 }
534
535 #[test]
536 fn test_scalar_subtract_unsigned() {
537 use vortex_buffer::buffer;
538
539 use crate::IntoArray;
540 use crate::arrays::ConstantArray;
541 use crate::arrays::PrimitiveArray;
542
543 let values = buffer![1u16, 2, 3].into_array();
544 let rhs = ConstantArray::new(Scalar::from(1u16), 3).into_array();
545 let result = values.binary(rhs, Operator::Sub).unwrap();
546 assert_arrays_eq!(result, PrimitiveArray::from_iter([0u16, 1, 2]));
547 }
548
549 #[test]
550 fn test_scalar_subtract_signed() {
551 use vortex_buffer::buffer;
552
553 use crate::IntoArray;
554 use crate::arrays::ConstantArray;
555 use crate::arrays::PrimitiveArray;
556
557 let values = buffer![1i64, 2, 3].into_array();
558 let rhs = ConstantArray::new(Scalar::from(-1i64), 3).into_array();
559 let result = values.binary(rhs, Operator::Sub).unwrap();
560 assert_arrays_eq!(result, PrimitiveArray::from_iter([2i64, 3, 4]));
561 }
562
563 #[test]
564 fn test_scalar_subtract_nullable() {
565 use crate::IntoArray;
566 use crate::arrays::ConstantArray;
567 use crate::arrays::PrimitiveArray;
568
569 let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
570 let rhs = ConstantArray::new(Scalar::from(Some(1u16)), 4).into_array();
571 let result = values.into_array().binary(rhs, Operator::Sub).unwrap();
572 assert_arrays_eq!(
573 result,
574 PrimitiveArray::from_option_iter([Some(0u16), Some(1), None, Some(2)])
575 );
576 }
577
578 #[test]
579 fn test_scalar_subtract_float() {
580 use vortex_buffer::buffer;
581
582 use crate::IntoArray;
583 use crate::arrays::ConstantArray;
584 use crate::arrays::PrimitiveArray;
585
586 let values = buffer![1.0f64, 2.0, 3.0].into_array();
587 let rhs = ConstantArray::new(Scalar::from(-1f64), 3).into_array();
588 let result = values.binary(rhs, Operator::Sub).unwrap();
589 assert_arrays_eq!(result, PrimitiveArray::from_iter([2.0f64, 3.0, 4.0]));
590 }
591
592 #[test]
593 fn test_scalar_subtract_float_underflow_is_ok() {
594 use vortex_buffer::buffer;
595
596 use crate::IntoArray;
597 use crate::arrays::ConstantArray;
598
599 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
600 let rhs1 = ConstantArray::new(Scalar::from(1.0f32), 3).into_array();
601 let _results = values.binary(rhs1, Operator::Sub).unwrap();
602 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
603 let rhs2 = ConstantArray::new(Scalar::from(f32::MAX), 3).into_array();
604 let _results = values.binary(rhs2, Operator::Sub).unwrap();
605 }
606}