vortex_array/arrays/struct_/compute/
mod.rs

1mod cast;
2mod filter;
3mod mask;
4
5use itertools::Itertools;
6use vortex_error::VortexResult;
7
8use crate::arrays::StructVTable;
9use crate::arrays::struct_::StructArray;
10use crate::compute::{
11    IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, MinMaxKernel, MinMaxKernelAdapter,
12    MinMaxResult, TakeKernel, TakeKernelAdapter, is_constant_opts, take,
13};
14use crate::vtable::ValidityHelper;
15use crate::{Array, ArrayRef, IntoArray, register_kernel};
16
17impl TakeKernel for StructVTable {
18    fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
19        StructArray::try_new_with_dtype(
20            array
21                .fields()
22                .iter()
23                .map(|field| take(field, indices))
24                .try_collect()?,
25            array.struct_fields().clone(),
26            indices.len(),
27            array.validity().take(indices)?,
28        )
29        .map(|a| a.into_array())
30    }
31}
32
33register_kernel!(TakeKernelAdapter(StructVTable).lift());
34
35impl MinMaxKernel for StructVTable {
36    fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
37        // TODO(joe): Implement struct min max
38        Ok(None)
39    }
40}
41
42register_kernel!(MinMaxKernelAdapter(StructVTable).lift());
43
44impl IsConstantKernel for StructVTable {
45    fn is_constant(
46        &self,
47        array: &StructArray,
48        opts: &IsConstantOpts,
49    ) -> VortexResult<Option<bool>> {
50        let children = array.children();
51        if children.is_empty() {
52            return Ok(None);
53        }
54
55        for child in children.iter() {
56            match is_constant_opts(child, opts)? {
57                // Un-determined
58                None => return Ok(None),
59                Some(false) => return Ok(Some(false)),
60                Some(true) => {}
61            }
62        }
63
64        Ok(Some(true))
65    }
66}
67
68register_kernel!(IsConstantKernelAdapter(StructVTable).lift());
69
70#[cfg(test)]
71mod tests {
72    use std::sync::Arc;
73
74    use vortex_buffer::buffer;
75    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
76    use vortex_mask::Mask;
77
78    use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
79    use crate::compute::conformance::mask::test_mask;
80    use crate::compute::{cast, filter};
81    use crate::validity::Validity;
82    use crate::{Array, IntoArray as _};
83
84    #[test]
85    fn filter_empty_struct() {
86        let struct_arr =
87            StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
88        let mask = vec![
89            false, true, false, true, false, true, false, true, false, true,
90        ];
91        let filtered = filter(struct_arr.as_ref(), &Mask::from_iter(mask)).unwrap();
92        assert_eq!(filtered.len(), 5);
93    }
94
95    #[test]
96    fn filter_empty_struct_with_empty_filter() {
97        let struct_arr =
98            StructArray::try_new(vec![].into(), vec![], 0, Validity::NonNullable).unwrap();
99        let filtered = filter(struct_arr.as_ref(), &Mask::from_iter::<[bool; 0]>([])).unwrap();
100        assert_eq!(filtered.len(), 0);
101    }
102
103    #[test]
104    fn test_mask_empty_struct() {
105        test_mask(
106            StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
107                .unwrap()
108                .as_ref(),
109        );
110    }
111
112    #[test]
113    fn test_mask_complex_struct() {
114        let xs = buffer![0i64, 1, 2, 3, 4].into_array();
115        let ys = VarBinArray::from_iter(
116            [Some("a"), Some("b"), None, Some("d"), None],
117            DType::Utf8(Nullability::Nullable),
118        )
119        .into_array();
120        let zs =
121            BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
122
123        test_mask(
124            StructArray::try_new(
125                ["xs".into(), "ys".into(), "zs".into()].into(),
126                vec![
127                    StructArray::try_new(
128                        ["left".into(), "right".into()].into(),
129                        vec![xs.clone(), xs],
130                        5,
131                        Validity::NonNullable,
132                    )
133                    .unwrap()
134                    .into_array(),
135                    ys,
136                    zs,
137                ],
138                5,
139                Validity::NonNullable,
140            )
141            .unwrap()
142            .as_ref(),
143        );
144    }
145
146    #[test]
147    fn test_cast_empty_struct() {
148        let array = StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
149            .unwrap()
150            .into_array();
151        let non_nullable_dtype = DType::Struct(
152            Arc::from(StructFields::new([].into(), vec![])),
153            Nullability::NonNullable,
154        );
155        let casted = cast(&array, &non_nullable_dtype).unwrap();
156        assert_eq!(casted.dtype(), &non_nullable_dtype);
157
158        let nullable_dtype = DType::Struct(
159            Arc::from(StructFields::new([].into(), vec![])),
160            Nullability::Nullable,
161        );
162        let casted = cast(&array, &nullable_dtype).unwrap();
163        assert_eq!(casted.dtype(), &nullable_dtype);
164    }
165
166    #[test]
167    fn test_cast_cannot_change_name_order() {
168        let array = StructArray::try_new(
169            ["xs".into(), "ys".into(), "zs".into()].into(),
170            vec![
171                buffer![1u8].into_array(),
172                buffer![1u8].into_array(),
173                buffer![1u8].into_array(),
174            ],
175            1,
176            Validity::NonNullable,
177        )
178        .unwrap();
179
180        let tu8 = DType::Primitive(PType::U8, Nullability::NonNullable);
181
182        let result = cast(
183            array.as_ref(),
184            &DType::Struct(
185                Arc::from(StructFields::new(
186                    FieldNames::from(["ys".into(), "xs".into(), "zs".into()]),
187                    vec![tu8.clone(), tu8.clone(), tu8],
188                )),
189                Nullability::NonNullable,
190            ),
191        );
192        assert!(
193            result.as_ref().is_err_and(|err| {
194                err.to_string()
195                    .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
196            }),
197            "{result:?}"
198        );
199    }
200
201    #[test]
202    fn test_cast_complex_struct() {
203        let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
204        let ys = VarBinArray::from_vec(
205            vec!["a", "b", "c", "d", "e"],
206            DType::Utf8(Nullability::Nullable),
207        );
208        let zs = BoolArray::new(
209            BooleanBuffer::from_iter([true, true, false, false, true]),
210            Validity::AllValid,
211        );
212        let fully_nullable_array = StructArray::try_new(
213            ["xs".into(), "ys".into(), "zs".into()].into(),
214            vec![
215                StructArray::try_new(
216                    ["left".into(), "right".into()].into(),
217                    vec![xs.to_array(), xs.to_array()],
218                    5,
219                    Validity::AllValid,
220                )
221                .unwrap()
222                .into_array(),
223                ys.into_array(),
224                zs.into_array(),
225            ],
226            5,
227            Validity::AllValid,
228        )
229        .unwrap()
230        .into_array();
231
232        let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
233        let casted = cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
234        assert_eq!(casted.dtype(), &top_level_non_nullable);
235
236        let non_null_xs_right = DType::Struct(
237            Arc::from(StructFields::new(
238                ["xs".into(), "ys".into(), "zs".into()].into(),
239                vec![
240                    DType::Struct(
241                        Arc::from(StructFields::new(
242                            ["left".into(), "right".into()].into(),
243                            vec![
244                                DType::Primitive(PType::I64, Nullability::NonNullable),
245                                DType::Primitive(PType::I64, Nullability::Nullable),
246                            ],
247                        )),
248                        Nullability::Nullable,
249                    ),
250                    DType::Utf8(Nullability::Nullable),
251                    DType::Bool(Nullability::Nullable),
252                ],
253            )),
254            Nullability::Nullable,
255        );
256        let casted = cast(&fully_nullable_array, &non_null_xs_right).unwrap();
257        assert_eq!(casted.dtype(), &non_null_xs_right);
258
259        let non_null_xs = DType::Struct(
260            Arc::from(StructFields::new(
261                ["xs".into(), "ys".into(), "zs".into()].into(),
262                vec![
263                    DType::Struct(
264                        Arc::from(StructFields::new(
265                            ["left".into(), "right".into()].into(),
266                            vec![
267                                DType::Primitive(PType::I64, Nullability::Nullable),
268                                DType::Primitive(PType::I64, Nullability::Nullable),
269                            ],
270                        )),
271                        Nullability::NonNullable,
272                    ),
273                    DType::Utf8(Nullability::Nullable),
274                    DType::Bool(Nullability::Nullable),
275                ],
276            )),
277            Nullability::Nullable,
278        );
279        let casted = cast(&fully_nullable_array, &non_null_xs).unwrap();
280        assert_eq!(casted.dtype(), &non_null_xs);
281    }
282}