Skip to main content

vortex_array/expr/exprs/binary/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_dtype::DType;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_proto::expr as pb;
12use vortex_session::VortexSession;
13
14use crate::ArrayRef;
15use crate::compute;
16use crate::compute::BooleanOperator;
17use crate::expr::Arity;
18use crate::expr::ChildName;
19use crate::expr::ExecutionArgs;
20use crate::expr::ExprId;
21use crate::expr::StatsCatalog;
22use crate::expr::VTable;
23use crate::expr::VTableExt;
24use crate::expr::expression::Expression;
25use crate::expr::exprs::literal::lit;
26use crate::expr::exprs::operators::Operator;
27use crate::expr::stats::Stat;
28
29mod boolean;
30pub(crate) use boolean::*;
31mod compare;
32pub use compare::*;
33mod numeric;
34pub(crate) use numeric::*;
35
36pub struct Binary;
37
38impl VTable for Binary {
39    type Options = Operator;
40
41    fn id(&self) -> ExprId {
42        ExprId::from("vortex.binary")
43    }
44
45    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
46        Ok(Some(
47            pb::BinaryOpts {
48                op: (*instance).into(),
49            }
50            .encode_to_vec(),
51        ))
52    }
53
54    fn deserialize(
55        &self,
56        _metadata: &[u8],
57        _session: &VortexSession,
58    ) -> VortexResult<Self::Options> {
59        let opts = pb::BinaryOpts::decode(_metadata)?;
60        Operator::try_from(opts.op)
61    }
62
63    fn arity(&self, _options: &Self::Options) -> Arity {
64        Arity::Exact(2)
65    }
66
67    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
68        match child_idx {
69            0 => ChildName::from("lhs"),
70            1 => ChildName::from("rhs"),
71            _ => unreachable!("Binary has only two children"),
72        }
73    }
74
75    fn fmt_sql(
76        &self,
77        operator: &Operator,
78        expr: &Expression,
79        f: &mut Formatter<'_>,
80    ) -> std::fmt::Result {
81        write!(f, "(")?;
82        expr.child(0).fmt_sql(f)?;
83        write!(f, " {} ", operator)?;
84        expr.child(1).fmt_sql(f)?;
85        write!(f, ")")
86    }
87
88    fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
89        let lhs = &arg_dtypes[0];
90        let rhs = &arg_dtypes[1];
91
92        if operator.is_arithmetic() {
93            if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
94                return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
95            }
96            vortex_bail!(
97                "incompatible types for arithmetic operation: {} {}",
98                lhs,
99                rhs
100            );
101        }
102
103        if operator.is_comparison()
104            && !lhs.eq_ignore_nullability(rhs)
105            && !lhs.is_extension()
106            && !rhs.is_extension()
107        {
108            vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs);
109        }
110
111        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
112    }
113
114    fn execute(&self, op: &Operator, args: ExecutionArgs) -> VortexResult<ArrayRef> {
115        let [lhs, rhs] = &args.inputs[..] else {
116            vortex_bail!("Wrong arg count")
117        };
118
119        match op {
120            Operator::Eq => execute_compare(lhs, rhs, compute::Operator::Eq),
121            Operator::NotEq => execute_compare(lhs, rhs, compute::Operator::NotEq),
122            Operator::Lt => execute_compare(lhs, rhs, compute::Operator::Lt),
123            Operator::Lte => execute_compare(lhs, rhs, compute::Operator::Lte),
124            Operator::Gt => execute_compare(lhs, rhs, compute::Operator::Gt),
125            Operator::Gte => execute_compare(lhs, rhs, compute::Operator::Gte),
126            Operator::And => execute_boolean(lhs, rhs, BooleanOperator::AndKleene),
127            Operator::Or => execute_boolean(lhs, rhs, BooleanOperator::OrKleene),
128            Operator::Add => execute_numeric(lhs, rhs, crate::scalar::NumericOperator::Add),
129            Operator::Sub => execute_numeric(lhs, rhs, crate::scalar::NumericOperator::Sub),
130            Operator::Mul => execute_numeric(lhs, rhs, crate::scalar::NumericOperator::Mul),
131            Operator::Div => execute_numeric(lhs, rhs, crate::scalar::NumericOperator::Div),
132        }
133    }
134
135    fn stat_falsification(
136        &self,
137        operator: &Operator,
138        expr: &Expression,
139        catalog: &dyn StatsCatalog,
140    ) -> Option<Expression> {
141        // Wrap another predicate with an optional NaNCount check, if the stat is available.
142        //
143        // For example, regular pruning conversion for `A >= B` would be
144        //
145        //      A.max < B.min
146        //
147        // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting
148        // in:
149        //
150        //      (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min
151        //
152        // Non-floating point column and literal expressions should be unaffected as they do not
153        // have a nan_count statistic defined.
154        #[inline]
155        fn with_nan_predicate(
156            lhs: &Expression,
157            rhs: &Expression,
158            value_predicate: Expression,
159            catalog: &dyn StatsCatalog,
160        ) -> Expression {
161            let nan_predicate = lhs
162                .stat_expression(Stat::NaNCount, catalog)
163                .into_iter()
164                .chain(rhs.stat_expression(Stat::NaNCount, catalog))
165                .map(|nans| eq(nans, lit(0u64)))
166                .reduce(and);
167
168            if let Some(nan_check) = nan_predicate {
169                and(nan_check, value_predicate)
170            } else {
171                value_predicate
172            }
173        }
174
175        let lhs = expr.child(0);
176        let rhs = expr.child(1);
177        match operator {
178            Operator::Eq => {
179                let min_lhs = lhs.stat_min(catalog);
180                let max_lhs = lhs.stat_max(catalog);
181
182                let min_rhs = rhs.stat_min(catalog);
183                let max_rhs = rhs.stat_max(catalog);
184
185                let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
186                let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
187
188                let min_max_check = left.into_iter().chain(right).reduce(or)?;
189
190                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
191                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
192            }
193            Operator::NotEq => {
194                let min_lhs = lhs.stat_min(catalog)?;
195                let max_lhs = lhs.stat_max(catalog)?;
196
197                let min_rhs = rhs.stat_min(catalog)?;
198                let max_rhs = rhs.stat_max(catalog)?;
199
200                let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
201
202                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
203            }
204            Operator::Gt => {
205                let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
206
207                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
208            }
209            Operator::Gte => {
210                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
211                let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
212
213                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
214            }
215            Operator::Lt => {
216                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
217                let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
218
219                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
220            }
221            Operator::Lte => {
222                // NaN is not captured by the min/max stat, so we must check NaNCount before pruning
223                let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
224
225                Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
226            }
227            Operator::And => lhs
228                .stat_falsification(catalog)
229                .into_iter()
230                .chain(rhs.stat_falsification(catalog))
231                .reduce(or),
232            Operator::Or => Some(and(
233                lhs.stat_falsification(catalog)?,
234                rhs.stat_falsification(catalog)?,
235            )),
236            Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
237        }
238    }
239
240    fn validity(
241        &self,
242        operator: &Operator,
243        expression: &Expression,
244    ) -> VortexResult<Option<Expression>> {
245        let lhs = expression.child(0).validity()?;
246        let rhs = expression.child(1).validity()?;
247
248        Ok(match operator {
249            // AND and OR are kleene logic.
250            Operator::And => None,
251            Operator::Or => None,
252            _ => {
253                // All other binary operators are null if either side is null.
254                Some(and(lhs, rhs))
255            }
256        })
257    }
258
259    fn is_null_sensitive(&self, _operator: &Operator) -> bool {
260        false
261    }
262
263    fn is_fallible(&self, operator: &Operator) -> bool {
264        // Opt-in not out for fallibility.
265        // Arithmetic operations could be better modelled here.
266        let infallible = matches!(
267            operator,
268            Operator::Eq
269                | Operator::NotEq
270                | Operator::Gt
271                | Operator::Gte
272                | Operator::Lt
273                | Operator::Lte
274                | Operator::And
275                | Operator::Or
276        );
277
278        !infallible
279    }
280}
281
282/// Create a new [`Binary`] using the [`Eq`](crate::expr::exprs::operators::Operator::Eq) operator.
283///
284/// ## Example usage
285///
286/// ```
287/// # use vortex_array::arrays::{BoolArray, PrimitiveArray};
288/// # use vortex_array::{Array, IntoArray, ToCanonical};
289/// # use vortex_array::validity::Validity;
290/// # use vortex_buffer::buffer;
291/// # use vortex_array::expr::{eq, root, lit};
292/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
293/// let result = xs.to_array().apply(&eq(root(), lit(3))).unwrap();
294///
295/// assert_eq!(
296///     result.to_bool().to_bit_buffer(),
297///     BoolArray::from_iter(vec![false, false, true]).to_bit_buffer(),
298/// );
299/// ```
300pub fn eq(lhs: Expression, rhs: Expression) -> Expression {
301    Binary
302        .try_new_expr(Operator::Eq, [lhs, rhs])
303        .vortex_expect("Failed to create Eq binary expression")
304}
305
306/// Create a new [`Binary`] using the [`NotEq`](crate::expr::exprs::operators::Operator::NotEq) operator.
307///
308/// ## Example usage
309///
310/// ```
311/// # use vortex_array::arrays::{BoolArray, PrimitiveArray};
312/// # use vortex_array::{Array, IntoArray, ToCanonical};
313/// # use vortex_array::validity::Validity;
314/// # use vortex_buffer::buffer;
315/// # use vortex_array::expr::{root, lit, not_eq};
316/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
317/// let result = xs.to_array().apply(&not_eq(root(), lit(3))).unwrap();
318///
319/// assert_eq!(
320///     result.to_bool().to_bit_buffer(),
321///     BoolArray::from_iter(vec![true, true, false]).to_bit_buffer(),
322/// );
323/// ```
324pub fn not_eq(lhs: Expression, rhs: Expression) -> Expression {
325    Binary
326        .try_new_expr(Operator::NotEq, [lhs, rhs])
327        .vortex_expect("Failed to create NotEq binary expression")
328}
329
330/// Create a new [`Binary`] using the [`Gte`](crate::expr::exprs::operators::Operator::Gte) operator.
331///
332/// ## Example usage
333///
334/// ```
335/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
336/// # use vortex_array::{Array, IntoArray, ToCanonical};
337/// # use vortex_array::validity::Validity;
338/// # use vortex_buffer::buffer;
339/// # use vortex_array::expr::{gt_eq, root, lit};
340/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
341/// let result = xs.to_array().apply(&gt_eq(root(), lit(3))).unwrap();
342///
343/// assert_eq!(
344///     result.to_bool().to_bit_buffer(),
345///     BoolArray::from_iter(vec![false, false, true]).to_bit_buffer(),
346/// );
347/// ```
348pub fn gt_eq(lhs: Expression, rhs: Expression) -> Expression {
349    Binary
350        .try_new_expr(Operator::Gte, [lhs, rhs])
351        .vortex_expect("Failed to create Gte binary expression")
352}
353
354/// Create a new [`Binary`] using the [`Gt`](crate::expr::exprs::operators::Operator::Gt) operator.
355///
356/// ## Example usage
357///
358/// ```
359/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
360/// # use vortex_array::{Array, IntoArray, ToCanonical};
361/// # use vortex_array::validity::Validity;
362/// # use vortex_buffer::buffer;
363/// # use vortex_array::expr::{gt, root, lit};
364/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
365/// let result = xs.to_array().apply(&gt(root(), lit(2))).unwrap();
366///
367/// assert_eq!(
368///     result.to_bool().to_bit_buffer(),
369///     BoolArray::from_iter(vec![false, false, true]).to_bit_buffer(),
370/// );
371/// ```
372pub fn gt(lhs: Expression, rhs: Expression) -> Expression {
373    Binary
374        .try_new_expr(Operator::Gt, [lhs, rhs])
375        .vortex_expect("Failed to create Gt binary expression")
376}
377
378/// Create a new [`Binary`] using the [`Lte`](crate::expr::exprs::operators::Operator::Lte) operator.
379///
380/// ## Example usage
381///
382/// ```
383/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
384/// # use vortex_array::{Array, IntoArray, ToCanonical};
385/// # use vortex_array::validity::Validity;
386/// # use vortex_buffer::buffer;
387/// # use vortex_array::expr::{root, lit, lt_eq};
388/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
389/// let result = xs.to_array().apply(&lt_eq(root(), lit(2))).unwrap();
390///
391/// assert_eq!(
392///     result.to_bool().to_bit_buffer(),
393///     BoolArray::from_iter(vec![true, true, false]).to_bit_buffer(),
394/// );
395/// ```
396pub fn lt_eq(lhs: Expression, rhs: Expression) -> Expression {
397    Binary
398        .try_new_expr(Operator::Lte, [lhs, rhs])
399        .vortex_expect("Failed to create Lte binary expression")
400}
401
402/// Create a new [`Binary`] using the [`Lt`](crate::expr::exprs::operators::Operator::Lt) operator.
403///
404/// ## Example usage
405///
406/// ```
407/// # use vortex_array::arrays::{BoolArray, PrimitiveArray };
408/// # use vortex_array::{Array, IntoArray, ToCanonical};
409/// # use vortex_array::validity::Validity;
410/// # use vortex_buffer::buffer;
411/// # use vortex_array::expr::{root, lit, lt};
412/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
413/// let result = xs.to_array().apply(&lt(root(), lit(3))).unwrap();
414///
415/// assert_eq!(
416///     result.to_bool().to_bit_buffer(),
417///     BoolArray::from_iter(vec![true, true, false]).to_bit_buffer(),
418/// );
419/// ```
420pub fn lt(lhs: Expression, rhs: Expression) -> Expression {
421    Binary
422        .try_new_expr(Operator::Lt, [lhs, rhs])
423        .vortex_expect("Failed to create Lt binary expression")
424}
425
426/// Create a new [`Binary`] using the [`Or`](crate::expr::exprs::operators::Operator::Or) operator.
427///
428/// ## Example usage
429///
430/// ```
431/// # use vortex_array::arrays::BoolArray;
432/// # use vortex_array::{Array, IntoArray, ToCanonical};
433/// # use vortex_array::expr::{root, lit, or};
434/// let xs = BoolArray::from_iter(vec![true, false, true]);
435/// let result = xs.to_array().apply(&or(root(), lit(false))).unwrap();
436///
437/// assert_eq!(
438///     result.to_bool().to_bit_buffer(),
439///     BoolArray::from_iter(vec![true, false, true]).to_bit_buffer(),
440/// );
441/// ```
442pub fn or(lhs: Expression, rhs: Expression) -> Expression {
443    Binary
444        .try_new_expr(Operator::Or, [lhs, rhs])
445        .vortex_expect("Failed to create Or binary expression")
446}
447
448/// Collects a list of `or`ed values into a single expression using a balanced tree.
449///
450/// This creates a balanced binary tree to avoid deep nesting that could cause
451/// stack overflow during drop or evaluation.
452///
453/// [a, b, c, d] => or(or(a, b), or(c, d))
454pub fn or_collect<I>(iter: I) -> Option<Expression>
455where
456    I: IntoIterator<Item = Expression>,
457{
458    let exprs: Vec<_> = iter.into_iter().collect();
459    balanced_reduce(exprs, or)
460}
461
462/// Create a new [`Binary`] using the [`And`](crate::expr::exprs::operators::Operator::And) operator.
463///
464/// ## Example usage
465///
466/// ```
467/// # use vortex_array::arrays::BoolArray;
468/// # use vortex_array::{Array, IntoArray, ToCanonical};
469/// # use vortex_array::expr::{and, root, lit};
470/// let xs = BoolArray::from_iter(vec![true, false, true]);
471/// let result = xs.to_array().apply(&and(root(), lit(true))).unwrap();
472///
473/// assert_eq!(
474///     result.to_bool().to_bit_buffer(),
475///     BoolArray::from_iter(vec![true, false, true]).to_bit_buffer(),
476/// );
477/// ```
478pub fn and(lhs: Expression, rhs: Expression) -> Expression {
479    Binary
480        .try_new_expr(Operator::And, [lhs, rhs])
481        .vortex_expect("Failed to create And binary expression")
482}
483
484/// Collects a list of `and`ed values into a single expression using a balanced tree.
485///
486/// This creates a balanced binary tree to avoid deep nesting that could cause
487/// stack overflow during drop or evaluation.
488///
489/// [a, b, c, d] => and(and(a, b), and(c, d))
490pub fn and_collect<I>(iter: I) -> Option<Expression>
491where
492    I: IntoIterator<Item = Expression>,
493{
494    let exprs: Vec<_> = iter.into_iter().collect();
495    balanced_reduce(exprs, and)
496}
497
498/// Helper function to reduce a list of expressions into a balanced binary tree.
499fn balanced_reduce<F>(mut exprs: Vec<Expression>, combine: F) -> Option<Expression>
500where
501    F: Fn(Expression, Expression) -> Expression + Copy,
502{
503    if exprs.is_empty() {
504        return None;
505    }
506    if exprs.len() == 1 {
507        return exprs.pop();
508    }
509
510    while exprs.len() > 1 {
511        let exprs_len = exprs.len();
512
513        for target_idx in 0..(exprs.len() / 2) {
514            let item_idx = target_idx * 2;
515            let new = combine(exprs[item_idx].clone(), exprs[item_idx + 1].clone());
516            exprs[target_idx] = new;
517        }
518
519        if !exprs.len().is_multiple_of(2) {
520            // We want the odd nodes to be inside the tree and not at root
521            let lhs = exprs[(exprs.len() / 2) - 1].clone();
522            let rhs = exprs[exprs.len() - 1].clone();
523            exprs[exprs_len / 2 - 1] = combine(lhs, rhs);
524        }
525
526        exprs.truncate(exprs_len / 2);
527    }
528
529    exprs.pop()
530}
531
532/// Create a new [`Binary`] using the [`Add`](crate::expr::exprs::operators::Operator::Add) operator.
533///
534/// ## Example usage
535///
536/// ```
537/// # use vortex_array::{Array, IntoArray};
538/// # use vortex_array::arrow::IntoArrowArray as _;
539/// # use vortex_buffer::buffer;
540/// # use vortex_array::expr::{checked_add, lit, root};
541/// let xs = buffer![1, 2, 3].into_array();
542/// let result = xs.apply(&checked_add(root(), lit(5))).unwrap();
543///
544/// assert_eq!(
545///     &result.into_arrow_preferred().unwrap(),
546///     &buffer![6, 7, 8]
547///         .into_array()
548///         .into_arrow_preferred()
549///         .unwrap()
550/// );
551/// ```
552pub fn checked_add(lhs: Expression, rhs: Expression) -> Expression {
553    Binary
554        .try_new_expr(Operator::Add, [lhs, rhs])
555        .vortex_expect("Failed to create Add binary expression")
556}
557
558#[cfg(test)]
559mod tests {
560    use vortex_dtype::DType;
561    use vortex_dtype::Nullability;
562
563    use super::*;
564    use crate::assert_arrays_eq;
565    use crate::compute::compare;
566    use crate::expr::Expression;
567    use crate::expr::exprs::get_item::col;
568    use crate::expr::exprs::literal::lit;
569    use crate::expr::test_harness;
570    use crate::scalar::Scalar;
571
572    #[test]
573    fn and_collect_balanced() {
574        let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
575
576        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
577        vortex.binary(and)
578        ├── lhs: vortex.binary(and)
579        │   ├── lhs: vortex.literal(1i32)
580        │   └── rhs: vortex.literal(2i32)
581        └── rhs: vortex.binary(and)
582            ├── lhs: vortex.binary(and)
583            │   ├── lhs: vortex.literal(3i32)
584            │   └── rhs: vortex.literal(4i32)
585            └── rhs: vortex.literal(5i32)
586        ");
587
588        // 4 elements: and(and(1, 2), and(3, 4)) - perfectly balanced
589        let values = vec![lit(1), lit(2), lit(3), lit(4)];
590        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
591        vortex.binary(and)
592        ├── lhs: vortex.binary(and)
593        │   ├── lhs: vortex.literal(1i32)
594        │   └── rhs: vortex.literal(2i32)
595        └── rhs: vortex.binary(and)
596            ├── lhs: vortex.literal(3i32)
597            └── rhs: vortex.literal(4i32)
598        ");
599
600        // 1 element: just the element
601        let values = vec![lit(1)];
602        insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @"vortex.literal(1i32)");
603
604        // 0 elements: None
605        let values: Vec<Expression> = vec![];
606        assert!(and_collect(values.into_iter()).is_none());
607    }
608
609    #[test]
610    fn or_collect_balanced() {
611        // 4 elements: or(or(1, 2), or(3, 4)) - perfectly balanced
612        let values = vec![lit(1), lit(2), lit(3), lit(4)];
613        insta::assert_snapshot!(or_collect(values.into_iter()).unwrap().display_tree(), @r"
614        vortex.binary(or)
615        ├── lhs: vortex.binary(or)
616        │   ├── lhs: vortex.literal(1i32)
617        │   └── rhs: vortex.literal(2i32)
618        └── rhs: vortex.binary(or)
619            ├── lhs: vortex.literal(3i32)
620            └── rhs: vortex.literal(4i32)
621        ");
622    }
623
624    #[test]
625    fn dtype() {
626        let dtype = test_harness::struct_dtype();
627        let bool1: Expression = col("bool1");
628        let bool2: Expression = col("bool2");
629        assert_eq!(
630            and(bool1.clone(), bool2.clone())
631                .return_dtype(&dtype)
632                .unwrap(),
633            DType::Bool(Nullability::NonNullable)
634        );
635        assert_eq!(
636            or(bool1, bool2).return_dtype(&dtype).unwrap(),
637            DType::Bool(Nullability::NonNullable)
638        );
639
640        let col1: Expression = col("col1");
641        let col2: Expression = col("col2");
642
643        assert_eq!(
644            eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
645            DType::Bool(Nullability::Nullable)
646        );
647        assert_eq!(
648            not_eq(col1.clone(), col2.clone())
649                .return_dtype(&dtype)
650                .unwrap(),
651            DType::Bool(Nullability::Nullable)
652        );
653        assert_eq!(
654            gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
655            DType::Bool(Nullability::Nullable)
656        );
657        assert_eq!(
658            gt_eq(col1.clone(), col2.clone())
659                .return_dtype(&dtype)
660                .unwrap(),
661            DType::Bool(Nullability::Nullable)
662        );
663        assert_eq!(
664            lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
665            DType::Bool(Nullability::Nullable)
666        );
667        assert_eq!(
668            lt_eq(col1.clone(), col2.clone())
669                .return_dtype(&dtype)
670                .unwrap(),
671            DType::Bool(Nullability::Nullable)
672        );
673
674        assert_eq!(
675            or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
676                .return_dtype(&dtype)
677                .unwrap(),
678            DType::Bool(Nullability::Nullable)
679        );
680    }
681
682    #[test]
683    fn test_display_print() {
684        let expr = gt(lit(1), lit(2));
685        assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
686    }
687
688    /// Regression test for GitHub issue #5947: struct comparison in filter expressions should work
689    /// using `make_comparator` instead of Arrow's `cmp` functions which don't support nested types.
690    #[test]
691    fn test_struct_comparison() {
692        use crate::IntoArray;
693        use crate::arrays::StructArray;
694
695        // Create a struct array with one element for testing.
696        let lhs_struct = StructArray::from_fields(&[
697            (
698                "a",
699                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
700            ),
701            (
702                "b",
703                crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
704            ),
705        ])
706        .unwrap()
707        .into_array();
708
709        let rhs_struct_equal = StructArray::from_fields(&[
710            (
711                "a",
712                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
713            ),
714            (
715                "b",
716                crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
717            ),
718        ])
719        .unwrap()
720        .into_array();
721
722        let rhs_struct_different = StructArray::from_fields(&[
723            (
724                "a",
725                crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
726            ),
727            (
728                "b",
729                crate::arrays::PrimitiveArray::from_iter([4i32]).into_array(),
730            ),
731        ])
732        .unwrap()
733        .into_array();
734
735        // Test using compare compute function directly
736        let result_equal = compare(&lhs_struct, &rhs_struct_equal, compute::Operator::Eq).unwrap();
737        assert_eq!(
738            result_equal.scalar_at(0).vortex_expect("value"),
739            Scalar::bool(true, Nullability::NonNullable),
740            "Equal structs should be equal"
741        );
742
743        let result_different =
744            compare(&lhs_struct, &rhs_struct_different, compute::Operator::Eq).unwrap();
745        assert_eq!(
746            result_different.scalar_at(0).vortex_expect("value"),
747            Scalar::bool(false, Nullability::NonNullable),
748            "Different structs should not be equal"
749        );
750    }
751
752    #[test]
753    fn test_or_kleene_validity() {
754        use crate::IntoArray;
755        use crate::arrays::BoolArray;
756        use crate::arrays::StructArray;
757        use crate::expr::exprs::get_item::col;
758
759        let struct_arr = StructArray::from_fields(&[
760            ("a", BoolArray::from_iter([Some(true)]).into_array()),
761            (
762                "b",
763                BoolArray::from_iter([Option::<bool>::None]).into_array(),
764            ),
765        ])
766        .unwrap()
767        .into_array();
768
769        let expr = or(col("a"), col("b"));
770        let result = struct_arr.apply(&expr).unwrap();
771
772        assert_arrays_eq!(result, BoolArray::from_iter([Some(true)]).into_array())
773    }
774}