vortex_dict/compute/
cast.rs1use vortex_array::compute::{CastKernel, CastKernelAdapter, cast};
5use vortex_array::{ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::DType;
7use vortex_error::VortexResult;
8
9use crate::{DictArray, DictVTable};
10
11impl CastKernel for DictVTable {
12 fn cast(&self, array: &DictArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
13 let casted_values = cast(array.values(), dtype)?;
15
16 let casted_codes = if dtype.nullability() != array.codes().dtype().nullability() {
17 cast(
18 array.codes(),
19 &array.codes().dtype().with_nullability(dtype.nullability()),
20 )?
21 } else {
22 array.codes().clone()
23 };
24
25 unsafe {
27 Ok(Some(
28 DictArray::new_unchecked(casted_codes, casted_values).into_array(),
29 ))
30 }
31 }
32}
33
34register_kernel!(CastKernelAdapter(DictVTable).lift());
35
36#[cfg(test)]
37mod tests {
38 use rstest::rstest;
39 use vortex_array::arrays::PrimitiveArray;
40 use vortex_array::compute::cast;
41 use vortex_array::compute::conformance::cast::test_cast_conformance;
42 use vortex_array::{IntoArray, ToCanonical};
43 use vortex_buffer::buffer;
44 use vortex_dtype::{DType, Nullability, PType};
45
46 use crate::DictVTable;
47 use crate::builders::dict_encode;
48
49 #[test]
50 fn test_cast_dict_to_wider_type() {
51 let values = buffer![1i32, 2, 3, 2, 1].into_array();
52 let dict = dict_encode(&values).unwrap();
53
54 let casted = cast(
55 dict.as_ref(),
56 &DType::Primitive(PType::I64, Nullability::NonNullable),
57 )
58 .unwrap();
59 assert_eq!(
60 casted.dtype(),
61 &DType::Primitive(PType::I64, Nullability::NonNullable)
62 );
63
64 let decoded = casted.to_primitive();
65 assert_eq!(decoded.as_slice::<i64>(), &[1i64, 2, 3, 2, 1]);
66 }
67
68 #[test]
69 fn test_cast_dict_nullable() {
70 let values =
71 PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
72 let dict = dict_encode(values.as_ref()).unwrap();
73
74 let casted = cast(
75 dict.as_ref(),
76 &DType::Primitive(PType::I64, Nullability::Nullable),
77 )
78 .unwrap();
79 assert_eq!(
80 casted.dtype(),
81 &DType::Primitive(PType::I64, Nullability::Nullable)
82 );
83 }
84
85 #[test]
86 fn test_cast_dict_allvalid_to_nonnullable_and_back() {
87 let values = buffer![10i32, 20, 30, 40].into_array();
89 let dict = dict_encode(&values).unwrap();
90
91 assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
93 assert_eq!(
94 dict.values().dtype().nullability(),
95 Nullability::NonNullable
96 );
97
98 let non_nullable = cast(
100 dict.as_ref(),
101 &DType::Primitive(PType::I32, Nullability::NonNullable),
102 )
103 .unwrap();
104 assert_eq!(
105 non_nullable.dtype(),
106 &DType::Primitive(PType::I32, Nullability::NonNullable)
107 );
108
109 let non_nullable_dict = non_nullable.as_::<DictVTable>();
111 assert_eq!(
112 non_nullable_dict.codes().dtype().nullability(),
113 Nullability::NonNullable
114 );
115 assert_eq!(
116 non_nullable_dict.values().dtype().nullability(),
117 Nullability::NonNullable
118 );
119
120 let nullable = cast(
122 non_nullable.as_ref(),
123 &DType::Primitive(PType::I32, Nullability::Nullable),
124 )
125 .unwrap();
126 assert_eq!(
127 nullable.dtype(),
128 &DType::Primitive(PType::I32, Nullability::Nullable)
129 );
130
131 let nullable_dict = nullable.as_::<DictVTable>();
133 assert_eq!(
134 nullable_dict.codes().dtype().nullability(),
135 Nullability::Nullable
136 );
137 assert_eq!(
138 nullable_dict.values().dtype().nullability(),
139 Nullability::Nullable
140 );
141
142 let back_to_non_nullable = cast(
144 nullable.as_ref(),
145 &DType::Primitive(PType::I32, Nullability::NonNullable),
146 )
147 .unwrap();
148 assert_eq!(
149 back_to_non_nullable.dtype(),
150 &DType::Primitive(PType::I32, Nullability::NonNullable)
151 );
152
153 let back_dict = back_to_non_nullable.as_::<DictVTable>();
155 assert_eq!(
156 back_dict.codes().dtype().nullability(),
157 Nullability::NonNullable
158 );
159 assert_eq!(
160 back_dict.values().dtype().nullability(),
161 Nullability::NonNullable
162 );
163
164 let original_values = dict.to_primitive();
166 let final_values = back_dict.to_primitive();
167 assert_eq!(
168 original_values.as_slice::<i32>(),
169 final_values.as_slice::<i32>()
170 );
171 }
172
173 #[rstest]
174 #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array()).unwrap().into_array())]
175 #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array()).unwrap().into_array())]
176 #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()).unwrap().into_array())]
177 #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array()).unwrap().into_array())]
178 fn test_cast_dict_conformance(#[case] array: vortex_array::ArrayRef) {
179 test_cast_conformance(array.as_ref());
180 }
181}