vortex_array/arrays/struct_/compute/
cast.rs1use itertools::Itertools;
5use vortex_dtype::DType;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9
10use crate::ArrayRef;
11use crate::IntoArray;
12use crate::arrays::StructArray;
13use crate::arrays::StructVTable;
14use crate::compute::CastKernel;
15use crate::compute::CastKernelAdapter;
16use crate::compute::cast;
17use crate::register_kernel;
18use crate::vtable::ValidityHelper;
19
20impl CastKernel for StructVTable {
21 fn cast(&self, array: &StructArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
22 let Some(target_sdtype) = dtype.as_struct_fields_opt() else {
23 return Ok(None);
24 };
25
26 let source_sdtype = array
27 .dtype()
28 .as_struct_fields_opt()
29 .vortex_expect("struct array must have struct dtype");
30
31 if target_sdtype.names() != source_sdtype.names() {
32 vortex_bail!("cannot cast {} to {}", array.dtype(), dtype);
33 }
34
35 let validity = array
36 .validity()
37 .clone()
38 .cast_nullability(dtype.nullability(), array.len())?;
39
40 StructArray::try_new(
41 target_sdtype.names().clone(),
42 array
43 .fields()
44 .iter()
45 .zip_eq(target_sdtype.fields())
46 .map(|(field, dtype)| cast(field, &dtype))
47 .collect::<Result<Vec<_>, _>>()?,
48 array.len(),
49 validity,
50 )
51 .map(|a| Some(a.into_array()))
52 }
53}
54
55register_kernel!(CastKernelAdapter(StructVTable).lift());
56
57#[cfg(test)]
58mod tests {
59 use rstest::rstest;
60 use vortex_buffer::buffer;
61 use vortex_dtype::DType;
62 use vortex_dtype::FieldNames;
63 use vortex_dtype::Nullability;
64 use vortex_dtype::PType;
65
66 use crate::IntoArray;
67 use crate::arrays::PrimitiveArray;
68 use crate::arrays::StructArray;
69 use crate::arrays::VarBinArray;
70 use crate::compute::conformance::cast::test_cast_conformance;
71 use crate::validity::Validity;
72
73 #[rstest]
74 #[case(create_test_struct(false))]
75 #[case(create_test_struct(true))]
76 #[case(create_nested_struct())]
77 #[case(create_simple_struct())]
78 fn test_cast_struct_conformance(#[case] array: StructArray) {
79 test_cast_conformance(array.as_ref());
80 }
81
82 fn create_test_struct(nullable: bool) -> StructArray {
83 let names = FieldNames::from(["a", "b"]);
84
85 let a = buffer![1i32, 2, 3].into_array();
86 let b = VarBinArray::from_iter(
87 vec![Some("x"), None, Some("z")],
88 DType::Utf8(Nullability::Nullable),
89 )
90 .into_array();
91
92 StructArray::try_new(
93 names,
94 vec![a, b],
95 3,
96 if nullable {
97 Validity::AllValid
98 } else {
99 Validity::NonNullable
100 },
101 )
102 .unwrap()
103 }
104
105 fn create_nested_struct() -> StructArray {
106 let inner_names = FieldNames::from(["x", "y"]);
108
109 let x = buffer![1.0f32, 2.0, 3.0].into_array();
110 let y = buffer![4.0f32, 5.0, 6.0].into_array();
111 let inner_struct = StructArray::try_new(inner_names, vec![x, y], 3, Validity::NonNullable)
112 .unwrap()
113 .into_array();
114
115 let outer_names: FieldNames = ["id", "point"].into();
117 let ids = buffer![100i64, 200, 300].into_array();
120
121 StructArray::try_new(
122 outer_names,
123 vec![ids, inner_struct],
124 3,
125 Validity::NonNullable,
126 )
127 .unwrap()
128 }
129
130 fn create_simple_struct() -> StructArray {
131 let names = FieldNames::from(["value"]);
132 let values = buffer![42u8].into_array();
135
136 StructArray::try_new(names, vec![values], 1, Validity::NonNullable).unwrap()
137 }
138
139 #[test]
140 fn cast_nullable_all_invalid() {
141 let empty_struct = StructArray::try_new(
142 FieldNames::from(["a"]),
143 vec![PrimitiveArray::new::<i32>(buffer![], Validity::AllInvalid).to_array()],
144 0,
145 Validity::AllInvalid,
146 )
147 .unwrap()
148 .to_array();
149
150 let target_dtype = DType::struct_(
151 [("a", DType::Primitive(PType::I32, Nullability::NonNullable))],
152 Nullability::NonNullable,
153 );
154
155 let result = crate::compute::cast(&empty_struct, &target_dtype).unwrap();
156 assert_eq!(result.dtype(), &target_dtype);
157 assert_eq!(result.len(), 0);
158 }
159}