Skip to main content

vortex_array/expr/exprs/binary/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_array::BooleanArray;
5use arrow_ord::cmp;
6use vortex_error::VortexResult;
7
8use crate::Array;
9use crate::ArrayRef;
10use crate::Canonical;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::arrays::ConstantArray;
14use crate::arrays::ConstantVTable;
15use crate::arrays::ExactScalarFn;
16use crate::arrays::ScalarFnArrayView;
17use crate::arrays::ScalarFnVTable;
18use crate::arrow::Datum;
19use crate::arrow::IntoArrowArray;
20use crate::arrow::from_arrow_array_with_len;
21use crate::compute::Operator;
22use crate::compute::compare_nested_arrow_arrays;
23use crate::compute::scalar_cmp;
24use crate::expr::Binary;
25use crate::kernel::ExecuteParentKernel;
26use crate::scalar::Scalar;
27use crate::vtable::VTable;
28
29/// Trait for encoding-specific comparison kernels that operate in encoded space.
30///
31/// Implementations can compare an encoded array against another array (typically a constant)
32/// without first decompressing. The adaptor normalizes operand order so `array` is always
33/// the left-hand side, swapping the operator when necessary.
34pub trait CompareKernel: VTable {
35    fn compare(
36        lhs: &Self::Array,
37        rhs: &dyn Array,
38        operator: Operator,
39        ctx: &mut ExecutionCtx,
40    ) -> VortexResult<Option<ArrayRef>>;
41}
42
43/// Adaptor that bridges [`CompareKernel`] implementations to [`ExecuteParentKernel`].
44///
45/// When a `ScalarFnArray(Binary, cmp_op)` wraps a child that implements `CompareKernel`,
46/// this adaptor extracts the comparison operator and other operand, normalizes operand order
47/// (swapping the operator if the encoded array is on the RHS), and delegates to the kernel.
48#[derive(Default, Debug)]
49pub struct CompareExecuteAdaptor<V>(pub V);
50
51impl<V> ExecuteParentKernel<V> for CompareExecuteAdaptor<V>
52where
53    V: CompareKernel,
54{
55    type Parent = ExactScalarFn<Binary>;
56
57    fn execute_parent(
58        &self,
59        array: &V::Array,
60        parent: ScalarFnArrayView<'_, Binary>,
61        child_idx: usize,
62        ctx: &mut ExecutionCtx,
63    ) -> VortexResult<Option<ArrayRef>> {
64        // Only handle comparison operators
65        let Some(cmp_op) = parent.options.maybe_cmp_operator() else {
66            return Ok(None);
67        };
68
69        // Get the ScalarFnArray to access children
70        let Some(scalar_fn_array) = parent.as_opt::<ScalarFnVTable>() else {
71            return Ok(None);
72        };
73        let children = scalar_fn_array.children();
74
75        // Normalize so `array` is always LHS, swapping the operator if needed
76        // TODO(joe): should be go this here or in the Rule/Kernel
77        let (cmp_op, other) = match child_idx {
78            0 => (cmp_op, &children[1]),
79            1 => (cmp_op.swap(), &children[0]),
80            _ => return Ok(None),
81        };
82
83        let len = array.len();
84        let nullable = array.dtype().is_nullable() || other.dtype().is_nullable();
85
86        // Empty array → empty bool result
87        if len == 0 {
88            return Ok(Some(
89                Canonical::empty(&vortex_dtype::DType::Bool(nullable.into())).into_array(),
90            ));
91        }
92
93        // Null constant on either side → all-null bool result
94        if other.as_constant().is_some_and(|s| s.is_null()) {
95            return Ok(Some(
96                ConstantArray::new(
97                    Scalar::null(vortex_dtype::DType::Bool(nullable.into())),
98                    len,
99                )
100                .into_array(),
101            ));
102        }
103
104        V::compare(array, other.as_ref(), cmp_op, ctx)
105    }
106}
107
108/// Execute a compare operation between two arrays.
109///
110/// This is the entry point for compare operations from the binary expression.
111/// Handles empty, constant-null, and constant-constant directly, otherwise falls back to Arrow.
112pub(crate) fn execute_compare(
113    lhs: &dyn Array,
114    rhs: &dyn Array,
115    op: Operator,
116) -> VortexResult<ArrayRef> {
117    let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
118
119    if lhs.is_empty() {
120        return Ok(Canonical::empty(&vortex_dtype::DType::Bool(nullable.into())).into_array());
121    }
122
123    let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
124    let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
125    if left_constant_null || right_constant_null {
126        return Ok(ConstantArray::new(
127            Scalar::null(vortex_dtype::DType::Bool(nullable.into())),
128            lhs.len(),
129        )
130        .into_array());
131    }
132
133    // Constant-constant fast path
134    if let (Some(lhs_const), Some(rhs_const)) = (
135        lhs.as_opt::<ConstantVTable>(),
136        rhs.as_opt::<ConstantVTable>(),
137    ) {
138        let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op);
139        return Ok(ConstantArray::new(result, lhs.len()).into_array());
140    }
141
142    arrow_compare_arrays(lhs, rhs, op)
143}
144
145/// Fall back to Arrow for comparison.
146fn arrow_compare_arrays(
147    left: &dyn Array,
148    right: &dyn Array,
149    operator: Operator,
150) -> VortexResult<ArrayRef> {
151    assert_eq!(left.len(), right.len());
152
153    let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
154
155    // Arrow's vectorized comparison kernels don't support nested types.
156    // For nested types, fall back to `make_comparator` which does element-wise comparison.
157    let array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
158        let rhs = right.to_array().into_arrow_preferred()?;
159        let lhs = left.to_array().into_arrow(rhs.data_type())?;
160
161        assert!(
162            lhs.data_type().equals_datatype(rhs.data_type()),
163            "lhs data_type: {}, rhs data_type: {}",
164            lhs.data_type(),
165            rhs.data_type()
166        );
167
168        compare_nested_arrow_arrays(lhs.as_ref(), rhs.as_ref(), operator)?
169    } else {
170        // Fast path: use vectorized kernels for primitive types.
171        let lhs = Datum::try_new(left)?;
172        let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type())?;
173
174        match operator {
175            Operator::Eq => cmp::eq(&lhs, &rhs)?,
176            Operator::NotEq => cmp::neq(&lhs, &rhs)?,
177            Operator::Gt => cmp::gt(&lhs, &rhs)?,
178            Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
179            Operator::Lt => cmp::lt(&lhs, &rhs)?,
180            Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
181        }
182    };
183    from_arrow_array_with_len(&array, left.len(), nullable)
184}