Skip to main content

vortex_array/scalar_fn/fns/binary/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6#[expect(deprecated)]
7pub use boolean::and_kleene;
8#[expect(deprecated)]
9pub use boolean::or_kleene;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::dtype::DType;
19use crate::expr::StatsCatalog;
20use crate::expr::and;
21use crate::expr::and_collect;
22use crate::expr::eq;
23use crate::expr::expression::Expression;
24use crate::expr::gt;
25use crate::expr::gt_eq;
26use crate::expr::lit;
27use crate::expr::lt;
28use crate::expr::lt_eq;
29use crate::expr::or_collect;
30use crate::expr::stats::Stat;
31use crate::scalar_fn::Arity;
32use crate::scalar_fn::ChildName;
33use crate::scalar_fn::ExecutionArgs;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36use crate::scalar_fn::fns::operators::CompareOperator;
37use crate::scalar_fn::fns::operators::Operator;
38
39pub(crate) mod boolean;
40pub(crate) use boolean::*;
41mod compare;
42pub use compare::*;
43mod numeric;
44pub(crate) use numeric::*;
45
46use crate::scalar::NumericOperator;
47
48#[derive(Clone)]
49pub struct Binary;
50
51impl ScalarFnVTable for Binary {
52    type Options = Operator;
53
54    fn id(&self) -> ScalarFnId {
55        ScalarFnId::new("vortex.binary")
56    }
57
58    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
59        Ok(Some(
60            pb::BinaryOpts {
61                op: (*instance).into(),
62            }
63            .encode_to_vec(),
64        ))
65    }
66
67    fn deserialize(
68        &self,
69        _metadata: &[u8],
70        _session: &VortexSession,
71    ) -> VortexResult<Self::Options> {
72        let opts = pb::BinaryOpts::decode(_metadata)?;
73        Operator::try_from(opts.op)
74    }
75
76    fn arity(&self, _options: &Self::Options) -> Arity {
77        Arity::Exact(2)
78    }
79
80    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
81        match child_idx {
82            0 => ChildName::from("lhs"),
83            1 => ChildName::from("rhs"),
84            _ => unreachable!("Binary has only two children"),
85        }
86    }
87
88    fn fmt_sql(
89        &self,
90        operator: &Operator,
91        expr: &Expression,
92        f: &mut Formatter<'_>,
93    ) -> std::fmt::Result {
94        write!(f, "(")?;
95        expr.child(0).fmt_sql(f)?;
96        write!(f, " {} ", operator)?;
97        expr.child(1).fmt_sql(f)?;
98        write!(f, ")")
99    }
100
101    fn coerce_args(&self, operator: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
102        let lhs = &args[0];
103        let rhs = &args[1];
104        if operator.is_arithmetic() || operator.is_comparison() {
105            let supertype = lhs.least_supertype(rhs).ok_or_else(|| {
106                vortex_error::vortex_err!("No common supertype for {} and {}", lhs, rhs)
107            })?;
108            Ok(vec![supertype.clone(), supertype])
109        } else {
110            // Boolean And/Or: no coercion
111            Ok(args.to_vec())
112        }
113    }
114
115    fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
116        let lhs = &arg_dtypes[0];
117        let rhs = &arg_dtypes[1];
118
119        if operator.is_arithmetic() {
120            if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
121                return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
122            }
123            vortex_bail!(
124                "incompatible types for arithmetic operation: {} {}",
125                lhs,
126                rhs
127            );
128        }
129
130        if operator.is_comparison()
131            && !lhs.eq_ignore_nullability(rhs)
132            && !lhs.is_extension()
133            && !rhs.is_extension()
134        {
135            vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs);
136        }
137
138        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
139    }
140
141    fn execute(
142        &self,
143        op: &Operator,
144        args: &dyn ExecutionArgs,
145        _ctx: &mut ExecutionCtx,
146    ) -> VortexResult<ArrayRef> {
147        let lhs = args.get(0)?;
148        let rhs = args.get(1)?;
149
150        match op {
151            Operator::Eq => execute_compare(&lhs, &rhs, CompareOperator::Eq),
152            Operator::NotEq => execute_compare(&lhs, &rhs, CompareOperator::NotEq),
153            Operator::Lt => execute_compare(&lhs, &rhs, CompareOperator::Lt),
154            Operator::Lte => execute_compare(&lhs, &rhs, CompareOperator::Lte),
155            Operator::Gt => execute_compare(&lhs, &rhs, CompareOperator::Gt),
156            Operator::Gte => execute_compare(&lhs, &rhs, CompareOperator::Gte),
157            Operator::And => execute_boolean(&lhs, &rhs, Operator::And),
158            Operator::Or => execute_boolean(&lhs, &rhs, Operator::Or),
159            Operator::Add => execute_numeric(&lhs, &rhs, NumericOperator::Add),
160            Operator::Sub => execute_numeric(&lhs, &rhs, NumericOperator::Sub),
161            Operator::Mul => execute_numeric(&lhs, &rhs, NumericOperator::Mul),
162            Operator::Div => execute_numeric(&lhs, &rhs, NumericOperator::Div),
163        }
164    }
165
166    fn stat_falsification(
167        &self,
168        operator: &Operator,
169        expr: &Expression,
170        catalog: &dyn StatsCatalog,
171    ) -> Option<Expression> {
172        // Wrap another predicate with an optional NaNCount check, if the stat is available.
173        //
174        // For example, regular pruning conversion for `A >= B` would be
175        //
176        //      A.max < B.min
177        //
178        // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting
179        // in:
180        //
181        //      (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min
182        //
183        // Non-floating point column and literal expressions should be unaffected as they do not
184        // have a nan_count statistic defined.
185        fn with_nan_predicate(
186            lhs: &Expression,
187            rhs: &Expression,
188            value_predicate: Expression,
189            catalog: &dyn StatsCatalog,
190        ) -> Expression {
191            let nan_predicate = and_collect(
192                lhs.stat_expression(Stat::NaNCount, catalog)
193                    .into_iter()
194                    .chain(rhs.stat_expression(Stat::NaNCount, catalog))
195                    .map(|nans| eq(nans, lit(0u64))),
196            );
197
198            if let Some(nan_check) = nan_predicate {
199                and(nan_check, value_predicate)
200            } else {
201                value_predicate
202            }
203        }
204
205        let lhs = expr.child(0);
206        let rhs = expr.child(1);
207        match operator {
208            Operator::Eq => {
209                let min_lhs = lhs.stat_min(catalog);
210                let max_lhs = lhs.stat_max(catalog);
211
212                let min_rhs = rhs.stat_min(catalog);
213                let max_rhs = rhs.stat_max(catalog);
214
215                let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
216                let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
217
218                let min_max_check = or_collect(left.into_iter().chain(right))?;
219
220                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
221                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
222            }
223            Operator::NotEq => {
224                let min_lhs = lhs.stat_min(catalog)?;
225                let max_lhs = lhs.stat_max(catalog)?;
226
227                let min_rhs = rhs.stat_min(catalog)?;
228                let max_rhs = rhs.stat_max(catalog)?;
229
230                let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
231
232                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
233            }
234            Operator::Gt => {
235                let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
236
237                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
238            }
239            Operator::Gte => {
240                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
241                let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
242
243                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
244            }
245            Operator::Lt => {
246                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
247                let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
248
249                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
250            }
251            Operator::Lte => {
252                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
253                let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
254
255                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
256            }
257            Operator::And => or_collect(
258                lhs.stat_falsification(catalog)
259                    .into_iter()
260                    .chain(rhs.stat_falsification(catalog)),
261            ),
262            Operator::Or => Some(and(
263                lhs.stat_falsification(catalog)?,
264                rhs.stat_falsification(catalog)?,
265            )),
266            Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
267        }
268    }
269
270    fn validity(
271        &self,
272        operator: &Operator,
273        expression: &Expression,
274    ) -> VortexResult<Option<Expression>> {
275        let lhs = expression.child(0).validity()?;
276        let rhs = expression.child(1).validity()?;
277
278        Ok(match operator {
279            // AND and OR are kleene logic.
280            Operator::And => None,
281            Operator::Or => None,
282            _ => {
283                // All other binary operators are null if either side is null.
284                Some(and(lhs, rhs))
285            }
286        })
287    }
288
289    fn is_null_sensitive(&self, _operator: &Operator) -> bool {
290        false
291    }
292
293    fn is_fallible(&self, operator: &Operator) -> bool {
294        // Opt-in not out for fallibility.
295        // Arithmetic operations could be better modelled here.
296        let infallible = matches!(
297            operator,
298            Operator::Eq
299                | Operator::NotEq
300                | Operator::Gt
301                | Operator::Gte
302                | Operator::Lt
303                | Operator::Lte
304                | Operator::And
305                | Operator::Or
306        );
307
308        !infallible
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use vortex_error::VortexExpect;
315
316    use super::*;
317    use crate::LEGACY_SESSION;
318    use crate::VortexSessionExecute;
319    use crate::assert_arrays_eq;
320    use crate::builtins::ArrayBuiltins;
321    use crate::dtype::DType;
322    use crate::dtype::Nullability;
323    use crate::expr::Expression;
324    use crate::expr::and_collect;
325    use crate::expr::col;
326    use crate::expr::lit;
327    use crate::expr::lt;
328    use crate::expr::not_eq;
329    use crate::expr::or;
330    use crate::expr::or_collect;
331    use crate::expr::test_harness;
332    use crate::scalar::Scalar;
333    #[test]
334    fn and_collect_balanced() {
335        let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
336
337        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
338        vortex.binary(and)
339        ├── lhs: vortex.binary(and)
340        │   ├── lhs: vortex.literal(1i32)
341        │   └── rhs: vortex.literal(2i32)
342        └── rhs: vortex.binary(and)
343            ├── lhs: vortex.binary(and)
344            │   ├── lhs: vortex.literal(3i32)
345            │   └── rhs: vortex.literal(4i32)
346            └── rhs: vortex.literal(5i32)
347        ");
348
349        // 4 elements: and(and(1, 2), and(3, 4)) - perfectly balanced
350        let values = vec![lit(1), lit(2), lit(3), lit(4)];
351        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
352        vortex.binary(and)
353        ├── lhs: vortex.binary(and)
354        │   ├── lhs: vortex.literal(1i32)
355        │   └── rhs: vortex.literal(2i32)
356        └── rhs: vortex.binary(and)
357            ├── lhs: vortex.literal(3i32)
358            └── rhs: vortex.literal(4i32)
359        ");
360
361        // 1 element: just the element
362        let values = vec![lit(1)];
363        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @"vortex.literal(1i32)");
364
365        // 0 elements: None
366        let values: Vec<Expression> = vec![];
367        assert!(and_collect(values.into_iter()).is_none());
368    }
369
370    #[test]
371    fn or_collect_balanced() {
372        // 4 elements: or(or(1, 2), or(3, 4)) - perfectly balanced
373        let values = vec![lit(1), lit(2), lit(3), lit(4)];
374        insta::assert_snapshot!(or_collect(values.into_iter()).unwrap().display_tree(), @r"
375        vortex.binary(or)
376        ├── lhs: vortex.binary(or)
377        │   ├── lhs: vortex.literal(1i32)
378        │   └── rhs: vortex.literal(2i32)
379        └── rhs: vortex.binary(or)
380            ├── lhs: vortex.literal(3i32)
381            └── rhs: vortex.literal(4i32)
382        ");
383    }
384
385    #[test]
386    fn dtype() {
387        let dtype = test_harness::struct_dtype();
388        let bool1: Expression = col("bool1");
389        let bool2: Expression = col("bool2");
390        assert_eq!(
391            and(bool1.clone(), bool2.clone())
392                .return_dtype(&dtype)
393                .unwrap(),
394            DType::Bool(Nullability::NonNullable)
395        );
396        assert_eq!(
397            or(bool1, bool2).return_dtype(&dtype).unwrap(),
398            DType::Bool(Nullability::NonNullable)
399        );
400
401        let col1: Expression = col("col1");
402        let col2: Expression = col("col2");
403
404        assert_eq!(
405            eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
406            DType::Bool(Nullability::Nullable)
407        );
408        assert_eq!(
409            not_eq(col1.clone(), col2.clone())
410                .return_dtype(&dtype)
411                .unwrap(),
412            DType::Bool(Nullability::Nullable)
413        );
414        assert_eq!(
415            gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
416            DType::Bool(Nullability::Nullable)
417        );
418        assert_eq!(
419            gt_eq(col1.clone(), col2.clone())
420                .return_dtype(&dtype)
421                .unwrap(),
422            DType::Bool(Nullability::Nullable)
423        );
424        assert_eq!(
425            lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
426            DType::Bool(Nullability::Nullable)
427        );
428        assert_eq!(
429            lt_eq(col1.clone(), col2.clone())
430                .return_dtype(&dtype)
431                .unwrap(),
432            DType::Bool(Nullability::Nullable)
433        );
434
435        assert_eq!(
436            or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
437                .return_dtype(&dtype)
438                .unwrap(),
439            DType::Bool(Nullability::Nullable)
440        );
441    }
442
443    #[test]
444    fn test_display_print() {
445        let expr = gt(lit(1), lit(2));
446        assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
447    }
448
449    /// Regression test for GitHub issue #5947: struct comparison in filter expressions should work
450    /// using `make_comparator` instead of Arrow's `cmp` functions which don't support nested types.
451    #[test]
452    fn test_struct_comparison() {
453        use crate::IntoArray;
454        use crate::arrays::StructArray;
455
456        // Create a struct array with one element for testing.
457        let lhs_struct = StructArray::from_fields(&[
458            (
459                "a",
460                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
461            ),
462            (
463                "b",
464                crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
465            ),
466        ])
467        .unwrap()
468        .into_array();
469
470        let rhs_struct_equal = StructArray::from_fields(&[
471            (
472                "a",
473                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
474            ),
475            (
476                "b",
477                crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
478            ),
479        ])
480        .unwrap()
481        .into_array();
482
483        let rhs_struct_different = StructArray::from_fields(&[
484            (
485                "a",
486                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
487            ),
488            (
489                "b",
490                crate::arrays::PrimitiveArray::from_iter([4i32]).into_array(),
491            ),
492        ])
493        .unwrap()
494        .into_array();
495
496        // Test using binary method directly
497        let result_equal = lhs_struct.binary(rhs_struct_equal, Operator::Eq).unwrap();
498        assert_eq!(
499            result_equal
500                .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
501                .vortex_expect("value"),
502            Scalar::bool(true, Nullability::NonNullable),
503            "Equal structs should be equal"
504        );
505
506        let result_different = lhs_struct
507            .binary(rhs_struct_different, Operator::Eq)
508            .unwrap();
509        assert_eq!(
510            result_different
511                .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
512                .vortex_expect("value"),
513            Scalar::bool(false, Nullability::NonNullable),
514            "Different structs should not be equal"
515        );
516    }
517
518    #[test]
519    fn test_or_kleene_validity() {
520        use crate::IntoArray;
521        use crate::arrays::BoolArray;
522        use crate::arrays::StructArray;
523        use crate::expr::col;
524
525        let struct_arr = StructArray::from_fields(&[
526            ("a", BoolArray::from_iter([Some(true)]).into_array()),
527            (
528                "b",
529                BoolArray::from_iter([Option::<bool>::None]).into_array(),
530            ),
531        ])
532        .unwrap()
533        .into_array();
534
535        let expr = or(col("a"), col("b"));
536        let result = struct_arr.apply(&expr).unwrap();
537
538        assert_arrays_eq!(result, BoolArray::from_iter([Some(true)]).into_array())
539    }
540
541    #[test]
542    fn test_scalar_subtract_unsigned() {
543        use vortex_buffer::buffer;
544
545        use crate::IntoArray;
546        use crate::arrays::ConstantArray;
547        use crate::arrays::PrimitiveArray;
548
549        let values = buffer![1u16, 2, 3].into_array();
550        let rhs = ConstantArray::new(Scalar::from(1u16), 3).into_array();
551        let result = values.binary(rhs, Operator::Sub).unwrap();
552        assert_arrays_eq!(result, PrimitiveArray::from_iter([0u16, 1, 2]));
553    }
554
555    #[test]
556    fn test_scalar_subtract_signed() {
557        use vortex_buffer::buffer;
558
559        use crate::IntoArray;
560        use crate::arrays::ConstantArray;
561        use crate::arrays::PrimitiveArray;
562
563        let values = buffer![1i64, 2, 3].into_array();
564        let rhs = ConstantArray::new(Scalar::from(-1i64), 3).into_array();
565        let result = values.binary(rhs, Operator::Sub).unwrap();
566        assert_arrays_eq!(result, PrimitiveArray::from_iter([2i64, 3, 4]));
567    }
568
569    #[test]
570    fn test_scalar_subtract_nullable() {
571        use crate::IntoArray;
572        use crate::arrays::ConstantArray;
573        use crate::arrays::PrimitiveArray;
574
575        let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
576        let rhs = ConstantArray::new(Scalar::from(Some(1u16)), 4).into_array();
577        let result = values.into_array().binary(rhs, Operator::Sub).unwrap();
578        assert_arrays_eq!(
579            result,
580            PrimitiveArray::from_option_iter([Some(0u16), Some(1), None, Some(2)])
581        );
582    }
583
584    #[test]
585    fn test_scalar_subtract_float() {
586        use vortex_buffer::buffer;
587
588        use crate::IntoArray;
589        use crate::arrays::ConstantArray;
590        use crate::arrays::PrimitiveArray;
591
592        let values = buffer![1.0f64, 2.0, 3.0].into_array();
593        let rhs = ConstantArray::new(Scalar::from(-1f64), 3).into_array();
594        let result = values.binary(rhs, Operator::Sub).unwrap();
595        assert_arrays_eq!(result, PrimitiveArray::from_iter([2.0f64, 3.0, 4.0]));
596    }
597
598    #[test]
599    fn test_scalar_subtract_float_underflow_is_ok() {
600        use vortex_buffer::buffer;
601
602        use crate::IntoArray;
603        use crate::arrays::ConstantArray;
604
605        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
606        let rhs1 = ConstantArray::new(Scalar::from(1.0f32), 3).into_array();
607        let _results = values.binary(rhs1, Operator::Sub).unwrap();
608        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
609        let rhs2 = ConstantArray::new(Scalar::from(f32::MAX), 3).into_array();
610        let _results = values.binary(rhs2, Operator::Sub).unwrap();
611    }
612}