Skip to main content

vortex_array/scalar_fn/fns/binary/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6#[expect(deprecated)]
7pub use boolean::and_kleene;
8#[expect(deprecated)]
9pub use boolean::or_kleene;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15use vortex_session::registry::CachedId;
16
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::dtype::DType;
20use crate::expr::StatsCatalog;
21use crate::expr::and;
22use crate::expr::and_collect;
23use crate::expr::eq;
24use crate::expr::expression::Expression;
25use crate::expr::gt;
26use crate::expr::gt_eq;
27use crate::expr::lit;
28use crate::expr::lt;
29use crate::expr::lt_eq;
30use crate::expr::or_collect;
31use crate::expr::stats::Stat;
32use crate::scalar_fn::Arity;
33use crate::scalar_fn::ChildName;
34use crate::scalar_fn::ExecutionArgs;
35use crate::scalar_fn::ScalarFnId;
36use crate::scalar_fn::ScalarFnVTable;
37use crate::scalar_fn::fns::operators::CompareOperator;
38use crate::scalar_fn::fns::operators::Operator;
39
40pub(crate) mod boolean;
41pub(crate) use boolean::*;
42mod compare;
43pub use compare::*;
44mod numeric;
45pub(crate) use numeric::*;
46
47use crate::scalar::NumericOperator;
48
49#[derive(Clone)]
50pub struct Binary;
51
52impl ScalarFnVTable for Binary {
53    type Options = Operator;
54
55    fn id(&self) -> ScalarFnId {
56        static ID: CachedId = CachedId::new("vortex.binary");
57        *ID
58    }
59
60    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
61        Ok(Some(
62            pb::BinaryOpts {
63                op: (*instance).into(),
64            }
65            .encode_to_vec(),
66        ))
67    }
68
69    fn deserialize(
70        &self,
71        _metadata: &[u8],
72        _session: &VortexSession,
73    ) -> VortexResult<Self::Options> {
74        let opts = pb::BinaryOpts::decode(_metadata)?;
75        Operator::try_from(opts.op)
76    }
77
78    fn arity(&self, _options: &Self::Options) -> Arity {
79        Arity::Exact(2)
80    }
81
82    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
83        match child_idx {
84            0 => ChildName::from("lhs"),
85            1 => ChildName::from("rhs"),
86            _ => unreachable!("Binary has only two children"),
87        }
88    }
89
90    fn fmt_sql(
91        &self,
92        operator: &Operator,
93        expr: &Expression,
94        f: &mut Formatter<'_>,
95    ) -> std::fmt::Result {
96        write!(f, "(")?;
97        expr.child(0).fmt_sql(f)?;
98        write!(f, " {} ", operator)?;
99        expr.child(1).fmt_sql(f)?;
100        write!(f, ")")
101    }
102
103    fn coerce_args(&self, operator: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
104        let lhs = &args[0];
105        let rhs = &args[1];
106        if operator.is_arithmetic() || operator.is_comparison() {
107            let supertype = lhs.least_supertype(rhs).ok_or_else(|| {
108                vortex_error::vortex_err!("No common supertype for {} and {}", lhs, rhs)
109            })?;
110            Ok(vec![supertype.clone(), supertype])
111        } else {
112            // Boolean And/Or: no coercion
113            Ok(args.to_vec())
114        }
115    }
116
117    fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
118        let lhs = &arg_dtypes[0];
119        let rhs = &arg_dtypes[1];
120
121        if operator.is_arithmetic() {
122            if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
123                return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
124            }
125            vortex_bail!(
126                "incompatible types for arithmetic operation: {} {}",
127                lhs,
128                rhs
129            );
130        }
131
132        if operator.is_comparison()
133            && !lhs.eq_ignore_nullability(rhs)
134            && !lhs.is_extension()
135            && !rhs.is_extension()
136        {
137            vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs);
138        }
139
140        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
141    }
142
143    fn execute(
144        &self,
145        op: &Operator,
146        args: &dyn ExecutionArgs,
147        ctx: &mut ExecutionCtx,
148    ) -> VortexResult<ArrayRef> {
149        let lhs = args.get(0)?;
150        let rhs = args.get(1)?;
151
152        match op {
153            Operator::Eq => execute_compare(&lhs, &rhs, CompareOperator::Eq, ctx),
154            Operator::NotEq => execute_compare(&lhs, &rhs, CompareOperator::NotEq, ctx),
155            Operator::Lt => execute_compare(&lhs, &rhs, CompareOperator::Lt, ctx),
156            Operator::Lte => execute_compare(&lhs, &rhs, CompareOperator::Lte, ctx),
157            Operator::Gt => execute_compare(&lhs, &rhs, CompareOperator::Gt, ctx),
158            Operator::Gte => execute_compare(&lhs, &rhs, CompareOperator::Gte, ctx),
159            Operator::And => execute_boolean(&lhs, &rhs, Operator::And, ctx),
160            Operator::Or => execute_boolean(&lhs, &rhs, Operator::Or, ctx),
161            Operator::Add => execute_numeric(&lhs, &rhs, NumericOperator::Add, ctx),
162            Operator::Sub => execute_numeric(&lhs, &rhs, NumericOperator::Sub, ctx),
163            Operator::Mul => execute_numeric(&lhs, &rhs, NumericOperator::Mul, ctx),
164            Operator::Div => execute_numeric(&lhs, &rhs, NumericOperator::Div, ctx),
165        }
166    }
167
168    fn stat_falsification(
169        &self,
170        operator: &Operator,
171        expr: &Expression,
172        catalog: &dyn StatsCatalog,
173    ) -> Option<Expression> {
174        // Wrap another predicate with an optional NaNCount check, if the stat is available.
175        //
176        // For example, regular pruning conversion for `A >= B` would be
177        //
178        //      A.max < B.min
179        //
180        // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting
181        // in:
182        //
183        //      (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min
184        //
185        // Non-floating point column and literal expressions should be unaffected as they do not
186        // have a nan_count statistic defined.
187        fn with_nan_predicate(
188            lhs: &Expression,
189            rhs: &Expression,
190            value_predicate: Expression,
191            catalog: &dyn StatsCatalog,
192        ) -> Expression {
193            let nan_predicate = and_collect(
194                lhs.stat_expression(Stat::NaNCount, catalog)
195                    .into_iter()
196                    .chain(rhs.stat_expression(Stat::NaNCount, catalog))
197                    .map(|nans| eq(nans, lit(0u64))),
198            );
199
200            if let Some(nan_check) = nan_predicate {
201                and(nan_check, value_predicate)
202            } else {
203                value_predicate
204            }
205        }
206
207        let lhs = expr.child(0);
208        let rhs = expr.child(1);
209        match operator {
210            Operator::Eq => {
211                let min_lhs = lhs.stat_min(catalog);
212                let max_lhs = lhs.stat_max(catalog);
213
214                let min_rhs = rhs.stat_min(catalog);
215                let max_rhs = rhs.stat_max(catalog);
216
217                let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
218                let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
219
220                let min_max_check = or_collect(left.into_iter().chain(right))?;
221
222                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
223                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
224            }
225            Operator::NotEq => {
226                let min_lhs = lhs.stat_min(catalog)?;
227                let max_lhs = lhs.stat_max(catalog)?;
228
229                let min_rhs = rhs.stat_min(catalog)?;
230                let max_rhs = rhs.stat_max(catalog)?;
231
232                let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
233
234                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
235            }
236            Operator::Gt => {
237                let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
238
239                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
240            }
241            Operator::Gte => {
242                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
243                let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
244
245                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
246            }
247            Operator::Lt => {
248                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
249                let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
250
251                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
252            }
253            Operator::Lte => {
254                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
255                let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
256
257                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
258            }
259            Operator::And => or_collect(
260                lhs.stat_falsification(catalog)
261                    .into_iter()
262                    .chain(rhs.stat_falsification(catalog)),
263            ),
264            Operator::Or => Some(and(
265                lhs.stat_falsification(catalog)?,
266                rhs.stat_falsification(catalog)?,
267            )),
268            Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
269        }
270    }
271
272    fn validity(
273        &self,
274        operator: &Operator,
275        expression: &Expression,
276    ) -> VortexResult<Option<Expression>> {
277        let lhs = expression.child(0).validity()?;
278        let rhs = expression.child(1).validity()?;
279
280        Ok(match operator {
281            // AND and OR are kleene logic.
282            Operator::And => None,
283            Operator::Or => None,
284            _ => {
285                // All other binary operators are null if either side is null.
286                Some(and(lhs, rhs))
287            }
288        })
289    }
290
291    fn is_null_sensitive(&self, _operator: &Operator) -> bool {
292        false
293    }
294
295    fn is_fallible(&self, operator: &Operator) -> bool {
296        // Opt-in not out for fallibility.
297        // Arithmetic operations could be better modelled here.
298        let infallible = matches!(
299            operator,
300            Operator::Eq
301                | Operator::NotEq
302                | Operator::Gt
303                | Operator::Gte
304                | Operator::Lt
305                | Operator::Lte
306                | Operator::And
307                | Operator::Or
308        );
309
310        !infallible
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use vortex_error::VortexExpect;
317
318    use super::*;
319    use crate::LEGACY_SESSION;
320    use crate::VortexSessionExecute;
321    use crate::assert_arrays_eq;
322    use crate::builtins::ArrayBuiltins;
323    use crate::dtype::DType;
324    use crate::dtype::Nullability;
325    use crate::expr::Expression;
326    use crate::expr::and_collect;
327    use crate::expr::col;
328    use crate::expr::lit;
329    use crate::expr::lt;
330    use crate::expr::not_eq;
331    use crate::expr::or;
332    use crate::expr::or_collect;
333    use crate::expr::test_harness;
334    use crate::scalar::Scalar;
335    #[test]
336    fn and_collect_balanced() {
337        let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
338
339        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
340        vortex.binary(and)
341        ├── lhs: vortex.binary(and)
342        │   ├── lhs: vortex.literal(1i32)
343        │   └── rhs: vortex.literal(2i32)
344        └── rhs: vortex.binary(and)
345            ├── lhs: vortex.binary(and)
346            │   ├── lhs: vortex.literal(3i32)
347            │   └── rhs: vortex.literal(4i32)
348            └── rhs: vortex.literal(5i32)
349        ");
350
351        // 4 elements: and(and(1, 2), and(3, 4)) - perfectly balanced
352        let values = vec![lit(1), lit(2), lit(3), lit(4)];
353        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
354        vortex.binary(and)
355        ├── lhs: vortex.binary(and)
356        │   ├── lhs: vortex.literal(1i32)
357        │   └── rhs: vortex.literal(2i32)
358        └── rhs: vortex.binary(and)
359            ├── lhs: vortex.literal(3i32)
360            └── rhs: vortex.literal(4i32)
361        ");
362
363        // 1 element: just the element
364        let values = vec![lit(1)];
365        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @"vortex.literal(1i32)");
366
367        // 0 elements: None
368        let values: Vec<Expression> = vec![];
369        assert!(and_collect(values.into_iter()).is_none());
370    }
371
372    #[test]
373    fn or_collect_balanced() {
374        // 4 elements: or(or(1, 2), or(3, 4)) - perfectly balanced
375        let values = vec![lit(1), lit(2), lit(3), lit(4)];
376        insta::assert_snapshot!(or_collect(values.into_iter()).unwrap().display_tree(), @r"
377        vortex.binary(or)
378        ├── lhs: vortex.binary(or)
379        │   ├── lhs: vortex.literal(1i32)
380        │   └── rhs: vortex.literal(2i32)
381        └── rhs: vortex.binary(or)
382            ├── lhs: vortex.literal(3i32)
383            └── rhs: vortex.literal(4i32)
384        ");
385    }
386
387    #[test]
388    fn dtype() {
389        let dtype = test_harness::struct_dtype();
390        let bool1: Expression = col("bool1");
391        let bool2: Expression = col("bool2");
392        assert_eq!(
393            and(bool1.clone(), bool2.clone())
394                .return_dtype(&dtype)
395                .unwrap(),
396            DType::Bool(Nullability::NonNullable)
397        );
398        assert_eq!(
399            or(bool1, bool2).return_dtype(&dtype).unwrap(),
400            DType::Bool(Nullability::NonNullable)
401        );
402
403        let col1: Expression = col("col1");
404        let col2: Expression = col("col2");
405
406        assert_eq!(
407            eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
408            DType::Bool(Nullability::Nullable)
409        );
410        assert_eq!(
411            not_eq(col1.clone(), col2.clone())
412                .return_dtype(&dtype)
413                .unwrap(),
414            DType::Bool(Nullability::Nullable)
415        );
416        assert_eq!(
417            gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
418            DType::Bool(Nullability::Nullable)
419        );
420        assert_eq!(
421            gt_eq(col1.clone(), col2.clone())
422                .return_dtype(&dtype)
423                .unwrap(),
424            DType::Bool(Nullability::Nullable)
425        );
426        assert_eq!(
427            lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
428            DType::Bool(Nullability::Nullable)
429        );
430        assert_eq!(
431            lt_eq(col1.clone(), col2.clone())
432                .return_dtype(&dtype)
433                .unwrap(),
434            DType::Bool(Nullability::Nullable)
435        );
436
437        assert_eq!(
438            or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
439                .return_dtype(&dtype)
440                .unwrap(),
441            DType::Bool(Nullability::Nullable)
442        );
443    }
444
445    #[test]
446    fn test_display_print() {
447        let expr = gt(lit(1), lit(2));
448        assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
449    }
450
451    /// Regression test for GitHub issue #5947: struct comparison in filter expressions should work
452    /// using `make_comparator` instead of Arrow's `cmp` functions which don't support nested types.
453    #[test]
454    fn test_struct_comparison() {
455        use crate::IntoArray;
456        use crate::arrays::StructArray;
457
458        // Create a struct array with one element for testing.
459        let lhs_struct = StructArray::from_fields(&[
460            (
461                "a",
462                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
463            ),
464            (
465                "b",
466                crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
467            ),
468        ])
469        .unwrap()
470        .into_array();
471
472        let rhs_struct_equal = StructArray::from_fields(&[
473            (
474                "a",
475                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
476            ),
477            (
478                "b",
479                crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
480            ),
481        ])
482        .unwrap()
483        .into_array();
484
485        let rhs_struct_different = StructArray::from_fields(&[
486            (
487                "a",
488                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
489            ),
490            (
491                "b",
492                crate::arrays::PrimitiveArray::from_iter([4i32]).into_array(),
493            ),
494        ])
495        .unwrap()
496        .into_array();
497
498        // Test using binary method directly
499        let result_equal = lhs_struct.binary(rhs_struct_equal, Operator::Eq).unwrap();
500        assert_eq!(
501            result_equal
502                .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
503                .vortex_expect("value"),
504            Scalar::bool(true, Nullability::NonNullable),
505            "Equal structs should be equal"
506        );
507
508        let result_different = lhs_struct
509            .binary(rhs_struct_different, Operator::Eq)
510            .unwrap();
511        assert_eq!(
512            result_different
513                .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
514                .vortex_expect("value"),
515            Scalar::bool(false, Nullability::NonNullable),
516            "Different structs should not be equal"
517        );
518    }
519
520    #[test]
521    fn test_or_kleene_validity() {
522        use crate::IntoArray;
523        use crate::arrays::BoolArray;
524        use crate::arrays::StructArray;
525        use crate::expr::col;
526
527        let struct_arr = StructArray::from_fields(&[
528            ("a", BoolArray::from_iter([Some(true)]).into_array()),
529            (
530                "b",
531                BoolArray::from_iter([Option::<bool>::None]).into_array(),
532            ),
533        ])
534        .unwrap()
535        .into_array();
536
537        let expr = or(col("a"), col("b"));
538        let result = struct_arr.apply(&expr).unwrap();
539
540        assert_arrays_eq!(result, BoolArray::from_iter([Some(true)]).into_array())
541    }
542
543    #[test]
544    fn test_scalar_subtract_unsigned() {
545        use vortex_buffer::buffer;
546
547        use crate::IntoArray;
548        use crate::arrays::ConstantArray;
549        use crate::arrays::PrimitiveArray;
550
551        let values = buffer![1u16, 2, 3].into_array();
552        let rhs = ConstantArray::new(Scalar::from(1u16), 3).into_array();
553        let result = values.binary(rhs, Operator::Sub).unwrap();
554        assert_arrays_eq!(result, PrimitiveArray::from_iter([0u16, 1, 2]));
555    }
556
557    #[test]
558    fn test_scalar_subtract_signed() {
559        use vortex_buffer::buffer;
560
561        use crate::IntoArray;
562        use crate::arrays::ConstantArray;
563        use crate::arrays::PrimitiveArray;
564
565        let values = buffer![1i64, 2, 3].into_array();
566        let rhs = ConstantArray::new(Scalar::from(-1i64), 3).into_array();
567        let result = values.binary(rhs, Operator::Sub).unwrap();
568        assert_arrays_eq!(result, PrimitiveArray::from_iter([2i64, 3, 4]));
569    }
570
571    #[test]
572    fn test_scalar_subtract_nullable() {
573        use crate::IntoArray;
574        use crate::arrays::ConstantArray;
575        use crate::arrays::PrimitiveArray;
576
577        let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
578        let rhs = ConstantArray::new(Scalar::from(Some(1u16)), 4).into_array();
579        let result = values.into_array().binary(rhs, Operator::Sub).unwrap();
580        assert_arrays_eq!(
581            result,
582            PrimitiveArray::from_option_iter([Some(0u16), Some(1), None, Some(2)])
583        );
584    }
585
586    #[test]
587    fn test_scalar_subtract_float() {
588        use vortex_buffer::buffer;
589
590        use crate::IntoArray;
591        use crate::arrays::ConstantArray;
592        use crate::arrays::PrimitiveArray;
593
594        let values = buffer![1.0f64, 2.0, 3.0].into_array();
595        let rhs = ConstantArray::new(Scalar::from(-1f64), 3).into_array();
596        let result = values.binary(rhs, Operator::Sub).unwrap();
597        assert_arrays_eq!(result, PrimitiveArray::from_iter([2.0f64, 3.0, 4.0]));
598    }
599
600    #[test]
601    fn test_scalar_subtract_float_underflow_is_ok() {
602        use vortex_buffer::buffer;
603
604        use crate::IntoArray;
605        use crate::arrays::ConstantArray;
606
607        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
608        let rhs1 = ConstantArray::new(Scalar::from(1.0f32), 3).into_array();
609        let _results = values.binary(rhs1, Operator::Sub).unwrap();
610        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
611        let rhs2 = ConstantArray::new(Scalar::from(f32::MAX), 3).into_array();
612        let _results = values.binary(rhs2, Operator::Sub).unwrap();
613    }
614}