vortex_array/compute/conformance/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_dtype::Nullability;
6use vortex_dtype::PType;
7use vortex_error::VortexExpect;
8use vortex_error::vortex_panic;
9use vortex_scalar::Scalar;
10
11use crate::Array;
12use crate::compute::MinMaxResult;
13use crate::compute::cast;
14use crate::compute::min_max;
15
16/// Test conformance of the cast compute function for an array.
17///
18/// This function tests various casting scenarios including:
19/// - Casting between numeric types (widening and narrowing)
20/// - Casting between signed and unsigned types
21/// - Casting between integral and floating-point types
22/// - Casting with nullability changes
23/// - Casting between string types (Utf8/Binary)
24/// - Edge cases like overflow behavior
25pub fn test_cast_conformance(array: &dyn Array) {
26    let dtype = array.dtype();
27
28    // Always test identity cast and nullability changes
29    test_cast_identity(array);
30
31    test_cast_to_non_nullable(array);
32    test_cast_to_nullable(array);
33
34    // Test based on the specific DType
35    match dtype {
36        DType::Null => test_cast_from_null(array),
37        DType::Primitive(ptype, ..) => match ptype {
38            PType::U8
39            | PType::U16
40            | PType::U32
41            | PType::U64
42            | PType::I8
43            | PType::I16
44            | PType::I32
45            | PType::I64 => test_cast_to_integral_types(array),
46            PType::F16 | PType::F32 | PType::F64 => test_cast_from_floating_point_types(array),
47        },
48        _ => {}
49    }
50}
51
52fn test_cast_identity(array: &dyn Array) {
53    // Casting to the same type should be a no-op
54    let result =
55        cast(array, array.dtype()).vortex_expect("cast should succeed in conformance test");
56    assert_eq!(result.len(), array.len());
57    assert_eq!(result.dtype(), array.dtype());
58
59    // Verify values are unchanged
60    for i in 0..array.len().min(10) {
61        assert_eq!(array.scalar_at(i), result.scalar_at(i));
62    }
63}
64
65fn test_cast_from_null(array: &dyn Array) {
66    // Null can be cast to itself
67    let result = cast(array, &DType::Null).vortex_expect("cast should succeed in conformance test");
68    assert_eq!(result.len(), array.len());
69    assert_eq!(result.dtype(), &DType::Null);
70
71    // Null can also be cast to any nullable type
72    let nullable_types = vec![
73        DType::Bool(Nullability::Nullable),
74        DType::Primitive(PType::I32, Nullability::Nullable),
75        DType::Primitive(PType::F64, Nullability::Nullable),
76        DType::Utf8(Nullability::Nullable),
77        DType::Binary(Nullability::Nullable),
78    ];
79
80    for dtype in nullable_types {
81        let result = cast(array, &dtype).vortex_expect("cast should succeed in conformance test");
82        assert_eq!(result.len(), array.len());
83        assert_eq!(result.dtype(), &dtype);
84
85        // Verify all values are null
86        for i in 0..array.len().min(10) {
87            assert!(result.scalar_at(i).is_null());
88        }
89    }
90
91    // Casting to non-nullable types should fail
92    let non_nullable_types = vec![
93        DType::Bool(Nullability::NonNullable),
94        DType::Primitive(PType::I32, Nullability::NonNullable),
95    ];
96
97    for dtype in non_nullable_types {
98        assert!(cast(array, &dtype).is_err());
99    }
100}
101
102fn test_cast_to_non_nullable(array: &dyn Array) {
103    if array.invalid_count() == 0 {
104        let non_nullable = cast(array, &array.dtype().as_nonnullable())
105            .vortex_expect("arrays without nulls can cast to non-nullable");
106        assert_eq!(non_nullable.dtype(), &array.dtype().as_nonnullable());
107        assert_eq!(non_nullable.len(), array.len());
108
109        for i in 0..array.len().min(10) {
110            assert_eq!(array.scalar_at(i), non_nullable.scalar_at(i));
111        }
112
113        let back_to_nullable = cast(&non_nullable, array.dtype())
114            .vortex_expect("non-nullable arrays can cast to nullable");
115        assert_eq!(back_to_nullable.dtype(), array.dtype());
116        assert_eq!(back_to_nullable.len(), array.len());
117
118        for i in 0..array.len().min(10) {
119            assert_eq!(array.scalar_at(i), back_to_nullable.scalar_at(i));
120        }
121    } else {
122        if &DType::Null == array.dtype() {
123            // DType::Null.as_nonnullable() (confusingly) returns DType:Null. Of course, a null
124            // array can be casted to DType::Null.
125            return;
126        }
127        cast(array, &array.dtype().as_nonnullable())
128            .err()
129            .unwrap_or_else(|| {
130                vortex_panic!(
131                    "arrays with nulls should error when casting to non-nullable {}",
132                    array,
133                )
134            });
135    }
136}
137
138fn test_cast_to_nullable(array: &dyn Array) {
139    let nullable = cast(array, &array.dtype().as_nullable())
140        .vortex_expect("arrays without nulls can cast to nullable");
141    assert_eq!(nullable.dtype(), &array.dtype().as_nullable());
142    assert_eq!(nullable.len(), array.len());
143
144    for i in 0..array.len().min(10) {
145        assert_eq!(array.scalar_at(i), nullable.scalar_at(i));
146    }
147
148    let back = cast(&nullable, array.dtype())
149        .vortex_expect("casting to nullable and back should be a no-op");
150    assert_eq!(back.dtype(), array.dtype());
151    assert_eq!(back.len(), array.len());
152
153    for i in 0..array.len().min(10) {
154        assert_eq!(array.scalar_at(i), back.scalar_at(i));
155    }
156}
157
158fn test_cast_from_floating_point_types(array: &dyn Array) {
159    let ptype = array.as_primitive_typed().ptype();
160    test_cast_to_primitive(array, PType::I8, false);
161    test_cast_to_primitive(array, PType::U8, false);
162    test_cast_to_primitive(array, PType::I16, false);
163    test_cast_to_primitive(array, PType::U16, false);
164    test_cast_to_primitive(array, PType::I32, false);
165    test_cast_to_primitive(array, PType::U32, false);
166    test_cast_to_primitive(array, PType::I64, false);
167    test_cast_to_primitive(array, PType::U64, false);
168    test_cast_to_primitive(array, PType::F16, matches!(ptype, PType::F16));
169    test_cast_to_primitive(array, PType::F32, matches!(ptype, PType::F16 | PType::F32));
170    test_cast_to_primitive(array, PType::F64, true);
171}
172
173fn test_cast_to_integral_types(array: &dyn Array) {
174    test_cast_to_primitive(array, PType::I8, true);
175    test_cast_to_primitive(array, PType::U8, true);
176    test_cast_to_primitive(array, PType::I16, true);
177    test_cast_to_primitive(array, PType::U16, true);
178    test_cast_to_primitive(array, PType::I32, true);
179    test_cast_to_primitive(array, PType::U32, true);
180    test_cast_to_primitive(array, PType::I64, true);
181    test_cast_to_primitive(array, PType::U64, true);
182}
183
184/// Does this scalar fit in this type?
185fn fits(value: &Scalar, ptype: PType) -> bool {
186    let dtype = DType::Primitive(ptype, value.dtype().nullability());
187    value.cast(&dtype).is_ok()
188}
189
190fn test_cast_to_primitive(array: &dyn Array, target_ptype: PType, test_round_trip: bool) {
191    let maybe_min_max = min_max(array).vortex_expect("cast should succeed in conformance test");
192
193    if let Some(MinMaxResult { min, max }) = maybe_min_max
194        && (!fits(&min, target_ptype) || !fits(&max, target_ptype))
195    {
196        cast(
197            array,
198            &DType::Primitive(target_ptype, array.dtype().nullability()),
199        )
200        .err()
201        .unwrap_or_else(|| {
202            vortex_panic!(
203                "Cast must fail because some values are out of bounds. {} {:?} {:?} {} {}",
204                target_ptype,
205                min,
206                max,
207                array,
208                array.display_values(),
209            )
210        });
211        return;
212    }
213
214    // Otherwise, all values must fit.
215    let casted = cast(
216        array,
217        &DType::Primitive(target_ptype, array.dtype().nullability()),
218    )
219    .unwrap_or_else(|e| {
220        vortex_panic!(
221            "Cast must succeed because all values are within bounds. {} {}: {e}",
222            target_ptype,
223            array.display_values(),
224        )
225    });
226    assert_eq!(array.validity_mask(), casted.validity_mask());
227    for i in 0..array.len().min(10) {
228        let original = array.scalar_at(i);
229        let casted = casted.scalar_at(i);
230        assert_eq!(
231            original
232                .cast(casted.dtype())
233                .vortex_expect("cast should succeed in conformance test"),
234            casted,
235            "{i} {original} {casted}"
236        );
237        if test_round_trip {
238            assert_eq!(
239                original,
240                casted
241                    .cast(original.dtype())
242                    .vortex_expect("cast should succeed in conformance test"),
243                "{i} {original} {casted}"
244            );
245        }
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use vortex_buffer::buffer;
252    use vortex_dtype::DType;
253    use vortex_dtype::FieldNames;
254    use vortex_dtype::Nullability;
255
256    use super::*;
257    use crate::IntoArray;
258    use crate::arrays::BoolArray;
259    use crate::arrays::ListArray;
260    use crate::arrays::NullArray;
261    use crate::arrays::PrimitiveArray;
262    use crate::arrays::StructArray;
263    use crate::arrays::VarBinArray;
264
265    #[test]
266    fn test_cast_conformance_u32() {
267        let array = buffer![0u32, 100, 200, 65535, 1000000].into_array();
268        test_cast_conformance(array.as_ref());
269    }
270
271    #[test]
272    fn test_cast_conformance_i32() {
273        let array = buffer![-100i32, -1, 0, 1, 100].into_array();
274        test_cast_conformance(array.as_ref());
275    }
276
277    #[test]
278    fn test_cast_conformance_f32() {
279        let array = buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array();
280        test_cast_conformance(array.as_ref());
281    }
282
283    #[test]
284    fn test_cast_conformance_nullable() {
285        let array = PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]);
286        test_cast_conformance(array.as_ref());
287    }
288
289    #[test]
290    fn test_cast_conformance_bool() {
291        let array = BoolArray::from_iter(vec![true, false, true, false]);
292        test_cast_conformance(array.as_ref());
293    }
294
295    #[test]
296    fn test_cast_conformance_null() {
297        let array = NullArray::new(5);
298        test_cast_conformance(array.as_ref());
299    }
300
301    #[test]
302    fn test_cast_conformance_utf8() {
303        let array = VarBinArray::from_iter(
304            vec![Some("hello"), None, Some("world")],
305            DType::Utf8(Nullability::Nullable),
306        );
307        test_cast_conformance(array.as_ref());
308    }
309
310    #[test]
311    fn test_cast_conformance_binary() {
312        let array = VarBinArray::from_iter(
313            vec![Some(b"data".as_slice()), None, Some(b"bytes".as_slice())],
314            DType::Binary(Nullability::Nullable),
315        );
316        test_cast_conformance(array.as_ref());
317    }
318
319    #[test]
320    fn test_cast_conformance_struct() {
321        let names = FieldNames::from(["a", "b"]);
322
323        let a = buffer![1i32, 2, 3].into_array();
324        let b = VarBinArray::from_iter(
325            vec![Some("x"), None, Some("z")],
326            DType::Utf8(Nullability::Nullable),
327        )
328        .into_array();
329
330        let array =
331            StructArray::try_new(names, vec![a, b], 3, crate::validity::Validity::NonNullable)
332                .unwrap();
333        test_cast_conformance(array.as_ref());
334    }
335
336    #[test]
337    fn test_cast_conformance_list() {
338        let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
339        let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
340
341        let array =
342            ListArray::try_new(data, offsets, crate::validity::Validity::NonNullable).unwrap();
343        test_cast_conformance(array.as_ref());
344    }
345}