vortex_array/compute/conformance/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_dtype::Nullability;
6use vortex_dtype::PType;
7use vortex_error::VortexExpect as _;
8use vortex_error::VortexUnwrap;
9use vortex_error::vortex_panic;
10use vortex_scalar::Scalar;
11
12use crate::Array;
13use crate::compute::MinMaxResult;
14use crate::compute::cast;
15use crate::compute::min_max;
16
17/// Test conformance of the cast compute function for an array.
18///
19/// This function tests various casting scenarios including:
20/// - Casting between numeric types (widening and narrowing)
21/// - Casting between signed and unsigned types
22/// - Casting between integral and floating-point types
23/// - Casting with nullability changes
24/// - Casting between string types (Utf8/Binary)
25/// - Edge cases like overflow behavior
26pub fn test_cast_conformance(array: &dyn Array) {
27    let dtype = array.dtype();
28
29    // Always test identity cast and nullability changes
30    test_cast_identity(array);
31
32    test_cast_to_non_nullable(array);
33    test_cast_to_nullable(array);
34
35    // Test based on the specific DType
36    match dtype {
37        DType::Null => test_cast_from_null(array),
38        DType::Primitive(ptype, ..) => match ptype {
39            PType::U8
40            | PType::U16
41            | PType::U32
42            | PType::U64
43            | PType::I8
44            | PType::I16
45            | PType::I32
46            | PType::I64 => test_cast_to_integral_types(array),
47            PType::F16 | PType::F32 | PType::F64 => test_cast_from_floating_point_types(array),
48        },
49        _ => {}
50    }
51}
52
53fn test_cast_identity(array: &dyn Array) {
54    // Casting to the same type should be a no-op
55    let result = cast(array, array.dtype()).vortex_unwrap();
56    assert_eq!(result.len(), array.len());
57    assert_eq!(result.dtype(), array.dtype());
58
59    // Verify values are unchanged
60    for i in 0..array.len().min(10) {
61        assert_eq!(array.scalar_at(i), result.scalar_at(i));
62    }
63}
64
65fn test_cast_from_null(array: &dyn Array) {
66    // Null can be cast to itself
67    let result = cast(array, &DType::Null).vortex_unwrap();
68    assert_eq!(result.len(), array.len());
69    assert_eq!(result.dtype(), &DType::Null);
70
71    // Null can also be cast to any nullable type
72    let nullable_types = vec![
73        DType::Bool(Nullability::Nullable),
74        DType::Primitive(PType::I32, Nullability::Nullable),
75        DType::Primitive(PType::F64, Nullability::Nullable),
76        DType::Utf8(Nullability::Nullable),
77        DType::Binary(Nullability::Nullable),
78    ];
79
80    for dtype in nullable_types {
81        let result = cast(array, &dtype).vortex_unwrap();
82        assert_eq!(result.len(), array.len());
83        assert_eq!(result.dtype(), &dtype);
84
85        // Verify all values are null
86        for i in 0..array.len().min(10) {
87            assert!(result.scalar_at(i).is_null());
88        }
89    }
90
91    // Casting to non-nullable types should fail
92    let non_nullable_types = vec![
93        DType::Bool(Nullability::NonNullable),
94        DType::Primitive(PType::I32, Nullability::NonNullable),
95    ];
96
97    for dtype in non_nullable_types {
98        assert!(cast(array, &dtype).is_err());
99    }
100}
101
102fn test_cast_to_non_nullable(array: &dyn Array) {
103    if array.invalid_count() == 0 {
104        let non_nullable = cast(array, &array.dtype().as_nonnullable())
105            .vortex_expect("arrays without nulls can cast to non-nullable");
106        assert_eq!(non_nullable.dtype(), &array.dtype().as_nonnullable());
107        assert_eq!(non_nullable.len(), array.len());
108
109        for i in 0..array.len().min(10) {
110            assert_eq!(array.scalar_at(i), non_nullable.scalar_at(i));
111        }
112
113        let back_to_nullable = cast(&non_nullable, array.dtype())
114            .vortex_expect("non-nullable arrays can cast to nullable");
115        assert_eq!(back_to_nullable.dtype(), array.dtype());
116        assert_eq!(back_to_nullable.len(), array.len());
117
118        for i in 0..array.len().min(10) {
119            assert_eq!(array.scalar_at(i), back_to_nullable.scalar_at(i));
120        }
121    } else {
122        if &DType::Null == array.dtype() {
123            // DType::Null.as_nonnullable() (confusingly) returns DType:Null. Of course, a null
124            // array can be casted to DType::Null.
125            return;
126        }
127        cast(array, &array.dtype().as_nonnullable())
128            .err()
129            .unwrap_or_else(|| {
130                vortex_panic!(
131                    "arrays with nulls should error when casting to non-nullable {}",
132                    array,
133                )
134            });
135    }
136}
137
138fn test_cast_to_nullable(array: &dyn Array) {
139    let nullable = cast(array, &array.dtype().as_nullable())
140        .vortex_expect("arrays without nulls can cast to nullable");
141    assert_eq!(nullable.dtype(), &array.dtype().as_nullable());
142    assert_eq!(nullable.len(), array.len());
143
144    for i in 0..array.len().min(10) {
145        assert_eq!(array.scalar_at(i), nullable.scalar_at(i));
146    }
147
148    let back = cast(&nullable, array.dtype())
149        .vortex_expect("casting to nullable and back should be a no-op");
150    assert_eq!(back.dtype(), array.dtype());
151    assert_eq!(back.len(), array.len());
152
153    for i in 0..array.len().min(10) {
154        assert_eq!(array.scalar_at(i), back.scalar_at(i));
155    }
156}
157
158fn test_cast_from_floating_point_types(array: &dyn Array) {
159    let ptype = array.as_primitive_typed().ptype();
160    test_cast_to_primitive(array, PType::I8, false);
161    test_cast_to_primitive(array, PType::U8, false);
162    test_cast_to_primitive(array, PType::I16, false);
163    test_cast_to_primitive(array, PType::U16, false);
164    test_cast_to_primitive(array, PType::I32, false);
165    test_cast_to_primitive(array, PType::U32, false);
166    test_cast_to_primitive(array, PType::I64, false);
167    test_cast_to_primitive(array, PType::U64, false);
168    test_cast_to_primitive(array, PType::F16, matches!(ptype, PType::F16));
169    test_cast_to_primitive(array, PType::F32, matches!(ptype, PType::F16 | PType::F32));
170    test_cast_to_primitive(array, PType::F64, true);
171}
172
173fn test_cast_to_integral_types(array: &dyn Array) {
174    test_cast_to_primitive(array, PType::I8, true);
175    test_cast_to_primitive(array, PType::U8, true);
176    test_cast_to_primitive(array, PType::I16, true);
177    test_cast_to_primitive(array, PType::U16, true);
178    test_cast_to_primitive(array, PType::I32, true);
179    test_cast_to_primitive(array, PType::U32, true);
180    test_cast_to_primitive(array, PType::I64, true);
181    test_cast_to_primitive(array, PType::U64, true);
182}
183
184/// Does this scalar fit in this type?
185fn fits(value: &Scalar, ptype: PType) -> bool {
186    let dtype = DType::Primitive(ptype, value.dtype().nullability());
187    value.cast(&dtype).is_ok()
188}
189
190fn test_cast_to_primitive(array: &dyn Array, target_ptype: PType, test_round_trip: bool) {
191    let maybe_min_max = min_max(array).vortex_unwrap();
192
193    if let Some(MinMaxResult { min, max }) = maybe_min_max
194        && (!fits(&min, target_ptype) || !fits(&max, target_ptype))
195    {
196        cast(
197            array,
198            &DType::Primitive(target_ptype, array.dtype().nullability()),
199        )
200        .err()
201        .unwrap_or_else(|| {
202            vortex_panic!(
203                "Cast must fail because some values are out of bounds. {} {:?} {:?} {} {}",
204                target_ptype,
205                min,
206                max,
207                array,
208                array.display_values(),
209            )
210        });
211        return;
212    }
213
214    // Otherwise, all values must fit.
215    let casted = cast(
216        array,
217        &DType::Primitive(target_ptype, array.dtype().nullability()),
218    )
219    .unwrap_or_else(|e| {
220        vortex_panic!(
221            "Cast must succeed because all values are within bounds. {} {}: {e}",
222            target_ptype,
223            array.display_values(),
224        )
225    });
226    assert_eq!(array.validity_mask(), casted.validity_mask());
227    for i in 0..array.len().min(10) {
228        let original = array.scalar_at(i);
229        let casted = casted.scalar_at(i);
230        assert_eq!(
231            original.cast(casted.dtype()).vortex_unwrap(),
232            casted,
233            "{i} {original} {casted}"
234        );
235        if test_round_trip {
236            assert_eq!(
237                original,
238                casted.cast(original.dtype()).vortex_unwrap(),
239                "{i} {original} {casted}"
240            );
241        }
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use vortex_buffer::buffer;
248    use vortex_dtype::DType;
249    use vortex_dtype::FieldNames;
250    use vortex_dtype::Nullability;
251
252    use super::*;
253    use crate::IntoArray;
254    use crate::arrays::BoolArray;
255    use crate::arrays::ListArray;
256    use crate::arrays::NullArray;
257    use crate::arrays::PrimitiveArray;
258    use crate::arrays::StructArray;
259    use crate::arrays::VarBinArray;
260
261    #[test]
262    fn test_cast_conformance_u32() {
263        let array = buffer![0u32, 100, 200, 65535, 1000000].into_array();
264        test_cast_conformance(array.as_ref());
265    }
266
267    #[test]
268    fn test_cast_conformance_i32() {
269        let array = buffer![-100i32, -1, 0, 1, 100].into_array();
270        test_cast_conformance(array.as_ref());
271    }
272
273    #[test]
274    fn test_cast_conformance_f32() {
275        let array = buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array();
276        test_cast_conformance(array.as_ref());
277    }
278
279    #[test]
280    fn test_cast_conformance_nullable() {
281        let array = PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]);
282        test_cast_conformance(array.as_ref());
283    }
284
285    #[test]
286    fn test_cast_conformance_bool() {
287        let array = BoolArray::from_iter(vec![true, false, true, false]);
288        test_cast_conformance(array.as_ref());
289    }
290
291    #[test]
292    fn test_cast_conformance_null() {
293        let array = NullArray::new(5);
294        test_cast_conformance(array.as_ref());
295    }
296
297    #[test]
298    fn test_cast_conformance_utf8() {
299        let array = VarBinArray::from_iter(
300            vec![Some("hello"), None, Some("world")],
301            DType::Utf8(Nullability::Nullable),
302        );
303        test_cast_conformance(array.as_ref());
304    }
305
306    #[test]
307    fn test_cast_conformance_binary() {
308        let array = VarBinArray::from_iter(
309            vec![Some(b"data".as_slice()), None, Some(b"bytes".as_slice())],
310            DType::Binary(Nullability::Nullable),
311        );
312        test_cast_conformance(array.as_ref());
313    }
314
315    #[test]
316    fn test_cast_conformance_struct() {
317        let names = FieldNames::from(["a", "b"]);
318
319        let a = buffer![1i32, 2, 3].into_array();
320        let b = VarBinArray::from_iter(
321            vec![Some("x"), None, Some("z")],
322            DType::Utf8(Nullability::Nullable),
323        )
324        .into_array();
325
326        let array =
327            StructArray::try_new(names, vec![a, b], 3, crate::validity::Validity::NonNullable)
328                .unwrap();
329        test_cast_conformance(array.as_ref());
330    }
331
332    #[test]
333    fn test_cast_conformance_list() {
334        let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
335        let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
336
337        let array =
338            ListArray::try_new(data, offsets, crate::validity::Validity::NonNullable).unwrap();
339        test_cast_conformance(array.as_ref());
340    }
341}