vortex_array/arrays/dict/compute/
cast.rs1use vortex_error::VortexResult;
5
6use super::DictArray;
7use super::DictVTable;
8use crate::ArrayRef;
9use crate::DynArray;
10use crate::IntoArray;
11use crate::builtins::ArrayBuiltins;
12use crate::dtype::DType;
13use crate::scalar_fn::fns::cast::CastReduce;
14
15impl CastReduce for DictVTable {
16 fn cast(array: &DictArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17 if !dtype.is_nullable()
20 && array.values().dtype().is_nullable()
21 && !array.values().all_valid()?
22 {
23 return Ok(None);
24 }
25 let casted_values = array.values().cast(dtype.clone())?;
27
28 let casted_codes = if array.codes().dtype().is_nullable() && !dtype.is_nullable() {
30 array
31 .codes()
32 .cast(array.codes().dtype().with_nullability(dtype.nullability()))?
33 } else {
34 array.codes().clone()
35 };
36
37 Ok(Some(
39 unsafe {
40 DictArray::new_unchecked(casted_codes, casted_values)
41 .set_all_values_referenced(array.has_all_values_referenced())
42 }
43 .into_array(),
44 ))
45 }
46}
47
48#[cfg(test)]
49mod tests {
50 use rstest::rstest;
51 use vortex_buffer::buffer;
52
53 use crate::IntoArray;
54 use crate::ToCanonical;
55 use crate::arrays::DictVTable;
56 use crate::arrays::PrimitiveArray;
57 use crate::assert_arrays_eq;
58 use crate::builders::dict::dict_encode;
59 use crate::builtins::ArrayBuiltins;
60 use crate::compute::conformance::cast::test_cast_conformance;
61 use crate::dtype::DType;
62 use crate::dtype::Nullability;
63 use crate::dtype::PType;
64
65 #[test]
66 fn test_cast_dict_to_wider_type() {
67 let values = buffer![1i32, 2, 3, 2, 1].into_array();
68 let dict = dict_encode(&values).unwrap();
69
70 let casted = dict
71 .into_array()
72 .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
73 .unwrap();
74 assert_eq!(
75 casted.dtype(),
76 &DType::Primitive(PType::I64, Nullability::NonNullable)
77 );
78
79 let decoded = casted.to_primitive();
80 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 2, 1]));
81 }
82
83 #[test]
84 fn test_cast_dict_nullable() {
85 let values =
86 PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
87 let dict = dict_encode(&values.into_array()).unwrap();
88
89 let casted = dict
90 .into_array()
91 .cast(DType::Primitive(PType::I64, Nullability::Nullable))
92 .unwrap();
93 assert_eq!(
94 casted.dtype(),
95 &DType::Primitive(PType::I64, Nullability::Nullable)
96 );
97 }
98
99 #[test]
100 fn test_cast_dict_allvalid_to_nonnullable_and_back() {
101 let values = buffer![10i32, 20, 30, 40].into_array();
103 let dict = dict_encode(&values).unwrap();
104
105 assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
107 assert_eq!(
108 dict.values().dtype().nullability(),
109 Nullability::NonNullable
110 );
111
112 let non_nullable = dict
114 .clone()
115 .into_array()
116 .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
117 .unwrap();
118 assert_eq!(
119 non_nullable.dtype(),
120 &DType::Primitive(PType::I32, Nullability::NonNullable)
121 );
122
123 let non_nullable_dict = non_nullable.as_::<DictVTable>();
125 assert_eq!(
126 non_nullable_dict.codes().dtype().nullability(),
127 Nullability::NonNullable
128 );
129 assert_eq!(
130 non_nullable_dict.values().dtype().nullability(),
131 Nullability::NonNullable
132 );
133
134 let nullable = non_nullable
136 .cast(DType::Primitive(PType::I32, Nullability::Nullable))
137 .unwrap();
138 assert_eq!(
139 nullable.dtype(),
140 &DType::Primitive(PType::I32, Nullability::Nullable)
141 );
142
143 let nullable_dict = nullable.as_::<DictVTable>();
145 assert_eq!(
146 nullable_dict.codes().dtype().nullability(),
147 Nullability::NonNullable
148 );
149 assert_eq!(
150 nullable_dict.values().dtype().nullability(),
151 Nullability::Nullable
152 );
153
154 let back_to_non_nullable = nullable
156 .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
157 .unwrap();
158 assert_eq!(
159 back_to_non_nullable.dtype(),
160 &DType::Primitive(PType::I32, Nullability::NonNullable)
161 );
162
163 let back_dict = back_to_non_nullable.as_::<DictVTable>();
165 assert_eq!(
166 back_dict.codes().dtype().nullability(),
167 Nullability::NonNullable
168 );
169 assert_eq!(
170 back_dict.values().dtype().nullability(),
171 Nullability::NonNullable
172 );
173
174 let original_values = dict.to_primitive();
176 let final_values = back_dict.to_primitive();
177 assert_arrays_eq!(original_values, final_values);
178 }
179
180 #[rstest]
181 #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array()).unwrap().into_array())]
182 #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array()).unwrap().into_array())]
183 #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()).unwrap().into_array())]
184 #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array()).unwrap().into_array())]
185 fn test_cast_dict_conformance(#[case] array: crate::ArrayRef) {
186 test_cast_conformance(&array);
187 }
188
189 #[test]
190 fn test_cast_dict_with_unreferenced_null_values_to_nonnullable() {
191 use crate::arrays::DictArray;
192 use crate::validity::Validity;
193
194 let values = PrimitiveArray::new(
198 buffer![1.0f64, 0.0f64, 3.0f64],
199 Validity::from(vortex_buffer::BitBuffer::from(vec![true, false, true])),
200 )
201 .into_array();
202 let codes = buffer![0u32, 2, 0].into_array();
203 let dict = DictArray::try_new(codes, values).unwrap();
204
205 assert_eq!(
207 dict.dtype(),
208 &DType::Primitive(PType::F64, Nullability::Nullable)
209 );
210
211 let result = dict
213 .into_array()
214 .cast(DType::Primitive(PType::F64, Nullability::NonNullable));
215 assert!(
216 result.is_ok(),
217 "cast to NonNullable should succeed for dict with only unreferenced null values"
218 );
219 let casted = result.unwrap();
220 assert_eq!(
221 casted.dtype(),
222 &DType::Primitive(PType::F64, Nullability::NonNullable)
223 );
224 assert_arrays_eq!(
225 casted.to_primitive(),
226 PrimitiveArray::from_iter([1.0f64, 3.0, 1.0])
227 );
228 }
229}