Skip to main content

vortex_array/scalar_fn/fns/binary/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6#[expect(deprecated)]
7pub use boolean::and_kleene;
8#[expect(deprecated)]
9pub use boolean::or_kleene;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::dtype::DType;
19use crate::expr::StatsCatalog;
20use crate::expr::and;
21use crate::expr::and_collect;
22use crate::expr::eq;
23use crate::expr::expression::Expression;
24use crate::expr::gt;
25use crate::expr::gt_eq;
26use crate::expr::lit;
27use crate::expr::lt;
28use crate::expr::lt_eq;
29use crate::expr::or_collect;
30use crate::expr::stats::Stat;
31use crate::scalar_fn::Arity;
32use crate::scalar_fn::ChildName;
33use crate::scalar_fn::ExecutionArgs;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36use crate::scalar_fn::fns::operators::CompareOperator;
37use crate::scalar_fn::fns::operators::Operator;
38
39pub(crate) mod boolean;
40pub(crate) use boolean::*;
41mod compare;
42pub use compare::*;
43mod numeric;
44pub(crate) use numeric::*;
45
46use crate::scalar::NumericOperator;
47
48#[derive(Clone)]
49pub struct Binary;
50
51impl ScalarFnVTable for Binary {
52    type Options = Operator;
53
54    fn id(&self) -> ScalarFnId {
55        ScalarFnId::from("vortex.binary")
56    }
57
58    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
59        Ok(Some(
60            pb::BinaryOpts {
61                op: (*instance).into(),
62            }
63            .encode_to_vec(),
64        ))
65    }
66
67    fn deserialize(
68        &self,
69        _metadata: &[u8],
70        _session: &VortexSession,
71    ) -> VortexResult<Self::Options> {
72        let opts = pb::BinaryOpts::decode(_metadata)?;
73        Operator::try_from(opts.op)
74    }
75
76    fn arity(&self, _options: &Self::Options) -> Arity {
77        Arity::Exact(2)
78    }
79
80    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
81        match child_idx {
82            0 => ChildName::from("lhs"),
83            1 => ChildName::from("rhs"),
84            _ => unreachable!("Binary has only two children"),
85        }
86    }
87
88    fn fmt_sql(
89        &self,
90        operator: &Operator,
91        expr: &Expression,
92        f: &mut Formatter<'_>,
93    ) -> std::fmt::Result {
94        write!(f, "(")?;
95        expr.child(0).fmt_sql(f)?;
96        write!(f, " {} ", operator)?;
97        expr.child(1).fmt_sql(f)?;
98        write!(f, ")")
99    }
100
101    fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
102        let lhs = &arg_dtypes[0];
103        let rhs = &arg_dtypes[1];
104
105        if operator.is_arithmetic() {
106            if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
107                return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
108            }
109            vortex_bail!(
110                "incompatible types for arithmetic operation: {} {}",
111                lhs,
112                rhs
113            );
114        }
115
116        if operator.is_comparison()
117            && !lhs.eq_ignore_nullability(rhs)
118            && !lhs.is_extension()
119            && !rhs.is_extension()
120        {
121            vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs);
122        }
123
124        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
125    }
126
127    fn execute(
128        &self,
129        op: &Operator,
130        args: &dyn ExecutionArgs,
131        _ctx: &mut ExecutionCtx,
132    ) -> VortexResult<ArrayRef> {
133        let lhs = args.get(0)?;
134        let rhs = args.get(1)?;
135
136        match op {
137            Operator::Eq => execute_compare(&lhs, &rhs, CompareOperator::Eq),
138            Operator::NotEq => execute_compare(&lhs, &rhs, CompareOperator::NotEq),
139            Operator::Lt => execute_compare(&lhs, &rhs, CompareOperator::Lt),
140            Operator::Lte => execute_compare(&lhs, &rhs, CompareOperator::Lte),
141            Operator::Gt => execute_compare(&lhs, &rhs, CompareOperator::Gt),
142            Operator::Gte => execute_compare(&lhs, &rhs, CompareOperator::Gte),
143            Operator::And => execute_boolean(&lhs, &rhs, Operator::And),
144            Operator::Or => execute_boolean(&lhs, &rhs, Operator::Or),
145            Operator::Add => execute_numeric(&lhs, &rhs, NumericOperator::Add),
146            Operator::Sub => execute_numeric(&lhs, &rhs, NumericOperator::Sub),
147            Operator::Mul => execute_numeric(&lhs, &rhs, NumericOperator::Mul),
148            Operator::Div => execute_numeric(&lhs, &rhs, NumericOperator::Div),
149        }
150    }
151
152    fn stat_falsification(
153        &self,
154        operator: &Operator,
155        expr: &Expression,
156        catalog: &dyn StatsCatalog,
157    ) -> Option<Expression> {
158        // Wrap another predicate with an optional NaNCount check, if the stat is available.
159        //
160        // For example, regular pruning conversion for `A >= B` would be
161        //
162        //      A.max < B.min
163        //
164        // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting
165        // in:
166        //
167        //      (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min
168        //
169        // Non-floating point column and literal expressions should be unaffected as they do not
170        // have a nan_count statistic defined.
171        #[inline]
172        fn with_nan_predicate(
173            lhs: &Expression,
174            rhs: &Expression,
175            value_predicate: Expression,
176            catalog: &dyn StatsCatalog,
177        ) -> Expression {
178            let nan_predicate = and_collect(
179                lhs.stat_expression(Stat::NaNCount, catalog)
180                    .into_iter()
181                    .chain(rhs.stat_expression(Stat::NaNCount, catalog))
182                    .map(|nans| eq(nans, lit(0u64))),
183            );
184
185            if let Some(nan_check) = nan_predicate {
186                and(nan_check, value_predicate)
187            } else {
188                value_predicate
189            }
190        }
191
192        let lhs = expr.child(0);
193        let rhs = expr.child(1);
194        match operator {
195            Operator::Eq => {
196                let min_lhs = lhs.stat_min(catalog);
197                let max_lhs = lhs.stat_max(catalog);
198
199                let min_rhs = rhs.stat_min(catalog);
200                let max_rhs = rhs.stat_max(catalog);
201
202                let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
203                let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
204
205                let min_max_check = or_collect(left.into_iter().chain(right))?;
206
207                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
208                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
209            }
210            Operator::NotEq => {
211                let min_lhs = lhs.stat_min(catalog)?;
212                let max_lhs = lhs.stat_max(catalog)?;
213
214                let min_rhs = rhs.stat_min(catalog)?;
215                let max_rhs = rhs.stat_max(catalog)?;
216
217                let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
218
219                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
220            }
221            Operator::Gt => {
222                let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
223
224                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
225            }
226            Operator::Gte => {
227                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
228                let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
229
230                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
231            }
232            Operator::Lt => {
233                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
234                let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
235
236                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
237            }
238            Operator::Lte => {
239                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
240                let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
241
242                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
243            }
244            Operator::And => or_collect(
245                lhs.stat_falsification(catalog)
246                    .into_iter()
247                    .chain(rhs.stat_falsification(catalog)),
248            ),
249            Operator::Or => Some(and(
250                lhs.stat_falsification(catalog)?,
251                rhs.stat_falsification(catalog)?,
252            )),
253            Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
254        }
255    }
256
257    fn validity(
258        &self,
259        operator: &Operator,
260        expression: &Expression,
261    ) -> VortexResult<Option<Expression>> {
262        let lhs = expression.child(0).validity()?;
263        let rhs = expression.child(1).validity()?;
264
265        Ok(match operator {
266            // AND and OR are kleene logic.
267            Operator::And => None,
268            Operator::Or => None,
269            _ => {
270                // All other binary operators are null if either side is null.
271                Some(and(lhs, rhs))
272            }
273        })
274    }
275
276    fn is_null_sensitive(&self, _operator: &Operator) -> bool {
277        false
278    }
279
280    fn is_fallible(&self, operator: &Operator) -> bool {
281        // Opt-in not out for fallibility.
282        // Arithmetic operations could be better modelled here.
283        let infallible = matches!(
284            operator,
285            Operator::Eq
286                | Operator::NotEq
287                | Operator::Gt
288                | Operator::Gte
289                | Operator::Lt
290                | Operator::Lte
291                | Operator::And
292                | Operator::Or
293        );
294
295        !infallible
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use vortex_error::VortexExpect;
302
303    use super::*;
304    use crate::assert_arrays_eq;
305    use crate::builtins::ArrayBuiltins;
306    use crate::dtype::DType;
307    use crate::dtype::Nullability;
308    use crate::expr::Expression;
309    use crate::expr::and_collect;
310    use crate::expr::col;
311    use crate::expr::lit;
312    use crate::expr::lt;
313    use crate::expr::not_eq;
314    use crate::expr::or;
315    use crate::expr::or_collect;
316    use crate::expr::test_harness;
317    use crate::scalar::Scalar;
318    #[test]
319    fn and_collect_balanced() {
320        let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
321
322        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
323        vortex.binary(and)
324        ├── lhs: vortex.binary(and)
325        │   ├── lhs: vortex.literal(1i32)
326        │   └── rhs: vortex.literal(2i32)
327        └── rhs: vortex.binary(and)
328            ├── lhs: vortex.binary(and)
329            │   ├── lhs: vortex.literal(3i32)
330            │   └── rhs: vortex.literal(4i32)
331            └── rhs: vortex.literal(5i32)
332        ");
333
334        // 4 elements: and(and(1, 2), and(3, 4)) - perfectly balanced
335        let values = vec![lit(1), lit(2), lit(3), lit(4)];
336        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
337        vortex.binary(and)
338        ├── lhs: vortex.binary(and)
339        │   ├── lhs: vortex.literal(1i32)
340        │   └── rhs: vortex.literal(2i32)
341        └── rhs: vortex.binary(and)
342            ├── lhs: vortex.literal(3i32)
343            └── rhs: vortex.literal(4i32)
344        ");
345
346        // 1 element: just the element
347        let values = vec![lit(1)];
348        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @"vortex.literal(1i32)");
349
350        // 0 elements: None
351        let values: Vec<Expression> = vec![];
352        assert!(and_collect(values.into_iter()).is_none());
353    }
354
355    #[test]
356    fn or_collect_balanced() {
357        // 4 elements: or(or(1, 2), or(3, 4)) - perfectly balanced
358        let values = vec![lit(1), lit(2), lit(3), lit(4)];
359        insta::assert_snapshot!(or_collect(values.into_iter()).unwrap().display_tree(), @r"
360        vortex.binary(or)
361        ├── lhs: vortex.binary(or)
362        │   ├── lhs: vortex.literal(1i32)
363        │   └── rhs: vortex.literal(2i32)
364        └── rhs: vortex.binary(or)
365            ├── lhs: vortex.literal(3i32)
366            └── rhs: vortex.literal(4i32)
367        ");
368    }
369
370    #[test]
371    fn dtype() {
372        let dtype = test_harness::struct_dtype();
373        let bool1: Expression = col("bool1");
374        let bool2: Expression = col("bool2");
375        assert_eq!(
376            and(bool1.clone(), bool2.clone())
377                .return_dtype(&dtype)
378                .unwrap(),
379            DType::Bool(Nullability::NonNullable)
380        );
381        assert_eq!(
382            or(bool1, bool2).return_dtype(&dtype).unwrap(),
383            DType::Bool(Nullability::NonNullable)
384        );
385
386        let col1: Expression = col("col1");
387        let col2: Expression = col("col2");
388
389        assert_eq!(
390            eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
391            DType::Bool(Nullability::Nullable)
392        );
393        assert_eq!(
394            not_eq(col1.clone(), col2.clone())
395                .return_dtype(&dtype)
396                .unwrap(),
397            DType::Bool(Nullability::Nullable)
398        );
399        assert_eq!(
400            gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
401            DType::Bool(Nullability::Nullable)
402        );
403        assert_eq!(
404            gt_eq(col1.clone(), col2.clone())
405                .return_dtype(&dtype)
406                .unwrap(),
407            DType::Bool(Nullability::Nullable)
408        );
409        assert_eq!(
410            lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
411            DType::Bool(Nullability::Nullable)
412        );
413        assert_eq!(
414            lt_eq(col1.clone(), col2.clone())
415                .return_dtype(&dtype)
416                .unwrap(),
417            DType::Bool(Nullability::Nullable)
418        );
419
420        assert_eq!(
421            or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
422                .return_dtype(&dtype)
423                .unwrap(),
424            DType::Bool(Nullability::Nullable)
425        );
426    }
427
428    #[test]
429    fn test_display_print() {
430        let expr = gt(lit(1), lit(2));
431        assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
432    }
433
434    /// Regression test for GitHub issue #5947: struct comparison in filter expressions should work
435    /// using `make_comparator` instead of Arrow's `cmp` functions which don't support nested types.
436    #[test]
437    fn test_struct_comparison() {
438        use crate::IntoArray;
439        use crate::arrays::StructArray;
440
441        // Create a struct array with one element for testing.
442        let lhs_struct = StructArray::from_fields(&[
443            (
444                "a",
445                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
446            ),
447            (
448                "b",
449                crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
450            ),
451        ])
452        .unwrap()
453        .into_array();
454
455        let rhs_struct_equal = StructArray::from_fields(&[
456            (
457                "a",
458                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
459            ),
460            (
461                "b",
462                crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
463            ),
464        ])
465        .unwrap()
466        .into_array();
467
468        let rhs_struct_different = StructArray::from_fields(&[
469            (
470                "a",
471                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
472            ),
473            (
474                "b",
475                crate::arrays::PrimitiveArray::from_iter([4i32]).into_array(),
476            ),
477        ])
478        .unwrap()
479        .into_array();
480
481        // Test using binary method directly
482        let result_equal = lhs_struct.binary(rhs_struct_equal, Operator::Eq).unwrap();
483        assert_eq!(
484            result_equal.scalar_at(0).vortex_expect("value"),
485            Scalar::bool(true, Nullability::NonNullable),
486            "Equal structs should be equal"
487        );
488
489        let result_different = lhs_struct
490            .binary(rhs_struct_different, Operator::Eq)
491            .unwrap();
492        assert_eq!(
493            result_different.scalar_at(0).vortex_expect("value"),
494            Scalar::bool(false, Nullability::NonNullable),
495            "Different structs should not be equal"
496        );
497    }
498
499    #[test]
500    fn test_or_kleene_validity() {
501        use crate::IntoArray;
502        use crate::arrays::BoolArray;
503        use crate::arrays::StructArray;
504        use crate::expr::col;
505
506        let struct_arr = StructArray::from_fields(&[
507            ("a", BoolArray::from_iter([Some(true)]).into_array()),
508            (
509                "b",
510                BoolArray::from_iter([Option::<bool>::None]).into_array(),
511            ),
512        ])
513        .unwrap()
514        .into_array();
515
516        let expr = or(col("a"), col("b"));
517        let result = struct_arr.apply(&expr).unwrap();
518
519        assert_arrays_eq!(result, BoolArray::from_iter([Some(true)]).into_array())
520    }
521
522    #[test]
523    fn test_scalar_subtract_unsigned() {
524        use vortex_buffer::buffer;
525
526        use crate::IntoArray;
527        use crate::arrays::ConstantArray;
528        use crate::arrays::PrimitiveArray;
529
530        let values = buffer![1u16, 2, 3].into_array();
531        let rhs = ConstantArray::new(Scalar::from(1u16), 3).into_array();
532        let result = values.binary(rhs, Operator::Sub).unwrap();
533        assert_arrays_eq!(result, PrimitiveArray::from_iter([0u16, 1, 2]));
534    }
535
536    #[test]
537    fn test_scalar_subtract_signed() {
538        use vortex_buffer::buffer;
539
540        use crate::IntoArray;
541        use crate::arrays::ConstantArray;
542        use crate::arrays::PrimitiveArray;
543
544        let values = buffer![1i64, 2, 3].into_array();
545        let rhs = ConstantArray::new(Scalar::from(-1i64), 3).into_array();
546        let result = values.binary(rhs, Operator::Sub).unwrap();
547        assert_arrays_eq!(result, PrimitiveArray::from_iter([2i64, 3, 4]));
548    }
549
550    #[test]
551    fn test_scalar_subtract_nullable() {
552        use crate::IntoArray;
553        use crate::arrays::ConstantArray;
554        use crate::arrays::PrimitiveArray;
555
556        let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
557        let rhs = ConstantArray::new(Scalar::from(Some(1u16)), 4).into_array();
558        let result = values.into_array().binary(rhs, Operator::Sub).unwrap();
559        assert_arrays_eq!(
560            result,
561            PrimitiveArray::from_option_iter([Some(0u16), Some(1), None, Some(2)])
562        );
563    }
564
565    #[test]
566    fn test_scalar_subtract_float() {
567        use vortex_buffer::buffer;
568
569        use crate::IntoArray;
570        use crate::arrays::ConstantArray;
571        use crate::arrays::PrimitiveArray;
572
573        let values = buffer![1.0f64, 2.0, 3.0].into_array();
574        let rhs = ConstantArray::new(Scalar::from(-1f64), 3).into_array();
575        let result = values.binary(rhs, Operator::Sub).unwrap();
576        assert_arrays_eq!(result, PrimitiveArray::from_iter([2.0f64, 3.0, 4.0]));
577    }
578
579    #[test]
580    fn test_scalar_subtract_float_underflow_is_ok() {
581        use vortex_buffer::buffer;
582
583        use crate::IntoArray;
584        use crate::arrays::ConstantArray;
585
586        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
587        let rhs1 = ConstantArray::new(Scalar::from(1.0f32), 3).into_array();
588        let _results = values.binary(rhs1, Operator::Sub).unwrap();
589        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
590        let rhs2 = ConstantArray::new(Scalar::from(f32::MAX), 3).into_array();
591        let _results = values.binary(rhs2, Operator::Sub).unwrap();
592    }
593}