1use std::fmt::Formatter;
5
6#[expect(deprecated)]
7pub use boolean::and_kleene;
8#[expect(deprecated)]
9pub use boolean::or_kleene;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15use vortex_session::registry::CachedId;
16
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::dtype::DType;
20use crate::expr::StatsCatalog;
21use crate::expr::and;
22use crate::expr::and_collect;
23use crate::expr::eq;
24use crate::expr::expression::Expression;
25use crate::expr::gt;
26use crate::expr::gt_eq;
27use crate::expr::lit;
28use crate::expr::lt;
29use crate::expr::lt_eq;
30use crate::expr::or_collect;
31use crate::expr::stats::Stat;
32use crate::scalar_fn::Arity;
33use crate::scalar_fn::ChildName;
34use crate::scalar_fn::ExecutionArgs;
35use crate::scalar_fn::ScalarFnId;
36use crate::scalar_fn::ScalarFnVTable;
37use crate::scalar_fn::fns::operators::CompareOperator;
38use crate::scalar_fn::fns::operators::Operator;
39
40pub(crate) mod boolean;
41pub(crate) use boolean::*;
42mod compare;
43pub use compare::*;
44mod numeric;
45pub(crate) use numeric::*;
46
47use crate::scalar::NumericOperator;
48
49#[derive(Clone)]
50pub struct Binary;
51
52impl ScalarFnVTable for Binary {
53 type Options = Operator;
54
55 fn id(&self) -> ScalarFnId {
56 static ID: CachedId = CachedId::new("vortex.binary");
57 *ID
58 }
59
60 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
61 Ok(Some(
62 pb::BinaryOpts {
63 op: (*instance).into(),
64 }
65 .encode_to_vec(),
66 ))
67 }
68
69 fn deserialize(
70 &self,
71 _metadata: &[u8],
72 _session: &VortexSession,
73 ) -> VortexResult<Self::Options> {
74 let opts = pb::BinaryOpts::decode(_metadata)?;
75 Operator::try_from(opts.op)
76 }
77
78 fn arity(&self, _options: &Self::Options) -> Arity {
79 Arity::Exact(2)
80 }
81
82 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
83 match child_idx {
84 0 => ChildName::from("lhs"),
85 1 => ChildName::from("rhs"),
86 _ => unreachable!("Binary has only two children"),
87 }
88 }
89
90 fn fmt_sql(
91 &self,
92 operator: &Operator,
93 expr: &Expression,
94 f: &mut Formatter<'_>,
95 ) -> std::fmt::Result {
96 write!(f, "(")?;
97 expr.child(0).fmt_sql(f)?;
98 write!(f, " {} ", operator)?;
99 expr.child(1).fmt_sql(f)?;
100 write!(f, ")")
101 }
102
103 fn coerce_args(&self, operator: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
104 let lhs = &args[0];
105 let rhs = &args[1];
106 if operator.is_arithmetic() || operator.is_comparison() {
107 let supertype = lhs.least_supertype(rhs).ok_or_else(|| {
108 vortex_error::vortex_err!("No common supertype for {} and {}", lhs, rhs)
109 })?;
110 Ok(vec![supertype.clone(), supertype])
111 } else {
112 Ok(args.to_vec())
114 }
115 }
116
117 fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
118 let lhs = &arg_dtypes[0];
119 let rhs = &arg_dtypes[1];
120
121 if operator.is_arithmetic() {
122 if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
123 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
124 }
125 vortex_bail!(
126 "incompatible types for arithmetic operation: {} {}",
127 lhs,
128 rhs
129 );
130 }
131
132 if operator.is_comparison()
133 && !lhs.eq_ignore_nullability(rhs)
134 && !lhs.is_extension()
135 && !rhs.is_extension()
136 {
137 vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs);
138 }
139
140 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
141 }
142
143 fn execute(
144 &self,
145 op: &Operator,
146 args: &dyn ExecutionArgs,
147 ctx: &mut ExecutionCtx,
148 ) -> VortexResult<ArrayRef> {
149 let lhs = args.get(0)?;
150 let rhs = args.get(1)?;
151
152 match op {
153 Operator::Eq => execute_compare(&lhs, &rhs, CompareOperator::Eq, ctx),
154 Operator::NotEq => execute_compare(&lhs, &rhs, CompareOperator::NotEq, ctx),
155 Operator::Lt => execute_compare(&lhs, &rhs, CompareOperator::Lt, ctx),
156 Operator::Lte => execute_compare(&lhs, &rhs, CompareOperator::Lte, ctx),
157 Operator::Gt => execute_compare(&lhs, &rhs, CompareOperator::Gt, ctx),
158 Operator::Gte => execute_compare(&lhs, &rhs, CompareOperator::Gte, ctx),
159 Operator::And => execute_boolean(&lhs, &rhs, Operator::And, ctx),
160 Operator::Or => execute_boolean(&lhs, &rhs, Operator::Or, ctx),
161 Operator::Add => execute_numeric(&lhs, &rhs, NumericOperator::Add, ctx),
162 Operator::Sub => execute_numeric(&lhs, &rhs, NumericOperator::Sub, ctx),
163 Operator::Mul => execute_numeric(&lhs, &rhs, NumericOperator::Mul, ctx),
164 Operator::Div => execute_numeric(&lhs, &rhs, NumericOperator::Div, ctx),
165 }
166 }
167
168 fn stat_falsification(
169 &self,
170 operator: &Operator,
171 expr: &Expression,
172 catalog: &dyn StatsCatalog,
173 ) -> Option<Expression> {
174 fn with_nan_predicate(
188 lhs: &Expression,
189 rhs: &Expression,
190 value_predicate: Expression,
191 catalog: &dyn StatsCatalog,
192 ) -> Expression {
193 let nan_predicate = and_collect(
194 lhs.stat_expression(Stat::NaNCount, catalog)
195 .into_iter()
196 .chain(rhs.stat_expression(Stat::NaNCount, catalog))
197 .map(|nans| eq(nans, lit(0u64))),
198 );
199
200 if let Some(nan_check) = nan_predicate {
201 and(nan_check, value_predicate)
202 } else {
203 value_predicate
204 }
205 }
206
207 let lhs = expr.child(0);
208 let rhs = expr.child(1);
209 match operator {
210 Operator::Eq => {
211 let min_lhs = lhs.stat_min(catalog);
212 let max_lhs = lhs.stat_max(catalog);
213
214 let min_rhs = rhs.stat_min(catalog);
215 let max_rhs = rhs.stat_max(catalog);
216
217 let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
218 let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
219
220 let min_max_check = or_collect(left.into_iter().chain(right))?;
221
222 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
224 }
225 Operator::NotEq => {
226 let min_lhs = lhs.stat_min(catalog)?;
227 let max_lhs = lhs.stat_max(catalog)?;
228
229 let min_rhs = rhs.stat_min(catalog)?;
230 let max_rhs = rhs.stat_max(catalog)?;
231
232 let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
233
234 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
235 }
236 Operator::Gt => {
237 let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
238
239 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
240 }
241 Operator::Gte => {
242 let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
244
245 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
246 }
247 Operator::Lt => {
248 let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
250
251 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
252 }
253 Operator::Lte => {
254 let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
256
257 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
258 }
259 Operator::And => or_collect(
260 lhs.stat_falsification(catalog)
261 .into_iter()
262 .chain(rhs.stat_falsification(catalog)),
263 ),
264 Operator::Or => Some(and(
265 lhs.stat_falsification(catalog)?,
266 rhs.stat_falsification(catalog)?,
267 )),
268 Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
269 }
270 }
271
272 fn validity(
273 &self,
274 operator: &Operator,
275 expression: &Expression,
276 ) -> VortexResult<Option<Expression>> {
277 let lhs = expression.child(0).validity()?;
278 let rhs = expression.child(1).validity()?;
279
280 Ok(match operator {
281 Operator::And => None,
283 Operator::Or => None,
284 _ => {
285 Some(and(lhs, rhs))
287 }
288 })
289 }
290
291 fn is_null_sensitive(&self, _operator: &Operator) -> bool {
292 false
293 }
294
295 fn is_fallible(&self, operator: &Operator) -> bool {
296 let infallible = matches!(
299 operator,
300 Operator::Eq
301 | Operator::NotEq
302 | Operator::Gt
303 | Operator::Gte
304 | Operator::Lt
305 | Operator::Lte
306 | Operator::And
307 | Operator::Or
308 );
309
310 !infallible
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use vortex_error::VortexExpect;
317
318 use super::*;
319 use crate::LEGACY_SESSION;
320 use crate::VortexSessionExecute;
321 use crate::assert_arrays_eq;
322 use crate::builtins::ArrayBuiltins;
323 use crate::dtype::DType;
324 use crate::dtype::Nullability;
325 use crate::expr::Expression;
326 use crate::expr::and_collect;
327 use crate::expr::col;
328 use crate::expr::lit;
329 use crate::expr::lt;
330 use crate::expr::not_eq;
331 use crate::expr::or;
332 use crate::expr::or_collect;
333 use crate::expr::test_harness;
334 use crate::scalar::Scalar;
335 #[test]
336 fn and_collect_balanced() {
337 let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
338
339 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
340 vortex.binary(and)
341 ├── lhs: vortex.binary(and)
342 │ ├── lhs: vortex.literal(1i32)
343 │ └── rhs: vortex.literal(2i32)
344 └── rhs: vortex.binary(and)
345 ├── lhs: vortex.binary(and)
346 │ ├── lhs: vortex.literal(3i32)
347 │ └── rhs: vortex.literal(4i32)
348 └── rhs: vortex.literal(5i32)
349 ");
350
351 let values = vec![lit(1), lit(2), lit(3), lit(4)];
353 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
354 vortex.binary(and)
355 ├── lhs: vortex.binary(and)
356 │ ├── lhs: vortex.literal(1i32)
357 │ └── rhs: vortex.literal(2i32)
358 └── rhs: vortex.binary(and)
359 ├── lhs: vortex.literal(3i32)
360 └── rhs: vortex.literal(4i32)
361 ");
362
363 let values = vec![lit(1)];
365 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @"vortex.literal(1i32)");
366
367 let values: Vec<Expression> = vec![];
369 assert!(and_collect(values.into_iter()).is_none());
370 }
371
372 #[test]
373 fn or_collect_balanced() {
374 let values = vec![lit(1), lit(2), lit(3), lit(4)];
376 insta::assert_snapshot!(or_collect(values.into_iter()).unwrap().display_tree(), @r"
377 vortex.binary(or)
378 ├── lhs: vortex.binary(or)
379 │ ├── lhs: vortex.literal(1i32)
380 │ └── rhs: vortex.literal(2i32)
381 └── rhs: vortex.binary(or)
382 ├── lhs: vortex.literal(3i32)
383 └── rhs: vortex.literal(4i32)
384 ");
385 }
386
387 #[test]
388 fn dtype() {
389 let dtype = test_harness::struct_dtype();
390 let bool1: Expression = col("bool1");
391 let bool2: Expression = col("bool2");
392 assert_eq!(
393 and(bool1.clone(), bool2.clone())
394 .return_dtype(&dtype)
395 .unwrap(),
396 DType::Bool(Nullability::NonNullable)
397 );
398 assert_eq!(
399 or(bool1, bool2).return_dtype(&dtype).unwrap(),
400 DType::Bool(Nullability::NonNullable)
401 );
402
403 let col1: Expression = col("col1");
404 let col2: Expression = col("col2");
405
406 assert_eq!(
407 eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
408 DType::Bool(Nullability::Nullable)
409 );
410 assert_eq!(
411 not_eq(col1.clone(), col2.clone())
412 .return_dtype(&dtype)
413 .unwrap(),
414 DType::Bool(Nullability::Nullable)
415 );
416 assert_eq!(
417 gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
418 DType::Bool(Nullability::Nullable)
419 );
420 assert_eq!(
421 gt_eq(col1.clone(), col2.clone())
422 .return_dtype(&dtype)
423 .unwrap(),
424 DType::Bool(Nullability::Nullable)
425 );
426 assert_eq!(
427 lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
428 DType::Bool(Nullability::Nullable)
429 );
430 assert_eq!(
431 lt_eq(col1.clone(), col2.clone())
432 .return_dtype(&dtype)
433 .unwrap(),
434 DType::Bool(Nullability::Nullable)
435 );
436
437 assert_eq!(
438 or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
439 .return_dtype(&dtype)
440 .unwrap(),
441 DType::Bool(Nullability::Nullable)
442 );
443 }
444
445 #[test]
446 fn test_display_print() {
447 let expr = gt(lit(1), lit(2));
448 assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
449 }
450
451 #[test]
454 fn test_struct_comparison() {
455 use crate::IntoArray;
456 use crate::arrays::StructArray;
457
458 let lhs_struct = StructArray::from_fields(&[
460 (
461 "a",
462 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
463 ),
464 (
465 "b",
466 crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
467 ),
468 ])
469 .unwrap()
470 .into_array();
471
472 let rhs_struct_equal = StructArray::from_fields(&[
473 (
474 "a",
475 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
476 ),
477 (
478 "b",
479 crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
480 ),
481 ])
482 .unwrap()
483 .into_array();
484
485 let rhs_struct_different = StructArray::from_fields(&[
486 (
487 "a",
488 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
489 ),
490 (
491 "b",
492 crate::arrays::PrimitiveArray::from_iter([4i32]).into_array(),
493 ),
494 ])
495 .unwrap()
496 .into_array();
497
498 let result_equal = lhs_struct.binary(rhs_struct_equal, Operator::Eq).unwrap();
500 assert_eq!(
501 result_equal
502 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
503 .vortex_expect("value"),
504 Scalar::bool(true, Nullability::NonNullable),
505 "Equal structs should be equal"
506 );
507
508 let result_different = lhs_struct
509 .binary(rhs_struct_different, Operator::Eq)
510 .unwrap();
511 assert_eq!(
512 result_different
513 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
514 .vortex_expect("value"),
515 Scalar::bool(false, Nullability::NonNullable),
516 "Different structs should not be equal"
517 );
518 }
519
520 #[test]
521 fn test_or_kleene_validity() {
522 use crate::IntoArray;
523 use crate::arrays::BoolArray;
524 use crate::arrays::StructArray;
525 use crate::expr::col;
526
527 let struct_arr = StructArray::from_fields(&[
528 ("a", BoolArray::from_iter([Some(true)]).into_array()),
529 (
530 "b",
531 BoolArray::from_iter([Option::<bool>::None]).into_array(),
532 ),
533 ])
534 .unwrap()
535 .into_array();
536
537 let expr = or(col("a"), col("b"));
538 let result = struct_arr.apply(&expr).unwrap();
539
540 assert_arrays_eq!(result, BoolArray::from_iter([Some(true)]).into_array())
541 }
542
543 #[test]
544 fn test_scalar_subtract_unsigned() {
545 use vortex_buffer::buffer;
546
547 use crate::IntoArray;
548 use crate::arrays::ConstantArray;
549 use crate::arrays::PrimitiveArray;
550
551 let values = buffer![1u16, 2, 3].into_array();
552 let rhs = ConstantArray::new(Scalar::from(1u16), 3).into_array();
553 let result = values.binary(rhs, Operator::Sub).unwrap();
554 assert_arrays_eq!(result, PrimitiveArray::from_iter([0u16, 1, 2]));
555 }
556
557 #[test]
558 fn test_scalar_subtract_signed() {
559 use vortex_buffer::buffer;
560
561 use crate::IntoArray;
562 use crate::arrays::ConstantArray;
563 use crate::arrays::PrimitiveArray;
564
565 let values = buffer![1i64, 2, 3].into_array();
566 let rhs = ConstantArray::new(Scalar::from(-1i64), 3).into_array();
567 let result = values.binary(rhs, Operator::Sub).unwrap();
568 assert_arrays_eq!(result, PrimitiveArray::from_iter([2i64, 3, 4]));
569 }
570
571 #[test]
572 fn test_scalar_subtract_nullable() {
573 use crate::IntoArray;
574 use crate::arrays::ConstantArray;
575 use crate::arrays::PrimitiveArray;
576
577 let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
578 let rhs = ConstantArray::new(Scalar::from(Some(1u16)), 4).into_array();
579 let result = values.into_array().binary(rhs, Operator::Sub).unwrap();
580 assert_arrays_eq!(
581 result,
582 PrimitiveArray::from_option_iter([Some(0u16), Some(1), None, Some(2)])
583 );
584 }
585
586 #[test]
587 fn test_scalar_subtract_float() {
588 use vortex_buffer::buffer;
589
590 use crate::IntoArray;
591 use crate::arrays::ConstantArray;
592 use crate::arrays::PrimitiveArray;
593
594 let values = buffer![1.0f64, 2.0, 3.0].into_array();
595 let rhs = ConstantArray::new(Scalar::from(-1f64), 3).into_array();
596 let result = values.binary(rhs, Operator::Sub).unwrap();
597 assert_arrays_eq!(result, PrimitiveArray::from_iter([2.0f64, 3.0, 4.0]));
598 }
599
600 #[test]
601 fn test_scalar_subtract_float_underflow_is_ok() {
602 use vortex_buffer::buffer;
603
604 use crate::IntoArray;
605 use crate::arrays::ConstantArray;
606
607 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
608 let rhs1 = ConstantArray::new(Scalar::from(1.0f32), 3).into_array();
609 let _results = values.binary(rhs1, Operator::Sub).unwrap();
610 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
611 let rhs2 = ConstantArray::new(Scalar::from(f32::MAX), 3).into_array();
612 let _results = values.binary(rhs2, Operator::Sub).unwrap();
613 }
614}