Skip to main content

vortex_array/compute/conformance/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_dtype::Nullability;
6use vortex_dtype::PType;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_panic;
10
11use crate::Array;
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::builtins::ArrayBuiltins;
15use crate::compute::MinMaxResult;
16use crate::compute::min_max;
17use crate::scalar::Scalar;
18
19/// Cast and force execution via `to_canonical`, returning the canonical array.
20fn cast_and_execute(array: &ArrayRef, dtype: DType) -> VortexResult<ArrayRef> {
21    array.cast(dtype)?.to_canonical().map(|c| c.into_array())
22}
23
24/// Test conformance of the cast compute function for an array.
25///
26/// This function tests various casting scenarios including:
27/// - Casting between numeric types (widening and narrowing)
28/// - Casting between signed and unsigned types
29/// - Casting between integral and floating-point types
30/// - Casting with nullability changes
31/// - Casting between string types (Utf8/Binary)
32/// - Edge cases like overflow behavior
33pub fn test_cast_conformance(array: &dyn Array) {
34    let dtype = array.dtype();
35
36    // Always test identity cast and nullability changes
37    test_cast_identity(array);
38
39    test_cast_to_non_nullable(array);
40    test_cast_to_nullable(array);
41
42    // Test based on the specific DType
43    match dtype {
44        DType::Null => test_cast_from_null(array),
45        DType::Primitive(ptype, ..) => match ptype {
46            PType::U8
47            | PType::U16
48            | PType::U32
49            | PType::U64
50            | PType::I8
51            | PType::I16
52            | PType::I32
53            | PType::I64 => test_cast_to_integral_types(array),
54            PType::F16 | PType::F32 | PType::F64 => test_cast_from_floating_point_types(array),
55        },
56        _ => {}
57    }
58}
59
60fn test_cast_identity(array: &dyn Array) {
61    // Casting to the same type should be a no-op
62    let result = cast_and_execute(&array.to_array(), array.dtype().clone())
63        .vortex_expect("cast should succeed in conformance test");
64    assert_eq!(result.len(), array.len());
65    assert_eq!(result.dtype(), array.dtype());
66
67    // Verify values are unchanged
68    for i in 0..array.len().min(10) {
69        assert_eq!(
70            array
71                .scalar_at(i)
72                .vortex_expect("scalar_at should succeed in conformance test"),
73            result
74                .scalar_at(i)
75                .vortex_expect("scalar_at should succeed in conformance test")
76        );
77    }
78}
79
80fn test_cast_from_null(array: &dyn Array) {
81    // Null can be cast to itself
82    let result = cast_and_execute(&array.to_array(), DType::Null)
83        .vortex_expect("cast should succeed in conformance test");
84    assert_eq!(result.len(), array.len());
85    assert_eq!(result.dtype(), &DType::Null);
86
87    // Null can also be cast to any nullable type
88    let nullable_types = vec![
89        DType::Bool(Nullability::Nullable),
90        DType::Primitive(PType::I32, Nullability::Nullable),
91        DType::Primitive(PType::F64, Nullability::Nullable),
92        DType::Utf8(Nullability::Nullable),
93        DType::Binary(Nullability::Nullable),
94    ];
95
96    for dtype in nullable_types {
97        let result = cast_and_execute(&array.to_array(), dtype.clone())
98            .vortex_expect("cast should succeed in conformance test");
99        assert_eq!(result.len(), array.len());
100        assert_eq!(result.dtype(), &dtype);
101
102        // Verify all values are null
103        for i in 0..array.len().min(10) {
104            assert!(
105                result
106                    .scalar_at(i)
107                    .vortex_expect("scalar_at should succeed in conformance test")
108                    .is_null()
109            );
110        }
111    }
112
113    // Casting to non-nullable types should fail
114    let non_nullable_types = vec![
115        DType::Bool(Nullability::NonNullable),
116        DType::Primitive(PType::I32, Nullability::NonNullable),
117    ];
118
119    for dtype in non_nullable_types {
120        assert!(cast_and_execute(&array.to_array(), dtype.clone()).is_err());
121    }
122}
123
124fn test_cast_to_non_nullable(array: &dyn Array) {
125    if array
126        .invalid_count()
127        .vortex_expect("invalid_count should succeed in conformance test")
128        == 0
129    {
130        let non_nullable = cast_and_execute(&array.to_array(), array.dtype().as_nonnullable())
131            .vortex_expect("arrays without nulls can cast to non-nullable");
132        assert_eq!(non_nullable.dtype(), &array.dtype().as_nonnullable());
133        assert_eq!(non_nullable.len(), array.len());
134
135        for i in 0..array.len().min(10) {
136            assert_eq!(
137                array
138                    .scalar_at(i)
139                    .vortex_expect("scalar_at should succeed in conformance test"),
140                non_nullable
141                    .scalar_at(i)
142                    .vortex_expect("scalar_at should succeed in conformance test")
143            );
144        }
145
146        let back_to_nullable = cast_and_execute(&non_nullable, array.dtype().clone())
147            .vortex_expect("non-nullable arrays can cast to nullable");
148        assert_eq!(back_to_nullable.dtype(), array.dtype());
149        assert_eq!(back_to_nullable.len(), array.len());
150
151        for i in 0..array.len().min(10) {
152            assert_eq!(
153                array
154                    .scalar_at(i)
155                    .vortex_expect("scalar_at should succeed in conformance test"),
156                back_to_nullable
157                    .scalar_at(i)
158                    .vortex_expect("scalar_at should succeed in conformance test")
159            );
160        }
161    } else {
162        if &DType::Null == array.dtype() {
163            // DType::Null.as_nonnullable() (confusingly) returns DType:Null. Of course, a null
164            // array can be casted to DType::Null.
165            return;
166        }
167        cast_and_execute(&array.to_array(), array.dtype().as_nonnullable())
168            .err()
169            .unwrap_or_else(|| {
170                vortex_panic!(
171                    "arrays with nulls should error when casting to non-nullable {}",
172                    array,
173                )
174            });
175    }
176}
177
178fn test_cast_to_nullable(array: &dyn Array) {
179    let nullable = cast_and_execute(&array.to_array(), array.dtype().as_nullable())
180        .vortex_expect("arrays without nulls can cast to nullable");
181    assert_eq!(nullable.dtype(), &array.dtype().as_nullable());
182    assert_eq!(nullable.len(), array.len());
183
184    for i in 0..array.len().min(10) {
185        assert_eq!(
186            array
187                .scalar_at(i)
188                .vortex_expect("scalar_at should succeed in conformance test"),
189            nullable
190                .scalar_at(i)
191                .vortex_expect("scalar_at should succeed in conformance test")
192        );
193    }
194
195    let back = cast_and_execute(&nullable, array.dtype().clone())
196        .vortex_expect("casting to nullable and back should be a no-op");
197    assert_eq!(back.dtype(), array.dtype());
198    assert_eq!(back.len(), array.len());
199
200    for i in 0..array.len().min(10) {
201        assert_eq!(
202            array
203                .scalar_at(i)
204                .vortex_expect("scalar_at should succeed in conformance test"),
205            back.scalar_at(i)
206                .vortex_expect("scalar_at should succeed in conformance test")
207        );
208    }
209}
210
211fn test_cast_from_floating_point_types(array: &dyn Array) {
212    let ptype = array.as_primitive_typed().ptype();
213    test_cast_to_primitive(array, PType::I8, false);
214    test_cast_to_primitive(array, PType::U8, false);
215    test_cast_to_primitive(array, PType::I16, false);
216    test_cast_to_primitive(array, PType::U16, false);
217    test_cast_to_primitive(array, PType::I32, false);
218    test_cast_to_primitive(array, PType::U32, false);
219    test_cast_to_primitive(array, PType::I64, false);
220    test_cast_to_primitive(array, PType::U64, false);
221    test_cast_to_primitive(array, PType::F16, matches!(ptype, PType::F16));
222    test_cast_to_primitive(array, PType::F32, matches!(ptype, PType::F16 | PType::F32));
223    test_cast_to_primitive(array, PType::F64, true);
224}
225
226fn test_cast_to_integral_types(array: &dyn Array) {
227    test_cast_to_primitive(array, PType::I8, true);
228    test_cast_to_primitive(array, PType::U8, true);
229    test_cast_to_primitive(array, PType::I16, true);
230    test_cast_to_primitive(array, PType::U16, true);
231    test_cast_to_primitive(array, PType::I32, true);
232    test_cast_to_primitive(array, PType::U32, true);
233    test_cast_to_primitive(array, PType::I64, true);
234    test_cast_to_primitive(array, PType::U64, true);
235}
236
237/// Does this scalar fit in this type?
238fn fits(value: &Scalar, ptype: PType) -> bool {
239    let dtype = DType::Primitive(ptype, value.dtype().nullability());
240    value.cast(&dtype).is_ok()
241}
242
243fn test_cast_to_primitive(array: &dyn Array, target_ptype: PType, test_round_trip: bool) {
244    let maybe_min_max = min_max(array).vortex_expect("cast should succeed in conformance test");
245
246    if let Some(MinMaxResult { min, max }) = maybe_min_max
247        && (!fits(&min, target_ptype) || !fits(&max, target_ptype))
248    {
249        cast_and_execute(
250            &array.to_array(),
251            DType::Primitive(target_ptype, array.dtype().nullability()),
252        )
253        .err()
254        .unwrap_or_else(|| {
255            vortex_panic!(
256                "Cast must fail because some values are out of bounds. {} {:?} {:?} {} {}",
257                target_ptype,
258                min,
259                max,
260                array,
261                array.display_values(),
262            )
263        });
264        return;
265    }
266
267    // Otherwise, all values must fit.
268    let casted = cast_and_execute(
269        &array.to_array(),
270        DType::Primitive(target_ptype, array.dtype().nullability()),
271    )
272    .unwrap_or_else(|e| {
273        vortex_panic!(
274            "Cast must succeed because all values are within bounds. {} {}: {e}",
275            target_ptype,
276            array.display_values(),
277        )
278    });
279    assert_eq!(
280        array
281            .validity_mask()
282            .vortex_expect("validity_mask should succeed in conformance test"),
283        casted
284            .validity_mask()
285            .vortex_expect("validity_mask should succeed in conformance test")
286    );
287    for i in 0..array.len().min(10) {
288        let original = array
289            .scalar_at(i)
290            .vortex_expect("scalar_at should succeed in conformance test");
291        let casted = casted
292            .scalar_at(i)
293            .vortex_expect("scalar_at should succeed in conformance test");
294        assert_eq!(
295            original
296                .cast(casted.dtype())
297                .vortex_expect("cast should succeed in conformance test"),
298            casted,
299            "{i} {original} {casted}"
300        );
301        if test_round_trip {
302            assert_eq!(
303                original,
304                casted
305                    .cast(original.dtype())
306                    .vortex_expect("cast should succeed in conformance test"),
307                "{i} {original} {casted}"
308            );
309        }
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use vortex_buffer::buffer;
316    use vortex_dtype::DType;
317    use vortex_dtype::FieldNames;
318    use vortex_dtype::Nullability;
319
320    use super::*;
321    use crate::IntoArray;
322    use crate::arrays::BoolArray;
323    use crate::arrays::ListArray;
324    use crate::arrays::NullArray;
325    use crate::arrays::PrimitiveArray;
326    use crate::arrays::StructArray;
327    use crate::arrays::VarBinArray;
328
329    #[test]
330    fn test_cast_conformance_u32() {
331        let array = buffer![0u32, 100, 200, 65535, 1000000].into_array();
332        test_cast_conformance(array.as_ref());
333    }
334
335    #[test]
336    fn test_cast_conformance_i32() {
337        let array = buffer![-100i32, -1, 0, 1, 100].into_array();
338        test_cast_conformance(array.as_ref());
339    }
340
341    #[test]
342    fn test_cast_conformance_f32() {
343        let array = buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array();
344        test_cast_conformance(array.as_ref());
345    }
346
347    #[test]
348    fn test_cast_conformance_nullable() {
349        let array = PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]);
350        test_cast_conformance(array.as_ref());
351    }
352
353    #[test]
354    fn test_cast_conformance_bool() {
355        let array = BoolArray::from_iter(vec![true, false, true, false]);
356        test_cast_conformance(array.as_ref());
357    }
358
359    #[test]
360    fn test_cast_conformance_null() {
361        let array = NullArray::new(5);
362        test_cast_conformance(array.as_ref());
363    }
364
365    #[test]
366    fn test_cast_conformance_utf8() {
367        let array = VarBinArray::from_iter(
368            vec![Some("hello"), None, Some("world")],
369            DType::Utf8(Nullability::Nullable),
370        );
371        test_cast_conformance(array.as_ref());
372    }
373
374    #[test]
375    fn test_cast_conformance_binary() {
376        let array = VarBinArray::from_iter(
377            vec![Some(b"data".as_slice()), None, Some(b"bytes".as_slice())],
378            DType::Binary(Nullability::Nullable),
379        );
380        test_cast_conformance(array.as_ref());
381    }
382
383    #[test]
384    fn test_cast_conformance_struct() {
385        let names = FieldNames::from(["a", "b"]);
386
387        let a = buffer![1i32, 2, 3].into_array();
388        let b = VarBinArray::from_iter(
389            vec![Some("x"), None, Some("z")],
390            DType::Utf8(Nullability::Nullable),
391        )
392        .into_array();
393
394        let array =
395            StructArray::try_new(names, vec![a, b], 3, crate::validity::Validity::NonNullable)
396                .unwrap();
397        test_cast_conformance(array.as_ref());
398    }
399
400    #[test]
401    fn test_cast_conformance_list() {
402        let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
403        let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
404
405        let array =
406            ListArray::try_new(data, offsets, crate::validity::Validity::NonNullable).unwrap();
407        test_cast_conformance(array.as_ref());
408    }
409}