vortex_array/arrays/struct_/compute/
cast.rs1use itertools::Itertools;
5use vortex_error::VortexResult;
6use vortex_error::vortex_ensure;
7
8use crate::ArrayRef;
9use crate::ExecutionCtx;
10use crate::IntoArray;
11use crate::array::ArrayView;
12use crate::arrays::ConstantArray;
13use crate::arrays::Struct;
14use crate::arrays::StructArray;
15use crate::arrays::struct_::StructArrayExt;
16use crate::builtins::ArrayBuiltins;
17use crate::dtype::DType;
18use crate::scalar::Scalar;
19use crate::scalar_fn::fns::cast::CastKernel;
20
21impl CastKernel for Struct {
22 fn cast(
23 array: ArrayView<'_, Struct>,
24 dtype: &DType,
25 ctx: &mut ExecutionCtx,
26 ) -> VortexResult<Option<ArrayRef>> {
27 let Some(target_sdtype) = dtype.as_struct_fields_opt() else {
28 return Ok(None);
29 };
30
31 let source_sdtype = array.struct_fields();
32
33 let fields_match_order = target_sdtype.nfields() == source_sdtype.nfields()
34 && target_sdtype
35 .names()
36 .iter()
37 .zip(source_sdtype.names().iter())
38 .all(|(f1, f2)| f1 == f2);
39
40 let mut cast_fields = Vec::with_capacity(target_sdtype.nfields());
41 if fields_match_order {
42 for (field, target_type) in array.iter_unmasked_fields().zip_eq(target_sdtype.fields())
43 {
44 let cast_field = field.cast(target_type)?;
45 cast_fields.push(cast_field);
46 }
47 } else {
48 for (target_name, target_type) in
50 target_sdtype.names().iter().zip_eq(target_sdtype.fields())
51 {
52 match source_sdtype.find(target_name) {
53 None => {
54 vortex_ensure!(
57 target_type.is_nullable(),
58 "CAST for struct only supports added nullable fields"
59 );
60
61 cast_fields.push(
62 ConstantArray::new(Scalar::null(target_type), array.len()).into_array(),
63 );
64 }
65 Some(src_field_idx) => {
66 let cast_field = array.unmasked_field(src_field_idx).cast(target_type)?;
68 cast_fields.push(cast_field);
69 }
70 }
71 }
72 }
73
74 let validity = array
75 .validity()?
76 .cast_nullability(dtype.nullability(), array.len(), ctx)?;
77
78 StructArray::try_new(
79 target_sdtype.names().clone(),
80 cast_fields,
81 array.len(),
82 validity,
83 )
84 .map(|a| Some(a.into_array()))
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use rstest::rstest;
91 use vortex_buffer::buffer;
92
93 use crate::IntoArray;
94 #[expect(deprecated)]
95 use crate::ToCanonical as _;
96 use crate::arrays::PrimitiveArray;
97 use crate::arrays::StructArray;
98 use crate::arrays::VarBinArray;
99 use crate::arrays::struct_::StructArrayExt;
100 use crate::builtins::ArrayBuiltins;
101 use crate::compute::conformance::cast::test_cast_conformance;
102 use crate::dtype::DType;
103 use crate::dtype::DecimalDType;
104 use crate::dtype::FieldNames;
105 use crate::dtype::Nullability;
106 use crate::dtype::PType;
107 use crate::validity::Validity;
108
109 #[rstest]
110 #[case(create_test_struct(false))]
111 #[case(create_test_struct(true))]
112 #[case(create_nested_struct())]
113 #[case(create_simple_struct())]
114 fn test_cast_struct_conformance(#[case] array: StructArray) {
115 test_cast_conformance(&array.into_array());
116 }
117
118 fn create_test_struct(nullable: bool) -> StructArray {
119 let names = FieldNames::from(["a", "b"]);
120
121 let a = buffer![1i32, 2, 3].into_array();
122 let b = VarBinArray::from_iter(
123 vec![Some("x"), None, Some("z")],
124 DType::Utf8(Nullability::Nullable),
125 )
126 .into_array();
127
128 StructArray::try_new(
129 names,
130 vec![a, b],
131 3,
132 if nullable {
133 Validity::AllValid
134 } else {
135 Validity::NonNullable
136 },
137 )
138 .unwrap()
139 }
140
141 fn create_nested_struct() -> StructArray {
142 let inner_names = FieldNames::from(["x", "y"]);
144
145 let x = buffer![1.0f32, 2.0, 3.0].into_array();
146 let y = buffer![4.0f32, 5.0, 6.0].into_array();
147 let inner_struct = StructArray::try_new(inner_names, vec![x, y], 3, Validity::NonNullable)
148 .unwrap()
149 .into_array();
150
151 let outer_names: FieldNames = ["id", "point"].into();
153 let ids = buffer![100i64, 200, 300].into_array();
156
157 StructArray::try_new(
158 outer_names,
159 vec![ids, inner_struct],
160 3,
161 Validity::NonNullable,
162 )
163 .unwrap()
164 }
165
166 fn create_simple_struct() -> StructArray {
167 let names = FieldNames::from(["value"]);
168 let values = buffer![42u8].into_array();
171
172 StructArray::try_new(names, vec![values], 1, Validity::NonNullable).unwrap()
173 }
174
175 #[test]
176 fn cast_nullable_all_invalid() {
177 let empty_struct = StructArray::try_new(
178 FieldNames::from(["a"]),
179 vec![PrimitiveArray::new::<i32>(buffer![], Validity::AllInvalid).into_array()],
180 0,
181 Validity::AllInvalid,
182 )
183 .unwrap()
184 .into_array();
185
186 let target_dtype = DType::struct_(
187 [("a", DType::Primitive(PType::I32, Nullability::NonNullable))],
188 Nullability::NonNullable,
189 );
190
191 let result = empty_struct.cast(target_dtype.clone()).unwrap();
192 assert_eq!(result.dtype(), &target_dtype);
193 assert_eq!(result.len(), 0);
194 }
195
196 #[test]
197 fn cast_duplicate_field_names_to_nullable() {
198 let names = FieldNames::from(["a", "a"]);
199 let field1 = buffer![1i32, 2, 3].into_array();
200 let field2 = buffer![10i64, 20, 30].into_array();
201
202 let struct_array =
203 StructArray::try_new(names, vec![field1, field2], 3, Validity::NonNullable).unwrap();
204
205 let target_dtype = struct_array.dtype().as_nullable();
206
207 let result = struct_array
208 .into_array()
209 .cast(target_dtype.clone())
210 .unwrap();
211 assert_eq!(result.dtype(), &target_dtype);
212 assert_eq!(result.len(), 3);
213 #[expect(deprecated)]
214 let nfields = result.to_struct().struct_fields().nfields();
215 assert_eq!(nfields, 2);
216 }
217
218 #[test]
219 fn cast_add_fields() {
220 let names = FieldNames::from(["a", "b"]);
221 let field1 = buffer![1i32, 2, 3].into_array();
222 let field2 = buffer![10i64, 20, 30].into_array();
223 let target_dtype = DType::struct_(
224 [
225 ("a", field1.dtype().clone()),
226 ("b", field2.dtype().clone()),
227 (
228 "c",
229 DType::Decimal(DecimalDType::new(38, 10), Nullability::Nullable),
230 ),
231 ],
232 Nullability::NonNullable,
233 );
234
235 let struct_array =
236 StructArray::try_new(names, vec![field1, field2], 3, Validity::NonNullable).unwrap();
237
238 let result = struct_array
239 .into_array()
240 .cast(target_dtype.clone())
241 .unwrap();
242 assert_eq!(result.dtype(), &target_dtype);
243 assert_eq!(result.len(), 3);
244 #[expect(deprecated)]
245 let nfields = result.to_struct().struct_fields().nfields();
246 assert_eq!(nfields, 3);
247 }
248}