vortex_array/arrays/struct_/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use vortex_dtype::DType;
6use vortex_error::{VortexExpect, VortexResult, vortex_bail};
7
8use crate::arrays::{StructArray, StructVTable};
9use crate::compute::{CastKernel, CastKernelAdapter, cast};
10use crate::vtable::ValidityHelper;
11use crate::{ArrayRef, IntoArray, register_kernel};
12
13impl CastKernel for StructVTable {
14    fn cast(&self, array: &StructArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
15        let Some(target_sdtype) = dtype.as_struct_fields_opt() else {
16            return Ok(None);
17        };
18
19        let source_sdtype = array
20            .dtype()
21            .as_struct_fields_opt()
22            .vortex_expect("struct array must have struct dtype");
23
24        if target_sdtype.names() != source_sdtype.names() {
25            vortex_bail!("cannot cast {} to {}", array.dtype(), dtype);
26        }
27
28        let validity = array
29            .validity()
30            .clone()
31            .cast_nullability(dtype.nullability(), array.len())?;
32
33        StructArray::try_new(
34            target_sdtype.names().clone(),
35            array
36                .fields()
37                .iter()
38                .zip_eq(target_sdtype.fields())
39                .map(|(field, dtype)| cast(field, &dtype))
40                .try_collect()?,
41            array.len(),
42            validity,
43        )
44        .map(|a| Some(a.into_array()))
45    }
46}
47
48register_kernel!(CastKernelAdapter(StructVTable).lift());
49
50#[cfg(test)]
51mod tests {
52    use rstest::rstest;
53    use vortex_buffer::buffer;
54    use vortex_dtype::{DType, FieldNames, Nullability, PType};
55
56    use crate::IntoArray;
57    use crate::arrays::{PrimitiveArray, StructArray, VarBinArray};
58    use crate::compute::conformance::cast::test_cast_conformance;
59    use crate::validity::Validity;
60
61    #[rstest]
62    #[case(create_test_struct(false))]
63    #[case(create_test_struct(true))]
64    #[case(create_nested_struct())]
65    #[case(create_simple_struct())]
66    fn test_cast_struct_conformance(#[case] array: StructArray) {
67        test_cast_conformance(array.as_ref());
68    }
69
70    fn create_test_struct(nullable: bool) -> StructArray {
71        let names = FieldNames::from(["a", "b"]);
72
73        let a = buffer![1i32, 2, 3].into_array();
74        let b = VarBinArray::from_iter(
75            vec![Some("x"), None, Some("z")],
76            DType::Utf8(Nullability::Nullable),
77        )
78        .into_array();
79
80        StructArray::try_new(
81            names,
82            vec![a, b],
83            3,
84            if nullable {
85                Validity::AllValid
86            } else {
87                Validity::NonNullable
88            },
89        )
90        .unwrap()
91    }
92
93    fn create_nested_struct() -> StructArray {
94        // Create inner struct
95        let inner_names = FieldNames::from(["x", "y"]);
96
97        let x = buffer![1.0f32, 2.0, 3.0].into_array();
98        let y = buffer![4.0f32, 5.0, 6.0].into_array();
99        let inner_struct = StructArray::try_new(inner_names, vec![x, y], 3, Validity::NonNullable)
100            .unwrap()
101            .into_array();
102
103        // Create outer struct with inner struct as a field
104        let outer_names: FieldNames = ["id", "point"].into();
105        // Outer struct would have fields: id (I64) and point (inner struct)
106
107        let ids = buffer![100i64, 200, 300].into_array();
108
109        StructArray::try_new(
110            outer_names,
111            vec![ids, inner_struct],
112            3,
113            Validity::NonNullable,
114        )
115        .unwrap()
116    }
117
118    fn create_simple_struct() -> StructArray {
119        let names = FieldNames::from(["value"]);
120        // Simple struct with a single U8 field
121
122        let values = buffer![42u8].into_array();
123
124        StructArray::try_new(names, vec![values], 1, Validity::NonNullable).unwrap()
125    }
126
127    #[test]
128    fn cast_nullable_all_invalid() {
129        let empty_struct = StructArray::try_new(
130            FieldNames::from(["a"]),
131            vec![PrimitiveArray::new::<i32>(buffer![], Validity::AllInvalid).to_array()],
132            0,
133            Validity::AllInvalid,
134        )
135        .unwrap()
136        .to_array();
137
138        let target_dtype = DType::struct_(
139            [("a", DType::Primitive(PType::I32, Nullability::NonNullable))],
140            Nullability::NonNullable,
141        );
142
143        let result = crate::compute::cast(&empty_struct, &target_dtype).unwrap();
144        assert_eq!(result.dtype(), &target_dtype);
145        assert_eq!(result.len(), 0);
146    }
147}