vortex_array/arrays/struct_/compute/
cast.rs1use itertools::Itertools;
5use vortex_dtype::DType;
6use vortex_error::VortexResult;
7use vortex_error::vortex_ensure;
8
9use crate::ArrayRef;
10use crate::ExecutionCtx;
11use crate::IntoArray;
12use crate::arrays::ConstantArray;
13use crate::arrays::StructArray;
14use crate::arrays::StructVTable;
15use crate::builtins::ArrayBuiltins;
16use crate::compute::CastKernel;
17use crate::scalar::Scalar;
18use crate::vtable::ValidityHelper;
19
20impl CastKernel for StructVTable {
21 fn cast(
22 array: &StructArray,
23 dtype: &DType,
24 _ctx: &mut ExecutionCtx,
25 ) -> VortexResult<Option<ArrayRef>> {
26 let Some(target_sdtype) = dtype.as_struct_fields_opt() else {
27 return Ok(None);
28 };
29
30 let source_sdtype = array.struct_fields();
31
32 let fields_match_order = target_sdtype.nfields() == source_sdtype.nfields()
33 && target_sdtype
34 .names()
35 .iter()
36 .zip(source_sdtype.names().iter())
37 .all(|(f1, f2)| f1 == f2);
38
39 let mut cast_fields = Vec::with_capacity(target_sdtype.nfields());
40 if fields_match_order {
41 for (field, target_type) in array
42 .unmasked_fields()
43 .iter()
44 .zip_eq(target_sdtype.fields())
45 {
46 let cast_field = field.cast(target_type)?;
47 cast_fields.push(cast_field);
48 }
49 } else {
50 for (target_name, target_type) in
52 target_sdtype.names().iter().zip_eq(target_sdtype.fields())
53 {
54 match source_sdtype.find(target_name) {
55 None => {
56 vortex_ensure!(
59 target_type.is_nullable(),
60 "CAST for struct only supports added nullable fields"
61 );
62
63 cast_fields.push(
64 ConstantArray::new(Scalar::null(target_type), array.len()).into_array(),
65 );
66 }
67 Some(src_field_idx) => {
68 let cast_field =
70 array.unmasked_fields()[src_field_idx].cast(target_type)?;
71 cast_fields.push(cast_field);
72 }
73 }
74 }
75 }
76
77 let validity = array
78 .validity()
79 .clone()
80 .cast_nullability(dtype.nullability(), array.len())?;
81
82 StructArray::try_new(
83 target_sdtype.names().clone(),
84 cast_fields,
85 array.len(),
86 validity,
87 )
88 .map(|a| Some(a.into_array()))
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use rstest::rstest;
95 use vortex_buffer::buffer;
96 use vortex_dtype::DType;
97 use vortex_dtype::DecimalDType;
98 use vortex_dtype::FieldNames;
99 use vortex_dtype::Nullability;
100 use vortex_dtype::PType;
101
102 use crate::Array;
103 use crate::IntoArray;
104 use crate::ToCanonical;
105 use crate::arrays::PrimitiveArray;
106 use crate::arrays::StructArray;
107 use crate::arrays::VarBinArray;
108 use crate::builtins::ArrayBuiltins;
109 use crate::compute::conformance::cast::test_cast_conformance;
110 use crate::validity::Validity;
111
112 #[rstest]
113 #[case(create_test_struct(false))]
114 #[case(create_test_struct(true))]
115 #[case(create_nested_struct())]
116 #[case(create_simple_struct())]
117 fn test_cast_struct_conformance(#[case] array: StructArray) {
118 test_cast_conformance(array.as_ref());
119 }
120
121 fn create_test_struct(nullable: bool) -> StructArray {
122 let names = FieldNames::from(["a", "b"]);
123
124 let a = buffer![1i32, 2, 3].into_array();
125 let b = VarBinArray::from_iter(
126 vec![Some("x"), None, Some("z")],
127 DType::Utf8(Nullability::Nullable),
128 )
129 .into_array();
130
131 StructArray::try_new(
132 names,
133 vec![a, b],
134 3,
135 if nullable {
136 Validity::AllValid
137 } else {
138 Validity::NonNullable
139 },
140 )
141 .unwrap()
142 }
143
144 fn create_nested_struct() -> StructArray {
145 let inner_names = FieldNames::from(["x", "y"]);
147
148 let x = buffer![1.0f32, 2.0, 3.0].into_array();
149 let y = buffer![4.0f32, 5.0, 6.0].into_array();
150 let inner_struct = StructArray::try_new(inner_names, vec![x, y], 3, Validity::NonNullable)
151 .unwrap()
152 .into_array();
153
154 let outer_names: FieldNames = ["id", "point"].into();
156 let ids = buffer![100i64, 200, 300].into_array();
159
160 StructArray::try_new(
161 outer_names,
162 vec![ids, inner_struct],
163 3,
164 Validity::NonNullable,
165 )
166 .unwrap()
167 }
168
169 fn create_simple_struct() -> StructArray {
170 let names = FieldNames::from(["value"]);
171 let values = buffer![42u8].into_array();
174
175 StructArray::try_new(names, vec![values], 1, Validity::NonNullable).unwrap()
176 }
177
178 #[test]
179 fn cast_nullable_all_invalid() {
180 let empty_struct = StructArray::try_new(
181 FieldNames::from(["a"]),
182 vec![PrimitiveArray::new::<i32>(buffer![], Validity::AllInvalid).to_array()],
183 0,
184 Validity::AllInvalid,
185 )
186 .unwrap()
187 .to_array();
188
189 let target_dtype = DType::struct_(
190 [("a", DType::Primitive(PType::I32, Nullability::NonNullable))],
191 Nullability::NonNullable,
192 );
193
194 let result = empty_struct.cast(target_dtype.clone()).unwrap();
195 assert_eq!(result.dtype(), &target_dtype);
196 assert_eq!(result.len(), 0);
197 }
198
199 #[test]
200 fn cast_duplicate_field_names_to_nullable() {
201 let names = FieldNames::from(["a", "a"]);
202 let field1 = buffer![1i32, 2, 3].into_array();
203 let field2 = buffer![10i64, 20, 30].into_array();
204
205 let struct_array =
206 StructArray::try_new(names, vec![field1, field2], 3, Validity::NonNullable).unwrap();
207
208 let target_dtype = struct_array.dtype().as_nullable();
209
210 let result = struct_array.to_array().cast(target_dtype.clone()).unwrap();
211 assert_eq!(result.dtype(), &target_dtype);
212 assert_eq!(result.len(), 3);
213 assert_eq!(result.to_struct().unmasked_fields().len(), 2);
214 }
215
216 #[test]
217 fn cast_add_fields() {
218 let names = FieldNames::from(["a", "b"]);
219 let field1 = buffer![1i32, 2, 3].into_array();
220 let field2 = buffer![10i64, 20, 30].into_array();
221 let target_dtype = DType::struct_(
222 [
223 ("a", field1.dtype().clone()),
224 ("b", field2.dtype().clone()),
225 (
226 "c",
227 DType::Decimal(DecimalDType::new(38, 10), Nullability::Nullable),
228 ),
229 ],
230 Nullability::NonNullable,
231 );
232
233 let struct_array =
234 StructArray::try_new(names, vec![field1, field2], 3, Validity::NonNullable).unwrap();
235
236 let result = struct_array.to_array().cast(target_dtype.clone()).unwrap();
237 assert_eq!(result.dtype(), &target_dtype);
238 assert_eq!(result.len(), 3);
239 assert_eq!(result.to_struct().unmasked_fields().len(), 3);
240 }
241}