1use vortex_dtype::{DType, Nullability, PType};
5use vortex_error::VortexUnwrap;
6
7use crate::Array;
8use crate::compute::cast;
9
10pub fn test_cast_conformance(array: &dyn Array) {
20 let dtype = array.dtype();
21
22 test_cast_identity(array);
24
25 test_cast_allvalid_to_nonnullable_and_back(array);
27
28 match dtype {
30 DType::Null => test_cast_from_null(array),
31 DType::Bool(nullability) => test_cast_from_bool(array, *nullability),
32 DType::Primitive(ptype, nullability) => {
33 test_cast_nullability_changes_primitive(array, *ptype, *nullability);
34 match ptype {
35 PType::U8 => test_cast_from_u8(array),
36 PType::U16 => test_cast_from_u16(array),
37 PType::U32 => test_cast_from_u32(array),
38 PType::U64 => test_cast_from_u64(array),
39 PType::I8 => test_cast_from_i8(array),
40 PType::I16 => test_cast_from_i16(array),
41 PType::I32 => test_cast_from_i32(array),
42 PType::I64 => test_cast_from_i64(array),
43 PType::F16 => test_cast_from_f16(array),
44 PType::F32 => test_cast_from_f32(array),
45 PType::F64 => test_cast_from_f64(array),
46 }
47 }
48 DType::Decimal(_, nullability) => test_cast_from_decimal(array, *nullability),
49 DType::Utf8(nullability) => test_cast_from_utf8(array, *nullability),
50 DType::Binary(nullability) => test_cast_from_binary(array, *nullability),
51 DType::Struct(_, nullability) => test_cast_from_struct(array, *nullability),
52 DType::List(_, nullability) => test_cast_from_list(array, *nullability),
53 DType::FixedSizeList(.., nullability) => {
54 test_cast_from_fixed_size_list(array, *nullability)
55 }
56 DType::Extension(_) => test_cast_from_extension(array),
57 }
58}
59
60fn test_cast_identity(array: &dyn Array) {
61 let result = cast(array, array.dtype()).vortex_unwrap();
63 assert_eq!(result.len(), array.len());
64 assert_eq!(result.dtype(), array.dtype());
65
66 for i in 0..array.len().min(10) {
68 assert_eq!(array.scalar_at(i), result.scalar_at(i),);
69 }
70}
71
72fn test_cast_from_null(array: &dyn Array) {
73 let result = cast(array, &DType::Null).vortex_unwrap();
75 assert_eq!(result.len(), array.len());
76 assert_eq!(result.dtype(), &DType::Null);
77
78 let nullable_types = vec![
80 DType::Bool(Nullability::Nullable),
81 DType::Primitive(PType::I32, Nullability::Nullable),
82 DType::Primitive(PType::F64, Nullability::Nullable),
83 DType::Utf8(Nullability::Nullable),
84 DType::Binary(Nullability::Nullable),
85 ];
86
87 for dtype in nullable_types {
88 let result = cast(array, &dtype).vortex_unwrap();
89 assert_eq!(result.len(), array.len());
90 assert_eq!(result.dtype(), &dtype);
91
92 for i in 0..array.len().min(10) {
94 assert!(result.scalar_at(i).is_null());
95 }
96 }
97
98 let non_nullable_types = vec![
100 DType::Bool(Nullability::NonNullable),
101 DType::Primitive(PType::I32, Nullability::NonNullable),
102 ];
103
104 for dtype in non_nullable_types {
105 assert!(cast(array, &dtype).is_err());
106 }
107}
108
109fn test_cast_from_bool(array: &dyn Array, nullability: Nullability) {
110 test_cast_nullability_changes(array, &DType::Bool(Nullability::Nullable));
112 if nullability == Nullability::Nullable {
113 let _ = cast(array, &DType::Bool(Nullability::NonNullable));
115 }
116
117 test_cast_to_primitive(array, PType::U8);
119 test_cast_to_primitive(array, PType::I32);
120 test_cast_to_primitive(array, PType::F32);
121}
122
123fn test_cast_from_decimal(array: &dyn Array, nullability: Nullability) {
124 if let DType::Decimal(decimal_type, _) = array.dtype() {
126 test_cast_nullability_changes(array, &DType::Decimal(*decimal_type, Nullability::Nullable));
127 if nullability == Nullability::Nullable {
128 let _ = cast(
130 array,
131 &DType::Decimal(*decimal_type, Nullability::NonNullable),
132 );
133 }
134 }
135}
136
137fn test_cast_from_utf8(array: &dyn Array, nullability: Nullability) {
138 test_cast_nullability_changes(array, &DType::Utf8(Nullability::Nullable));
140 if nullability == Nullability::Nullable {
141 let _ = cast(array, &DType::Utf8(Nullability::NonNullable));
143 }
144
145 test_cast_to_type_safe(array, &DType::Binary(nullability));
147}
148
149fn test_cast_from_binary(array: &dyn Array, nullability: Nullability) {
150 test_cast_nullability_changes(array, &DType::Binary(Nullability::Nullable));
152 if nullability == Nullability::Nullable {
153 let _ = cast(array, &DType::Binary(Nullability::NonNullable));
155 }
156
157 test_cast_to_type_safe(array, &DType::Utf8(nullability));
159}
160
161fn test_cast_from_struct(array: &dyn Array, nullability: Nullability) {
162 if let DType::Struct(fields, _) = array.dtype() {
164 test_cast_nullability_changes(array, &DType::Struct(fields.clone(), Nullability::Nullable));
165 if nullability == Nullability::Nullable {
166 let _ = cast(
168 array,
169 &DType::Struct(fields.clone(), Nullability::NonNullable),
170 );
171 }
172 }
173}
174
175fn test_cast_from_list(array: &dyn Array, nullability: Nullability) {
176 if let DType::List(element_type, _) = array.dtype() {
178 test_cast_nullability_changes(
179 array,
180 &DType::List(element_type.clone(), Nullability::Nullable),
181 );
182 if nullability == Nullability::Nullable {
183 let _ = cast(
185 array,
186 &DType::List(element_type.clone(), Nullability::NonNullable),
187 );
188 }
189 }
190}
191
192fn test_cast_from_fixed_size_list(array: &dyn Array, nullability: Nullability) {
193 if let DType::FixedSizeList(element_type, list_size, ..) = array.dtype() {
195 test_cast_nullability_changes(
196 array,
197 &DType::FixedSizeList(element_type.clone(), *list_size, Nullability::Nullable),
198 );
199 if nullability == Nullability::Nullable {
200 let _ = cast(
202 array,
203 &DType::FixedSizeList(element_type.clone(), *list_size, Nullability::NonNullable),
204 );
205 }
206 }
207}
208
209fn test_cast_from_extension(array: &dyn Array) {
210 if let DType::Extension(ext_dtype) = array.dtype() {
213 let result = cast(array, &DType::Extension(ext_dtype.clone())).vortex_unwrap();
214 assert_eq!(result.len(), array.len());
215 assert_eq!(result.dtype(), array.dtype());
216 }
217}
218
219fn test_cast_allvalid_to_nonnullable_and_back(array: &dyn Array) {
220 if array.dtype() == &DType::Null {
222 return;
223 }
224
225 if array.invalid_count() == 0 {
227 if array.dtype().nullability() == Nullability::Nullable {
229 let non_nullable_dtype = array.dtype().with_nullability(Nullability::NonNullable);
230
231 if let Ok(non_nullable) = cast(array, &non_nullable_dtype) {
233 assert_eq!(non_nullable.dtype(), &non_nullable_dtype);
234 assert_eq!(non_nullable.len(), array.len());
235
236 let nullable_dtype = array.dtype().with_nullability(Nullability::Nullable);
238 let back_to_nullable = cast(&non_nullable, &nullable_dtype).vortex_unwrap();
239 assert_eq!(back_to_nullable.dtype(), &nullable_dtype);
240 assert_eq!(back_to_nullable.len(), array.len());
241
242 for i in 0..array.len().min(10) {
244 assert_eq!(array.scalar_at(i), back_to_nullable.scalar_at(i));
245 }
246 }
247 }
248 else if array.dtype().nullability() == Nullability::NonNullable {
250 let nullable_dtype = array.dtype().with_nullability(Nullability::Nullable);
251
252 let nullable = cast(array, &nullable_dtype).vortex_unwrap();
254 assert_eq!(nullable.dtype(), &nullable_dtype);
255 assert_eq!(nullable.len(), array.len());
256
257 let non_nullable_dtype = array.dtype().with_nullability(Nullability::NonNullable);
259 let back_to_non_nullable = cast(&nullable, &non_nullable_dtype).vortex_unwrap();
260 assert_eq!(back_to_non_nullable.dtype(), &non_nullable_dtype);
261 assert_eq!(back_to_non_nullable.len(), array.len());
262
263 for i in 0..array.len().min(10) {
265 assert_eq!(array.scalar_at(i), back_to_non_nullable.scalar_at(i));
266 }
267 }
268 }
269}
270
271fn test_cast_nullability_changes(array: &dyn Array, nullable_version: &DType) {
272 if array.dtype().nullability() == Nullability::NonNullable {
274 let result = cast(array, nullable_version).vortex_unwrap();
275 assert_eq!(result.len(), array.len());
276 assert_eq!(result.dtype(), nullable_version);
277
278 assert_eq!(
280 result.encoding().id(),
281 array.encoding().id(),
282 "Nullability cast should preserve encoding"
283 );
284
285 for i in 0..array.len().min(10) {
287 assert_eq!(array.scalar_at(i), result.scalar_at(i),);
288 }
289 }
290}
291
292fn test_cast_nullability_changes_primitive(
293 array: &dyn Array,
294 ptype: PType,
295 nullability: Nullability,
296) {
297 if nullability == Nullability::NonNullable {
299 let nullable_dtype = DType::Primitive(ptype, Nullability::Nullable);
300 let result = cast(array, &nullable_dtype).vortex_unwrap();
301 assert_eq!(result.len(), array.len());
302 assert_eq!(result.dtype(), &nullable_dtype);
303
304 assert_eq!(
306 result.encoding().id(),
307 array.encoding().id(),
308 "Nullability cast should preserve encoding"
309 );
310
311 for i in 0..array.len().min(10) {
313 assert_eq!(array.scalar_at(i), result.scalar_at(i),);
314 }
315 }
316
317 if nullability == Nullability::Nullable {
319 let non_nullable_dtype = DType::Primitive(ptype, Nullability::NonNullable);
321 if let Ok(result) = cast(array, &non_nullable_dtype) {
322 assert_eq!(result.len(), array.len());
323 assert_eq!(result.dtype(), &non_nullable_dtype);
324
325 assert_eq!(
327 result.encoding().id(),
328 array.encoding().id(),
329 "Nullability cast should preserve encoding"
330 );
331
332 for i in 0..array.len().min(10) {
334 assert_eq!(array.scalar_at(i), result.scalar_at(i),);
335 }
336 }
337 }
338}
339
340fn test_cast_from_u8(array: &dyn Array) {
341 test_cast_to_primitive(array, PType::U16);
343 test_cast_to_primitive(array, PType::U32);
344 test_cast_to_primitive(array, PType::U64);
345 test_cast_to_primitive(array, PType::I16);
346 test_cast_to_primitive(array, PType::I32);
347 test_cast_to_primitive(array, PType::I64);
348 test_cast_to_primitive(array, PType::F32);
349 test_cast_to_primitive(array, PType::F64);
350
351 test_cast_to_primitive(array, PType::I8);
353}
354
355fn test_cast_from_u16(array: &dyn Array) {
356 test_cast_to_primitive(array, PType::U8);
358
359 test_cast_to_primitive(array, PType::U32);
361 test_cast_to_primitive(array, PType::U64);
362 test_cast_to_primitive(array, PType::I32);
363 test_cast_to_primitive(array, PType::I64);
364 test_cast_to_primitive(array, PType::F32);
365 test_cast_to_primitive(array, PType::F64);
366
367 test_cast_to_primitive(array, PType::I16);
369}
370
371fn test_cast_from_u32(array: &dyn Array) {
372 test_cast_to_primitive(array, PType::U8);
374 test_cast_to_primitive(array, PType::U16);
375 test_cast_to_primitive(array, PType::I8);
376 test_cast_to_primitive(array, PType::I16);
377
378 test_cast_to_primitive(array, PType::U64);
380 test_cast_to_primitive(array, PType::I64);
381 test_cast_to_primitive(array, PType::F64);
382
383 test_cast_to_primitive(array, PType::I32);
385 test_cast_to_primitive(array, PType::F32);
386}
387
388fn test_cast_from_u64(array: &dyn Array) {
389 test_cast_to_primitive(array, PType::U8);
391 test_cast_to_primitive(array, PType::U16);
392 test_cast_to_primitive(array, PType::U32);
393 test_cast_to_primitive(array, PType::I8);
394 test_cast_to_primitive(array, PType::I16);
395 test_cast_to_primitive(array, PType::I32);
396 test_cast_to_primitive(array, PType::F32);
397
398 test_cast_to_primitive(array, PType::I64);
400 test_cast_to_primitive(array, PType::F64);
401}
402
403fn test_cast_from_i8(array: &dyn Array) {
404 test_cast_to_primitive(array, PType::I16);
406 test_cast_to_primitive(array, PType::I32);
407 test_cast_to_primitive(array, PType::I64);
408 test_cast_to_primitive(array, PType::F32);
409 test_cast_to_primitive(array, PType::F64);
410
411 test_cast_to_primitive(array, PType::U8);
413}
414
415fn test_cast_from_i16(array: &dyn Array) {
416 test_cast_to_primitive(array, PType::I8);
418
419 test_cast_to_primitive(array, PType::I32);
421 test_cast_to_primitive(array, PType::I64);
422 test_cast_to_primitive(array, PType::F32);
423 test_cast_to_primitive(array, PType::F64);
424
425 test_cast_to_primitive(array, PType::U16);
427}
428
429fn test_cast_from_i32(array: &dyn Array) {
430 test_cast_to_primitive(array, PType::I8);
432 test_cast_to_primitive(array, PType::I16);
433
434 test_cast_to_primitive(array, PType::I64);
436 test_cast_to_primitive(array, PType::F64);
437
438 test_cast_to_primitive(array, PType::F32);
440 test_cast_to_primitive(array, PType::U32);
441}
442
443fn test_cast_from_i64(array: &dyn Array) {
444 test_cast_to_primitive(array, PType::I8);
446 test_cast_to_primitive(array, PType::I16);
447 test_cast_to_primitive(array, PType::I32);
448 test_cast_to_primitive(array, PType::F32);
449
450 test_cast_to_primitive(array, PType::F64);
452 test_cast_to_primitive(array, PType::U64);
453}
454
455fn test_cast_from_f16(array: &dyn Array) {
456 test_cast_to_primitive(array, PType::F32);
458 test_cast_to_primitive(array, PType::F64);
459}
460
461fn test_cast_from_f32(array: &dyn Array) {
462 test_cast_to_primitive(array, PType::F16);
464
465 test_cast_to_primitive(array, PType::F64);
467
468 test_cast_to_integral_types(array);
470}
471
472fn test_cast_from_f64(array: &dyn Array) {
473 test_cast_to_primitive(array, PType::F16);
475 test_cast_to_primitive(array, PType::F32);
476
477 test_cast_to_integral_types(array);
479}
480
481fn test_cast_to_integral_types(array: &dyn Array) {
482 test_cast_to_primitive(array, PType::I8);
485 test_cast_to_primitive(array, PType::U8);
486 test_cast_to_primitive(array, PType::I16);
487 test_cast_to_primitive(array, PType::U16);
488 test_cast_to_primitive(array, PType::I32);
489 test_cast_to_primitive(array, PType::U32);
490 test_cast_to_primitive(array, PType::I64);
491 test_cast_to_primitive(array, PType::U64);
492}
493
494fn test_cast_to_primitive(array: &dyn Array, target_ptype: PType) {
495 let target_dtype = DType::Primitive(target_ptype, array.dtype().nullability());
496 test_cast_to_type_safe(array, &target_dtype);
497}
498
499fn test_cast_to_type_safe(array: &dyn Array, target_dtype: &DType) {
500 let result = match cast(array, target_dtype) {
502 Ok(r) => r,
503 Err(_) => {
504 return;
507 }
508 };
509
510 assert_eq!(result.len(), array.len());
511 assert_eq!(result.dtype(), target_dtype);
512
513 for i in 0..array.len().min(10) {
516 let original = array.scalar_at(i);
517 let casted = result.scalar_at(i);
518
519 if array.dtype().eq_ignore_nullability(target_dtype) {
521 assert_eq!(
522 original, casted,
523 "Value at index {i} changed during nullability cast"
524 );
525 } else {
526 if original.is_null() {
529 assert!(
530 casted.is_null(),
531 "Null value at index {i} became non-null after cast"
532 );
533 } else {
534 assert!(
535 !casted.is_null(),
536 "Non-null value at index {i} became null after cast"
537 );
538 }
539 }
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use vortex_buffer::buffer;
546 use vortex_dtype::{DType, FieldNames, Nullability};
547
548 use super::*;
549 use crate::IntoArray;
550 use crate::arrays::{
551 BoolArray, ListArray, NullArray, PrimitiveArray, StructArray, VarBinArray,
552 };
553
554 #[test]
555 fn test_cast_conformance_u32() {
556 let array = buffer![0u32, 100, 200, 65535, 1000000].into_array();
557 test_cast_conformance(array.as_ref());
558 }
559
560 #[test]
561 fn test_cast_conformance_i32() {
562 let array = buffer![-100i32, -1, 0, 1, 100].into_array();
563 test_cast_conformance(array.as_ref());
564 }
565
566 #[test]
567 fn test_cast_conformance_f32() {
568 let array = buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array();
569 test_cast_conformance(array.as_ref());
570 }
571
572 #[test]
573 fn test_cast_conformance_nullable() {
574 let array = PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]);
575 test_cast_conformance(array.as_ref());
576 }
577
578 #[test]
579 fn test_cast_conformance_bool() {
580 let array = BoolArray::from_iter(vec![true, false, true, false]);
581 test_cast_conformance(array.as_ref());
582 }
583
584 #[test]
585 fn test_cast_conformance_null() {
586 let array = NullArray::new(5);
587 test_cast_conformance(array.as_ref());
588 }
589
590 #[test]
591 fn test_cast_conformance_utf8() {
592 let array = VarBinArray::from_iter(
593 vec![Some("hello"), None, Some("world")],
594 DType::Utf8(Nullability::Nullable),
595 );
596 test_cast_conformance(array.as_ref());
597 }
598
599 #[test]
600 fn test_cast_conformance_binary() {
601 let array = VarBinArray::from_iter(
602 vec![Some(b"data".as_slice()), None, Some(b"bytes".as_slice())],
603 DType::Binary(Nullability::Nullable),
604 );
605 test_cast_conformance(array.as_ref());
606 }
607
608 #[test]
609 fn test_cast_conformance_struct() {
610 let names = FieldNames::from(["a", "b"]);
611
612 let a = buffer![1i32, 2, 3].into_array();
613 let b = VarBinArray::from_iter(
614 vec![Some("x"), None, Some("z")],
615 DType::Utf8(Nullability::Nullable),
616 )
617 .into_array();
618
619 let array =
620 StructArray::try_new(names, vec![a, b], 3, crate::validity::Validity::NonNullable)
621 .unwrap();
622 test_cast_conformance(array.as_ref());
623 }
624
625 #[test]
626 fn test_cast_conformance_list() {
627 let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
628 let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
629
630 let array =
631 ListArray::try_new(data, offsets, crate::validity::Validity::NonNullable).unwrap();
632 test_cast_conformance(array.as_ref());
633 }
634}