1use core::fmt;
5use std::any::Any;
6use std::fmt::{Display, Formatter};
7use std::sync::LazyLock;
8
9use arcref::ArcRef;
10use arrow_buffer::BooleanBuffer;
11use arrow_ord::cmp;
12use vortex_dtype::{DType, NativePType, Nullability};
13use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
14use vortex_scalar::Scalar;
15
16use crate::arrays::ConstantArray;
17use crate::arrow::{Datum, from_arrow_array_with_len};
18use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
19use crate::vtable::VTable;
20use crate::{Array, ArrayRef, Canonical, IntoArray};
21
22static COMPARE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
23 let compute = ComputeFn::new("compare".into(), ArcRef::new_ref(&Compare));
24 for kernel in inventory::iter::<CompareKernelRef> {
25 compute.register_kernel(kernel.0.clone());
26 }
27 compute
28});
29
30pub(crate) fn warm_up_vtable() -> usize {
31 COMPARE_FN.kernels().len()
32}
33
34pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
37 COMPARE_FN
38 .invoke(&InvocationArgs {
39 inputs: &[left.into(), right.into()],
40 options: &operator,
41 })?
42 .unwrap_array()
43}
44
45#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Hash)]
46pub enum Operator {
47 Eq,
49 NotEq,
51 Gt,
53 Gte,
55 Lt,
57 Lte,
59}
60
61impl Display for Operator {
62 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
63 let display = match &self {
64 Operator::Eq => "=",
65 Operator::NotEq => "!=",
66 Operator::Gt => ">",
67 Operator::Gte => ">=",
68 Operator::Lt => "<",
69 Operator::Lte => "<=",
70 };
71 Display::fmt(display, f)
72 }
73}
74
75impl Operator {
76 pub fn inverse(self) -> Self {
77 match self {
78 Operator::Eq => Operator::NotEq,
79 Operator::NotEq => Operator::Eq,
80 Operator::Gt => Operator::Lte,
81 Operator::Gte => Operator::Lt,
82 Operator::Lt => Operator::Gte,
83 Operator::Lte => Operator::Gt,
84 }
85 }
86
87 pub fn swap(self) -> Self {
89 match self {
90 Operator::Eq => Operator::Eq,
91 Operator::NotEq => Operator::NotEq,
92 Operator::Gt => Operator::Lt,
93 Operator::Gte => Operator::Lte,
94 Operator::Lt => Operator::Gt,
95 Operator::Lte => Operator::Gte,
96 }
97 }
98}
99
100pub struct CompareKernelRef(ArcRef<dyn Kernel>);
101inventory::collect!(CompareKernelRef);
102
103pub trait CompareKernel: VTable {
104 fn compare(
105 &self,
106 lhs: &Self::Array,
107 rhs: &dyn Array,
108 operator: Operator,
109 ) -> VortexResult<Option<ArrayRef>>;
110}
111
112#[derive(Debug)]
113pub struct CompareKernelAdapter<V: VTable>(pub V);
114
115impl<V: VTable + CompareKernel> CompareKernelAdapter<V> {
116 pub const fn lift(&'static self) -> CompareKernelRef {
117 CompareKernelRef(ArcRef::new_ref(self))
118 }
119}
120
121impl<V: VTable + CompareKernel> Kernel for CompareKernelAdapter<V> {
122 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
123 let inputs = CompareArgs::try_from(args)?;
124 let Some(array) = inputs.lhs.as_opt::<V>() else {
125 return Ok(None);
126 };
127 Ok(V::compare(&self.0, array, inputs.rhs, inputs.operator)?.map(|array| array.into()))
128 }
129}
130
131struct Compare;
132
133impl ComputeFnVTable for Compare {
134 fn invoke(
135 &self,
136 args: &InvocationArgs,
137 kernels: &[ArcRef<dyn Kernel>],
138 ) -> VortexResult<Output> {
139 let CompareArgs { lhs, rhs, operator } = CompareArgs::try_from(args)?;
140
141 let return_dtype = self.return_dtype(args)?;
142
143 if lhs.is_empty() {
144 return Ok(Canonical::empty(&return_dtype).into_array().into());
145 }
146
147 let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
148 let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
149 if left_constant_null || right_constant_null {
150 return Ok(ConstantArray::new(Scalar::null(return_dtype), lhs.len())
151 .into_array()
152 .into());
153 }
154
155 let right_is_constant = rhs.is_constant();
156
157 if lhs.is_constant() && !right_is_constant {
159 return Ok(compare(rhs, lhs, operator.swap())?.into());
160 }
161
162 for kernel in kernels {
164 if let Some(output) = kernel.invoke(args)? {
165 return Ok(output);
166 }
167 }
168 if let Some(output) = lhs.invoke(&COMPARE_FN, args)? {
169 return Ok(output);
170 }
171
172 let inverted_args = InvocationArgs {
174 inputs: &[rhs.into(), lhs.into()],
175 options: &operator.swap(),
176 };
177 for kernel in kernels {
178 if let Some(output) = kernel.invoke(&inverted_args)? {
179 return Ok(output);
180 }
181 }
182 if let Some(output) = rhs.invoke(&COMPARE_FN, &inverted_args)? {
183 return Ok(output);
184 }
185
186 if !(lhs.is_arrow() && (rhs.is_arrow() || right_is_constant)) {
189 log::debug!(
190 "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)",
191 lhs.encoding_id(),
192 rhs.encoding_id(),
193 operator,
194 );
195 }
196
197 Ok(arrow_compare(lhs, rhs, operator)?.into())
199 }
200
201 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
202 let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
203
204 if !lhs.dtype().eq_ignore_nullability(rhs.dtype()) {
205 vortex_bail!(
206 "Cannot compare different DTypes {} and {}",
207 lhs.dtype(),
208 rhs.dtype()
209 );
210 }
211
212 if lhs.dtype().is_struct() {
214 vortex_bail!(
215 "Compare does not support arrays with Struct DType, got: {} and {}",
216 lhs.dtype(),
217 rhs.dtype()
218 )
219 }
220
221 Ok(DType::Bool(
222 lhs.dtype().nullability() | rhs.dtype().nullability(),
223 ))
224 }
225
226 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
227 let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
228 if lhs.len() != rhs.len() {
229 vortex_bail!(
230 "Compare operations only support arrays of the same length, got {} and {}",
231 lhs.len(),
232 rhs.len()
233 );
234 }
235 Ok(lhs.len())
236 }
237
238 fn is_elementwise(&self) -> bool {
239 true
240 }
241}
242
243struct CompareArgs<'a> {
244 lhs: &'a dyn Array,
245 rhs: &'a dyn Array,
246 operator: Operator,
247}
248
249impl Options for Operator {
250 fn as_any(&self) -> &dyn Any {
251 self
252 }
253}
254
255impl<'a> TryFrom<&InvocationArgs<'a>> for CompareArgs<'a> {
256 type Error = VortexError;
257
258 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
259 if value.inputs.len() != 2 {
260 vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
261 }
262 let lhs = value.inputs[0]
263 .array()
264 .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
265 let rhs = value.inputs[1]
266 .array()
267 .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
268 let operator = *value
269 .options
270 .as_any()
271 .downcast_ref::<Operator>()
272 .vortex_expect("Expected options to be an operator");
273
274 Ok(CompareArgs { lhs, rhs, operator })
275 }
276}
277
278pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BooleanBuffer
281where
282 P: NativePType,
283 I: Iterator<Item = P>,
284{
285 let cmp_fn = match op {
287 Operator::Eq | Operator::Lte => |v| v == P::zero(),
288 Operator::NotEq | Operator::Gt => |v| v != P::zero(),
289 Operator::Gte => |_| true,
290 Operator::Lt => |_| false,
291 };
292
293 lengths.map(cmp_fn).collect::<BooleanBuffer>()
294}
295
296fn arrow_compare(
298 left: &dyn Array,
299 right: &dyn Array,
300 operator: Operator,
301) -> VortexResult<ArrayRef> {
302 let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
303 let lhs = Datum::try_new(left)?;
304 let rhs = Datum::try_new(right)?;
305
306 let array = match operator {
307 Operator::Eq => cmp::eq(&lhs, &rhs)?,
308 Operator::NotEq => cmp::neq(&lhs, &rhs)?,
309 Operator::Gt => cmp::gt(&lhs, &rhs)?,
310 Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
311 Operator::Lt => cmp::lt(&lhs, &rhs)?,
312 Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
313 };
314 Ok(from_arrow_array_with_len(&array, left.len(), nullable))
315}
316
317pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
318 if lhs.is_null() | rhs.is_null() {
319 Scalar::null(DType::Bool(Nullability::Nullable))
320 } else {
321 let b = match operator {
322 Operator::Eq => lhs == rhs,
323 Operator::NotEq => lhs != rhs,
324 Operator::Gt => lhs > rhs,
325 Operator::Gte => lhs >= rhs,
326 Operator::Lt => lhs < rhs,
327 Operator::Lte => lhs <= rhs,
328 };
329
330 Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use arrow_buffer::BooleanBuffer;
337 use rstest::rstest;
338
339 use super::*;
340 use crate::ToCanonical;
341 use crate::arrays::{BoolArray, ConstantArray, VarBinArray, VarBinViewArray};
342 use crate::test_harness::to_int_indices;
343 use crate::validity::Validity;
344
345 #[test]
346 fn test_bool_basic_comparisons() {
347 let arr = BoolArray::from_bool_buffer(
348 BooleanBuffer::from_iter([true, true, false, true, false]),
349 Validity::from_iter([false, true, true, true, true]),
350 );
351
352 let matches = compare(arr.as_ref(), arr.as_ref(), Operator::Eq)
353 .unwrap()
354 .to_bool();
355
356 assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
357
358 let matches = compare(arr.as_ref(), arr.as_ref(), Operator::NotEq)
359 .unwrap()
360 .to_bool();
361 let empty: [u64; 0] = [];
362 assert_eq!(to_int_indices(matches).unwrap(), empty);
363
364 let other = BoolArray::from_bool_buffer(
365 BooleanBuffer::from_iter([false, false, false, true, true]),
366 Validity::from_iter([false, true, true, true, true]),
367 );
368
369 let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lte)
370 .unwrap()
371 .to_bool();
372 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
373
374 let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lt)
375 .unwrap()
376 .to_bool();
377 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
378
379 let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gte)
380 .unwrap()
381 .to_bool();
382 assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
383
384 let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gt)
385 .unwrap()
386 .to_bool();
387 assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
388 }
389
390 #[test]
391 fn constant_compare() {
392 let left = ConstantArray::new(Scalar::from(2u32), 10);
393 let right = ConstantArray::new(Scalar::from(10u32), 10);
394
395 let compare = compare(left.as_ref(), right.as_ref(), Operator::Gt).unwrap();
396 let res = compare.as_constant().unwrap();
397 assert_eq!(res.as_bool().value(), Some(false));
398 assert_eq!(compare.len(), 10);
399
400 let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
401 let res = compare.as_constant().unwrap();
402 assert_eq!(res.as_bool().value(), Some(false));
403 assert_eq!(compare.len(), 10);
404 }
405
406 #[rstest]
407 #[case(Operator::Eq, vec![false, false, false, true])]
408 #[case(Operator::NotEq, vec![true, true, true, false])]
409 #[case(Operator::Gt, vec![true, true, true, false])]
410 #[case(Operator::Gte, vec![true, true, true, true])]
411 #[case(Operator::Lt, vec![false, false, false, false])]
412 #[case(Operator::Lte, vec![false, false, false, true])]
413 fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
414 let lengths: Vec<i32> = vec![1, 5, 7, 0];
415
416 let output = compare_lengths_to_empty(lengths.iter().copied(), op);
417 assert_eq!(Vec::from_iter(output.iter()), expected);
418 }
419
420 #[rstest]
421 #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
422 #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
423 #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
424 #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
425 fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
426 let res = compare(&left, &right, Operator::Eq).unwrap();
427 assert_eq!(res.to_bool().boolean_buffer().count_set_bits(), left.len());
428 }
429}