vortex_array/arrays/struct_/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use vortex_dtype::DType;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9
10use crate::ArrayRef;
11use crate::IntoArray;
12use crate::arrays::StructArray;
13use crate::arrays::StructVTable;
14use crate::compute::CastKernel;
15use crate::compute::CastKernelAdapter;
16use crate::compute::cast;
17use crate::register_kernel;
18use crate::vtable::ValidityHelper;
19
20impl CastKernel for StructVTable {
21    fn cast(&self, array: &StructArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
22        let Some(target_sdtype) = dtype.as_struct_fields_opt() else {
23            return Ok(None);
24        };
25
26        let source_sdtype = array
27            .dtype()
28            .as_struct_fields_opt()
29            .vortex_expect("struct array must have struct dtype");
30
31        if target_sdtype.names() != source_sdtype.names() {
32            vortex_bail!("cannot cast {} to {}", array.dtype(), dtype);
33        }
34
35        let validity = array
36            .validity()
37            .clone()
38            .cast_nullability(dtype.nullability(), array.len())?;
39
40        StructArray::try_new(
41            target_sdtype.names().clone(),
42            array
43                .fields()
44                .iter()
45                .zip_eq(target_sdtype.fields())
46                .map(|(field, dtype)| cast(field, &dtype))
47                .collect::<Result<Vec<_>, _>>()?,
48            array.len(),
49            validity,
50        )
51        .map(|a| Some(a.into_array()))
52    }
53}
54
55register_kernel!(CastKernelAdapter(StructVTable).lift());
56
57#[cfg(test)]
58mod tests {
59    use rstest::rstest;
60    use vortex_buffer::buffer;
61    use vortex_dtype::DType;
62    use vortex_dtype::FieldNames;
63    use vortex_dtype::Nullability;
64    use vortex_dtype::PType;
65
66    use crate::IntoArray;
67    use crate::arrays::PrimitiveArray;
68    use crate::arrays::StructArray;
69    use crate::arrays::VarBinArray;
70    use crate::compute::conformance::cast::test_cast_conformance;
71    use crate::validity::Validity;
72
73    #[rstest]
74    #[case(create_test_struct(false))]
75    #[case(create_test_struct(true))]
76    #[case(create_nested_struct())]
77    #[case(create_simple_struct())]
78    fn test_cast_struct_conformance(#[case] array: StructArray) {
79        test_cast_conformance(array.as_ref());
80    }
81
82    fn create_test_struct(nullable: bool) -> StructArray {
83        let names = FieldNames::from(["a", "b"]);
84
85        let a = buffer![1i32, 2, 3].into_array();
86        let b = VarBinArray::from_iter(
87            vec![Some("x"), None, Some("z")],
88            DType::Utf8(Nullability::Nullable),
89        )
90        .into_array();
91
92        StructArray::try_new(
93            names,
94            vec![a, b],
95            3,
96            if nullable {
97                Validity::AllValid
98            } else {
99                Validity::NonNullable
100            },
101        )
102        .unwrap()
103    }
104
105    fn create_nested_struct() -> StructArray {
106        // Create inner struct
107        let inner_names = FieldNames::from(["x", "y"]);
108
109        let x = buffer![1.0f32, 2.0, 3.0].into_array();
110        let y = buffer![4.0f32, 5.0, 6.0].into_array();
111        let inner_struct = StructArray::try_new(inner_names, vec![x, y], 3, Validity::NonNullable)
112            .unwrap()
113            .into_array();
114
115        // Create outer struct with inner struct as a field
116        let outer_names: FieldNames = ["id", "point"].into();
117        // Outer struct would have fields: id (I64) and point (inner struct)
118
119        let ids = buffer![100i64, 200, 300].into_array();
120
121        StructArray::try_new(
122            outer_names,
123            vec![ids, inner_struct],
124            3,
125            Validity::NonNullable,
126        )
127        .unwrap()
128    }
129
130    fn create_simple_struct() -> StructArray {
131        let names = FieldNames::from(["value"]);
132        // Simple struct with a single U8 field
133
134        let values = buffer![42u8].into_array();
135
136        StructArray::try_new(names, vec![values], 1, Validity::NonNullable).unwrap()
137    }
138
139    #[test]
140    fn cast_nullable_all_invalid() {
141        let empty_struct = StructArray::try_new(
142            FieldNames::from(["a"]),
143            vec![PrimitiveArray::new::<i32>(buffer![], Validity::AllInvalid).to_array()],
144            0,
145            Validity::AllInvalid,
146        )
147        .unwrap()
148        .to_array();
149
150        let target_dtype = DType::struct_(
151            [("a", DType::Primitive(PType::I32, Nullability::NonNullable))],
152            Nullability::NonNullable,
153        );
154
155        let result = crate::compute::cast(&empty_struct, &target_dtype).unwrap();
156        assert_eq!(result.dtype(), &target_dtype);
157        assert_eq!(result.len(), 0);
158    }
159}