vortex_array/arrays/struct_/compute/
mod.rs

1mod cast;
2mod filter;
3mod mask;
4
5use itertools::Itertools;
6use vortex_dtype::Nullability::NonNullable;
7use vortex_error::VortexResult;
8use vortex_scalar::Scalar;
9
10use crate::arrays::StructVTable;
11use crate::arrays::struct_::StructArray;
12use crate::compute::{
13    IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, MinMaxKernel, MinMaxKernelAdapter,
14    MinMaxResult, TakeKernel, TakeKernelAdapter, fill_null, is_constant_opts, take,
15};
16use crate::validity::Validity;
17use crate::vtable::ValidityHelper;
18use crate::{Array, ArrayRef, IntoArray, register_kernel};
19
20impl TakeKernel for StructVTable {
21    fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
22        // If the struct array is empty then the indices must be all null, otherwise it will access
23        // an out of bounds element
24        if array.is_empty() {
25            return StructArray::try_new_with_dtype(
26                array.fields().to_vec(),
27                array.struct_fields().clone(),
28                indices.len(),
29                Validity::AllInvalid,
30            )
31            .map(StructArray::into_array);
32        }
33        // The validity is applied to the struct validity,
34        let inner_indices = &fill_null(
35            indices,
36            &Scalar::default_value(indices.dtype().with_nullability(NonNullable)),
37        )?;
38        StructArray::try_new_with_dtype(
39            array
40                .fields()
41                .iter()
42                .map(|field| take(field, inner_indices))
43                .try_collect()?,
44            array.struct_fields().clone(),
45            indices.len(),
46            array.validity().take(indices)?,
47        )
48        .map(|a| a.into_array())
49    }
50}
51
52register_kernel!(TakeKernelAdapter(StructVTable).lift());
53
54impl MinMaxKernel for StructVTable {
55    fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
56        // TODO(joe): Implement struct min max
57        Ok(None)
58    }
59}
60
61register_kernel!(MinMaxKernelAdapter(StructVTable).lift());
62
63impl IsConstantKernel for StructVTable {
64    fn is_constant(
65        &self,
66        array: &StructArray,
67        opts: &IsConstantOpts,
68    ) -> VortexResult<Option<bool>> {
69        let children = array.children();
70        if children.is_empty() {
71            return Ok(None);
72        }
73
74        for child in children.iter() {
75            match is_constant_opts(child, opts)? {
76                // Un-determined
77                None => return Ok(None),
78                Some(false) => return Ok(Some(false)),
79                Some(true) => {}
80            }
81        }
82
83        Ok(Some(true))
84    }
85}
86
87register_kernel!(IsConstantKernelAdapter(StructVTable).lift());
88
89#[cfg(test)]
90mod tests {
91    use std::sync::Arc;
92
93    use Nullability::{NonNullable, Nullable};
94    use vortex_buffer::buffer;
95    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
96    use vortex_mask::Mask;
97    use vortex_scalar::Scalar;
98
99    use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
100    use crate::compute::conformance::mask::test_mask;
101    use crate::compute::{cast, filter, take};
102    use crate::validity::Validity;
103    use crate::{Array, IntoArray as _};
104
105    #[test]
106    fn filter_empty_struct() {
107        let struct_arr =
108            StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
109        let mask = vec![
110            false, true, false, true, false, true, false, true, false, true,
111        ];
112        let filtered = filter(struct_arr.as_ref(), &Mask::from_iter(mask)).unwrap();
113        assert_eq!(filtered.len(), 5);
114    }
115
116    #[test]
117    fn take_empty_struct() {
118        let struct_arr =
119            StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
120        let indices = PrimitiveArray::from_option_iter([Some(1), None]);
121        let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
122        assert_eq!(taken.len(), 2);
123
124        assert_eq!(
125            taken.scalar_at(0).unwrap(),
126            Scalar::struct_(
127                DType::Struct(Arc::new(StructFields::new([].into(), vec![])), Nullable),
128                vec![]
129            )
130        );
131        assert_eq!(
132            taken.scalar_at(1).unwrap(),
133            Scalar::null(DType::Struct(
134                Arc::new(StructFields::new([].into(), vec![])),
135                Nullable
136            ))
137        );
138    }
139
140    #[test]
141    fn take_field_struct() {
142        let struct_arr =
143            StructArray::from_fields(&[("a", PrimitiveArray::from_iter(0..10).to_array())])
144                .unwrap();
145        let indices = PrimitiveArray::from_option_iter([Some(1), None]);
146        let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
147        assert_eq!(taken.len(), 2);
148
149        assert_eq!(
150            taken.scalar_at(0).unwrap(),
151            Scalar::struct_(
152                struct_arr.dtype().union_nullability(Nullable),
153                vec![Scalar::primitive(1, NonNullable)],
154            )
155        );
156        assert_eq!(
157            taken.scalar_at(1).unwrap(),
158            Scalar::null(struct_arr.dtype().union_nullability(Nullable),)
159        );
160    }
161
162    #[test]
163    fn filter_empty_struct_with_empty_filter() {
164        let struct_arr =
165            StructArray::try_new(vec![].into(), vec![], 0, Validity::NonNullable).unwrap();
166        let filtered = filter(struct_arr.as_ref(), &Mask::from_iter::<[bool; 0]>([])).unwrap();
167        assert_eq!(filtered.len(), 0);
168    }
169
170    #[test]
171    fn test_mask_empty_struct() {
172        test_mask(
173            StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
174                .unwrap()
175                .as_ref(),
176        );
177    }
178
179    #[test]
180    fn test_mask_complex_struct() {
181        let xs = buffer![0i64, 1, 2, 3, 4].into_array();
182        let ys = VarBinArray::from_iter(
183            [Some("a"), Some("b"), None, Some("d"), None],
184            DType::Utf8(Nullable),
185        )
186        .into_array();
187        let zs =
188            BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
189
190        test_mask(
191            StructArray::try_new(
192                ["xs".into(), "ys".into(), "zs".into()].into(),
193                vec![
194                    StructArray::try_new(
195                        ["left".into(), "right".into()].into(),
196                        vec![xs.clone(), xs],
197                        5,
198                        Validity::NonNullable,
199                    )
200                    .unwrap()
201                    .into_array(),
202                    ys,
203                    zs,
204                ],
205                5,
206                Validity::NonNullable,
207            )
208            .unwrap()
209            .as_ref(),
210        );
211    }
212
213    #[test]
214    fn test_cast_empty_struct() {
215        let array = StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
216            .unwrap()
217            .into_array();
218        let non_nullable_dtype =
219            DType::Struct(Arc::from(StructFields::new([].into(), vec![])), NonNullable);
220        let casted = cast(&array, &non_nullable_dtype).unwrap();
221        assert_eq!(casted.dtype(), &non_nullable_dtype);
222
223        let nullable_dtype =
224            DType::Struct(Arc::from(StructFields::new([].into(), vec![])), Nullable);
225        let casted = cast(&array, &nullable_dtype).unwrap();
226        assert_eq!(casted.dtype(), &nullable_dtype);
227    }
228
229    #[test]
230    fn test_cast_cannot_change_name_order() {
231        let array = StructArray::try_new(
232            ["xs".into(), "ys".into(), "zs".into()].into(),
233            vec![
234                buffer![1u8].into_array(),
235                buffer![1u8].into_array(),
236                buffer![1u8].into_array(),
237            ],
238            1,
239            Validity::NonNullable,
240        )
241        .unwrap();
242
243        let tu8 = DType::Primitive(PType::U8, NonNullable);
244
245        let result = cast(
246            array.as_ref(),
247            &DType::Struct(
248                Arc::from(StructFields::new(
249                    FieldNames::from(["ys".into(), "xs".into(), "zs".into()]),
250                    vec![tu8.clone(), tu8.clone(), tu8],
251                )),
252                NonNullable,
253            ),
254        );
255        assert!(
256            result.as_ref().is_err_and(|err| {
257                err.to_string()
258                    .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
259            }),
260            "{result:?}"
261        );
262    }
263
264    #[test]
265    fn test_cast_complex_struct() {
266        let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
267        let ys = VarBinArray::from_vec(vec!["a", "b", "c", "d", "e"], DType::Utf8(Nullable));
268        let zs = BoolArray::new(
269            BooleanBuffer::from_iter([true, true, false, false, true]),
270            Validity::AllValid,
271        );
272        let fully_nullable_array = StructArray::try_new(
273            ["xs".into(), "ys".into(), "zs".into()].into(),
274            vec![
275                StructArray::try_new(
276                    ["left".into(), "right".into()].into(),
277                    vec![xs.to_array(), xs.to_array()],
278                    5,
279                    Validity::AllValid,
280                )
281                .unwrap()
282                .into_array(),
283                ys.into_array(),
284                zs.into_array(),
285            ],
286            5,
287            Validity::AllValid,
288        )
289        .unwrap()
290        .into_array();
291
292        let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
293        let casted = cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
294        assert_eq!(casted.dtype(), &top_level_non_nullable);
295
296        let non_null_xs_right = DType::Struct(
297            Arc::from(StructFields::new(
298                ["xs".into(), "ys".into(), "zs".into()].into(),
299                vec![
300                    DType::Struct(
301                        Arc::from(StructFields::new(
302                            ["left".into(), "right".into()].into(),
303                            vec![
304                                DType::Primitive(PType::I64, NonNullable),
305                                DType::Primitive(PType::I64, Nullable),
306                            ],
307                        )),
308                        Nullable,
309                    ),
310                    DType::Utf8(Nullable),
311                    DType::Bool(Nullable),
312                ],
313            )),
314            Nullable,
315        );
316        let casted = cast(&fully_nullable_array, &non_null_xs_right).unwrap();
317        assert_eq!(casted.dtype(), &non_null_xs_right);
318
319        let non_null_xs = DType::Struct(
320            Arc::from(StructFields::new(
321                ["xs".into(), "ys".into(), "zs".into()].into(),
322                vec![
323                    DType::Struct(
324                        Arc::from(StructFields::new(
325                            ["left".into(), "right".into()].into(),
326                            vec![
327                                DType::Primitive(PType::I64, Nullable),
328                                DType::Primitive(PType::I64, Nullable),
329                            ],
330                        )),
331                        NonNullable,
332                    ),
333                    DType::Utf8(Nullable),
334                    DType::Bool(Nullable),
335                ],
336            )),
337            Nullable,
338        );
339        let casted = cast(&fully_nullable_array, &non_null_xs).unwrap();
340        assert_eq!(casted.dtype(), &non_null_xs);
341    }
342}