vortex_array/arrays/struct_/compute/
cast.rs1use itertools::Itertools;
5use vortex_dtype::DType;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_ensure;
9use vortex_scalar::Scalar;
10
11use crate::ArrayRef;
12use crate::IntoArray;
13use crate::arrays::ConstantArray;
14use crate::arrays::StructArray;
15use crate::arrays::StructVTable;
16use crate::compute::CastKernel;
17use crate::compute::CastKernelAdapter;
18use crate::compute::cast;
19use crate::register_kernel;
20use crate::vtable::ValidityHelper;
21
22impl CastKernel for StructVTable {
23 fn cast(&self, array: &StructArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
24 let Some(target_sdtype) = dtype.as_struct_fields_opt() else {
25 return Ok(None);
26 };
27
28 let source_sdtype = array
29 .dtype()
30 .as_struct_fields_opt()
31 .vortex_expect("struct array must have struct dtype");
32
33 let mut cast_fields = vec![];
35 for (target_name, target_type) in
36 target_sdtype.names().iter().zip_eq(target_sdtype.fields())
37 {
38 match source_sdtype.find(target_name) {
39 None => {
40 vortex_ensure!(
43 target_type.is_nullable(),
44 "CAST for struct only supports added nullable fields"
45 );
46
47 cast_fields.push(
48 ConstantArray::new(Scalar::null(target_type), array.len).into_array(),
49 );
50 }
51 Some(src_field_idx) => {
52 let cast_field = cast(array.fields()[src_field_idx].as_ref(), &target_type)?;
54 cast_fields.push(cast_field);
55 }
56 }
57 }
58
59 let validity = array
60 .validity()
61 .clone()
62 .cast_nullability(dtype.nullability(), array.len())?;
63
64 StructArray::try_new(
65 target_sdtype.names().clone(),
66 cast_fields,
67 array.len(),
68 validity,
69 )
70 .map(|a| Some(a.into_array()))
71 }
72}
73
74register_kernel!(CastKernelAdapter(StructVTable).lift());
75
76#[cfg(test)]
77mod tests {
78 use rstest::rstest;
79 use vortex_buffer::buffer;
80 use vortex_dtype::DType;
81 use vortex_dtype::FieldNames;
82 use vortex_dtype::Nullability;
83 use vortex_dtype::PType;
84
85 use crate::IntoArray;
86 use crate::arrays::PrimitiveArray;
87 use crate::arrays::StructArray;
88 use crate::arrays::VarBinArray;
89 use crate::compute::conformance::cast::test_cast_conformance;
90 use crate::validity::Validity;
91
92 #[rstest]
93 #[case(create_test_struct(false))]
94 #[case(create_test_struct(true))]
95 #[case(create_nested_struct())]
96 #[case(create_simple_struct())]
97 fn test_cast_struct_conformance(#[case] array: StructArray) {
98 test_cast_conformance(array.as_ref());
99 }
100
101 fn create_test_struct(nullable: bool) -> StructArray {
102 let names = FieldNames::from(["a", "b"]);
103
104 let a = buffer![1i32, 2, 3].into_array();
105 let b = VarBinArray::from_iter(
106 vec![Some("x"), None, Some("z")],
107 DType::Utf8(Nullability::Nullable),
108 )
109 .into_array();
110
111 StructArray::try_new(
112 names,
113 vec![a, b],
114 3,
115 if nullable {
116 Validity::AllValid
117 } else {
118 Validity::NonNullable
119 },
120 )
121 .unwrap()
122 }
123
124 fn create_nested_struct() -> StructArray {
125 let inner_names = FieldNames::from(["x", "y"]);
127
128 let x = buffer![1.0f32, 2.0, 3.0].into_array();
129 let y = buffer![4.0f32, 5.0, 6.0].into_array();
130 let inner_struct = StructArray::try_new(inner_names, vec![x, y], 3, Validity::NonNullable)
131 .unwrap()
132 .into_array();
133
134 let outer_names: FieldNames = ["id", "point"].into();
136 let ids = buffer![100i64, 200, 300].into_array();
139
140 StructArray::try_new(
141 outer_names,
142 vec![ids, inner_struct],
143 3,
144 Validity::NonNullable,
145 )
146 .unwrap()
147 }
148
149 fn create_simple_struct() -> StructArray {
150 let names = FieldNames::from(["value"]);
151 let values = buffer![42u8].into_array();
154
155 StructArray::try_new(names, vec![values], 1, Validity::NonNullable).unwrap()
156 }
157
158 #[test]
159 fn cast_nullable_all_invalid() {
160 let empty_struct = StructArray::try_new(
161 FieldNames::from(["a"]),
162 vec![PrimitiveArray::new::<i32>(buffer![], Validity::AllInvalid).to_array()],
163 0,
164 Validity::AllInvalid,
165 )
166 .unwrap()
167 .to_array();
168
169 let target_dtype = DType::struct_(
170 [("a", DType::Primitive(PType::I32, Nullability::NonNullable))],
171 Nullability::NonNullable,
172 );
173
174 let result = crate::compute::cast(&empty_struct, &target_dtype).unwrap();
175 assert_eq!(result.dtype(), &target_dtype);
176 assert_eq!(result.len(), 0);
177 }
178}