Skip to main content

vortex_array/scalar_fn/fns/binary/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6#[expect(deprecated)]
7pub use boolean::and_kleene;
8#[expect(deprecated)]
9pub use boolean::or_kleene;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15use vortex_session::registry::CachedId;
16
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::dtype::DType;
20use crate::dtype::Nullability;
21use crate::expr::and;
22use crate::expr::expression::Expression;
23use crate::expr::lit;
24use crate::scalar_fn::Arity;
25use crate::scalar_fn::ChildName;
26use crate::scalar_fn::ExecutionArgs;
27use crate::scalar_fn::ScalarFnId;
28use crate::scalar_fn::ScalarFnVTable;
29use crate::scalar_fn::SimplifyCtx;
30use crate::scalar_fn::fns::literal::Literal;
31use crate::scalar_fn::fns::operators::CompareOperator;
32use crate::scalar_fn::fns::operators::Operator;
33
34pub mod boolean;
35pub use boolean::BooleanExecuteAdaptor;
36pub use boolean::BooleanKernel;
37pub(crate) use boolean::execute_boolean;
38pub use boolean::kleene_boolean_buffer_scalar;
39pub use boolean::kleene_boolean_buffers;
40mod compare;
41pub use compare::*;
42mod numeric;
43pub(crate) use numeric::*;
44
45use crate::scalar::NumericOperator;
46use crate::scalar::Scalar;
47
48#[derive(Clone)]
49pub struct Binary;
50
51impl ScalarFnVTable for Binary {
52    type Options = Operator;
53
54    fn id(&self) -> ScalarFnId {
55        static ID: CachedId = CachedId::new("vortex.binary");
56        *ID
57    }
58
59    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
60        Ok(Some(
61            pb::BinaryOpts {
62                op: (*instance).into(),
63            }
64            .encode_to_vec(),
65        ))
66    }
67
68    fn deserialize(
69        &self,
70        _metadata: &[u8],
71        _session: &VortexSession,
72    ) -> VortexResult<Self::Options> {
73        let opts = pb::BinaryOpts::decode(_metadata)?;
74        Operator::try_from(opts.op)
75    }
76
77    fn arity(&self, _options: &Self::Options) -> Arity {
78        Arity::Exact(2)
79    }
80
81    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
82        match child_idx {
83            0 => ChildName::from("lhs"),
84            1 => ChildName::from("rhs"),
85            _ => unreachable!("Binary has only two children"),
86        }
87    }
88
89    fn fmt_sql(
90        &self,
91        operator: &Operator,
92        expr: &Expression,
93        f: &mut Formatter<'_>,
94    ) -> std::fmt::Result {
95        write!(f, "(")?;
96        expr.child(0).fmt_sql(f)?;
97        write!(f, " {} ", operator)?;
98        expr.child(1).fmt_sql(f)?;
99        write!(f, ")")
100    }
101
102    fn coerce_args(&self, operator: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
103        let lhs = &args[0];
104        let rhs = &args[1];
105        if operator.is_arithmetic() || operator.is_comparison() {
106            let supertype = lhs.least_supertype(rhs).ok_or_else(|| {
107                vortex_error::vortex_err!("No common supertype for {} and {}", lhs, rhs)
108            })?;
109            Ok(vec![supertype.clone(), supertype])
110        } else {
111            // Boolean And/Or: no coercion
112            Ok(args.to_vec())
113        }
114    }
115
116    fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
117        let lhs = &arg_dtypes[0];
118        let rhs = &arg_dtypes[1];
119
120        if operator.is_arithmetic() {
121            if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
122                return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
123            }
124            vortex_bail!(
125                "incompatible types for arithmetic operation: {} {}",
126                lhs,
127                rhs
128            );
129        }
130
131        if operator.is_comparison()
132            && !lhs.eq_ignore_nullability(rhs)
133            && !lhs.is_extension()
134            && !rhs.is_extension()
135        {
136            vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs);
137        }
138
139        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
140    }
141
142    fn execute(
143        &self,
144        op: &Operator,
145        args: &dyn ExecutionArgs,
146        ctx: &mut ExecutionCtx,
147    ) -> VortexResult<ArrayRef> {
148        let lhs = args.get(0)?;
149        let rhs = args.get(1)?;
150
151        match op {
152            Operator::Eq => execute_compare(&lhs, &rhs, CompareOperator::Eq, ctx),
153            Operator::NotEq => execute_compare(&lhs, &rhs, CompareOperator::NotEq, ctx),
154            Operator::Lt => execute_compare(&lhs, &rhs, CompareOperator::Lt, ctx),
155            Operator::Lte => execute_compare(&lhs, &rhs, CompareOperator::Lte, ctx),
156            Operator::Gt => execute_compare(&lhs, &rhs, CompareOperator::Gt, ctx),
157            Operator::Gte => execute_compare(&lhs, &rhs, CompareOperator::Gte, ctx),
158            Operator::And => execute_boolean(&lhs, &rhs, Operator::And, ctx),
159            Operator::Or => execute_boolean(&lhs, &rhs, Operator::Or, ctx),
160            Operator::Add => execute_numeric(&lhs, &rhs, NumericOperator::Add, ctx),
161            Operator::Sub => execute_numeric(&lhs, &rhs, NumericOperator::Sub, ctx),
162            Operator::Mul => execute_numeric(&lhs, &rhs, NumericOperator::Mul, ctx),
163            Operator::Div => execute_numeric(&lhs, &rhs, NumericOperator::Div, ctx),
164        }
165    }
166
167    fn simplify_untyped(
168        &self,
169        operator: &Operator,
170        expr: &Expression,
171    ) -> VortexResult<Option<Expression>> {
172        let lhs = expr.child(0);
173        let rhs = expr.child(1);
174
175        let bool_literal = |expr: &Expression| {
176            expr.as_opt::<Literal>()?
177                .as_bool_opt()
178                .map(|value| value.value())
179        };
180
181        // AND/OR use Kleene three-valued logic. `None` below is a boolean null.
182        //
183        // AND:
184        // - false AND x => false
185        // - true  AND x => x
186        // - null  AND null => null
187        //
188        // OR:
189        // - true  OR x => true
190        // - false OR x => x
191        // - null  OR null => null
192        //
193        // Other null cases either fall out of the identity/annihilator rules
194        // above (`null AND true`, `null OR false`) or cannot be simplified under
195        // Kleene semantics (`null AND x`, `null OR x` for non-literal `x`).
196        Ok(match operator {
197            Operator::And => match (bool_literal(lhs), bool_literal(rhs)) {
198                (Some(Some(false)), _) | (_, Some(Some(false))) => Some(lit(false)),
199                (Some(Some(true)), _) => Some(rhs.clone()),
200                (_, Some(Some(true))) => Some(lhs.clone()),
201                (Some(None), Some(None)) => Some(lhs.clone()),
202                _ => None,
203            },
204            Operator::Or => match (bool_literal(lhs), bool_literal(rhs)) {
205                (Some(Some(true)), _) | (_, Some(Some(true))) => Some(lit(true)),
206                (Some(Some(false)), _) => Some(rhs.clone()),
207                (_, Some(Some(false))) => Some(lhs.clone()),
208                (Some(None), Some(None)) => Some(lhs.clone()),
209                _ => None,
210            },
211            _ => None,
212        })
213    }
214
215    fn simplify(
216        &self,
217        operator: &Operator,
218        expr: &Expression,
219        ctx: &dyn SimplifyCtx,
220    ) -> VortexResult<Option<Expression>> {
221        let is_literal_null =
222            |expr: &Expression| expr.as_opt::<Literal>().is_some_and(Scalar::is_null);
223
224        if operator.is_comparison()
225            && (is_literal_null(expr.child(0)) || is_literal_null(expr.child(1)))
226        {
227            // Validate the comparison before reducing it. This preserves type
228            // errors for expressions like `int_col = null_utf8`.
229            ctx.return_dtype(expr)?;
230            return Ok(Some(lit(Scalar::null(DType::Bool(Nullability::Nullable)))));
231        }
232
233        Ok(None)
234    }
235
236    fn validity(
237        &self,
238        operator: &Operator,
239        expression: &Expression,
240    ) -> VortexResult<Option<Expression>> {
241        let lhs = expression.child(0).validity()?;
242        let rhs = expression.child(1).validity()?;
243
244        Ok(match operator {
245            // AND and OR are kleene logic.
246            Operator::And => None,
247            Operator::Or => None,
248            _ => {
249                // All other binary operators are null if either side is null.
250                Some(and(lhs, rhs))
251            }
252        })
253    }
254
255    fn is_null_sensitive(&self, _operator: &Operator) -> bool {
256        false
257    }
258
259    fn is_fallible(&self, operator: &Operator) -> bool {
260        // Opt-in not out for fallibility.
261        // Arithmetic operations could be better modelled here.
262        let infallible = matches!(
263            operator,
264            Operator::Eq
265                | Operator::NotEq
266                | Operator::Gt
267                | Operator::Gte
268                | Operator::Lt
269                | Operator::Lte
270                | Operator::And
271                | Operator::Or
272        );
273
274        !infallible
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use vortex_error::VortexExpect;
281    use vortex_error::VortexResult;
282
283    use super::*;
284    use crate::VortexSessionExecute;
285    use crate::array_session;
286    use crate::assert_arrays_eq;
287    use crate::builtins::ArrayBuiltins;
288    use crate::dtype::DType;
289    use crate::dtype::Nullability;
290    use crate::dtype::PType;
291    use crate::expr::Expression;
292    use crate::expr::and_collect;
293    use crate::expr::col;
294    use crate::expr::eq;
295    use crate::expr::gt;
296    use crate::expr::gt_eq;
297    use crate::expr::lit;
298    use crate::expr::lt;
299    use crate::expr::lt_eq;
300    use crate::expr::not_eq;
301    use crate::expr::or;
302    use crate::expr::or_collect;
303    use crate::expr::test_harness;
304    use crate::scalar::Scalar;
305    #[test]
306    fn and_collect_balanced() {
307        let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
308
309        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
310        vortex.binary(and)
311        ├── lhs: vortex.binary(and)
312        │   ├── lhs: vortex.literal(1i32)
313        │   └── rhs: vortex.literal(2i32)
314        └── rhs: vortex.binary(and)
315            ├── lhs: vortex.binary(and)
316            │   ├── lhs: vortex.literal(3i32)
317            │   └── rhs: vortex.literal(4i32)
318            └── rhs: vortex.literal(5i32)
319        ");
320
321        // 4 elements: and(and(1, 2), and(3, 4)) - perfectly balanced
322        let values = vec![lit(1), lit(2), lit(3), lit(4)];
323        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
324        vortex.binary(and)
325        ├── lhs: vortex.binary(and)
326        │   ├── lhs: vortex.literal(1i32)
327        │   └── rhs: vortex.literal(2i32)
328        └── rhs: vortex.binary(and)
329            ├── lhs: vortex.literal(3i32)
330            └── rhs: vortex.literal(4i32)
331        ");
332
333        // 1 element: just the element
334        let values = vec![lit(1)];
335        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @"vortex.literal(1i32)");
336
337        // 0 elements: None
338        let values: Vec<Expression> = vec![];
339        assert!(and_collect(values.into_iter()).is_none());
340    }
341
342    #[test]
343    fn or_collect_balanced() {
344        // 4 elements: or(or(1, 2), or(3, 4)) - perfectly balanced
345        let values = vec![lit(1), lit(2), lit(3), lit(4)];
346        insta::assert_snapshot!(or_collect(values.into_iter()).unwrap().display_tree(), @r"
347        vortex.binary(or)
348        ├── lhs: vortex.binary(or)
349        │   ├── lhs: vortex.literal(1i32)
350        │   └── rhs: vortex.literal(2i32)
351        └── rhs: vortex.binary(or)
352            ├── lhs: vortex.literal(3i32)
353            └── rhs: vortex.literal(4i32)
354        ");
355    }
356
357    #[test]
358    fn dtype() {
359        let dtype = test_harness::struct_dtype();
360        let bool1: Expression = col("bool1");
361        let bool2: Expression = col("bool2");
362        assert_eq!(
363            and(bool1.clone(), bool2.clone())
364                .return_dtype(&dtype)
365                .unwrap(),
366            DType::Bool(Nullability::NonNullable)
367        );
368        assert_eq!(
369            or(bool1, bool2).return_dtype(&dtype).unwrap(),
370            DType::Bool(Nullability::NonNullable)
371        );
372
373        let col1: Expression = col("col1");
374        let col2: Expression = col("col2");
375
376        assert_eq!(
377            eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
378            DType::Bool(Nullability::Nullable)
379        );
380        assert_eq!(
381            not_eq(col1.clone(), col2.clone())
382                .return_dtype(&dtype)
383                .unwrap(),
384            DType::Bool(Nullability::Nullable)
385        );
386        assert_eq!(
387            gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
388            DType::Bool(Nullability::Nullable)
389        );
390        assert_eq!(
391            gt_eq(col1.clone(), col2.clone())
392                .return_dtype(&dtype)
393                .unwrap(),
394            DType::Bool(Nullability::Nullable)
395        );
396        assert_eq!(
397            lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
398            DType::Bool(Nullability::Nullable)
399        );
400        assert_eq!(
401            lt_eq(col1.clone(), col2.clone())
402                .return_dtype(&dtype)
403                .unwrap(),
404            DType::Bool(Nullability::Nullable)
405        );
406
407        assert_eq!(
408            or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
409                .return_dtype(&dtype)
410                .unwrap(),
411            DType::Bool(Nullability::Nullable)
412        );
413    }
414
415    #[test]
416    fn comparison_with_typed_null_simplifies_after_type_check() -> VortexResult<()> {
417        let dtype = test_harness::struct_dtype();
418
419        let expr = eq(
420            col("col1"),
421            lit(Scalar::null(DType::Primitive(
422                PType::U16,
423                Nullability::Nullable,
424            ))),
425        );
426
427        assert_eq!(
428            expr.optimize_recursive(&dtype)?,
429            lit(Scalar::null(DType::Bool(Nullability::Nullable)))
430        );
431        Ok(())
432    }
433
434    #[test]
435    fn comparison_with_incompatible_null_still_type_checks() {
436        let dtype = test_harness::struct_dtype();
437        let expr = eq(
438            col("col1"),
439            lit(Scalar::null(DType::Utf8(Nullability::Nullable))),
440        );
441
442        assert!(expr.optimize_recursive(&dtype).is_err());
443    }
444
445    #[test]
446    fn test_display_print() {
447        let expr = gt(lit(1), lit(2));
448        assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
449    }
450
451    /// Regression test for GitHub issue #5947: struct comparison in filter expressions should work
452    /// using `make_comparator` instead of Arrow's `cmp` functions which don't support nested types.
453    #[test]
454    fn test_struct_comparison() {
455        use crate::IntoArray;
456        use crate::arrays::StructArray;
457
458        // Create a struct array with one element for testing.
459        let lhs_struct = StructArray::from_fields(&[
460            (
461                "a",
462                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
463            ),
464            (
465                "b",
466                crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
467            ),
468        ])
469        .unwrap()
470        .into_array();
471
472        let rhs_struct_equal = StructArray::from_fields(&[
473            (
474                "a",
475                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
476            ),
477            (
478                "b",
479                crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
480            ),
481        ])
482        .unwrap()
483        .into_array();
484
485        let rhs_struct_different = StructArray::from_fields(&[
486            (
487                "a",
488                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
489            ),
490            (
491                "b",
492                crate::arrays::PrimitiveArray::from_iter([4i32]).into_array(),
493            ),
494        ])
495        .unwrap()
496        .into_array();
497
498        // Test using binary method directly
499        let result_equal = lhs_struct.binary(rhs_struct_equal, Operator::Eq).unwrap();
500        assert_eq!(
501            result_equal
502                .execute_scalar(0, &mut array_session().create_execution_ctx())
503                .vortex_expect("value"),
504            Scalar::bool(true, Nullability::NonNullable),
505            "Equal structs should be equal"
506        );
507
508        let result_different = lhs_struct
509            .binary(rhs_struct_different, Operator::Eq)
510            .unwrap();
511        assert_eq!(
512            result_different
513                .execute_scalar(0, &mut array_session().create_execution_ctx())
514                .vortex_expect("value"),
515            Scalar::bool(false, Nullability::NonNullable),
516            "Different structs should not be equal"
517        );
518    }
519
520    #[test]
521    fn test_or_kleene_validity() {
522        let mut ctx = array_session().create_execution_ctx();
523        use crate::IntoArray;
524        use crate::arrays::BoolArray;
525        use crate::arrays::StructArray;
526        use crate::expr::col;
527
528        let struct_arr = StructArray::from_fields(&[
529            ("a", BoolArray::from_iter([Some(true)]).into_array()),
530            (
531                "b",
532                BoolArray::from_iter([Option::<bool>::None]).into_array(),
533            ),
534        ])
535        .unwrap()
536        .into_array();
537
538        let expr = or(col("a"), col("b"));
539        let result = struct_arr.apply(&expr).unwrap();
540
541        assert_arrays_eq!(
542            result,
543            BoolArray::from_iter([Some(true)]).into_array(),
544            &mut ctx
545        )
546    }
547
548    #[test]
549    fn test_scalar_subtract_unsigned() {
550        let mut ctx = array_session().create_execution_ctx();
551        use vortex_buffer::buffer;
552
553        use crate::IntoArray;
554        use crate::arrays::ConstantArray;
555        use crate::arrays::PrimitiveArray;
556
557        let values = buffer![1u16, 2, 3].into_array();
558        let rhs = ConstantArray::new(Scalar::from(1u16), 3).into_array();
559        let result = values.binary(rhs, Operator::Sub).unwrap();
560        assert_arrays_eq!(result, PrimitiveArray::from_iter([0u16, 1, 2]), &mut ctx);
561    }
562
563    #[test]
564    fn test_scalar_subtract_signed() {
565        let mut ctx = array_session().create_execution_ctx();
566        use vortex_buffer::buffer;
567
568        use crate::IntoArray;
569        use crate::arrays::ConstantArray;
570        use crate::arrays::PrimitiveArray;
571
572        let values = buffer![1i64, 2, 3].into_array();
573        let rhs = ConstantArray::new(Scalar::from(-1i64), 3).into_array();
574        let result = values.binary(rhs, Operator::Sub).unwrap();
575        assert_arrays_eq!(result, PrimitiveArray::from_iter([2i64, 3, 4]), &mut ctx);
576    }
577
578    #[test]
579    fn test_scalar_subtract_nullable() {
580        let mut ctx = array_session().create_execution_ctx();
581        use crate::IntoArray;
582        use crate::arrays::ConstantArray;
583        use crate::arrays::PrimitiveArray;
584
585        let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
586        let rhs = ConstantArray::new(Scalar::from(Some(1u16)), 4).into_array();
587        let result = values.into_array().binary(rhs, Operator::Sub).unwrap();
588        assert_arrays_eq!(
589            result,
590            PrimitiveArray::from_option_iter([Some(0u16), Some(1), None, Some(2)]),
591            &mut ctx
592        );
593    }
594
595    #[test]
596    fn test_scalar_subtract_float() {
597        let mut ctx = array_session().create_execution_ctx();
598        use vortex_buffer::buffer;
599
600        use crate::IntoArray;
601        use crate::arrays::ConstantArray;
602        use crate::arrays::PrimitiveArray;
603
604        let values = buffer![1.0f64, 2.0, 3.0].into_array();
605        let rhs = ConstantArray::new(Scalar::from(-1f64), 3).into_array();
606        let result = values.binary(rhs, Operator::Sub).unwrap();
607        assert_arrays_eq!(
608            result,
609            PrimitiveArray::from_iter([2.0f64, 3.0, 4.0]),
610            &mut ctx
611        );
612    }
613
614    #[test]
615    fn test_scalar_subtract_float_underflow_is_ok() {
616        use vortex_buffer::buffer;
617
618        use crate::IntoArray;
619        use crate::arrays::ConstantArray;
620
621        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
622        let rhs1 = ConstantArray::new(Scalar::from(1.0f32), 3).into_array();
623        let _results = values.binary(rhs1, Operator::Sub).unwrap();
624        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
625        let rhs2 = ConstantArray::new(Scalar::from(f32::MAX), 3).into_array();
626        let _results = values.binary(rhs2, Operator::Sub).unwrap();
627    }
628}