1use core::fmt;
2use std::any::Any;
3use std::fmt::{Display, Formatter};
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use arrow_buffer::BooleanBuffer;
8use arrow_ord::cmp;
9use vortex_dtype::{DType, NativePType, Nullability};
10use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
11use vortex_scalar::Scalar;
12
13use crate::arrays::ConstantArray;
14use crate::arrow::{Datum, from_arrow_array_with_len};
15use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
16use crate::vtable::VTable;
17use crate::{Array, ArrayRef, Canonical, IntoArray};
18
19pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
22 COMPARE_FN
23 .invoke(&InvocationArgs {
24 inputs: &[left.into(), right.into()],
25 options: &operator,
26 })?
27 .unwrap_array()
28}
29
30#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd)]
31pub enum Operator {
32 Eq,
33 NotEq,
34 Gt,
35 Gte,
36 Lt,
37 Lte,
38}
39
40impl Display for Operator {
41 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
42 let display = match &self {
43 Operator::Eq => "=",
44 Operator::NotEq => "!=",
45 Operator::Gt => ">",
46 Operator::Gte => ">=",
47 Operator::Lt => "<",
48 Operator::Lte => "<=",
49 };
50 Display::fmt(display, f)
51 }
52}
53
54impl Operator {
55 pub fn inverse(self) -> Self {
56 match self {
57 Operator::Eq => Operator::NotEq,
58 Operator::NotEq => Operator::Eq,
59 Operator::Gt => Operator::Lte,
60 Operator::Gte => Operator::Lt,
61 Operator::Lt => Operator::Gte,
62 Operator::Lte => Operator::Gt,
63 }
64 }
65
66 pub fn swap(self) -> Self {
68 match self {
69 Operator::Eq => Operator::Eq,
70 Operator::NotEq => Operator::NotEq,
71 Operator::Gt => Operator::Lt,
72 Operator::Gte => Operator::Lte,
73 Operator::Lt => Operator::Gt,
74 Operator::Lte => Operator::Gte,
75 }
76 }
77}
78
79pub struct CompareKernelRef(ArcRef<dyn Kernel>);
80inventory::collect!(CompareKernelRef);
81
82pub trait CompareKernel: VTable {
83 fn compare(
84 &self,
85 lhs: &Self::Array,
86 rhs: &dyn Array,
87 operator: Operator,
88 ) -> VortexResult<Option<ArrayRef>>;
89}
90
91#[derive(Debug)]
92pub struct CompareKernelAdapter<V: VTable>(pub V);
93
94impl<V: VTable + CompareKernel> CompareKernelAdapter<V> {
95 pub const fn lift(&'static self) -> CompareKernelRef {
96 CompareKernelRef(ArcRef::new_ref(self))
97 }
98}
99
100impl<V: VTable + CompareKernel> Kernel for CompareKernelAdapter<V> {
101 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
102 let inputs = CompareArgs::try_from(args)?;
103 let Some(array) = inputs.lhs.as_opt::<V>() else {
104 return Ok(None);
105 };
106 Ok(V::compare(&self.0, array, inputs.rhs, inputs.operator)?.map(|array| array.into()))
107 }
108}
109
110pub static COMPARE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
111 let compute = ComputeFn::new("compare".into(), ArcRef::new_ref(&Compare));
112 for kernel in inventory::iter::<CompareKernelRef> {
113 compute.register_kernel(kernel.0.clone());
114 }
115 compute
116});
117
118struct Compare;
119
120impl ComputeFnVTable for Compare {
121 fn invoke(
122 &self,
123 args: &InvocationArgs,
124 kernels: &[ArcRef<dyn Kernel>],
125 ) -> VortexResult<Output> {
126 let CompareArgs { lhs, rhs, operator } = CompareArgs::try_from(args)?;
127
128 let return_dtype = self.return_dtype(args)?;
129
130 if lhs.is_empty() {
131 return Ok(Canonical::empty(&return_dtype).into_array().into());
132 }
133
134 let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
135 let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
136 if left_constant_null || right_constant_null {
137 return Ok(ConstantArray::new(Scalar::null(return_dtype), lhs.len())
138 .into_array()
139 .into());
140 }
141
142 let right_is_constant = rhs.is_constant();
143
144 if lhs.is_constant() && !right_is_constant {
146 return Ok(compare(rhs, lhs, operator.swap())?.into());
147 }
148
149 for kernel in kernels {
151 if let Some(output) = kernel.invoke(args)? {
152 return Ok(output);
153 }
154 }
155 if let Some(output) = lhs.invoke(&COMPARE_FN, args)? {
156 return Ok(output);
157 }
158
159 let inverted_args = InvocationArgs {
161 inputs: &[rhs.into(), lhs.into()],
162 options: &operator.swap(),
163 };
164 for kernel in kernels {
165 if let Some(output) = kernel.invoke(&inverted_args)? {
166 return Ok(output);
167 }
168 }
169 if let Some(output) = rhs.invoke(&COMPARE_FN, &inverted_args)? {
170 return Ok(output);
171 }
172
173 if !(lhs.is_arrow() && (rhs.is_arrow() || right_is_constant)) {
176 log::debug!(
177 "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)",
178 lhs.encoding_id(),
179 rhs.encoding_id(),
180 operator,
181 );
182 }
183
184 Ok(arrow_compare(lhs, rhs, operator)?.into())
186 }
187
188 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
189 let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
190
191 if !lhs.dtype().eq_ignore_nullability(rhs.dtype()) {
192 vortex_bail!(
193 "Cannot compare different DTypes {} and {}",
194 lhs.dtype(),
195 rhs.dtype()
196 );
197 }
198
199 if lhs.dtype().is_struct() {
201 vortex_bail!(
202 "Compare does not support arrays with Struct DType, got: {} and {}",
203 lhs.dtype(),
204 rhs.dtype()
205 )
206 }
207
208 Ok(DType::Bool(
209 lhs.dtype().nullability() | rhs.dtype().nullability(),
210 ))
211 }
212
213 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
214 let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
215 if lhs.len() != rhs.len() {
216 vortex_bail!(
217 "Compare operations only support arrays of the same length, got {} and {}",
218 lhs.len(),
219 rhs.len()
220 );
221 }
222 Ok(lhs.len())
223 }
224
225 fn is_elementwise(&self) -> bool {
226 true
227 }
228}
229
230struct CompareArgs<'a> {
231 lhs: &'a dyn Array,
232 rhs: &'a dyn Array,
233 operator: Operator,
234}
235
236impl Options for Operator {
237 fn as_any(&self) -> &dyn Any {
238 self
239 }
240}
241
242impl<'a> TryFrom<&InvocationArgs<'a>> for CompareArgs<'a> {
243 type Error = VortexError;
244
245 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
246 if value.inputs.len() != 2 {
247 vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
248 }
249 let lhs = value.inputs[0]
250 .array()
251 .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
252 let rhs = value.inputs[1]
253 .array()
254 .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
255 let operator = *value
256 .options
257 .as_any()
258 .downcast_ref::<Operator>()
259 .vortex_expect("Expected options to be an operator");
260
261 Ok(CompareArgs { lhs, rhs, operator })
262 }
263}
264
265pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BooleanBuffer
268where
269 P: NativePType,
270 I: Iterator<Item = P>,
271{
272 let cmp_fn = match op {
274 Operator::Eq | Operator::Lte => |v| v == P::zero(),
275 Operator::NotEq | Operator::Gt => |v| v != P::zero(),
276 Operator::Gte => |_| true,
277 Operator::Lt => |_| false,
278 };
279
280 lengths.map(cmp_fn).collect::<BooleanBuffer>()
281}
282
283fn arrow_compare(
285 left: &dyn Array,
286 right: &dyn Array,
287 operator: Operator,
288) -> VortexResult<ArrayRef> {
289 let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
290 let lhs = Datum::try_new(left)?;
291 let rhs = Datum::try_new(right)?;
292
293 let array = match operator {
294 Operator::Eq => cmp::eq(&lhs, &rhs)?,
295 Operator::NotEq => cmp::neq(&lhs, &rhs)?,
296 Operator::Gt => cmp::gt(&lhs, &rhs)?,
297 Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
298 Operator::Lt => cmp::lt(&lhs, &rhs)?,
299 Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
300 };
301 from_arrow_array_with_len(&array, left.len(), nullable)
302}
303
304pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
305 if lhs.is_null() | rhs.is_null() {
306 Scalar::null(DType::Bool(Nullability::Nullable))
307 } else {
308 let b = match operator {
309 Operator::Eq => lhs == rhs,
310 Operator::NotEq => lhs != rhs,
311 Operator::Gt => lhs > rhs,
312 Operator::Gte => lhs >= rhs,
313 Operator::Lt => lhs < rhs,
314 Operator::Lte => lhs <= rhs,
315 };
316
317 Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use arrow_buffer::BooleanBuffer;
324 use rstest::rstest;
325
326 use super::*;
327 use crate::ToCanonical;
328 use crate::arrays::{BoolArray, ConstantArray, VarBinArray, VarBinViewArray};
329 use crate::test_harness::to_int_indices;
330 use crate::validity::Validity;
331
332 #[test]
333 fn test_bool_basic_comparisons() {
334 let arr = BoolArray::new(
335 BooleanBuffer::from_iter([true, true, false, true, false]),
336 Validity::from_iter([false, true, true, true, true]),
337 );
338
339 let matches = compare(arr.as_ref(), arr.as_ref(), Operator::Eq)
340 .unwrap()
341 .to_bool()
342 .unwrap();
343
344 assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
345
346 let matches = compare(arr.as_ref(), arr.as_ref(), Operator::NotEq)
347 .unwrap()
348 .to_bool()
349 .unwrap();
350 let empty: [u64; 0] = [];
351 assert_eq!(to_int_indices(matches).unwrap(), empty);
352
353 let other = BoolArray::new(
354 BooleanBuffer::from_iter([false, false, false, true, true]),
355 Validity::from_iter([false, true, true, true, true]),
356 );
357
358 let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lte)
359 .unwrap()
360 .to_bool()
361 .unwrap();
362 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
363
364 let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lt)
365 .unwrap()
366 .to_bool()
367 .unwrap();
368 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
369
370 let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gte)
371 .unwrap()
372 .to_bool()
373 .unwrap();
374 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
375
376 let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gt)
377 .unwrap()
378 .to_bool()
379 .unwrap();
380 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
381 }
382
383 #[test]
384 fn constant_compare() {
385 let left = ConstantArray::new(Scalar::from(2u32), 10);
386 let right = ConstantArray::new(Scalar::from(10u32), 10);
387
388 let compare = compare(left.as_ref(), right.as_ref(), Operator::Gt).unwrap();
389 let res = compare.as_constant().unwrap();
390 assert_eq!(res.as_bool().value(), Some(false));
391 assert_eq!(compare.len(), 10);
392
393 let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
394 let res = compare.as_constant().unwrap();
395 assert_eq!(res.as_bool().value(), Some(false));
396 assert_eq!(compare.len(), 10);
397 }
398
399 #[rstest]
400 #[case(Operator::Eq, vec![false, false, false, true])]
401 #[case(Operator::NotEq, vec![true, true, true, false])]
402 #[case(Operator::Gt, vec![true, true, true, false])]
403 #[case(Operator::Gte, vec![true, true, true, true])]
404 #[case(Operator::Lt, vec![false, false, false, false])]
405 #[case(Operator::Lte, vec![false, false, false, true])]
406 fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
407 let lengths: Vec<i32> = vec![1, 5, 7, 0];
408
409 let output = compare_lengths_to_empty(lengths.iter().copied(), op);
410 assert_eq!(Vec::from_iter(output.iter()), expected);
411 }
412
413 #[rstest]
414 #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
415 #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
416 #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
417 #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
418 fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
419 let res = compare(&left, &right, Operator::Eq).unwrap();
420 assert_eq!(
421 res.to_bool().unwrap().boolean_buffer().count_set_bits(),
422 left.len()
423 );
424 }
425}