vortex_array/arrays/struct_/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::iter::once;
6
7use itertools::Itertools;
8use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::Scalar;
11
12use crate::stats::{ArrayStats, StatsSetRef};
13use crate::validity::Validity;
14use crate::vtable::{
15    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
16    ValidityVTableFromValidityHelper,
17};
18use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
19
20mod compute;
21mod serde;
22
23vtable!(Struct);
24
25impl VTable for StructVTable {
26    type Array = StructArray;
27    type Encoding = StructEncoding;
28
29    type ArrayVTable = Self;
30    type CanonicalVTable = Self;
31    type OperationsVTable = Self;
32    type ValidityVTable = ValidityVTableFromValidityHelper;
33    type VisitorVTable = Self;
34    type ComputeVTable = NotSupported;
35    type EncodeVTable = NotSupported;
36    type SerdeVTable = Self;
37
38    fn id(_encoding: &Self::Encoding) -> EncodingId {
39        EncodingId::new_ref("vortex.struct")
40    }
41
42    fn encoding(_array: &Self::Array) -> EncodingRef {
43        EncodingRef::new_ref(StructEncoding.as_ref())
44    }
45}
46
47#[derive(Clone, Debug)]
48pub struct StructArray {
49    len: usize,
50    dtype: DType,
51    fields: Vec<ArrayRef>,
52    validity: Validity,
53    stats_set: ArrayStats,
54}
55
56#[derive(Clone, Debug)]
57pub struct StructEncoding;
58
59impl StructArray {
60    pub fn fields(&self) -> &[ArrayRef] {
61        &self.fields
62    }
63
64    pub fn field_by_name(&self, name: impl AsRef<str>) -> VortexResult<&ArrayRef> {
65        let name = name.as_ref();
66        self.field_by_name_opt(name).ok_or_else(|| {
67            vortex_err!(
68                "Field {name} not found in struct array with names {:?}",
69                self.names()
70            )
71        })
72    }
73
74    pub fn field_by_name_opt(&self, name: impl AsRef<str>) -> Option<&ArrayRef> {
75        let name = name.as_ref();
76        self.names()
77            .iter()
78            .position(|field_name| field_name.as_ref() == name)
79            .map(|idx| &self.fields[idx])
80    }
81
82    pub fn names(&self) -> &FieldNames {
83        self.struct_fields().names()
84    }
85
86    pub fn struct_fields(&self) -> &StructFields {
87        let Some(struct_dtype) = &self.dtype.as_struct() else {
88            unreachable!(
89                "struct arrays must have be a DType::Struct, this is likely an internal bug."
90            )
91        };
92        struct_dtype
93    }
94
95    /// Create a new `StructArray` with the given length, but without any fields.
96    pub fn new_with_len(len: usize) -> Self {
97        Self::try_new(
98            FieldNames::default(),
99            Vec::new(),
100            len,
101            Validity::NonNullable,
102        )
103        .vortex_expect("StructArray::new_with_len should not fail")
104    }
105
106    pub fn try_new(
107        names: FieldNames,
108        fields: Vec<ArrayRef>,
109        length: usize,
110        validity: Validity,
111    ) -> VortexResult<Self> {
112        let nullability = validity.nullability();
113
114        if names.len() != fields.len() {
115            vortex_bail!("Got {} names and {} fields", names.len(), fields.len());
116        }
117
118        for field in fields.iter() {
119            if field.len() != length {
120                vortex_bail!(
121                    "Expected all struct fields to have length {length}, found {}",
122                    fields.iter().map(|f| f.len()).format(","),
123                );
124            }
125        }
126
127        let field_dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect();
128        let dtype = DType::Struct(StructFields::new(names, field_dtypes), nullability);
129
130        if length != validity.maybe_len().unwrap_or(length) {
131            vortex_bail!(
132                "array length {} and validity length must match {}",
133                length,
134                validity
135                    .maybe_len()
136                    .vortex_expect("can only fail if maybe is some")
137            )
138        }
139
140        Ok(Self {
141            len: length,
142            dtype,
143            fields,
144            validity,
145            stats_set: Default::default(),
146        })
147    }
148
149    pub fn try_new_with_dtype(
150        fields: Vec<ArrayRef>,
151        dtype: StructFields,
152        length: usize,
153        validity: Validity,
154    ) -> VortexResult<Self> {
155        for (field, struct_dt) in fields.iter().zip(dtype.fields()) {
156            if field.len() != length {
157                vortex_bail!(
158                    "Expected all struct fields to have length {length}, found {}",
159                    field.len()
160                );
161            }
162
163            if &struct_dt != field.dtype() {
164                vortex_bail!(
165                    "Expected all struct fields to have dtype {}, found {}",
166                    struct_dt,
167                    field.dtype()
168                );
169            }
170        }
171
172        Ok(Self {
173            len: length,
174            dtype: DType::Struct(dtype, validity.nullability()),
175            fields,
176            validity,
177            stats_set: Default::default(),
178        })
179    }
180
181    pub fn from_fields<N: AsRef<str>>(items: &[(N, ArrayRef)]) -> VortexResult<Self> {
182        Self::try_from_iter(items.iter().map(|(a, b)| (a, b.to_array())))
183    }
184
185    pub fn try_from_iter_with_validity<
186        N: AsRef<str>,
187        A: IntoArray,
188        T: IntoIterator<Item = (N, A)>,
189    >(
190        iter: T,
191        validity: Validity,
192    ) -> VortexResult<Self> {
193        let (names, fields): (Vec<FieldName>, Vec<ArrayRef>) = iter
194            .into_iter()
195            .map(|(name, fields)| (FieldName::from(name.as_ref()), fields.into_array()))
196            .unzip();
197        let len = fields
198            .first()
199            .map(|f| f.len())
200            .ok_or_else(|| vortex_err!("StructArray cannot be constructed from an empty slice of arrays because the length is unspecified"))?;
201
202        Self::try_new(FieldNames::from_iter(names), fields, len, validity)
203    }
204
205    pub fn try_from_iter<N: AsRef<str>, A: IntoArray, T: IntoIterator<Item = (N, A)>>(
206        iter: T,
207    ) -> VortexResult<Self> {
208        Self::try_from_iter_with_validity(iter, Validity::NonNullable)
209    }
210
211    // TODO(aduffy): Add equivalent function to support field masks for nested column access.
212    /// Return a new StructArray with the given projection applied.
213    ///
214    /// Projection does not copy data arrays. Projection is defined by an ordinal array slice
215    /// which specifies the new ordering of columns in the struct. The projection can be used to
216    /// perform column re-ordering, deletion, or duplication at a logical level, without any data
217    /// copying.
218    #[allow(clippy::same_name_method)]
219    pub fn project(&self, projection: &[FieldName]) -> VortexResult<Self> {
220        let mut children = Vec::with_capacity(projection.len());
221        let mut names = Vec::with_capacity(projection.len());
222
223        for f_name in projection.iter() {
224            let idx = self
225                .names()
226                .iter()
227                .position(|name| name == f_name)
228                .ok_or_else(|| vortex_err!("Unknown field {f_name}"))?;
229
230            names.push(self.names()[idx].clone());
231            children.push(self.fields()[idx].clone());
232        }
233
234        StructArray::try_new(
235            FieldNames::from(names.as_slice()),
236            children,
237            self.len(),
238            self.validity().clone(),
239        )
240    }
241
242    /// Removes and returns a column from the struct array by name.
243    /// If the column does not exist, returns `None`.
244    pub fn remove_column(&mut self, name: impl Into<FieldName>) -> Option<ArrayRef> {
245        let name = name.into();
246
247        let struct_dtype = self.struct_fields().clone();
248
249        let position = struct_dtype
250            .names()
251            .iter()
252            .position(|field_name| field_name.as_ref() == name.as_ref())?;
253
254        let field = self.fields.remove(position);
255
256        let new_dtype = struct_dtype.without_field(position);
257        self.dtype = DType::Struct(new_dtype, self.dtype.nullability());
258
259        Some(field)
260    }
261
262    /// Create a new StructArray by appending a new column onto the existing array.
263    pub fn with_column(&self, name: impl Into<FieldName>, array: ArrayRef) -> VortexResult<Self> {
264        let name = name.into();
265        let struct_dtype = self.struct_fields().clone();
266
267        let names = struct_dtype.names().iter().cloned().chain(once(name));
268        let types = struct_dtype.fields().chain(once(array.dtype().clone()));
269        let new_fields = StructFields::new(names.collect(), types.collect());
270
271        let mut children = self.fields.clone();
272        children.push(array);
273
274        Self::try_new_with_dtype(children, new_fields, self.len, self.validity.clone())
275    }
276}
277
278impl ValidityHelper for StructArray {
279    fn validity(&self) -> &Validity {
280        &self.validity
281    }
282}
283
284impl ArrayVTable<StructVTable> for StructVTable {
285    fn len(array: &StructArray) -> usize {
286        array.len
287    }
288
289    fn dtype(array: &StructArray) -> &DType {
290        &array.dtype
291    }
292
293    fn stats(array: &StructArray) -> StatsSetRef<'_> {
294        array.stats_set.to_ref(array.as_ref())
295    }
296}
297
298impl CanonicalVTable<StructVTable> for StructVTable {
299    fn canonicalize(array: &StructArray) -> VortexResult<Canonical> {
300        Ok(Canonical::Struct(array.clone()))
301    }
302}
303
304impl OperationsVTable<StructVTable> for StructVTable {
305    fn slice(array: &StructArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
306        let fields = array
307            .fields()
308            .iter()
309            .map(|field| field.slice(start, stop))
310            .try_collect()?;
311        StructArray::try_new_with_dtype(
312            fields,
313            array.struct_fields().clone(),
314            stop - start,
315            array.validity().slice(start, stop)?,
316        )
317        .map(|a| a.into_array())
318    }
319
320    fn scalar_at(array: &StructArray, index: usize) -> VortexResult<Scalar> {
321        Ok(Scalar::struct_(
322            array.dtype().clone(),
323            array
324                .fields()
325                .iter()
326                .map(|field| field.scalar_at(index))
327                .try_collect()?,
328        ))
329    }
330}
331
332#[cfg(test)]
333mod test {
334    use vortex_buffer::buffer;
335    use vortex_dtype::{DType, FieldName, FieldNames, Nullability, PType};
336
337    use crate::IntoArray;
338    use crate::arrays::primitive::PrimitiveArray;
339    use crate::arrays::struct_::StructArray;
340    use crate::arrays::varbin::VarBinArray;
341    use crate::arrays::{BoolArray, BoolVTable, PrimitiveVTable};
342    use crate::validity::Validity;
343
344    #[test]
345    fn test_project() {
346        let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
347        let ys = VarBinArray::from_vec(
348            vec!["a", "b", "c", "d", "e"],
349            DType::Utf8(Nullability::NonNullable),
350        );
351        let zs = BoolArray::from_iter([true, true, true, false, false]);
352
353        let struct_a = StructArray::try_new(
354            FieldNames::from(["xs", "ys", "zs"]),
355            vec![xs.into_array(), ys.into_array(), zs.into_array()],
356            5,
357            Validity::NonNullable,
358        )
359        .unwrap();
360
361        let struct_b = struct_a
362            .project(&[FieldName::from("zs"), FieldName::from("xs")])
363            .unwrap();
364        assert_eq!(
365            struct_b.names().as_ref(),
366            [FieldName::from("zs"), FieldName::from("xs")],
367        );
368
369        assert_eq!(struct_b.len(), 5);
370
371        let bools = &struct_b.fields[0];
372        assert_eq!(
373            bools
374                .as_::<BoolVTable>()
375                .boolean_buffer()
376                .iter()
377                .collect::<Vec<_>>(),
378            vec![true, true, true, false, false]
379        );
380
381        let prims = &struct_b.fields[1];
382        assert_eq!(
383            prims.as_::<PrimitiveVTable>().as_slice::<i64>(),
384            [0i64, 1, 2, 3, 4]
385        );
386    }
387
388    #[test]
389    fn test_remove_column() {
390        let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
391        let ys = PrimitiveArray::new(buffer![4u64, 5, 6, 7, 8], Validity::NonNullable);
392
393        let mut struct_a = StructArray::try_new(
394            FieldNames::from(["xs", "ys"]),
395            vec![xs.into_array(), ys.into_array()],
396            5,
397            Validity::NonNullable,
398        )
399        .unwrap();
400
401        let removed = struct_a.remove_column("xs").unwrap();
402        assert_eq!(
403            removed.dtype(),
404            &DType::Primitive(PType::I64, Nullability::NonNullable)
405        );
406        assert_eq!(
407            removed.as_::<PrimitiveVTable>().as_slice::<i64>(),
408            [0i64, 1, 2, 3, 4]
409        );
410
411        assert_eq!(struct_a.names(), &[FieldName::from("ys")].into());
412        assert_eq!(struct_a.fields.len(), 1);
413        assert_eq!(struct_a.len(), 5);
414        assert_eq!(
415            struct_a.fields[0].dtype(),
416            &DType::Primitive(PType::U64, Nullability::NonNullable)
417        );
418        assert_eq!(
419            struct_a.fields[0]
420                .as_::<PrimitiveVTable>()
421                .as_slice::<u64>(),
422            [4u64, 5, 6, 7, 8]
423        );
424
425        let empty = struct_a.remove_column("non_existent");
426        assert!(
427            empty.is_none(),
428            "Expected None when removing non-existent column"
429        );
430        assert_eq!(struct_a.names(), &[FieldName::from("ys")].into());
431    }
432}