1use vortex_array::arrays::ConstantArray;
2use vortex_array::builders::builder_with_capacity;
3use vortex_array::compute::{CompareFn, Operator, compare, try_cast};
4use vortex_array::validity::Validity;
5use vortex_array::{Array, ArrayRef, ToCanonical};
6use vortex_dtype::{DType, Nullability};
7use vortex_error::VortexResult;
8use vortex_mask::{AllOr, Mask};
9use vortex_scalar::Scalar;
10
11use crate::{DictArray, DictEncoding};
12
13impl CompareFn<&DictArray> for DictEncoding {
14 fn compare(
15 &self,
16 lhs: &DictArray,
17 rhs: &dyn Array,
18 operator: Operator,
19 ) -> VortexResult<Option<ArrayRef>> {
20 if lhs.values().len() > lhs.codes().len() {
22 return Ok(None);
23 }
24 if let Some(rhs) = rhs.as_constant() {
26 let compare_result = compare(
27 lhs.values(),
28 &ConstantArray::new(rhs, lhs.values().len()),
29 operator,
30 )?;
31 return if operator == Operator::Eq {
32 let result_nullability =
33 compare_result.dtype().nullability() | lhs.dtype().nullability();
34 dict_equal_to(compare_result, lhs.codes(), result_nullability).map(Some)
35 } else {
36 DictArray::try_new(lhs.codes().clone(), compare_result)
37 .map(|a| a.into_array())
38 .map(Some)
39 };
40 }
41
42 Ok(None)
45 }
46}
47
48fn dict_equal_to(
49 values_compare: ArrayRef,
50 codes: &ArrayRef,
51 result_nullability: Nullability,
52) -> VortexResult<ArrayRef> {
53 let bool_result = values_compare.to_bool()?;
54 let result_validity = bool_result.validity_mask()?;
55 let bool_buffer = bool_result.boolean_buffer();
56 let (first_match, second_match) = match result_validity.boolean_buffer() {
57 AllOr::All => {
58 let mut indices_iter = bool_buffer.set_indices();
59 (indices_iter.next(), indices_iter.next())
60 }
61 AllOr::None => (None, None),
62 AllOr::Some(v) => {
63 let mut indices_iter = bool_buffer.set_indices().filter(|i| v.value(*i));
64 (indices_iter.next(), indices_iter.next())
65 }
66 };
67
68 Ok(match (first_match, second_match) {
69 (None, _) => match result_validity {
71 Mask::AllTrue(_) => {
72 let mut result_builder =
73 builder_with_capacity(&DType::Bool(result_nullability), codes.len());
74 result_builder.extend_from_array(
75 &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
76 .into_array(),
77 )?;
78 result_builder.set_validity(codes.validity_mask()?);
79 result_builder.finish()
80 }
81 Mask::AllFalse(_) => ConstantArray::new(
82 Scalar::null(DType::Bool(Nullability::Nullable)),
83 codes.len(),
84 )
85 .into_array(),
86 Mask::Values(_) => {
87 let mut result_builder =
88 builder_with_capacity(&DType::Bool(result_nullability), codes.len());
89 result_builder.extend_from_array(
90 &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
91 .into_array(),
92 )?;
93 result_builder.set_validity(
94 Validity::from_mask(result_validity, bool_result.dtype().nullability())
95 .take(codes)?
96 .to_mask(codes.len())?,
97 );
98 result_builder.finish()
99 }
100 },
101 (Some(code), None) => try_cast(
104 &compare(
105 codes,
106 &try_cast(&ConstantArray::new(code, codes.len()), codes.dtype())?,
107 Operator::Eq,
108 )?,
109 &DType::Bool(result_nullability),
110 )?,
111 _ => DictArray::try_new(codes.clone(), bool_result.into_array())?.into_array(),
113 })
114}
115
116#[cfg(test)]
117mod tests {
118 use vortex_array::arrays::{ConstantArray, PrimitiveArray};
119 use vortex_array::compute::{Operator, compare};
120 use vortex_array::validity::Validity;
121 use vortex_array::{Array, IntoArray, ToCanonical};
122 use vortex_buffer::buffer;
123 use vortex_dtype::Nullability;
124 use vortex_mask::Mask;
125 use vortex_scalar::Scalar;
126
127 use crate::DictArray;
128
129 #[test]
130 fn test_compare_value() {
131 let dict = DictArray::try_new(
132 buffer![0u32, 1, 2].into_array(),
133 buffer![1i32, 2, 3].into_array(),
134 )
135 .unwrap();
136
137 let res = compare(
138 &dict,
139 &ConstantArray::new(Scalar::from(1i32), 3),
140 Operator::Eq,
141 )
142 .unwrap();
143 let res = res.to_bool().unwrap();
144 assert_eq!(
145 res.boolean_buffer().iter().collect::<Vec<_>>(),
146 vec![true, false, false]
147 );
148 }
149
150 #[test]
151 fn test_compare_non_eq() {
152 let dict = DictArray::try_new(
153 buffer![0u32, 1, 2].into_array(),
154 buffer![1i32, 2, 3].into_array(),
155 )
156 .unwrap();
157
158 let res = compare(
159 &dict,
160 &ConstantArray::new(Scalar::from(1i32), 3),
161 Operator::Gt,
162 )
163 .unwrap();
164 let res = res.to_bool().unwrap();
165 assert_eq!(
166 res.boolean_buffer().iter().collect::<Vec<_>>(),
167 vec![false, true, true]
168 );
169 }
170
171 #[test]
172 fn test_compare_nullable() {
173 let dict = DictArray::try_new(
174 PrimitiveArray::new(
175 buffer![0u32, 1, 2],
176 Validity::from_iter([false, true, false]),
177 )
178 .into_array(),
179 PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array(),
180 )
181 .unwrap();
182
183 let res = compare(
184 &dict,
185 &ConstantArray::new(Scalar::primitive(4i32, Nullability::Nullable), 3),
186 Operator::Eq,
187 )
188 .unwrap();
189 let res = res.to_bool().unwrap();
190 assert_eq!(
191 res.boolean_buffer().iter().collect::<Vec<_>>(),
192 vec![false, false, false]
193 );
194 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
195 assert_eq!(
196 res.validity_mask().unwrap(),
197 Mask::from_iter([false, true, false])
198 );
199 }
200
201 #[test]
202 fn test_compare_null_values() {
203 let dict = DictArray::try_new(
204 buffer![0u32, 1, 2].into_array(),
205 PrimitiveArray::new(
206 buffer![1i32, 2, 0],
207 Validity::from_iter([true, true, false]),
208 )
209 .into_array(),
210 )
211 .unwrap();
212
213 let res = compare(
214 &dict,
215 &ConstantArray::new(Scalar::primitive(4i32, Nullability::NonNullable), 3),
216 Operator::Eq,
217 )
218 .unwrap();
219 let res = res.to_bool().unwrap();
220 assert_eq!(
221 res.boolean_buffer().iter().collect::<Vec<_>>(),
222 vec![false, false, false]
223 );
224 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
225 assert_eq!(
226 res.validity_mask().unwrap(),
227 Mask::from_iter([true, true, false])
228 );
229 }
230}