1use core::fmt;
5use std::any::Any;
6use std::fmt::{Display, Formatter};
7use std::sync::LazyLock;
8
9use arcref::ArcRef;
10use arrow_buffer::BooleanBuffer;
11use arrow_ord::cmp;
12use vortex_dtype::{DType, NativePType, Nullability};
13use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
14use vortex_scalar::Scalar;
15
16use crate::arrays::ConstantArray;
17use crate::arrow::{Datum, from_arrow_array_with_len};
18use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
19use crate::vtable::VTable;
20use crate::{Array, ArrayRef, Canonical, IntoArray};
21
22static COMPARE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
23 let compute = ComputeFn::new("compare".into(), ArcRef::new_ref(&Compare));
24 for kernel in inventory::iter::<CompareKernelRef> {
25 compute.register_kernel(kernel.0.clone());
26 }
27 compute
28});
29
30pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
33 COMPARE_FN
34 .invoke(&InvocationArgs {
35 inputs: &[left.into(), right.into()],
36 options: &operator,
37 })?
38 .unwrap_array()
39}
40
41#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Hash)]
42pub enum Operator {
43 Eq,
45 NotEq,
47 Gt,
49 Gte,
51 Lt,
53 Lte,
55}
56
57impl Display for Operator {
58 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
59 let display = match &self {
60 Operator::Eq => "=",
61 Operator::NotEq => "!=",
62 Operator::Gt => ">",
63 Operator::Gte => ">=",
64 Operator::Lt => "<",
65 Operator::Lte => "<=",
66 };
67 Display::fmt(display, f)
68 }
69}
70
71impl Operator {
72 pub fn inverse(self) -> Self {
73 match self {
74 Operator::Eq => Operator::NotEq,
75 Operator::NotEq => Operator::Eq,
76 Operator::Gt => Operator::Lte,
77 Operator::Gte => Operator::Lt,
78 Operator::Lt => Operator::Gte,
79 Operator::Lte => Operator::Gt,
80 }
81 }
82
83 pub fn swap(self) -> Self {
85 match self {
86 Operator::Eq => Operator::Eq,
87 Operator::NotEq => Operator::NotEq,
88 Operator::Gt => Operator::Lt,
89 Operator::Gte => Operator::Lte,
90 Operator::Lt => Operator::Gt,
91 Operator::Lte => Operator::Gte,
92 }
93 }
94}
95
96pub struct CompareKernelRef(ArcRef<dyn Kernel>);
97inventory::collect!(CompareKernelRef);
98
99pub trait CompareKernel: VTable {
100 fn compare(
101 &self,
102 lhs: &Self::Array,
103 rhs: &dyn Array,
104 operator: Operator,
105 ) -> VortexResult<Option<ArrayRef>>;
106}
107
108#[derive(Debug)]
109pub struct CompareKernelAdapter<V: VTable>(pub V);
110
111impl<V: VTable + CompareKernel> CompareKernelAdapter<V> {
112 pub const fn lift(&'static self) -> CompareKernelRef {
113 CompareKernelRef(ArcRef::new_ref(self))
114 }
115}
116
117impl<V: VTable + CompareKernel> Kernel for CompareKernelAdapter<V> {
118 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
119 let inputs = CompareArgs::try_from(args)?;
120 let Some(array) = inputs.lhs.as_opt::<V>() else {
121 return Ok(None);
122 };
123 Ok(V::compare(&self.0, array, inputs.rhs, inputs.operator)?.map(|array| array.into()))
124 }
125}
126
127struct Compare;
128
129impl ComputeFnVTable for Compare {
130 fn invoke(
131 &self,
132 args: &InvocationArgs,
133 kernels: &[ArcRef<dyn Kernel>],
134 ) -> VortexResult<Output> {
135 let CompareArgs { lhs, rhs, operator } = CompareArgs::try_from(args)?;
136
137 let return_dtype = self.return_dtype(args)?;
138
139 if lhs.is_empty() {
140 return Ok(Canonical::empty(&return_dtype).into_array().into());
141 }
142
143 let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
144 let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
145 if left_constant_null || right_constant_null {
146 return Ok(ConstantArray::new(Scalar::null(return_dtype), lhs.len())
147 .into_array()
148 .into());
149 }
150
151 let right_is_constant = rhs.is_constant();
152
153 if lhs.is_constant() && !right_is_constant {
155 return Ok(compare(rhs, lhs, operator.swap())?.into());
156 }
157
158 for kernel in kernels {
160 if let Some(output) = kernel.invoke(args)? {
161 return Ok(output);
162 }
163 }
164 if let Some(output) = lhs.invoke(&COMPARE_FN, args)? {
165 return Ok(output);
166 }
167
168 let inverted_args = InvocationArgs {
170 inputs: &[rhs.into(), lhs.into()],
171 options: &operator.swap(),
172 };
173 for kernel in kernels {
174 if let Some(output) = kernel.invoke(&inverted_args)? {
175 return Ok(output);
176 }
177 }
178 if let Some(output) = rhs.invoke(&COMPARE_FN, &inverted_args)? {
179 return Ok(output);
180 }
181
182 if !(lhs.is_arrow() && (rhs.is_arrow() || right_is_constant)) {
185 log::debug!(
186 "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)",
187 lhs.encoding_id(),
188 rhs.encoding_id(),
189 operator,
190 );
191 }
192
193 Ok(arrow_compare(lhs, rhs, operator)?.into())
195 }
196
197 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
198 let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
199
200 if !lhs.dtype().eq_ignore_nullability(rhs.dtype()) {
201 vortex_bail!(
202 "Cannot compare different DTypes {} and {}",
203 lhs.dtype(),
204 rhs.dtype()
205 );
206 }
207
208 if lhs.dtype().is_struct() {
210 vortex_bail!(
211 "Compare does not support arrays with Struct DType, got: {} and {}",
212 lhs.dtype(),
213 rhs.dtype()
214 )
215 }
216
217 Ok(DType::Bool(
218 lhs.dtype().nullability() | rhs.dtype().nullability(),
219 ))
220 }
221
222 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
223 let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
224 if lhs.len() != rhs.len() {
225 vortex_bail!(
226 "Compare operations only support arrays of the same length, got {} and {}",
227 lhs.len(),
228 rhs.len()
229 );
230 }
231 Ok(lhs.len())
232 }
233
234 fn is_elementwise(&self) -> bool {
235 true
236 }
237}
238
239struct CompareArgs<'a> {
240 lhs: &'a dyn Array,
241 rhs: &'a dyn Array,
242 operator: Operator,
243}
244
245impl Options for Operator {
246 fn as_any(&self) -> &dyn Any {
247 self
248 }
249}
250
251impl<'a> TryFrom<&InvocationArgs<'a>> for CompareArgs<'a> {
252 type Error = VortexError;
253
254 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
255 if value.inputs.len() != 2 {
256 vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
257 }
258 let lhs = value.inputs[0]
259 .array()
260 .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
261 let rhs = value.inputs[1]
262 .array()
263 .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
264 let operator = *value
265 .options
266 .as_any()
267 .downcast_ref::<Operator>()
268 .vortex_expect("Expected options to be an operator");
269
270 Ok(CompareArgs { lhs, rhs, operator })
271 }
272}
273
274pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BooleanBuffer
277where
278 P: NativePType,
279 I: Iterator<Item = P>,
280{
281 let cmp_fn = match op {
283 Operator::Eq | Operator::Lte => |v| v == P::zero(),
284 Operator::NotEq | Operator::Gt => |v| v != P::zero(),
285 Operator::Gte => |_| true,
286 Operator::Lt => |_| false,
287 };
288
289 lengths.map(cmp_fn).collect::<BooleanBuffer>()
290}
291
292fn arrow_compare(
294 left: &dyn Array,
295 right: &dyn Array,
296 operator: Operator,
297) -> VortexResult<ArrayRef> {
298 let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
299 let lhs = Datum::try_new(left)?;
300 let rhs = Datum::try_new(right)?;
301
302 let array = match operator {
303 Operator::Eq => cmp::eq(&lhs, &rhs)?,
304 Operator::NotEq => cmp::neq(&lhs, &rhs)?,
305 Operator::Gt => cmp::gt(&lhs, &rhs)?,
306 Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
307 Operator::Lt => cmp::lt(&lhs, &rhs)?,
308 Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
309 };
310 from_arrow_array_with_len(&array, left.len(), nullable)
311}
312
313pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
314 if lhs.is_null() | rhs.is_null() {
315 Scalar::null(DType::Bool(Nullability::Nullable))
316 } else {
317 let b = match operator {
318 Operator::Eq => lhs == rhs,
319 Operator::NotEq => lhs != rhs,
320 Operator::Gt => lhs > rhs,
321 Operator::Gte => lhs >= rhs,
322 Operator::Lt => lhs < rhs,
323 Operator::Lte => lhs <= rhs,
324 };
325
326 Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use arrow_buffer::BooleanBuffer;
333 use rstest::rstest;
334
335 use super::*;
336 use crate::ToCanonical;
337 use crate::arrays::{BoolArray, ConstantArray, VarBinArray, VarBinViewArray};
338 use crate::test_harness::to_int_indices;
339 use crate::validity::Validity;
340
341 #[test]
342 fn test_bool_basic_comparisons() {
343 let arr = BoolArray::new(
344 BooleanBuffer::from_iter([true, true, false, true, false]),
345 Validity::from_iter([false, true, true, true, true]),
346 );
347
348 let matches = compare(arr.as_ref(), arr.as_ref(), Operator::Eq)
349 .unwrap()
350 .to_bool()
351 .unwrap();
352
353 assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
354
355 let matches = compare(arr.as_ref(), arr.as_ref(), Operator::NotEq)
356 .unwrap()
357 .to_bool()
358 .unwrap();
359 let empty: [u64; 0] = [];
360 assert_eq!(to_int_indices(matches).unwrap(), empty);
361
362 let other = BoolArray::new(
363 BooleanBuffer::from_iter([false, false, false, true, true]),
364 Validity::from_iter([false, true, true, true, true]),
365 );
366
367 let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lte)
368 .unwrap()
369 .to_bool()
370 .unwrap();
371 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
372
373 let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lt)
374 .unwrap()
375 .to_bool()
376 .unwrap();
377 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
378
379 let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gte)
380 .unwrap()
381 .to_bool()
382 .unwrap();
383 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
384
385 let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gt)
386 .unwrap()
387 .to_bool()
388 .unwrap();
389 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
390 }
391
392 #[test]
393 fn constant_compare() {
394 let left = ConstantArray::new(Scalar::from(2u32), 10);
395 let right = ConstantArray::new(Scalar::from(10u32), 10);
396
397 let compare = compare(left.as_ref(), right.as_ref(), Operator::Gt).unwrap();
398 let res = compare.as_constant().unwrap();
399 assert_eq!(res.as_bool().value(), Some(false));
400 assert_eq!(compare.len(), 10);
401
402 let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
403 let res = compare.as_constant().unwrap();
404 assert_eq!(res.as_bool().value(), Some(false));
405 assert_eq!(compare.len(), 10);
406 }
407
408 #[rstest]
409 #[case(Operator::Eq, vec![false, false, false, true])]
410 #[case(Operator::NotEq, vec![true, true, true, false])]
411 #[case(Operator::Gt, vec![true, true, true, false])]
412 #[case(Operator::Gte, vec![true, true, true, true])]
413 #[case(Operator::Lt, vec![false, false, false, false])]
414 #[case(Operator::Lte, vec![false, false, false, true])]
415 fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
416 let lengths: Vec<i32> = vec![1, 5, 7, 0];
417
418 let output = compare_lengths_to_empty(lengths.iter().copied(), op);
419 assert_eq!(Vec::from_iter(output.iter()), expected);
420 }
421
422 #[rstest]
423 #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
424 #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
425 #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
426 #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
427 fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
428 let res = compare(&left, &right, Operator::Eq).unwrap();
429 assert_eq!(
430 res.to_bool().unwrap().boolean_buffer().count_set_bits(),
431 left.len()
432 );
433 }
434}