vortex_array/arrays/dict/compute/
cast.rs1use vortex_error::VortexResult;
5
6use super::Dict;
7use super::DictArray;
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::array::ArrayView;
11use crate::arrays::dict::DictArrayExt;
12use crate::arrays::dict::DictArraySlotsExt;
13use crate::builtins::ArrayBuiltins;
14use crate::dtype::DType;
15use crate::scalar_fn::fns::cast::CastReduce;
16use crate::validity::Validity;
17
18impl CastReduce for Dict {
19 fn cast(array: ArrayView<'_, Dict>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
20 if !dtype.is_nullable()
23 && array.values().dtype().is_nullable()
24 && !matches!(
25 array.values().validity()?,
26 Validity::NonNullable | Validity::AllValid
27 )
28 {
29 return Ok(None);
30 }
31 let casted_values = array.values().cast(dtype.clone())?;
33
34 let casted_codes = if array.codes().dtype().is_nullable() && !dtype.is_nullable() {
36 array
37 .codes()
38 .cast(array.codes().dtype().with_nullability(dtype.nullability()))?
39 } else {
40 array.codes().clone()
41 };
42
43 Ok(Some(
45 unsafe {
46 DictArray::new_unchecked(casted_codes, casted_values)
47 .set_all_values_referenced(array.has_all_values_referenced())
48 }
49 .into_array(),
50 ))
51 }
52}
53
54#[cfg(test)]
55mod tests {
56 use rstest::rstest;
57 use vortex_buffer::buffer;
58
59 use crate::IntoArray;
60 use crate::ToCanonical;
61 use crate::arrays::Dict;
62 use crate::arrays::PrimitiveArray;
63 use crate::arrays::dict::DictArraySlotsExt;
64 use crate::assert_arrays_eq;
65 use crate::builders::dict::dict_encode;
66 use crate::builtins::ArrayBuiltins;
67 use crate::compute::conformance::cast::test_cast_conformance;
68 use crate::dtype::DType;
69 use crate::dtype::Nullability;
70 use crate::dtype::PType;
71
72 #[test]
73 fn test_cast_dict_to_wider_type() {
74 let values = buffer![1i32, 2, 3, 2, 1].into_array();
75 let dict = dict_encode(&values).unwrap();
76
77 let casted = dict
78 .into_array()
79 .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
80 .unwrap();
81 assert_eq!(
82 casted.dtype(),
83 &DType::Primitive(PType::I64, Nullability::NonNullable)
84 );
85
86 let decoded = casted.to_primitive();
87 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 2, 1]));
88 }
89
90 #[test]
91 fn test_cast_dict_nullable() {
92 let values =
93 PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
94 let dict = dict_encode(&values.into_array()).unwrap();
95
96 let casted = dict
97 .into_array()
98 .cast(DType::Primitive(PType::I64, Nullability::Nullable))
99 .unwrap();
100 assert_eq!(
101 casted.dtype(),
102 &DType::Primitive(PType::I64, Nullability::Nullable)
103 );
104 }
105
106 #[test]
107 fn test_cast_dict_allvalid_to_nonnullable_and_back() {
108 let values = buffer![10i32, 20, 30, 40].into_array();
110 let dict = dict_encode(&values).unwrap();
111
112 assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
114 assert_eq!(
115 dict.values().dtype().nullability(),
116 Nullability::NonNullable
117 );
118
119 let non_nullable = dict
121 .clone()
122 .into_array()
123 .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
124 .unwrap();
125 assert_eq!(
126 non_nullable.dtype(),
127 &DType::Primitive(PType::I32, Nullability::NonNullable)
128 );
129
130 let non_nullable_dict = non_nullable.as_::<Dict>();
132 assert_eq!(
133 non_nullable_dict.codes().dtype().nullability(),
134 Nullability::NonNullable
135 );
136 assert_eq!(
137 non_nullable_dict.values().dtype().nullability(),
138 Nullability::NonNullable
139 );
140
141 let nullable = non_nullable
143 .cast(DType::Primitive(PType::I32, Nullability::Nullable))
144 .unwrap();
145 assert_eq!(
146 nullable.dtype(),
147 &DType::Primitive(PType::I32, Nullability::Nullable)
148 );
149
150 let nullable_dict = nullable.as_::<Dict>();
152 assert_eq!(
153 nullable_dict.codes().dtype().nullability(),
154 Nullability::NonNullable
155 );
156 assert_eq!(
157 nullable_dict.values().dtype().nullability(),
158 Nullability::Nullable
159 );
160
161 let back_to_non_nullable = nullable
163 .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
164 .unwrap();
165 assert_eq!(
166 back_to_non_nullable.dtype(),
167 &DType::Primitive(PType::I32, Nullability::NonNullable)
168 );
169
170 let original_values = dict.as_array().to_primitive();
172 let final_values = back_to_non_nullable.to_primitive();
173 assert_arrays_eq!(original_values, final_values);
174 }
175
176 #[rstest]
177 #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array()).unwrap().into_array())]
178 #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array()).unwrap().into_array())]
179 #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()).unwrap().into_array())]
180 #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array()).unwrap().into_array())]
181 fn test_cast_dict_conformance(#[case] array: crate::ArrayRef) {
182 test_cast_conformance(&array);
183 }
184
185 #[test]
186 fn test_cast_dict_with_unreferenced_null_values_to_nonnullable() {
187 use crate::arrays::DictArray;
188 use crate::validity::Validity;
189
190 let values = PrimitiveArray::new(
194 buffer![1.0f64, 0.0f64, 3.0f64],
195 Validity::from(vortex_buffer::BitBuffer::from(vec![true, false, true])),
196 )
197 .into_array();
198 let codes = buffer![0u32, 2, 0].into_array();
199 let dict = DictArray::try_new(codes, values).unwrap();
200
201 assert_eq!(
203 dict.dtype(),
204 &DType::Primitive(PType::F64, Nullability::Nullable)
205 );
206
207 let result = dict
209 .into_array()
210 .cast(DType::Primitive(PType::F64, Nullability::NonNullable));
211 assert!(
212 result.is_ok(),
213 "cast to NonNullable should succeed for dict with only unreferenced null values"
214 );
215 let casted = result.unwrap();
216 assert_eq!(
217 casted.dtype(),
218 &DType::Primitive(PType::F64, Nullability::NonNullable)
219 );
220 assert_arrays_eq!(
221 casted.to_primitive(),
222 PrimitiveArray::from_iter([1.0f64, 3.0, 1.0])
223 );
224 }
225}