Skip to main content

vortex_array/compute/conformance/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6use vortex_error::vortex_panic;
7
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::LEGACY_SESSION;
11use crate::RecursiveCanonical;
12use crate::VortexSessionExecute;
13use crate::aggregate_fn::fns::min_max::MinMaxResult;
14use crate::aggregate_fn::fns::min_max::min_max;
15use crate::builtins::ArrayBuiltins;
16use crate::dtype::DType;
17use crate::dtype::Nullability;
18use crate::dtype::PType;
19use crate::scalar::Scalar;
20
21/// Cast and force execution via `to_canonical`, returning the canonical array.
22fn cast_and_execute(array: &ArrayRef, dtype: DType) -> VortexResult<ArrayRef> {
23    Ok(array
24        .cast(dtype)?
25        .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())?
26        .0
27        .into_array())
28}
29
30/// Test conformance of the cast compute function for an array.
31///
32/// This function tests various casting scenarios including:
33/// - Casting between numeric types (widening and narrowing)
34/// - Casting between signed and unsigned types
35/// - Casting between integral and floating-point types
36/// - Casting with nullability changes
37/// - Casting between string types (Utf8/Binary)
38/// - Edge cases like overflow behavior
39pub fn test_cast_conformance(array: &ArrayRef) {
40    let dtype = array.dtype();
41
42    // Always test identity cast and nullability changes
43    test_cast_identity(array);
44
45    test_cast_to_non_nullable(array);
46    test_cast_to_nullable(array);
47
48    // Test based on the specific DType
49    match dtype {
50        DType::Null => test_cast_from_null(array),
51        DType::Primitive(ptype, ..) => match ptype {
52            PType::U8
53            | PType::U16
54            | PType::U32
55            | PType::U64
56            | PType::I8
57            | PType::I16
58            | PType::I32
59            | PType::I64 => test_cast_to_integral_types(array),
60            PType::F16 | PType::F32 | PType::F64 => test_cast_from_floating_point_types(array),
61        },
62        _ => {}
63    }
64}
65
66fn test_cast_identity(array: &ArrayRef) {
67    // Casting to the same type should be a no-op
68    let result = cast_and_execute(&array.clone(), array.dtype().clone())
69        .vortex_expect("cast should succeed in conformance test");
70    assert_eq!(result.len(), array.len());
71    assert_eq!(result.dtype(), array.dtype());
72
73    // Verify values are unchanged
74    for i in 0..array.len().min(10) {
75        assert_eq!(
76            array
77                .scalar_at(i)
78                .vortex_expect("scalar_at should succeed in conformance test"),
79            result
80                .scalar_at(i)
81                .vortex_expect("scalar_at should succeed in conformance test")
82        );
83    }
84}
85
86fn test_cast_from_null(array: &ArrayRef) {
87    // Null can be cast to itself
88    let result = cast_and_execute(&array.clone(), DType::Null)
89        .vortex_expect("cast should succeed in conformance test");
90    assert_eq!(result.len(), array.len());
91    assert_eq!(result.dtype(), &DType::Null);
92
93    // Null can also be cast to any nullable type
94    let nullable_types = vec![
95        DType::Bool(Nullability::Nullable),
96        DType::Primitive(PType::I32, Nullability::Nullable),
97        DType::Primitive(PType::F64, Nullability::Nullable),
98        DType::Utf8(Nullability::Nullable),
99        DType::Binary(Nullability::Nullable),
100    ];
101
102    for dtype in nullable_types {
103        let result = cast_and_execute(&array.clone(), dtype.clone())
104            .vortex_expect("cast should succeed in conformance test");
105        assert_eq!(result.len(), array.len());
106        assert_eq!(result.dtype(), &dtype);
107
108        // Verify all values are null
109        for i in 0..array.len().min(10) {
110            assert!(
111                result
112                    .scalar_at(i)
113                    .vortex_expect("scalar_at should succeed in conformance test")
114                    .is_null()
115            );
116        }
117    }
118
119    // Casting to non-nullable types should fail
120    let non_nullable_types = vec![
121        DType::Bool(Nullability::NonNullable),
122        DType::Primitive(PType::I32, Nullability::NonNullable),
123    ];
124
125    for dtype in non_nullable_types {
126        assert!(cast_and_execute(&array.clone(), dtype.clone()).is_err());
127    }
128}
129
130fn test_cast_to_non_nullable(array: &ArrayRef) {
131    if array
132        .invalid_count()
133        .vortex_expect("invalid_count should succeed in conformance test")
134        == 0
135    {
136        let non_nullable = cast_and_execute(&array.clone(), array.dtype().as_nonnullable())
137            .vortex_expect("arrays without nulls can cast to non-nullable");
138        assert_eq!(non_nullable.dtype(), &array.dtype().as_nonnullable());
139        assert_eq!(non_nullable.len(), array.len());
140
141        for i in 0..array.len().min(10) {
142            assert_eq!(
143                array
144                    .scalar_at(i)
145                    .vortex_expect("scalar_at should succeed in conformance test"),
146                non_nullable
147                    .scalar_at(i)
148                    .vortex_expect("scalar_at should succeed in conformance test")
149            );
150        }
151
152        let back_to_nullable = cast_and_execute(&non_nullable, array.dtype().clone())
153            .vortex_expect("non-nullable arrays can cast to nullable");
154        assert_eq!(back_to_nullable.dtype(), array.dtype());
155        assert_eq!(back_to_nullable.len(), array.len());
156
157        for i in 0..array.len().min(10) {
158            assert_eq!(
159                array
160                    .scalar_at(i)
161                    .vortex_expect("scalar_at should succeed in conformance test"),
162                back_to_nullable
163                    .scalar_at(i)
164                    .vortex_expect("scalar_at should succeed in conformance test")
165            );
166        }
167    } else {
168        if &DType::Null == array.dtype() {
169            // DType::Null.as_nonnullable() (confusingly) returns DType:Null. Of course, a null
170            // array can be casted to DType::Null.
171            return;
172        }
173        cast_and_execute(&array.clone(), array.dtype().as_nonnullable())
174            .err()
175            .unwrap_or_else(|| {
176                vortex_panic!(
177                    "arrays with nulls should error when casting to non-nullable {}",
178                    array,
179                )
180            });
181    }
182}
183
184fn test_cast_to_nullable(array: &ArrayRef) {
185    let nullable = cast_and_execute(&array.clone(), array.dtype().as_nullable())
186        .vortex_expect("arrays without nulls can cast to nullable");
187    assert_eq!(nullable.dtype(), &array.dtype().as_nullable());
188    assert_eq!(nullable.len(), array.len());
189
190    for i in 0..array.len().min(10) {
191        assert_eq!(
192            array
193                .scalar_at(i)
194                .vortex_expect("scalar_at should succeed in conformance test"),
195            nullable
196                .scalar_at(i)
197                .vortex_expect("scalar_at should succeed in conformance test")
198        );
199    }
200
201    let back = cast_and_execute(&nullable, array.dtype().clone())
202        .vortex_expect("casting to nullable and back should be a no-op");
203    assert_eq!(back.dtype(), array.dtype());
204    assert_eq!(back.len(), array.len());
205
206    for i in 0..array.len().min(10) {
207        assert_eq!(
208            array
209                .scalar_at(i)
210                .vortex_expect("scalar_at should succeed in conformance test"),
211            back.scalar_at(i)
212                .vortex_expect("scalar_at should succeed in conformance test")
213        );
214    }
215}
216
217fn test_cast_from_floating_point_types(array: &ArrayRef) {
218    let ptype = array.as_primitive_typed().ptype();
219    test_cast_to_primitive(array, PType::I8, false);
220    test_cast_to_primitive(array, PType::U8, false);
221    test_cast_to_primitive(array, PType::I16, false);
222    test_cast_to_primitive(array, PType::U16, false);
223    test_cast_to_primitive(array, PType::I32, false);
224    test_cast_to_primitive(array, PType::U32, false);
225    test_cast_to_primitive(array, PType::I64, false);
226    test_cast_to_primitive(array, PType::U64, false);
227    test_cast_to_primitive(array, PType::F16, matches!(ptype, PType::F16));
228    test_cast_to_primitive(array, PType::F32, matches!(ptype, PType::F16 | PType::F32));
229    test_cast_to_primitive(array, PType::F64, true);
230}
231
232fn test_cast_to_integral_types(array: &ArrayRef) {
233    test_cast_to_primitive(array, PType::I8, true);
234    test_cast_to_primitive(array, PType::U8, true);
235    test_cast_to_primitive(array, PType::I16, true);
236    test_cast_to_primitive(array, PType::U16, true);
237    test_cast_to_primitive(array, PType::I32, true);
238    test_cast_to_primitive(array, PType::U32, true);
239    test_cast_to_primitive(array, PType::I64, true);
240    test_cast_to_primitive(array, PType::U64, true);
241}
242
243/// Does this scalar fit in this type?
244fn fits(value: &Scalar, ptype: PType) -> bool {
245    let dtype = DType::Primitive(ptype, value.dtype().nullability());
246    value.cast(&dtype).is_ok()
247}
248
249fn test_cast_to_primitive(array: &ArrayRef, target_ptype: PType, test_round_trip: bool) {
250    let mut ctx = LEGACY_SESSION.create_execution_ctx();
251    let maybe_min_max =
252        min_max(array, &mut ctx).vortex_expect("cast should succeed in conformance test");
253
254    if let Some(MinMaxResult { min, max }) = maybe_min_max
255        && (!fits(&min, target_ptype) || !fits(&max, target_ptype))
256    {
257        cast_and_execute(
258            &array.clone(),
259            DType::Primitive(target_ptype, array.dtype().nullability()),
260        )
261        .err()
262        .unwrap_or_else(|| {
263            vortex_panic!(
264                "Cast must fail because some values are out of bounds. {} {:?} {:?} {} {}",
265                target_ptype,
266                min,
267                max,
268                array,
269                array.display_values(),
270            )
271        });
272        return;
273    }
274
275    // Otherwise, all values must fit.
276    let casted = cast_and_execute(
277        &array.clone(),
278        DType::Primitive(target_ptype, array.dtype().nullability()),
279    )
280    .unwrap_or_else(|e| {
281        vortex_panic!(
282            "Cast must succeed because all values are within bounds. {} {}: {e}",
283            target_ptype,
284            array.display_values(),
285        )
286    });
287    assert_eq!(
288        array
289            .validity_mask()
290            .vortex_expect("validity_mask should succeed in conformance test"),
291        casted
292            .validity_mask()
293            .vortex_expect("validity_mask should succeed in conformance test")
294    );
295    for i in 0..array.len().min(10) {
296        let original = array
297            .scalar_at(i)
298            .vortex_expect("scalar_at should succeed in conformance test");
299        let casted = casted
300            .scalar_at(i)
301            .vortex_expect("scalar_at should succeed in conformance test");
302        assert_eq!(
303            original
304                .cast(casted.dtype())
305                .vortex_expect("cast should succeed in conformance test"),
306            casted,
307            "{i} {original} {casted}"
308        );
309        if test_round_trip {
310            assert_eq!(
311                original,
312                casted
313                    .cast(original.dtype())
314                    .vortex_expect("cast should succeed in conformance test"),
315                "{i} {original} {casted}"
316            );
317        }
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use vortex_buffer::buffer;
324
325    use super::*;
326    use crate::IntoArray;
327    use crate::arrays::BoolArray;
328    use crate::arrays::ListArray;
329    use crate::arrays::NullArray;
330    use crate::arrays::PrimitiveArray;
331    use crate::arrays::StructArray;
332    use crate::arrays::VarBinArray;
333    use crate::dtype::DType;
334    use crate::dtype::FieldNames;
335    use crate::dtype::Nullability;
336
337    #[test]
338    fn test_cast_conformance_u32() {
339        let array = buffer![0u32, 100, 200, 65535, 1000000].into_array();
340        test_cast_conformance(&array);
341    }
342
343    #[test]
344    fn test_cast_conformance_i32() {
345        let array = buffer![-100i32, -1, 0, 1, 100].into_array();
346        test_cast_conformance(&array);
347    }
348
349    #[test]
350    fn test_cast_conformance_f32() {
351        let array = buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array();
352        test_cast_conformance(&array);
353    }
354
355    #[test]
356    fn test_cast_conformance_nullable() {
357        let array = PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]);
358        test_cast_conformance(&array.into_array());
359    }
360
361    #[test]
362    fn test_cast_conformance_bool() {
363        let array = BoolArray::from_iter(vec![true, false, true, false]);
364        test_cast_conformance(&array.into_array());
365    }
366
367    #[test]
368    fn test_cast_conformance_null() {
369        let array = NullArray::new(5);
370        test_cast_conformance(&array.into_array());
371    }
372
373    #[test]
374    fn test_cast_conformance_utf8() {
375        let array = VarBinArray::from_iter(
376            vec![Some("hello"), None, Some("world")],
377            DType::Utf8(Nullability::Nullable),
378        );
379        test_cast_conformance(&array.into_array());
380    }
381
382    #[test]
383    fn test_cast_conformance_binary() {
384        let array = VarBinArray::from_iter(
385            vec![Some(b"data".as_slice()), None, Some(b"bytes".as_slice())],
386            DType::Binary(Nullability::Nullable),
387        );
388        test_cast_conformance(&array.into_array());
389    }
390
391    #[test]
392    fn test_cast_conformance_struct() {
393        let names = FieldNames::from(["a", "b"]);
394
395        let a = buffer![1i32, 2, 3].into_array();
396        let b = VarBinArray::from_iter(
397            vec![Some("x"), None, Some("z")],
398            DType::Utf8(Nullability::Nullable),
399        )
400        .into_array();
401
402        let array =
403            StructArray::try_new(names, vec![a, b], 3, crate::validity::Validity::NonNullable)
404                .unwrap();
405        test_cast_conformance(&array.into_array());
406    }
407
408    #[test]
409    fn test_cast_conformance_list() {
410        let data = buffer![1i32, 2, 3, 4, 5, 6].into_array();
411        let offsets = buffer![0i64, 2, 2, 5, 6].into_array();
412
413        let array =
414            ListArray::try_new(data, offsets, crate::validity::Validity::NonNullable).unwrap();
415        test_cast_conformance(&array.into_array());
416    }
417}