vortex_array/arrays/struct_/compute/
mod.rs

1mod to_arrow;
2
3use itertools::Itertools;
4use vortex_dtype::DType;
5use vortex_error::{VortexExpect, VortexResult, vortex_bail};
6use vortex_mask::Mask;
7use vortex_scalar::Scalar;
8
9use crate::arrays::StructEncoding;
10use crate::arrays::struct_::StructArray;
11use crate::compute::{
12    CastFn, FilterFn, IsConstantFn, IsConstantOpts, MaskFn, MinMaxFn, MinMaxResult, ScalarAtFn,
13    SliceFn, TakeFn, ToArrowFn, UncompressedSizeFn, filter, is_constant_opts, scalar_at, slice,
14    take, try_cast, uncompressed_size,
15};
16use crate::variants::StructArrayTrait;
17use crate::vtable::ComputeVTable;
18use crate::{Array, ArrayRef, ArrayVisitor};
19
20impl ComputeVTable for StructEncoding {
21    fn cast_fn(&self) -> Option<&dyn CastFn<&dyn Array>> {
22        Some(self)
23    }
24
25    fn filter_fn(&self) -> Option<&dyn FilterFn<&dyn Array>> {
26        Some(self)
27    }
28
29    fn is_constant_fn(&self) -> Option<&dyn IsConstantFn<&dyn Array>> {
30        Some(self)
31    }
32
33    fn mask_fn(&self) -> Option<&dyn MaskFn<&dyn Array>> {
34        Some(self)
35    }
36
37    fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
38        Some(self)
39    }
40
41    fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
42        Some(self)
43    }
44
45    fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
46        Some(self)
47    }
48
49    fn to_arrow_fn(&self) -> Option<&dyn ToArrowFn<&dyn Array>> {
50        Some(self)
51    }
52
53    fn min_max_fn(&self) -> Option<&dyn MinMaxFn<&dyn Array>> {
54        Some(self)
55    }
56
57    fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
58        Some(self)
59    }
60}
61
62impl CastFn<&StructArray> for StructEncoding {
63    fn cast(&self, array: &StructArray, dtype: &DType) -> VortexResult<ArrayRef> {
64        let Some(target_sdtype) = dtype.as_struct() else {
65            vortex_bail!("cannot cast {} to {}", array.dtype(), dtype);
66        };
67
68        let source_sdtype = array
69            .dtype()
70            .as_struct()
71            .vortex_expect("struct array must have struct dtype");
72
73        if target_sdtype.names() != source_sdtype.names() {
74            vortex_bail!("cannot cast {} to {}", array.dtype(), dtype);
75        }
76
77        let validity = array
78            .validity()
79            .clone()
80            .cast_nullability(dtype.nullability())?;
81
82        StructArray::try_new(
83            target_sdtype.names().clone(),
84            array
85                .fields()
86                .iter()
87                .zip_eq(target_sdtype.fields())
88                .map(|(field, dtype)| try_cast(field, &dtype))
89                .try_collect()?,
90            array.len(),
91            validity,
92        )
93        .map(|a| a.into_array())
94    }
95}
96
97impl ScalarAtFn<&StructArray> for StructEncoding {
98    fn scalar_at(&self, array: &StructArray, index: usize) -> VortexResult<Scalar> {
99        Ok(Scalar::struct_(
100            array.dtype().clone(),
101            array
102                .fields()
103                .iter()
104                .map(|field| scalar_at(field, index))
105                .try_collect()?,
106        ))
107    }
108}
109
110impl TakeFn<&StructArray> for StructEncoding {
111    fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
112        StructArray::try_new(
113            array.names().clone(),
114            array
115                .fields()
116                .iter()
117                .map(|field| take(field, indices))
118                .try_collect()?,
119            indices.len(),
120            array.validity().take(indices)?,
121        )
122        .map(|a| a.into_array())
123    }
124}
125
126impl SliceFn<&StructArray> for StructEncoding {
127    fn slice(&self, array: &StructArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
128        let fields = array
129            .fields()
130            .iter()
131            .map(|field| slice(field, start, stop))
132            .try_collect()?;
133        StructArray::try_new(
134            array.names().clone(),
135            fields,
136            stop - start,
137            array.validity().slice(start, stop)?,
138        )
139        .map(|a| a.into_array())
140    }
141}
142
143impl FilterFn<&StructArray> for StructEncoding {
144    fn filter(&self, array: &StructArray, mask: &Mask) -> VortexResult<ArrayRef> {
145        let validity = array.validity().filter(mask)?;
146
147        let fields: Vec<ArrayRef> = array
148            .fields()
149            .iter()
150            .map(|field| filter(field, mask))
151            .try_collect()?;
152        let length = fields
153            .first()
154            .map(|a| a.len())
155            .unwrap_or_else(|| mask.true_count());
156
157        StructArray::try_new(array.names().clone(), fields, length, validity)
158            .map(|a| a.into_array())
159    }
160}
161
162impl MaskFn<&StructArray> for StructEncoding {
163    fn mask(&self, array: &StructArray, filter_mask: Mask) -> VortexResult<ArrayRef> {
164        let validity = array.validity().mask(&filter_mask)?;
165
166        StructArray::try_new(
167            array.names().clone(),
168            array.fields().to_vec(),
169            array.len(),
170            validity,
171        )
172        .map(|a| a.into_array())
173    }
174}
175
176impl MinMaxFn<&StructArray> for StructEncoding {
177    fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
178        // TODO(joe): Implement struct min max
179        Ok(None)
180    }
181}
182
183impl IsConstantFn<&StructArray> for StructEncoding {
184    fn is_constant(
185        &self,
186        array: &StructArray,
187        opts: &IsConstantOpts,
188    ) -> VortexResult<Option<bool>> {
189        let children = array.children();
190        if children.is_empty() {
191            return Ok(None);
192        }
193
194        for child in children.iter() {
195            if !is_constant_opts(child, opts)? {
196                return Ok(Some(false));
197            }
198        }
199
200        Ok(Some(true))
201    }
202}
203
204impl UncompressedSizeFn<&StructArray> for StructEncoding {
205    fn uncompressed_size(&self, array: &StructArray) -> VortexResult<usize> {
206        let mut sum = array.validity().uncompressed_size();
207        for child in array.children().into_iter() {
208            sum += uncompressed_size(child.as_ref())?;
209        }
210
211        Ok(sum)
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use std::sync::Arc;
218
219    use vortex_buffer::buffer;
220    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructDType};
221    use vortex_mask::Mask;
222
223    use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
224    use crate::compute::test_harness::test_mask;
225    use crate::compute::{filter, try_cast};
226    use crate::validity::Validity;
227    use crate::{Array, IntoArray as _};
228
229    #[test]
230    fn filter_empty_struct() {
231        let struct_arr =
232            StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
233        let mask = vec![
234            false, true, false, true, false, true, false, true, false, true,
235        ];
236        let filtered = filter(&struct_arr, &Mask::from_iter(mask)).unwrap();
237        assert_eq!(filtered.len(), 5);
238    }
239
240    #[test]
241    fn filter_empty_struct_with_empty_filter() {
242        let struct_arr =
243            StructArray::try_new(vec![].into(), vec![], 0, Validity::NonNullable).unwrap();
244        let filtered = filter(&struct_arr, &Mask::from_iter::<[bool; 0]>([])).unwrap();
245        assert_eq!(filtered.len(), 0);
246    }
247
248    #[test]
249    fn test_mask_empty_struct() {
250        test_mask(&StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable).unwrap());
251    }
252
253    #[test]
254    fn test_mask_complex_struct() {
255        let xs = buffer![0i64, 1, 2, 3, 4].into_array();
256        let ys = VarBinArray::from_iter(
257            [Some("a"), Some("b"), None, Some("d"), None],
258            DType::Utf8(Nullability::Nullable),
259        )
260        .into_array();
261        let zs =
262            BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
263
264        test_mask(
265            &StructArray::try_new(
266                ["xs".into(), "ys".into(), "zs".into()].into(),
267                vec![
268                    StructArray::try_new(
269                        ["left".into(), "right".into()].into(),
270                        vec![xs.clone(), xs],
271                        5,
272                        Validity::NonNullable,
273                    )
274                    .unwrap()
275                    .into_array(),
276                    ys,
277                    zs,
278                ],
279                5,
280                Validity::NonNullable,
281            )
282            .unwrap(),
283        );
284    }
285
286    #[test]
287    fn test_cast_empty_struct() {
288        let array = StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
289            .unwrap()
290            .into_array();
291        let non_nullable_dtype = DType::Struct(
292            Arc::from(StructDType::new([].into(), vec![])),
293            Nullability::NonNullable,
294        );
295        let casted = try_cast(&array, &non_nullable_dtype).unwrap();
296        assert_eq!(casted.dtype(), &non_nullable_dtype);
297
298        let nullable_dtype = DType::Struct(
299            Arc::from(StructDType::new([].into(), vec![])),
300            Nullability::Nullable,
301        );
302        let casted = try_cast(&array, &nullable_dtype).unwrap();
303        assert_eq!(casted.dtype(), &nullable_dtype);
304    }
305
306    #[test]
307    fn test_cast_cannot_change_name_order() {
308        let array = StructArray::try_new(
309            ["xs".into(), "ys".into(), "zs".into()].into(),
310            vec![
311                buffer![1u8].into_array(),
312                buffer![1u8].into_array(),
313                buffer![1u8].into_array(),
314            ],
315            1,
316            Validity::NonNullable,
317        )
318        .unwrap();
319
320        let tu8 = DType::Primitive(PType::U8, Nullability::NonNullable);
321
322        let result = try_cast(
323            &array,
324            &DType::Struct(
325                Arc::from(StructDType::new(
326                    FieldNames::from(["ys".into(), "xs".into(), "zs".into()]),
327                    vec![tu8.clone(), tu8.clone(), tu8],
328                )),
329                Nullability::NonNullable,
330            ),
331        );
332        assert!(
333            result.as_ref().is_err_and(|err| {
334                err.to_string()
335                    .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
336            }),
337            "{:?}",
338            result
339        );
340    }
341
342    #[test]
343    fn test_cast_complex_struct() {
344        let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
345        let ys = VarBinArray::from_vec(
346            vec!["a", "b", "c", "d", "e"],
347            DType::Utf8(Nullability::Nullable),
348        );
349        let zs = BoolArray::new(
350            BooleanBuffer::from_iter([true, true, false, false, true]),
351            Validity::AllValid,
352        );
353        let fully_nullable_array = StructArray::try_new(
354            ["xs".into(), "ys".into(), "zs".into()].into(),
355            vec![
356                StructArray::try_new(
357                    ["left".into(), "right".into()].into(),
358                    vec![xs.to_array(), xs.to_array()],
359                    5,
360                    Validity::AllValid,
361                )
362                .unwrap()
363                .into_array(),
364                ys.into_array(),
365                zs.into_array(),
366            ],
367            5,
368            Validity::AllValid,
369        )
370        .unwrap()
371        .into_array();
372
373        let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
374        let casted = try_cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
375        assert_eq!(casted.dtype(), &top_level_non_nullable);
376
377        let non_null_xs_right = DType::Struct(
378            Arc::from(StructDType::new(
379                ["xs".into(), "ys".into(), "zs".into()].into(),
380                vec![
381                    DType::Struct(
382                        Arc::from(StructDType::new(
383                            ["left".into(), "right".into()].into(),
384                            vec![
385                                DType::Primitive(PType::I64, Nullability::NonNullable),
386                                DType::Primitive(PType::I64, Nullability::Nullable),
387                            ],
388                        )),
389                        Nullability::Nullable,
390                    ),
391                    DType::Utf8(Nullability::Nullable),
392                    DType::Bool(Nullability::Nullable),
393                ],
394            )),
395            Nullability::Nullable,
396        );
397        let casted = try_cast(&fully_nullable_array, &non_null_xs_right).unwrap();
398        assert_eq!(casted.dtype(), &non_null_xs_right);
399
400        let non_null_xs = DType::Struct(
401            Arc::from(StructDType::new(
402                ["xs".into(), "ys".into(), "zs".into()].into(),
403                vec![
404                    DType::Struct(
405                        Arc::from(StructDType::new(
406                            ["left".into(), "right".into()].into(),
407                            vec![
408                                DType::Primitive(PType::I64, Nullability::Nullable),
409                                DType::Primitive(PType::I64, Nullability::Nullable),
410                            ],
411                        )),
412                        Nullability::NonNullable,
413                    ),
414                    DType::Utf8(Nullability::Nullable),
415                    DType::Bool(Nullability::Nullable),
416                ],
417            )),
418            Nullability::Nullable,
419        );
420        let casted = try_cast(&fully_nullable_array, &non_null_xs).unwrap();
421        assert_eq!(casted.dtype(), &non_null_xs);
422    }
423}