vortex_array/arrays/struct_/compute/
mod.rs

1mod to_arrow;
2
3use itertools::Itertools;
4use vortex_dtype::DType;
5use vortex_error::{VortexExpect, VortexResult, vortex_bail};
6use vortex_mask::Mask;
7use vortex_scalar::Scalar;
8
9use crate::arrays::StructEncoding;
10use crate::arrays::struct_::StructArray;
11use crate::compute::{
12    CastFn, FilterKernel, FilterKernelAdapter, IsConstantFn, IsConstantOpts, KernelRef, MaskFn,
13    MinMaxFn, MinMaxResult, ScalarAtFn, SliceFn, TakeFn, ToArrowFn, UncompressedSizeFn, filter,
14    is_constant_opts, scalar_at, slice, take, try_cast, uncompressed_size,
15};
16use crate::vtable::ComputeVTable;
17use crate::{Array, ArrayComputeImpl, ArrayRef, ArrayVisitor};
18
19impl ArrayComputeImpl for StructArray {
20    const FILTER: Option<KernelRef> = FilterKernelAdapter(StructEncoding).some();
21}
22
23impl ComputeVTable for StructEncoding {
24    fn cast_fn(&self) -> Option<&dyn CastFn<&dyn Array>> {
25        Some(self)
26    }
27
28    fn is_constant_fn(&self) -> Option<&dyn IsConstantFn<&dyn Array>> {
29        Some(self)
30    }
31
32    fn mask_fn(&self) -> Option<&dyn MaskFn<&dyn Array>> {
33        Some(self)
34    }
35
36    fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
37        Some(self)
38    }
39
40    fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
41        Some(self)
42    }
43
44    fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
45        Some(self)
46    }
47
48    fn to_arrow_fn(&self) -> Option<&dyn ToArrowFn<&dyn Array>> {
49        Some(self)
50    }
51
52    fn min_max_fn(&self) -> Option<&dyn MinMaxFn<&dyn Array>> {
53        Some(self)
54    }
55
56    fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
57        Some(self)
58    }
59}
60
61impl CastFn<&StructArray> for StructEncoding {
62    fn cast(&self, array: &StructArray, dtype: &DType) -> VortexResult<ArrayRef> {
63        let Some(target_sdtype) = dtype.as_struct() else {
64            vortex_bail!("cannot cast {} to {}", array.dtype(), dtype);
65        };
66
67        let source_sdtype = array
68            .dtype()
69            .as_struct()
70            .vortex_expect("struct array must have struct dtype");
71
72        if target_sdtype.names() != source_sdtype.names() {
73            vortex_bail!("cannot cast {} to {}", array.dtype(), dtype);
74        }
75
76        let validity = array
77            .validity()
78            .clone()
79            .cast_nullability(dtype.nullability())?;
80
81        StructArray::try_new(
82            target_sdtype.names().clone(),
83            array
84                .fields()
85                .iter()
86                .zip_eq(target_sdtype.fields())
87                .map(|(field, dtype)| try_cast(field, &dtype))
88                .try_collect()?,
89            array.len(),
90            validity,
91        )
92        .map(|a| a.into_array())
93    }
94}
95
96impl ScalarAtFn<&StructArray> for StructEncoding {
97    fn scalar_at(&self, array: &StructArray, index: usize) -> VortexResult<Scalar> {
98        Ok(Scalar::struct_(
99            array.dtype().clone(),
100            array
101                .fields()
102                .iter()
103                .map(|field| scalar_at(field, index))
104                .try_collect()?,
105        ))
106    }
107}
108
109impl TakeFn<&StructArray> for StructEncoding {
110    fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
111        StructArray::try_new_with_dtype(
112            array
113                .fields()
114                .iter()
115                .map(|field| take(field, indices))
116                .try_collect()?,
117            array.struct_dtype().clone(),
118            indices.len(),
119            array.validity().take(indices)?,
120        )
121        .map(|a| a.into_array())
122    }
123}
124
125impl SliceFn<&StructArray> for StructEncoding {
126    fn slice(&self, array: &StructArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
127        let fields = array
128            .fields()
129            .iter()
130            .map(|field| slice(field, start, stop))
131            .try_collect()?;
132        StructArray::try_new_with_dtype(
133            fields,
134            array.struct_dtype().clone(),
135            stop - start,
136            array.validity().slice(start, stop)?,
137        )
138        .map(|a| a.into_array())
139    }
140}
141
142impl FilterKernel for StructEncoding {
143    fn filter(&self, array: &StructArray, mask: &Mask) -> VortexResult<ArrayRef> {
144        let validity = array.validity().filter(mask)?;
145
146        let fields: Vec<ArrayRef> = array
147            .fields()
148            .iter()
149            .map(|field| filter(field, mask))
150            .try_collect()?;
151        let length = fields
152            .first()
153            .map(|a| a.len())
154            .unwrap_or_else(|| mask.true_count());
155
156        StructArray::try_new_with_dtype(fields, array.struct_dtype().clone(), length, validity)
157            .map(|a| a.into_array())
158    }
159}
160
161impl MaskFn<&StructArray> for StructEncoding {
162    fn mask(&self, array: &StructArray, filter_mask: Mask) -> VortexResult<ArrayRef> {
163        let validity = array.validity().mask(&filter_mask)?;
164
165        StructArray::try_new_with_dtype(
166            array.fields().to_vec(),
167            array.struct_dtype().clone(),
168            array.len(),
169            validity,
170        )
171        .map(|a| a.into_array())
172    }
173}
174
175impl MinMaxFn<&StructArray> for StructEncoding {
176    fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
177        // TODO(joe): Implement struct min max
178        Ok(None)
179    }
180}
181
182impl IsConstantFn<&StructArray> for StructEncoding {
183    fn is_constant(
184        &self,
185        array: &StructArray,
186        opts: &IsConstantOpts,
187    ) -> VortexResult<Option<bool>> {
188        let children = array.children();
189        if children.is_empty() {
190            return Ok(None);
191        }
192
193        for child in children.iter() {
194            if !is_constant_opts(child, opts)? {
195                return Ok(Some(false));
196            }
197        }
198
199        Ok(Some(true))
200    }
201}
202
203impl UncompressedSizeFn<&StructArray> for StructEncoding {
204    fn uncompressed_size(&self, array: &StructArray) -> VortexResult<usize> {
205        let mut sum = array.validity().uncompressed_size();
206        for child in array.children().into_iter() {
207            sum += uncompressed_size(child.as_ref())?;
208        }
209
210        Ok(sum)
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use std::sync::Arc;
217
218    use vortex_buffer::buffer;
219    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructDType};
220    use vortex_mask::Mask;
221
222    use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
223    use crate::compute::conformance::mask::test_mask;
224    use crate::compute::{filter, try_cast};
225    use crate::validity::Validity;
226    use crate::{Array, IntoArray as _};
227
228    #[test]
229    fn filter_empty_struct() {
230        let struct_arr =
231            StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
232        let mask = vec![
233            false, true, false, true, false, true, false, true, false, true,
234        ];
235        let filtered = filter(&struct_arr, &Mask::from_iter(mask)).unwrap();
236        assert_eq!(filtered.len(), 5);
237    }
238
239    #[test]
240    fn filter_empty_struct_with_empty_filter() {
241        let struct_arr =
242            StructArray::try_new(vec![].into(), vec![], 0, Validity::NonNullable).unwrap();
243        let filtered = filter(&struct_arr, &Mask::from_iter::<[bool; 0]>([])).unwrap();
244        assert_eq!(filtered.len(), 0);
245    }
246
247    #[test]
248    fn test_mask_empty_struct() {
249        test_mask(&StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable).unwrap());
250    }
251
252    #[test]
253    fn test_mask_complex_struct() {
254        let xs = buffer![0i64, 1, 2, 3, 4].into_array();
255        let ys = VarBinArray::from_iter(
256            [Some("a"), Some("b"), None, Some("d"), None],
257            DType::Utf8(Nullability::Nullable),
258        )
259        .into_array();
260        let zs =
261            BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
262
263        test_mask(
264            &StructArray::try_new(
265                ["xs".into(), "ys".into(), "zs".into()].into(),
266                vec![
267                    StructArray::try_new(
268                        ["left".into(), "right".into()].into(),
269                        vec![xs.clone(), xs],
270                        5,
271                        Validity::NonNullable,
272                    )
273                    .unwrap()
274                    .into_array(),
275                    ys,
276                    zs,
277                ],
278                5,
279                Validity::NonNullable,
280            )
281            .unwrap(),
282        );
283    }
284
285    #[test]
286    fn test_cast_empty_struct() {
287        let array = StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
288            .unwrap()
289            .into_array();
290        let non_nullable_dtype = DType::Struct(
291            Arc::from(StructDType::new([].into(), vec![])),
292            Nullability::NonNullable,
293        );
294        let casted = try_cast(&array, &non_nullable_dtype).unwrap();
295        assert_eq!(casted.dtype(), &non_nullable_dtype);
296
297        let nullable_dtype = DType::Struct(
298            Arc::from(StructDType::new([].into(), vec![])),
299            Nullability::Nullable,
300        );
301        let casted = try_cast(&array, &nullable_dtype).unwrap();
302        assert_eq!(casted.dtype(), &nullable_dtype);
303    }
304
305    #[test]
306    fn test_cast_cannot_change_name_order() {
307        let array = StructArray::try_new(
308            ["xs".into(), "ys".into(), "zs".into()].into(),
309            vec![
310                buffer![1u8].into_array(),
311                buffer![1u8].into_array(),
312                buffer![1u8].into_array(),
313            ],
314            1,
315            Validity::NonNullable,
316        )
317        .unwrap();
318
319        let tu8 = DType::Primitive(PType::U8, Nullability::NonNullable);
320
321        let result = try_cast(
322            &array,
323            &DType::Struct(
324                Arc::from(StructDType::new(
325                    FieldNames::from(["ys".into(), "xs".into(), "zs".into()]),
326                    vec![tu8.clone(), tu8.clone(), tu8],
327                )),
328                Nullability::NonNullable,
329            ),
330        );
331        assert!(
332            result.as_ref().is_err_and(|err| {
333                err.to_string()
334                    .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
335            }),
336            "{:?}",
337            result
338        );
339    }
340
341    #[test]
342    fn test_cast_complex_struct() {
343        let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
344        let ys = VarBinArray::from_vec(
345            vec!["a", "b", "c", "d", "e"],
346            DType::Utf8(Nullability::Nullable),
347        );
348        let zs = BoolArray::new(
349            BooleanBuffer::from_iter([true, true, false, false, true]),
350            Validity::AllValid,
351        );
352        let fully_nullable_array = StructArray::try_new(
353            ["xs".into(), "ys".into(), "zs".into()].into(),
354            vec![
355                StructArray::try_new(
356                    ["left".into(), "right".into()].into(),
357                    vec![xs.to_array(), xs.to_array()],
358                    5,
359                    Validity::AllValid,
360                )
361                .unwrap()
362                .into_array(),
363                ys.into_array(),
364                zs.into_array(),
365            ],
366            5,
367            Validity::AllValid,
368        )
369        .unwrap()
370        .into_array();
371
372        let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
373        let casted = try_cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
374        assert_eq!(casted.dtype(), &top_level_non_nullable);
375
376        let non_null_xs_right = DType::Struct(
377            Arc::from(StructDType::new(
378                ["xs".into(), "ys".into(), "zs".into()].into(),
379                vec![
380                    DType::Struct(
381                        Arc::from(StructDType::new(
382                            ["left".into(), "right".into()].into(),
383                            vec![
384                                DType::Primitive(PType::I64, Nullability::NonNullable),
385                                DType::Primitive(PType::I64, Nullability::Nullable),
386                            ],
387                        )),
388                        Nullability::Nullable,
389                    ),
390                    DType::Utf8(Nullability::Nullable),
391                    DType::Bool(Nullability::Nullable),
392                ],
393            )),
394            Nullability::Nullable,
395        );
396        let casted = try_cast(&fully_nullable_array, &non_null_xs_right).unwrap();
397        assert_eq!(casted.dtype(), &non_null_xs_right);
398
399        let non_null_xs = DType::Struct(
400            Arc::from(StructDType::new(
401                ["xs".into(), "ys".into(), "zs".into()].into(),
402                vec![
403                    DType::Struct(
404                        Arc::from(StructDType::new(
405                            ["left".into(), "right".into()].into(),
406                            vec![
407                                DType::Primitive(PType::I64, Nullability::Nullable),
408                                DType::Primitive(PType::I64, Nullability::Nullable),
409                            ],
410                        )),
411                        Nullability::NonNullable,
412                    ),
413                    DType::Utf8(Nullability::Nullable),
414                    DType::Bool(Nullability::Nullable),
415                ],
416            )),
417            Nullability::Nullable,
418        );
419        let casted = try_cast(&fully_nullable_array, &non_null_xs).unwrap();
420        assert_eq!(casted.dtype(), &non_null_xs);
421    }
422}