vortex_dict/compute/
cast.rs1use vortex_array::compute::{CastKernel, CastKernelAdapter, cast};
5use vortex_array::{ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::DType;
7use vortex_error::VortexResult;
8
9use crate::{DictArray, DictVTable};
10
11impl CastKernel for DictVTable {
12 fn cast(&self, array: &DictArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
13 let casted_values = cast(array.values(), dtype)?;
15
16 let casted_codes = if dtype.nullability() != array.codes().dtype().nullability() {
17 cast(
18 array.codes(),
19 &array.codes().dtype().with_nullability(dtype.nullability()),
20 )?
21 } else {
22 array.codes().clone()
23 };
24
25 Ok(Some(
27 DictArray::try_new(casted_codes, casted_values)?.into_array(),
28 ))
29 }
30}
31
32register_kernel!(CastKernelAdapter(DictVTable).lift());
33
34#[cfg(test)]
35mod tests {
36 use rstest::rstest;
37 use vortex_array::IntoArray;
38 use vortex_array::arrays::PrimitiveArray;
39 use vortex_array::compute::cast;
40 use vortex_array::compute::conformance::cast::test_cast_conformance;
41 use vortex_buffer::buffer;
42 use vortex_dtype::{DType, Nullability, PType};
43
44 use crate::DictVTable;
45 use crate::builders::dict_encode;
46
47 #[test]
48 fn test_cast_dict_to_wider_type() {
49 let values = buffer![1i32, 2, 3, 2, 1].into_array();
50 let dict = dict_encode(&values).unwrap();
51
52 let casted = cast(
53 dict.as_ref(),
54 &DType::Primitive(PType::I64, Nullability::NonNullable),
55 )
56 .unwrap();
57 assert_eq!(
58 casted.dtype(),
59 &DType::Primitive(PType::I64, Nullability::NonNullable)
60 );
61
62 let decoded = casted.to_canonical().unwrap().into_primitive().unwrap();
63 assert_eq!(decoded.as_slice::<i64>(), &[1i64, 2, 3, 2, 1]);
64 }
65
66 #[test]
67 fn test_cast_dict_nullable() {
68 let values =
69 PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
70 let dict = dict_encode(values.as_ref()).unwrap();
71
72 let casted = cast(
73 dict.as_ref(),
74 &DType::Primitive(PType::I64, Nullability::Nullable),
75 )
76 .unwrap();
77 assert_eq!(
78 casted.dtype(),
79 &DType::Primitive(PType::I64, Nullability::Nullable)
80 );
81 }
82
83 #[test]
84 fn test_cast_dict_allvalid_to_nonnullable_and_back() {
85 let values = buffer![10i32, 20, 30, 40].into_array();
87 let dict = dict_encode(&values).unwrap();
88
89 assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
91 assert_eq!(
92 dict.values().dtype().nullability(),
93 Nullability::NonNullable
94 );
95
96 let non_nullable = cast(
98 dict.as_ref(),
99 &DType::Primitive(PType::I32, Nullability::NonNullable),
100 )
101 .unwrap();
102 assert_eq!(
103 non_nullable.dtype(),
104 &DType::Primitive(PType::I32, Nullability::NonNullable)
105 );
106
107 let non_nullable_dict = non_nullable.as_::<DictVTable>();
109 assert_eq!(
110 non_nullable_dict.codes().dtype().nullability(),
111 Nullability::NonNullable
112 );
113 assert_eq!(
114 non_nullable_dict.values().dtype().nullability(),
115 Nullability::NonNullable
116 );
117
118 let nullable = cast(
120 non_nullable.as_ref(),
121 &DType::Primitive(PType::I32, Nullability::Nullable),
122 )
123 .unwrap();
124 assert_eq!(
125 nullable.dtype(),
126 &DType::Primitive(PType::I32, Nullability::Nullable)
127 );
128
129 let nullable_dict = nullable.as_::<DictVTable>();
131 assert_eq!(
132 nullable_dict.codes().dtype().nullability(),
133 Nullability::Nullable
134 );
135 assert_eq!(
136 nullable_dict.values().dtype().nullability(),
137 Nullability::Nullable
138 );
139
140 let back_to_non_nullable = cast(
142 nullable.as_ref(),
143 &DType::Primitive(PType::I32, Nullability::NonNullable),
144 )
145 .unwrap();
146 assert_eq!(
147 back_to_non_nullable.dtype(),
148 &DType::Primitive(PType::I32, Nullability::NonNullable)
149 );
150
151 let back_dict = back_to_non_nullable.as_::<DictVTable>();
153 assert_eq!(
154 back_dict.codes().dtype().nullability(),
155 Nullability::NonNullable
156 );
157 assert_eq!(
158 back_dict.values().dtype().nullability(),
159 Nullability::NonNullable
160 );
161
162 let original_values = dict.to_canonical().unwrap().into_primitive().unwrap();
164 let final_values = back_dict.to_canonical().unwrap().into_primitive().unwrap();
165 assert_eq!(
166 original_values.as_slice::<i32>(),
167 final_values.as_slice::<i32>()
168 );
169 }
170
171 #[rstest]
172 #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array()).unwrap().into_array())]
173 #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array()).unwrap().into_array())]
174 #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()).unwrap().into_array())]
175 #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array()).unwrap().into_array())]
176 fn test_cast_dict_conformance(#[case] array: vortex_array::ArrayRef) {
177 test_cast_conformance(array.as_ref());
178 }
179}