1use core::fmt;
5use std::any::Any;
6use std::fmt::{Display, Formatter};
7use std::sync::LazyLock;
8
9use arcref::ArcRef;
10use arrow_buffer::BooleanBuffer;
11use arrow_ord::cmp;
12use vortex_dtype::{DType, NativePType, Nullability};
13use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
14use vortex_scalar::Scalar;
15
16use crate::arrays::ConstantArray;
17use crate::arrow::{Datum, from_arrow_array_with_len};
18use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
19use crate::vtable::VTable;
20use crate::{Array, ArrayRef, Canonical, IntoArray};
21
22pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
25 COMPARE_FN
26 .invoke(&InvocationArgs {
27 inputs: &[left.into(), right.into()],
28 options: &operator,
29 })?
30 .unwrap_array()
31}
32
33#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Hash)]
34pub enum Operator {
35 Eq,
36 NotEq,
37 Gt,
38 Gte,
39 Lt,
40 Lte,
41}
42
43impl Display for Operator {
44 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
45 let display = match &self {
46 Operator::Eq => "=",
47 Operator::NotEq => "!=",
48 Operator::Gt => ">",
49 Operator::Gte => ">=",
50 Operator::Lt => "<",
51 Operator::Lte => "<=",
52 };
53 Display::fmt(display, f)
54 }
55}
56
57impl Operator {
58 pub fn inverse(self) -> Self {
59 match self {
60 Operator::Eq => Operator::NotEq,
61 Operator::NotEq => Operator::Eq,
62 Operator::Gt => Operator::Lte,
63 Operator::Gte => Operator::Lt,
64 Operator::Lt => Operator::Gte,
65 Operator::Lte => Operator::Gt,
66 }
67 }
68
69 pub fn swap(self) -> Self {
71 match self {
72 Operator::Eq => Operator::Eq,
73 Operator::NotEq => Operator::NotEq,
74 Operator::Gt => Operator::Lt,
75 Operator::Gte => Operator::Lte,
76 Operator::Lt => Operator::Gt,
77 Operator::Lte => Operator::Gte,
78 }
79 }
80}
81
82pub struct CompareKernelRef(ArcRef<dyn Kernel>);
83inventory::collect!(CompareKernelRef);
84
85pub trait CompareKernel: VTable {
86 fn compare(
87 &self,
88 lhs: &Self::Array,
89 rhs: &dyn Array,
90 operator: Operator,
91 ) -> VortexResult<Option<ArrayRef>>;
92}
93
94#[derive(Debug)]
95pub struct CompareKernelAdapter<V: VTable>(pub V);
96
97impl<V: VTable + CompareKernel> CompareKernelAdapter<V> {
98 pub const fn lift(&'static self) -> CompareKernelRef {
99 CompareKernelRef(ArcRef::new_ref(self))
100 }
101}
102
103impl<V: VTable + CompareKernel> Kernel for CompareKernelAdapter<V> {
104 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
105 let inputs = CompareArgs::try_from(args)?;
106 let Some(array) = inputs.lhs.as_opt::<V>() else {
107 return Ok(None);
108 };
109 Ok(V::compare(&self.0, array, inputs.rhs, inputs.operator)?.map(|array| array.into()))
110 }
111}
112
113pub static COMPARE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
114 let compute = ComputeFn::new("compare".into(), ArcRef::new_ref(&Compare));
115 for kernel in inventory::iter::<CompareKernelRef> {
116 compute.register_kernel(kernel.0.clone());
117 }
118 compute
119});
120
121struct Compare;
122
123impl ComputeFnVTable for Compare {
124 fn invoke(
125 &self,
126 args: &InvocationArgs,
127 kernels: &[ArcRef<dyn Kernel>],
128 ) -> VortexResult<Output> {
129 let CompareArgs { lhs, rhs, operator } = CompareArgs::try_from(args)?;
130
131 let return_dtype = self.return_dtype(args)?;
132
133 if lhs.is_empty() {
134 return Ok(Canonical::empty(&return_dtype).into_array().into());
135 }
136
137 let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
138 let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
139 if left_constant_null || right_constant_null {
140 return Ok(ConstantArray::new(Scalar::null(return_dtype), lhs.len())
141 .into_array()
142 .into());
143 }
144
145 let right_is_constant = rhs.is_constant();
146
147 if lhs.is_constant() && !right_is_constant {
149 return Ok(compare(rhs, lhs, operator.swap())?.into());
150 }
151
152 for kernel in kernels {
154 if let Some(output) = kernel.invoke(args)? {
155 return Ok(output);
156 }
157 }
158 if let Some(output) = lhs.invoke(&COMPARE_FN, args)? {
159 return Ok(output);
160 }
161
162 let inverted_args = InvocationArgs {
164 inputs: &[rhs.into(), lhs.into()],
165 options: &operator.swap(),
166 };
167 for kernel in kernels {
168 if let Some(output) = kernel.invoke(&inverted_args)? {
169 return Ok(output);
170 }
171 }
172 if let Some(output) = rhs.invoke(&COMPARE_FN, &inverted_args)? {
173 return Ok(output);
174 }
175
176 if !(lhs.is_arrow() && (rhs.is_arrow() || right_is_constant)) {
179 log::debug!(
180 "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)",
181 lhs.encoding_id(),
182 rhs.encoding_id(),
183 operator,
184 );
185 }
186
187 Ok(arrow_compare(lhs, rhs, operator)?.into())
189 }
190
191 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
192 let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
193
194 if !lhs.dtype().eq_ignore_nullability(rhs.dtype()) {
195 vortex_bail!(
196 "Cannot compare different DTypes {} and {}",
197 lhs.dtype(),
198 rhs.dtype()
199 );
200 }
201
202 if lhs.dtype().is_struct() {
204 vortex_bail!(
205 "Compare does not support arrays with Struct DType, got: {} and {}",
206 lhs.dtype(),
207 rhs.dtype()
208 )
209 }
210
211 Ok(DType::Bool(
212 lhs.dtype().nullability() | rhs.dtype().nullability(),
213 ))
214 }
215
216 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
217 let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
218 if lhs.len() != rhs.len() {
219 vortex_bail!(
220 "Compare operations only support arrays of the same length, got {} and {}",
221 lhs.len(),
222 rhs.len()
223 );
224 }
225 Ok(lhs.len())
226 }
227
228 fn is_elementwise(&self) -> bool {
229 true
230 }
231}
232
233struct CompareArgs<'a> {
234 lhs: &'a dyn Array,
235 rhs: &'a dyn Array,
236 operator: Operator,
237}
238
239impl Options for Operator {
240 fn as_any(&self) -> &dyn Any {
241 self
242 }
243}
244
245impl<'a> TryFrom<&InvocationArgs<'a>> for CompareArgs<'a> {
246 type Error = VortexError;
247
248 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
249 if value.inputs.len() != 2 {
250 vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
251 }
252 let lhs = value.inputs[0]
253 .array()
254 .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
255 let rhs = value.inputs[1]
256 .array()
257 .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
258 let operator = *value
259 .options
260 .as_any()
261 .downcast_ref::<Operator>()
262 .vortex_expect("Expected options to be an operator");
263
264 Ok(CompareArgs { lhs, rhs, operator })
265 }
266}
267
268pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BooleanBuffer
271where
272 P: NativePType,
273 I: Iterator<Item = P>,
274{
275 let cmp_fn = match op {
277 Operator::Eq | Operator::Lte => |v| v == P::zero(),
278 Operator::NotEq | Operator::Gt => |v| v != P::zero(),
279 Operator::Gte => |_| true,
280 Operator::Lt => |_| false,
281 };
282
283 lengths.map(cmp_fn).collect::<BooleanBuffer>()
284}
285
286fn arrow_compare(
288 left: &dyn Array,
289 right: &dyn Array,
290 operator: Operator,
291) -> VortexResult<ArrayRef> {
292 let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
293 let lhs = Datum::try_new(left)?;
294 let rhs = Datum::try_new(right)?;
295
296 let array = match operator {
297 Operator::Eq => cmp::eq(&lhs, &rhs)?,
298 Operator::NotEq => cmp::neq(&lhs, &rhs)?,
299 Operator::Gt => cmp::gt(&lhs, &rhs)?,
300 Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
301 Operator::Lt => cmp::lt(&lhs, &rhs)?,
302 Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
303 };
304 from_arrow_array_with_len(&array, left.len(), nullable)
305}
306
307pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
308 if lhs.is_null() | rhs.is_null() {
309 Scalar::null(DType::Bool(Nullability::Nullable))
310 } else {
311 let b = match operator {
312 Operator::Eq => lhs == rhs,
313 Operator::NotEq => lhs != rhs,
314 Operator::Gt => lhs > rhs,
315 Operator::Gte => lhs >= rhs,
316 Operator::Lt => lhs < rhs,
317 Operator::Lte => lhs <= rhs,
318 };
319
320 Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use arrow_buffer::BooleanBuffer;
327 use rstest::rstest;
328
329 use super::*;
330 use crate::ToCanonical;
331 use crate::arrays::{BoolArray, ConstantArray, VarBinArray, VarBinViewArray};
332 use crate::test_harness::to_int_indices;
333 use crate::validity::Validity;
334
335 #[test]
336 fn test_bool_basic_comparisons() {
337 let arr = BoolArray::new(
338 BooleanBuffer::from_iter([true, true, false, true, false]),
339 Validity::from_iter([false, true, true, true, true]),
340 );
341
342 let matches = compare(arr.as_ref(), arr.as_ref(), Operator::Eq)
343 .unwrap()
344 .to_bool()
345 .unwrap();
346
347 assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
348
349 let matches = compare(arr.as_ref(), arr.as_ref(), Operator::NotEq)
350 .unwrap()
351 .to_bool()
352 .unwrap();
353 let empty: [u64; 0] = [];
354 assert_eq!(to_int_indices(matches).unwrap(), empty);
355
356 let other = BoolArray::new(
357 BooleanBuffer::from_iter([false, false, false, true, true]),
358 Validity::from_iter([false, true, true, true, true]),
359 );
360
361 let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lte)
362 .unwrap()
363 .to_bool()
364 .unwrap();
365 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
366
367 let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lt)
368 .unwrap()
369 .to_bool()
370 .unwrap();
371 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
372
373 let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gte)
374 .unwrap()
375 .to_bool()
376 .unwrap();
377 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
378
379 let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gt)
380 .unwrap()
381 .to_bool()
382 .unwrap();
383 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
384 }
385
386 #[test]
387 fn constant_compare() {
388 let left = ConstantArray::new(Scalar::from(2u32), 10);
389 let right = ConstantArray::new(Scalar::from(10u32), 10);
390
391 let compare = compare(left.as_ref(), right.as_ref(), Operator::Gt).unwrap();
392 let res = compare.as_constant().unwrap();
393 assert_eq!(res.as_bool().value(), Some(false));
394 assert_eq!(compare.len(), 10);
395
396 let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
397 let res = compare.as_constant().unwrap();
398 assert_eq!(res.as_bool().value(), Some(false));
399 assert_eq!(compare.len(), 10);
400 }
401
402 #[rstest]
403 #[case(Operator::Eq, vec![false, false, false, true])]
404 #[case(Operator::NotEq, vec![true, true, true, false])]
405 #[case(Operator::Gt, vec![true, true, true, false])]
406 #[case(Operator::Gte, vec![true, true, true, true])]
407 #[case(Operator::Lt, vec![false, false, false, false])]
408 #[case(Operator::Lte, vec![false, false, false, true])]
409 fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
410 let lengths: Vec<i32> = vec![1, 5, 7, 0];
411
412 let output = compare_lengths_to_empty(lengths.iter().copied(), op);
413 assert_eq!(Vec::from_iter(output.iter()), expected);
414 }
415
416 #[rstest]
417 #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
418 #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
419 #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
420 #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
421 fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
422 let res = compare(&left, &right, Operator::Eq).unwrap();
423 assert_eq!(
424 res.to_bool().unwrap().boolean_buffer().count_set_bits(),
425 left.len()
426 );
427 }
428}