Skip to main content

vortex_array/arrays/struct_/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use vortex_dtype::DType;
6use vortex_error::VortexResult;
7use vortex_error::vortex_ensure;
8
9use crate::ArrayRef;
10use crate::ExecutionCtx;
11use crate::IntoArray;
12use crate::arrays::ConstantArray;
13use crate::arrays::StructArray;
14use crate::arrays::StructVTable;
15use crate::builtins::ArrayBuiltins;
16use crate::compute::CastKernel;
17use crate::scalar::Scalar;
18use crate::vtable::ValidityHelper;
19
20impl CastKernel for StructVTable {
21    fn cast(
22        array: &StructArray,
23        dtype: &DType,
24        _ctx: &mut ExecutionCtx,
25    ) -> VortexResult<Option<ArrayRef>> {
26        let Some(target_sdtype) = dtype.as_struct_fields_opt() else {
27            return Ok(None);
28        };
29
30        let source_sdtype = array.struct_fields();
31
32        let fields_match_order = target_sdtype.nfields() == source_sdtype.nfields()
33            && target_sdtype
34                .names()
35                .iter()
36                .zip(source_sdtype.names().iter())
37                .all(|(f1, f2)| f1 == f2);
38
39        let mut cast_fields = Vec::with_capacity(target_sdtype.nfields());
40        if fields_match_order {
41            for (field, target_type) in array
42                .unmasked_fields()
43                .iter()
44                .zip_eq(target_sdtype.fields())
45            {
46                let cast_field = field.cast(target_type)?;
47                cast_fields.push(cast_field);
48            }
49        } else {
50            // Re-order, handle fields by value instead.
51            for (target_name, target_type) in
52                target_sdtype.names().iter().zip_eq(target_sdtype.fields())
53            {
54                match source_sdtype.find(target_name) {
55                    None => {
56                        // No source field with this name => evolve the schema compatibly.
57                        // If the field is nullable, we add a new ConstantArray field with the type.
58                        vortex_ensure!(
59                            target_type.is_nullable(),
60                            "CAST for struct only supports added nullable fields"
61                        );
62
63                        cast_fields.push(
64                            ConstantArray::new(Scalar::null(target_type), array.len()).into_array(),
65                        );
66                    }
67                    Some(src_field_idx) => {
68                        // Field exists in source field. Cast it to the target type.
69                        let cast_field =
70                            array.unmasked_fields()[src_field_idx].cast(target_type)?;
71                        cast_fields.push(cast_field);
72                    }
73                }
74            }
75        }
76
77        let validity = array
78            .validity()
79            .clone()
80            .cast_nullability(dtype.nullability(), array.len())?;
81
82        StructArray::try_new(
83            target_sdtype.names().clone(),
84            cast_fields,
85            array.len(),
86            validity,
87        )
88        .map(|a| Some(a.into_array()))
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use rstest::rstest;
95    use vortex_buffer::buffer;
96    use vortex_dtype::DType;
97    use vortex_dtype::DecimalDType;
98    use vortex_dtype::FieldNames;
99    use vortex_dtype::Nullability;
100    use vortex_dtype::PType;
101
102    use crate::Array;
103    use crate::IntoArray;
104    use crate::ToCanonical;
105    use crate::arrays::PrimitiveArray;
106    use crate::arrays::StructArray;
107    use crate::arrays::VarBinArray;
108    use crate::builtins::ArrayBuiltins;
109    use crate::compute::conformance::cast::test_cast_conformance;
110    use crate::validity::Validity;
111
112    #[rstest]
113    #[case(create_test_struct(false))]
114    #[case(create_test_struct(true))]
115    #[case(create_nested_struct())]
116    #[case(create_simple_struct())]
117    fn test_cast_struct_conformance(#[case] array: StructArray) {
118        test_cast_conformance(array.as_ref());
119    }
120
121    fn create_test_struct(nullable: bool) -> StructArray {
122        let names = FieldNames::from(["a", "b"]);
123
124        let a = buffer![1i32, 2, 3].into_array();
125        let b = VarBinArray::from_iter(
126            vec![Some("x"), None, Some("z")],
127            DType::Utf8(Nullability::Nullable),
128        )
129        .into_array();
130
131        StructArray::try_new(
132            names,
133            vec![a, b],
134            3,
135            if nullable {
136                Validity::AllValid
137            } else {
138                Validity::NonNullable
139            },
140        )
141        .unwrap()
142    }
143
144    fn create_nested_struct() -> StructArray {
145        // Create inner struct
146        let inner_names = FieldNames::from(["x", "y"]);
147
148        let x = buffer![1.0f32, 2.0, 3.0].into_array();
149        let y = buffer![4.0f32, 5.0, 6.0].into_array();
150        let inner_struct = StructArray::try_new(inner_names, vec![x, y], 3, Validity::NonNullable)
151            .unwrap()
152            .into_array();
153
154        // Create outer struct with inner struct as a field
155        let outer_names: FieldNames = ["id", "point"].into();
156        // Outer struct would have fields: id (I64) and point (inner struct)
157
158        let ids = buffer![100i64, 200, 300].into_array();
159
160        StructArray::try_new(
161            outer_names,
162            vec![ids, inner_struct],
163            3,
164            Validity::NonNullable,
165        )
166        .unwrap()
167    }
168
169    fn create_simple_struct() -> StructArray {
170        let names = FieldNames::from(["value"]);
171        // Simple struct with a single U8 field
172
173        let values = buffer![42u8].into_array();
174
175        StructArray::try_new(names, vec![values], 1, Validity::NonNullable).unwrap()
176    }
177
178    #[test]
179    fn cast_nullable_all_invalid() {
180        let empty_struct = StructArray::try_new(
181            FieldNames::from(["a"]),
182            vec![PrimitiveArray::new::<i32>(buffer![], Validity::AllInvalid).to_array()],
183            0,
184            Validity::AllInvalid,
185        )
186        .unwrap()
187        .to_array();
188
189        let target_dtype = DType::struct_(
190            [("a", DType::Primitive(PType::I32, Nullability::NonNullable))],
191            Nullability::NonNullable,
192        );
193
194        let result = empty_struct.cast(target_dtype.clone()).unwrap();
195        assert_eq!(result.dtype(), &target_dtype);
196        assert_eq!(result.len(), 0);
197    }
198
199    #[test]
200    fn cast_duplicate_field_names_to_nullable() {
201        let names = FieldNames::from(["a", "a"]);
202        let field1 = buffer![1i32, 2, 3].into_array();
203        let field2 = buffer![10i64, 20, 30].into_array();
204
205        let struct_array =
206            StructArray::try_new(names, vec![field1, field2], 3, Validity::NonNullable).unwrap();
207
208        let target_dtype = struct_array.dtype().as_nullable();
209
210        let result = struct_array.to_array().cast(target_dtype.clone()).unwrap();
211        assert_eq!(result.dtype(), &target_dtype);
212        assert_eq!(result.len(), 3);
213        assert_eq!(result.to_struct().unmasked_fields().len(), 2);
214    }
215
216    #[test]
217    fn cast_add_fields() {
218        let names = FieldNames::from(["a", "b"]);
219        let field1 = buffer![1i32, 2, 3].into_array();
220        let field2 = buffer![10i64, 20, 30].into_array();
221        let target_dtype = DType::struct_(
222            [
223                ("a", field1.dtype().clone()),
224                ("b", field2.dtype().clone()),
225                (
226                    "c",
227                    DType::Decimal(DecimalDType::new(38, 10), Nullability::Nullable),
228                ),
229            ],
230            Nullability::NonNullable,
231        );
232
233        let struct_array =
234            StructArray::try_new(names, vec![field1, field2], 3, Validity::NonNullable).unwrap();
235
236        let result = struct_array.to_array().cast(target_dtype.clone()).unwrap();
237        assert_eq!(result.dtype(), &target_dtype);
238        assert_eq!(result.len(), 3);
239        assert_eq!(result.to_struct().unmasked_fields().len(), 3);
240    }
241}