vortex_array/arrays/dict/compute/
cast.rs1use vortex_error::VortexResult;
5
6use super::Dict;
7use super::DictArray;
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::array::ArrayView;
11use crate::arrays::dict::DictArrayExt;
12use crate::arrays::dict::DictArraySlotsExt;
13use crate::builtins::ArrayBuiltins;
14use crate::dtype::DType;
15use crate::scalar_fn::fns::cast::CastReduce;
16
17impl CastReduce for Dict {
18 fn cast(array: ArrayView<'_, Dict>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
19 if !dtype.is_nullable() && !array.values().validity()?.definitely_no_nulls() {
22 return Ok(None);
23 }
24 let casted_values = array.values().cast(dtype.clone())?;
26
27 let casted_codes = if array.codes().dtype().is_nullable() && !dtype.is_nullable() {
29 array
30 .codes()
31 .cast(array.codes().dtype().with_nullability(dtype.nullability()))?
32 } else {
33 array.codes().clone()
34 };
35
36 Ok(Some(
38 unsafe {
39 DictArray::new_unchecked(casted_codes, casted_values)
40 .set_all_values_referenced(array.has_all_values_referenced())
41 }
42 .into_array(),
43 ))
44 }
45}
46
47#[cfg(test)]
48mod tests {
49 use std::sync::LazyLock;
50
51 use rstest::rstest;
52 use vortex_buffer::BitBuffer;
53 use vortex_buffer::buffer;
54 use vortex_session::VortexSession;
55
56 use crate::ArrayRef;
57 use crate::IntoArray;
58 use crate::VortexSessionExecute;
59 use crate::arrays::Dict;
60 use crate::arrays::DictArray;
61 use crate::arrays::PrimitiveArray;
62 use crate::arrays::dict::DictArraySlotsExt;
63 use crate::assert_arrays_eq;
64 use crate::builders::dict::dict_encode;
65 use crate::builtins::ArrayBuiltins;
66 use crate::compute::conformance::cast::test_cast_conformance;
67 use crate::dtype::DType;
68 use crate::dtype::Nullability;
69 use crate::dtype::PType;
70 use crate::validity::Validity;
71
72 static SESSION: LazyLock<VortexSession> = LazyLock::new(crate::array_session);
73
74 #[test]
75 fn test_cast_dict_to_wider_type() {
76 let ctx = &mut SESSION.create_execution_ctx();
77 let values = buffer![1i32, 2, 3, 2, 1].into_array();
78 let dict = dict_encode(&values, ctx).unwrap();
79
80 let casted = dict
81 .into_array()
82 .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
83 .unwrap();
84 assert_eq!(
85 casted.dtype(),
86 &DType::Primitive(PType::I64, Nullability::NonNullable)
87 );
88
89 let decoded = casted.into_array().execute::<PrimitiveArray>(ctx).unwrap();
90 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 2, 1]), ctx);
91 }
92
93 #[test]
94 fn test_cast_dict_nullable() {
95 let values =
96 PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
97 let dict = dict_encode(&values.into_array(), &mut SESSION.create_execution_ctx()).unwrap();
98
99 let casted = dict
100 .into_array()
101 .cast(DType::Primitive(PType::I64, Nullability::Nullable))
102 .unwrap();
103 assert_eq!(
104 casted.dtype(),
105 &DType::Primitive(PType::I64, Nullability::Nullable)
106 );
107 }
108
109 #[test]
110 fn test_cast_dict_allvalid_to_nonnullable_and_back() {
111 let ctx = &mut SESSION.create_execution_ctx();
112 let values = buffer![10i32, 20, 30, 40].into_array();
114 let dict = dict_encode(&values, ctx).unwrap();
115
116 assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
118 assert_eq!(
119 dict.values().dtype().nullability(),
120 Nullability::NonNullable
121 );
122
123 let non_nullable = dict
125 .clone()
126 .into_array()
127 .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
128 .unwrap();
129 assert_eq!(
130 non_nullable.dtype(),
131 &DType::Primitive(PType::I32, Nullability::NonNullable)
132 );
133
134 let non_nullable_dict = non_nullable.as_::<Dict>();
136 assert_eq!(
137 non_nullable_dict.codes().dtype().nullability(),
138 Nullability::NonNullable
139 );
140 assert_eq!(
141 non_nullable_dict.values().dtype().nullability(),
142 Nullability::NonNullable
143 );
144
145 let nullable = non_nullable
147 .cast(DType::Primitive(PType::I32, Nullability::Nullable))
148 .unwrap();
149 assert_eq!(
150 nullable.dtype(),
151 &DType::Primitive(PType::I32, Nullability::Nullable)
152 );
153
154 let nullable_dict = nullable.as_::<Dict>();
156 assert_eq!(
157 nullable_dict.codes().dtype().nullability(),
158 Nullability::NonNullable
159 );
160 assert_eq!(
161 nullable_dict.values().dtype().nullability(),
162 Nullability::Nullable
163 );
164
165 let back_to_non_nullable = nullable
167 .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
168 .unwrap();
169 assert_eq!(
170 back_to_non_nullable.dtype(),
171 &DType::Primitive(PType::I32, Nullability::NonNullable)
172 );
173
174 let original_values = dict.into_array().execute::<PrimitiveArray>(ctx).unwrap();
176
177 let final_values = back_to_non_nullable.execute::<PrimitiveArray>(ctx).unwrap();
178 assert_arrays_eq!(original_values, final_values, ctx);
179 }
180
181 #[rstest]
182 #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array(), &mut SESSION.create_execution_ctx()).unwrap().into_array())]
183 #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array(), &mut SESSION.create_execution_ctx()).unwrap().into_array())]
184 #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array(), &mut SESSION.create_execution_ctx()).unwrap().into_array())]
185 #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array(), &mut SESSION.create_execution_ctx()).unwrap().into_array())]
186 fn test_cast_dict_conformance(#[case] array: ArrayRef) {
187 test_cast_conformance(&array);
188 }
189
190 #[test]
191 fn test_cast_dict_with_unreferenced_null_values_to_nonnullable() {
192 let ctx = &mut SESSION.create_execution_ctx();
193 let values = PrimitiveArray::new(
197 buffer![1.0f64, 0.0f64, 3.0f64],
198 Validity::from(BitBuffer::from(vec![true, false, true])),
199 )
200 .into_array();
201 let codes = buffer![0u32, 2, 0].into_array();
202 let dict = DictArray::try_new(codes, values).unwrap();
203
204 assert_eq!(
206 dict.dtype(),
207 &DType::Primitive(PType::F64, Nullability::Nullable)
208 );
209
210 let result = dict
212 .into_array()
213 .cast(DType::Primitive(PType::F64, Nullability::NonNullable));
214 assert!(
215 result.is_ok(),
216 "cast to NonNullable should succeed for dict with only unreferenced null values"
217 );
218 let casted = result.unwrap();
219 assert_eq!(
220 casted.dtype(),
221 &DType::Primitive(PType::F64, Nullability::NonNullable)
222 );
223 assert_arrays_eq!(casted, PrimitiveArray::from_iter([1.0f64, 3.0, 1.0]), ctx);
224 }
225}