1use vortex_dtype::DType;
5use vortex_dtype::Nullability;
6use vortex_dtype::PType;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_panic;
10
11use crate::Array;
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::builtins::ArrayBuiltins;
15use crate::compute::MinMaxResult;
16use crate::compute::min_max;
17use crate::scalar::Scalar;
18
19fn cast_and_execute(array: &ArrayRef, dtype: DType) -> VortexResult<ArrayRef> {
21 array.cast(dtype)?.to_canonical().map(|c| c.into_array())
22}
23
24pub fn test_cast_conformance(array: &dyn Array) {
34 let dtype = array.dtype();
35
36 test_cast_identity(array);
38
39 test_cast_to_non_nullable(array);
40 test_cast_to_nullable(array);
41
42 match dtype {
44 DType::Null => test_cast_from_null(array),
45 DType::Primitive(ptype, ..) => match ptype {
46 PType::U8
47 | PType::U16
48 | PType::U32
49 | PType::U64
50 | PType::I8
51 | PType::I16
52 | PType::I32
53 | PType::I64 => test_cast_to_integral_types(array),
54 PType::F16 | PType::F32 | PType::F64 => test_cast_from_floating_point_types(array),
55 },
56 _ => {}
57 }
58}
59
60fn test_cast_identity(array: &dyn Array) {
61 let result = cast_and_execute(&array.to_array(), array.dtype().clone())
63 .vortex_expect("cast should succeed in conformance test");
64 assert_eq!(result.len(), array.len());
65 assert_eq!(result.dtype(), array.dtype());
66
67 for i in 0..array.len().min(10) {
69 assert_eq!(
70 array
71 .scalar_at(i)
72 .vortex_expect("scalar_at should succeed in conformance test"),
73 result
74 .scalar_at(i)
75 .vortex_expect("scalar_at should succeed in conformance test")
76 );
77 }
78}
79
80fn test_cast_from_null(array: &dyn Array) {
81 let result = cast_and_execute(&array.to_array(), DType::Null)
83 .vortex_expect("cast should succeed in conformance test");
84 assert_eq!(result.len(), array.len());
85 assert_eq!(result.dtype(), &DType::Null);
86
87 let nullable_types = vec![
89 DType::Bool(Nullability::Nullable),
90 DType::Primitive(PType::I32, Nullability::Nullable),
91 DType::Primitive(PType::F64, Nullability::Nullable),
92 DType::Utf8(Nullability::Nullable),
93 DType::Binary(Nullability::Nullable),
94 ];
95
96 for dtype in nullable_types {
97 let result = cast_and_execute(&array.to_array(), dtype.clone())
98 .vortex_expect("cast should succeed in conformance test");
99 assert_eq!(result.len(), array.len());
100 assert_eq!(result.dtype(), &dtype);
101
102 for i in 0..array.len().min(10) {
104 assert!(
105 result
106 .scalar_at(i)
107 .vortex_expect("scalar_at should succeed in conformance test")
108 .is_null()
109 );
110 }
111 }
112
113 let non_nullable_types = vec![
115 DType::Bool(Nullability::NonNullable),
116 DType::Primitive(PType::I32, Nullability::NonNullable),
117 ];
118
119 for dtype in non_nullable_types {
120 assert!(cast_and_execute(&array.to_array(), dtype.clone()).is_err());
121 }
122}
123
124fn test_cast_to_non_nullable(array: &dyn Array) {
125 if array
126 .invalid_count()
127 .vortex_expect("invalid_count should succeed in conformance test")
128 == 0
129 {
130 let non_nullable = cast_and_execute(&array.to_array(), array.dtype().as_nonnullable())
131 .vortex_expect("arrays without nulls can cast to non-nullable");
132 assert_eq!(non_nullable.dtype(), &array.dtype().as_nonnullable());
133 assert_eq!(non_nullable.len(), array.len());
134
135 for i in 0..array.len().min(10) {
136 assert_eq!(
137 array
138 .scalar_at(i)
139 .vortex_expect("scalar_at should succeed in conformance test"),
140 non_nullable
141 .scalar_at(i)
142 .vortex_expect("scalar_at should succeed in conformance test")
143 );
144 }
145
146 let back_to_nullable = cast_and_execute(&non_nullable, array.dtype().clone())
147 .vortex_expect("non-nullable arrays can cast to nullable");
148 assert_eq!(back_to_nullable.dtype(), array.dtype());
149 assert_eq!(back_to_nullable.len(), array.len());
150
151 for i in 0..array.len().min(10) {
152 assert_eq!(
153 array
154 .scalar_at(i)
155 .vortex_expect("scalar_at should succeed in conformance test"),
156 back_to_nullable
157 .scalar_at(i)
158 .vortex_expect("scalar_at should succeed in conformance test")
159 );
160 }
161 } else {
162 if &DType::Null == array.dtype() {
163 return;
166 }
167 cast_and_execute(&array.to_array(), array.dtype().as_nonnullable())
168 .err()
169 .unwrap_or_else(|| {
170 vortex_panic!(
171 "arrays with nulls should error when casting to non-nullable {}",
172 array,
173 )
174 });
175 }
176}
177
178fn test_cast_to_nullable(array: &dyn Array) {
179 let nullable = cast_and_execute(&array.to_array(), array.dtype().as_nullable())
180 .vortex_expect("arrays without nulls can cast to nullable");
181 assert_eq!(nullable.dtype(), &array.dtype().as_nullable());
182 assert_eq!(nullable.len(), array.len());
183
184 for i in 0..array.len().min(10) {
185 assert_eq!(
186 array
187 .scalar_at(i)
188 .vortex_expect("scalar_at should succeed in conformance test"),
189 nullable
190 .scalar_at(i)
191 .vortex_expect("scalar_at should succeed in conformance test")
192 );
193 }
194
195 let back = cast_and_execute(&nullable, array.dtype().clone())
196 .vortex_expect("casting to nullable and back should be a no-op");
197 assert_eq!(back.dtype(), array.dtype());
198 assert_eq!(back.len(), array.len());
199
200 for i in 0..array.len().min(10) {
201 assert_eq!(
202 array
203 .scalar_at(i)
204 .vortex_expect("scalar_at should succeed in conformance test"),
205 back.scalar_at(i)
206 .vortex_expect("scalar_at should succeed in conformance test")
207 );
208 }
209}
210
211fn test_cast_from_floating_point_types(array: &dyn Array) {
212 let ptype = array.as_primitive_typed().ptype();
213 test_cast_to_primitive(array, PType::I8, false);
214 test_cast_to_primitive(array, PType::U8, false);
215 test_cast_to_primitive(array, PType::I16, false);
216 test_cast_to_primitive(array, PType::U16, false);
217 test_cast_to_primitive(array, PType::I32, false);
218 test_cast_to_primitive(array, PType::U32, false);
219 test_cast_to_primitive(array, PType::I64, false);
220 test_cast_to_primitive(array, PType::U64, false);
221 test_cast_to_primitive(array, PType::F16, matches!(ptype, PType::F16));
222 test_cast_to_primitive(array, PType::F32, matches!(ptype, PType::F16 | PType::F32));
223 test_cast_to_primitive(array, PType::F64, true);
224}
225
226fn test_cast_to_integral_types(array: &dyn Array) {
227 test_cast_to_primitive(array, PType::I8, true);
228 test_cast_to_primitive(array, PType::U8, true);
229 test_cast_to_primitive(array, PType::I16, true);
230 test_cast_to_primitive(array, PType::U16, true);
231 test_cast_to_primitive(array, PType::I32, true);
232 test_cast_to_primitive(array, PType::U32, true);
233 test_cast_to_primitive(array, PType::I64, true);
234 test_cast_to_primitive(array, PType::U64, true);
235}
236
237fn fits(value: &Scalar, ptype: PType) -> bool {
239 let dtype = DType::Primitive(ptype, value.dtype().nullability());
240 value.cast(&dtype).is_ok()
241}
242
243fn test_cast_to_primitive(array: &dyn Array, target_ptype: PType, test_round_trip: bool) {
244 let maybe_min_max = min_max(array).vortex_expect("cast should succeed in conformance test");
245
246 if let Some(MinMaxResult { min, max }) = maybe_min_max
247 && (!fits(&min, target_ptype) || !fits(&max, target_ptype))
248 {
249 cast_and_execute(
250 &array.to_array(),
251 DType::Primitive(target_ptype, array.dtype().nullability()),
252 )
253 .err()
254 .unwrap_or_else(|| {
255 vortex_panic!(
256 "Cast must fail because some values are out of bounds. {} {:?} {:?} {} {}",
257 target_ptype,
258 min,
259 max,
260 array,
261 array.display_values(),
262 )
263 });
264 return;
265 }
266
267 let casted = cast_and_execute(
269 &array.to_array(),
270 DType::Primitive(target_ptype, array.dtype().nullability()),
271 )
272 .unwrap_or_else(|e| {
273 vortex_panic!(
274 "Cast must succeed because all values are within bounds. {} {}: {e}",
275 target_ptype,
276 array.display_values(),
277 )
278 });
279 assert_eq!(
280 array
281 .validity_mask()
282 .vortex_expect("validity_mask should succeed in conformance test"),
283 casted
284 .validity_mask()
285 .vortex_expect("validity_mask should succeed in conformance test")
286 );
287 for i in 0..array.len().min(10) {
288 let original = array
289 .scalar_at(i)
290 .vortex_expect("scalar_at should succeed in conformance test");
291 let casted = casted
292 .scalar_at(i)
293 .vortex_expect("scalar_at should succeed in conformance test");
294 assert_eq!(
295 original
296 .cast(casted.dtype())
297 .vortex_expect("cast should succeed in conformance test"),
298 casted,
299 "{i} {original} {casted}"
300 );
301 if test_round_trip {
302 assert_eq!(
303 original,
304 casted
305 .cast(original.dtype())
306 .vortex_expect("cast should succeed in conformance test"),
307 "{i} {original} {casted}"
308 );
309 }
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use vortex_buffer::buffer;
316 use vortex_dtype::DType;
317 use vortex_dtype::FieldNames;
318 use vortex_dtype::Nullability;
319
320 use super::*;
321 use crate::IntoArray;
322 use crate::arrays::BoolArray;
323 use crate::arrays::ListArray;
324 use crate::arrays::NullArray;
325 use crate::arrays::PrimitiveArray;
326 use crate::arrays::StructArray;
327 use crate::arrays::VarBinArray;
328
329 #[test]
330 fn test_cast_conformance_u32() {
331 let array = buffer![0u32, 100, 200, 65535, 1000000].into_array();
332 test_cast_conformance(array.as_ref());
333 }
334
335 #[test]
336 fn test_cast_conformance_i32() {
337 let array = buffer![-100i32, -1, 0, 1, 100].into_array();
338 test_cast_conformance(array.as_ref());
339 }
340
341 #[test]
342 fn test_cast_conformance_f32() {
343 let array = buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array();
344 test_cast_conformance(array.as_ref());
345 }
346
347 #[test]
348 fn test_cast_conformance_nullable() {
349 let array = PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]);
350 test_cast_conformance(array.as_ref());
351 }
352
353 #[test]
354 fn test_cast_conformance_bool() {
355 let array = BoolArray::from_iter(vec![true, false, true, false]);
356 test_cast_conformance(array.as_ref());
357 }
358
359 #[test]
360 fn test_cast_conformance_null() {
361 let array = NullArray::new(5);
362 test_cast_conformance(array.as_ref());
363 }
364
365 #[test]
366 fn test_cast_conformance_utf8() {
367 let array = VarBinArray::from_iter(
368 vec![Some("hello"), None, Some("world")],
369 DType::Utf8(Nullability::Nullable),
370 );
371 test_cast_conformance(array.as_ref());
372 }
373
374 #[test]
375 fn test_cast_conformance_binary() {
376 let array = VarBinArray::from_iter(
377 vec![Some(b"data".as_slice()), None, Some(b"bytes".as_slice())],
378 DType::Binary(Nullability::Nullable),
379 );
380 test_cast_conformance(array.as_ref());
381 }
382
383 #[test]
384 fn test_cast_conformance_struct() {
385 let names = FieldNames::from(["a", "b"]);
386
387 let a = buffer![1i32, 2, 3].into_array();
388 let b = VarBinArray::from_iter(
389 vec![Some("x"), None, Some("z")],
390 DType::Utf8(Nullability::Nullable),
391 )
392 .into_array();
393
394 let array =
395 StructArray::try_new(names, vec![a, b], 3, crate::validity::Validity::NonNullable)
396 .unwrap();
397 test_cast_conformance(array.as_ref());
398 }
399
400 #[test]
401 fn test_cast_conformance_list() {
402 let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
403 let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
404
405 let array =
406 ListArray::try_new(data, offsets, crate::validity::Validity::NonNullable).unwrap();
407 test_cast_conformance(array.as_ref());
408 }
409}