vortex_array/arrays/struct_/compute/
cast.rs1use itertools::Itertools;
5use vortex_dtype::DType;
6use vortex_error::{VortexExpect, VortexResult, vortex_bail};
7
8use crate::arrays::{StructArray, StructVTable};
9use crate::compute::{CastKernel, CastKernelAdapter, cast};
10use crate::vtable::ValidityHelper;
11use crate::{ArrayRef, IntoArray, register_kernel};
12
13impl CastKernel for StructVTable {
14 fn cast(&self, array: &StructArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
15 let Some(target_sdtype) = dtype.as_struct_fields_opt() else {
16 return Ok(None);
17 };
18
19 let source_sdtype = array
20 .dtype()
21 .as_struct_fields_opt()
22 .vortex_expect("struct array must have struct dtype");
23
24 if target_sdtype.names() != source_sdtype.names() {
25 vortex_bail!("cannot cast {} to {}", array.dtype(), dtype);
26 }
27
28 let validity = array
29 .validity()
30 .clone()
31 .cast_nullability(dtype.nullability(), array.len())?;
32
33 StructArray::try_new(
34 target_sdtype.names().clone(),
35 array
36 .fields()
37 .iter()
38 .zip_eq(target_sdtype.fields())
39 .map(|(field, dtype)| cast(field, &dtype))
40 .try_collect()?,
41 array.len(),
42 validity,
43 )
44 .map(|a| Some(a.into_array()))
45 }
46}
47
48register_kernel!(CastKernelAdapter(StructVTable).lift());
49
50#[cfg(test)]
51mod tests {
52 use rstest::rstest;
53 use vortex_buffer::buffer;
54 use vortex_dtype::{DType, FieldNames, Nullability, PType};
55
56 use crate::IntoArray;
57 use crate::arrays::{PrimitiveArray, StructArray, VarBinArray};
58 use crate::compute::conformance::cast::test_cast_conformance;
59 use crate::validity::Validity;
60
61 #[rstest]
62 #[case(create_test_struct(false))]
63 #[case(create_test_struct(true))]
64 #[case(create_nested_struct())]
65 #[case(create_simple_struct())]
66 fn test_cast_struct_conformance(#[case] array: StructArray) {
67 test_cast_conformance(array.as_ref());
68 }
69
70 fn create_test_struct(nullable: bool) -> StructArray {
71 let names = FieldNames::from(["a", "b"]);
72
73 let a = buffer![1i32, 2, 3].into_array();
74 let b = VarBinArray::from_iter(
75 vec![Some("x"), None, Some("z")],
76 DType::Utf8(Nullability::Nullable),
77 )
78 .into_array();
79
80 StructArray::try_new(
81 names,
82 vec![a, b],
83 3,
84 if nullable {
85 Validity::AllValid
86 } else {
87 Validity::NonNullable
88 },
89 )
90 .unwrap()
91 }
92
93 fn create_nested_struct() -> StructArray {
94 let inner_names = FieldNames::from(["x", "y"]);
96
97 let x = buffer![1.0f32, 2.0, 3.0].into_array();
98 let y = buffer![4.0f32, 5.0, 6.0].into_array();
99 let inner_struct = StructArray::try_new(inner_names, vec![x, y], 3, Validity::NonNullable)
100 .unwrap()
101 .into_array();
102
103 let outer_names: FieldNames = ["id", "point"].into();
105 let ids = buffer![100i64, 200, 300].into_array();
108
109 StructArray::try_new(
110 outer_names,
111 vec![ids, inner_struct],
112 3,
113 Validity::NonNullable,
114 )
115 .unwrap()
116 }
117
118 fn create_simple_struct() -> StructArray {
119 let names = FieldNames::from(["value"]);
120 let values = buffer![42u8].into_array();
123
124 StructArray::try_new(names, vec![values], 1, Validity::NonNullable).unwrap()
125 }
126
127 #[test]
128 fn cast_nullable_all_invalid() {
129 let empty_struct = StructArray::try_new(
130 FieldNames::from(["a"]),
131 vec![PrimitiveArray::new::<i32>(buffer![], Validity::AllInvalid).to_array()],
132 0,
133 Validity::AllInvalid,
134 )
135 .unwrap()
136 .to_array();
137
138 let target_dtype = DType::struct_(
139 [("a", DType::Primitive(PType::I32, Nullability::NonNullable))],
140 Nullability::NonNullable,
141 );
142
143 let result = crate::compute::cast(&empty_struct, &target_dtype).unwrap();
144 assert_eq!(result.dtype(), &target_dtype);
145 assert_eq!(result.len(), 0);
146 }
147}