vortex_array/compute/conformance/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::{DType, Nullability, PType};
5use vortex_error::VortexUnwrap;
6
7use crate::Array;
8use crate::compute::cast;
9
10/// Test conformance of the cast compute function for an array.
11///
12/// This function tests various casting scenarios including:
13/// - Casting between numeric types (widening and narrowing)
14/// - Casting between signed and unsigned types
15/// - Casting between integral and floating-point types
16/// - Casting with nullability changes
17/// - Casting between string types (Utf8/Binary)
18/// - Edge cases like overflow behavior
19pub fn test_cast_conformance(array: &dyn Array) {
20    let dtype = array.dtype();
21
22    // Always test identity cast and nullability changes
23    test_cast_identity(array);
24
25    // Test AllValid to NonNullable and back if applicable
26    test_cast_allvalid_to_nonnullable_and_back(array);
27
28    // Test based on the specific DType
29    match dtype {
30        DType::Null => test_cast_from_null(array),
31        DType::Bool(nullability) => test_cast_from_bool(array, *nullability),
32        DType::Primitive(ptype, nullability) => {
33            test_cast_nullability_changes_primitive(array, *ptype, *nullability);
34            match ptype {
35                PType::U8 => test_cast_from_u8(array),
36                PType::U16 => test_cast_from_u16(array),
37                PType::U32 => test_cast_from_u32(array),
38                PType::U64 => test_cast_from_u64(array),
39                PType::I8 => test_cast_from_i8(array),
40                PType::I16 => test_cast_from_i16(array),
41                PType::I32 => test_cast_from_i32(array),
42                PType::I64 => test_cast_from_i64(array),
43                PType::F16 => test_cast_from_f16(array),
44                PType::F32 => test_cast_from_f32(array),
45                PType::F64 => test_cast_from_f64(array),
46            }
47        }
48        DType::Decimal(_, nullability) => test_cast_from_decimal(array, *nullability),
49        DType::Utf8(nullability) => test_cast_from_utf8(array, *nullability),
50        DType::Binary(nullability) => test_cast_from_binary(array, *nullability),
51        DType::Struct(_, nullability) => test_cast_from_struct(array, *nullability),
52        DType::List(_, nullability) => test_cast_from_list(array, *nullability),
53        DType::FixedSizeList(.., nullability) => {
54            test_cast_from_fixed_size_list(array, *nullability)
55        }
56        DType::Extension(_) => test_cast_from_extension(array),
57    }
58}
59
60fn test_cast_identity(array: &dyn Array) {
61    // Casting to the same type should be a no-op
62    let result = cast(array, array.dtype()).vortex_unwrap();
63    assert_eq!(result.len(), array.len());
64    assert_eq!(result.dtype(), array.dtype());
65
66    // Verify values are unchanged
67    for i in 0..array.len().min(10) {
68        assert_eq!(array.scalar_at(i), result.scalar_at(i),);
69    }
70}
71
72fn test_cast_from_null(array: &dyn Array) {
73    // Null can be cast to itself
74    let result = cast(array, &DType::Null).vortex_unwrap();
75    assert_eq!(result.len(), array.len());
76    assert_eq!(result.dtype(), &DType::Null);
77
78    // Null can also be cast to any nullable type
79    let nullable_types = vec![
80        DType::Bool(Nullability::Nullable),
81        DType::Primitive(PType::I32, Nullability::Nullable),
82        DType::Primitive(PType::F64, Nullability::Nullable),
83        DType::Utf8(Nullability::Nullable),
84        DType::Binary(Nullability::Nullable),
85    ];
86
87    for dtype in nullable_types {
88        let result = cast(array, &dtype).vortex_unwrap();
89        assert_eq!(result.len(), array.len());
90        assert_eq!(result.dtype(), &dtype);
91
92        // Verify all values are null
93        for i in 0..array.len().min(10) {
94            assert!(result.scalar_at(i).is_null());
95        }
96    }
97
98    // Casting to non-nullable types should fail
99    let non_nullable_types = vec![
100        DType::Bool(Nullability::NonNullable),
101        DType::Primitive(PType::I32, Nullability::NonNullable),
102    ];
103
104    for dtype in non_nullable_types {
105        assert!(cast(array, &dtype).is_err());
106    }
107}
108
109fn test_cast_from_bool(array: &dyn Array, nullability: Nullability) {
110    // Test nullability changes
111    test_cast_nullability_changes(array, &DType::Bool(Nullability::Nullable));
112    if nullability == Nullability::Nullable {
113        // Try casting to non-nullable (may fail if nulls present)
114        let _ = cast(array, &DType::Bool(Nullability::NonNullable));
115    }
116
117    // Test bool to numeric casts (true -> 1, false -> 0)
118    test_cast_to_primitive(array, PType::U8);
119    test_cast_to_primitive(array, PType::I32);
120    test_cast_to_primitive(array, PType::F32);
121}
122
123fn test_cast_from_decimal(array: &dyn Array, nullability: Nullability) {
124    // Test nullability changes for the same decimal type
125    if let DType::Decimal(decimal_type, _) = array.dtype() {
126        test_cast_nullability_changes(array, &DType::Decimal(*decimal_type, Nullability::Nullable));
127        if nullability == Nullability::Nullable {
128            // Try casting to non-nullable (may fail if nulls present)
129            let _ = cast(
130                array,
131                &DType::Decimal(*decimal_type, Nullability::NonNullable),
132            );
133        }
134    }
135}
136
137fn test_cast_from_utf8(array: &dyn Array, nullability: Nullability) {
138    // Test nullability changes
139    test_cast_nullability_changes(array, &DType::Utf8(Nullability::Nullable));
140    if nullability == Nullability::Nullable {
141        // Try casting to non-nullable (may fail if nulls present)
142        let _ = cast(array, &DType::Utf8(Nullability::NonNullable));
143    }
144
145    // UTF-8 strings can potentially be cast to Binary
146    test_cast_to_type_safe(array, &DType::Binary(nullability));
147}
148
149fn test_cast_from_binary(array: &dyn Array, nullability: Nullability) {
150    // Test nullability changes
151    test_cast_nullability_changes(array, &DType::Binary(Nullability::Nullable));
152    if nullability == Nullability::Nullable {
153        // Try casting to non-nullable (may fail if nulls present)
154        let _ = cast(array, &DType::Binary(Nullability::NonNullable));
155    }
156
157    // Binary might be castable to UTF-8 if it contains valid UTF-8
158    test_cast_to_type_safe(array, &DType::Utf8(nullability));
159}
160
161fn test_cast_from_struct(array: &dyn Array, nullability: Nullability) {
162    // Test nullability changes for the same struct type
163    if let DType::Struct(fields, _) = array.dtype() {
164        test_cast_nullability_changes(array, &DType::Struct(fields.clone(), Nullability::Nullable));
165        if nullability == Nullability::Nullable {
166            // Try casting to non-nullable (may fail if nulls present)
167            let _ = cast(
168                array,
169                &DType::Struct(fields.clone(), Nullability::NonNullable),
170            );
171        }
172    }
173}
174
175fn test_cast_from_list(array: &dyn Array, nullability: Nullability) {
176    // Test nullability changes for the same list type
177    if let DType::List(element_type, _) = array.dtype() {
178        test_cast_nullability_changes(
179            array,
180            &DType::List(element_type.clone(), Nullability::Nullable),
181        );
182        if nullability == Nullability::Nullable {
183            // Try casting to non-nullable (may fail if nulls present)
184            let _ = cast(
185                array,
186                &DType::List(element_type.clone(), Nullability::NonNullable),
187            );
188        }
189    }
190}
191
192fn test_cast_from_fixed_size_list(array: &dyn Array, nullability: Nullability) {
193    // Test nullability changes for the same fixed-size list type
194    if let DType::FixedSizeList(element_type, list_size, ..) = array.dtype() {
195        test_cast_nullability_changes(
196            array,
197            &DType::FixedSizeList(element_type.clone(), *list_size, Nullability::Nullable),
198        );
199        if nullability == Nullability::Nullable {
200            // Try casting to non-nullable (may fail if nulls present)
201            let _ = cast(
202                array,
203                &DType::FixedSizeList(element_type.clone(), *list_size, Nullability::NonNullable),
204            );
205        }
206    }
207}
208
209fn test_cast_from_extension(array: &dyn Array) {
210    // Extension types typically only cast to themselves
211    // The specific casting rules depend on the extension type
212    if let DType::Extension(ext_dtype) = array.dtype() {
213        let result = cast(array, &DType::Extension(ext_dtype.clone())).vortex_unwrap();
214        assert_eq!(result.len(), array.len());
215        assert_eq!(result.dtype(), array.dtype());
216    }
217}
218
219fn test_cast_allvalid_to_nonnullable_and_back(array: &dyn Array) {
220    // Skip if array is null type (special case)
221    if array.dtype() == &DType::Null {
222        return;
223    }
224
225    // Only test if array has no nulls
226    if array.invalid_count() == 0 {
227        // Test casting to NonNullable if currently Nullable
228        if array.dtype().nullability() == Nullability::Nullable {
229            let non_nullable_dtype = array.dtype().with_nullability(Nullability::NonNullable);
230
231            // Cast to NonNullable
232            if let Ok(non_nullable) = cast(array, &non_nullable_dtype) {
233                assert_eq!(non_nullable.dtype(), &non_nullable_dtype);
234                assert_eq!(non_nullable.len(), array.len());
235
236                // Cast back to Nullable
237                let nullable_dtype = array.dtype().with_nullability(Nullability::Nullable);
238                let back_to_nullable = cast(&non_nullable, &nullable_dtype).vortex_unwrap();
239                assert_eq!(back_to_nullable.dtype(), &nullable_dtype);
240                assert_eq!(back_to_nullable.len(), array.len());
241
242                // Verify values are unchanged
243                for i in 0..array.len().min(10) {
244                    assert_eq!(array.scalar_at(i), back_to_nullable.scalar_at(i));
245                }
246            }
247        }
248        // Test casting to Nullable if currently NonNullable
249        else if array.dtype().nullability() == Nullability::NonNullable {
250            let nullable_dtype = array.dtype().with_nullability(Nullability::Nullable);
251
252            // Cast to Nullable
253            let nullable = cast(array, &nullable_dtype).vortex_unwrap();
254            assert_eq!(nullable.dtype(), &nullable_dtype);
255            assert_eq!(nullable.len(), array.len());
256
257            // Cast back to NonNullable
258            let non_nullable_dtype = array.dtype().with_nullability(Nullability::NonNullable);
259            let back_to_non_nullable = cast(&nullable, &non_nullable_dtype).vortex_unwrap();
260            assert_eq!(back_to_non_nullable.dtype(), &non_nullable_dtype);
261            assert_eq!(back_to_non_nullable.len(), array.len());
262
263            // Verify values are unchanged
264            for i in 0..array.len().min(10) {
265                assert_eq!(array.scalar_at(i), back_to_non_nullable.scalar_at(i));
266            }
267        }
268    }
269}
270
271fn test_cast_nullability_changes(array: &dyn Array, nullable_version: &DType) {
272    // Test casting to nullable version
273    if array.dtype().nullability() == Nullability::NonNullable {
274        let result = cast(array, nullable_version).vortex_unwrap();
275        assert_eq!(result.len(), array.len());
276        assert_eq!(result.dtype(), nullable_version);
277
278        // IMPORTANT: Nullability casting should preserve the encoding
279        assert_eq!(
280            result.encoding().id(),
281            array.encoding().id(),
282            "Nullability cast should preserve encoding"
283        );
284
285        // Values should be unchanged
286        for i in 0..array.len().min(10) {
287            assert_eq!(array.scalar_at(i), result.scalar_at(i),);
288        }
289    }
290}
291
292fn test_cast_nullability_changes_primitive(
293    array: &dyn Array,
294    ptype: PType,
295    nullability: Nullability,
296) {
297    // Test casting to nullable version
298    if nullability == Nullability::NonNullable {
299        let nullable_dtype = DType::Primitive(ptype, Nullability::Nullable);
300        let result = cast(array, &nullable_dtype).vortex_unwrap();
301        assert_eq!(result.len(), array.len());
302        assert_eq!(result.dtype(), &nullable_dtype);
303
304        // IMPORTANT: Nullability casting should preserve the encoding
305        assert_eq!(
306            result.encoding().id(),
307            array.encoding().id(),
308            "Nullability cast should preserve encoding"
309        );
310
311        // Values should be unchanged
312        for i in 0..array.len().min(10) {
313            assert_eq!(array.scalar_at(i), result.scalar_at(i),);
314        }
315    }
316
317    // Test casting from nullable to non-nullable (only if no nulls present)
318    if nullability == Nullability::Nullable {
319        // Try to cast to non-nullable and see if it succeeds
320        let non_nullable_dtype = DType::Primitive(ptype, Nullability::NonNullable);
321        if let Ok(result) = cast(array, &non_nullable_dtype) {
322            assert_eq!(result.len(), array.len());
323            assert_eq!(result.dtype(), &non_nullable_dtype);
324
325            // IMPORTANT: Nullability casting should preserve the encoding
326            assert_eq!(
327                result.encoding().id(),
328                array.encoding().id(),
329                "Nullability cast should preserve encoding"
330            );
331
332            // Values should be unchanged
333            for i in 0..array.len().min(10) {
334                assert_eq!(array.scalar_at(i), result.scalar_at(i),);
335            }
336        }
337    }
338}
339
340fn test_cast_from_u8(array: &dyn Array) {
341    // Test widening casts
342    test_cast_to_primitive(array, PType::U16);
343    test_cast_to_primitive(array, PType::U32);
344    test_cast_to_primitive(array, PType::U64);
345    test_cast_to_primitive(array, PType::I16);
346    test_cast_to_primitive(array, PType::I32);
347    test_cast_to_primitive(array, PType::I64);
348    test_cast_to_primitive(array, PType::F32);
349    test_cast_to_primitive(array, PType::F64);
350
351    // Test same-width cast
352    test_cast_to_primitive(array, PType::I8);
353}
354
355fn test_cast_from_u16(array: &dyn Array) {
356    // Test narrowing cast
357    test_cast_to_primitive(array, PType::U8);
358
359    // Test widening casts
360    test_cast_to_primitive(array, PType::U32);
361    test_cast_to_primitive(array, PType::U64);
362    test_cast_to_primitive(array, PType::I32);
363    test_cast_to_primitive(array, PType::I64);
364    test_cast_to_primitive(array, PType::F32);
365    test_cast_to_primitive(array, PType::F64);
366
367    // Test same-width cast
368    test_cast_to_primitive(array, PType::I16);
369}
370
371fn test_cast_from_u32(array: &dyn Array) {
372    // Test narrowing casts
373    test_cast_to_primitive(array, PType::U8);
374    test_cast_to_primitive(array, PType::U16);
375    test_cast_to_primitive(array, PType::I8);
376    test_cast_to_primitive(array, PType::I16);
377
378    // Test widening casts
379    test_cast_to_primitive(array, PType::U64);
380    test_cast_to_primitive(array, PType::I64);
381    test_cast_to_primitive(array, PType::F64);
382
383    // Test same-width casts
384    test_cast_to_primitive(array, PType::I32);
385    test_cast_to_primitive(array, PType::F32);
386}
387
388fn test_cast_from_u64(array: &dyn Array) {
389    // Test narrowing casts
390    test_cast_to_primitive(array, PType::U8);
391    test_cast_to_primitive(array, PType::U16);
392    test_cast_to_primitive(array, PType::U32);
393    test_cast_to_primitive(array, PType::I8);
394    test_cast_to_primitive(array, PType::I16);
395    test_cast_to_primitive(array, PType::I32);
396    test_cast_to_primitive(array, PType::F32);
397
398    // Test same-width casts
399    test_cast_to_primitive(array, PType::I64);
400    test_cast_to_primitive(array, PType::F64);
401}
402
403fn test_cast_from_i8(array: &dyn Array) {
404    // Test widening casts
405    test_cast_to_primitive(array, PType::I16);
406    test_cast_to_primitive(array, PType::I32);
407    test_cast_to_primitive(array, PType::I64);
408    test_cast_to_primitive(array, PType::F32);
409    test_cast_to_primitive(array, PType::F64);
410
411    // Test same-width cast (may fail for negative values)
412    test_cast_to_primitive(array, PType::U8);
413}
414
415fn test_cast_from_i16(array: &dyn Array) {
416    // Test narrowing cast
417    test_cast_to_primitive(array, PType::I8);
418
419    // Test widening casts
420    test_cast_to_primitive(array, PType::I32);
421    test_cast_to_primitive(array, PType::I64);
422    test_cast_to_primitive(array, PType::F32);
423    test_cast_to_primitive(array, PType::F64);
424
425    // Test same-width cast (may fail for negative values)
426    test_cast_to_primitive(array, PType::U16);
427}
428
429fn test_cast_from_i32(array: &dyn Array) {
430    // Test narrowing casts
431    test_cast_to_primitive(array, PType::I8);
432    test_cast_to_primitive(array, PType::I16);
433
434    // Test widening casts
435    test_cast_to_primitive(array, PType::I64);
436    test_cast_to_primitive(array, PType::F64);
437
438    // Test same-width casts
439    test_cast_to_primitive(array, PType::F32);
440    test_cast_to_primitive(array, PType::U32);
441}
442
443fn test_cast_from_i64(array: &dyn Array) {
444    // Test narrowing casts
445    test_cast_to_primitive(array, PType::I8);
446    test_cast_to_primitive(array, PType::I16);
447    test_cast_to_primitive(array, PType::I32);
448    test_cast_to_primitive(array, PType::F32);
449
450    // Test same-width cast
451    test_cast_to_primitive(array, PType::F64);
452    test_cast_to_primitive(array, PType::U64);
453}
454
455fn test_cast_from_f16(array: &dyn Array) {
456    // Test casts to other float types
457    test_cast_to_primitive(array, PType::F32);
458    test_cast_to_primitive(array, PType::F64);
459}
460
461fn test_cast_from_f32(array: &dyn Array) {
462    // Test narrowing cast
463    test_cast_to_primitive(array, PType::F16);
464
465    // Test widening cast
466    test_cast_to_primitive(array, PType::F64);
467
468    // Test casts to integer types (truncation)
469    test_cast_to_integral_types(array);
470}
471
472fn test_cast_from_f64(array: &dyn Array) {
473    // Test narrowing casts
474    test_cast_to_primitive(array, PType::F16);
475    test_cast_to_primitive(array, PType::F32);
476
477    // Test casts to integer types (truncation)
478    test_cast_to_integral_types(array);
479}
480
481fn test_cast_to_integral_types(array: &dyn Array) {
482    // Test casting to all integral types
483    // Some may fail due to out-of-range values
484    test_cast_to_primitive(array, PType::I8);
485    test_cast_to_primitive(array, PType::U8);
486    test_cast_to_primitive(array, PType::I16);
487    test_cast_to_primitive(array, PType::U16);
488    test_cast_to_primitive(array, PType::I32);
489    test_cast_to_primitive(array, PType::U32);
490    test_cast_to_primitive(array, PType::I64);
491    test_cast_to_primitive(array, PType::U64);
492}
493
494fn test_cast_to_primitive(array: &dyn Array, target_ptype: PType) {
495    let target_dtype = DType::Primitive(target_ptype, array.dtype().nullability());
496    test_cast_to_type_safe(array, &target_dtype);
497}
498
499fn test_cast_to_type_safe(array: &dyn Array, target_dtype: &DType) {
500    // Attempt the cast
501    let result = match cast(array, target_dtype) {
502        Ok(r) => r,
503        Err(_) => {
504            // Some casts may fail (e.g., negative to unsigned, out-of-range values)
505            // This is expected behavior
506            return;
507        }
508    };
509
510    assert_eq!(result.len(), array.len());
511    assert_eq!(result.dtype(), target_dtype);
512
513    // For valid casts, verify the values are correctly converted
514    // We verify up to the first 10 values (or all if less than 10)
515    for i in 0..array.len().min(10) {
516        let original = array.scalar_at(i);
517        let casted = result.scalar_at(i);
518
519        // For nullability-only changes, values should be identical
520        if array.dtype().eq_ignore_nullability(target_dtype) {
521            assert_eq!(
522                original, casted,
523                "Value at index {i} changed during nullability cast"
524            );
525        } else {
526            // For type conversions, at least verify we can retrieve the values
527            // and that null values remain null
528            if original.is_null() {
529                assert!(
530                    casted.is_null(),
531                    "Null value at index {i} became non-null after cast"
532                );
533            } else {
534                assert!(
535                    !casted.is_null(),
536                    "Non-null value at index {i} became null after cast"
537                );
538            }
539        }
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use vortex_buffer::buffer;
546    use vortex_dtype::{DType, FieldNames, Nullability};
547
548    use super::*;
549    use crate::IntoArray;
550    use crate::arrays::{
551        BoolArray, ListArray, NullArray, PrimitiveArray, StructArray, VarBinArray,
552    };
553
554    #[test]
555    fn test_cast_conformance_u32() {
556        let array = buffer![0u32, 100, 200, 65535, 1000000].into_array();
557        test_cast_conformance(array.as_ref());
558    }
559
560    #[test]
561    fn test_cast_conformance_i32() {
562        let array = buffer![-100i32, -1, 0, 1, 100].into_array();
563        test_cast_conformance(array.as_ref());
564    }
565
566    #[test]
567    fn test_cast_conformance_f32() {
568        let array = buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array();
569        test_cast_conformance(array.as_ref());
570    }
571
572    #[test]
573    fn test_cast_conformance_nullable() {
574        let array = PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]);
575        test_cast_conformance(array.as_ref());
576    }
577
578    #[test]
579    fn test_cast_conformance_bool() {
580        let array = BoolArray::from_iter(vec![true, false, true, false]);
581        test_cast_conformance(array.as_ref());
582    }
583
584    #[test]
585    fn test_cast_conformance_null() {
586        let array = NullArray::new(5);
587        test_cast_conformance(array.as_ref());
588    }
589
590    #[test]
591    fn test_cast_conformance_utf8() {
592        let array = VarBinArray::from_iter(
593            vec![Some("hello"), None, Some("world")],
594            DType::Utf8(Nullability::Nullable),
595        );
596        test_cast_conformance(array.as_ref());
597    }
598
599    #[test]
600    fn test_cast_conformance_binary() {
601        let array = VarBinArray::from_iter(
602            vec![Some(b"data".as_slice()), None, Some(b"bytes".as_slice())],
603            DType::Binary(Nullability::Nullable),
604        );
605        test_cast_conformance(array.as_ref());
606    }
607
608    #[test]
609    fn test_cast_conformance_struct() {
610        let names = FieldNames::from(["a", "b"]);
611
612        let a = buffer![1i32, 2, 3].into_array();
613        let b = VarBinArray::from_iter(
614            vec![Some("x"), None, Some("z")],
615            DType::Utf8(Nullability::Nullable),
616        )
617        .into_array();
618
619        let array =
620            StructArray::try_new(names, vec![a, b], 3, crate::validity::Validity::NonNullable)
621                .unwrap();
622        test_cast_conformance(array.as_ref());
623    }
624
625    #[test]
626    fn test_cast_conformance_list() {
627        let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
628        let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
629
630        let array =
631            ListArray::try_new(data, offsets, crate::validity::Validity::NonNullable).unwrap();
632        test_cast_conformance(array.as_ref());
633    }
634}