vortex_array/arrays/struct_/compute/
mod.rs

1mod cast;
2mod filter;
3mod mask;
4
5use itertools::Itertools;
6use vortex_dtype::Nullability::NonNullable;
7use vortex_error::VortexResult;
8use vortex_scalar::Scalar;
9
10use crate::arrays::StructVTable;
11use crate::arrays::struct_::StructArray;
12use crate::compute::{
13    IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, MinMaxKernel, MinMaxKernelAdapter,
14    MinMaxResult, TakeKernel, TakeKernelAdapter, fill_null, is_constant_opts, take,
15};
16use crate::validity::Validity;
17use crate::vtable::ValidityHelper;
18use crate::{Array, ArrayRef, IntoArray, register_kernel};
19
20impl TakeKernel for StructVTable {
21    fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
22        // If the struct array is empty then the indices must be all null, otherwise it will access
23        // an out of bounds element
24        if array.is_empty() {
25            return StructArray::try_new_with_dtype(
26                array.fields().to_vec(),
27                array.struct_fields().clone(),
28                indices.len(),
29                Validity::AllInvalid,
30            )
31            .map(StructArray::into_array);
32        }
33        // The validity is applied to the struct validity,
34        let inner_indices = &fill_null(
35            indices,
36            &Scalar::default_value(indices.dtype().with_nullability(NonNullable)),
37        )?;
38        StructArray::try_new_with_dtype(
39            array
40                .fields()
41                .iter()
42                .map(|field| take(field, inner_indices))
43                .try_collect()?,
44            array.struct_fields().clone(),
45            indices.len(),
46            array.validity().take(indices)?,
47        )
48        .map(|a| a.into_array())
49    }
50}
51
52register_kernel!(TakeKernelAdapter(StructVTable).lift());
53
54impl MinMaxKernel for StructVTable {
55    fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
56        // TODO(joe): Implement struct min max
57        Ok(None)
58    }
59}
60
61register_kernel!(MinMaxKernelAdapter(StructVTable).lift());
62
63impl IsConstantKernel for StructVTable {
64    fn is_constant(
65        &self,
66        array: &StructArray,
67        opts: &IsConstantOpts,
68    ) -> VortexResult<Option<bool>> {
69        let children = array.children();
70        if children.is_empty() {
71            return Ok(None);
72        }
73
74        for child in children.iter() {
75            match is_constant_opts(child, opts)? {
76                // Un-determined
77                None => return Ok(None),
78                Some(false) => return Ok(Some(false)),
79                Some(true) => {}
80            }
81        }
82
83        Ok(Some(true))
84    }
85}
86
87register_kernel!(IsConstantKernelAdapter(StructVTable).lift());
88
89#[cfg(test)]
90mod tests {
91
92    use Nullability::{NonNullable, Nullable};
93    use vortex_buffer::buffer;
94    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
95    use vortex_mask::Mask;
96    use vortex_scalar::Scalar;
97
98    use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
99    use crate::compute::conformance::mask::test_mask;
100    use crate::compute::{cast, filter, take};
101    use crate::validity::Validity;
102    use crate::{Array, IntoArray as _};
103
104    #[test]
105    fn filter_empty_struct() {
106        let struct_arr =
107            StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
108        let mask = vec![
109            false, true, false, true, false, true, false, true, false, true,
110        ];
111        let filtered = filter(struct_arr.as_ref(), &Mask::from_iter(mask)).unwrap();
112        assert_eq!(filtered.len(), 5);
113    }
114
115    #[test]
116    fn take_empty_struct() {
117        let struct_arr =
118            StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
119        let indices = PrimitiveArray::from_option_iter([Some(1), None]);
120        let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
121        assert_eq!(taken.len(), 2);
122
123        assert_eq!(
124            taken.scalar_at(0).unwrap(),
125            Scalar::struct_(
126                DType::Struct(StructFields::new(FieldNames::default(), vec![]), Nullable),
127                vec![]
128            )
129        );
130        assert_eq!(
131            taken.scalar_at(1).unwrap(),
132            Scalar::null(DType::Struct(
133                StructFields::new(FieldNames::default(), vec![]),
134                Nullable
135            ))
136        );
137    }
138
139    #[test]
140    fn take_field_struct() {
141        let struct_arr =
142            StructArray::from_fields(&[("a", PrimitiveArray::from_iter(0..10).to_array())])
143                .unwrap();
144        let indices = PrimitiveArray::from_option_iter([Some(1), None]);
145        let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
146        assert_eq!(taken.len(), 2);
147
148        assert_eq!(
149            taken.scalar_at(0).unwrap(),
150            Scalar::struct_(
151                struct_arr.dtype().union_nullability(Nullable),
152                vec![Scalar::primitive(1, NonNullable)],
153            )
154        );
155        assert_eq!(
156            taken.scalar_at(1).unwrap(),
157            Scalar::null(struct_arr.dtype().union_nullability(Nullable),)
158        );
159    }
160
161    #[test]
162    fn filter_empty_struct_with_empty_filter() {
163        let struct_arr =
164            StructArray::try_new(vec![].into(), vec![], 0, Validity::NonNullable).unwrap();
165        let filtered = filter(struct_arr.as_ref(), &Mask::from_iter::<[bool; 0]>([])).unwrap();
166        assert_eq!(filtered.len(), 0);
167    }
168
169    #[test]
170    fn test_mask_empty_struct() {
171        test_mask(
172            StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
173                .unwrap()
174                .as_ref(),
175        );
176    }
177
178    #[test]
179    fn test_mask_complex_struct() {
180        let xs = buffer![0i64, 1, 2, 3, 4].into_array();
181        let ys = VarBinArray::from_iter(
182            [Some("a"), Some("b"), None, Some("d"), None],
183            DType::Utf8(Nullable),
184        )
185        .into_array();
186        let zs =
187            BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
188
189        test_mask(
190            StructArray::try_new(
191                ["xs", "ys", "zs"].into(),
192                vec![
193                    StructArray::try_new(
194                        ["left", "right"].into(),
195                        vec![xs.clone(), xs],
196                        5,
197                        Validity::NonNullable,
198                    )
199                    .unwrap()
200                    .into_array(),
201                    ys,
202                    zs,
203                ],
204                5,
205                Validity::NonNullable,
206            )
207            .unwrap()
208            .as_ref(),
209        );
210    }
211
212    #[test]
213    fn test_cast_empty_struct() {
214        let array = StructArray::try_new(FieldNames::default(), vec![], 5, Validity::NonNullable)
215            .unwrap()
216            .into_array();
217        let non_nullable_dtype = DType::Struct(
218            StructFields::new(FieldNames::default(), vec![]),
219            NonNullable,
220        );
221        let casted = cast(&array, &non_nullable_dtype).unwrap();
222        assert_eq!(casted.dtype(), &non_nullable_dtype);
223
224        let nullable_dtype =
225            DType::Struct(StructFields::new(FieldNames::default(), vec![]), Nullable);
226        let casted = cast(&array, &nullable_dtype).unwrap();
227        assert_eq!(casted.dtype(), &nullable_dtype);
228    }
229
230    #[test]
231    fn test_cast_cannot_change_name_order() {
232        let array = StructArray::try_new(
233            ["xs", "ys", "zs"].into(),
234            vec![
235                buffer![1u8].into_array(),
236                buffer![1u8].into_array(),
237                buffer![1u8].into_array(),
238            ],
239            1,
240            Validity::NonNullable,
241        )
242        .unwrap();
243
244        let tu8 = DType::Primitive(PType::U8, NonNullable);
245
246        let result = cast(
247            array.as_ref(),
248            &DType::Struct(
249                StructFields::new(
250                    FieldNames::from(["ys", "xs", "zs"]),
251                    vec![tu8.clone(), tu8.clone(), tu8],
252                ),
253                NonNullable,
254            ),
255        );
256        assert!(
257            result.as_ref().is_err_and(|err| {
258                err.to_string()
259                    .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
260            }),
261            "{result:?}"
262        );
263    }
264
265    #[test]
266    fn test_cast_complex_struct() {
267        let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
268        let ys = VarBinArray::from_vec(vec!["a", "b", "c", "d", "e"], DType::Utf8(Nullable));
269        let zs = BoolArray::new(
270            BooleanBuffer::from_iter([true, true, false, false, true]),
271            Validity::AllValid,
272        );
273        let fully_nullable_array = StructArray::try_new(
274            ["xs", "ys", "zs"].into(),
275            vec![
276                StructArray::try_new(
277                    ["left", "right"].into(),
278                    vec![xs.to_array(), xs.to_array()],
279                    5,
280                    Validity::AllValid,
281                )
282                .unwrap()
283                .into_array(),
284                ys.into_array(),
285                zs.into_array(),
286            ],
287            5,
288            Validity::AllValid,
289        )
290        .unwrap()
291        .into_array();
292
293        let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
294        let casted = cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
295        assert_eq!(casted.dtype(), &top_level_non_nullable);
296
297        let non_null_xs_right = DType::Struct(
298            StructFields::new(
299                ["xs", "ys", "zs"].into(),
300                vec![
301                    DType::Struct(
302                        StructFields::new(
303                            ["left", "right"].into(),
304                            vec![
305                                DType::Primitive(PType::I64, NonNullable),
306                                DType::Primitive(PType::I64, Nullable),
307                            ],
308                        ),
309                        Nullable,
310                    ),
311                    DType::Utf8(Nullable),
312                    DType::Bool(Nullable),
313                ],
314            ),
315            Nullable,
316        );
317        let casted = cast(&fully_nullable_array, &non_null_xs_right).unwrap();
318        assert_eq!(casted.dtype(), &non_null_xs_right);
319
320        let non_null_xs = DType::Struct(
321            StructFields::new(
322                ["xs", "ys", "zs"].into(),
323                vec![
324                    DType::Struct(
325                        StructFields::new(
326                            ["left", "right"].into(),
327                            vec![
328                                DType::Primitive(PType::I64, Nullable),
329                                DType::Primitive(PType::I64, Nullable),
330                            ],
331                        ),
332                        NonNullable,
333                    ),
334                    DType::Utf8(Nullable),
335                    DType::Bool(Nullable),
336                ],
337            ),
338            Nullable,
339        );
340        let casted = cast(&fully_nullable_array, &non_null_xs).unwrap();
341        assert_eq!(casted.dtype(), &non_null_xs);
342    }
343}