vortex_array/arrays/dict/compute/
cast.rs1use vortex_error::VortexResult;
5
6use super::Dict;
7use super::DictArray;
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::array::ArrayView;
11use crate::arrays::dict::DictArrayExt;
12use crate::arrays::dict::DictArraySlotsExt;
13use crate::builtins::ArrayBuiltins;
14use crate::dtype::DType;
15use crate::scalar_fn::fns::cast::CastReduce;
16
17impl CastReduce for Dict {
18 fn cast(array: ArrayView<'_, Dict>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
19 if !dtype.is_nullable() && !array.values().validity()?.no_nulls() {
22 return Ok(None);
23 }
24 let casted_values = array.values().cast(dtype.clone())?;
26
27 let casted_codes = if array.codes().dtype().is_nullable() && !dtype.is_nullable() {
29 array
30 .codes()
31 .cast(array.codes().dtype().with_nullability(dtype.nullability()))?
32 } else {
33 array.codes().clone()
34 };
35
36 Ok(Some(
38 unsafe {
39 DictArray::new_unchecked(casted_codes, casted_values)
40 .set_all_values_referenced(array.has_all_values_referenced())
41 }
42 .into_array(),
43 ))
44 }
45}
46
47#[cfg(test)]
48mod tests {
49 use rstest::rstest;
50 use vortex_buffer::buffer;
51
52 use crate::IntoArray;
53 #[expect(deprecated)]
54 use crate::ToCanonical as _;
55 use crate::arrays::Dict;
56 use crate::arrays::PrimitiveArray;
57 use crate::arrays::dict::DictArraySlotsExt;
58 use crate::assert_arrays_eq;
59 use crate::builders::dict::dict_encode;
60 use crate::builtins::ArrayBuiltins;
61 use crate::compute::conformance::cast::test_cast_conformance;
62 use crate::dtype::DType;
63 use crate::dtype::Nullability;
64 use crate::dtype::PType;
65
66 #[test]
67 fn test_cast_dict_to_wider_type() {
68 let values = buffer![1i32, 2, 3, 2, 1].into_array();
69 let dict = dict_encode(&values).unwrap();
70
71 let casted = dict
72 .into_array()
73 .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
74 .unwrap();
75 assert_eq!(
76 casted.dtype(),
77 &DType::Primitive(PType::I64, Nullability::NonNullable)
78 );
79
80 #[expect(deprecated)]
81 let decoded = casted.to_primitive();
82 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 2, 1]));
83 }
84
85 #[test]
86 fn test_cast_dict_nullable() {
87 let values =
88 PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
89 let dict = dict_encode(&values.into_array()).unwrap();
90
91 let casted = dict
92 .into_array()
93 .cast(DType::Primitive(PType::I64, Nullability::Nullable))
94 .unwrap();
95 assert_eq!(
96 casted.dtype(),
97 &DType::Primitive(PType::I64, Nullability::Nullable)
98 );
99 }
100
101 #[test]
102 fn test_cast_dict_allvalid_to_nonnullable_and_back() {
103 let values = buffer![10i32, 20, 30, 40].into_array();
105 let dict = dict_encode(&values).unwrap();
106
107 assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
109 assert_eq!(
110 dict.values().dtype().nullability(),
111 Nullability::NonNullable
112 );
113
114 let non_nullable = dict
116 .clone()
117 .into_array()
118 .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
119 .unwrap();
120 assert_eq!(
121 non_nullable.dtype(),
122 &DType::Primitive(PType::I32, Nullability::NonNullable)
123 );
124
125 let non_nullable_dict = non_nullable.as_::<Dict>();
127 assert_eq!(
128 non_nullable_dict.codes().dtype().nullability(),
129 Nullability::NonNullable
130 );
131 assert_eq!(
132 non_nullable_dict.values().dtype().nullability(),
133 Nullability::NonNullable
134 );
135
136 let nullable = non_nullable
138 .cast(DType::Primitive(PType::I32, Nullability::Nullable))
139 .unwrap();
140 assert_eq!(
141 nullable.dtype(),
142 &DType::Primitive(PType::I32, Nullability::Nullable)
143 );
144
145 let nullable_dict = nullable.as_::<Dict>();
147 assert_eq!(
148 nullable_dict.codes().dtype().nullability(),
149 Nullability::NonNullable
150 );
151 assert_eq!(
152 nullable_dict.values().dtype().nullability(),
153 Nullability::Nullable
154 );
155
156 let back_to_non_nullable = nullable
158 .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
159 .unwrap();
160 assert_eq!(
161 back_to_non_nullable.dtype(),
162 &DType::Primitive(PType::I32, Nullability::NonNullable)
163 );
164
165 #[expect(deprecated)]
167 let original_values = dict.as_array().to_primitive();
168 #[expect(deprecated)]
169 let final_values = back_to_non_nullable.to_primitive();
170 assert_arrays_eq!(original_values, final_values);
171 }
172
173 #[rstest]
174 #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array()).unwrap().into_array())]
175 #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array()).unwrap().into_array())]
176 #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()).unwrap().into_array())]
177 #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array()).unwrap().into_array())]
178 fn test_cast_dict_conformance(#[case] array: crate::ArrayRef) {
179 test_cast_conformance(&array);
180 }
181
182 #[test]
183 fn test_cast_dict_with_unreferenced_null_values_to_nonnullable() {
184 use crate::arrays::DictArray;
185 use crate::validity::Validity;
186
187 let values = PrimitiveArray::new(
191 buffer![1.0f64, 0.0f64, 3.0f64],
192 Validity::from(vortex_buffer::BitBuffer::from(vec![true, false, true])),
193 )
194 .into_array();
195 let codes = buffer![0u32, 2, 0].into_array();
196 let dict = DictArray::try_new(codes, values).unwrap();
197
198 assert_eq!(
200 dict.dtype(),
201 &DType::Primitive(PType::F64, Nullability::Nullable)
202 );
203
204 let result = dict
206 .into_array()
207 .cast(DType::Primitive(PType::F64, Nullability::NonNullable));
208 assert!(
209 result.is_ok(),
210 "cast to NonNullable should succeed for dict with only unreferenced null values"
211 );
212 let casted = result.unwrap();
213 assert_eq!(
214 casted.dtype(),
215 &DType::Primitive(PType::F64, Nullability::NonNullable)
216 );
217 #[expect(deprecated)]
218 let casted_prim = casted.to_primitive();
219 assert_arrays_eq!(casted_prim, PrimitiveArray::from_iter([1.0f64, 3.0, 1.0]));
220 }
221}