vortex_array/arrays/dict/compute/
cast.rs1use vortex_dtype::DType;
5use vortex_error::VortexResult;
6
7use super::DictArray;
8use super::DictVTable;
9use crate::Array;
10use crate::ArrayRef;
11use crate::IntoArray;
12use crate::compute::CastKernel;
13use crate::compute::CastKernelAdapter;
14use crate::compute::cast;
15use crate::register_kernel;
16
17impl CastKernel for DictVTable {
18 fn cast(&self, array: &DictArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
19 let casted_values = cast(array.values(), dtype)?;
21
22 let casted_codes = if array.codes().dtype().is_nullable() && !dtype.is_nullable() {
24 cast(
25 array.codes(),
26 &array.codes().dtype().with_nullability(dtype.nullability()),
27 )?
28 } else {
29 array.codes().clone()
30 };
31
32 Ok(Some(
34 unsafe {
35 DictArray::new_unchecked(casted_codes, casted_values)
36 .set_all_values_referenced(array.has_all_values_referenced())
37 }
38 .into_array(),
39 ))
40 }
41}
42
43register_kernel!(CastKernelAdapter(DictVTable).lift());
44
45#[cfg(test)]
46mod tests {
47 use rstest::rstest;
48 use vortex_buffer::buffer;
49 use vortex_dtype::DType;
50 use vortex_dtype::Nullability;
51 use vortex_dtype::PType;
52
53 use crate::IntoArray;
54 use crate::ToCanonical;
55 use crate::arrays::PrimitiveArray;
56 use crate::arrays::dict::DictVTable;
57 use crate::assert_arrays_eq;
58 use crate::builders::dict::dict_encode;
59 use crate::compute::cast;
60 use crate::compute::conformance::cast::test_cast_conformance;
61
62 #[test]
63 fn test_cast_dict_to_wider_type() {
64 let values = buffer![1i32, 2, 3, 2, 1].into_array();
65 let dict = dict_encode(&values).unwrap();
66
67 let casted = cast(
68 dict.as_ref(),
69 &DType::Primitive(PType::I64, Nullability::NonNullable),
70 )
71 .unwrap();
72 assert_eq!(
73 casted.dtype(),
74 &DType::Primitive(PType::I64, Nullability::NonNullable)
75 );
76
77 let decoded = casted.to_primitive();
78 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 2, 1]));
79 }
80
81 #[test]
82 fn test_cast_dict_nullable() {
83 let values =
84 PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
85 let dict = dict_encode(values.as_ref()).unwrap();
86
87 let casted = cast(
88 dict.as_ref(),
89 &DType::Primitive(PType::I64, Nullability::Nullable),
90 )
91 .unwrap();
92 assert_eq!(
93 casted.dtype(),
94 &DType::Primitive(PType::I64, Nullability::Nullable)
95 );
96 }
97
98 #[test]
99 fn test_cast_dict_allvalid_to_nonnullable_and_back() {
100 let values = buffer![10i32, 20, 30, 40].into_array();
102 let dict = dict_encode(&values).unwrap();
103
104 assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
106 assert_eq!(
107 dict.values().dtype().nullability(),
108 Nullability::NonNullable
109 );
110
111 let non_nullable = cast(
113 dict.as_ref(),
114 &DType::Primitive(PType::I32, Nullability::NonNullable),
115 )
116 .unwrap();
117 assert_eq!(
118 non_nullable.dtype(),
119 &DType::Primitive(PType::I32, Nullability::NonNullable)
120 );
121
122 let non_nullable_dict = non_nullable.as_::<DictVTable>();
124 assert_eq!(
125 non_nullable_dict.codes().dtype().nullability(),
126 Nullability::NonNullable
127 );
128 assert_eq!(
129 non_nullable_dict.values().dtype().nullability(),
130 Nullability::NonNullable
131 );
132
133 let nullable = cast(
135 non_nullable.as_ref(),
136 &DType::Primitive(PType::I32, Nullability::Nullable),
137 )
138 .unwrap();
139 assert_eq!(
140 nullable.dtype(),
141 &DType::Primitive(PType::I32, Nullability::Nullable)
142 );
143
144 let nullable_dict = nullable.as_::<DictVTable>();
146 assert_eq!(
147 nullable_dict.codes().dtype().nullability(),
148 Nullability::NonNullable
149 );
150 assert_eq!(
151 nullable_dict.values().dtype().nullability(),
152 Nullability::Nullable
153 );
154
155 let back_to_non_nullable = cast(
157 nullable.as_ref(),
158 &DType::Primitive(PType::I32, Nullability::NonNullable),
159 )
160 .unwrap();
161 assert_eq!(
162 back_to_non_nullable.dtype(),
163 &DType::Primitive(PType::I32, Nullability::NonNullable)
164 );
165
166 let back_dict = back_to_non_nullable.as_::<DictVTable>();
168 assert_eq!(
169 back_dict.codes().dtype().nullability(),
170 Nullability::NonNullable
171 );
172 assert_eq!(
173 back_dict.values().dtype().nullability(),
174 Nullability::NonNullable
175 );
176
177 let original_values = dict.to_primitive();
179 let final_values = back_dict.to_primitive();
180 assert_arrays_eq!(original_values, final_values);
181 }
182
183 #[rstest]
184 #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array()).unwrap().into_array())]
185 #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array()).unwrap().into_array())]
186 #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()).unwrap().into_array())]
187 #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array()).unwrap().into_array())]
188 fn test_cast_dict_conformance(#[case] array: crate::ArrayRef) {
189 test_cast_conformance(array.as_ref());
190 }
191}