Skip to main content

vortex_array/scalar_fn/fns/binary/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6#[expect(deprecated)]
7pub use boolean::and_kleene;
8#[expect(deprecated)]
9pub use boolean::or_kleene;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::dtype::DType;
19use crate::expr::StatsCatalog;
20use crate::expr::and;
21use crate::expr::and_collect;
22use crate::expr::eq;
23use crate::expr::expression::Expression;
24use crate::expr::gt;
25use crate::expr::gt_eq;
26use crate::expr::lit;
27use crate::expr::lt;
28use crate::expr::lt_eq;
29use crate::expr::or_collect;
30use crate::expr::stats::Stat;
31use crate::scalar_fn::Arity;
32use crate::scalar_fn::ChildName;
33use crate::scalar_fn::ExecutionArgs;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36use crate::scalar_fn::fns::operators::CompareOperator;
37use crate::scalar_fn::fns::operators::Operator;
38
39pub(crate) mod boolean;
40pub(crate) use boolean::*;
41mod compare;
42pub use compare::*;
43mod numeric;
44pub(crate) use numeric::*;
45
46use crate::scalar::NumericOperator;
47
48#[derive(Clone)]
49pub struct Binary;
50
51impl ScalarFnVTable for Binary {
52    type Options = Operator;
53
54    fn id(&self) -> ScalarFnId {
55        ScalarFnId::from("vortex.binary")
56    }
57
58    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
59        Ok(Some(
60            pb::BinaryOpts {
61                op: (*instance).into(),
62            }
63            .encode_to_vec(),
64        ))
65    }
66
67    fn deserialize(
68        &self,
69        _metadata: &[u8],
70        _session: &VortexSession,
71    ) -> VortexResult<Self::Options> {
72        let opts = pb::BinaryOpts::decode(_metadata)?;
73        Operator::try_from(opts.op)
74    }
75
76    fn arity(&self, _options: &Self::Options) -> Arity {
77        Arity::Exact(2)
78    }
79
80    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
81        match child_idx {
82            0 => ChildName::from("lhs"),
83            1 => ChildName::from("rhs"),
84            _ => unreachable!("Binary has only two children"),
85        }
86    }
87
88    fn fmt_sql(
89        &self,
90        operator: &Operator,
91        expr: &Expression,
92        f: &mut Formatter<'_>,
93    ) -> std::fmt::Result {
94        write!(f, "(")?;
95        expr.child(0).fmt_sql(f)?;
96        write!(f, " {} ", operator)?;
97        expr.child(1).fmt_sql(f)?;
98        write!(f, ")")
99    }
100
101    fn coerce_args(&self, operator: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
102        let lhs = &args[0];
103        let rhs = &args[1];
104        if operator.is_arithmetic() || operator.is_comparison() {
105            let supertype = lhs.least_supertype(rhs).ok_or_else(|| {
106                vortex_error::vortex_err!("No common supertype for {} and {}", lhs, rhs)
107            })?;
108            Ok(vec![supertype.clone(), supertype])
109        } else {
110            // Boolean And/Or: no coercion
111            Ok(args.to_vec())
112        }
113    }
114
115    fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
116        let lhs = &arg_dtypes[0];
117        let rhs = &arg_dtypes[1];
118
119        if operator.is_arithmetic() {
120            if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
121                return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
122            }
123            vortex_bail!(
124                "incompatible types for arithmetic operation: {} {}",
125                lhs,
126                rhs
127            );
128        }
129
130        if operator.is_comparison()
131            && !lhs.eq_ignore_nullability(rhs)
132            && !lhs.is_extension()
133            && !rhs.is_extension()
134        {
135            vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs);
136        }
137
138        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
139    }
140
141    fn execute(
142        &self,
143        op: &Operator,
144        args: &dyn ExecutionArgs,
145        _ctx: &mut ExecutionCtx,
146    ) -> VortexResult<ArrayRef> {
147        let lhs = args.get(0)?;
148        let rhs = args.get(1)?;
149
150        match op {
151            Operator::Eq => execute_compare(&lhs, &rhs, CompareOperator::Eq),
152            Operator::NotEq => execute_compare(&lhs, &rhs, CompareOperator::NotEq),
153            Operator::Lt => execute_compare(&lhs, &rhs, CompareOperator::Lt),
154            Operator::Lte => execute_compare(&lhs, &rhs, CompareOperator::Lte),
155            Operator::Gt => execute_compare(&lhs, &rhs, CompareOperator::Gt),
156            Operator::Gte => execute_compare(&lhs, &rhs, CompareOperator::Gte),
157            Operator::And => execute_boolean(&lhs, &rhs, Operator::And),
158            Operator::Or => execute_boolean(&lhs, &rhs, Operator::Or),
159            Operator::Add => execute_numeric(&lhs, &rhs, NumericOperator::Add),
160            Operator::Sub => execute_numeric(&lhs, &rhs, NumericOperator::Sub),
161            Operator::Mul => execute_numeric(&lhs, &rhs, NumericOperator::Mul),
162            Operator::Div => execute_numeric(&lhs, &rhs, NumericOperator::Div),
163        }
164    }
165
166    fn stat_falsification(
167        &self,
168        operator: &Operator,
169        expr: &Expression,
170        catalog: &dyn StatsCatalog,
171    ) -> Option<Expression> {
172        // Wrap another predicate with an optional NaNCount check, if the stat is available.
173        //
174        // For example, regular pruning conversion for `A >= B` would be
175        //
176        //      A.max < B.min
177        //
178        // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting
179        // in:
180        //
181        //      (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min
182        //
183        // Non-floating point column and literal expressions should be unaffected as they do not
184        // have a nan_count statistic defined.
185        fn with_nan_predicate(
186            lhs: &Expression,
187            rhs: &Expression,
188            value_predicate: Expression,
189            catalog: &dyn StatsCatalog,
190        ) -> Expression {
191            let nan_predicate = and_collect(
192                lhs.stat_expression(Stat::NaNCount, catalog)
193                    .into_iter()
194                    .chain(rhs.stat_expression(Stat::NaNCount, catalog))
195                    .map(|nans| eq(nans, lit(0u64))),
196            );
197
198            if let Some(nan_check) = nan_predicate {
199                and(nan_check, value_predicate)
200            } else {
201                value_predicate
202            }
203        }
204
205        let lhs = expr.child(0);
206        let rhs = expr.child(1);
207        match operator {
208            Operator::Eq => {
209                let min_lhs = lhs.stat_min(catalog);
210                let max_lhs = lhs.stat_max(catalog);
211
212                let min_rhs = rhs.stat_min(catalog);
213                let max_rhs = rhs.stat_max(catalog);
214
215                let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
216                let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
217
218                let min_max_check = or_collect(left.into_iter().chain(right))?;
219
220                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
221                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
222            }
223            Operator::NotEq => {
224                let min_lhs = lhs.stat_min(catalog)?;
225                let max_lhs = lhs.stat_max(catalog)?;
226
227                let min_rhs = rhs.stat_min(catalog)?;
228                let max_rhs = rhs.stat_max(catalog)?;
229
230                let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
231
232                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
233            }
234            Operator::Gt => {
235                let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
236
237                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
238            }
239            Operator::Gte => {
240                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
241                let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
242
243                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
244            }
245            Operator::Lt => {
246                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
247                let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
248
249                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
250            }
251            Operator::Lte => {
252                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
253                let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
254
255                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
256            }
257            Operator::And => or_collect(
258                lhs.stat_falsification(catalog)
259                    .into_iter()
260                    .chain(rhs.stat_falsification(catalog)),
261            ),
262            Operator::Or => Some(and(
263                lhs.stat_falsification(catalog)?,
264                rhs.stat_falsification(catalog)?,
265            )),
266            Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
267        }
268    }
269
270    fn validity(
271        &self,
272        operator: &Operator,
273        expression: &Expression,
274    ) -> VortexResult<Option<Expression>> {
275        let lhs = expression.child(0).validity()?;
276        let rhs = expression.child(1).validity()?;
277
278        Ok(match operator {
279            // AND and OR are kleene logic.
280            Operator::And => None,
281            Operator::Or => None,
282            _ => {
283                // All other binary operators are null if either side is null.
284                Some(and(lhs, rhs))
285            }
286        })
287    }
288
289    fn is_null_sensitive(&self, _operator: &Operator) -> bool {
290        false
291    }
292
293    fn is_fallible(&self, operator: &Operator) -> bool {
294        // Opt-in not out for fallibility.
295        // Arithmetic operations could be better modelled here.
296        let infallible = matches!(
297            operator,
298            Operator::Eq
299                | Operator::NotEq
300                | Operator::Gt
301                | Operator::Gte
302                | Operator::Lt
303                | Operator::Lte
304                | Operator::And
305                | Operator::Or
306        );
307
308        !infallible
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use vortex_error::VortexExpect;
315
316    use super::*;
317    use crate::assert_arrays_eq;
318    use crate::builtins::ArrayBuiltins;
319    use crate::dtype::DType;
320    use crate::dtype::Nullability;
321    use crate::expr::Expression;
322    use crate::expr::and_collect;
323    use crate::expr::col;
324    use crate::expr::lit;
325    use crate::expr::lt;
326    use crate::expr::not_eq;
327    use crate::expr::or;
328    use crate::expr::or_collect;
329    use crate::expr::test_harness;
330    use crate::scalar::Scalar;
331    #[test]
332    fn and_collect_balanced() {
333        let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
334
335        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
336        vortex.binary(and)
337        ├── lhs: vortex.binary(and)
338        │   ├── lhs: vortex.literal(1i32)
339        │   └── rhs: vortex.literal(2i32)
340        └── rhs: vortex.binary(and)
341            ├── lhs: vortex.binary(and)
342            │   ├── lhs: vortex.literal(3i32)
343            │   └── rhs: vortex.literal(4i32)
344            └── rhs: vortex.literal(5i32)
345        ");
346
347        // 4 elements: and(and(1, 2), and(3, 4)) - perfectly balanced
348        let values = vec![lit(1), lit(2), lit(3), lit(4)];
349        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
350        vortex.binary(and)
351        ├── lhs: vortex.binary(and)
352        │   ├── lhs: vortex.literal(1i32)
353        │   └── rhs: vortex.literal(2i32)
354        └── rhs: vortex.binary(and)
355            ├── lhs: vortex.literal(3i32)
356            └── rhs: vortex.literal(4i32)
357        ");
358
359        // 1 element: just the element
360        let values = vec![lit(1)];
361        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @"vortex.literal(1i32)");
362
363        // 0 elements: None
364        let values: Vec<Expression> = vec![];
365        assert!(and_collect(values.into_iter()).is_none());
366    }
367
368    #[test]
369    fn or_collect_balanced() {
370        // 4 elements: or(or(1, 2), or(3, 4)) - perfectly balanced
371        let values = vec![lit(1), lit(2), lit(3), lit(4)];
372        insta::assert_snapshot!(or_collect(values.into_iter()).unwrap().display_tree(), @r"
373        vortex.binary(or)
374        ├── lhs: vortex.binary(or)
375        │   ├── lhs: vortex.literal(1i32)
376        │   └── rhs: vortex.literal(2i32)
377        └── rhs: vortex.binary(or)
378            ├── lhs: vortex.literal(3i32)
379            └── rhs: vortex.literal(4i32)
380        ");
381    }
382
383    #[test]
384    fn dtype() {
385        let dtype = test_harness::struct_dtype();
386        let bool1: Expression = col("bool1");
387        let bool2: Expression = col("bool2");
388        assert_eq!(
389            and(bool1.clone(), bool2.clone())
390                .return_dtype(&dtype)
391                .unwrap(),
392            DType::Bool(Nullability::NonNullable)
393        );
394        assert_eq!(
395            or(bool1, bool2).return_dtype(&dtype).unwrap(),
396            DType::Bool(Nullability::NonNullable)
397        );
398
399        let col1: Expression = col("col1");
400        let col2: Expression = col("col2");
401
402        assert_eq!(
403            eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
404            DType::Bool(Nullability::Nullable)
405        );
406        assert_eq!(
407            not_eq(col1.clone(), col2.clone())
408                .return_dtype(&dtype)
409                .unwrap(),
410            DType::Bool(Nullability::Nullable)
411        );
412        assert_eq!(
413            gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
414            DType::Bool(Nullability::Nullable)
415        );
416        assert_eq!(
417            gt_eq(col1.clone(), col2.clone())
418                .return_dtype(&dtype)
419                .unwrap(),
420            DType::Bool(Nullability::Nullable)
421        );
422        assert_eq!(
423            lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
424            DType::Bool(Nullability::Nullable)
425        );
426        assert_eq!(
427            lt_eq(col1.clone(), col2.clone())
428                .return_dtype(&dtype)
429                .unwrap(),
430            DType::Bool(Nullability::Nullable)
431        );
432
433        assert_eq!(
434            or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
435                .return_dtype(&dtype)
436                .unwrap(),
437            DType::Bool(Nullability::Nullable)
438        );
439    }
440
441    #[test]
442    fn test_display_print() {
443        let expr = gt(lit(1), lit(2));
444        assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
445    }
446
447    /// Regression test for GitHub issue #5947: struct comparison in filter expressions should work
448    /// using `make_comparator` instead of Arrow's `cmp` functions which don't support nested types.
449    #[test]
450    fn test_struct_comparison() {
451        use crate::IntoArray;
452        use crate::arrays::StructArray;
453
454        // Create a struct array with one element for testing.
455        let lhs_struct = StructArray::from_fields(&[
456            (
457                "a",
458                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
459            ),
460            (
461                "b",
462                crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
463            ),
464        ])
465        .unwrap()
466        .into_array();
467
468        let rhs_struct_equal = StructArray::from_fields(&[
469            (
470                "a",
471                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
472            ),
473            (
474                "b",
475                crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
476            ),
477        ])
478        .unwrap()
479        .into_array();
480
481        let rhs_struct_different = StructArray::from_fields(&[
482            (
483                "a",
484                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
485            ),
486            (
487                "b",
488                crate::arrays::PrimitiveArray::from_iter([4i32]).into_array(),
489            ),
490        ])
491        .unwrap()
492        .into_array();
493
494        // Test using binary method directly
495        let result_equal = lhs_struct.binary(rhs_struct_equal, Operator::Eq).unwrap();
496        assert_eq!(
497            result_equal.scalar_at(0).vortex_expect("value"),
498            Scalar::bool(true, Nullability::NonNullable),
499            "Equal structs should be equal"
500        );
501
502        let result_different = lhs_struct
503            .binary(rhs_struct_different, Operator::Eq)
504            .unwrap();
505        assert_eq!(
506            result_different.scalar_at(0).vortex_expect("value"),
507            Scalar::bool(false, Nullability::NonNullable),
508            "Different structs should not be equal"
509        );
510    }
511
512    #[test]
513    fn test_or_kleene_validity() {
514        use crate::IntoArray;
515        use crate::arrays::BoolArray;
516        use crate::arrays::StructArray;
517        use crate::expr::col;
518
519        let struct_arr = StructArray::from_fields(&[
520            ("a", BoolArray::from_iter([Some(true)]).into_array()),
521            (
522                "b",
523                BoolArray::from_iter([Option::<bool>::None]).into_array(),
524            ),
525        ])
526        .unwrap()
527        .into_array();
528
529        let expr = or(col("a"), col("b"));
530        let result = struct_arr.apply(&expr).unwrap();
531
532        assert_arrays_eq!(result, BoolArray::from_iter([Some(true)]).into_array())
533    }
534
535    #[test]
536    fn test_scalar_subtract_unsigned() {
537        use vortex_buffer::buffer;
538
539        use crate::IntoArray;
540        use crate::arrays::ConstantArray;
541        use crate::arrays::PrimitiveArray;
542
543        let values = buffer![1u16, 2, 3].into_array();
544        let rhs = ConstantArray::new(Scalar::from(1u16), 3).into_array();
545        let result = values.binary(rhs, Operator::Sub).unwrap();
546        assert_arrays_eq!(result, PrimitiveArray::from_iter([0u16, 1, 2]));
547    }
548
549    #[test]
550    fn test_scalar_subtract_signed() {
551        use vortex_buffer::buffer;
552
553        use crate::IntoArray;
554        use crate::arrays::ConstantArray;
555        use crate::arrays::PrimitiveArray;
556
557        let values = buffer![1i64, 2, 3].into_array();
558        let rhs = ConstantArray::new(Scalar::from(-1i64), 3).into_array();
559        let result = values.binary(rhs, Operator::Sub).unwrap();
560        assert_arrays_eq!(result, PrimitiveArray::from_iter([2i64, 3, 4]));
561    }
562
563    #[test]
564    fn test_scalar_subtract_nullable() {
565        use crate::IntoArray;
566        use crate::arrays::ConstantArray;
567        use crate::arrays::PrimitiveArray;
568
569        let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
570        let rhs = ConstantArray::new(Scalar::from(Some(1u16)), 4).into_array();
571        let result = values.into_array().binary(rhs, Operator::Sub).unwrap();
572        assert_arrays_eq!(
573            result,
574            PrimitiveArray::from_option_iter([Some(0u16), Some(1), None, Some(2)])
575        );
576    }
577
578    #[test]
579    fn test_scalar_subtract_float() {
580        use vortex_buffer::buffer;
581
582        use crate::IntoArray;
583        use crate::arrays::ConstantArray;
584        use crate::arrays::PrimitiveArray;
585
586        let values = buffer![1.0f64, 2.0, 3.0].into_array();
587        let rhs = ConstantArray::new(Scalar::from(-1f64), 3).into_array();
588        let result = values.binary(rhs, Operator::Sub).unwrap();
589        assert_arrays_eq!(result, PrimitiveArray::from_iter([2.0f64, 3.0, 4.0]));
590    }
591
592    #[test]
593    fn test_scalar_subtract_float_underflow_is_ok() {
594        use vortex_buffer::buffer;
595
596        use crate::IntoArray;
597        use crate::arrays::ConstantArray;
598
599        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
600        let rhs1 = ConstantArray::new(Scalar::from(1.0f32), 3).into_array();
601        let _results = values.binary(rhs1, Operator::Sub).unwrap();
602        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
603        let rhs2 = ConstantArray::new(Scalar::from(f32::MAX), 3).into_array();
604        let _results = values.binary(rhs2, Operator::Sub).unwrap();
605    }
606}