Skip to main content

vortex_array/arrays/struct_/compute/
rules.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_error::vortex_ensure;
6use vortex_error::vortex_err;
7
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::array::ArrayView;
11use crate::arrays::ConstantArray;
12use crate::arrays::Struct;
13use crate::arrays::StructArray;
14use crate::arrays::dict::TakeReduceAdaptor;
15use crate::arrays::scalar_fn::ExactScalarFn;
16use crate::arrays::scalar_fn::ScalarFnArrayView;
17use crate::arrays::scalar_fn::ScalarFnFactoryExt;
18use crate::arrays::slice::SliceReduceAdaptor;
19use crate::arrays::struct_::StructArrayExt;
20use crate::builtins::ArrayBuiltins;
21use crate::dtype::DType;
22use crate::optimizer::rules::ArrayParentReduceRule;
23use crate::optimizer::rules::ParentRuleSet;
24use crate::scalar_fn::EmptyOptions;
25use crate::scalar_fn::fns::cast::CastReduce;
26use crate::scalar_fn::fns::cast::CastReduceAdaptor;
27use crate::scalar_fn::fns::get_item::GetItem;
28use crate::scalar_fn::fns::mask::Mask;
29use crate::scalar_fn::fns::mask::MaskReduceAdaptor;
30use crate::validity::Validity;
31
32pub(crate) const PARENT_RULES: ParentRuleSet<Struct> = ParentRuleSet::new(&[
33    ParentRuleSet::lift(&CastReduceAdaptor(Struct)),
34    ParentRuleSet::lift(&StructGetItemRule),
35    ParentRuleSet::lift(&MaskReduceAdaptor(Struct)),
36    ParentRuleSet::lift(&SliceReduceAdaptor(Struct)),
37    ParentRuleSet::lift(&TakeReduceAdaptor(Struct)),
38]);
39
40/// Push the cast into struct fields without execution.
41///
42/// Supports schema evolution by allowing new nullable fields to be added during the cast,
43/// filled with null values. For nullability changes, only handles the cheap path
44/// (`try_cast_nullability`); when statistics computation is required to determine whether
45/// the array contains invalid values, returns `Ok(None)` so [`CastKernel`] can run instead.
46///
47/// [`CastKernel`]: crate::scalar_fn::fns::cast::CastKernel
48impl CastReduce for Struct {
49    fn cast(array: ArrayView<'_, Struct>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
50        let Some(target_fields) = dtype.as_struct_fields_opt() else {
51            return Ok(None);
52        };
53
54        let Some(validity) = array
55            .validity()?
56            .trivial_cast_nullability(dtype.nullability(), array.len())?
57        else {
58            return Ok(None);
59        };
60
61        let mut new_fields = Vec::with_capacity(target_fields.nfields());
62
63        for (target_name, target_dtype) in target_fields.names().iter().zip(target_fields.fields())
64        {
65            match array.unmasked_field_by_name(target_name).ok() {
66                Some(field) => {
67                    new_fields.push(field.cast(target_dtype)?);
68                }
69                None => {
70                    // Not found - create NULL array (schema evolution)
71                    vortex_ensure!(
72                        target_dtype.is_nullable(),
73                        "Cannot add non-nullable field '{}' during struct cast",
74                        target_name
75                    );
76                    new_fields.push(
77                        ConstantArray::new(crate::scalar::Scalar::null(target_dtype), array.len())
78                            .into_array(),
79                    );
80                }
81            }
82        }
83
84        Ok(Some(
85            unsafe {
86                StructArray::new_unchecked(new_fields, target_fields.clone(), array.len(), validity)
87            }
88            .into_array(),
89        ))
90    }
91}
92
93/// Rule to flatten get_item from struct by field name
94#[derive(Debug)]
95pub(crate) struct StructGetItemRule;
96impl ArrayParentReduceRule<Struct> for StructGetItemRule {
97    type Parent = ExactScalarFn<GetItem>;
98
99    fn reduce_parent(
100        &self,
101        child: ArrayView<'_, Struct>,
102        parent: ScalarFnArrayView<'_, GetItem>,
103        _child_idx: usize,
104    ) -> VortexResult<Option<ArrayRef>> {
105        let field_name = parent.options;
106        let field = child
107            .unmasked_field_by_name_opt(field_name)
108            .ok_or_else(|| {
109                vortex_err!(
110                    "Field '{}' missing from struct array {}",
111                    field_name,
112                    child.struct_fields().names()
113                )
114            })?;
115
116        match child.validity()? {
117            Validity::NonNullable | Validity::AllValid => {
118                // If the struct is non-nullable or all valid, the field's validity is unchanged
119                Ok(Some(field.clone()))
120            }
121            Validity::AllInvalid => {
122                // If everything is invalid, the field is also all invalid
123                Ok(Some(
124                    ConstantArray::new(
125                        crate::scalar::Scalar::null(field.dtype().clone()),
126                        field.len(),
127                    )
128                    .into_array(),
129                ))
130            }
131            Validity::Array(mask) => {
132                // If the validity is an array, we need to combine it with the field's validity
133                Mask.try_new_array(field.len(), EmptyOptions, [field.clone(), mask])
134                    .map(Some)
135            }
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use std::sync::LazyLock;
143
144    use vortex_buffer::buffer;
145    use vortex_session::VortexSession;
146
147    use crate::IntoArray;
148    use crate::VortexSessionExecute;
149    use crate::arrays::StructArray;
150    use crate::arrays::VarBinViewArray;
151    use crate::arrays::struct_::StructArrayExt;
152    use crate::arrays::struct_::compute::rules::ConstantArray;
153    use crate::assert_arrays_eq;
154    use crate::builtins::ArrayBuiltins;
155    use crate::dtype::DType;
156    use crate::dtype::FieldNames;
157    use crate::dtype::Nullability;
158    use crate::dtype::PType;
159    use crate::dtype::StructFields;
160    use crate::scalar::Scalar;
161    use crate::session::ArraySession;
162    use crate::validity::Validity;
163
164    static SESSION: LazyLock<VortexSession> =
165        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
166
167    #[test]
168    fn test_struct_cast_field_reorder() {
169        // Source: {a, b}, Target: {c, b, a} - reordered + new null field
170        let source = StructArray::try_new(
171            FieldNames::from(["a", "b"]),
172            vec![
173                VarBinViewArray::from_iter_str(["A"]).into_array(),
174                VarBinViewArray::from_iter_str(["B"]).into_array(),
175            ],
176            1,
177            Validity::NonNullable,
178        )
179        .unwrap();
180
181        let utf8_null = DType::Utf8(Nullability::Nullable);
182        let target = DType::Struct(
183            StructFields::new(
184                FieldNames::from(["c", "b", "a"]),
185                vec![utf8_null.clone(); 3],
186            ),
187            Nullability::NonNullable,
188        );
189
190        // Use `ArrayBuiltins::cast` which goes through the optimizer and applies
191        // `StructCastPushDownRule`.
192        let result = source
193            .into_array()
194            .cast(target)
195            .unwrap()
196            .execute::<StructArray>(&mut SESSION.create_execution_ctx())
197            .unwrap();
198        assert_arrays_eq!(
199            result.unmasked_field_by_name("a").unwrap(),
200            VarBinViewArray::from_iter_nullable_str([Some("A")])
201        );
202        assert_arrays_eq!(
203            result.unmasked_field_by_name("b").unwrap(),
204            VarBinViewArray::from_iter_nullable_str([Some("B")])
205        );
206        assert_arrays_eq!(
207            result.unmasked_field_by_name("c").unwrap(),
208            ConstantArray::new(Scalar::null(utf8_null), 1)
209        );
210    }
211
212    /// Regression test: casting a struct to a non-struct DType must not panic. Previously,
213    /// `StructCastPushDownRule` called `as_struct_fields()` which panics on non-struct types.
214    #[test]
215    fn cast_struct_to_non_struct_does_not_panic() {
216        let source = StructArray::try_new(
217            FieldNames::from(["x"]),
218            vec![buffer![1i32, 2, 3].into_array()],
219            3,
220            Validity::NonNullable,
221        )
222        .unwrap();
223
224        // Casting a struct to a primitive type should not panic. Before the fix,
225        // `StructCastPushDownRule` would panic via `as_struct_fields()` on the non-struct target.
226        let result = source
227            .into_array()
228            .cast(DType::Primitive(PType::I32, Nullability::NonNullable));
229        // Whether this errors or succeeds depends on execution, but the key invariant is that the
230        // optimizer rule does not panic.
231        if let Ok(arr) = &result {
232            assert_eq!(
233                arr.dtype(),
234                &DType::Primitive(PType::I32, Nullability::NonNullable)
235            );
236        }
237    }
238
239    #[test]
240    fn cast_struct_drop_field() {
241        // Casting to a struct with a subset of fields should succeed.
242        let source = StructArray::try_new(
243            FieldNames::from(["a", "b", "c"]),
244            vec![
245                buffer![1i32, 2, 3].into_array(),
246                buffer![10i64, 20, 30].into_array(),
247                buffer![100u8, 200, 255].into_array(),
248            ],
249            3,
250            Validity::NonNullable,
251        )
252        .unwrap();
253
254        let target = DType::Struct(
255            StructFields::new(
256                FieldNames::from(["a", "c"]),
257                vec![
258                    DType::Primitive(PType::I32, Nullability::NonNullable),
259                    DType::Primitive(PType::U8, Nullability::NonNullable),
260                ],
261            ),
262            Nullability::NonNullable,
263        );
264
265        let result = source
266            .into_array()
267            .cast(target)
268            .unwrap()
269            .execute::<StructArray>(&mut SESSION.create_execution_ctx())
270            .unwrap();
271        assert_eq!(result.unmasked_fields().len(), 2);
272        assert_arrays_eq!(
273            result.unmasked_field_by_name("a").unwrap(),
274            buffer![1i32, 2, 3].into_array()
275        );
276        assert_arrays_eq!(
277            result.unmasked_field_by_name("c").unwrap(),
278            buffer![100u8, 200, 255].into_array()
279        );
280    }
281
282    #[test]
283    fn cast_struct_field_type_widening() {
284        // Casting struct fields to wider types (i32 -> i64).
285        let source = StructArray::try_new(
286            FieldNames::from(["val"]),
287            vec![buffer![1i32, 2, 3].into_array()],
288            3,
289            Validity::NonNullable,
290        )
291        .unwrap();
292
293        let target = DType::Struct(
294            StructFields::new(
295                FieldNames::from(["val"]),
296                vec![DType::Primitive(PType::I64, Nullability::NonNullable)],
297            ),
298            Nullability::NonNullable,
299        );
300
301        let result = source
302            .into_array()
303            .cast(target)
304            .unwrap()
305            .execute::<StructArray>(&mut SESSION.create_execution_ctx())
306            .unwrap();
307        assert_eq!(
308            result.unmasked_field_by_name("val").unwrap().dtype(),
309            &DType::Primitive(PType::I64, Nullability::NonNullable)
310        );
311        assert_arrays_eq!(
312            result.unmasked_field_by_name("val").unwrap(),
313            buffer![1i64, 2, 3].into_array()
314        );
315    }
316
317    #[test]
318    fn cast_struct_add_non_nullable_field_fails() {
319        // Adding a non-nullable field via cast should fail.
320        let source = StructArray::try_new(
321            FieldNames::from(["a"]),
322            vec![buffer![1i32].into_array()],
323            1,
324            Validity::NonNullable,
325        )
326        .unwrap();
327
328        let target = DType::Struct(
329            StructFields::new(
330                FieldNames::from(["a", "b"]),
331                vec![
332                    DType::Primitive(PType::I32, Nullability::NonNullable),
333                    DType::Primitive(PType::I32, Nullability::NonNullable),
334                ],
335            ),
336            Nullability::NonNullable,
337        );
338
339        assert!(source.into_array().cast(target).is_err());
340    }
341}