1use vortex_dtype::{DType, Nullability, PType};
5use vortex_error::{VortexExpect as _, VortexUnwrap, vortex_panic};
6use vortex_scalar::Scalar;
7
8use crate::Array;
9use crate::compute::{MinMaxResult, cast, min_max};
10
11pub fn test_cast_conformance(array: &dyn Array) {
21 let dtype = array.dtype();
22
23 test_cast_identity(array);
25
26 test_cast_to_non_nullable(array);
27 test_cast_to_nullable(array);
28
29 match dtype {
31 DType::Null => test_cast_from_null(array),
32 DType::Primitive(ptype, ..) => match ptype {
33 PType::U8
34 | PType::U16
35 | PType::U32
36 | PType::U64
37 | PType::I8
38 | PType::I16
39 | PType::I32
40 | PType::I64 => test_cast_to_integral_types(array),
41 PType::F16 | PType::F32 | PType::F64 => test_cast_from_floating_point_types(array),
42 },
43 _ => {}
44 }
45}
46
47fn test_cast_identity(array: &dyn Array) {
48 let result = cast(array, array.dtype()).vortex_unwrap();
50 assert_eq!(result.len(), array.len());
51 assert_eq!(result.dtype(), array.dtype());
52
53 for i in 0..array.len().min(10) {
55 assert_eq!(array.scalar_at(i), result.scalar_at(i));
56 }
57}
58
59fn test_cast_from_null(array: &dyn Array) {
60 let result = cast(array, &DType::Null).vortex_unwrap();
62 assert_eq!(result.len(), array.len());
63 assert_eq!(result.dtype(), &DType::Null);
64
65 let nullable_types = vec![
67 DType::Bool(Nullability::Nullable),
68 DType::Primitive(PType::I32, Nullability::Nullable),
69 DType::Primitive(PType::F64, Nullability::Nullable),
70 DType::Utf8(Nullability::Nullable),
71 DType::Binary(Nullability::Nullable),
72 ];
73
74 for dtype in nullable_types {
75 let result = cast(array, &dtype).vortex_unwrap();
76 assert_eq!(result.len(), array.len());
77 assert_eq!(result.dtype(), &dtype);
78
79 for i in 0..array.len().min(10) {
81 assert!(result.scalar_at(i).is_null());
82 }
83 }
84
85 let non_nullable_types = vec![
87 DType::Bool(Nullability::NonNullable),
88 DType::Primitive(PType::I32, Nullability::NonNullable),
89 ];
90
91 for dtype in non_nullable_types {
92 assert!(cast(array, &dtype).is_err());
93 }
94}
95
96fn test_cast_to_non_nullable(array: &dyn Array) {
97 if array.invalid_count() == 0 {
98 let non_nullable = cast(array, &array.dtype().as_nonnullable())
99 .vortex_expect("arrays without nulls can cast to non-nullable");
100 assert_eq!(non_nullable.dtype(), &array.dtype().as_nonnullable());
101 assert_eq!(non_nullable.len(), array.len());
102
103 for i in 0..array.len().min(10) {
104 assert_eq!(array.scalar_at(i), non_nullable.scalar_at(i));
105 }
106
107 let back_to_nullable = cast(&non_nullable, array.dtype())
108 .vortex_expect("non-nullable arrays can cast to nullable");
109 assert_eq!(back_to_nullable.dtype(), array.dtype());
110 assert_eq!(back_to_nullable.len(), array.len());
111
112 for i in 0..array.len().min(10) {
113 assert_eq!(array.scalar_at(i), back_to_nullable.scalar_at(i));
114 }
115 } else {
116 if &DType::Null == array.dtype() {
117 return;
120 }
121 cast(array, &array.dtype().as_nonnullable())
122 .err()
123 .unwrap_or_else(|| {
124 vortex_panic!(
125 "arrays with nulls should error when casting to non-nullable {}",
126 array,
127 )
128 });
129 }
130}
131
132fn test_cast_to_nullable(array: &dyn Array) {
133 let nullable = cast(array, &array.dtype().as_nullable())
134 .vortex_expect("arrays without nulls can cast to nullable");
135 assert_eq!(nullable.dtype(), &array.dtype().as_nullable());
136 assert_eq!(nullable.len(), array.len());
137
138 for i in 0..array.len().min(10) {
139 assert_eq!(array.scalar_at(i), nullable.scalar_at(i));
140 }
141
142 let back = cast(&nullable, array.dtype())
143 .vortex_expect("casting to nullable and back should be a no-op");
144 assert_eq!(back.dtype(), array.dtype());
145 assert_eq!(back.len(), array.len());
146
147 for i in 0..array.len().min(10) {
148 assert_eq!(array.scalar_at(i), back.scalar_at(i));
149 }
150}
151
152fn test_cast_from_floating_point_types(array: &dyn Array) {
153 let ptype = array.as_primitive_typed().ptype();
154 test_cast_to_primitive(array, PType::I8, false);
155 test_cast_to_primitive(array, PType::U8, false);
156 test_cast_to_primitive(array, PType::I16, false);
157 test_cast_to_primitive(array, PType::U16, false);
158 test_cast_to_primitive(array, PType::I32, false);
159 test_cast_to_primitive(array, PType::U32, false);
160 test_cast_to_primitive(array, PType::I64, false);
161 test_cast_to_primitive(array, PType::U64, false);
162 test_cast_to_primitive(array, PType::F16, matches!(ptype, PType::F16));
163 test_cast_to_primitive(array, PType::F32, matches!(ptype, PType::F16 | PType::F32));
164 test_cast_to_primitive(array, PType::F64, true);
165}
166
167fn test_cast_to_integral_types(array: &dyn Array) {
168 test_cast_to_primitive(array, PType::I8, true);
169 test_cast_to_primitive(array, PType::U8, true);
170 test_cast_to_primitive(array, PType::I16, true);
171 test_cast_to_primitive(array, PType::U16, true);
172 test_cast_to_primitive(array, PType::I32, true);
173 test_cast_to_primitive(array, PType::U32, true);
174 test_cast_to_primitive(array, PType::I64, true);
175 test_cast_to_primitive(array, PType::U64, true);
176}
177
178fn fits(value: &Scalar, ptype: PType) -> bool {
180 let dtype = DType::Primitive(ptype, value.dtype().nullability());
181 value.cast(&dtype).is_ok()
182}
183
184fn test_cast_to_primitive(array: &dyn Array, target_ptype: PType, test_round_trip: bool) {
185 let maybe_min_max = min_max(array).vortex_unwrap();
186
187 if let Some(MinMaxResult { min, max }) = maybe_min_max
188 && (!fits(&min, target_ptype) || !fits(&max, target_ptype))
189 {
190 cast(
191 array,
192 &DType::Primitive(target_ptype, array.dtype().nullability()),
193 )
194 .err()
195 .unwrap_or_else(|| {
196 vortex_panic!(
197 "Cast must fail because some values are out of bounds. {} {:?} {:?} {} {}",
198 target_ptype,
199 min,
200 max,
201 array,
202 array.display_values(),
203 )
204 });
205 return;
206 }
207
208 let casted = cast(
210 array,
211 &DType::Primitive(target_ptype, array.dtype().nullability()),
212 )
213 .unwrap_or_else(|e| {
214 vortex_panic!(
215 "Cast must succeed because all values are within bounds. {} {}: {e}",
216 target_ptype,
217 array.display_values(),
218 )
219 });
220 assert_eq!(array.validity_mask(), casted.validity_mask());
221 for i in 0..array.len().min(10) {
222 let original = array.scalar_at(i);
223 let casted = casted.scalar_at(i);
224 assert_eq!(
225 original.cast(casted.dtype()).vortex_unwrap(),
226 casted,
227 "{i} {original} {casted}"
228 );
229 if test_round_trip {
230 assert_eq!(
231 original,
232 casted.cast(original.dtype()).vortex_unwrap(),
233 "{i} {original} {casted}"
234 );
235 }
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use vortex_buffer::buffer;
242 use vortex_dtype::{DType, FieldNames, Nullability};
243
244 use super::*;
245 use crate::IntoArray;
246 use crate::arrays::{
247 BoolArray, ListArray, NullArray, PrimitiveArray, StructArray, VarBinArray,
248 };
249
250 #[test]
251 fn test_cast_conformance_u32() {
252 let array = buffer![0u32, 100, 200, 65535, 1000000].into_array();
253 test_cast_conformance(array.as_ref());
254 }
255
256 #[test]
257 fn test_cast_conformance_i32() {
258 let array = buffer![-100i32, -1, 0, 1, 100].into_array();
259 test_cast_conformance(array.as_ref());
260 }
261
262 #[test]
263 fn test_cast_conformance_f32() {
264 let array = buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array();
265 test_cast_conformance(array.as_ref());
266 }
267
268 #[test]
269 fn test_cast_conformance_nullable() {
270 let array = PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]);
271 test_cast_conformance(array.as_ref());
272 }
273
274 #[test]
275 fn test_cast_conformance_bool() {
276 let array = BoolArray::from_iter(vec![true, false, true, false]);
277 test_cast_conformance(array.as_ref());
278 }
279
280 #[test]
281 fn test_cast_conformance_null() {
282 let array = NullArray::new(5);
283 test_cast_conformance(array.as_ref());
284 }
285
286 #[test]
287 fn test_cast_conformance_utf8() {
288 let array = VarBinArray::from_iter(
289 vec![Some("hello"), None, Some("world")],
290 DType::Utf8(Nullability::Nullable),
291 );
292 test_cast_conformance(array.as_ref());
293 }
294
295 #[test]
296 fn test_cast_conformance_binary() {
297 let array = VarBinArray::from_iter(
298 vec![Some(b"data".as_slice()), None, Some(b"bytes".as_slice())],
299 DType::Binary(Nullability::Nullable),
300 );
301 test_cast_conformance(array.as_ref());
302 }
303
304 #[test]
305 fn test_cast_conformance_struct() {
306 let names = FieldNames::from(["a", "b"]);
307
308 let a = buffer![1i32, 2, 3].into_array();
309 let b = VarBinArray::from_iter(
310 vec![Some("x"), None, Some("z")],
311 DType::Utf8(Nullability::Nullable),
312 )
313 .into_array();
314
315 let array =
316 StructArray::try_new(names, vec![a, b], 3, crate::validity::Validity::NonNullable)
317 .unwrap();
318 test_cast_conformance(array.as_ref());
319 }
320
321 #[test]
322 fn test_cast_conformance_list() {
323 let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
324 let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
325
326 let array =
327 ListArray::try_new(data, offsets, crate::validity::Validity::NonNullable).unwrap();
328 test_cast_conformance(array.as_ref());
329 }
330}