vortex_array/arrays/struct_/compute/
rules.rs1use vortex_error::VortexResult;
5use vortex_error::vortex_ensure;
6use vortex_error::vortex_err;
7
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::array::ArrayView;
11use crate::arrays::ConstantArray;
12use crate::arrays::Struct;
13use crate::arrays::StructArray;
14use crate::arrays::dict::TakeReduceAdaptor;
15use crate::arrays::scalar_fn::ExactScalarFn;
16use crate::arrays::scalar_fn::ScalarFnArrayView;
17use crate::arrays::scalar_fn::ScalarFnFactoryExt;
18use crate::arrays::slice::SliceReduceAdaptor;
19use crate::arrays::struct_::StructArrayExt;
20use crate::builtins::ArrayBuiltins;
21use crate::dtype::DType;
22use crate::optimizer::rules::ArrayParentReduceRule;
23use crate::optimizer::rules::ParentRuleSet;
24use crate::scalar_fn::EmptyOptions;
25use crate::scalar_fn::fns::cast::CastReduce;
26use crate::scalar_fn::fns::cast::CastReduceAdaptor;
27use crate::scalar_fn::fns::get_item::GetItem;
28use crate::scalar_fn::fns::mask::Mask;
29use crate::scalar_fn::fns::mask::MaskReduceAdaptor;
30use crate::validity::Validity;
31
32pub(crate) const PARENT_RULES: ParentRuleSet<Struct> = ParentRuleSet::new(&[
33 ParentRuleSet::lift(&CastReduceAdaptor(Struct)),
34 ParentRuleSet::lift(&StructGetItemRule),
35 ParentRuleSet::lift(&MaskReduceAdaptor(Struct)),
36 ParentRuleSet::lift(&SliceReduceAdaptor(Struct)),
37 ParentRuleSet::lift(&TakeReduceAdaptor(Struct)),
38]);
39
40impl CastReduce for Struct {
49 fn cast(array: ArrayView<'_, Struct>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
50 let Some(target_fields) = dtype.as_struct_fields_opt() else {
51 return Ok(None);
52 };
53
54 let Some(validity) = array
55 .validity()?
56 .trivial_cast_nullability(dtype.nullability(), array.len())?
57 else {
58 return Ok(None);
59 };
60
61 let mut new_fields = Vec::with_capacity(target_fields.nfields());
62
63 for (target_name, target_dtype) in target_fields.names().iter().zip(target_fields.fields())
64 {
65 match array.unmasked_field_by_name(target_name).ok() {
66 Some(field) => {
67 new_fields.push(field.cast(target_dtype)?);
68 }
69 None => {
70 vortex_ensure!(
72 target_dtype.is_nullable(),
73 "Cannot add non-nullable field '{}' during struct cast",
74 target_name
75 );
76 new_fields.push(
77 ConstantArray::new(crate::scalar::Scalar::null(target_dtype), array.len())
78 .into_array(),
79 );
80 }
81 }
82 }
83
84 Ok(Some(
85 unsafe {
86 StructArray::new_unchecked(new_fields, target_fields.clone(), array.len(), validity)
87 }
88 .into_array(),
89 ))
90 }
91}
92
93#[derive(Debug)]
95pub(crate) struct StructGetItemRule;
96impl ArrayParentReduceRule<Struct> for StructGetItemRule {
97 type Parent = ExactScalarFn<GetItem>;
98
99 fn reduce_parent(
100 &self,
101 child: ArrayView<'_, Struct>,
102 parent: ScalarFnArrayView<'_, GetItem>,
103 _child_idx: usize,
104 ) -> VortexResult<Option<ArrayRef>> {
105 let field_name = parent.options;
106 let field = child
107 .unmasked_field_by_name_opt(field_name)
108 .ok_or_else(|| {
109 vortex_err!(
110 "Field '{}' missing from struct array {}",
111 field_name,
112 child.struct_fields().names()
113 )
114 })?;
115
116 match child.validity()? {
117 Validity::NonNullable | Validity::AllValid => {
118 Ok(Some(field.clone()))
120 }
121 Validity::AllInvalid => {
122 Ok(Some(
124 ConstantArray::new(
125 crate::scalar::Scalar::null(field.dtype().clone()),
126 field.len(),
127 )
128 .into_array(),
129 ))
130 }
131 Validity::Array(mask) => {
132 Mask.try_new_array(field.len(), EmptyOptions, [field.clone(), mask])
134 .map(Some)
135 }
136 }
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use std::sync::LazyLock;
143
144 use vortex_buffer::buffer;
145 use vortex_session::VortexSession;
146
147 use crate::IntoArray;
148 use crate::VortexSessionExecute;
149 use crate::arrays::StructArray;
150 use crate::arrays::VarBinViewArray;
151 use crate::arrays::struct_::StructArrayExt;
152 use crate::arrays::struct_::compute::rules::ConstantArray;
153 use crate::assert_arrays_eq;
154 use crate::builtins::ArrayBuiltins;
155 use crate::dtype::DType;
156 use crate::dtype::FieldNames;
157 use crate::dtype::Nullability;
158 use crate::dtype::PType;
159 use crate::dtype::StructFields;
160 use crate::scalar::Scalar;
161 use crate::session::ArraySession;
162 use crate::validity::Validity;
163
164 static SESSION: LazyLock<VortexSession> =
165 LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
166
167 #[test]
168 fn test_struct_cast_field_reorder() {
169 let source = StructArray::try_new(
171 FieldNames::from(["a", "b"]),
172 vec![
173 VarBinViewArray::from_iter_str(["A"]).into_array(),
174 VarBinViewArray::from_iter_str(["B"]).into_array(),
175 ],
176 1,
177 Validity::NonNullable,
178 )
179 .unwrap();
180
181 let utf8_null = DType::Utf8(Nullability::Nullable);
182 let target = DType::Struct(
183 StructFields::new(
184 FieldNames::from(["c", "b", "a"]),
185 vec![utf8_null.clone(); 3],
186 ),
187 Nullability::NonNullable,
188 );
189
190 let result = source
193 .into_array()
194 .cast(target)
195 .unwrap()
196 .execute::<StructArray>(&mut SESSION.create_execution_ctx())
197 .unwrap();
198 assert_arrays_eq!(
199 result.unmasked_field_by_name("a").unwrap(),
200 VarBinViewArray::from_iter_nullable_str([Some("A")])
201 );
202 assert_arrays_eq!(
203 result.unmasked_field_by_name("b").unwrap(),
204 VarBinViewArray::from_iter_nullable_str([Some("B")])
205 );
206 assert_arrays_eq!(
207 result.unmasked_field_by_name("c").unwrap(),
208 ConstantArray::new(Scalar::null(utf8_null), 1)
209 );
210 }
211
212 #[test]
215 fn cast_struct_to_non_struct_does_not_panic() {
216 let source = StructArray::try_new(
217 FieldNames::from(["x"]),
218 vec![buffer![1i32, 2, 3].into_array()],
219 3,
220 Validity::NonNullable,
221 )
222 .unwrap();
223
224 let result = source
227 .into_array()
228 .cast(DType::Primitive(PType::I32, Nullability::NonNullable));
229 if let Ok(arr) = &result {
232 assert_eq!(
233 arr.dtype(),
234 &DType::Primitive(PType::I32, Nullability::NonNullable)
235 );
236 }
237 }
238
239 #[test]
240 fn cast_struct_drop_field() {
241 let source = StructArray::try_new(
243 FieldNames::from(["a", "b", "c"]),
244 vec![
245 buffer![1i32, 2, 3].into_array(),
246 buffer![10i64, 20, 30].into_array(),
247 buffer![100u8, 200, 255].into_array(),
248 ],
249 3,
250 Validity::NonNullable,
251 )
252 .unwrap();
253
254 let target = DType::Struct(
255 StructFields::new(
256 FieldNames::from(["a", "c"]),
257 vec![
258 DType::Primitive(PType::I32, Nullability::NonNullable),
259 DType::Primitive(PType::U8, Nullability::NonNullable),
260 ],
261 ),
262 Nullability::NonNullable,
263 );
264
265 let result = source
266 .into_array()
267 .cast(target)
268 .unwrap()
269 .execute::<StructArray>(&mut SESSION.create_execution_ctx())
270 .unwrap();
271 assert_eq!(result.unmasked_fields().len(), 2);
272 assert_arrays_eq!(
273 result.unmasked_field_by_name("a").unwrap(),
274 buffer![1i32, 2, 3].into_array()
275 );
276 assert_arrays_eq!(
277 result.unmasked_field_by_name("c").unwrap(),
278 buffer![100u8, 200, 255].into_array()
279 );
280 }
281
282 #[test]
283 fn cast_struct_field_type_widening() {
284 let source = StructArray::try_new(
286 FieldNames::from(["val"]),
287 vec![buffer![1i32, 2, 3].into_array()],
288 3,
289 Validity::NonNullable,
290 )
291 .unwrap();
292
293 let target = DType::Struct(
294 StructFields::new(
295 FieldNames::from(["val"]),
296 vec![DType::Primitive(PType::I64, Nullability::NonNullable)],
297 ),
298 Nullability::NonNullable,
299 );
300
301 let result = source
302 .into_array()
303 .cast(target)
304 .unwrap()
305 .execute::<StructArray>(&mut SESSION.create_execution_ctx())
306 .unwrap();
307 assert_eq!(
308 result.unmasked_field_by_name("val").unwrap().dtype(),
309 &DType::Primitive(PType::I64, Nullability::NonNullable)
310 );
311 assert_arrays_eq!(
312 result.unmasked_field_by_name("val").unwrap(),
313 buffer![1i64, 2, 3].into_array()
314 );
315 }
316
317 #[test]
318 fn cast_struct_add_non_nullable_field_fails() {
319 let source = StructArray::try_new(
321 FieldNames::from(["a"]),
322 vec![buffer![1i32].into_array()],
323 1,
324 Validity::NonNullable,
325 )
326 .unwrap();
327
328 let target = DType::Struct(
329 StructFields::new(
330 FieldNames::from(["a", "b"]),
331 vec![
332 DType::Primitive(PType::I32, Nullability::NonNullable),
333 DType::Primitive(PType::I32, Nullability::NonNullable),
334 ],
335 ),
336 Nullability::NonNullable,
337 );
338
339 assert!(source.into_array().cast(target).is_err());
340 }
341}