vortex_array/arrays/dict/compute/
cast.rs1use vortex_dtype::DType;
5use vortex_error::VortexResult;
6
7use super::{DictArray, DictVTable};
8use crate::compute::{CastKernel, CastKernelAdapter, cast};
9use crate::{Array, ArrayRef, IntoArray, register_kernel};
10
11impl CastKernel for DictVTable {
12 fn cast(&self, array: &DictArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
13 let casted_values = cast(array.values(), dtype)?;
15
16 let casted_codes = if array.codes().dtype().is_nullable() && !dtype.is_nullable() {
18 cast(
19 array.codes(),
20 &array.codes().dtype().with_nullability(dtype.nullability()),
21 )?
22 } else {
23 array.codes().clone()
24 };
25
26 Ok(Some(
28 unsafe { DictArray::new_unchecked(casted_codes, casted_values) }.into_array(),
29 ))
30 }
31}
32
33register_kernel!(CastKernelAdapter(DictVTable).lift());
34
35#[cfg(test)]
36mod tests {
37 use rstest::rstest;
38 use vortex_buffer::buffer;
39 use vortex_dtype::{DType, Nullability, PType};
40
41 use crate::arrays::PrimitiveArray;
42 use crate::arrays::dict::DictVTable;
43 use crate::builders::dict::dict_encode;
44 use crate::compute::cast;
45 use crate::compute::conformance::cast::test_cast_conformance;
46 use crate::{IntoArray, ToCanonical, assert_arrays_eq};
47
48 #[test]
49 fn test_cast_dict_to_wider_type() {
50 let values = buffer![1i32, 2, 3, 2, 1].into_array();
51 let dict = dict_encode(&values).unwrap();
52
53 let casted = cast(
54 dict.as_ref(),
55 &DType::Primitive(PType::I64, Nullability::NonNullable),
56 )
57 .unwrap();
58 assert_eq!(
59 casted.dtype(),
60 &DType::Primitive(PType::I64, Nullability::NonNullable)
61 );
62
63 let decoded = casted.to_primitive();
64 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 2, 1]));
65 }
66
67 #[test]
68 fn test_cast_dict_nullable() {
69 let values =
70 PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
71 let dict = dict_encode(values.as_ref()).unwrap();
72
73 let casted = cast(
74 dict.as_ref(),
75 &DType::Primitive(PType::I64, Nullability::Nullable),
76 )
77 .unwrap();
78 assert_eq!(
79 casted.dtype(),
80 &DType::Primitive(PType::I64, Nullability::Nullable)
81 );
82 }
83
84 #[test]
85 fn test_cast_dict_allvalid_to_nonnullable_and_back() {
86 let values = buffer![10i32, 20, 30, 40].into_array();
88 let dict = dict_encode(&values).unwrap();
89
90 assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
92 assert_eq!(
93 dict.values().dtype().nullability(),
94 Nullability::NonNullable
95 );
96
97 let non_nullable = cast(
99 dict.as_ref(),
100 &DType::Primitive(PType::I32, Nullability::NonNullable),
101 )
102 .unwrap();
103 assert_eq!(
104 non_nullable.dtype(),
105 &DType::Primitive(PType::I32, Nullability::NonNullable)
106 );
107
108 let non_nullable_dict = non_nullable.as_::<DictVTable>();
110 assert_eq!(
111 non_nullable_dict.codes().dtype().nullability(),
112 Nullability::NonNullable
113 );
114 assert_eq!(
115 non_nullable_dict.values().dtype().nullability(),
116 Nullability::NonNullable
117 );
118
119 let nullable = cast(
121 non_nullable.as_ref(),
122 &DType::Primitive(PType::I32, Nullability::Nullable),
123 )
124 .unwrap();
125 assert_eq!(
126 nullable.dtype(),
127 &DType::Primitive(PType::I32, Nullability::Nullable)
128 );
129
130 let nullable_dict = nullable.as_::<DictVTable>();
132 assert_eq!(
133 nullable_dict.codes().dtype().nullability(),
134 Nullability::NonNullable
135 );
136 assert_eq!(
137 nullable_dict.values().dtype().nullability(),
138 Nullability::Nullable
139 );
140
141 let back_to_non_nullable = cast(
143 nullable.as_ref(),
144 &DType::Primitive(PType::I32, Nullability::NonNullable),
145 )
146 .unwrap();
147 assert_eq!(
148 back_to_non_nullable.dtype(),
149 &DType::Primitive(PType::I32, Nullability::NonNullable)
150 );
151
152 let back_dict = back_to_non_nullable.as_::<DictVTable>();
154 assert_eq!(
155 back_dict.codes().dtype().nullability(),
156 Nullability::NonNullable
157 );
158 assert_eq!(
159 back_dict.values().dtype().nullability(),
160 Nullability::NonNullable
161 );
162
163 let original_values = dict.to_primitive();
165 let final_values = back_dict.to_primitive();
166 assert_arrays_eq!(original_values, final_values);
167 }
168
169 #[rstest]
170 #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array()).unwrap().into_array())]
171 #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array()).unwrap().into_array())]
172 #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()).unwrap().into_array())]
173 #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array()).unwrap().into_array())]
174 fn test_cast_dict_conformance(#[case] array: crate::ArrayRef) {
175 test_cast_conformance(array.as_ref());
176 }
177}