vortex_array/compute/conformance/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::{DType, Nullability, PType};
5use vortex_error::{VortexExpect as _, VortexUnwrap, vortex_panic};
6use vortex_scalar::Scalar;
7
8use crate::Array;
9use crate::compute::{MinMaxResult, cast, min_max};
10
11/// Test conformance of the cast compute function for an array.
12///
13/// This function tests various casting scenarios including:
14/// - Casting between numeric types (widening and narrowing)
15/// - Casting between signed and unsigned types
16/// - Casting between integral and floating-point types
17/// - Casting with nullability changes
18/// - Casting between string types (Utf8/Binary)
19/// - Edge cases like overflow behavior
20pub fn test_cast_conformance(array: &dyn Array) {
21    let dtype = array.dtype();
22
23    // Always test identity cast and nullability changes
24    test_cast_identity(array);
25
26    test_cast_to_non_nullable(array);
27    test_cast_to_nullable(array);
28
29    // Test based on the specific DType
30    match dtype {
31        DType::Null => test_cast_from_null(array),
32        DType::Primitive(ptype, ..) => match ptype {
33            PType::U8
34            | PType::U16
35            | PType::U32
36            | PType::U64
37            | PType::I8
38            | PType::I16
39            | PType::I32
40            | PType::I64 => test_cast_to_integral_types(array),
41            PType::F16 | PType::F32 | PType::F64 => test_cast_from_floating_point_types(array),
42        },
43        _ => {}
44    }
45}
46
47fn test_cast_identity(array: &dyn Array) {
48    // Casting to the same type should be a no-op
49    let result = cast(array, array.dtype()).vortex_unwrap();
50    assert_eq!(result.len(), array.len());
51    assert_eq!(result.dtype(), array.dtype());
52
53    // Verify values are unchanged
54    for i in 0..array.len().min(10) {
55        assert_eq!(array.scalar_at(i), result.scalar_at(i));
56    }
57}
58
59fn test_cast_from_null(array: &dyn Array) {
60    // Null can be cast to itself
61    let result = cast(array, &DType::Null).vortex_unwrap();
62    assert_eq!(result.len(), array.len());
63    assert_eq!(result.dtype(), &DType::Null);
64
65    // Null can also be cast to any nullable type
66    let nullable_types = vec![
67        DType::Bool(Nullability::Nullable),
68        DType::Primitive(PType::I32, Nullability::Nullable),
69        DType::Primitive(PType::F64, Nullability::Nullable),
70        DType::Utf8(Nullability::Nullable),
71        DType::Binary(Nullability::Nullable),
72    ];
73
74    for dtype in nullable_types {
75        let result = cast(array, &dtype).vortex_unwrap();
76        assert_eq!(result.len(), array.len());
77        assert_eq!(result.dtype(), &dtype);
78
79        // Verify all values are null
80        for i in 0..array.len().min(10) {
81            assert!(result.scalar_at(i).is_null());
82        }
83    }
84
85    // Casting to non-nullable types should fail
86    let non_nullable_types = vec![
87        DType::Bool(Nullability::NonNullable),
88        DType::Primitive(PType::I32, Nullability::NonNullable),
89    ];
90
91    for dtype in non_nullable_types {
92        assert!(cast(array, &dtype).is_err());
93    }
94}
95
96fn test_cast_to_non_nullable(array: &dyn Array) {
97    if array.invalid_count() == 0 {
98        let non_nullable = cast(array, &array.dtype().as_nonnullable())
99            .vortex_expect("arrays without nulls can cast to non-nullable");
100        assert_eq!(non_nullable.dtype(), &array.dtype().as_nonnullable());
101        assert_eq!(non_nullable.len(), array.len());
102
103        for i in 0..array.len().min(10) {
104            assert_eq!(array.scalar_at(i), non_nullable.scalar_at(i));
105        }
106
107        let back_to_nullable = cast(&non_nullable, array.dtype())
108            .vortex_expect("non-nullable arrays can cast to nullable");
109        assert_eq!(back_to_nullable.dtype(), array.dtype());
110        assert_eq!(back_to_nullable.len(), array.len());
111
112        for i in 0..array.len().min(10) {
113            assert_eq!(array.scalar_at(i), back_to_nullable.scalar_at(i));
114        }
115    } else {
116        if &DType::Null == array.dtype() {
117            // DType::Null.as_nonnullable() (confusingly) returns DType:Null. Of course, a null
118            // array can be casted to DType::Null.
119            return;
120        }
121        cast(array, &array.dtype().as_nonnullable())
122            .err()
123            .unwrap_or_else(|| {
124                vortex_panic!(
125                    "arrays with nulls should error when casting to non-nullable {}",
126                    array,
127                )
128            });
129    }
130}
131
132fn test_cast_to_nullable(array: &dyn Array) {
133    let nullable = cast(array, &array.dtype().as_nullable())
134        .vortex_expect("arrays without nulls can cast to nullable");
135    assert_eq!(nullable.dtype(), &array.dtype().as_nullable());
136    assert_eq!(nullable.len(), array.len());
137
138    for i in 0..array.len().min(10) {
139        assert_eq!(array.scalar_at(i), nullable.scalar_at(i));
140    }
141
142    let back = cast(&nullable, array.dtype())
143        .vortex_expect("casting to nullable and back should be a no-op");
144    assert_eq!(back.dtype(), array.dtype());
145    assert_eq!(back.len(), array.len());
146
147    for i in 0..array.len().min(10) {
148        assert_eq!(array.scalar_at(i), back.scalar_at(i));
149    }
150}
151
152fn test_cast_from_floating_point_types(array: &dyn Array) {
153    let ptype = array.as_primitive_typed().ptype();
154    test_cast_to_primitive(array, PType::I8, false);
155    test_cast_to_primitive(array, PType::U8, false);
156    test_cast_to_primitive(array, PType::I16, false);
157    test_cast_to_primitive(array, PType::U16, false);
158    test_cast_to_primitive(array, PType::I32, false);
159    test_cast_to_primitive(array, PType::U32, false);
160    test_cast_to_primitive(array, PType::I64, false);
161    test_cast_to_primitive(array, PType::U64, false);
162    test_cast_to_primitive(array, PType::F16, matches!(ptype, PType::F16));
163    test_cast_to_primitive(array, PType::F32, matches!(ptype, PType::F16 | PType::F32));
164    test_cast_to_primitive(array, PType::F64, true);
165}
166
167fn test_cast_to_integral_types(array: &dyn Array) {
168    test_cast_to_primitive(array, PType::I8, true);
169    test_cast_to_primitive(array, PType::U8, true);
170    test_cast_to_primitive(array, PType::I16, true);
171    test_cast_to_primitive(array, PType::U16, true);
172    test_cast_to_primitive(array, PType::I32, true);
173    test_cast_to_primitive(array, PType::U32, true);
174    test_cast_to_primitive(array, PType::I64, true);
175    test_cast_to_primitive(array, PType::U64, true);
176}
177
178/// Does this scalar fit in this type?
179fn fits(value: &Scalar, ptype: PType) -> bool {
180    let dtype = DType::Primitive(ptype, value.dtype().nullability());
181    value.cast(&dtype).is_ok()
182}
183
184fn test_cast_to_primitive(array: &dyn Array, target_ptype: PType, test_round_trip: bool) {
185    let maybe_min_max = min_max(array).vortex_unwrap();
186
187    if let Some(MinMaxResult { min, max }) = maybe_min_max
188        && (!fits(&min, target_ptype) || !fits(&max, target_ptype))
189    {
190        cast(
191            array,
192            &DType::Primitive(target_ptype, array.dtype().nullability()),
193        )
194        .err()
195        .unwrap_or_else(|| {
196            vortex_panic!(
197                "Cast must fail because some values are out of bounds. {} {:?} {:?} {} {}",
198                target_ptype,
199                min,
200                max,
201                array,
202                array.display_values(),
203            )
204        });
205        return;
206    }
207
208    // Otherwise, all values must fit.
209    let casted = cast(
210        array,
211        &DType::Primitive(target_ptype, array.dtype().nullability()),
212    )
213    .unwrap_or_else(|e| {
214        vortex_panic!(
215            "Cast must succeed because all values are within bounds. {} {}: {e}",
216            target_ptype,
217            array.display_values(),
218        )
219    });
220    assert_eq!(array.validity_mask(), casted.validity_mask());
221    for i in 0..array.len().min(10) {
222        let original = array.scalar_at(i);
223        let casted = casted.scalar_at(i);
224        assert_eq!(
225            original.cast(casted.dtype()).vortex_unwrap(),
226            casted,
227            "{i} {original} {casted}"
228        );
229        if test_round_trip {
230            assert_eq!(
231                original,
232                casted.cast(original.dtype()).vortex_unwrap(),
233                "{i} {original} {casted}"
234            );
235        }
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use vortex_buffer::buffer;
242    use vortex_dtype::{DType, FieldNames, Nullability};
243
244    use super::*;
245    use crate::IntoArray;
246    use crate::arrays::{
247        BoolArray, ListArray, NullArray, PrimitiveArray, StructArray, VarBinArray,
248    };
249
250    #[test]
251    fn test_cast_conformance_u32() {
252        let array = buffer![0u32, 100, 200, 65535, 1000000].into_array();
253        test_cast_conformance(array.as_ref());
254    }
255
256    #[test]
257    fn test_cast_conformance_i32() {
258        let array = buffer![-100i32, -1, 0, 1, 100].into_array();
259        test_cast_conformance(array.as_ref());
260    }
261
262    #[test]
263    fn test_cast_conformance_f32() {
264        let array = buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array();
265        test_cast_conformance(array.as_ref());
266    }
267
268    #[test]
269    fn test_cast_conformance_nullable() {
270        let array = PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]);
271        test_cast_conformance(array.as_ref());
272    }
273
274    #[test]
275    fn test_cast_conformance_bool() {
276        let array = BoolArray::from_iter(vec![true, false, true, false]);
277        test_cast_conformance(array.as_ref());
278    }
279
280    #[test]
281    fn test_cast_conformance_null() {
282        let array = NullArray::new(5);
283        test_cast_conformance(array.as_ref());
284    }
285
286    #[test]
287    fn test_cast_conformance_utf8() {
288        let array = VarBinArray::from_iter(
289            vec![Some("hello"), None, Some("world")],
290            DType::Utf8(Nullability::Nullable),
291        );
292        test_cast_conformance(array.as_ref());
293    }
294
295    #[test]
296    fn test_cast_conformance_binary() {
297        let array = VarBinArray::from_iter(
298            vec![Some(b"data".as_slice()), None, Some(b"bytes".as_slice())],
299            DType::Binary(Nullability::Nullable),
300        );
301        test_cast_conformance(array.as_ref());
302    }
303
304    #[test]
305    fn test_cast_conformance_struct() {
306        let names = FieldNames::from(["a", "b"]);
307
308        let a = buffer![1i32, 2, 3].into_array();
309        let b = VarBinArray::from_iter(
310            vec![Some("x"), None, Some("z")],
311            DType::Utf8(Nullability::Nullable),
312        )
313        .into_array();
314
315        let array =
316            StructArray::try_new(names, vec![a, b], 3, crate::validity::Validity::NonNullable)
317                .unwrap();
318        test_cast_conformance(array.as_ref());
319    }
320
321    #[test]
322    fn test_cast_conformance_list() {
323        let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
324        let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
325
326        let array =
327            ListArray::try_new(data, offsets, crate::validity::Validity::NonNullable).unwrap();
328        test_cast_conformance(array.as_ref());
329    }
330}