vortex_dict/compute/
compare.rs1use vortex_array::arrays::ConstantArray;
2use vortex_array::compute::{CompareFn, Operator, compare, take, try_cast};
3use vortex_array::{Array, ArrayRef, ToCanonical};
4use vortex_dtype::DType;
5use vortex_error::VortexResult;
6use vortex_scalar::Scalar;
7
8use crate::{DictArray, DictEncoding};
9
10impl CompareFn<&DictArray> for DictEncoding {
11 fn compare(
12 &self,
13 lhs: &DictArray,
14 rhs: &dyn Array,
15 operator: Operator,
16 ) -> VortexResult<Option<ArrayRef>> {
17 if let Some(rhs) = rhs.as_constant() {
19 let compare_result = compare(
20 lhs.values(),
21 &ConstantArray::new(rhs, lhs.values().len()),
22 operator,
23 )?;
24
25 let bool = compare_result.to_bool()?;
26 let bool_buffer = bool.boolean_buffer();
27 let mut indices_iter = bool_buffer.set_indices();
28
29 let result = match (indices_iter.next(), indices_iter.next()) {
30 (None, _) => ConstantArray::new(
32 Scalar::bool(false, lhs.dtype().nullability()),
33 lhs.codes().len(),
34 )
35 .into_array(),
36 (Some(code), None) => try_cast(
39 &compare(
40 lhs.codes(),
41 &try_cast(&ConstantArray::new(code, lhs.len()), lhs.codes().dtype())?,
42 operator,
43 )?,
44 &DType::Bool(lhs.dtype().nullability()),
45 )?,
46 _ => take(&bool, lhs.codes())?,
48 };
49 return Ok(Some(result));
50 }
51
52 Ok(None)
55 }
56}
57
58#[cfg(test)]
59mod tests {
60 use vortex_array::arrays::ConstantArray;
61 use vortex_array::compute::{Operator, compare};
62 use vortex_array::{Array, IntoArray, ToCanonical};
63 use vortex_buffer::buffer;
64 use vortex_scalar::Scalar;
65
66 use crate::DictArray;
67
68 #[test]
69 fn test_compare_value() {
70 let dict = DictArray::try_new(
71 buffer![0u32, 1, 2].into_array(),
72 buffer![1i32, 2, 3].into_array(),
73 )
74 .unwrap();
75
76 let res = compare(
77 &dict,
78 &ConstantArray::new(Scalar::from(1i32), 3),
79 Operator::Eq,
80 )
81 .unwrap();
82 let res = res.to_bool().unwrap();
83 assert_eq!(res.len(), 3);
84 assert_eq!(
85 res.boolean_buffer().iter().collect::<Vec<_>>(),
86 vec![true, false, false]
87 );
88 }
89}