vortex_array/arrays/dict/compute/
cast.rs1use vortex_error::VortexResult;
5
6use super::Dict;
7use super::DictArray;
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::array::ArrayView;
11use crate::arrays::dict::DictArrayExt;
12use crate::arrays::dict::DictArraySlotsExt;
13use crate::builtins::ArrayBuiltins;
14use crate::dtype::DType;
15use crate::scalar_fn::fns::cast::CastReduce;
16
17impl CastReduce for Dict {
18 fn cast(array: ArrayView<'_, Dict>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
19 if !dtype.is_nullable()
22 && array.values().dtype().is_nullable()
23 && !array.values().all_valid()?
24 {
25 return Ok(None);
26 }
27 let casted_values = array.values().cast(dtype.clone())?;
29
30 let casted_codes = if array.codes().dtype().is_nullable() && !dtype.is_nullable() {
32 array
33 .codes()
34 .cast(array.codes().dtype().with_nullability(dtype.nullability()))?
35 } else {
36 array.codes().clone()
37 };
38
39 Ok(Some(
41 unsafe {
42 DictArray::new_unchecked(casted_codes, casted_values)
43 .set_all_values_referenced(array.has_all_values_referenced())
44 }
45 .into_array(),
46 ))
47 }
48}
49
50#[cfg(test)]
51mod tests {
52 use rstest::rstest;
53 use vortex_buffer::buffer;
54
55 use crate::IntoArray;
56 use crate::ToCanonical;
57 use crate::arrays::Dict;
58 use crate::arrays::PrimitiveArray;
59 use crate::arrays::dict::DictArraySlotsExt;
60 use crate::assert_arrays_eq;
61 use crate::builders::dict::dict_encode;
62 use crate::builtins::ArrayBuiltins;
63 use crate::compute::conformance::cast::test_cast_conformance;
64 use crate::dtype::DType;
65 use crate::dtype::Nullability;
66 use crate::dtype::PType;
67
68 #[test]
69 fn test_cast_dict_to_wider_type() {
70 let values = buffer![1i32, 2, 3, 2, 1].into_array();
71 let dict = dict_encode(&values).unwrap();
72
73 let casted = dict
74 .into_array()
75 .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
76 .unwrap();
77 assert_eq!(
78 casted.dtype(),
79 &DType::Primitive(PType::I64, Nullability::NonNullable)
80 );
81
82 let decoded = casted.to_primitive();
83 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 2, 1]));
84 }
85
86 #[test]
87 fn test_cast_dict_nullable() {
88 let values =
89 PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
90 let dict = dict_encode(&values.into_array()).unwrap();
91
92 let casted = dict
93 .into_array()
94 .cast(DType::Primitive(PType::I64, Nullability::Nullable))
95 .unwrap();
96 assert_eq!(
97 casted.dtype(),
98 &DType::Primitive(PType::I64, Nullability::Nullable)
99 );
100 }
101
102 #[test]
103 fn test_cast_dict_allvalid_to_nonnullable_and_back() {
104 let values = buffer![10i32, 20, 30, 40].into_array();
106 let dict = dict_encode(&values).unwrap();
107
108 assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
110 assert_eq!(
111 dict.values().dtype().nullability(),
112 Nullability::NonNullable
113 );
114
115 let non_nullable = dict
117 .clone()
118 .into_array()
119 .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
120 .unwrap();
121 assert_eq!(
122 non_nullable.dtype(),
123 &DType::Primitive(PType::I32, Nullability::NonNullable)
124 );
125
126 let non_nullable_dict = non_nullable.as_::<Dict>();
128 assert_eq!(
129 non_nullable_dict.codes().dtype().nullability(),
130 Nullability::NonNullable
131 );
132 assert_eq!(
133 non_nullable_dict.values().dtype().nullability(),
134 Nullability::NonNullable
135 );
136
137 let nullable = non_nullable
139 .cast(DType::Primitive(PType::I32, Nullability::Nullable))
140 .unwrap();
141 assert_eq!(
142 nullable.dtype(),
143 &DType::Primitive(PType::I32, Nullability::Nullable)
144 );
145
146 let nullable_dict = nullable.as_::<Dict>();
148 assert_eq!(
149 nullable_dict.codes().dtype().nullability(),
150 Nullability::NonNullable
151 );
152 assert_eq!(
153 nullable_dict.values().dtype().nullability(),
154 Nullability::Nullable
155 );
156
157 let back_to_non_nullable = nullable
159 .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
160 .unwrap();
161 assert_eq!(
162 back_to_non_nullable.dtype(),
163 &DType::Primitive(PType::I32, Nullability::NonNullable)
164 );
165
166 let back_dict = back_to_non_nullable.as_::<Dict>();
168 assert_eq!(
169 back_dict.codes().dtype().nullability(),
170 Nullability::NonNullable
171 );
172 assert_eq!(
173 back_dict.values().dtype().nullability(),
174 Nullability::NonNullable
175 );
176
177 let original_values = dict.as_array().to_primitive();
179 let final_values = back_dict.array().to_primitive();
180 assert_arrays_eq!(original_values, final_values);
181 }
182
183 #[rstest]
184 #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array()).unwrap().into_array())]
185 #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array()).unwrap().into_array())]
186 #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()).unwrap().into_array())]
187 #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array()).unwrap().into_array())]
188 fn test_cast_dict_conformance(#[case] array: crate::ArrayRef) {
189 test_cast_conformance(&array);
190 }
191
192 #[test]
193 fn test_cast_dict_with_unreferenced_null_values_to_nonnullable() {
194 use crate::arrays::DictArray;
195 use crate::validity::Validity;
196
197 let values = PrimitiveArray::new(
201 buffer![1.0f64, 0.0f64, 3.0f64],
202 Validity::from(vortex_buffer::BitBuffer::from(vec![true, false, true])),
203 )
204 .into_array();
205 let codes = buffer![0u32, 2, 0].into_array();
206 let dict = DictArray::try_new(codes, values).unwrap();
207
208 assert_eq!(
210 dict.dtype(),
211 &DType::Primitive(PType::F64, Nullability::Nullable)
212 );
213
214 let result = dict
216 .into_array()
217 .cast(DType::Primitive(PType::F64, Nullability::NonNullable));
218 assert!(
219 result.is_ok(),
220 "cast to NonNullable should succeed for dict with only unreferenced null values"
221 );
222 let casted = result.unwrap();
223 assert_eq!(
224 casted.dtype(),
225 &DType::Primitive(PType::F64, Nullability::NonNullable)
226 );
227 assert_arrays_eq!(
228 casted.to_primitive(),
229 PrimitiveArray::from_iter([1.0f64, 3.0, 1.0])
230 );
231 }
232}