1use vortex_array::arrays::ConstantArray;
5use vortex_array::builders::builder_with_capacity;
6use vortex_array::compute::{CompareKernel, CompareKernelAdapter, Operator, cast, compare};
7use vortex_array::validity::Validity;
8use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
9use vortex_dtype::{DType, Nullability};
10use vortex_error::VortexResult;
11use vortex_mask::{AllOr, Mask};
12use vortex_scalar::Scalar;
13
14use crate::{DictArray, DictVTable};
15
16impl CompareKernel for DictVTable {
17 fn compare(
18 &self,
19 lhs: &DictArray,
20 rhs: &dyn Array,
21 operator: Operator,
22 ) -> VortexResult<Option<ArrayRef>> {
23 if lhs.values().len() > lhs.codes().len() {
25 return Ok(None);
26 }
27 if let Some(rhs) = rhs.as_constant() {
29 let compare_result = compare(
30 lhs.values(),
31 ConstantArray::new(rhs, lhs.values().len()).as_ref(),
32 operator,
33 )?;
34 return if operator == Operator::Eq {
35 let result_nullability =
36 compare_result.dtype().nullability() | lhs.dtype().nullability();
37 dict_equal_to(compare_result, lhs.codes(), result_nullability).map(Some)
38 } else {
39 DictArray::try_new(lhs.codes().clone(), compare_result)
40 .map(|a| a.into_array())
41 .map(Some)
42 };
43 }
44
45 Ok(None)
48 }
49}
50
51register_kernel!(CompareKernelAdapter(DictVTable).lift());
52
53fn dict_equal_to(
54 values_compare: ArrayRef,
55 codes: &ArrayRef,
56 result_nullability: Nullability,
57) -> VortexResult<ArrayRef> {
58 let bool_result = values_compare.to_bool()?;
59 let result_validity = bool_result.validity_mask()?;
60 let bool_buffer = bool_result.boolean_buffer();
61 let (first_match, second_match) = match result_validity.boolean_buffer() {
62 AllOr::All => {
63 let mut indices_iter = bool_buffer.set_indices();
64 (indices_iter.next(), indices_iter.next())
65 }
66 AllOr::None => (None, None),
67 AllOr::Some(v) => {
68 let mut indices_iter = bool_buffer.set_indices().filter(|i| v.value(*i));
69 (indices_iter.next(), indices_iter.next())
70 }
71 };
72
73 Ok(match (first_match, second_match) {
74 (None, _) => match result_validity {
76 Mask::AllTrue(_) => {
77 let mut result_builder =
78 builder_with_capacity(&DType::Bool(result_nullability), codes.len());
79 result_builder.extend_from_array(
80 &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
81 .into_array(),
82 )?;
83 result_builder.set_validity(codes.validity_mask()?);
84 result_builder.finish()
85 }
86 Mask::AllFalse(_) => ConstantArray::new(
87 Scalar::null(DType::Bool(Nullability::Nullable)),
88 codes.len(),
89 )
90 .into_array(),
91 Mask::Values(_) => {
92 let mut result_builder =
93 builder_with_capacity(&DType::Bool(result_nullability), codes.len());
94 result_builder.extend_from_array(
95 &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
96 .into_array(),
97 )?;
98 result_builder.set_validity(
99 Validity::from_mask(result_validity, bool_result.dtype().nullability())
100 .take(codes)?
101 .to_mask(codes.len())?,
102 );
103 result_builder.finish()
104 }
105 },
106 (Some(code), None) => cast(
109 &compare(
110 codes,
111 &cast(
112 ConstantArray::new(code, codes.len()).as_ref(),
113 codes.dtype(),
114 )?,
115 Operator::Eq,
116 )?,
117 &DType::Bool(result_nullability),
118 )?,
119 _ => DictArray::try_new(codes.clone(), bool_result.into_array())?.into_array(),
121 })
122}
123
124#[cfg(test)]
125mod tests {
126 use vortex_array::arrays::{ConstantArray, PrimitiveArray};
127 use vortex_array::compute::{Operator, compare};
128 use vortex_array::validity::Validity;
129 use vortex_array::{IntoArray, ToCanonical};
130 use vortex_buffer::buffer;
131 use vortex_dtype::Nullability;
132 use vortex_mask::Mask;
133 use vortex_scalar::Scalar;
134
135 use crate::DictArray;
136
137 #[test]
138 fn test_compare_value() {
139 let dict = DictArray::try_new(
140 buffer![0u32, 1, 2].into_array(),
141 buffer![1i32, 2, 3].into_array(),
142 )
143 .unwrap();
144
145 let res = compare(
146 dict.as_ref(),
147 ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
148 Operator::Eq,
149 )
150 .unwrap();
151 let res = res.to_bool().unwrap();
152 assert_eq!(
153 res.boolean_buffer().iter().collect::<Vec<_>>(),
154 vec![true, false, false]
155 );
156 }
157
158 #[test]
159 fn test_compare_non_eq() {
160 let dict = DictArray::try_new(
161 buffer![0u32, 1, 2].into_array(),
162 buffer![1i32, 2, 3].into_array(),
163 )
164 .unwrap();
165
166 let res = compare(
167 dict.as_ref(),
168 ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
169 Operator::Gt,
170 )
171 .unwrap();
172 let res = res.to_bool().unwrap();
173 assert_eq!(
174 res.boolean_buffer().iter().collect::<Vec<_>>(),
175 vec![false, true, true]
176 );
177 }
178
179 #[test]
180 fn test_compare_nullable() {
181 let dict = DictArray::try_new(
182 PrimitiveArray::new(
183 buffer![0u32, 1, 2],
184 Validity::from_iter([false, true, false]),
185 )
186 .into_array(),
187 PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array(),
188 )
189 .unwrap();
190
191 let res = compare(
192 dict.as_ref(),
193 ConstantArray::new(Scalar::primitive(4i32, Nullability::Nullable), 3).as_ref(),
194 Operator::Eq,
195 )
196 .unwrap();
197 let res = res.to_bool().unwrap();
198 assert_eq!(
199 res.boolean_buffer().iter().collect::<Vec<_>>(),
200 vec![false, false, false]
201 );
202 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
203 assert_eq!(
204 res.validity_mask().unwrap(),
205 Mask::from_iter([false, true, false])
206 );
207 }
208
209 #[test]
210 fn test_compare_null_values() {
211 let dict = DictArray::try_new(
212 buffer![0u32, 1, 2].into_array(),
213 PrimitiveArray::new(
214 buffer![1i32, 2, 0],
215 Validity::from_iter([true, true, false]),
216 )
217 .into_array(),
218 )
219 .unwrap();
220
221 let res = compare(
222 dict.as_ref(),
223 ConstantArray::new(Scalar::primitive(4i32, Nullability::NonNullable), 3).as_ref(),
224 Operator::Eq,
225 )
226 .unwrap();
227 let res = res.to_bool().unwrap();
228 assert_eq!(
229 res.boolean_buffer().iter().collect::<Vec<_>>(),
230 vec![false, false, false]
231 );
232 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
233 assert_eq!(
234 res.validity_mask().unwrap(),
235 Mask::from_iter([true, true, false])
236 );
237 }
238}