1use std::fmt::Formatter;
5
6#[expect(deprecated)]
7pub use boolean::and_kleene;
8#[expect(deprecated)]
9pub use boolean::or_kleene;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15use vortex_session::registry::CachedId;
16
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::dtype::DType;
20use crate::dtype::Nullability;
21use crate::expr::and;
22use crate::expr::expression::Expression;
23use crate::expr::lit;
24use crate::scalar_fn::Arity;
25use crate::scalar_fn::ChildName;
26use crate::scalar_fn::ExecutionArgs;
27use crate::scalar_fn::ScalarFnId;
28use crate::scalar_fn::ScalarFnVTable;
29use crate::scalar_fn::SimplifyCtx;
30use crate::scalar_fn::fns::literal::Literal;
31use crate::scalar_fn::fns::operators::CompareOperator;
32use crate::scalar_fn::fns::operators::Operator;
33
34pub mod boolean;
35pub use boolean::BooleanExecuteAdaptor;
36pub use boolean::BooleanKernel;
37pub(crate) use boolean::execute_boolean;
38pub use boolean::kleene_boolean_buffer_scalar;
39pub use boolean::kleene_boolean_buffers;
40mod compare;
41pub use compare::*;
42mod numeric;
43pub(crate) use numeric::*;
44
45use crate::scalar::NumericOperator;
46use crate::scalar::Scalar;
47
48#[derive(Clone)]
49pub struct Binary;
50
51impl ScalarFnVTable for Binary {
52 type Options = Operator;
53
54 fn id(&self) -> ScalarFnId {
55 static ID: CachedId = CachedId::new("vortex.binary");
56 *ID
57 }
58
59 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
60 Ok(Some(
61 pb::BinaryOpts {
62 op: (*instance).into(),
63 }
64 .encode_to_vec(),
65 ))
66 }
67
68 fn deserialize(
69 &self,
70 _metadata: &[u8],
71 _session: &VortexSession,
72 ) -> VortexResult<Self::Options> {
73 let opts = pb::BinaryOpts::decode(_metadata)?;
74 Operator::try_from(opts.op)
75 }
76
77 fn arity(&self, _options: &Self::Options) -> Arity {
78 Arity::Exact(2)
79 }
80
81 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
82 match child_idx {
83 0 => ChildName::from("lhs"),
84 1 => ChildName::from("rhs"),
85 _ => unreachable!("Binary has only two children"),
86 }
87 }
88
89 fn fmt_sql(
90 &self,
91 operator: &Operator,
92 expr: &Expression,
93 f: &mut Formatter<'_>,
94 ) -> std::fmt::Result {
95 write!(f, "(")?;
96 expr.child(0).fmt_sql(f)?;
97 write!(f, " {} ", operator)?;
98 expr.child(1).fmt_sql(f)?;
99 write!(f, ")")
100 }
101
102 fn coerce_args(&self, operator: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
103 let lhs = &args[0];
104 let rhs = &args[1];
105 if operator.is_arithmetic() || operator.is_comparison() {
106 let supertype = lhs.least_supertype(rhs).ok_or_else(|| {
107 vortex_error::vortex_err!("No common supertype for {} and {}", lhs, rhs)
108 })?;
109 Ok(vec![supertype.clone(), supertype])
110 } else {
111 Ok(args.to_vec())
113 }
114 }
115
116 fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
117 let lhs = &arg_dtypes[0];
118 let rhs = &arg_dtypes[1];
119
120 if operator.is_arithmetic() {
121 if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
122 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
123 }
124 vortex_bail!(
125 "incompatible types for arithmetic operation: {} {}",
126 lhs,
127 rhs
128 );
129 }
130
131 if operator.is_comparison()
132 && !lhs.eq_ignore_nullability(rhs)
133 && !lhs.is_extension()
134 && !rhs.is_extension()
135 {
136 vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs);
137 }
138
139 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
140 }
141
142 fn execute(
143 &self,
144 op: &Operator,
145 args: &dyn ExecutionArgs,
146 ctx: &mut ExecutionCtx,
147 ) -> VortexResult<ArrayRef> {
148 let lhs = args.get(0)?;
149 let rhs = args.get(1)?;
150
151 match op {
152 Operator::Eq => execute_compare(&lhs, &rhs, CompareOperator::Eq, ctx),
153 Operator::NotEq => execute_compare(&lhs, &rhs, CompareOperator::NotEq, ctx),
154 Operator::Lt => execute_compare(&lhs, &rhs, CompareOperator::Lt, ctx),
155 Operator::Lte => execute_compare(&lhs, &rhs, CompareOperator::Lte, ctx),
156 Operator::Gt => execute_compare(&lhs, &rhs, CompareOperator::Gt, ctx),
157 Operator::Gte => execute_compare(&lhs, &rhs, CompareOperator::Gte, ctx),
158 Operator::And => execute_boolean(&lhs, &rhs, Operator::And, ctx),
159 Operator::Or => execute_boolean(&lhs, &rhs, Operator::Or, ctx),
160 Operator::Add => execute_numeric(&lhs, &rhs, NumericOperator::Add, ctx),
161 Operator::Sub => execute_numeric(&lhs, &rhs, NumericOperator::Sub, ctx),
162 Operator::Mul => execute_numeric(&lhs, &rhs, NumericOperator::Mul, ctx),
163 Operator::Div => execute_numeric(&lhs, &rhs, NumericOperator::Div, ctx),
164 }
165 }
166
167 fn simplify_untyped(
168 &self,
169 operator: &Operator,
170 expr: &Expression,
171 ) -> VortexResult<Option<Expression>> {
172 let lhs = expr.child(0);
173 let rhs = expr.child(1);
174
175 let bool_literal = |expr: &Expression| {
176 expr.as_opt::<Literal>()?
177 .as_bool_opt()
178 .map(|value| value.value())
179 };
180
181 Ok(match operator {
197 Operator::And => match (bool_literal(lhs), bool_literal(rhs)) {
198 (Some(Some(false)), _) | (_, Some(Some(false))) => Some(lit(false)),
199 (Some(Some(true)), _) => Some(rhs.clone()),
200 (_, Some(Some(true))) => Some(lhs.clone()),
201 (Some(None), Some(None)) => Some(lhs.clone()),
202 _ => None,
203 },
204 Operator::Or => match (bool_literal(lhs), bool_literal(rhs)) {
205 (Some(Some(true)), _) | (_, Some(Some(true))) => Some(lit(true)),
206 (Some(Some(false)), _) => Some(rhs.clone()),
207 (_, Some(Some(false))) => Some(lhs.clone()),
208 (Some(None), Some(None)) => Some(lhs.clone()),
209 _ => None,
210 },
211 _ => None,
212 })
213 }
214
215 fn simplify(
216 &self,
217 operator: &Operator,
218 expr: &Expression,
219 ctx: &dyn SimplifyCtx,
220 ) -> VortexResult<Option<Expression>> {
221 let is_literal_null =
222 |expr: &Expression| expr.as_opt::<Literal>().is_some_and(Scalar::is_null);
223
224 if operator.is_comparison()
225 && (is_literal_null(expr.child(0)) || is_literal_null(expr.child(1)))
226 {
227 ctx.return_dtype(expr)?;
230 return Ok(Some(lit(Scalar::null(DType::Bool(Nullability::Nullable)))));
231 }
232
233 Ok(None)
234 }
235
236 fn validity(
237 &self,
238 operator: &Operator,
239 expression: &Expression,
240 ) -> VortexResult<Option<Expression>> {
241 let lhs = expression.child(0).validity()?;
242 let rhs = expression.child(1).validity()?;
243
244 Ok(match operator {
245 Operator::And => None,
247 Operator::Or => None,
248 _ => {
249 Some(and(lhs, rhs))
251 }
252 })
253 }
254
255 fn is_null_sensitive(&self, _operator: &Operator) -> bool {
256 false
257 }
258
259 fn is_fallible(&self, operator: &Operator) -> bool {
260 let infallible = matches!(
263 operator,
264 Operator::Eq
265 | Operator::NotEq
266 | Operator::Gt
267 | Operator::Gte
268 | Operator::Lt
269 | Operator::Lte
270 | Operator::And
271 | Operator::Or
272 );
273
274 !infallible
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use vortex_error::VortexExpect;
281 use vortex_error::VortexResult;
282
283 use super::*;
284 use crate::VortexSessionExecute;
285 use crate::array_session;
286 use crate::assert_arrays_eq;
287 use crate::builtins::ArrayBuiltins;
288 use crate::dtype::DType;
289 use crate::dtype::Nullability;
290 use crate::dtype::PType;
291 use crate::expr::Expression;
292 use crate::expr::and_collect;
293 use crate::expr::col;
294 use crate::expr::eq;
295 use crate::expr::gt;
296 use crate::expr::gt_eq;
297 use crate::expr::lit;
298 use crate::expr::lt;
299 use crate::expr::lt_eq;
300 use crate::expr::not_eq;
301 use crate::expr::or;
302 use crate::expr::or_collect;
303 use crate::expr::test_harness;
304 use crate::scalar::Scalar;
305 #[test]
306 fn and_collect_balanced() {
307 let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
308
309 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
310 vortex.binary(and)
311 ├── lhs: vortex.binary(and)
312 │ ├── lhs: vortex.literal(1i32)
313 │ └── rhs: vortex.literal(2i32)
314 └── rhs: vortex.binary(and)
315 ├── lhs: vortex.binary(and)
316 │ ├── lhs: vortex.literal(3i32)
317 │ └── rhs: vortex.literal(4i32)
318 └── rhs: vortex.literal(5i32)
319 ");
320
321 let values = vec![lit(1), lit(2), lit(3), lit(4)];
323 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
324 vortex.binary(and)
325 ├── lhs: vortex.binary(and)
326 │ ├── lhs: vortex.literal(1i32)
327 │ └── rhs: vortex.literal(2i32)
328 └── rhs: vortex.binary(and)
329 ├── lhs: vortex.literal(3i32)
330 └── rhs: vortex.literal(4i32)
331 ");
332
333 let values = vec![lit(1)];
335 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @"vortex.literal(1i32)");
336
337 let values: Vec<Expression> = vec![];
339 assert!(and_collect(values.into_iter()).is_none());
340 }
341
342 #[test]
343 fn or_collect_balanced() {
344 let values = vec![lit(1), lit(2), lit(3), lit(4)];
346 insta::assert_snapshot!(or_collect(values.into_iter()).unwrap().display_tree(), @r"
347 vortex.binary(or)
348 ├── lhs: vortex.binary(or)
349 │ ├── lhs: vortex.literal(1i32)
350 │ └── rhs: vortex.literal(2i32)
351 └── rhs: vortex.binary(or)
352 ├── lhs: vortex.literal(3i32)
353 └── rhs: vortex.literal(4i32)
354 ");
355 }
356
357 #[test]
358 fn dtype() {
359 let dtype = test_harness::struct_dtype();
360 let bool1: Expression = col("bool1");
361 let bool2: Expression = col("bool2");
362 assert_eq!(
363 and(bool1.clone(), bool2.clone())
364 .return_dtype(&dtype)
365 .unwrap(),
366 DType::Bool(Nullability::NonNullable)
367 );
368 assert_eq!(
369 or(bool1, bool2).return_dtype(&dtype).unwrap(),
370 DType::Bool(Nullability::NonNullable)
371 );
372
373 let col1: Expression = col("col1");
374 let col2: Expression = col("col2");
375
376 assert_eq!(
377 eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
378 DType::Bool(Nullability::Nullable)
379 );
380 assert_eq!(
381 not_eq(col1.clone(), col2.clone())
382 .return_dtype(&dtype)
383 .unwrap(),
384 DType::Bool(Nullability::Nullable)
385 );
386 assert_eq!(
387 gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
388 DType::Bool(Nullability::Nullable)
389 );
390 assert_eq!(
391 gt_eq(col1.clone(), col2.clone())
392 .return_dtype(&dtype)
393 .unwrap(),
394 DType::Bool(Nullability::Nullable)
395 );
396 assert_eq!(
397 lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
398 DType::Bool(Nullability::Nullable)
399 );
400 assert_eq!(
401 lt_eq(col1.clone(), col2.clone())
402 .return_dtype(&dtype)
403 .unwrap(),
404 DType::Bool(Nullability::Nullable)
405 );
406
407 assert_eq!(
408 or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
409 .return_dtype(&dtype)
410 .unwrap(),
411 DType::Bool(Nullability::Nullable)
412 );
413 }
414
415 #[test]
416 fn comparison_with_typed_null_simplifies_after_type_check() -> VortexResult<()> {
417 let dtype = test_harness::struct_dtype();
418
419 let expr = eq(
420 col("col1"),
421 lit(Scalar::null(DType::Primitive(
422 PType::U16,
423 Nullability::Nullable,
424 ))),
425 );
426
427 assert_eq!(
428 expr.optimize_recursive(&dtype)?,
429 lit(Scalar::null(DType::Bool(Nullability::Nullable)))
430 );
431 Ok(())
432 }
433
434 #[test]
435 fn comparison_with_incompatible_null_still_type_checks() {
436 let dtype = test_harness::struct_dtype();
437 let expr = eq(
438 col("col1"),
439 lit(Scalar::null(DType::Utf8(Nullability::Nullable))),
440 );
441
442 assert!(expr.optimize_recursive(&dtype).is_err());
443 }
444
445 #[test]
446 fn test_display_print() {
447 let expr = gt(lit(1), lit(2));
448 assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
449 }
450
451 #[test]
454 fn test_struct_comparison() {
455 use crate::IntoArray;
456 use crate::arrays::StructArray;
457
458 let lhs_struct = StructArray::from_fields(&[
460 (
461 "a",
462 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
463 ),
464 (
465 "b",
466 crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
467 ),
468 ])
469 .unwrap()
470 .into_array();
471
472 let rhs_struct_equal = StructArray::from_fields(&[
473 (
474 "a",
475 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
476 ),
477 (
478 "b",
479 crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
480 ),
481 ])
482 .unwrap()
483 .into_array();
484
485 let rhs_struct_different = StructArray::from_fields(&[
486 (
487 "a",
488 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
489 ),
490 (
491 "b",
492 crate::arrays::PrimitiveArray::from_iter([4i32]).into_array(),
493 ),
494 ])
495 .unwrap()
496 .into_array();
497
498 let result_equal = lhs_struct.binary(rhs_struct_equal, Operator::Eq).unwrap();
500 assert_eq!(
501 result_equal
502 .execute_scalar(0, &mut array_session().create_execution_ctx())
503 .vortex_expect("value"),
504 Scalar::bool(true, Nullability::NonNullable),
505 "Equal structs should be equal"
506 );
507
508 let result_different = lhs_struct
509 .binary(rhs_struct_different, Operator::Eq)
510 .unwrap();
511 assert_eq!(
512 result_different
513 .execute_scalar(0, &mut array_session().create_execution_ctx())
514 .vortex_expect("value"),
515 Scalar::bool(false, Nullability::NonNullable),
516 "Different structs should not be equal"
517 );
518 }
519
520 #[test]
521 fn test_or_kleene_validity() {
522 let mut ctx = array_session().create_execution_ctx();
523 use crate::IntoArray;
524 use crate::arrays::BoolArray;
525 use crate::arrays::StructArray;
526 use crate::expr::col;
527
528 let struct_arr = StructArray::from_fields(&[
529 ("a", BoolArray::from_iter([Some(true)]).into_array()),
530 (
531 "b",
532 BoolArray::from_iter([Option::<bool>::None]).into_array(),
533 ),
534 ])
535 .unwrap()
536 .into_array();
537
538 let expr = or(col("a"), col("b"));
539 let result = struct_arr.apply(&expr).unwrap();
540
541 assert_arrays_eq!(
542 result,
543 BoolArray::from_iter([Some(true)]).into_array(),
544 &mut ctx
545 )
546 }
547
548 #[test]
549 fn test_scalar_subtract_unsigned() {
550 let mut ctx = array_session().create_execution_ctx();
551 use vortex_buffer::buffer;
552
553 use crate::IntoArray;
554 use crate::arrays::ConstantArray;
555 use crate::arrays::PrimitiveArray;
556
557 let values = buffer![1u16, 2, 3].into_array();
558 let rhs = ConstantArray::new(Scalar::from(1u16), 3).into_array();
559 let result = values.binary(rhs, Operator::Sub).unwrap();
560 assert_arrays_eq!(result, PrimitiveArray::from_iter([0u16, 1, 2]), &mut ctx);
561 }
562
563 #[test]
564 fn test_scalar_subtract_signed() {
565 let mut ctx = array_session().create_execution_ctx();
566 use vortex_buffer::buffer;
567
568 use crate::IntoArray;
569 use crate::arrays::ConstantArray;
570 use crate::arrays::PrimitiveArray;
571
572 let values = buffer![1i64, 2, 3].into_array();
573 let rhs = ConstantArray::new(Scalar::from(-1i64), 3).into_array();
574 let result = values.binary(rhs, Operator::Sub).unwrap();
575 assert_arrays_eq!(result, PrimitiveArray::from_iter([2i64, 3, 4]), &mut ctx);
576 }
577
578 #[test]
579 fn test_scalar_subtract_nullable() {
580 let mut ctx = array_session().create_execution_ctx();
581 use crate::IntoArray;
582 use crate::arrays::ConstantArray;
583 use crate::arrays::PrimitiveArray;
584
585 let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
586 let rhs = ConstantArray::new(Scalar::from(Some(1u16)), 4).into_array();
587 let result = values.into_array().binary(rhs, Operator::Sub).unwrap();
588 assert_arrays_eq!(
589 result,
590 PrimitiveArray::from_option_iter([Some(0u16), Some(1), None, Some(2)]),
591 &mut ctx
592 );
593 }
594
595 #[test]
596 fn test_scalar_subtract_float() {
597 let mut ctx = array_session().create_execution_ctx();
598 use vortex_buffer::buffer;
599
600 use crate::IntoArray;
601 use crate::arrays::ConstantArray;
602 use crate::arrays::PrimitiveArray;
603
604 let values = buffer![1.0f64, 2.0, 3.0].into_array();
605 let rhs = ConstantArray::new(Scalar::from(-1f64), 3).into_array();
606 let result = values.binary(rhs, Operator::Sub).unwrap();
607 assert_arrays_eq!(
608 result,
609 PrimitiveArray::from_iter([2.0f64, 3.0, 4.0]),
610 &mut ctx
611 );
612 }
613
614 #[test]
615 fn test_scalar_subtract_float_underflow_is_ok() {
616 use vortex_buffer::buffer;
617
618 use crate::IntoArray;
619 use crate::arrays::ConstantArray;
620
621 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
622 let rhs1 = ConstantArray::new(Scalar::from(1.0f32), 3).into_array();
623 let _results = values.binary(rhs1, Operator::Sub).unwrap();
624 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
625 let rhs2 = ConstantArray::new(Scalar::from(f32::MAX), 3).into_array();
626 let _results = values.binary(rhs2, Operator::Sub).unwrap();
627 }
628}