1use core::fmt;
2use std::fmt::{Display, Formatter};
3
4use arrow_buffer::BooleanBuffer;
5use arrow_ord::cmp;
6use vortex_dtype::{DType, NativePType, Nullability};
7use vortex_error::{VortexExpect, VortexResult, vortex_bail};
8use vortex_scalar::Scalar;
9
10use crate::arrays::ConstantArray;
11use crate::arrow::{Datum, from_arrow_array_with_len};
12use crate::encoding::Encoding;
13use crate::{Array, ArrayRef, Canonical, IntoArray};
14
15#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd)]
16pub enum Operator {
17 Eq,
18 NotEq,
19 Gt,
20 Gte,
21 Lt,
22 Lte,
23}
24
25impl Display for Operator {
26 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
27 let display = match &self {
28 Operator::Eq => "=",
29 Operator::NotEq => "!=",
30 Operator::Gt => ">",
31 Operator::Gte => ">=",
32 Operator::Lt => "<",
33 Operator::Lte => "<=",
34 };
35 Display::fmt(display, f)
36 }
37}
38
39impl Operator {
40 pub fn inverse(self) -> Self {
41 match self {
42 Operator::Eq => Operator::NotEq,
43 Operator::NotEq => Operator::Eq,
44 Operator::Gt => Operator::Lte,
45 Operator::Gte => Operator::Lt,
46 Operator::Lt => Operator::Gte,
47 Operator::Lte => Operator::Gt,
48 }
49 }
50
51 pub fn swap(self) -> Self {
53 match self {
54 Operator::Eq => Operator::Eq,
55 Operator::NotEq => Operator::NotEq,
56 Operator::Gt => Operator::Lt,
57 Operator::Gte => Operator::Lte,
58 Operator::Lt => Operator::Gt,
59 Operator::Lte => Operator::Gte,
60 }
61 }
62}
63
64pub trait CompareFn<A> {
65 fn compare(
68 &self,
69 lhs: A,
70 rhs: &dyn Array,
71 operator: Operator,
72 ) -> VortexResult<Option<ArrayRef>>;
73}
74
75impl<E: Encoding> CompareFn<&dyn Array> for E
76where
77 E: for<'a> CompareFn<&'a E::Array>,
78{
79 fn compare(
80 &self,
81 lhs: &dyn Array,
82 rhs: &dyn Array,
83 operator: Operator,
84 ) -> VortexResult<Option<ArrayRef>> {
85 let array_ref = lhs
86 .as_any()
87 .downcast_ref::<E::Array>()
88 .vortex_expect("Failed to downcast array");
89
90 CompareFn::compare(self, array_ref, rhs, operator)
91 }
92}
93
94pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
95 if left.len() != right.len() {
96 vortex_bail!("Compare operations only support arrays of the same length");
97 }
98 if !left.dtype().eq_ignore_nullability(right.dtype()) {
99 vortex_bail!(
100 "Cannot compare different DTypes {} and {}",
101 left.dtype(),
102 right.dtype()
103 );
104 }
105
106 if left.dtype().is_struct() {
108 vortex_bail!(
109 "Compare does not support arrays with Struct DType, got: {} and {}",
110 left.dtype(),
111 right.dtype()
112 )
113 }
114
115 let result_dtype =
116 DType::Bool((left.dtype().is_nullable() || right.dtype().is_nullable()).into());
117
118 if left.is_empty() {
119 return Ok(Canonical::empty(&result_dtype).into_array());
120 }
121
122 let left_constant_null = left.as_constant().map(|l| l.is_null()).unwrap_or(false);
123 let right_constant_null = right.as_constant().map(|r| r.is_null()).unwrap_or(false);
124 if left_constant_null || right_constant_null {
125 return Ok(ConstantArray::new(Scalar::null(result_dtype), left.len()).into_array());
126 }
127
128 let right_is_constant = right.is_constant();
129
130 if left.is_constant() && !right_is_constant {
132 return compare(right, left, operator.swap());
133 }
134
135 if let Some(result) = left
136 .vtable()
137 .compare_fn()
138 .and_then(|f| f.compare(left, right, operator).transpose())
139 .transpose()?
140 {
141 check_compare_result(&result, left, right);
142 return Ok(result);
143 }
144
145 if let Some(result) = right
146 .vtable()
147 .compare_fn()
148 .and_then(|f| f.compare(right, left, operator.swap()).transpose())
149 .transpose()?
150 {
151 check_compare_result(&result, left, right);
152 return Ok(result);
153 }
154
155 if !(left.is_arrow() && (right.is_arrow() || right_is_constant)) {
158 log::debug!(
159 "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)",
160 right.encoding(),
161 left.encoding(),
162 operator.swap(),
163 );
164 }
165
166 let result = arrow_compare(left, right, operator)?;
168 check_compare_result(&result, left, right);
169 Ok(result)
170}
171
172pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BooleanBuffer
175where
176 P: NativePType,
177 I: Iterator<Item = P>,
178{
179 let cmp_fn = match op {
181 Operator::Eq | Operator::Lte => |v| v == P::zero(),
182 Operator::NotEq | Operator::Gt => |v| v != P::zero(),
183 Operator::Gte => |_| true,
184 Operator::Lt => |_| false,
185 };
186
187 lengths.map(cmp_fn).collect::<BooleanBuffer>()
188}
189
190fn arrow_compare(
192 left: &dyn Array,
193 right: &dyn Array,
194 operator: Operator,
195) -> VortexResult<ArrayRef> {
196 let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
197 let lhs = Datum::try_new(left.to_array())?;
198 let rhs = Datum::try_new(right.to_array())?;
199
200 let array = match operator {
201 Operator::Eq => cmp::eq(&lhs, &rhs)?,
202 Operator::NotEq => cmp::neq(&lhs, &rhs)?,
203 Operator::Gt => cmp::gt(&lhs, &rhs)?,
204 Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
205 Operator::Lt => cmp::lt(&lhs, &rhs)?,
206 Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
207 };
208 from_arrow_array_with_len(&array, left.len(), nullable)
209}
210
211#[inline(always)]
212fn check_compare_result(result: &dyn Array, lhs: &dyn Array, rhs: &dyn Array) {
213 debug_assert_eq!(
214 result.len(),
215 lhs.len(),
216 "CompareFn result length ({}) mismatch for left encoding {}, left len {}, right encoding {}, right len {}",
217 result.len(),
218 lhs.encoding(),
219 lhs.len(),
220 rhs.encoding(),
221 rhs.len()
222 );
223 debug_assert_eq!(
224 result.dtype(),
225 &DType::Bool((lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()),
226 "CompareFn result dtype ({}) mismatch for left encoding {}, right encoding {}",
227 result.dtype(),
228 lhs.encoding(),
229 rhs.encoding(),
230 );
231}
232
233pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
234 if lhs.is_null() | rhs.is_null() {
235 Scalar::null(DType::Bool(Nullability::Nullable))
236 } else {
237 let b = match operator {
238 Operator::Eq => lhs == rhs,
239 Operator::NotEq => lhs != rhs,
240 Operator::Gt => lhs > rhs,
241 Operator::Gte => lhs >= rhs,
242 Operator::Lt => lhs < rhs,
243 Operator::Lte => lhs <= rhs,
244 };
245
246 Scalar::bool(
247 b,
248 (lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into(),
249 )
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use arrow_buffer::BooleanBuffer;
256 use itertools::Itertools;
257
258 use super::*;
259 use crate::ToCanonical;
260 use crate::arrays::{BoolArray, ConstantArray};
261 use crate::validity::Validity;
262
263 fn to_int_indices(indices_bits: BoolArray) -> Vec<u64> {
264 let buffer = indices_bits.boolean_buffer();
265 let null_buffer = indices_bits
266 .validity()
267 .to_logical(indices_bits.len())
268 .unwrap()
269 .to_null_buffer();
270 let is_valid = |idx: usize| match null_buffer.as_ref() {
271 None => true,
272 Some(buffer) => buffer.is_valid(idx),
273 };
274 let filtered = buffer
275 .iter()
276 .enumerate()
277 .flat_map(|(idx, v)| (v && is_valid(idx)).then_some(idx as u64))
278 .collect_vec();
279 filtered
280 }
281
282 #[test]
283 fn test_bool_basic_comparisons() {
284 let arr = BoolArray::new(
285 BooleanBuffer::from_iter([true, true, false, true, false]),
286 Validity::from_iter([false, true, true, true, true]),
287 );
288
289 let matches = compare(&arr, &arr, Operator::Eq)
290 .unwrap()
291 .to_bool()
292 .unwrap();
293
294 assert_eq!(to_int_indices(matches), [1u64, 2, 3, 4]);
295
296 let matches = compare(&arr, &arr, Operator::NotEq)
297 .unwrap()
298 .to_bool()
299 .unwrap();
300 let empty: [u64; 0] = [];
301 assert_eq!(to_int_indices(matches), empty);
302
303 let other = BoolArray::new(
304 BooleanBuffer::from_iter([false, false, false, true, true]),
305 Validity::from_iter([false, true, true, true, true]),
306 );
307
308 let matches = compare(&arr, &other, Operator::Lte)
309 .unwrap()
310 .to_bool()
311 .unwrap();
312 assert_eq!(to_int_indices(matches), [2u64, 3, 4]);
313
314 let matches = compare(&arr, &other, Operator::Lt)
315 .unwrap()
316 .to_bool()
317 .unwrap();
318 assert_eq!(to_int_indices(matches), [4u64]);
319
320 let matches = compare(&other, &arr, Operator::Gte)
321 .unwrap()
322 .to_bool()
323 .unwrap();
324 assert_eq!(to_int_indices(matches), [2u64, 3, 4]);
325
326 let matches = compare(&other, &arr, Operator::Gt)
327 .unwrap()
328 .to_bool()
329 .unwrap();
330 assert_eq!(to_int_indices(matches), [4u64]);
331 }
332
333 #[test]
334 fn constant_compare() {
335 let left = ConstantArray::new(Scalar::from(2u32), 10);
336 let right = ConstantArray::new(Scalar::from(10u32), 10);
337
338 let compare = compare(&left, &right, Operator::Gt).unwrap();
339 let res = compare.as_constant().unwrap();
340 assert_eq!(res.as_bool().value(), Some(false));
341 assert_eq!(compare.len(), 10);
342
343 let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
344 let res = compare.as_constant().unwrap();
345 assert_eq!(res.as_bool().value(), Some(false));
346 assert_eq!(compare.len(), 10);
347 }
348
349 #[rstest::rstest]
350 #[case(Operator::Eq, vec![false, false, false, true])]
351 #[case(Operator::NotEq, vec![true, true, true, false])]
352 #[case(Operator::Gt, vec![true, true, true, false])]
353 #[case(Operator::Gte, vec![true, true, true, true])]
354 #[case(Operator::Lt, vec![false, false, false, false])]
355 #[case(Operator::Lte, vec![false, false, false, true])]
356 fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
357 let lengths: Vec<i32> = vec![1, 5, 7, 0];
358
359 let output = compare_lengths_to_empty(lengths.iter().copied(), op);
360 assert_eq!(Vec::from_iter(output.iter()), expected);
361 }
362}