1use std::fmt::Formatter;
5
6#[expect(deprecated)]
7pub use boolean::and_kleene;
8#[expect(deprecated)]
9pub use boolean::or_kleene;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::dtype::DType;
19use crate::expr::StatsCatalog;
20use crate::expr::and;
21use crate::expr::and_collect;
22use crate::expr::eq;
23use crate::expr::expression::Expression;
24use crate::expr::gt;
25use crate::expr::gt_eq;
26use crate::expr::lit;
27use crate::expr::lt;
28use crate::expr::lt_eq;
29use crate::expr::or_collect;
30use crate::expr::stats::Stat;
31use crate::scalar_fn::Arity;
32use crate::scalar_fn::ChildName;
33use crate::scalar_fn::ExecutionArgs;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36use crate::scalar_fn::fns::operators::CompareOperator;
37use crate::scalar_fn::fns::operators::Operator;
38
39pub(crate) mod boolean;
40pub(crate) use boolean::*;
41mod compare;
42pub use compare::*;
43mod numeric;
44pub(crate) use numeric::*;
45
46use crate::scalar::NumericOperator;
47
48#[derive(Clone)]
49pub struct Binary;
50
51impl ScalarFnVTable for Binary {
52 type Options = Operator;
53
54 fn id(&self) -> ScalarFnId {
55 ScalarFnId::new("vortex.binary")
56 }
57
58 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
59 Ok(Some(
60 pb::BinaryOpts {
61 op: (*instance).into(),
62 }
63 .encode_to_vec(),
64 ))
65 }
66
67 fn deserialize(
68 &self,
69 _metadata: &[u8],
70 _session: &VortexSession,
71 ) -> VortexResult<Self::Options> {
72 let opts = pb::BinaryOpts::decode(_metadata)?;
73 Operator::try_from(opts.op)
74 }
75
76 fn arity(&self, _options: &Self::Options) -> Arity {
77 Arity::Exact(2)
78 }
79
80 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
81 match child_idx {
82 0 => ChildName::from("lhs"),
83 1 => ChildName::from("rhs"),
84 _ => unreachable!("Binary has only two children"),
85 }
86 }
87
88 fn fmt_sql(
89 &self,
90 operator: &Operator,
91 expr: &Expression,
92 f: &mut Formatter<'_>,
93 ) -> std::fmt::Result {
94 write!(f, "(")?;
95 expr.child(0).fmt_sql(f)?;
96 write!(f, " {} ", operator)?;
97 expr.child(1).fmt_sql(f)?;
98 write!(f, ")")
99 }
100
101 fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
102 let lhs = &arg_dtypes[0];
103 let rhs = &arg_dtypes[1];
104
105 if operator.is_arithmetic() {
106 if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
107 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
108 }
109 vortex_bail!(
110 "incompatible types for arithmetic operation: {} {}",
111 lhs,
112 rhs
113 );
114 }
115
116 if operator.is_comparison()
117 && !lhs.eq_ignore_nullability(rhs)
118 && !lhs.is_extension()
119 && !rhs.is_extension()
120 {
121 vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs);
122 }
123
124 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
125 }
126
127 fn execute(
128 &self,
129 op: &Operator,
130 args: &dyn ExecutionArgs,
131 ctx: &mut ExecutionCtx,
132 ) -> VortexResult<ArrayRef> {
133 let lhs = args.get(0)?;
134 let rhs = args.get(1)?;
135
136 match op {
137 Operator::Eq => execute_compare(&lhs, &rhs, CompareOperator::Eq, ctx),
138 Operator::NotEq => execute_compare(&lhs, &rhs, CompareOperator::NotEq, ctx),
139 Operator::Lt => execute_compare(&lhs, &rhs, CompareOperator::Lt, ctx),
140 Operator::Lte => execute_compare(&lhs, &rhs, CompareOperator::Lte, ctx),
141 Operator::Gt => execute_compare(&lhs, &rhs, CompareOperator::Gt, ctx),
142 Operator::Gte => execute_compare(&lhs, &rhs, CompareOperator::Gte, ctx),
143 Operator::And => execute_boolean(&lhs, &rhs, Operator::And, ctx),
144 Operator::Or => execute_boolean(&lhs, &rhs, Operator::Or, ctx),
145 Operator::Add => execute_numeric(&lhs, &rhs, NumericOperator::Add, ctx),
146 Operator::Sub => execute_numeric(&lhs, &rhs, NumericOperator::Sub, ctx),
147 Operator::Mul => execute_numeric(&lhs, &rhs, NumericOperator::Mul, ctx),
148 Operator::Div => execute_numeric(&lhs, &rhs, NumericOperator::Div, ctx),
149 }
150 }
151
152 fn stat_falsification(
153 &self,
154 operator: &Operator,
155 expr: &Expression,
156 catalog: &dyn StatsCatalog,
157 ) -> Option<Expression> {
158 fn with_nan_predicate(
172 lhs: &Expression,
173 rhs: &Expression,
174 value_predicate: Expression,
175 catalog: &dyn StatsCatalog,
176 ) -> Expression {
177 let nan_predicate = and_collect(
178 lhs.stat_expression(Stat::NaNCount, catalog)
179 .into_iter()
180 .chain(rhs.stat_expression(Stat::NaNCount, catalog))
181 .map(|nans| eq(nans, lit(0u64))),
182 );
183
184 if let Some(nan_check) = nan_predicate {
185 and(nan_check, value_predicate)
186 } else {
187 value_predicate
188 }
189 }
190
191 let lhs = expr.child(0);
192 let rhs = expr.child(1);
193 match operator {
194 Operator::Eq => {
195 let min_lhs = lhs.stat_min(catalog);
196 let max_lhs = lhs.stat_max(catalog);
197
198 let min_rhs = rhs.stat_min(catalog);
199 let max_rhs = rhs.stat_max(catalog);
200
201 let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
202 let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
203
204 let min_max_check = or_collect(left.into_iter().chain(right))?;
205
206 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
208 }
209 Operator::NotEq => {
210 let min_lhs = lhs.stat_min(catalog)?;
211 let max_lhs = lhs.stat_max(catalog)?;
212
213 let min_rhs = rhs.stat_min(catalog)?;
214 let max_rhs = rhs.stat_max(catalog)?;
215
216 let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
217
218 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
219 }
220 Operator::Gt => {
221 let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
222
223 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
224 }
225 Operator::Gte => {
226 let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
228
229 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
230 }
231 Operator::Lt => {
232 let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
234
235 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
236 }
237 Operator::Lte => {
238 let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
240
241 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
242 }
243 Operator::And => or_collect(
244 lhs.stat_falsification(catalog)
245 .into_iter()
246 .chain(rhs.stat_falsification(catalog)),
247 ),
248 Operator::Or => Some(and(
249 lhs.stat_falsification(catalog)?,
250 rhs.stat_falsification(catalog)?,
251 )),
252 Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
253 }
254 }
255
256 fn validity(
257 &self,
258 operator: &Operator,
259 expression: &Expression,
260 ) -> VortexResult<Option<Expression>> {
261 let lhs = expression.child(0).validity()?;
262 let rhs = expression.child(1).validity()?;
263
264 Ok(match operator {
265 Operator::And => None,
267 Operator::Or => None,
268 _ => {
269 Some(and(lhs, rhs))
271 }
272 })
273 }
274
275 fn is_null_sensitive(&self, _operator: &Operator) -> bool {
276 false
277 }
278
279 fn is_fallible(&self, operator: &Operator) -> bool {
280 let infallible = matches!(
283 operator,
284 Operator::Eq
285 | Operator::NotEq
286 | Operator::Gt
287 | Operator::Gte
288 | Operator::Lt
289 | Operator::Lte
290 | Operator::And
291 | Operator::Or
292 );
293
294 !infallible
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use vortex_error::VortexExpect;
301
302 use super::*;
303 use crate::LEGACY_SESSION;
304 use crate::VortexSessionExecute;
305 use crate::assert_arrays_eq;
306 use crate::builtins::ArrayBuiltins;
307 use crate::dtype::DType;
308 use crate::dtype::Nullability;
309 use crate::expr::Expression;
310 use crate::expr::and_collect;
311 use crate::expr::col;
312 use crate::expr::lit;
313 use crate::expr::lt;
314 use crate::expr::not_eq;
315 use crate::expr::or;
316 use crate::expr::or_collect;
317 use crate::expr::test_harness;
318 use crate::scalar::Scalar;
319 #[test]
320 fn and_collect_balanced() {
321 let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
322
323 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
324 vortex.binary(and)
325 ├── lhs: vortex.binary(and)
326 │ ├── lhs: vortex.literal(1i32)
327 │ └── rhs: vortex.literal(2i32)
328 └── rhs: vortex.binary(and)
329 ├── lhs: vortex.binary(and)
330 │ ├── lhs: vortex.literal(3i32)
331 │ └── rhs: vortex.literal(4i32)
332 └── rhs: vortex.literal(5i32)
333 ");
334
335 let values = vec![lit(1), lit(2), lit(3), lit(4)];
337 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
338 vortex.binary(and)
339 ├── lhs: vortex.binary(and)
340 │ ├── lhs: vortex.literal(1i32)
341 │ └── rhs: vortex.literal(2i32)
342 └── rhs: vortex.binary(and)
343 ├── lhs: vortex.literal(3i32)
344 └── rhs: vortex.literal(4i32)
345 ");
346
347 let values = vec![lit(1)];
349 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @"vortex.literal(1i32)");
350
351 let values: Vec<Expression> = vec![];
353 assert!(and_collect(values.into_iter()).is_none());
354 }
355
356 #[test]
357 fn or_collect_balanced() {
358 let values = vec![lit(1), lit(2), lit(3), lit(4)];
360 insta::assert_snapshot!(or_collect(values.into_iter()).unwrap().display_tree(), @r"
361 vortex.binary(or)
362 ├── lhs: vortex.binary(or)
363 │ ├── lhs: vortex.literal(1i32)
364 │ └── rhs: vortex.literal(2i32)
365 └── rhs: vortex.binary(or)
366 ├── lhs: vortex.literal(3i32)
367 └── rhs: vortex.literal(4i32)
368 ");
369 }
370
371 #[test]
372 fn dtype() {
373 let dtype = test_harness::struct_dtype();
374 let bool1: Expression = col("bool1");
375 let bool2: Expression = col("bool2");
376 assert_eq!(
377 and(bool1.clone(), bool2.clone())
378 .return_dtype(&dtype)
379 .unwrap(),
380 DType::Bool(Nullability::NonNullable)
381 );
382 assert_eq!(
383 or(bool1, bool2).return_dtype(&dtype).unwrap(),
384 DType::Bool(Nullability::NonNullable)
385 );
386
387 let col1: Expression = col("col1");
388 let col2: Expression = col("col2");
389
390 assert_eq!(
391 eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
392 DType::Bool(Nullability::Nullable)
393 );
394 assert_eq!(
395 not_eq(col1.clone(), col2.clone())
396 .return_dtype(&dtype)
397 .unwrap(),
398 DType::Bool(Nullability::Nullable)
399 );
400 assert_eq!(
401 gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
402 DType::Bool(Nullability::Nullable)
403 );
404 assert_eq!(
405 gt_eq(col1.clone(), col2.clone())
406 .return_dtype(&dtype)
407 .unwrap(),
408 DType::Bool(Nullability::Nullable)
409 );
410 assert_eq!(
411 lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
412 DType::Bool(Nullability::Nullable)
413 );
414 assert_eq!(
415 lt_eq(col1.clone(), col2.clone())
416 .return_dtype(&dtype)
417 .unwrap(),
418 DType::Bool(Nullability::Nullable)
419 );
420
421 assert_eq!(
422 or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
423 .return_dtype(&dtype)
424 .unwrap(),
425 DType::Bool(Nullability::Nullable)
426 );
427 }
428
429 #[test]
430 fn test_display_print() {
431 let expr = gt(lit(1), lit(2));
432 assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
433 }
434
435 #[test]
438 fn test_struct_comparison() {
439 use crate::IntoArray;
440 use crate::arrays::StructArray;
441
442 let lhs_struct = StructArray::from_fields(&[
444 (
445 "a",
446 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
447 ),
448 (
449 "b",
450 crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
451 ),
452 ])
453 .unwrap()
454 .into_array();
455
456 let rhs_struct_equal = StructArray::from_fields(&[
457 (
458 "a",
459 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
460 ),
461 (
462 "b",
463 crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
464 ),
465 ])
466 .unwrap()
467 .into_array();
468
469 let rhs_struct_different = StructArray::from_fields(&[
470 (
471 "a",
472 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
473 ),
474 (
475 "b",
476 crate::arrays::PrimitiveArray::from_iter([4i32]).into_array(),
477 ),
478 ])
479 .unwrap()
480 .into_array();
481
482 let result_equal = lhs_struct.binary(rhs_struct_equal, Operator::Eq).unwrap();
484 assert_eq!(
485 result_equal
486 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
487 .vortex_expect("value"),
488 Scalar::bool(true, Nullability::NonNullable),
489 "Equal structs should be equal"
490 );
491
492 let result_different = lhs_struct
493 .binary(rhs_struct_different, Operator::Eq)
494 .unwrap();
495 assert_eq!(
496 result_different
497 .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
498 .vortex_expect("value"),
499 Scalar::bool(false, Nullability::NonNullable),
500 "Different structs should not be equal"
501 );
502 }
503
504 #[test]
505 fn test_or_kleene_validity() {
506 use crate::IntoArray;
507 use crate::arrays::BoolArray;
508 use crate::arrays::StructArray;
509 use crate::expr::col;
510
511 let struct_arr = StructArray::from_fields(&[
512 ("a", BoolArray::from_iter([Some(true)]).into_array()),
513 (
514 "b",
515 BoolArray::from_iter([Option::<bool>::None]).into_array(),
516 ),
517 ])
518 .unwrap()
519 .into_array();
520
521 let expr = or(col("a"), col("b"));
522 let result = struct_arr.apply(&expr).unwrap();
523
524 assert_arrays_eq!(result, BoolArray::from_iter([Some(true)]).into_array())
525 }
526
527 #[test]
528 fn test_scalar_subtract_unsigned() {
529 use vortex_buffer::buffer;
530
531 use crate::IntoArray;
532 use crate::arrays::ConstantArray;
533 use crate::arrays::PrimitiveArray;
534
535 let values = buffer![1u16, 2, 3].into_array();
536 let rhs = ConstantArray::new(Scalar::from(1u16), 3).into_array();
537 let result = values.binary(rhs, Operator::Sub).unwrap();
538 assert_arrays_eq!(result, PrimitiveArray::from_iter([0u16, 1, 2]));
539 }
540
541 #[test]
542 fn test_scalar_subtract_signed() {
543 use vortex_buffer::buffer;
544
545 use crate::IntoArray;
546 use crate::arrays::ConstantArray;
547 use crate::arrays::PrimitiveArray;
548
549 let values = buffer![1i64, 2, 3].into_array();
550 let rhs = ConstantArray::new(Scalar::from(-1i64), 3).into_array();
551 let result = values.binary(rhs, Operator::Sub).unwrap();
552 assert_arrays_eq!(result, PrimitiveArray::from_iter([2i64, 3, 4]));
553 }
554
555 #[test]
556 fn test_scalar_subtract_nullable() {
557 use crate::IntoArray;
558 use crate::arrays::ConstantArray;
559 use crate::arrays::PrimitiveArray;
560
561 let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
562 let rhs = ConstantArray::new(Scalar::from(Some(1u16)), 4).into_array();
563 let result = values.into_array().binary(rhs, Operator::Sub).unwrap();
564 assert_arrays_eq!(
565 result,
566 PrimitiveArray::from_option_iter([Some(0u16), Some(1), None, Some(2)])
567 );
568 }
569
570 #[test]
571 fn test_scalar_subtract_float() {
572 use vortex_buffer::buffer;
573
574 use crate::IntoArray;
575 use crate::arrays::ConstantArray;
576 use crate::arrays::PrimitiveArray;
577
578 let values = buffer![1.0f64, 2.0, 3.0].into_array();
579 let rhs = ConstantArray::new(Scalar::from(-1f64), 3).into_array();
580 let result = values.binary(rhs, Operator::Sub).unwrap();
581 assert_arrays_eq!(result, PrimitiveArray::from_iter([2.0f64, 3.0, 4.0]));
582 }
583
584 #[test]
585 fn test_scalar_subtract_float_underflow_is_ok() {
586 use vortex_buffer::buffer;
587
588 use crate::IntoArray;
589 use crate::arrays::ConstantArray;
590
591 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
592 let rhs1 = ConstantArray::new(Scalar::from(1.0f32), 3).into_array();
593 let _results = values.binary(rhs1, Operator::Sub).unwrap();
594 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
595 let rhs2 = ConstantArray::new(Scalar::from(f32::MAX), 3).into_array();
596 let _results = values.binary(rhs2, Operator::Sub).unwrap();
597 }
598}