vortex_array/arrays/struct_/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use vortex_dtype::DType;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_ensure;
9use vortex_scalar::Scalar;
10
11use crate::ArrayRef;
12use crate::IntoArray;
13use crate::arrays::ConstantArray;
14use crate::arrays::StructArray;
15use crate::arrays::StructVTable;
16use crate::compute::CastKernel;
17use crate::compute::CastKernelAdapter;
18use crate::compute::cast;
19use crate::register_kernel;
20use crate::vtable::ValidityHelper;
21
22impl CastKernel for StructVTable {
23    fn cast(&self, array: &StructArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
24        let Some(target_sdtype) = dtype.as_struct_fields_opt() else {
25            return Ok(None);
26        };
27
28        let source_sdtype = array
29            .dtype()
30            .as_struct_fields_opt()
31            .vortex_expect("struct array must have struct dtype");
32
33        // Re-order, handle fields by value instead.
34        let mut cast_fields = vec![];
35        for (target_name, target_type) in
36            target_sdtype.names().iter().zip_eq(target_sdtype.fields())
37        {
38            match source_sdtype.find(target_name) {
39                None => {
40                    // No source field with this name => evolve the schema compatibly.
41                    // If the field is nullable, we add a new ConstantArray field with the type.
42                    vortex_ensure!(
43                        target_type.is_nullable(),
44                        "CAST for struct only supports added nullable fields"
45                    );
46
47                    cast_fields.push(
48                        ConstantArray::new(Scalar::null(target_type), array.len).into_array(),
49                    );
50                }
51                Some(src_field_idx) => {
52                    // Field exists in source field. Cast it to the target type.
53                    let cast_field = cast(array.fields()[src_field_idx].as_ref(), &target_type)?;
54                    cast_fields.push(cast_field);
55                }
56            }
57        }
58
59        let validity = array
60            .validity()
61            .clone()
62            .cast_nullability(dtype.nullability(), array.len())?;
63
64        StructArray::try_new(
65            target_sdtype.names().clone(),
66            cast_fields,
67            array.len(),
68            validity,
69        )
70        .map(|a| Some(a.into_array()))
71    }
72}
73
74register_kernel!(CastKernelAdapter(StructVTable).lift());
75
76#[cfg(test)]
77mod tests {
78    use rstest::rstest;
79    use vortex_buffer::buffer;
80    use vortex_dtype::DType;
81    use vortex_dtype::FieldNames;
82    use vortex_dtype::Nullability;
83    use vortex_dtype::PType;
84
85    use crate::IntoArray;
86    use crate::arrays::PrimitiveArray;
87    use crate::arrays::StructArray;
88    use crate::arrays::VarBinArray;
89    use crate::compute::conformance::cast::test_cast_conformance;
90    use crate::validity::Validity;
91
92    #[rstest]
93    #[case(create_test_struct(false))]
94    #[case(create_test_struct(true))]
95    #[case(create_nested_struct())]
96    #[case(create_simple_struct())]
97    fn test_cast_struct_conformance(#[case] array: StructArray) {
98        test_cast_conformance(array.as_ref());
99    }
100
101    fn create_test_struct(nullable: bool) -> StructArray {
102        let names = FieldNames::from(["a", "b"]);
103
104        let a = buffer![1i32, 2, 3].into_array();
105        let b = VarBinArray::from_iter(
106            vec![Some("x"), None, Some("z")],
107            DType::Utf8(Nullability::Nullable),
108        )
109        .into_array();
110
111        StructArray::try_new(
112            names,
113            vec![a, b],
114            3,
115            if nullable {
116                Validity::AllValid
117            } else {
118                Validity::NonNullable
119            },
120        )
121        .unwrap()
122    }
123
124    fn create_nested_struct() -> StructArray {
125        // Create inner struct
126        let inner_names = FieldNames::from(["x", "y"]);
127
128        let x = buffer![1.0f32, 2.0, 3.0].into_array();
129        let y = buffer![4.0f32, 5.0, 6.0].into_array();
130        let inner_struct = StructArray::try_new(inner_names, vec![x, y], 3, Validity::NonNullable)
131            .unwrap()
132            .into_array();
133
134        // Create outer struct with inner struct as a field
135        let outer_names: FieldNames = ["id", "point"].into();
136        // Outer struct would have fields: id (I64) and point (inner struct)
137
138        let ids = buffer![100i64, 200, 300].into_array();
139
140        StructArray::try_new(
141            outer_names,
142            vec![ids, inner_struct],
143            3,
144            Validity::NonNullable,
145        )
146        .unwrap()
147    }
148
149    fn create_simple_struct() -> StructArray {
150        let names = FieldNames::from(["value"]);
151        // Simple struct with a single U8 field
152
153        let values = buffer![42u8].into_array();
154
155        StructArray::try_new(names, vec![values], 1, Validity::NonNullable).unwrap()
156    }
157
158    #[test]
159    fn cast_nullable_all_invalid() {
160        let empty_struct = StructArray::try_new(
161            FieldNames::from(["a"]),
162            vec![PrimitiveArray::new::<i32>(buffer![], Validity::AllInvalid).to_array()],
163            0,
164            Validity::AllInvalid,
165        )
166        .unwrap()
167        .to_array();
168
169        let target_dtype = DType::struct_(
170            [("a", DType::Primitive(PType::I32, Nullability::NonNullable))],
171            Nullability::NonNullable,
172        );
173
174        let result = crate::compute::cast(&empty_struct, &target_dtype).unwrap();
175        assert_eq!(result.dtype(), &target_dtype);
176        assert_eq!(result.len(), 0);
177    }
178}