1use vortex_array::arrays::ConstantArray;
2use vortex_array::builders::builder_with_capacity;
3use vortex_array::compute::{CompareKernel, CompareKernelAdapter, Operator, cast, compare};
4use vortex_array::validity::Validity;
5use vortex_array::{Array, ArrayRef, ToCanonical, register_kernel};
6use vortex_dtype::{DType, Nullability};
7use vortex_error::VortexResult;
8use vortex_mask::{AllOr, Mask};
9use vortex_scalar::Scalar;
10
11use crate::{DictArray, DictEncoding};
12
13impl CompareKernel for DictEncoding {
14 fn compare(
15 &self,
16 lhs: &DictArray,
17 rhs: &dyn Array,
18 operator: Operator,
19 ) -> VortexResult<Option<ArrayRef>> {
20 if lhs.values().len() > lhs.codes().len() {
22 return Ok(None);
23 }
24 if let Some(rhs) = rhs.as_constant() {
26 let compare_result = compare(
27 lhs.values(),
28 &ConstantArray::new(rhs, lhs.values().len()),
29 operator,
30 )?;
31 return if operator == Operator::Eq {
32 let result_nullability =
33 compare_result.dtype().nullability() | lhs.dtype().nullability();
34 dict_equal_to(compare_result, lhs.codes(), result_nullability).map(Some)
35 } else {
36 DictArray::try_new(lhs.codes().clone(), compare_result)
37 .map(|a| a.into_array())
38 .map(Some)
39 };
40 }
41
42 Ok(None)
45 }
46}
47
48register_kernel!(CompareKernelAdapter(DictEncoding).lift());
49
50fn dict_equal_to(
51 values_compare: ArrayRef,
52 codes: &ArrayRef,
53 result_nullability: Nullability,
54) -> VortexResult<ArrayRef> {
55 let bool_result = values_compare.to_bool()?;
56 let result_validity = bool_result.validity_mask()?;
57 let bool_buffer = bool_result.boolean_buffer();
58 let (first_match, second_match) = match result_validity.boolean_buffer() {
59 AllOr::All => {
60 let mut indices_iter = bool_buffer.set_indices();
61 (indices_iter.next(), indices_iter.next())
62 }
63 AllOr::None => (None, None),
64 AllOr::Some(v) => {
65 let mut indices_iter = bool_buffer.set_indices().filter(|i| v.value(*i));
66 (indices_iter.next(), indices_iter.next())
67 }
68 };
69
70 Ok(match (first_match, second_match) {
71 (None, _) => match result_validity {
73 Mask::AllTrue(_) => {
74 let mut result_builder =
75 builder_with_capacity(&DType::Bool(result_nullability), codes.len());
76 result_builder.extend_from_array(
77 &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
78 .into_array(),
79 )?;
80 result_builder.set_validity(codes.validity_mask()?);
81 result_builder.finish()
82 }
83 Mask::AllFalse(_) => ConstantArray::new(
84 Scalar::null(DType::Bool(Nullability::Nullable)),
85 codes.len(),
86 )
87 .into_array(),
88 Mask::Values(_) => {
89 let mut result_builder =
90 builder_with_capacity(&DType::Bool(result_nullability), codes.len());
91 result_builder.extend_from_array(
92 &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
93 .into_array(),
94 )?;
95 result_builder.set_validity(
96 Validity::from_mask(result_validity, bool_result.dtype().nullability())
97 .take(codes)?
98 .to_mask(codes.len())?,
99 );
100 result_builder.finish()
101 }
102 },
103 (Some(code), None) => cast(
106 &compare(
107 codes,
108 &cast(&ConstantArray::new(code, codes.len()), codes.dtype())?,
109 Operator::Eq,
110 )?,
111 &DType::Bool(result_nullability),
112 )?,
113 _ => DictArray::try_new(codes.clone(), bool_result.into_array())?.into_array(),
115 })
116}
117
118#[cfg(test)]
119mod tests {
120 use vortex_array::arrays::{ConstantArray, PrimitiveArray};
121 use vortex_array::compute::{Operator, compare};
122 use vortex_array::validity::Validity;
123 use vortex_array::{Array, IntoArray, ToCanonical};
124 use vortex_buffer::buffer;
125 use vortex_dtype::Nullability;
126 use vortex_mask::Mask;
127 use vortex_scalar::Scalar;
128
129 use crate::DictArray;
130
131 #[test]
132 fn test_compare_value() {
133 let dict = DictArray::try_new(
134 buffer![0u32, 1, 2].into_array(),
135 buffer![1i32, 2, 3].into_array(),
136 )
137 .unwrap();
138
139 let res = compare(
140 &dict,
141 &ConstantArray::new(Scalar::from(1i32), 3),
142 Operator::Eq,
143 )
144 .unwrap();
145 let res = res.to_bool().unwrap();
146 assert_eq!(
147 res.boolean_buffer().iter().collect::<Vec<_>>(),
148 vec![true, false, false]
149 );
150 }
151
152 #[test]
153 fn test_compare_non_eq() {
154 let dict = DictArray::try_new(
155 buffer![0u32, 1, 2].into_array(),
156 buffer![1i32, 2, 3].into_array(),
157 )
158 .unwrap();
159
160 let res = compare(
161 &dict,
162 &ConstantArray::new(Scalar::from(1i32), 3),
163 Operator::Gt,
164 )
165 .unwrap();
166 let res = res.to_bool().unwrap();
167 assert_eq!(
168 res.boolean_buffer().iter().collect::<Vec<_>>(),
169 vec![false, true, true]
170 );
171 }
172
173 #[test]
174 fn test_compare_nullable() {
175 let dict = DictArray::try_new(
176 PrimitiveArray::new(
177 buffer![0u32, 1, 2],
178 Validity::from_iter([false, true, false]),
179 )
180 .into_array(),
181 PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array(),
182 )
183 .unwrap();
184
185 let res = compare(
186 &dict,
187 &ConstantArray::new(Scalar::primitive(4i32, Nullability::Nullable), 3),
188 Operator::Eq,
189 )
190 .unwrap();
191 let res = res.to_bool().unwrap();
192 assert_eq!(
193 res.boolean_buffer().iter().collect::<Vec<_>>(),
194 vec![false, false, false]
195 );
196 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
197 assert_eq!(
198 res.validity_mask().unwrap(),
199 Mask::from_iter([false, true, false])
200 );
201 }
202
203 #[test]
204 fn test_compare_null_values() {
205 let dict = DictArray::try_new(
206 buffer![0u32, 1, 2].into_array(),
207 PrimitiveArray::new(
208 buffer![1i32, 2, 0],
209 Validity::from_iter([true, true, false]),
210 )
211 .into_array(),
212 )
213 .unwrap();
214
215 let res = compare(
216 &dict,
217 &ConstantArray::new(Scalar::primitive(4i32, Nullability::NonNullable), 3),
218 Operator::Eq,
219 )
220 .unwrap();
221 let res = res.to_bool().unwrap();
222 assert_eq!(
223 res.boolean_buffer().iter().collect::<Vec<_>>(),
224 vec![false, false, false]
225 );
226 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
227 assert_eq!(
228 res.validity_mask().unwrap(),
229 Mask::from_iter([true, true, false])
230 );
231 }
232}