1use std::fmt::Formatter;
5
6#[expect(deprecated)]
7pub use boolean::and_kleene;
8#[expect(deprecated)]
9pub use boolean::or_kleene;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::dtype::DType;
19use crate::expr::StatsCatalog;
20use crate::expr::and;
21use crate::expr::and_collect;
22use crate::expr::eq;
23use crate::expr::expression::Expression;
24use crate::expr::gt;
25use crate::expr::gt_eq;
26use crate::expr::lit;
27use crate::expr::lt;
28use crate::expr::lt_eq;
29use crate::expr::or_collect;
30use crate::expr::stats::Stat;
31use crate::scalar_fn::Arity;
32use crate::scalar_fn::ChildName;
33use crate::scalar_fn::ExecutionArgs;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36use crate::scalar_fn::fns::operators::CompareOperator;
37use crate::scalar_fn::fns::operators::Operator;
38
39pub(crate) mod boolean;
40pub(crate) use boolean::*;
41mod compare;
42pub use compare::*;
43mod numeric;
44pub(crate) use numeric::*;
45
46use crate::scalar::NumericOperator;
47
48#[derive(Clone)]
49pub struct Binary;
50
51impl ScalarFnVTable for Binary {
52 type Options = Operator;
53
54 fn id(&self) -> ScalarFnId {
55 ScalarFnId::from("vortex.binary")
56 }
57
58 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
59 Ok(Some(
60 pb::BinaryOpts {
61 op: (*instance).into(),
62 }
63 .encode_to_vec(),
64 ))
65 }
66
67 fn deserialize(
68 &self,
69 _metadata: &[u8],
70 _session: &VortexSession,
71 ) -> VortexResult<Self::Options> {
72 let opts = pb::BinaryOpts::decode(_metadata)?;
73 Operator::try_from(opts.op)
74 }
75
76 fn arity(&self, _options: &Self::Options) -> Arity {
77 Arity::Exact(2)
78 }
79
80 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
81 match child_idx {
82 0 => ChildName::from("lhs"),
83 1 => ChildName::from("rhs"),
84 _ => unreachable!("Binary has only two children"),
85 }
86 }
87
88 fn fmt_sql(
89 &self,
90 operator: &Operator,
91 expr: &Expression,
92 f: &mut Formatter<'_>,
93 ) -> std::fmt::Result {
94 write!(f, "(")?;
95 expr.child(0).fmt_sql(f)?;
96 write!(f, " {} ", operator)?;
97 expr.child(1).fmt_sql(f)?;
98 write!(f, ")")
99 }
100
101 fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
102 let lhs = &arg_dtypes[0];
103 let rhs = &arg_dtypes[1];
104
105 if operator.is_arithmetic() {
106 if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
107 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
108 }
109 vortex_bail!(
110 "incompatible types for arithmetic operation: {} {}",
111 lhs,
112 rhs
113 );
114 }
115
116 if operator.is_comparison()
117 && !lhs.eq_ignore_nullability(rhs)
118 && !lhs.is_extension()
119 && !rhs.is_extension()
120 {
121 vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs);
122 }
123
124 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
125 }
126
127 fn execute(
128 &self,
129 op: &Operator,
130 args: &dyn ExecutionArgs,
131 _ctx: &mut ExecutionCtx,
132 ) -> VortexResult<ArrayRef> {
133 let lhs = args.get(0)?;
134 let rhs = args.get(1)?;
135
136 match op {
137 Operator::Eq => execute_compare(&lhs, &rhs, CompareOperator::Eq),
138 Operator::NotEq => execute_compare(&lhs, &rhs, CompareOperator::NotEq),
139 Operator::Lt => execute_compare(&lhs, &rhs, CompareOperator::Lt),
140 Operator::Lte => execute_compare(&lhs, &rhs, CompareOperator::Lte),
141 Operator::Gt => execute_compare(&lhs, &rhs, CompareOperator::Gt),
142 Operator::Gte => execute_compare(&lhs, &rhs, CompareOperator::Gte),
143 Operator::And => execute_boolean(&lhs, &rhs, Operator::And),
144 Operator::Or => execute_boolean(&lhs, &rhs, Operator::Or),
145 Operator::Add => execute_numeric(&lhs, &rhs, NumericOperator::Add),
146 Operator::Sub => execute_numeric(&lhs, &rhs, NumericOperator::Sub),
147 Operator::Mul => execute_numeric(&lhs, &rhs, NumericOperator::Mul),
148 Operator::Div => execute_numeric(&lhs, &rhs, NumericOperator::Div),
149 }
150 }
151
152 fn stat_falsification(
153 &self,
154 operator: &Operator,
155 expr: &Expression,
156 catalog: &dyn StatsCatalog,
157 ) -> Option<Expression> {
158 #[inline]
172 fn with_nan_predicate(
173 lhs: &Expression,
174 rhs: &Expression,
175 value_predicate: Expression,
176 catalog: &dyn StatsCatalog,
177 ) -> Expression {
178 let nan_predicate = and_collect(
179 lhs.stat_expression(Stat::NaNCount, catalog)
180 .into_iter()
181 .chain(rhs.stat_expression(Stat::NaNCount, catalog))
182 .map(|nans| eq(nans, lit(0u64))),
183 );
184
185 if let Some(nan_check) = nan_predicate {
186 and(nan_check, value_predicate)
187 } else {
188 value_predicate
189 }
190 }
191
192 let lhs = expr.child(0);
193 let rhs = expr.child(1);
194 match operator {
195 Operator::Eq => {
196 let min_lhs = lhs.stat_min(catalog);
197 let max_lhs = lhs.stat_max(catalog);
198
199 let min_rhs = rhs.stat_min(catalog);
200 let max_rhs = rhs.stat_max(catalog);
201
202 let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
203 let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
204
205 let min_max_check = or_collect(left.into_iter().chain(right))?;
206
207 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
209 }
210 Operator::NotEq => {
211 let min_lhs = lhs.stat_min(catalog)?;
212 let max_lhs = lhs.stat_max(catalog)?;
213
214 let min_rhs = rhs.stat_min(catalog)?;
215 let max_rhs = rhs.stat_max(catalog)?;
216
217 let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
218
219 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
220 }
221 Operator::Gt => {
222 let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
223
224 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
225 }
226 Operator::Gte => {
227 let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
229
230 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
231 }
232 Operator::Lt => {
233 let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
235
236 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
237 }
238 Operator::Lte => {
239 let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
241
242 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
243 }
244 Operator::And => or_collect(
245 lhs.stat_falsification(catalog)
246 .into_iter()
247 .chain(rhs.stat_falsification(catalog)),
248 ),
249 Operator::Or => Some(and(
250 lhs.stat_falsification(catalog)?,
251 rhs.stat_falsification(catalog)?,
252 )),
253 Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
254 }
255 }
256
257 fn validity(
258 &self,
259 operator: &Operator,
260 expression: &Expression,
261 ) -> VortexResult<Option<Expression>> {
262 let lhs = expression.child(0).validity()?;
263 let rhs = expression.child(1).validity()?;
264
265 Ok(match operator {
266 Operator::And => None,
268 Operator::Or => None,
269 _ => {
270 Some(and(lhs, rhs))
272 }
273 })
274 }
275
276 fn is_null_sensitive(&self, _operator: &Operator) -> bool {
277 false
278 }
279
280 fn is_fallible(&self, operator: &Operator) -> bool {
281 let infallible = matches!(
284 operator,
285 Operator::Eq
286 | Operator::NotEq
287 | Operator::Gt
288 | Operator::Gte
289 | Operator::Lt
290 | Operator::Lte
291 | Operator::And
292 | Operator::Or
293 );
294
295 !infallible
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use vortex_error::VortexExpect;
302
303 use super::*;
304 use crate::assert_arrays_eq;
305 use crate::builtins::ArrayBuiltins;
306 use crate::dtype::DType;
307 use crate::dtype::Nullability;
308 use crate::expr::Expression;
309 use crate::expr::and_collect;
310 use crate::expr::col;
311 use crate::expr::lit;
312 use crate::expr::lt;
313 use crate::expr::not_eq;
314 use crate::expr::or;
315 use crate::expr::or_collect;
316 use crate::expr::test_harness;
317 use crate::scalar::Scalar;
318 #[test]
319 fn and_collect_balanced() {
320 let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
321
322 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
323 vortex.binary(and)
324 ├── lhs: vortex.binary(and)
325 │ ├── lhs: vortex.literal(1i32)
326 │ └── rhs: vortex.literal(2i32)
327 └── rhs: vortex.binary(and)
328 ├── lhs: vortex.binary(and)
329 │ ├── lhs: vortex.literal(3i32)
330 │ └── rhs: vortex.literal(4i32)
331 └── rhs: vortex.literal(5i32)
332 ");
333
334 let values = vec![lit(1), lit(2), lit(3), lit(4)];
336 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
337 vortex.binary(and)
338 ├── lhs: vortex.binary(and)
339 │ ├── lhs: vortex.literal(1i32)
340 │ └── rhs: vortex.literal(2i32)
341 └── rhs: vortex.binary(and)
342 ├── lhs: vortex.literal(3i32)
343 └── rhs: vortex.literal(4i32)
344 ");
345
346 let values = vec![lit(1)];
348 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @"vortex.literal(1i32)");
349
350 let values: Vec<Expression> = vec![];
352 assert!(and_collect(values.into_iter()).is_none());
353 }
354
355 #[test]
356 fn or_collect_balanced() {
357 let values = vec![lit(1), lit(2), lit(3), lit(4)];
359 insta::assert_snapshot!(or_collect(values.into_iter()).unwrap().display_tree(), @r"
360 vortex.binary(or)
361 ├── lhs: vortex.binary(or)
362 │ ├── lhs: vortex.literal(1i32)
363 │ └── rhs: vortex.literal(2i32)
364 └── rhs: vortex.binary(or)
365 ├── lhs: vortex.literal(3i32)
366 └── rhs: vortex.literal(4i32)
367 ");
368 }
369
370 #[test]
371 fn dtype() {
372 let dtype = test_harness::struct_dtype();
373 let bool1: Expression = col("bool1");
374 let bool2: Expression = col("bool2");
375 assert_eq!(
376 and(bool1.clone(), bool2.clone())
377 .return_dtype(&dtype)
378 .unwrap(),
379 DType::Bool(Nullability::NonNullable)
380 );
381 assert_eq!(
382 or(bool1, bool2).return_dtype(&dtype).unwrap(),
383 DType::Bool(Nullability::NonNullable)
384 );
385
386 let col1: Expression = col("col1");
387 let col2: Expression = col("col2");
388
389 assert_eq!(
390 eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
391 DType::Bool(Nullability::Nullable)
392 );
393 assert_eq!(
394 not_eq(col1.clone(), col2.clone())
395 .return_dtype(&dtype)
396 .unwrap(),
397 DType::Bool(Nullability::Nullable)
398 );
399 assert_eq!(
400 gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
401 DType::Bool(Nullability::Nullable)
402 );
403 assert_eq!(
404 gt_eq(col1.clone(), col2.clone())
405 .return_dtype(&dtype)
406 .unwrap(),
407 DType::Bool(Nullability::Nullable)
408 );
409 assert_eq!(
410 lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
411 DType::Bool(Nullability::Nullable)
412 );
413 assert_eq!(
414 lt_eq(col1.clone(), col2.clone())
415 .return_dtype(&dtype)
416 .unwrap(),
417 DType::Bool(Nullability::Nullable)
418 );
419
420 assert_eq!(
421 or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
422 .return_dtype(&dtype)
423 .unwrap(),
424 DType::Bool(Nullability::Nullable)
425 );
426 }
427
428 #[test]
429 fn test_display_print() {
430 let expr = gt(lit(1), lit(2));
431 assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
432 }
433
434 #[test]
437 fn test_struct_comparison() {
438 use crate::IntoArray;
439 use crate::arrays::StructArray;
440
441 let lhs_struct = StructArray::from_fields(&[
443 (
444 "a",
445 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
446 ),
447 (
448 "b",
449 crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
450 ),
451 ])
452 .unwrap()
453 .into_array();
454
455 let rhs_struct_equal = StructArray::from_fields(&[
456 (
457 "a",
458 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
459 ),
460 (
461 "b",
462 crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
463 ),
464 ])
465 .unwrap()
466 .into_array();
467
468 let rhs_struct_different = StructArray::from_fields(&[
469 (
470 "a",
471 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
472 ),
473 (
474 "b",
475 crate::arrays::PrimitiveArray::from_iter([4i32]).into_array(),
476 ),
477 ])
478 .unwrap()
479 .into_array();
480
481 let result_equal = lhs_struct.binary(rhs_struct_equal, Operator::Eq).unwrap();
483 assert_eq!(
484 result_equal.scalar_at(0).vortex_expect("value"),
485 Scalar::bool(true, Nullability::NonNullable),
486 "Equal structs should be equal"
487 );
488
489 let result_different = lhs_struct
490 .binary(rhs_struct_different, Operator::Eq)
491 .unwrap();
492 assert_eq!(
493 result_different.scalar_at(0).vortex_expect("value"),
494 Scalar::bool(false, Nullability::NonNullable),
495 "Different structs should not be equal"
496 );
497 }
498
499 #[test]
500 fn test_or_kleene_validity() {
501 use crate::IntoArray;
502 use crate::arrays::BoolArray;
503 use crate::arrays::StructArray;
504 use crate::expr::col;
505
506 let struct_arr = StructArray::from_fields(&[
507 ("a", BoolArray::from_iter([Some(true)]).into_array()),
508 (
509 "b",
510 BoolArray::from_iter([Option::<bool>::None]).into_array(),
511 ),
512 ])
513 .unwrap()
514 .into_array();
515
516 let expr = or(col("a"), col("b"));
517 let result = struct_arr.apply(&expr).unwrap();
518
519 assert_arrays_eq!(result, BoolArray::from_iter([Some(true)]).into_array())
520 }
521
522 #[test]
523 fn test_scalar_subtract_unsigned() {
524 use vortex_buffer::buffer;
525
526 use crate::IntoArray;
527 use crate::arrays::ConstantArray;
528 use crate::arrays::PrimitiveArray;
529
530 let values = buffer![1u16, 2, 3].into_array();
531 let rhs = ConstantArray::new(Scalar::from(1u16), 3).into_array();
532 let result = values.binary(rhs, Operator::Sub).unwrap();
533 assert_arrays_eq!(result, PrimitiveArray::from_iter([0u16, 1, 2]));
534 }
535
536 #[test]
537 fn test_scalar_subtract_signed() {
538 use vortex_buffer::buffer;
539
540 use crate::IntoArray;
541 use crate::arrays::ConstantArray;
542 use crate::arrays::PrimitiveArray;
543
544 let values = buffer![1i64, 2, 3].into_array();
545 let rhs = ConstantArray::new(Scalar::from(-1i64), 3).into_array();
546 let result = values.binary(rhs, Operator::Sub).unwrap();
547 assert_arrays_eq!(result, PrimitiveArray::from_iter([2i64, 3, 4]));
548 }
549
550 #[test]
551 fn test_scalar_subtract_nullable() {
552 use crate::IntoArray;
553 use crate::arrays::ConstantArray;
554 use crate::arrays::PrimitiveArray;
555
556 let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
557 let rhs = ConstantArray::new(Scalar::from(Some(1u16)), 4).into_array();
558 let result = values.into_array().binary(rhs, Operator::Sub).unwrap();
559 assert_arrays_eq!(
560 result,
561 PrimitiveArray::from_option_iter([Some(0u16), Some(1), None, Some(2)])
562 );
563 }
564
565 #[test]
566 fn test_scalar_subtract_float() {
567 use vortex_buffer::buffer;
568
569 use crate::IntoArray;
570 use crate::arrays::ConstantArray;
571 use crate::arrays::PrimitiveArray;
572
573 let values = buffer![1.0f64, 2.0, 3.0].into_array();
574 let rhs = ConstantArray::new(Scalar::from(-1f64), 3).into_array();
575 let result = values.binary(rhs, Operator::Sub).unwrap();
576 assert_arrays_eq!(result, PrimitiveArray::from_iter([2.0f64, 3.0, 4.0]));
577 }
578
579 #[test]
580 fn test_scalar_subtract_float_underflow_is_ok() {
581 use vortex_buffer::buffer;
582
583 use crate::IntoArray;
584 use crate::arrays::ConstantArray;
585
586 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
587 let rhs1 = ConstantArray::new(Scalar::from(1.0f32), 3).into_array();
588 let _results = values.binary(rhs1, Operator::Sub).unwrap();
589 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
590 let rhs2 = ConstantArray::new(Scalar::from(f32::MAX), 3).into_array();
591 let _results = values.binary(rhs2, Operator::Sub).unwrap();
592 }
593}