vortex_array/arrays/struct_/compute/
mod.rs

1mod cast;
2mod mask;
3mod to_arrow;
4
5use itertools::Itertools;
6use vortex_error::VortexResult;
7use vortex_mask::Mask;
8use vortex_scalar::Scalar;
9
10use crate::arrays::StructEncoding;
11use crate::arrays::struct_::StructArray;
12use crate::compute::{
13    FilterKernel, FilterKernelAdapter, IsConstantFn, IsConstantOpts, MinMaxFn, MinMaxResult,
14    ScalarAtFn, SliceFn, TakeFn, ToArrowFn, UncompressedSizeFn, filter, is_constant_opts,
15    scalar_at, slice, take, uncompressed_size,
16};
17use crate::vtable::ComputeVTable;
18use crate::{Array, ArrayRef, ArrayVisitor, register_kernel};
19
20impl ComputeVTable for StructEncoding {
21    fn is_constant_fn(&self) -> Option<&dyn IsConstantFn<&dyn Array>> {
22        Some(self)
23    }
24
25    fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
26        Some(self)
27    }
28
29    fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
30        Some(self)
31    }
32
33    fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
34        Some(self)
35    }
36
37    fn to_arrow_fn(&self) -> Option<&dyn ToArrowFn<&dyn Array>> {
38        Some(self)
39    }
40
41    fn min_max_fn(&self) -> Option<&dyn MinMaxFn<&dyn Array>> {
42        Some(self)
43    }
44
45    fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
46        Some(self)
47    }
48}
49
50impl ScalarAtFn<&StructArray> for StructEncoding {
51    fn scalar_at(&self, array: &StructArray, index: usize) -> VortexResult<Scalar> {
52        Ok(Scalar::struct_(
53            array.dtype().clone(),
54            array
55                .fields()
56                .iter()
57                .map(|field| scalar_at(field, index))
58                .try_collect()?,
59        ))
60    }
61}
62
63impl TakeFn<&StructArray> for StructEncoding {
64    fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
65        StructArray::try_new_with_dtype(
66            array
67                .fields()
68                .iter()
69                .map(|field| take(field, indices))
70                .try_collect()?,
71            array.struct_dtype().clone(),
72            indices.len(),
73            array.validity().take(indices)?,
74        )
75        .map(|a| a.into_array())
76    }
77}
78
79impl SliceFn<&StructArray> for StructEncoding {
80    fn slice(&self, array: &StructArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
81        let fields = array
82            .fields()
83            .iter()
84            .map(|field| slice(field, start, stop))
85            .try_collect()?;
86        StructArray::try_new_with_dtype(
87            fields,
88            array.struct_dtype().clone(),
89            stop - start,
90            array.validity().slice(start, stop)?,
91        )
92        .map(|a| a.into_array())
93    }
94}
95
96impl FilterKernel for StructEncoding {
97    fn filter(&self, array: &StructArray, mask: &Mask) -> VortexResult<ArrayRef> {
98        let validity = array.validity().filter(mask)?;
99
100        let fields: Vec<ArrayRef> = array
101            .fields()
102            .iter()
103            .map(|field| filter(field, mask))
104            .try_collect()?;
105        let length = fields
106            .first()
107            .map(|a| a.len())
108            .unwrap_or_else(|| mask.true_count());
109
110        StructArray::try_new_with_dtype(fields, array.struct_dtype().clone(), length, validity)
111            .map(|a| a.into_array())
112    }
113}
114
115register_kernel!(FilterKernelAdapter(StructEncoding).lift());
116
117impl MinMaxFn<&StructArray> for StructEncoding {
118    fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
119        // TODO(joe): Implement struct min max
120        Ok(None)
121    }
122}
123
124impl IsConstantFn<&StructArray> for StructEncoding {
125    fn is_constant(
126        &self,
127        array: &StructArray,
128        opts: &IsConstantOpts,
129    ) -> VortexResult<Option<bool>> {
130        let children = array.children();
131        if children.is_empty() {
132            return Ok(None);
133        }
134
135        for child in children.iter() {
136            if !is_constant_opts(child, opts)? {
137                return Ok(Some(false));
138            }
139        }
140
141        Ok(Some(true))
142    }
143}
144
145impl UncompressedSizeFn<&StructArray> for StructEncoding {
146    fn uncompressed_size(&self, array: &StructArray) -> VortexResult<usize> {
147        let mut sum = array.validity().uncompressed_size();
148        for child in array.children().into_iter() {
149            sum += uncompressed_size(child.as_ref())?;
150        }
151
152        Ok(sum)
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use std::sync::Arc;
159
160    use vortex_buffer::buffer;
161    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructDType};
162    use vortex_mask::Mask;
163
164    use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
165    use crate::compute::conformance::mask::test_mask;
166    use crate::compute::{cast, filter};
167    use crate::validity::Validity;
168    use crate::{Array, IntoArray as _};
169
170    #[test]
171    fn filter_empty_struct() {
172        let struct_arr =
173            StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
174        let mask = vec![
175            false, true, false, true, false, true, false, true, false, true,
176        ];
177        let filtered = filter(&struct_arr, &Mask::from_iter(mask)).unwrap();
178        assert_eq!(filtered.len(), 5);
179    }
180
181    #[test]
182    fn filter_empty_struct_with_empty_filter() {
183        let struct_arr =
184            StructArray::try_new(vec![].into(), vec![], 0, Validity::NonNullable).unwrap();
185        let filtered = filter(&struct_arr, &Mask::from_iter::<[bool; 0]>([])).unwrap();
186        assert_eq!(filtered.len(), 0);
187    }
188
189    #[test]
190    fn test_mask_empty_struct() {
191        test_mask(&StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable).unwrap());
192    }
193
194    #[test]
195    fn test_mask_complex_struct() {
196        let xs = buffer![0i64, 1, 2, 3, 4].into_array();
197        let ys = VarBinArray::from_iter(
198            [Some("a"), Some("b"), None, Some("d"), None],
199            DType::Utf8(Nullability::Nullable),
200        )
201        .into_array();
202        let zs =
203            BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
204
205        test_mask(
206            &StructArray::try_new(
207                ["xs".into(), "ys".into(), "zs".into()].into(),
208                vec![
209                    StructArray::try_new(
210                        ["left".into(), "right".into()].into(),
211                        vec![xs.clone(), xs],
212                        5,
213                        Validity::NonNullable,
214                    )
215                    .unwrap()
216                    .into_array(),
217                    ys,
218                    zs,
219                ],
220                5,
221                Validity::NonNullable,
222            )
223            .unwrap(),
224        );
225    }
226
227    #[test]
228    fn test_cast_empty_struct() {
229        let array = StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
230            .unwrap()
231            .into_array();
232        let non_nullable_dtype = DType::Struct(
233            Arc::from(StructDType::new([].into(), vec![])),
234            Nullability::NonNullable,
235        );
236        let casted = cast(&array, &non_nullable_dtype).unwrap();
237        assert_eq!(casted.dtype(), &non_nullable_dtype);
238
239        let nullable_dtype = DType::Struct(
240            Arc::from(StructDType::new([].into(), vec![])),
241            Nullability::Nullable,
242        );
243        let casted = cast(&array, &nullable_dtype).unwrap();
244        assert_eq!(casted.dtype(), &nullable_dtype);
245    }
246
247    #[test]
248    fn test_cast_cannot_change_name_order() {
249        let array = StructArray::try_new(
250            ["xs".into(), "ys".into(), "zs".into()].into(),
251            vec![
252                buffer![1u8].into_array(),
253                buffer![1u8].into_array(),
254                buffer![1u8].into_array(),
255            ],
256            1,
257            Validity::NonNullable,
258        )
259        .unwrap();
260
261        let tu8 = DType::Primitive(PType::U8, Nullability::NonNullable);
262
263        let result = cast(
264            &array,
265            &DType::Struct(
266                Arc::from(StructDType::new(
267                    FieldNames::from(["ys".into(), "xs".into(), "zs".into()]),
268                    vec![tu8.clone(), tu8.clone(), tu8],
269                )),
270                Nullability::NonNullable,
271            ),
272        );
273        assert!(
274            result.as_ref().is_err_and(|err| {
275                err.to_string()
276                    .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
277            }),
278            "{:?}",
279            result
280        );
281    }
282
283    #[test]
284    fn test_cast_complex_struct() {
285        let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
286        let ys = VarBinArray::from_vec(
287            vec!["a", "b", "c", "d", "e"],
288            DType::Utf8(Nullability::Nullable),
289        );
290        let zs = BoolArray::new(
291            BooleanBuffer::from_iter([true, true, false, false, true]),
292            Validity::AllValid,
293        );
294        let fully_nullable_array = StructArray::try_new(
295            ["xs".into(), "ys".into(), "zs".into()].into(),
296            vec![
297                StructArray::try_new(
298                    ["left".into(), "right".into()].into(),
299                    vec![xs.to_array(), xs.to_array()],
300                    5,
301                    Validity::AllValid,
302                )
303                .unwrap()
304                .into_array(),
305                ys.into_array(),
306                zs.into_array(),
307            ],
308            5,
309            Validity::AllValid,
310        )
311        .unwrap()
312        .into_array();
313
314        let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
315        let casted = cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
316        assert_eq!(casted.dtype(), &top_level_non_nullable);
317
318        let non_null_xs_right = DType::Struct(
319            Arc::from(StructDType::new(
320                ["xs".into(), "ys".into(), "zs".into()].into(),
321                vec![
322                    DType::Struct(
323                        Arc::from(StructDType::new(
324                            ["left".into(), "right".into()].into(),
325                            vec![
326                                DType::Primitive(PType::I64, Nullability::NonNullable),
327                                DType::Primitive(PType::I64, Nullability::Nullable),
328                            ],
329                        )),
330                        Nullability::Nullable,
331                    ),
332                    DType::Utf8(Nullability::Nullable),
333                    DType::Bool(Nullability::Nullable),
334                ],
335            )),
336            Nullability::Nullable,
337        );
338        let casted = cast(&fully_nullable_array, &non_null_xs_right).unwrap();
339        assert_eq!(casted.dtype(), &non_null_xs_right);
340
341        let non_null_xs = DType::Struct(
342            Arc::from(StructDType::new(
343                ["xs".into(), "ys".into(), "zs".into()].into(),
344                vec![
345                    DType::Struct(
346                        Arc::from(StructDType::new(
347                            ["left".into(), "right".into()].into(),
348                            vec![
349                                DType::Primitive(PType::I64, Nullability::Nullable),
350                                DType::Primitive(PType::I64, Nullability::Nullable),
351                            ],
352                        )),
353                        Nullability::NonNullable,
354                    ),
355                    DType::Utf8(Nullability::Nullable),
356                    DType::Bool(Nullability::Nullable),
357                ],
358            )),
359            Nullability::Nullable,
360        );
361        let casted = cast(&fully_nullable_array, &non_null_xs).unwrap();
362        assert_eq!(casted.dtype(), &non_null_xs);
363    }
364}