1use vortex_array::arrays::ConstantArray;
2use vortex_array::builders::builder_with_capacity;
3use vortex_array::compute::{CompareKernel, CompareKernelAdapter, Operator, cast, compare};
4use vortex_array::validity::Validity;
5use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
6use vortex_dtype::{DType, Nullability};
7use vortex_error::VortexResult;
8use vortex_mask::{AllOr, Mask};
9use vortex_scalar::Scalar;
10
11use crate::{DictArray, DictVTable};
12
13impl CompareKernel for DictVTable {
14 fn compare(
15 &self,
16 lhs: &DictArray,
17 rhs: &dyn Array,
18 operator: Operator,
19 ) -> VortexResult<Option<ArrayRef>> {
20 if lhs.values().len() > lhs.codes().len() {
22 return Ok(None);
23 }
24 if let Some(rhs) = rhs.as_constant() {
26 let compare_result = compare(
27 lhs.values(),
28 ConstantArray::new(rhs, lhs.values().len()).as_ref(),
29 operator,
30 )?;
31 return if operator == Operator::Eq {
32 let result_nullability =
33 compare_result.dtype().nullability() | lhs.dtype().nullability();
34 dict_equal_to(compare_result, lhs.codes(), result_nullability).map(Some)
35 } else {
36 DictArray::try_new(lhs.codes().clone(), compare_result)
37 .map(|a| a.into_array())
38 .map(Some)
39 };
40 }
41
42 Ok(None)
45 }
46}
47
48register_kernel!(CompareKernelAdapter(DictVTable).lift());
49
50fn dict_equal_to(
51 values_compare: ArrayRef,
52 codes: &ArrayRef,
53 result_nullability: Nullability,
54) -> VortexResult<ArrayRef> {
55 let bool_result = values_compare.to_bool()?;
56 let result_validity = bool_result.validity_mask()?;
57 let bool_buffer = bool_result.boolean_buffer();
58 let (first_match, second_match) = match result_validity.boolean_buffer() {
59 AllOr::All => {
60 let mut indices_iter = bool_buffer.set_indices();
61 (indices_iter.next(), indices_iter.next())
62 }
63 AllOr::None => (None, None),
64 AllOr::Some(v) => {
65 let mut indices_iter = bool_buffer.set_indices().filter(|i| v.value(*i));
66 (indices_iter.next(), indices_iter.next())
67 }
68 };
69
70 Ok(match (first_match, second_match) {
71 (None, _) => match result_validity {
73 Mask::AllTrue(_) => {
74 let mut result_builder =
75 builder_with_capacity(&DType::Bool(result_nullability), codes.len());
76 result_builder.extend_from_array(
77 &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
78 .into_array(),
79 )?;
80 result_builder.set_validity(codes.validity_mask()?);
81 result_builder.finish()
82 }
83 Mask::AllFalse(_) => ConstantArray::new(
84 Scalar::null(DType::Bool(Nullability::Nullable)),
85 codes.len(),
86 )
87 .into_array(),
88 Mask::Values(_) => {
89 let mut result_builder =
90 builder_with_capacity(&DType::Bool(result_nullability), codes.len());
91 result_builder.extend_from_array(
92 &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
93 .into_array(),
94 )?;
95 result_builder.set_validity(
96 Validity::from_mask(result_validity, bool_result.dtype().nullability())
97 .take(codes)?
98 .to_mask(codes.len())?,
99 );
100 result_builder.finish()
101 }
102 },
103 (Some(code), None) => cast(
106 &compare(
107 codes,
108 &cast(
109 ConstantArray::new(code, codes.len()).as_ref(),
110 codes.dtype(),
111 )?,
112 Operator::Eq,
113 )?,
114 &DType::Bool(result_nullability),
115 )?,
116 _ => DictArray::try_new(codes.clone(), bool_result.into_array())?.into_array(),
118 })
119}
120
121#[cfg(test)]
122mod tests {
123 use vortex_array::arrays::{ConstantArray, PrimitiveArray};
124 use vortex_array::compute::{Operator, compare};
125 use vortex_array::validity::Validity;
126 use vortex_array::{IntoArray, ToCanonical};
127 use vortex_buffer::buffer;
128 use vortex_dtype::Nullability;
129 use vortex_mask::Mask;
130 use vortex_scalar::Scalar;
131
132 use crate::DictArray;
133
134 #[test]
135 fn test_compare_value() {
136 let dict = DictArray::try_new(
137 buffer![0u32, 1, 2].into_array(),
138 buffer![1i32, 2, 3].into_array(),
139 )
140 .unwrap();
141
142 let res = compare(
143 dict.as_ref(),
144 ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
145 Operator::Eq,
146 )
147 .unwrap();
148 let res = res.to_bool().unwrap();
149 assert_eq!(
150 res.boolean_buffer().iter().collect::<Vec<_>>(),
151 vec![true, false, false]
152 );
153 }
154
155 #[test]
156 fn test_compare_non_eq() {
157 let dict = DictArray::try_new(
158 buffer![0u32, 1, 2].into_array(),
159 buffer![1i32, 2, 3].into_array(),
160 )
161 .unwrap();
162
163 let res = compare(
164 dict.as_ref(),
165 ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
166 Operator::Gt,
167 )
168 .unwrap();
169 let res = res.to_bool().unwrap();
170 assert_eq!(
171 res.boolean_buffer().iter().collect::<Vec<_>>(),
172 vec![false, true, true]
173 );
174 }
175
176 #[test]
177 fn test_compare_nullable() {
178 let dict = DictArray::try_new(
179 PrimitiveArray::new(
180 buffer![0u32, 1, 2],
181 Validity::from_iter([false, true, false]),
182 )
183 .into_array(),
184 PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array(),
185 )
186 .unwrap();
187
188 let res = compare(
189 dict.as_ref(),
190 ConstantArray::new(Scalar::primitive(4i32, Nullability::Nullable), 3).as_ref(),
191 Operator::Eq,
192 )
193 .unwrap();
194 let res = res.to_bool().unwrap();
195 assert_eq!(
196 res.boolean_buffer().iter().collect::<Vec<_>>(),
197 vec![false, false, false]
198 );
199 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
200 assert_eq!(
201 res.validity_mask().unwrap(),
202 Mask::from_iter([false, true, false])
203 );
204 }
205
206 #[test]
207 fn test_compare_null_values() {
208 let dict = DictArray::try_new(
209 buffer![0u32, 1, 2].into_array(),
210 PrimitiveArray::new(
211 buffer![1i32, 2, 0],
212 Validity::from_iter([true, true, false]),
213 )
214 .into_array(),
215 )
216 .unwrap();
217
218 let res = compare(
219 dict.as_ref(),
220 ConstantArray::new(Scalar::primitive(4i32, Nullability::NonNullable), 3).as_ref(),
221 Operator::Eq,
222 )
223 .unwrap();
224 let res = res.to_bool().unwrap();
225 assert_eq!(
226 res.boolean_buffer().iter().collect::<Vec<_>>(),
227 vec![false, false, false]
228 );
229 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
230 assert_eq!(
231 res.validity_mask().unwrap(),
232 Mask::from_iter([true, true, false])
233 );
234 }
235}