Skip to main content

vortex_array/compute/conformance/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6use vortex_error::vortex_panic;
7
8use crate::ArrayRef;
9use crate::DynArray;
10use crate::IntoArray;
11use crate::LEGACY_SESSION;
12use crate::RecursiveCanonical;
13use crate::VortexSessionExecute;
14use crate::aggregate_fn::fns::min_max::MinMaxResult;
15use crate::aggregate_fn::fns::min_max::min_max;
16use crate::builtins::ArrayBuiltins;
17use crate::dtype::DType;
18use crate::dtype::Nullability;
19use crate::dtype::PType;
20use crate::scalar::Scalar;
21
22/// Cast and force execution via `to_canonical`, returning the canonical array.
23fn cast_and_execute(array: &ArrayRef, dtype: DType) -> VortexResult<ArrayRef> {
24    Ok(array
25        .cast(dtype)?
26        .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())?
27        .0
28        .into_array())
29}
30
31/// Test conformance of the cast compute function for an array.
32///
33/// This function tests various casting scenarios including:
34/// - Casting between numeric types (widening and narrowing)
35/// - Casting between signed and unsigned types
36/// - Casting between integral and floating-point types
37/// - Casting with nullability changes
38/// - Casting between string types (Utf8/Binary)
39/// - Edge cases like overflow behavior
40pub fn test_cast_conformance(array: &ArrayRef) {
41    let dtype = array.dtype();
42
43    // Always test identity cast and nullability changes
44    test_cast_identity(array);
45
46    test_cast_to_non_nullable(array);
47    test_cast_to_nullable(array);
48
49    // Test based on the specific DType
50    match dtype {
51        DType::Null => test_cast_from_null(array),
52        DType::Primitive(ptype, ..) => match ptype {
53            PType::U8
54            | PType::U16
55            | PType::U32
56            | PType::U64
57            | PType::I8
58            | PType::I16
59            | PType::I32
60            | PType::I64 => test_cast_to_integral_types(array),
61            PType::F16 | PType::F32 | PType::F64 => test_cast_from_floating_point_types(array),
62        },
63        _ => {}
64    }
65}
66
67fn test_cast_identity(array: &ArrayRef) {
68    // Casting to the same type should be a no-op
69    let result = cast_and_execute(&array.to_array(), array.dtype().clone())
70        .vortex_expect("cast should succeed in conformance test");
71    assert_eq!(result.len(), array.len());
72    assert_eq!(result.dtype(), array.dtype());
73
74    // Verify values are unchanged
75    for i in 0..array.len().min(10) {
76        assert_eq!(
77            array
78                .scalar_at(i)
79                .vortex_expect("scalar_at should succeed in conformance test"),
80            result
81                .scalar_at(i)
82                .vortex_expect("scalar_at should succeed in conformance test")
83        );
84    }
85}
86
87fn test_cast_from_null(array: &ArrayRef) {
88    // Null can be cast to itself
89    let result = cast_and_execute(&array.to_array(), DType::Null)
90        .vortex_expect("cast should succeed in conformance test");
91    assert_eq!(result.len(), array.len());
92    assert_eq!(result.dtype(), &DType::Null);
93
94    // Null can also be cast to any nullable type
95    let nullable_types = vec![
96        DType::Bool(Nullability::Nullable),
97        DType::Primitive(PType::I32, Nullability::Nullable),
98        DType::Primitive(PType::F64, Nullability::Nullable),
99        DType::Utf8(Nullability::Nullable),
100        DType::Binary(Nullability::Nullable),
101    ];
102
103    for dtype in nullable_types {
104        let result = cast_and_execute(&array.to_array(), dtype.clone())
105            .vortex_expect("cast should succeed in conformance test");
106        assert_eq!(result.len(), array.len());
107        assert_eq!(result.dtype(), &dtype);
108
109        // Verify all values are null
110        for i in 0..array.len().min(10) {
111            assert!(
112                result
113                    .scalar_at(i)
114                    .vortex_expect("scalar_at should succeed in conformance test")
115                    .is_null()
116            );
117        }
118    }
119
120    // Casting to non-nullable types should fail
121    let non_nullable_types = vec![
122        DType::Bool(Nullability::NonNullable),
123        DType::Primitive(PType::I32, Nullability::NonNullable),
124    ];
125
126    for dtype in non_nullable_types {
127        assert!(cast_and_execute(&array.to_array(), dtype.clone()).is_err());
128    }
129}
130
131fn test_cast_to_non_nullable(array: &ArrayRef) {
132    if array
133        .invalid_count()
134        .vortex_expect("invalid_count should succeed in conformance test")
135        == 0
136    {
137        let non_nullable = cast_and_execute(&array.to_array(), array.dtype().as_nonnullable())
138            .vortex_expect("arrays without nulls can cast to non-nullable");
139        assert_eq!(non_nullable.dtype(), &array.dtype().as_nonnullable());
140        assert_eq!(non_nullable.len(), array.len());
141
142        for i in 0..array.len().min(10) {
143            assert_eq!(
144                array
145                    .scalar_at(i)
146                    .vortex_expect("scalar_at should succeed in conformance test"),
147                non_nullable
148                    .scalar_at(i)
149                    .vortex_expect("scalar_at should succeed in conformance test")
150            );
151        }
152
153        let back_to_nullable = cast_and_execute(&non_nullable, array.dtype().clone())
154            .vortex_expect("non-nullable arrays can cast to nullable");
155        assert_eq!(back_to_nullable.dtype(), array.dtype());
156        assert_eq!(back_to_nullable.len(), array.len());
157
158        for i in 0..array.len().min(10) {
159            assert_eq!(
160                array
161                    .scalar_at(i)
162                    .vortex_expect("scalar_at should succeed in conformance test"),
163                back_to_nullable
164                    .scalar_at(i)
165                    .vortex_expect("scalar_at should succeed in conformance test")
166            );
167        }
168    } else {
169        if &DType::Null == array.dtype() {
170            // DType::Null.as_nonnullable() (confusingly) returns DType:Null. Of course, a null
171            // array can be casted to DType::Null.
172            return;
173        }
174        cast_and_execute(&array.to_array(), array.dtype().as_nonnullable())
175            .err()
176            .unwrap_or_else(|| {
177                vortex_panic!(
178                    "arrays with nulls should error when casting to non-nullable {}",
179                    array,
180                )
181            });
182    }
183}
184
185fn test_cast_to_nullable(array: &ArrayRef) {
186    let nullable = cast_and_execute(&array.to_array(), array.dtype().as_nullable())
187        .vortex_expect("arrays without nulls can cast to nullable");
188    assert_eq!(nullable.dtype(), &array.dtype().as_nullable());
189    assert_eq!(nullable.len(), array.len());
190
191    for i in 0..array.len().min(10) {
192        assert_eq!(
193            array
194                .scalar_at(i)
195                .vortex_expect("scalar_at should succeed in conformance test"),
196            nullable
197                .scalar_at(i)
198                .vortex_expect("scalar_at should succeed in conformance test")
199        );
200    }
201
202    let back = cast_and_execute(&nullable, array.dtype().clone())
203        .vortex_expect("casting to nullable and back should be a no-op");
204    assert_eq!(back.dtype(), array.dtype());
205    assert_eq!(back.len(), array.len());
206
207    for i in 0..array.len().min(10) {
208        assert_eq!(
209            array
210                .scalar_at(i)
211                .vortex_expect("scalar_at should succeed in conformance test"),
212            back.scalar_at(i)
213                .vortex_expect("scalar_at should succeed in conformance test")
214        );
215    }
216}
217
218fn test_cast_from_floating_point_types(array: &ArrayRef) {
219    let ptype = array.as_primitive_typed().ptype();
220    test_cast_to_primitive(array, PType::I8, false);
221    test_cast_to_primitive(array, PType::U8, false);
222    test_cast_to_primitive(array, PType::I16, false);
223    test_cast_to_primitive(array, PType::U16, false);
224    test_cast_to_primitive(array, PType::I32, false);
225    test_cast_to_primitive(array, PType::U32, false);
226    test_cast_to_primitive(array, PType::I64, false);
227    test_cast_to_primitive(array, PType::U64, false);
228    test_cast_to_primitive(array, PType::F16, matches!(ptype, PType::F16));
229    test_cast_to_primitive(array, PType::F32, matches!(ptype, PType::F16 | PType::F32));
230    test_cast_to_primitive(array, PType::F64, true);
231}
232
233fn test_cast_to_integral_types(array: &ArrayRef) {
234    test_cast_to_primitive(array, PType::I8, true);
235    test_cast_to_primitive(array, PType::U8, true);
236    test_cast_to_primitive(array, PType::I16, true);
237    test_cast_to_primitive(array, PType::U16, true);
238    test_cast_to_primitive(array, PType::I32, true);
239    test_cast_to_primitive(array, PType::U32, true);
240    test_cast_to_primitive(array, PType::I64, true);
241    test_cast_to_primitive(array, PType::U64, true);
242}
243
244/// Does this scalar fit in this type?
245fn fits(value: &Scalar, ptype: PType) -> bool {
246    let dtype = DType::Primitive(ptype, value.dtype().nullability());
247    value.cast(&dtype).is_ok()
248}
249
250fn test_cast_to_primitive(array: &ArrayRef, target_ptype: PType, test_round_trip: bool) {
251    let mut ctx = LEGACY_SESSION.create_execution_ctx();
252    let maybe_min_max =
253        min_max(array, &mut ctx).vortex_expect("cast should succeed in conformance test");
254
255    if let Some(MinMaxResult { min, max }) = maybe_min_max
256        && (!fits(&min, target_ptype) || !fits(&max, target_ptype))
257    {
258        cast_and_execute(
259            &array.to_array(),
260            DType::Primitive(target_ptype, array.dtype().nullability()),
261        )
262        .err()
263        .unwrap_or_else(|| {
264            vortex_panic!(
265                "Cast must fail because some values are out of bounds. {} {:?} {:?} {} {}",
266                target_ptype,
267                min,
268                max,
269                array,
270                array.display_values(),
271            )
272        });
273        return;
274    }
275
276    // Otherwise, all values must fit.
277    let casted = cast_and_execute(
278        &array.to_array(),
279        DType::Primitive(target_ptype, array.dtype().nullability()),
280    )
281    .unwrap_or_else(|e| {
282        vortex_panic!(
283            "Cast must succeed because all values are within bounds. {} {}: {e}",
284            target_ptype,
285            array.display_values(),
286        )
287    });
288    assert_eq!(
289        array
290            .validity_mask()
291            .vortex_expect("validity_mask should succeed in conformance test"),
292        casted
293            .validity_mask()
294            .vortex_expect("validity_mask should succeed in conformance test")
295    );
296    for i in 0..array.len().min(10) {
297        let original = array
298            .scalar_at(i)
299            .vortex_expect("scalar_at should succeed in conformance test");
300        let casted = casted
301            .scalar_at(i)
302            .vortex_expect("scalar_at should succeed in conformance test");
303        assert_eq!(
304            original
305                .cast(casted.dtype())
306                .vortex_expect("cast should succeed in conformance test"),
307            casted,
308            "{i} {original} {casted}"
309        );
310        if test_round_trip {
311            assert_eq!(
312                original,
313                casted
314                    .cast(original.dtype())
315                    .vortex_expect("cast should succeed in conformance test"),
316                "{i} {original} {casted}"
317            );
318        }
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use vortex_buffer::buffer;
325
326    use super::*;
327    use crate::IntoArray;
328    use crate::arrays::BoolArray;
329    use crate::arrays::ListArray;
330    use crate::arrays::NullArray;
331    use crate::arrays::PrimitiveArray;
332    use crate::arrays::StructArray;
333    use crate::arrays::VarBinArray;
334    use crate::dtype::DType;
335    use crate::dtype::FieldNames;
336    use crate::dtype::Nullability;
337
338    #[test]
339    fn test_cast_conformance_u32() {
340        let array = buffer![0u32, 100, 200, 65535, 1000000].into_array();
341        test_cast_conformance(&array);
342    }
343
344    #[test]
345    fn test_cast_conformance_i32() {
346        let array = buffer![-100i32, -1, 0, 1, 100].into_array();
347        test_cast_conformance(&array);
348    }
349
350    #[test]
351    fn test_cast_conformance_f32() {
352        let array = buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array();
353        test_cast_conformance(&array);
354    }
355
356    #[test]
357    fn test_cast_conformance_nullable() {
358        let array = PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]);
359        test_cast_conformance(&array.into_array());
360    }
361
362    #[test]
363    fn test_cast_conformance_bool() {
364        let array = BoolArray::from_iter(vec![true, false, true, false]);
365        test_cast_conformance(&array.into_array());
366    }
367
368    #[test]
369    fn test_cast_conformance_null() {
370        let array = NullArray::new(5);
371        test_cast_conformance(&array.into_array());
372    }
373
374    #[test]
375    fn test_cast_conformance_utf8() {
376        let array = VarBinArray::from_iter(
377            vec![Some("hello"), None, Some("world")],
378            DType::Utf8(Nullability::Nullable),
379        );
380        test_cast_conformance(&array.into_array());
381    }
382
383    #[test]
384    fn test_cast_conformance_binary() {
385        let array = VarBinArray::from_iter(
386            vec![Some(b"data".as_slice()), None, Some(b"bytes".as_slice())],
387            DType::Binary(Nullability::Nullable),
388        );
389        test_cast_conformance(&array.into_array());
390    }
391
392    #[test]
393    fn test_cast_conformance_struct() {
394        let names = FieldNames::from(["a", "b"]);
395
396        let a = buffer![1i32, 2, 3].into_array();
397        let b = VarBinArray::from_iter(
398            vec![Some("x"), None, Some("z")],
399            DType::Utf8(Nullability::Nullable),
400        )
401        .into_array();
402
403        let array =
404            StructArray::try_new(names, vec![a, b], 3, crate::validity::Validity::NonNullable)
405                .unwrap();
406        test_cast_conformance(&array.into_array());
407    }
408
409    #[test]
410    fn test_cast_conformance_list() {
411        let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
412        let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
413
414        let array =
415            ListArray::try_new(data, offsets, crate::validity::Validity::NonNullable).unwrap();
416        test_cast_conformance(&array.into_array());
417    }
418}