Skip to main content

vortex_array/arrays/struct_/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools;
5use vortex_error::VortexResult;
6use vortex_error::vortex_ensure;
7
8use crate::ArrayRef;
9use crate::ExecutionCtx;
10use crate::IntoArray;
11use crate::array::ArrayView;
12use crate::arrays::ConstantArray;
13use crate::arrays::Struct;
14use crate::arrays::StructArray;
15use crate::arrays::struct_::StructArrayExt;
16use crate::builtins::ArrayBuiltins;
17use crate::dtype::DType;
18use crate::scalar::Scalar;
19use crate::scalar_fn::fns::cast::CastKernel;
20
21impl CastKernel for Struct {
22    fn cast(
23        array: ArrayView<'_, Struct>,
24        dtype: &DType,
25        ctx: &mut ExecutionCtx,
26    ) -> VortexResult<Option<ArrayRef>> {
27        let Some(target_sdtype) = dtype.as_struct_fields_opt() else {
28            return Ok(None);
29        };
30
31        let source_sdtype = array.struct_fields();
32
33        let fields_match_order = target_sdtype.nfields() == source_sdtype.nfields()
34            && target_sdtype
35                .names()
36                .iter()
37                .zip(source_sdtype.names().iter())
38                .all(|(f1, f2)| f1 == f2);
39
40        let mut cast_fields = Vec::with_capacity(target_sdtype.nfields());
41        if fields_match_order {
42            for (field, target_type) in array.iter_unmasked_fields().zip_eq(target_sdtype.fields())
43            {
44                let cast_field = field.cast(target_type)?;
45                cast_fields.push(cast_field);
46            }
47        } else {
48            // Re-order, handle fields by value instead.
49            for (target_name, target_type) in
50                target_sdtype.names().iter().zip_eq(target_sdtype.fields())
51            {
52                match source_sdtype.find(target_name) {
53                    None => {
54                        // No source field with this name => evolve the schema compatibly.
55                        // If the field is nullable, we add a new ConstantArray field with the type.
56                        vortex_ensure!(
57                            target_type.is_nullable(),
58                            "CAST for struct only supports added nullable fields"
59                        );
60
61                        cast_fields.push(
62                            ConstantArray::new(Scalar::null(target_type), array.len()).into_array(),
63                        );
64                    }
65                    Some(src_field_idx) => {
66                        // Field exists in source field. Cast it to the target type.
67                        let cast_field = array.unmasked_field(src_field_idx).cast(target_type)?;
68                        cast_fields.push(cast_field);
69                    }
70                }
71            }
72        }
73
74        let validity = array
75            .validity()?
76            .cast_nullability(dtype.nullability(), array.len(), ctx)?;
77
78        StructArray::try_new(
79            target_sdtype.names().clone(),
80            cast_fields,
81            array.len(),
82            validity,
83        )
84        .map(|a| Some(a.into_array()))
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use rstest::rstest;
91    use vortex_buffer::buffer;
92
93    use crate::IntoArray;
94    #[expect(deprecated)]
95    use crate::ToCanonical as _;
96    use crate::arrays::PrimitiveArray;
97    use crate::arrays::StructArray;
98    use crate::arrays::VarBinArray;
99    use crate::arrays::struct_::StructArrayExt;
100    use crate::builtins::ArrayBuiltins;
101    use crate::compute::conformance::cast::test_cast_conformance;
102    use crate::dtype::DType;
103    use crate::dtype::DecimalDType;
104    use crate::dtype::FieldNames;
105    use crate::dtype::Nullability;
106    use crate::dtype::PType;
107    use crate::validity::Validity;
108
109    #[rstest]
110    #[case(create_test_struct(false))]
111    #[case(create_test_struct(true))]
112    #[case(create_nested_struct())]
113    #[case(create_simple_struct())]
114    fn test_cast_struct_conformance(#[case] array: StructArray) {
115        test_cast_conformance(&array.into_array());
116    }
117
118    fn create_test_struct(nullable: bool) -> StructArray {
119        let names = FieldNames::from(["a", "b"]);
120
121        let a = buffer![1i32, 2, 3].into_array();
122        let b = VarBinArray::from_iter(
123            vec![Some("x"), None, Some("z")],
124            DType::Utf8(Nullability::Nullable),
125        )
126        .into_array();
127
128        StructArray::try_new(
129            names,
130            vec![a, b],
131            3,
132            if nullable {
133                Validity::AllValid
134            } else {
135                Validity::NonNullable
136            },
137        )
138        .unwrap()
139    }
140
141    fn create_nested_struct() -> StructArray {
142        // Create inner struct
143        let inner_names = FieldNames::from(["x", "y"]);
144
145        let x = buffer![1.0f32, 2.0, 3.0].into_array();
146        let y = buffer![4.0f32, 5.0, 6.0].into_array();
147        let inner_struct = StructArray::try_new(inner_names, vec![x, y], 3, Validity::NonNullable)
148            .unwrap()
149            .into_array();
150
151        // Create outer struct with inner struct as a field
152        let outer_names: FieldNames = ["id", "point"].into();
153        // Outer struct would have fields: id (I64) and point (inner struct)
154
155        let ids = buffer![100i64, 200, 300].into_array();
156
157        StructArray::try_new(
158            outer_names,
159            vec![ids, inner_struct],
160            3,
161            Validity::NonNullable,
162        )
163        .unwrap()
164    }
165
166    fn create_simple_struct() -> StructArray {
167        let names = FieldNames::from(["value"]);
168        // Simple struct with a single U8 field
169
170        let values = buffer![42u8].into_array();
171
172        StructArray::try_new(names, vec![values], 1, Validity::NonNullable).unwrap()
173    }
174
175    #[test]
176    fn cast_nullable_all_invalid() {
177        let empty_struct = StructArray::try_new(
178            FieldNames::from(["a"]),
179            vec![PrimitiveArray::new::<i32>(buffer![], Validity::AllInvalid).into_array()],
180            0,
181            Validity::AllInvalid,
182        )
183        .unwrap()
184        .into_array();
185
186        let target_dtype = DType::struct_(
187            [("a", DType::Primitive(PType::I32, Nullability::NonNullable))],
188            Nullability::NonNullable,
189        );
190
191        let result = empty_struct.cast(target_dtype.clone()).unwrap();
192        assert_eq!(result.dtype(), &target_dtype);
193        assert_eq!(result.len(), 0);
194    }
195
196    #[test]
197    fn cast_duplicate_field_names_to_nullable() {
198        let names = FieldNames::from(["a", "a"]);
199        let field1 = buffer![1i32, 2, 3].into_array();
200        let field2 = buffer![10i64, 20, 30].into_array();
201
202        let struct_array =
203            StructArray::try_new(names, vec![field1, field2], 3, Validity::NonNullable).unwrap();
204
205        let target_dtype = struct_array.dtype().as_nullable();
206
207        let result = struct_array
208            .into_array()
209            .cast(target_dtype.clone())
210            .unwrap();
211        assert_eq!(result.dtype(), &target_dtype);
212        assert_eq!(result.len(), 3);
213        #[expect(deprecated)]
214        let nfields = result.to_struct().struct_fields().nfields();
215        assert_eq!(nfields, 2);
216    }
217
218    #[test]
219    fn cast_add_fields() {
220        let names = FieldNames::from(["a", "b"]);
221        let field1 = buffer![1i32, 2, 3].into_array();
222        let field2 = buffer![10i64, 20, 30].into_array();
223        let target_dtype = DType::struct_(
224            [
225                ("a", field1.dtype().clone()),
226                ("b", field2.dtype().clone()),
227                (
228                    "c",
229                    DType::Decimal(DecimalDType::new(38, 10), Nullability::Nullable),
230                ),
231            ],
232            Nullability::NonNullable,
233        );
234
235        let struct_array =
236            StructArray::try_new(names, vec![field1, field2], 3, Validity::NonNullable).unwrap();
237
238        let result = struct_array
239            .into_array()
240            .cast(target_dtype.clone())
241            .unwrap();
242        assert_eq!(result.dtype(), &target_dtype);
243        assert_eq!(result.len(), 3);
244        #[expect(deprecated)]
245        let nfields = result.to_struct().struct_fields().nfields();
246        assert_eq!(nfields, 3);
247    }
248}