Skip to main content

vortex_array/scalar_fn/fns/binary/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6#[expect(deprecated)]
7pub use boolean::and_kleene;
8#[expect(deprecated)]
9pub use boolean::or_kleene;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::dtype::DType;
18use crate::expr::StatsCatalog;
19use crate::expr::and;
20use crate::expr::and_collect;
21use crate::expr::eq;
22use crate::expr::expression::Expression;
23use crate::expr::gt;
24use crate::expr::gt_eq;
25use crate::expr::lit;
26use crate::expr::lt;
27use crate::expr::lt_eq;
28use crate::expr::or_collect;
29use crate::expr::stats::Stat;
30use crate::scalar_fn::Arity;
31use crate::scalar_fn::ChildName;
32use crate::scalar_fn::ExecutionArgs;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35use crate::scalar_fn::fns::operators::CompareOperator;
36use crate::scalar_fn::fns::operators::Operator;
37
38pub(crate) mod boolean;
39pub(crate) use boolean::*;
40mod compare;
41pub use compare::*;
42mod numeric;
43pub(crate) use numeric::*;
44
45#[derive(Clone)]
46pub struct Binary;
47
48impl ScalarFnVTable for Binary {
49    type Options = Operator;
50
51    fn id(&self) -> ScalarFnId {
52        ScalarFnId::from("vortex.binary")
53    }
54
55    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
56        Ok(Some(
57            pb::BinaryOpts {
58                op: (*instance).into(),
59            }
60            .encode_to_vec(),
61        ))
62    }
63
64    fn deserialize(
65        &self,
66        _metadata: &[u8],
67        _session: &VortexSession,
68    ) -> VortexResult<Self::Options> {
69        let opts = pb::BinaryOpts::decode(_metadata)?;
70        Operator::try_from(opts.op)
71    }
72
73    fn arity(&self, _options: &Self::Options) -> Arity {
74        Arity::Exact(2)
75    }
76
77    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
78        match child_idx {
79            0 => ChildName::from("lhs"),
80            1 => ChildName::from("rhs"),
81            _ => unreachable!("Binary has only two children"),
82        }
83    }
84
85    fn fmt_sql(
86        &self,
87        operator: &Operator,
88        expr: &Expression,
89        f: &mut Formatter<'_>,
90    ) -> std::fmt::Result {
91        write!(f, "(")?;
92        expr.child(0).fmt_sql(f)?;
93        write!(f, " {} ", operator)?;
94        expr.child(1).fmt_sql(f)?;
95        write!(f, ")")
96    }
97
98    fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
99        let lhs = &arg_dtypes[0];
100        let rhs = &arg_dtypes[1];
101
102        if operator.is_arithmetic() {
103            if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
104                return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
105            }
106            vortex_bail!(
107                "incompatible types for arithmetic operation: {} {}",
108                lhs,
109                rhs
110            );
111        }
112
113        if operator.is_comparison()
114            && !lhs.eq_ignore_nullability(rhs)
115            && !lhs.is_extension()
116            && !rhs.is_extension()
117        {
118            vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs);
119        }
120
121        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
122    }
123
124    fn execute(&self, op: &Operator, args: ExecutionArgs) -> VortexResult<ArrayRef> {
125        let [lhs, rhs] = &args.inputs[..] else {
126            vortex_bail!("Wrong arg count")
127        };
128
129        match op {
130            Operator::Eq => execute_compare(lhs, rhs, CompareOperator::Eq),
131            Operator::NotEq => execute_compare(lhs, rhs, CompareOperator::NotEq),
132            Operator::Lt => execute_compare(lhs, rhs, CompareOperator::Lt),
133            Operator::Lte => execute_compare(lhs, rhs, CompareOperator::Lte),
134            Operator::Gt => execute_compare(lhs, rhs, CompareOperator::Gt),
135            Operator::Gte => execute_compare(lhs, rhs, CompareOperator::Gte),
136            Operator::And => execute_boolean(lhs, rhs, Operator::And),
137            Operator::Or => execute_boolean(lhs, rhs, Operator::Or),
138            Operator::Add => execute_numeric(lhs, rhs, crate::scalar::NumericOperator::Add),
139            Operator::Sub => execute_numeric(lhs, rhs, crate::scalar::NumericOperator::Sub),
140            Operator::Mul => execute_numeric(lhs, rhs, crate::scalar::NumericOperator::Mul),
141            Operator::Div => execute_numeric(lhs, rhs, crate::scalar::NumericOperator::Div),
142        }
143    }
144
145    fn stat_falsification(
146        &self,
147        operator: &Operator,
148        expr: &Expression,
149        catalog: &dyn StatsCatalog,
150    ) -> Option<Expression> {
151        // Wrap another predicate with an optional NaNCount check, if the stat is available.
152        //
153        // For example, regular pruning conversion for `A >= B` would be
154        //
155        //      A.max < B.min
156        //
157        // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting
158        // in:
159        //
160        //      (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min
161        //
162        // Non-floating point column and literal expressions should be unaffected as they do not
163        // have a nan_count statistic defined.
164        #[inline]
165        fn with_nan_predicate(
166            lhs: &Expression,
167            rhs: &Expression,
168            value_predicate: Expression,
169            catalog: &dyn StatsCatalog,
170        ) -> Expression {
171            let nan_predicate = and_collect(
172                lhs.stat_expression(Stat::NaNCount, catalog)
173                    .into_iter()
174                    .chain(rhs.stat_expression(Stat::NaNCount, catalog))
175                    .map(|nans| eq(nans, lit(0u64))),
176            );
177
178            if let Some(nan_check) = nan_predicate {
179                and(nan_check, value_predicate)
180            } else {
181                value_predicate
182            }
183        }
184
185        let lhs = expr.child(0);
186        let rhs = expr.child(1);
187        match operator {
188            Operator::Eq => {
189                let min_lhs = lhs.stat_min(catalog);
190                let max_lhs = lhs.stat_max(catalog);
191
192                let min_rhs = rhs.stat_min(catalog);
193                let max_rhs = rhs.stat_max(catalog);
194
195                let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
196                let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
197
198                let min_max_check = or_collect(left.into_iter().chain(right))?;
199
200                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
201                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
202            }
203            Operator::NotEq => {
204                let min_lhs = lhs.stat_min(catalog)?;
205                let max_lhs = lhs.stat_max(catalog)?;
206
207                let min_rhs = rhs.stat_min(catalog)?;
208                let max_rhs = rhs.stat_max(catalog)?;
209
210                let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
211
212                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
213            }
214            Operator::Gt => {
215                let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
216
217                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
218            }
219            Operator::Gte => {
220                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
221                let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
222
223                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
224            }
225            Operator::Lt => {
226                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
227                let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
228
229                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
230            }
231            Operator::Lte => {
232                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
233                let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
234
235                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
236            }
237            Operator::And => or_collect(
238                lhs.stat_falsification(catalog)
239                    .into_iter()
240                    .chain(rhs.stat_falsification(catalog)),
241            ),
242            Operator::Or => Some(and(
243                lhs.stat_falsification(catalog)?,
244                rhs.stat_falsification(catalog)?,
245            )),
246            Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
247        }
248    }
249
250    fn validity(
251        &self,
252        operator: &Operator,
253        expression: &Expression,
254    ) -> VortexResult<Option<Expression>> {
255        let lhs = expression.child(0).validity()?;
256        let rhs = expression.child(1).validity()?;
257
258        Ok(match operator {
259            // AND and OR are kleene logic.
260            Operator::And => None,
261            Operator::Or => None,
262            _ => {
263                // All other binary operators are null if either side is null.
264                Some(and(lhs, rhs))
265            }
266        })
267    }
268
269    fn is_null_sensitive(&self, _operator: &Operator) -> bool {
270        false
271    }
272
273    fn is_fallible(&self, operator: &Operator) -> bool {
274        // Opt-in not out for fallibility.
275        // Arithmetic operations could be better modelled here.
276        let infallible = matches!(
277            operator,
278            Operator::Eq
279                | Operator::NotEq
280                | Operator::Gt
281                | Operator::Gte
282                | Operator::Lt
283                | Operator::Lte
284                | Operator::And
285                | Operator::Or
286        );
287
288        !infallible
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use vortex_error::VortexExpect;
295
296    use super::*;
297    use crate::assert_arrays_eq;
298    use crate::builtins::ArrayBuiltins;
299    use crate::dtype::DType;
300    use crate::dtype::Nullability;
301    use crate::expr::Expression;
302    use crate::expr::and_collect;
303    use crate::expr::col;
304    use crate::expr::lit;
305    use crate::expr::lt;
306    use crate::expr::not_eq;
307    use crate::expr::or;
308    use crate::expr::or_collect;
309    use crate::expr::test_harness;
310    use crate::scalar::Scalar;
311    #[test]
312    fn and_collect_balanced() {
313        let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
314
315        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
316        vortex.binary(and)
317        ├── lhs: vortex.binary(and)
318        │   ├── lhs: vortex.literal(1i32)
319        │   └── rhs: vortex.literal(2i32)
320        └── rhs: vortex.binary(and)
321            ├── lhs: vortex.binary(and)
322            │   ├── lhs: vortex.literal(3i32)
323            │   └── rhs: vortex.literal(4i32)
324            └── rhs: vortex.literal(5i32)
325        ");
326
327        // 4 elements: and(and(1, 2), and(3, 4)) - perfectly balanced
328        let values = vec![lit(1), lit(2), lit(3), lit(4)];
329        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
330        vortex.binary(and)
331        ├── lhs: vortex.binary(and)
332        │   ├── lhs: vortex.literal(1i32)
333        │   └── rhs: vortex.literal(2i32)
334        └── rhs: vortex.binary(and)
335            ├── lhs: vortex.literal(3i32)
336            └── rhs: vortex.literal(4i32)
337        ");
338
339        // 1 element: just the element
340        let values = vec![lit(1)];
341        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @"vortex.literal(1i32)");
342
343        // 0 elements: None
344        let values: Vec<Expression> = vec![];
345        assert!(and_collect(values.into_iter()).is_none());
346    }
347
348    #[test]
349    fn or_collect_balanced() {
350        // 4 elements: or(or(1, 2), or(3, 4)) - perfectly balanced
351        let values = vec![lit(1), lit(2), lit(3), lit(4)];
352        insta::assert_snapshot!(or_collect(values.into_iter()).unwrap().display_tree(), @r"
353        vortex.binary(or)
354        ├── lhs: vortex.binary(or)
355        │   ├── lhs: vortex.literal(1i32)
356        │   └── rhs: vortex.literal(2i32)
357        └── rhs: vortex.binary(or)
358            ├── lhs: vortex.literal(3i32)
359            └── rhs: vortex.literal(4i32)
360        ");
361    }
362
363    #[test]
364    fn dtype() {
365        let dtype = test_harness::struct_dtype();
366        let bool1: Expression = col("bool1");
367        let bool2: Expression = col("bool2");
368        assert_eq!(
369            and(bool1.clone(), bool2.clone())
370                .return_dtype(&dtype)
371                .unwrap(),
372            DType::Bool(Nullability::NonNullable)
373        );
374        assert_eq!(
375            or(bool1, bool2).return_dtype(&dtype).unwrap(),
376            DType::Bool(Nullability::NonNullable)
377        );
378
379        let col1: Expression = col("col1");
380        let col2: Expression = col("col2");
381
382        assert_eq!(
383            eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
384            DType::Bool(Nullability::Nullable)
385        );
386        assert_eq!(
387            not_eq(col1.clone(), col2.clone())
388                .return_dtype(&dtype)
389                .unwrap(),
390            DType::Bool(Nullability::Nullable)
391        );
392        assert_eq!(
393            gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
394            DType::Bool(Nullability::Nullable)
395        );
396        assert_eq!(
397            gt_eq(col1.clone(), col2.clone())
398                .return_dtype(&dtype)
399                .unwrap(),
400            DType::Bool(Nullability::Nullable)
401        );
402        assert_eq!(
403            lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
404            DType::Bool(Nullability::Nullable)
405        );
406        assert_eq!(
407            lt_eq(col1.clone(), col2.clone())
408                .return_dtype(&dtype)
409                .unwrap(),
410            DType::Bool(Nullability::Nullable)
411        );
412
413        assert_eq!(
414            or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
415                .return_dtype(&dtype)
416                .unwrap(),
417            DType::Bool(Nullability::Nullable)
418        );
419    }
420
421    #[test]
422    fn test_display_print() {
423        let expr = gt(lit(1), lit(2));
424        assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
425    }
426
427    /// Regression test for GitHub issue #5947: struct comparison in filter expressions should work
428    /// using `make_comparator` instead of Arrow's `cmp` functions which don't support nested types.
429    #[test]
430    fn test_struct_comparison() {
431        use crate::IntoArray;
432        use crate::arrays::StructArray;
433
434        // Create a struct array with one element for testing.
435        let lhs_struct = StructArray::from_fields(&[
436            (
437                "a",
438                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
439            ),
440            (
441                "b",
442                crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
443            ),
444        ])
445        .unwrap()
446        .into_array();
447
448        let rhs_struct_equal = StructArray::from_fields(&[
449            (
450                "a",
451                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
452            ),
453            (
454                "b",
455                crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
456            ),
457        ])
458        .unwrap()
459        .into_array();
460
461        let rhs_struct_different = StructArray::from_fields(&[
462            (
463                "a",
464                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
465            ),
466            (
467                "b",
468                crate::arrays::PrimitiveArray::from_iter([4i32]).into_array(),
469            ),
470        ])
471        .unwrap()
472        .into_array();
473
474        // Test using binary method directly
475        let result_equal = lhs_struct.binary(rhs_struct_equal, Operator::Eq).unwrap();
476        assert_eq!(
477            result_equal.scalar_at(0).vortex_expect("value"),
478            Scalar::bool(true, Nullability::NonNullable),
479            "Equal structs should be equal"
480        );
481
482        let result_different = lhs_struct
483            .binary(rhs_struct_different, Operator::Eq)
484            .unwrap();
485        assert_eq!(
486            result_different.scalar_at(0).vortex_expect("value"),
487            Scalar::bool(false, Nullability::NonNullable),
488            "Different structs should not be equal"
489        );
490    }
491
492    #[test]
493    fn test_or_kleene_validity() {
494        use crate::IntoArray;
495        use crate::arrays::BoolArray;
496        use crate::arrays::StructArray;
497        use crate::expr::col;
498
499        let struct_arr = StructArray::from_fields(&[
500            ("a", BoolArray::from_iter([Some(true)]).into_array()),
501            (
502                "b",
503                BoolArray::from_iter([Option::<bool>::None]).into_array(),
504            ),
505        ])
506        .unwrap()
507        .into_array();
508
509        let expr = or(col("a"), col("b"));
510        let result = struct_arr.apply(&expr).unwrap();
511
512        assert_arrays_eq!(result, BoolArray::from_iter([Some(true)]).into_array())
513    }
514}