1use vortex_array::arrays::ConstantArray;
5use vortex_array::builders::builder_with_capacity;
6use vortex_array::compute::{CompareKernel, CompareKernelAdapter, Operator, cast, compare};
7use vortex_array::validity::Validity;
8use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
9use vortex_dtype::{DType, Nullability};
10use vortex_error::VortexResult;
11use vortex_mask::{AllOr, Mask};
12use vortex_scalar::Scalar;
13
14use crate::{DictArray, DictVTable};
15
16impl CompareKernel for DictVTable {
17 fn compare(
18 &self,
19 lhs: &DictArray,
20 rhs: &dyn Array,
21 operator: Operator,
22 ) -> VortexResult<Option<ArrayRef>> {
23 if lhs.values().len() > lhs.codes().len() {
25 return Ok(None);
26 }
27 if let Some(rhs) = rhs.as_constant() {
29 let compare_result = compare(
30 lhs.values(),
31 ConstantArray::new(rhs, lhs.values().len()).as_ref(),
32 operator,
33 )?;
34 return if operator == Operator::Eq {
35 let result_nullability =
36 compare_result.dtype().nullability() | lhs.dtype().nullability();
37 dict_equal_to(compare_result, lhs.codes(), result_nullability).map(Some)
38 } else {
39 unsafe {
41 Ok(Some(
42 DictArray::new_unchecked(lhs.codes().clone(), compare_result).into_array(),
43 ))
44 }
45 };
46 }
47
48 Ok(None)
51 }
52}
53
54register_kernel!(CompareKernelAdapter(DictVTable).lift());
55
56fn dict_equal_to(
57 values_compare: ArrayRef,
58 codes: &ArrayRef,
59 result_nullability: Nullability,
60) -> VortexResult<ArrayRef> {
61 let bool_result = values_compare.to_bool()?;
62 let result_validity = bool_result.validity_mask()?;
63 let bool_buffer = bool_result.boolean_buffer();
64 let (first_match, second_match) = match result_validity.boolean_buffer() {
65 AllOr::All => {
66 let mut indices_iter = bool_buffer.set_indices();
67 (indices_iter.next(), indices_iter.next())
68 }
69 AllOr::None => (None, None),
70 AllOr::Some(v) => {
71 let mut indices_iter = bool_buffer.set_indices().filter(|i| v.value(*i));
72 (indices_iter.next(), indices_iter.next())
73 }
74 };
75
76 Ok(match (first_match, second_match) {
77 (None, _) => match result_validity {
79 Mask::AllTrue(_) => {
80 let mut result_builder =
81 builder_with_capacity(&DType::Bool(result_nullability), codes.len());
82 result_builder.extend_from_array(
83 &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
84 .into_array(),
85 )?;
86 result_builder.set_validity(codes.validity_mask()?);
87 result_builder.finish()
88 }
89 Mask::AllFalse(_) => ConstantArray::new(
90 Scalar::null(DType::Bool(Nullability::Nullable)),
91 codes.len(),
92 )
93 .into_array(),
94 Mask::Values(_) => {
95 let mut result_builder =
96 builder_with_capacity(&DType::Bool(result_nullability), codes.len());
97 result_builder.extend_from_array(
98 &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
99 .into_array(),
100 )?;
101 result_builder.set_validity(
102 Validity::from_mask(result_validity, bool_result.dtype().nullability())
103 .take(codes)?
104 .to_mask(codes.len())?,
105 );
106 result_builder.finish()
107 }
108 },
109 (Some(code), None) => cast(
112 &compare(
113 codes,
114 &cast(
115 ConstantArray::new(code, codes.len()).as_ref(),
116 codes.dtype(),
117 )?,
118 Operator::Eq,
119 )?,
120 &DType::Bool(result_nullability),
121 )?,
122 _ => unsafe {
124 DictArray::new_unchecked(codes.clone(), bool_result.into_array()).into_array()
125 },
126 })
127}
128
129#[cfg(test)]
130mod tests {
131 use vortex_array::arrays::{ConstantArray, PrimitiveArray};
132 use vortex_array::compute::{Operator, compare};
133 use vortex_array::validity::Validity;
134 use vortex_array::{IntoArray, ToCanonical};
135 use vortex_buffer::buffer;
136 use vortex_dtype::Nullability;
137 use vortex_mask::Mask;
138 use vortex_scalar::Scalar;
139
140 use crate::DictArray;
141
142 #[test]
143 fn test_compare_value() {
144 let dict = DictArray::try_new(
145 buffer![0u32, 1, 2].into_array(),
146 buffer![1i32, 2, 3].into_array(),
147 )
148 .unwrap();
149
150 let res = compare(
151 dict.as_ref(),
152 ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
153 Operator::Eq,
154 )
155 .unwrap();
156 let res = res.to_bool().unwrap();
157 assert_eq!(
158 res.boolean_buffer().iter().collect::<Vec<_>>(),
159 vec![true, false, false]
160 );
161 }
162
163 #[test]
164 fn test_compare_non_eq() {
165 let dict = DictArray::try_new(
166 buffer![0u32, 1, 2].into_array(),
167 buffer![1i32, 2, 3].into_array(),
168 )
169 .unwrap();
170
171 let res = compare(
172 dict.as_ref(),
173 ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
174 Operator::Gt,
175 )
176 .unwrap();
177 let res = res.to_bool().unwrap();
178 assert_eq!(
179 res.boolean_buffer().iter().collect::<Vec<_>>(),
180 vec![false, true, true]
181 );
182 }
183
184 #[test]
185 fn test_compare_nullable() {
186 let dict = DictArray::try_new(
187 PrimitiveArray::new(
188 buffer![0u32, 1, 2],
189 Validity::from_iter([false, true, false]),
190 )
191 .into_array(),
192 PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array(),
193 )
194 .unwrap();
195
196 let res = compare(
197 dict.as_ref(),
198 ConstantArray::new(Scalar::primitive(4i32, Nullability::Nullable), 3).as_ref(),
199 Operator::Eq,
200 )
201 .unwrap();
202 let res = res.to_bool().unwrap();
203 assert_eq!(
204 res.boolean_buffer().iter().collect::<Vec<_>>(),
205 vec![false, false, false]
206 );
207 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
208 assert_eq!(
209 res.validity_mask().unwrap(),
210 Mask::from_iter([false, true, false])
211 );
212 }
213
214 #[test]
215 fn test_compare_null_values() {
216 let dict = DictArray::try_new(
217 buffer![0u32, 1, 2].into_array(),
218 PrimitiveArray::new(
219 buffer![1i32, 2, 0],
220 Validity::from_iter([true, true, false]),
221 )
222 .into_array(),
223 )
224 .unwrap();
225
226 let res = compare(
227 dict.as_ref(),
228 ConstantArray::new(Scalar::primitive(4i32, Nullability::NonNullable), 3).as_ref(),
229 Operator::Eq,
230 )
231 .unwrap();
232 let res = res.to_bool().unwrap();
233 assert_eq!(
234 res.boolean_buffer().iter().collect::<Vec<_>>(),
235 vec![false, false, false]
236 );
237 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
238 assert_eq!(
239 res.validity_mask().unwrap(),
240 Mask::from_iter([true, true, false])
241 );
242 }
243}