1use std::fmt::Formatter;
5
6#[expect(deprecated)]
7pub use boolean::and_kleene;
8#[expect(deprecated)]
9pub use boolean::or_kleene;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::dtype::DType;
19use crate::expr::StatsCatalog;
20use crate::expr::and;
21use crate::expr::and_collect;
22use crate::expr::eq;
23use crate::expr::expression::Expression;
24use crate::expr::gt;
25use crate::expr::gt_eq;
26use crate::expr::lit;
27use crate::expr::lt;
28use crate::expr::lt_eq;
29use crate::expr::or_collect;
30use crate::expr::stats::Stat;
31use crate::scalar_fn::Arity;
32use crate::scalar_fn::ChildName;
33use crate::scalar_fn::ExecutionArgs;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36use crate::scalar_fn::fns::operators::CompareOperator;
37use crate::scalar_fn::fns::operators::Operator;
38
39pub(crate) mod boolean;
40pub(crate) use boolean::*;
41mod compare;
42pub use compare::*;
43mod numeric;
44pub(crate) use numeric::*;
45
46use crate::scalar::NumericOperator;
47
48#[derive(Clone)]
49pub struct Binary;
50
51impl ScalarFnVTable for Binary {
52 type Options = Operator;
53
54 fn id(&self) -> ScalarFnId {
55 ScalarFnId::new("vortex.binary")
56 }
57
58 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
59 Ok(Some(
60 pb::BinaryOpts {
61 op: (*instance).into(),
62 }
63 .encode_to_vec(),
64 ))
65 }
66
67 fn deserialize(
68 &self,
69 _metadata: &[u8],
70 _session: &VortexSession,
71 ) -> VortexResult<Self::Options> {
72 let opts = pb::BinaryOpts::decode(_metadata)?;
73 Operator::try_from(opts.op)
74 }
75
76 fn arity(&self, _options: &Self::Options) -> Arity {
77 Arity::Exact(2)
78 }
79
80 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
81 match child_idx {
82 0 => ChildName::from("lhs"),
83 1 => ChildName::from("rhs"),
84 _ => unreachable!("Binary has only two children"),
85 }
86 }
87
88 fn fmt_sql(
89 &self,
90 operator: &Operator,
91 expr: &Expression,
92 f: &mut Formatter<'_>,
93 ) -> std::fmt::Result {
94 write!(f, "(")?;
95 expr.child(0).fmt_sql(f)?;
96 write!(f, " {} ", operator)?;
97 expr.child(1).fmt_sql(f)?;
98 write!(f, ")")
99 }
100
101 fn coerce_args(&self, operator: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
102 let lhs = &args[0];
103 let rhs = &args[1];
104 if operator.is_arithmetic() || operator.is_comparison() {
105 let supertype = lhs.least_supertype(rhs).ok_or_else(|| {
106 vortex_error::vortex_err!("No common supertype for {} and {}", lhs, rhs)
107 })?;
108 Ok(vec![supertype.clone(), supertype])
109 } else {
110 Ok(args.to_vec())
112 }
113 }
114
115 fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
116 let lhs = &arg_dtypes[0];
117 let rhs = &arg_dtypes[1];
118
119 if operator.is_arithmetic() {
120 if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
121 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
122 }
123 vortex_bail!(
124 "incompatible types for arithmetic operation: {} {}",
125 lhs,
126 rhs
127 );
128 }
129
130 if operator.is_comparison()
131 && !lhs.eq_ignore_nullability(rhs)
132 && !lhs.is_extension()
133 && !rhs.is_extension()
134 {
135 vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs);
136 }
137
138 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
139 }
140
141 fn execute(
142 &self,
143 op: &Operator,
144 args: &dyn ExecutionArgs,
145 _ctx: &mut ExecutionCtx,
146 ) -> VortexResult<ArrayRef> {
147 let lhs = args.get(0)?;
148 let rhs = args.get(1)?;
149
150 match op {
151 Operator::Eq => execute_compare(&lhs, &rhs, CompareOperator::Eq),
152 Operator::NotEq => execute_compare(&lhs, &rhs, CompareOperator::NotEq),
153 Operator::Lt => execute_compare(&lhs, &rhs, CompareOperator::Lt),
154 Operator::Lte => execute_compare(&lhs, &rhs, CompareOperator::Lte),
155 Operator::Gt => execute_compare(&lhs, &rhs, CompareOperator::Gt),
156 Operator::Gte => execute_compare(&lhs, &rhs, CompareOperator::Gte),
157 Operator::And => execute_boolean(&lhs, &rhs, Operator::And),
158 Operator::Or => execute_boolean(&lhs, &rhs, Operator::Or),
159 Operator::Add => execute_numeric(&lhs, &rhs, NumericOperator::Add),
160 Operator::Sub => execute_numeric(&lhs, &rhs, NumericOperator::Sub),
161 Operator::Mul => execute_numeric(&lhs, &rhs, NumericOperator::Mul),
162 Operator::Div => execute_numeric(&lhs, &rhs, NumericOperator::Div),
163 }
164 }
165
166 fn stat_falsification(
167 &self,
168 operator: &Operator,
169 expr: &Expression,
170 catalog: &dyn StatsCatalog,
171 ) -> Option<Expression> {
172 fn with_nan_predicate(
186 lhs: &Expression,
187 rhs: &Expression,
188 value_predicate: Expression,
189 catalog: &dyn StatsCatalog,
190 ) -> Expression {
191 let nan_predicate = and_collect(
192 lhs.stat_expression(Stat::NaNCount, catalog)
193 .into_iter()
194 .chain(rhs.stat_expression(Stat::NaNCount, catalog))
195 .map(|nans| eq(nans, lit(0u64))),
196 );
197
198 if let Some(nan_check) = nan_predicate {
199 and(nan_check, value_predicate)
200 } else {
201 value_predicate
202 }
203 }
204
205 let lhs = expr.child(0);
206 let rhs = expr.child(1);
207 match operator {
208 Operator::Eq => {
209 let min_lhs = lhs.stat_min(catalog);
210 let max_lhs = lhs.stat_max(catalog);
211
212 let min_rhs = rhs.stat_min(catalog);
213 let max_rhs = rhs.stat_max(catalog);
214
215 let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
216 let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
217
218 let min_max_check = or_collect(left.into_iter().chain(right))?;
219
220 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
222 }
223 Operator::NotEq => {
224 let min_lhs = lhs.stat_min(catalog)?;
225 let max_lhs = lhs.stat_max(catalog)?;
226
227 let min_rhs = rhs.stat_min(catalog)?;
228 let max_rhs = rhs.stat_max(catalog)?;
229
230 let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
231
232 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
233 }
234 Operator::Gt => {
235 let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
236
237 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
238 }
239 Operator::Gte => {
240 let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
242
243 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
244 }
245 Operator::Lt => {
246 let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
248
249 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
250 }
251 Operator::Lte => {
252 let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
254
255 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
256 }
257 Operator::And => or_collect(
258 lhs.stat_falsification(catalog)
259 .into_iter()
260 .chain(rhs.stat_falsification(catalog)),
261 ),
262 Operator::Or => Some(and(
263 lhs.stat_falsification(catalog)?,
264 rhs.stat_falsification(catalog)?,
265 )),
266 Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
267 }
268 }
269
270 fn validity(
271 &self,
272 operator: &Operator,
273 expression: &Expression,
274 ) -> VortexResult<Option<Expression>> {
275 let lhs = expression.child(0).validity()?;
276 let rhs = expression.child(1).validity()?;
277
278 Ok(match operator {
279 Operator::And => None,
281 Operator::Or => None,
282 _ => {
283 Some(and(lhs, rhs))
285 }
286 })
287 }
288
289 fn is_null_sensitive(&self, _operator: &Operator) -> bool {
290 false
291 }
292
293 fn is_fallible(&self, operator: &Operator) -> bool {
294 let infallible = matches!(
297 operator,
298 Operator::Eq
299 | Operator::NotEq
300 | Operator::Gt
301 | Operator::Gte
302 | Operator::Lt
303 | Operator::Lte
304 | Operator::And
305 | Operator::Or
306 );
307
308 !infallible
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use vortex_error::VortexExpect;
315
316 use super::*;
317 use crate::LEGACY_SESSION;
318 use crate::VortexSessionExecute;
319 use crate::assert_arrays_eq;
320 use crate::builtins::ArrayBuiltins;
321 use crate::dtype::DType;
322 use crate::dtype::Nullability;
323 use crate::expr::Expression;
324 use crate::expr::and_collect;
325 use crate::expr::col;
326 use crate::expr::lit;
327 use crate::expr::lt;
328 use crate::expr::not_eq;
329 use crate::expr::or;
330 use crate::expr::or_collect;
331 use crate::expr::test_harness;
332 use crate::scalar::Scalar;
333 #[test]
334 fn and_collect_balanced() {
335 let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
336
337 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
338 vortex.binary(and)
339 ├── lhs: vortex.binary(and)
340 │ ├── lhs: vortex.literal(1i32)
341 │ └── rhs: vortex.literal(2i32)
342 └── rhs: vortex.binary(and)
343 ├── lhs: vortex.binary(and)
344 │ ├── lhs: vortex.literal(3i32)
345 │ └── rhs: vortex.literal(4i32)
346 └── rhs: vortex.literal(5i32)
347 ");
348
349 let values = vec![lit(1), lit(2), lit(3), lit(4)];
351 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
352 vortex.binary(and)
353 ├── lhs: vortex.binary(and)
354 │ ├── lhs: vortex.literal(1i32)
355 │ └── rhs: vortex.literal(2i32)
356 └── rhs: vortex.binary(and)
357 ├── lhs: vortex.literal(3i32)
358 └── rhs: vortex.literal(4i32)
359 ");
360
361 let values = vec![lit(1)];
363 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @"vortex.literal(1i32)");
364
365 let values: Vec<Expression> = vec![];
367 assert!(and_collect(values.into_iter()).is_none());
368 }
369
370 #[test]
371 fn or_collect_balanced() {
372 let values = vec![lit(1), lit(2), lit(3), lit(4)];
374 insta::assert_snapshot!(or_collect(values.into_iter()).unwrap().display_tree(), @r"
375 vortex.binary(or)
376 ├── lhs: vortex.binary(or)
377 │ ├── lhs: vortex.literal(1i32)
378 │ └── rhs: vortex.literal(2i32)
379 └── rhs: vortex.binary(or)
380 ├── lhs: vortex.literal(3i32)
381 └── rhs: vortex.literal(4i32)
382 ");
383 }
384
385 #[test]
386 fn dtype() {
387 let dtype = test_harness::struct_dtype();
388 let bool1: Expression = col("bool1");
389 let bool2: Expression = col("bool2");
390 assert_eq!(
391 and(bool1.clone(), bool2.clone())
392 .return_dtype(&dtype)
393 .unwrap(),
394 DType::Bool(Nullability::NonNullable)
395 );
396 assert_eq!(
397 or(bool1, bool2).return_dtype(&dtype).unwrap(),
398 DType::Bool(Nullability::NonNullable)
399 );
400
401 let col1: Expression = col("col1");
402 let col2: Expression = col("col2");
403
404 assert_eq!(
405 eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
406 DType::Bool(Nullability::Nullable)
407 );
408 assert_eq!(
409 not_eq(col1.clone(), col2.clone())
410 .return_dtype(&dtype)
411 .unwrap(),
412 DType::Bool(Nullability::Nullable)
413 );
414 assert_eq!(
415 gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
416 DType::Bool(Nullability::Nullable)
417 );
418 assert_eq!(
419 gt_eq(col1.clone(), col2.clone())
420 .return_dtype(&dtype)
421 .unwrap(),
422 DType::Bool(Nullability::Nullable)
423 );
424 assert_eq!(
425 lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
426 DType::Bool(Nullability::Nullable)
427 );
428 assert_eq!(
429 lt_eq(col1.clone(), col2.clone())
430 .return_dtype(&dtype)
431 .unwrap(),
432 DType::Bool(Nullability::Nullable)
433 );
434
435 assert_eq!(
436 or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
437 .return_dtype(&dtype)
438 .unwrap(),
439 DType::Bool(Nullability::Nullable)
440 );
441 }
442
443 #[test]
444 fn test_display_print() {
445 let expr = gt(lit(1), lit(2));
446 assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
447 }
448
449 #[test]
452 fn test_struct_comparison() {
453 use crate::IntoArray;
454 use crate::arrays::StructArray;
455
456 let lhs_struct = StructArray::from_fields(&[
458 (
459 "a",
460 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
461 ),
462 (
463 "b",
464 crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
465 ),
466 ])
467 .unwrap()
468 .into_array();
469
470 let rhs_struct_equal = StructArray::from_fields(&[
471 (
472 "a",
473 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
474 ),
475 (
476 "b",
477 crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
478 ),
479 ])
480 .unwrap()
481 .into_array();
482
483 let rhs_struct_different = StructArray::from_fields(&[
484 (
485 "a",
486 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
487 ),
488 (
489 "b",
490 crate::arrays::PrimitiveArray::from_iter([4i32]).into_array(),
491 ),
492 ])
493 .unwrap()
494 .into_array();
495
496 let result_equal = lhs_struct.binary(rhs_struct_equal, Operator::Eq).unwrap();
498 assert_eq!(
499 result_equal
500 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
501 .vortex_expect("value"),
502 Scalar::bool(true, Nullability::NonNullable),
503 "Equal structs should be equal"
504 );
505
506 let result_different = lhs_struct
507 .binary(rhs_struct_different, Operator::Eq)
508 .unwrap();
509 assert_eq!(
510 result_different
511 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
512 .vortex_expect("value"),
513 Scalar::bool(false, Nullability::NonNullable),
514 "Different structs should not be equal"
515 );
516 }
517
518 #[test]
519 fn test_or_kleene_validity() {
520 use crate::IntoArray;
521 use crate::arrays::BoolArray;
522 use crate::arrays::StructArray;
523 use crate::expr::col;
524
525 let struct_arr = StructArray::from_fields(&[
526 ("a", BoolArray::from_iter([Some(true)]).into_array()),
527 (
528 "b",
529 BoolArray::from_iter([Option::<bool>::None]).into_array(),
530 ),
531 ])
532 .unwrap()
533 .into_array();
534
535 let expr = or(col("a"), col("b"));
536 let result = struct_arr.apply(&expr).unwrap();
537
538 assert_arrays_eq!(result, BoolArray::from_iter([Some(true)]).into_array())
539 }
540
541 #[test]
542 fn test_scalar_subtract_unsigned() {
543 use vortex_buffer::buffer;
544
545 use crate::IntoArray;
546 use crate::arrays::ConstantArray;
547 use crate::arrays::PrimitiveArray;
548
549 let values = buffer![1u16, 2, 3].into_array();
550 let rhs = ConstantArray::new(Scalar::from(1u16), 3).into_array();
551 let result = values.binary(rhs, Operator::Sub).unwrap();
552 assert_arrays_eq!(result, PrimitiveArray::from_iter([0u16, 1, 2]));
553 }
554
555 #[test]
556 fn test_scalar_subtract_signed() {
557 use vortex_buffer::buffer;
558
559 use crate::IntoArray;
560 use crate::arrays::ConstantArray;
561 use crate::arrays::PrimitiveArray;
562
563 let values = buffer![1i64, 2, 3].into_array();
564 let rhs = ConstantArray::new(Scalar::from(-1i64), 3).into_array();
565 let result = values.binary(rhs, Operator::Sub).unwrap();
566 assert_arrays_eq!(result, PrimitiveArray::from_iter([2i64, 3, 4]));
567 }
568
569 #[test]
570 fn test_scalar_subtract_nullable() {
571 use crate::IntoArray;
572 use crate::arrays::ConstantArray;
573 use crate::arrays::PrimitiveArray;
574
575 let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
576 let rhs = ConstantArray::new(Scalar::from(Some(1u16)), 4).into_array();
577 let result = values.into_array().binary(rhs, Operator::Sub).unwrap();
578 assert_arrays_eq!(
579 result,
580 PrimitiveArray::from_option_iter([Some(0u16), Some(1), None, Some(2)])
581 );
582 }
583
584 #[test]
585 fn test_scalar_subtract_float() {
586 use vortex_buffer::buffer;
587
588 use crate::IntoArray;
589 use crate::arrays::ConstantArray;
590 use crate::arrays::PrimitiveArray;
591
592 let values = buffer![1.0f64, 2.0, 3.0].into_array();
593 let rhs = ConstantArray::new(Scalar::from(-1f64), 3).into_array();
594 let result = values.binary(rhs, Operator::Sub).unwrap();
595 assert_arrays_eq!(result, PrimitiveArray::from_iter([2.0f64, 3.0, 4.0]));
596 }
597
598 #[test]
599 fn test_scalar_subtract_float_underflow_is_ok() {
600 use vortex_buffer::buffer;
601
602 use crate::IntoArray;
603 use crate::arrays::ConstantArray;
604
605 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
606 let rhs1 = ConstantArray::new(Scalar::from(1.0f32), 3).into_array();
607 let _results = values.binary(rhs1, Operator::Sub).unwrap();
608 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
609 let rhs2 = ConstantArray::new(Scalar::from(f32::MAX), 3).into_array();
610 let _results = values.binary(rhs2, Operator::Sub).unwrap();
611 }
612}