Skip to main content

vortex_array/scalar_fn/fns/binary/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6#[expect(deprecated)]
7pub use boolean::and_kleene;
8#[expect(deprecated)]
9pub use boolean::or_kleene;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::dtype::DType;
19use crate::expr::StatsCatalog;
20use crate::expr::and;
21use crate::expr::and_collect;
22use crate::expr::eq;
23use crate::expr::expression::Expression;
24use crate::expr::gt;
25use crate::expr::gt_eq;
26use crate::expr::lit;
27use crate::expr::lt;
28use crate::expr::lt_eq;
29use crate::expr::or_collect;
30use crate::expr::stats::Stat;
31use crate::scalar_fn::Arity;
32use crate::scalar_fn::ChildName;
33use crate::scalar_fn::ExecutionArgs;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36use crate::scalar_fn::fns::operators::CompareOperator;
37use crate::scalar_fn::fns::operators::Operator;
38
39pub(crate) mod boolean;
40pub(crate) use boolean::*;
41mod compare;
42pub use compare::*;
43mod numeric;
44pub(crate) use numeric::*;
45
46use crate::scalar::NumericOperator;
47
48#[derive(Clone)]
49pub struct Binary;
50
51impl ScalarFnVTable for Binary {
52    type Options = Operator;
53
54    fn id(&self) -> ScalarFnId {
55        ScalarFnId::from("vortex.binary")
56    }
57
58    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
59        Ok(Some(
60            pb::BinaryOpts {
61                op: (*instance).into(),
62            }
63            .encode_to_vec(),
64        ))
65    }
66
67    fn deserialize(
68        &self,
69        _metadata: &[u8],
70        _session: &VortexSession,
71    ) -> VortexResult<Self::Options> {
72        let opts = pb::BinaryOpts::decode(_metadata)?;
73        Operator::try_from(opts.op)
74    }
75
76    fn arity(&self, _options: &Self::Options) -> Arity {
77        Arity::Exact(2)
78    }
79
80    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
81        match child_idx {
82            0 => ChildName::from("lhs"),
83            1 => ChildName::from("rhs"),
84            _ => unreachable!("Binary has only two children"),
85        }
86    }
87
88    fn fmt_sql(
89        &self,
90        operator: &Operator,
91        expr: &Expression,
92        f: &mut Formatter<'_>,
93    ) -> std::fmt::Result {
94        write!(f, "(")?;
95        expr.child(0).fmt_sql(f)?;
96        write!(f, " {} ", operator)?;
97        expr.child(1).fmt_sql(f)?;
98        write!(f, ")")
99    }
100
101    fn coerce_args(&self, operator: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
102        let lhs = &args[0];
103        let rhs = &args[1];
104        if operator.is_arithmetic() || operator.is_comparison() {
105            let supertype = lhs.least_supertype(rhs).ok_or_else(|| {
106                vortex_error::vortex_err!("No common supertype for {} and {}", lhs, rhs)
107            })?;
108            Ok(vec![supertype.clone(), supertype])
109        } else {
110            // Boolean And/Or: no coercion
111            Ok(args.to_vec())
112        }
113    }
114
115    fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
116        let lhs = &arg_dtypes[0];
117        let rhs = &arg_dtypes[1];
118
119        if operator.is_arithmetic() {
120            if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
121                return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
122            }
123            vortex_bail!(
124                "incompatible types for arithmetic operation: {} {}",
125                lhs,
126                rhs
127            );
128        }
129
130        if operator.is_comparison()
131            && !lhs.eq_ignore_nullability(rhs)
132            && !lhs.is_extension()
133            && !rhs.is_extension()
134        {
135            vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs);
136        }
137
138        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
139    }
140
141    fn execute(
142        &self,
143        op: &Operator,
144        args: &dyn ExecutionArgs,
145        _ctx: &mut ExecutionCtx,
146    ) -> VortexResult<ArrayRef> {
147        let lhs = args.get(0)?;
148        let rhs = args.get(1)?;
149
150        match op {
151            Operator::Eq => execute_compare(&lhs, &rhs, CompareOperator::Eq),
152            Operator::NotEq => execute_compare(&lhs, &rhs, CompareOperator::NotEq),
153            Operator::Lt => execute_compare(&lhs, &rhs, CompareOperator::Lt),
154            Operator::Lte => execute_compare(&lhs, &rhs, CompareOperator::Lte),
155            Operator::Gt => execute_compare(&lhs, &rhs, CompareOperator::Gt),
156            Operator::Gte => execute_compare(&lhs, &rhs, CompareOperator::Gte),
157            Operator::And => execute_boolean(&lhs, &rhs, Operator::And),
158            Operator::Or => execute_boolean(&lhs, &rhs, Operator::Or),
159            Operator::Add => execute_numeric(&lhs, &rhs, NumericOperator::Add),
160            Operator::Sub => execute_numeric(&lhs, &rhs, NumericOperator::Sub),
161            Operator::Mul => execute_numeric(&lhs, &rhs, NumericOperator::Mul),
162            Operator::Div => execute_numeric(&lhs, &rhs, NumericOperator::Div),
163        }
164    }
165
166    fn stat_falsification(
167        &self,
168        operator: &Operator,
169        expr: &Expression,
170        catalog: &dyn StatsCatalog,
171    ) -> Option<Expression> {
172        // Wrap another predicate with an optional NaNCount check, if the stat is available.
173        //
174        // For example, regular pruning conversion for `A >= B` would be
175        //
176        //      A.max < B.min
177        //
178        // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting
179        // in:
180        //
181        //      (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min
182        //
183        // Non-floating point column and literal expressions should be unaffected as they do not
184        // have a nan_count statistic defined.
185        #[inline]
186        fn with_nan_predicate(
187            lhs: &Expression,
188            rhs: &Expression,
189            value_predicate: Expression,
190            catalog: &dyn StatsCatalog,
191        ) -> Expression {
192            let nan_predicate = and_collect(
193                lhs.stat_expression(Stat::NaNCount, catalog)
194                    .into_iter()
195                    .chain(rhs.stat_expression(Stat::NaNCount, catalog))
196                    .map(|nans| eq(nans, lit(0u64))),
197            );
198
199            if let Some(nan_check) = nan_predicate {
200                and(nan_check, value_predicate)
201            } else {
202                value_predicate
203            }
204        }
205
206        let lhs = expr.child(0);
207        let rhs = expr.child(1);
208        match operator {
209            Operator::Eq => {
210                let min_lhs = lhs.stat_min(catalog);
211                let max_lhs = lhs.stat_max(catalog);
212
213                let min_rhs = rhs.stat_min(catalog);
214                let max_rhs = rhs.stat_max(catalog);
215
216                let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
217                let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
218
219                let min_max_check = or_collect(left.into_iter().chain(right))?;
220
221                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
222                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
223            }
224            Operator::NotEq => {
225                let min_lhs = lhs.stat_min(catalog)?;
226                let max_lhs = lhs.stat_max(catalog)?;
227
228                let min_rhs = rhs.stat_min(catalog)?;
229                let max_rhs = rhs.stat_max(catalog)?;
230
231                let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
232
233                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
234            }
235            Operator::Gt => {
236                let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
237
238                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
239            }
240            Operator::Gte => {
241                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
242                let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
243
244                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
245            }
246            Operator::Lt => {
247                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
248                let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
249
250                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
251            }
252            Operator::Lte => {
253                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
254                let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
255
256                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
257            }
258            Operator::And => or_collect(
259                lhs.stat_falsification(catalog)
260                    .into_iter()
261                    .chain(rhs.stat_falsification(catalog)),
262            ),
263            Operator::Or => Some(and(
264                lhs.stat_falsification(catalog)?,
265                rhs.stat_falsification(catalog)?,
266            )),
267            Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
268        }
269    }
270
271    fn validity(
272        &self,
273        operator: &Operator,
274        expression: &Expression,
275    ) -> VortexResult<Option<Expression>> {
276        let lhs = expression.child(0).validity()?;
277        let rhs = expression.child(1).validity()?;
278
279        Ok(match operator {
280            // AND and OR are kleene logic.
281            Operator::And => None,
282            Operator::Or => None,
283            _ => {
284                // All other binary operators are null if either side is null.
285                Some(and(lhs, rhs))
286            }
287        })
288    }
289
290    fn is_null_sensitive(&self, _operator: &Operator) -> bool {
291        false
292    }
293
294    fn is_fallible(&self, operator: &Operator) -> bool {
295        // Opt-in not out for fallibility.
296        // Arithmetic operations could be better modelled here.
297        let infallible = matches!(
298            operator,
299            Operator::Eq
300                | Operator::NotEq
301                | Operator::Gt
302                | Operator::Gte
303                | Operator::Lt
304                | Operator::Lte
305                | Operator::And
306                | Operator::Or
307        );
308
309        !infallible
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use vortex_error::VortexExpect;
316
317    use super::*;
318    use crate::assert_arrays_eq;
319    use crate::builtins::ArrayBuiltins;
320    use crate::dtype::DType;
321    use crate::dtype::Nullability;
322    use crate::expr::Expression;
323    use crate::expr::and_collect;
324    use crate::expr::col;
325    use crate::expr::lit;
326    use crate::expr::lt;
327    use crate::expr::not_eq;
328    use crate::expr::or;
329    use crate::expr::or_collect;
330    use crate::expr::test_harness;
331    use crate::scalar::Scalar;
332    #[test]
333    fn and_collect_balanced() {
334        let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
335
336        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
337        vortex.binary(and)
338        ├── lhs: vortex.binary(and)
339        │   ├── lhs: vortex.literal(1i32)
340        │   └── rhs: vortex.literal(2i32)
341        └── rhs: vortex.binary(and)
342            ├── lhs: vortex.binary(and)
343            │   ├── lhs: vortex.literal(3i32)
344            │   └── rhs: vortex.literal(4i32)
345            └── rhs: vortex.literal(5i32)
346        ");
347
348        // 4 elements: and(and(1, 2), and(3, 4)) - perfectly balanced
349        let values = vec![lit(1), lit(2), lit(3), lit(4)];
350        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
351        vortex.binary(and)
352        ├── lhs: vortex.binary(and)
353        │   ├── lhs: vortex.literal(1i32)
354        │   └── rhs: vortex.literal(2i32)
355        └── rhs: vortex.binary(and)
356            ├── lhs: vortex.literal(3i32)
357            └── rhs: vortex.literal(4i32)
358        ");
359
360        // 1 element: just the element
361        let values = vec![lit(1)];
362        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @"vortex.literal(1i32)");
363
364        // 0 elements: None
365        let values: Vec<Expression> = vec![];
366        assert!(and_collect(values.into_iter()).is_none());
367    }
368
369    #[test]
370    fn or_collect_balanced() {
371        // 4 elements: or(or(1, 2), or(3, 4)) - perfectly balanced
372        let values = vec![lit(1), lit(2), lit(3), lit(4)];
373        insta::assert_snapshot!(or_collect(values.into_iter()).unwrap().display_tree(), @r"
374        vortex.binary(or)
375        ├── lhs: vortex.binary(or)
376        │   ├── lhs: vortex.literal(1i32)
377        │   └── rhs: vortex.literal(2i32)
378        └── rhs: vortex.binary(or)
379            ├── lhs: vortex.literal(3i32)
380            └── rhs: vortex.literal(4i32)
381        ");
382    }
383
384    #[test]
385    fn dtype() {
386        let dtype = test_harness::struct_dtype();
387        let bool1: Expression = col("bool1");
388        let bool2: Expression = col("bool2");
389        assert_eq!(
390            and(bool1.clone(), bool2.clone())
391                .return_dtype(&dtype)
392                .unwrap(),
393            DType::Bool(Nullability::NonNullable)
394        );
395        assert_eq!(
396            or(bool1, bool2).return_dtype(&dtype).unwrap(),
397            DType::Bool(Nullability::NonNullable)
398        );
399
400        let col1: Expression = col("col1");
401        let col2: Expression = col("col2");
402
403        assert_eq!(
404            eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
405            DType::Bool(Nullability::Nullable)
406        );
407        assert_eq!(
408            not_eq(col1.clone(), col2.clone())
409                .return_dtype(&dtype)
410                .unwrap(),
411            DType::Bool(Nullability::Nullable)
412        );
413        assert_eq!(
414            gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
415            DType::Bool(Nullability::Nullable)
416        );
417        assert_eq!(
418            gt_eq(col1.clone(), col2.clone())
419                .return_dtype(&dtype)
420                .unwrap(),
421            DType::Bool(Nullability::Nullable)
422        );
423        assert_eq!(
424            lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
425            DType::Bool(Nullability::Nullable)
426        );
427        assert_eq!(
428            lt_eq(col1.clone(), col2.clone())
429                .return_dtype(&dtype)
430                .unwrap(),
431            DType::Bool(Nullability::Nullable)
432        );
433
434        assert_eq!(
435            or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
436                .return_dtype(&dtype)
437                .unwrap(),
438            DType::Bool(Nullability::Nullable)
439        );
440    }
441
442    #[test]
443    fn test_display_print() {
444        let expr = gt(lit(1), lit(2));
445        assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
446    }
447
448    /// Regression test for GitHub issue #5947: struct comparison in filter expressions should work
449    /// using `make_comparator` instead of Arrow's `cmp` functions which don't support nested types.
450    #[test]
451    fn test_struct_comparison() {
452        use crate::IntoArray;
453        use crate::arrays::StructArray;
454
455        // Create a struct array with one element for testing.
456        let lhs_struct = StructArray::from_fields(&[
457            (
458                "a",
459                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
460            ),
461            (
462                "b",
463                crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
464            ),
465        ])
466        .unwrap()
467        .into_array();
468
469        let rhs_struct_equal = StructArray::from_fields(&[
470            (
471                "a",
472                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
473            ),
474            (
475                "b",
476                crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
477            ),
478        ])
479        .unwrap()
480        .into_array();
481
482        let rhs_struct_different = StructArray::from_fields(&[
483            (
484                "a",
485                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
486            ),
487            (
488                "b",
489                crate::arrays::PrimitiveArray::from_iter([4i32]).into_array(),
490            ),
491        ])
492        .unwrap()
493        .into_array();
494
495        // Test using binary method directly
496        let result_equal = lhs_struct.binary(rhs_struct_equal, Operator::Eq).unwrap();
497        assert_eq!(
498            result_equal.scalar_at(0).vortex_expect("value"),
499            Scalar::bool(true, Nullability::NonNullable),
500            "Equal structs should be equal"
501        );
502
503        let result_different = lhs_struct
504            .binary(rhs_struct_different, Operator::Eq)
505            .unwrap();
506        assert_eq!(
507            result_different.scalar_at(0).vortex_expect("value"),
508            Scalar::bool(false, Nullability::NonNullable),
509            "Different structs should not be equal"
510        );
511    }
512
513    #[test]
514    fn test_or_kleene_validity() {
515        use crate::IntoArray;
516        use crate::arrays::BoolArray;
517        use crate::arrays::StructArray;
518        use crate::expr::col;
519
520        let struct_arr = StructArray::from_fields(&[
521            ("a", BoolArray::from_iter([Some(true)]).into_array()),
522            (
523                "b",
524                BoolArray::from_iter([Option::<bool>::None]).into_array(),
525            ),
526        ])
527        .unwrap()
528        .into_array();
529
530        let expr = or(col("a"), col("b"));
531        let result = struct_arr.apply(&expr).unwrap();
532
533        assert_arrays_eq!(result, BoolArray::from_iter([Some(true)]).into_array())
534    }
535
536    #[test]
537    fn test_scalar_subtract_unsigned() {
538        use vortex_buffer::buffer;
539
540        use crate::IntoArray;
541        use crate::arrays::ConstantArray;
542        use crate::arrays::PrimitiveArray;
543
544        let values = buffer![1u16, 2, 3].into_array();
545        let rhs = ConstantArray::new(Scalar::from(1u16), 3).into_array();
546        let result = values.binary(rhs, Operator::Sub).unwrap();
547        assert_arrays_eq!(result, PrimitiveArray::from_iter([0u16, 1, 2]));
548    }
549
550    #[test]
551    fn test_scalar_subtract_signed() {
552        use vortex_buffer::buffer;
553
554        use crate::IntoArray;
555        use crate::arrays::ConstantArray;
556        use crate::arrays::PrimitiveArray;
557
558        let values = buffer![1i64, 2, 3].into_array();
559        let rhs = ConstantArray::new(Scalar::from(-1i64), 3).into_array();
560        let result = values.binary(rhs, Operator::Sub).unwrap();
561        assert_arrays_eq!(result, PrimitiveArray::from_iter([2i64, 3, 4]));
562    }
563
564    #[test]
565    fn test_scalar_subtract_nullable() {
566        use crate::IntoArray;
567        use crate::arrays::ConstantArray;
568        use crate::arrays::PrimitiveArray;
569
570        let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
571        let rhs = ConstantArray::new(Scalar::from(Some(1u16)), 4).into_array();
572        let result = values.into_array().binary(rhs, Operator::Sub).unwrap();
573        assert_arrays_eq!(
574            result,
575            PrimitiveArray::from_option_iter([Some(0u16), Some(1), None, Some(2)])
576        );
577    }
578
579    #[test]
580    fn test_scalar_subtract_float() {
581        use vortex_buffer::buffer;
582
583        use crate::IntoArray;
584        use crate::arrays::ConstantArray;
585        use crate::arrays::PrimitiveArray;
586
587        let values = buffer![1.0f64, 2.0, 3.0].into_array();
588        let rhs = ConstantArray::new(Scalar::from(-1f64), 3).into_array();
589        let result = values.binary(rhs, Operator::Sub).unwrap();
590        assert_arrays_eq!(result, PrimitiveArray::from_iter([2.0f64, 3.0, 4.0]));
591    }
592
593    #[test]
594    fn test_scalar_subtract_float_underflow_is_ok() {
595        use vortex_buffer::buffer;
596
597        use crate::IntoArray;
598        use crate::arrays::ConstantArray;
599
600        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
601        let rhs1 = ConstantArray::new(Scalar::from(1.0f32), 3).into_array();
602        let _results = values.binary(rhs1, Operator::Sub).unwrap();
603        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
604        let rhs2 = ConstantArray::new(Scalar::from(f32::MAX), 3).into_array();
605        let _results = values.binary(rhs2, Operator::Sub).unwrap();
606    }
607}