vortex_array/expr/exprs/binary/
compare.rs1use arrow_array::BooleanArray;
5use arrow_ord::cmp;
6use vortex_error::VortexResult;
7
8use crate::Array;
9use crate::ArrayRef;
10use crate::Canonical;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::arrays::ConstantArray;
14use crate::arrays::ConstantVTable;
15use crate::arrays::ExactScalarFn;
16use crate::arrays::ScalarFnArrayView;
17use crate::arrays::ScalarFnVTable;
18use crate::arrow::Datum;
19use crate::arrow::IntoArrowArray;
20use crate::arrow::from_arrow_array_with_len;
21use crate::compute::Operator;
22use crate::compute::compare_nested_arrow_arrays;
23use crate::compute::scalar_cmp;
24use crate::expr::Binary;
25use crate::kernel::ExecuteParentKernel;
26use crate::scalar::Scalar;
27use crate::vtable::VTable;
28
29pub trait CompareKernel: VTable {
35 fn compare(
36 lhs: &Self::Array,
37 rhs: &dyn Array,
38 operator: Operator,
39 ctx: &mut ExecutionCtx,
40 ) -> VortexResult<Option<ArrayRef>>;
41}
42
43#[derive(Default, Debug)]
49pub struct CompareExecuteAdaptor<V>(pub V);
50
51impl<V> ExecuteParentKernel<V> for CompareExecuteAdaptor<V>
52where
53 V: CompareKernel,
54{
55 type Parent = ExactScalarFn<Binary>;
56
57 fn execute_parent(
58 &self,
59 array: &V::Array,
60 parent: ScalarFnArrayView<'_, Binary>,
61 child_idx: usize,
62 ctx: &mut ExecutionCtx,
63 ) -> VortexResult<Option<ArrayRef>> {
64 let Some(cmp_op) = parent.options.maybe_cmp_operator() else {
66 return Ok(None);
67 };
68
69 let Some(scalar_fn_array) = parent.as_opt::<ScalarFnVTable>() else {
71 return Ok(None);
72 };
73 let children = scalar_fn_array.children();
74
75 let (cmp_op, other) = match child_idx {
78 0 => (cmp_op, &children[1]),
79 1 => (cmp_op.swap(), &children[0]),
80 _ => return Ok(None),
81 };
82
83 let len = array.len();
84 let nullable = array.dtype().is_nullable() || other.dtype().is_nullable();
85
86 if len == 0 {
88 return Ok(Some(
89 Canonical::empty(&vortex_dtype::DType::Bool(nullable.into())).into_array(),
90 ));
91 }
92
93 if other.as_constant().is_some_and(|s| s.is_null()) {
95 return Ok(Some(
96 ConstantArray::new(
97 Scalar::null(vortex_dtype::DType::Bool(nullable.into())),
98 len,
99 )
100 .into_array(),
101 ));
102 }
103
104 V::compare(array, other.as_ref(), cmp_op, ctx)
105 }
106}
107
108pub(crate) fn execute_compare(
113 lhs: &dyn Array,
114 rhs: &dyn Array,
115 op: Operator,
116) -> VortexResult<ArrayRef> {
117 let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
118
119 if lhs.is_empty() {
120 return Ok(Canonical::empty(&vortex_dtype::DType::Bool(nullable.into())).into_array());
121 }
122
123 let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
124 let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
125 if left_constant_null || right_constant_null {
126 return Ok(ConstantArray::new(
127 Scalar::null(vortex_dtype::DType::Bool(nullable.into())),
128 lhs.len(),
129 )
130 .into_array());
131 }
132
133 if let (Some(lhs_const), Some(rhs_const)) = (
135 lhs.as_opt::<ConstantVTable>(),
136 rhs.as_opt::<ConstantVTable>(),
137 ) {
138 let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op);
139 return Ok(ConstantArray::new(result, lhs.len()).into_array());
140 }
141
142 arrow_compare_arrays(lhs, rhs, op)
143}
144
145fn arrow_compare_arrays(
147 left: &dyn Array,
148 right: &dyn Array,
149 operator: Operator,
150) -> VortexResult<ArrayRef> {
151 assert_eq!(left.len(), right.len());
152
153 let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
154
155 let array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
158 let rhs = right.to_array().into_arrow_preferred()?;
159 let lhs = left.to_array().into_arrow(rhs.data_type())?;
160
161 assert!(
162 lhs.data_type().equals_datatype(rhs.data_type()),
163 "lhs data_type: {}, rhs data_type: {}",
164 lhs.data_type(),
165 rhs.data_type()
166 );
167
168 compare_nested_arrow_arrays(lhs.as_ref(), rhs.as_ref(), operator)?
169 } else {
170 let lhs = Datum::try_new(left)?;
172 let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type())?;
173
174 match operator {
175 Operator::Eq => cmp::eq(&lhs, &rhs)?,
176 Operator::NotEq => cmp::neq(&lhs, &rhs)?,
177 Operator::Gt => cmp::gt(&lhs, &rhs)?,
178 Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
179 Operator::Lt => cmp::lt(&lhs, &rhs)?,
180 Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
181 }
182 };
183 from_arrow_array_with_len(&array, left.len(), nullable)
184}