Skip to main content

vortex_array/scalar_fn/fns/binary/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6#[expect(deprecated)]
7pub use boolean::and_kleene;
8#[expect(deprecated)]
9pub use boolean::or_kleene;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::dtype::DType;
19use crate::expr::StatsCatalog;
20use crate::expr::and;
21use crate::expr::and_collect;
22use crate::expr::eq;
23use crate::expr::expression::Expression;
24use crate::expr::gt;
25use crate::expr::gt_eq;
26use crate::expr::lit;
27use crate::expr::lt;
28use crate::expr::lt_eq;
29use crate::expr::or_collect;
30use crate::expr::stats::Stat;
31use crate::scalar_fn::Arity;
32use crate::scalar_fn::ChildName;
33use crate::scalar_fn::ExecutionArgs;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36use crate::scalar_fn::fns::operators::CompareOperator;
37use crate::scalar_fn::fns::operators::Operator;
38
39pub(crate) mod boolean;
40pub(crate) use boolean::*;
41mod compare;
42pub use compare::*;
43mod numeric;
44pub(crate) use numeric::*;
45
46use crate::scalar::NumericOperator;
47
48#[derive(Clone)]
49pub struct Binary;
50
51impl ScalarFnVTable for Binary {
52    type Options = Operator;
53
54    fn id(&self) -> ScalarFnId {
55        ScalarFnId::new("vortex.binary")
56    }
57
58    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
59        Ok(Some(
60            pb::BinaryOpts {
61                op: (*instance).into(),
62            }
63            .encode_to_vec(),
64        ))
65    }
66
67    fn deserialize(
68        &self,
69        _metadata: &[u8],
70        _session: &VortexSession,
71    ) -> VortexResult<Self::Options> {
72        let opts = pb::BinaryOpts::decode(_metadata)?;
73        Operator::try_from(opts.op)
74    }
75
76    fn arity(&self, _options: &Self::Options) -> Arity {
77        Arity::Exact(2)
78    }
79
80    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
81        match child_idx {
82            0 => ChildName::from("lhs"),
83            1 => ChildName::from("rhs"),
84            _ => unreachable!("Binary has only two children"),
85        }
86    }
87
88    fn fmt_sql(
89        &self,
90        operator: &Operator,
91        expr: &Expression,
92        f: &mut Formatter<'_>,
93    ) -> std::fmt::Result {
94        write!(f, "(")?;
95        expr.child(0).fmt_sql(f)?;
96        write!(f, " {} ", operator)?;
97        expr.child(1).fmt_sql(f)?;
98        write!(f, ")")
99    }
100
101    fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
102        let lhs = &arg_dtypes[0];
103        let rhs = &arg_dtypes[1];
104
105        if operator.is_arithmetic() {
106            if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
107                return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
108            }
109            vortex_bail!(
110                "incompatible types for arithmetic operation: {} {}",
111                lhs,
112                rhs
113            );
114        }
115
116        if operator.is_comparison()
117            && !lhs.eq_ignore_nullability(rhs)
118            && !lhs.is_extension()
119            && !rhs.is_extension()
120        {
121            vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs);
122        }
123
124        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
125    }
126
127    fn execute(
128        &self,
129        op: &Operator,
130        args: &dyn ExecutionArgs,
131        ctx: &mut ExecutionCtx,
132    ) -> VortexResult<ArrayRef> {
133        let lhs = args.get(0)?;
134        let rhs = args.get(1)?;
135
136        match op {
137            Operator::Eq => execute_compare(&lhs, &rhs, CompareOperator::Eq, ctx),
138            Operator::NotEq => execute_compare(&lhs, &rhs, CompareOperator::NotEq, ctx),
139            Operator::Lt => execute_compare(&lhs, &rhs, CompareOperator::Lt, ctx),
140            Operator::Lte => execute_compare(&lhs, &rhs, CompareOperator::Lte, ctx),
141            Operator::Gt => execute_compare(&lhs, &rhs, CompareOperator::Gt, ctx),
142            Operator::Gte => execute_compare(&lhs, &rhs, CompareOperator::Gte, ctx),
143            Operator::And => execute_boolean(&lhs, &rhs, Operator::And, ctx),
144            Operator::Or => execute_boolean(&lhs, &rhs, Operator::Or, ctx),
145            Operator::Add => execute_numeric(&lhs, &rhs, NumericOperator::Add, ctx),
146            Operator::Sub => execute_numeric(&lhs, &rhs, NumericOperator::Sub, ctx),
147            Operator::Mul => execute_numeric(&lhs, &rhs, NumericOperator::Mul, ctx),
148            Operator::Div => execute_numeric(&lhs, &rhs, NumericOperator::Div, ctx),
149        }
150    }
151
152    fn stat_falsification(
153        &self,
154        operator: &Operator,
155        expr: &Expression,
156        catalog: &dyn StatsCatalog,
157    ) -> Option<Expression> {
158        // Wrap another predicate with an optional NaNCount check, if the stat is available.
159        //
160        // For example, regular pruning conversion for `A >= B` would be
161        //
162        //      A.max < B.min
163        //
164        // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting
165        // in:
166        //
167        //      (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min
168        //
169        // Non-floating point column and literal expressions should be unaffected as they do not
170        // have a nan_count statistic defined.
171        fn with_nan_predicate(
172            lhs: &Expression,
173            rhs: &Expression,
174            value_predicate: Expression,
175            catalog: &dyn StatsCatalog,
176        ) -> Expression {
177            let nan_predicate = and_collect(
178                lhs.stat_expression(Stat::NaNCount, catalog)
179                    .into_iter()
180                    .chain(rhs.stat_expression(Stat::NaNCount, catalog))
181                    .map(|nans| eq(nans, lit(0u64))),
182            );
183
184            if let Some(nan_check) = nan_predicate {
185                and(nan_check, value_predicate)
186            } else {
187                value_predicate
188            }
189        }
190
191        let lhs = expr.child(0);
192        let rhs = expr.child(1);
193        match operator {
194            Operator::Eq => {
195                let min_lhs = lhs.stat_min(catalog);
196                let max_lhs = lhs.stat_max(catalog);
197
198                let min_rhs = rhs.stat_min(catalog);
199                let max_rhs = rhs.stat_max(catalog);
200
201                let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
202                let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
203
204                let min_max_check = or_collect(left.into_iter().chain(right))?;
205
206                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
207                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
208            }
209            Operator::NotEq => {
210                let min_lhs = lhs.stat_min(catalog)?;
211                let max_lhs = lhs.stat_max(catalog)?;
212
213                let min_rhs = rhs.stat_min(catalog)?;
214                let max_rhs = rhs.stat_max(catalog)?;
215
216                let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
217
218                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
219            }
220            Operator::Gt => {
221                let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
222
223                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
224            }
225            Operator::Gte => {
226                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
227                let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
228
229                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
230            }
231            Operator::Lt => {
232                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
233                let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
234
235                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
236            }
237            Operator::Lte => {
238                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
239                let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
240
241                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
242            }
243            Operator::And => or_collect(
244                lhs.stat_falsification(catalog)
245                    .into_iter()
246                    .chain(rhs.stat_falsification(catalog)),
247            ),
248            Operator::Or => Some(and(
249                lhs.stat_falsification(catalog)?,
250                rhs.stat_falsification(catalog)?,
251            )),
252            Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
253        }
254    }
255
256    fn validity(
257        &self,
258        operator: &Operator,
259        expression: &Expression,
260    ) -> VortexResult<Option<Expression>> {
261        let lhs = expression.child(0).validity()?;
262        let rhs = expression.child(1).validity()?;
263
264        Ok(match operator {
265            // AND and OR are kleene logic.
266            Operator::And => None,
267            Operator::Or => None,
268            _ => {
269                // All other binary operators are null if either side is null.
270                Some(and(lhs, rhs))
271            }
272        })
273    }
274
275    fn is_null_sensitive(&self, _operator: &Operator) -> bool {
276        false
277    }
278
279    fn is_fallible(&self, operator: &Operator) -> bool {
280        // Opt-in not out for fallibility.
281        // Arithmetic operations could be better modelled here.
282        let infallible = matches!(
283            operator,
284            Operator::Eq
285                | Operator::NotEq
286                | Operator::Gt
287                | Operator::Gte
288                | Operator::Lt
289                | Operator::Lte
290                | Operator::And
291                | Operator::Or
292        );
293
294        !infallible
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use vortex_error::VortexExpect;
301
302    use super::*;
303    use crate::LEGACY_SESSION;
304    use crate::VortexSessionExecute;
305    use crate::assert_arrays_eq;
306    use crate::builtins::ArrayBuiltins;
307    use crate::dtype::DType;
308    use crate::dtype::Nullability;
309    use crate::expr::Expression;
310    use crate::expr::and_collect;
311    use crate::expr::col;
312    use crate::expr::lit;
313    use crate::expr::lt;
314    use crate::expr::not_eq;
315    use crate::expr::or;
316    use crate::expr::or_collect;
317    use crate::expr::test_harness;
318    use crate::scalar::Scalar;
319    #[test]
320    fn and_collect_balanced() {
321        let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
322
323        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
324        vortex.binary(and)
325        ├── lhs: vortex.binary(and)
326        │   ├── lhs: vortex.literal(1i32)
327        │   └── rhs: vortex.literal(2i32)
328        └── rhs: vortex.binary(and)
329            ├── lhs: vortex.binary(and)
330            │   ├── lhs: vortex.literal(3i32)
331            │   └── rhs: vortex.literal(4i32)
332            └── rhs: vortex.literal(5i32)
333        ");
334
335        // 4 elements: and(and(1, 2), and(3, 4)) - perfectly balanced
336        let values = vec![lit(1), lit(2), lit(3), lit(4)];
337        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
338        vortex.binary(and)
339        ├── lhs: vortex.binary(and)
340        │   ├── lhs: vortex.literal(1i32)
341        │   └── rhs: vortex.literal(2i32)
342        └── rhs: vortex.binary(and)
343            ├── lhs: vortex.literal(3i32)
344            └── rhs: vortex.literal(4i32)
345        ");
346
347        // 1 element: just the element
348        let values = vec![lit(1)];
349        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @"vortex.literal(1i32)");
350
351        // 0 elements: None
352        let values: Vec<Expression> = vec![];
353        assert!(and_collect(values.into_iter()).is_none());
354    }
355
356    #[test]
357    fn or_collect_balanced() {
358        // 4 elements: or(or(1, 2), or(3, 4)) - perfectly balanced
359        let values = vec![lit(1), lit(2), lit(3), lit(4)];
360        insta::assert_snapshot!(or_collect(values.into_iter()).unwrap().display_tree(), @r"
361        vortex.binary(or)
362        ├── lhs: vortex.binary(or)
363        │   ├── lhs: vortex.literal(1i32)
364        │   └── rhs: vortex.literal(2i32)
365        └── rhs: vortex.binary(or)
366            ├── lhs: vortex.literal(3i32)
367            └── rhs: vortex.literal(4i32)
368        ");
369    }
370
371    #[test]
372    fn dtype() {
373        let dtype = test_harness::struct_dtype();
374        let bool1: Expression = col("bool1");
375        let bool2: Expression = col("bool2");
376        assert_eq!(
377            and(bool1.clone(), bool2.clone())
378                .return_dtype(&dtype)
379                .unwrap(),
380            DType::Bool(Nullability::NonNullable)
381        );
382        assert_eq!(
383            or(bool1, bool2).return_dtype(&dtype).unwrap(),
384            DType::Bool(Nullability::NonNullable)
385        );
386
387        let col1: Expression = col("col1");
388        let col2: Expression = col("col2");
389
390        assert_eq!(
391            eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
392            DType::Bool(Nullability::Nullable)
393        );
394        assert_eq!(
395            not_eq(col1.clone(), col2.clone())
396                .return_dtype(&dtype)
397                .unwrap(),
398            DType::Bool(Nullability::Nullable)
399        );
400        assert_eq!(
401            gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
402            DType::Bool(Nullability::Nullable)
403        );
404        assert_eq!(
405            gt_eq(col1.clone(), col2.clone())
406                .return_dtype(&dtype)
407                .unwrap(),
408            DType::Bool(Nullability::Nullable)
409        );
410        assert_eq!(
411            lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
412            DType::Bool(Nullability::Nullable)
413        );
414        assert_eq!(
415            lt_eq(col1.clone(), col2.clone())
416                .return_dtype(&dtype)
417                .unwrap(),
418            DType::Bool(Nullability::Nullable)
419        );
420
421        assert_eq!(
422            or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
423                .return_dtype(&dtype)
424                .unwrap(),
425            DType::Bool(Nullability::Nullable)
426        );
427    }
428
429    #[test]
430    fn test_display_print() {
431        let expr = gt(lit(1), lit(2));
432        assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
433    }
434
435    /// Regression test for GitHub issue #5947: struct comparison in filter expressions should work
436    /// using `make_comparator` instead of Arrow's `cmp` functions which don't support nested types.
437    #[test]
438    fn test_struct_comparison() {
439        use crate::IntoArray;
440        use crate::arrays::StructArray;
441
442        // Create a struct array with one element for testing.
443        let lhs_struct = StructArray::from_fields(&[
444            (
445                "a",
446                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
447            ),
448            (
449                "b",
450                crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
451            ),
452        ])
453        .unwrap()
454        .into_array();
455
456        let rhs_struct_equal = StructArray::from_fields(&[
457            (
458                "a",
459                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
460            ),
461            (
462                "b",
463                crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
464            ),
465        ])
466        .unwrap()
467        .into_array();
468
469        let rhs_struct_different = StructArray::from_fields(&[
470            (
471                "a",
472                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
473            ),
474            (
475                "b",
476                crate::arrays::PrimitiveArray::from_iter([4i32]).into_array(),
477            ),
478        ])
479        .unwrap()
480        .into_array();
481
482        // Test using binary method directly
483        let result_equal = lhs_struct.binary(rhs_struct_equal, Operator::Eq).unwrap();
484        assert_eq!(
485            result_equal
486                .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
487                .vortex_expect("value"),
488            Scalar::bool(true, Nullability::NonNullable),
489            "Equal structs should be equal"
490        );
491
492        let result_different = lhs_struct
493            .binary(rhs_struct_different, Operator::Eq)
494            .unwrap();
495        assert_eq!(
496            result_different
497                .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
498                .vortex_expect("value"),
499            Scalar::bool(false, Nullability::NonNullable),
500            "Different structs should not be equal"
501        );
502    }
503
504    #[test]
505    fn test_or_kleene_validity() {
506        use crate::IntoArray;
507        use crate::arrays::BoolArray;
508        use crate::arrays::StructArray;
509        use crate::expr::col;
510
511        let struct_arr = StructArray::from_fields(&[
512            ("a", BoolArray::from_iter([Some(true)]).into_array()),
513            (
514                "b",
515                BoolArray::from_iter([Option::<bool>::None]).into_array(),
516            ),
517        ])
518        .unwrap()
519        .into_array();
520
521        let expr = or(col("a"), col("b"));
522        let result = struct_arr.apply(&expr).unwrap();
523
524        assert_arrays_eq!(result, BoolArray::from_iter([Some(true)]).into_array())
525    }
526
527    #[test]
528    fn test_scalar_subtract_unsigned() {
529        use vortex_buffer::buffer;
530
531        use crate::IntoArray;
532        use crate::arrays::ConstantArray;
533        use crate::arrays::PrimitiveArray;
534
535        let values = buffer![1u16, 2, 3].into_array();
536        let rhs = ConstantArray::new(Scalar::from(1u16), 3).into_array();
537        let result = values.binary(rhs, Operator::Sub).unwrap();
538        assert_arrays_eq!(result, PrimitiveArray::from_iter([0u16, 1, 2]));
539    }
540
541    #[test]
542    fn test_scalar_subtract_signed() {
543        use vortex_buffer::buffer;
544
545        use crate::IntoArray;
546        use crate::arrays::ConstantArray;
547        use crate::arrays::PrimitiveArray;
548
549        let values = buffer![1i64, 2, 3].into_array();
550        let rhs = ConstantArray::new(Scalar::from(-1i64), 3).into_array();
551        let result = values.binary(rhs, Operator::Sub).unwrap();
552        assert_arrays_eq!(result, PrimitiveArray::from_iter([2i64, 3, 4]));
553    }
554
555    #[test]
556    fn test_scalar_subtract_nullable() {
557        use crate::IntoArray;
558        use crate::arrays::ConstantArray;
559        use crate::arrays::PrimitiveArray;
560
561        let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
562        let rhs = ConstantArray::new(Scalar::from(Some(1u16)), 4).into_array();
563        let result = values.into_array().binary(rhs, Operator::Sub).unwrap();
564        assert_arrays_eq!(
565            result,
566            PrimitiveArray::from_option_iter([Some(0u16), Some(1), None, Some(2)])
567        );
568    }
569
570    #[test]
571    fn test_scalar_subtract_float() {
572        use vortex_buffer::buffer;
573
574        use crate::IntoArray;
575        use crate::arrays::ConstantArray;
576        use crate::arrays::PrimitiveArray;
577
578        let values = buffer![1.0f64, 2.0, 3.0].into_array();
579        let rhs = ConstantArray::new(Scalar::from(-1f64), 3).into_array();
580        let result = values.binary(rhs, Operator::Sub).unwrap();
581        assert_arrays_eq!(result, PrimitiveArray::from_iter([2.0f64, 3.0, 4.0]));
582    }
583
584    #[test]
585    fn test_scalar_subtract_float_underflow_is_ok() {
586        use vortex_buffer::buffer;
587
588        use crate::IntoArray;
589        use crate::arrays::ConstantArray;
590
591        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
592        let rhs1 = ConstantArray::new(Scalar::from(1.0f32), 3).into_array();
593        let _results = values.binary(rhs1, Operator::Sub).unwrap();
594        let values = buffer![f32::MIN, 2.0, 3.0].into_array();
595        let rhs2 = ConstantArray::new(Scalar::from(f32::MAX), 3).into_array();
596        let _results = values.binary(rhs2, Operator::Sub).unwrap();
597    }
598}