vortex_array/arrays/dict/compute/
cast.rs1use vortex_error::VortexResult;
5
6use super::Dict;
7use super::DictArray;
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::array::ArrayView;
11use crate::arrays::dict::DictArrayExt;
12use crate::arrays::dict::DictArraySlotsExt;
13use crate::builtins::ArrayBuiltins;
14use crate::dtype::DType;
15use crate::scalar_fn::fns::cast::CastReduce;
16
17impl CastReduce for Dict {
18 fn cast(array: ArrayView<'_, Dict>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
19 if !dtype.is_nullable() && !array.values().validity()?.no_nulls() {
22 return Ok(None);
23 }
24 let casted_values = array.values().cast(dtype.clone())?;
26
27 let casted_codes = if array.codes().dtype().is_nullable() && !dtype.is_nullable() {
29 array
30 .codes()
31 .cast(array.codes().dtype().with_nullability(dtype.nullability()))?
32 } else {
33 array.codes().clone()
34 };
35
36 Ok(Some(
38 unsafe {
39 DictArray::new_unchecked(casted_codes, casted_values)
40 .set_all_values_referenced(array.has_all_values_referenced())
41 }
42 .into_array(),
43 ))
44 }
45}
46
47#[cfg(test)]
48mod tests {
49 use std::sync::LazyLock;
50
51 use rstest::rstest;
52 use vortex_buffer::buffer;
53 use vortex_session::VortexSession;
54
55 use crate::IntoArray;
56 #[expect(deprecated)]
57 use crate::ToCanonical as _;
58 use crate::VortexSessionExecute;
59 use crate::arrays::Dict;
60 use crate::arrays::PrimitiveArray;
61 use crate::arrays::dict::DictArraySlotsExt;
62 use crate::assert_arrays_eq;
63 use crate::builders::dict::dict_encode;
64 use crate::builtins::ArrayBuiltins;
65 use crate::compute::conformance::cast::test_cast_conformance;
66 use crate::dtype::DType;
67 use crate::dtype::Nullability;
68 use crate::dtype::PType;
69 use crate::session::ArraySession;
70
71 static SESSION: LazyLock<VortexSession> =
72 LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
73
74 #[test]
75 fn test_cast_dict_to_wider_type() {
76 let values = buffer![1i32, 2, 3, 2, 1].into_array();
77 let dict = dict_encode(&values, &mut SESSION.create_execution_ctx()).unwrap();
78
79 let casted = dict
80 .into_array()
81 .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
82 .unwrap();
83 assert_eq!(
84 casted.dtype(),
85 &DType::Primitive(PType::I64, Nullability::NonNullable)
86 );
87
88 #[expect(deprecated)]
89 let decoded = casted.to_primitive();
90 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 2, 1]));
91 }
92
93 #[test]
94 fn test_cast_dict_nullable() {
95 let values =
96 PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
97 let dict = dict_encode(&values.into_array(), &mut SESSION.create_execution_ctx()).unwrap();
98
99 let casted = dict
100 .into_array()
101 .cast(DType::Primitive(PType::I64, Nullability::Nullable))
102 .unwrap();
103 assert_eq!(
104 casted.dtype(),
105 &DType::Primitive(PType::I64, Nullability::Nullable)
106 );
107 }
108
109 #[test]
110 fn test_cast_dict_allvalid_to_nonnullable_and_back() {
111 let values = buffer![10i32, 20, 30, 40].into_array();
113 let dict = dict_encode(&values, &mut SESSION.create_execution_ctx()).unwrap();
114
115 assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
117 assert_eq!(
118 dict.values().dtype().nullability(),
119 Nullability::NonNullable
120 );
121
122 let non_nullable = dict
124 .clone()
125 .into_array()
126 .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
127 .unwrap();
128 assert_eq!(
129 non_nullable.dtype(),
130 &DType::Primitive(PType::I32, Nullability::NonNullable)
131 );
132
133 let non_nullable_dict = non_nullable.as_::<Dict>();
135 assert_eq!(
136 non_nullable_dict.codes().dtype().nullability(),
137 Nullability::NonNullable
138 );
139 assert_eq!(
140 non_nullable_dict.values().dtype().nullability(),
141 Nullability::NonNullable
142 );
143
144 let nullable = non_nullable
146 .cast(DType::Primitive(PType::I32, Nullability::Nullable))
147 .unwrap();
148 assert_eq!(
149 nullable.dtype(),
150 &DType::Primitive(PType::I32, Nullability::Nullable)
151 );
152
153 let nullable_dict = nullable.as_::<Dict>();
155 assert_eq!(
156 nullable_dict.codes().dtype().nullability(),
157 Nullability::NonNullable
158 );
159 assert_eq!(
160 nullable_dict.values().dtype().nullability(),
161 Nullability::Nullable
162 );
163
164 let back_to_non_nullable = nullable
166 .cast(DType::Primitive(PType::I32, Nullability::NonNullable))
167 .unwrap();
168 assert_eq!(
169 back_to_non_nullable.dtype(),
170 &DType::Primitive(PType::I32, Nullability::NonNullable)
171 );
172
173 #[expect(deprecated)]
175 let original_values = dict.as_array().to_primitive();
176 #[expect(deprecated)]
177 let final_values = back_to_non_nullable.to_primitive();
178 assert_arrays_eq!(original_values, final_values);
179 }
180
181 #[rstest]
182 #[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array(), &mut SESSION.create_execution_ctx()).unwrap().into_array())]
183 #[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array(), &mut SESSION.create_execution_ctx()).unwrap().into_array())]
184 #[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array(), &mut SESSION.create_execution_ctx()).unwrap().into_array())]
185 #[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array(), &mut SESSION.create_execution_ctx()).unwrap().into_array())]
186 fn test_cast_dict_conformance(#[case] array: crate::ArrayRef) {
187 test_cast_conformance(&array);
188 }
189
190 #[test]
191 fn test_cast_dict_with_unreferenced_null_values_to_nonnullable() {
192 use crate::arrays::DictArray;
193 use crate::validity::Validity;
194
195 let values = PrimitiveArray::new(
199 buffer![1.0f64, 0.0f64, 3.0f64],
200 Validity::from(vortex_buffer::BitBuffer::from(vec![true, false, true])),
201 )
202 .into_array();
203 let codes = buffer![0u32, 2, 0].into_array();
204 let dict = DictArray::try_new(codes, values).unwrap();
205
206 assert_eq!(
208 dict.dtype(),
209 &DType::Primitive(PType::F64, Nullability::Nullable)
210 );
211
212 let result = dict
214 .into_array()
215 .cast(DType::Primitive(PType::F64, Nullability::NonNullable));
216 assert!(
217 result.is_ok(),
218 "cast to NonNullable should succeed for dict with only unreferenced null values"
219 );
220 let casted = result.unwrap();
221 assert_eq!(
222 casted.dtype(),
223 &DType::Primitive(PType::F64, Nullability::NonNullable)
224 );
225 #[expect(deprecated)]
226 let casted_prim = casted.to_primitive();
227 assert_arrays_eq!(casted_prim, PrimitiveArray::from_iter([1.0f64, 3.0, 1.0]));
228 }
229}