vortex_array/arrays/struct_/
mod.rs

1use std::fmt::Debug;
2use std::sync::Arc;
3
4use itertools::Itertools;
5use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
6use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
7use vortex_scalar::Scalar;
8
9use crate::stats::{ArrayStats, StatsSetRef};
10use crate::validity::Validity;
11use crate::vtable::{
12    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
13    ValidityVTableFromValidityHelper,
14};
15use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
16
17mod compute;
18mod serde;
19
20vtable!(Struct);
21
22impl VTable for StructVTable {
23    type Array = StructArray;
24    type Encoding = StructEncoding;
25
26    type ArrayVTable = Self;
27    type CanonicalVTable = Self;
28    type OperationsVTable = Self;
29    type ValidityVTable = ValidityVTableFromValidityHelper;
30    type VisitorVTable = Self;
31    type ComputeVTable = NotSupported;
32    type EncodeVTable = NotSupported;
33    type SerdeVTable = Self;
34
35    fn id(_encoding: &Self::Encoding) -> EncodingId {
36        EncodingId::new_ref("vortex.struct")
37    }
38
39    fn encoding(_array: &Self::Array) -> EncodingRef {
40        EncodingRef::new_ref(StructEncoding.as_ref())
41    }
42}
43
44#[derive(Clone, Debug)]
45pub struct StructArray {
46    len: usize,
47    dtype: DType,
48    fields: Vec<ArrayRef>,
49    validity: Validity,
50    stats_set: ArrayStats,
51}
52
53#[derive(Clone, Debug)]
54pub struct StructEncoding;
55
56impl StructArray {
57    pub fn fields(&self) -> &[ArrayRef] {
58        &self.fields
59    }
60
61    pub fn field_by_name(&self, name: impl AsRef<str>) -> VortexResult<&ArrayRef> {
62        let name = name.as_ref();
63        self.field_by_name_opt(name).ok_or_else(|| {
64            vortex_err!(
65                "Field {name} not found in struct array with names {:?}",
66                self.names()
67            )
68        })
69    }
70
71    pub fn field_by_name_opt(&self, name: impl AsRef<str>) -> Option<&ArrayRef> {
72        let name = name.as_ref();
73        self.names()
74            .iter()
75            .position(|field_name| field_name.as_ref() == name)
76            .map(|idx| &self.fields[idx])
77    }
78
79    pub fn names(&self) -> &FieldNames {
80        self.struct_fields().names()
81    }
82
83    pub fn struct_fields(&self) -> &Arc<StructFields> {
84        let Some(struct_dtype) = &self.dtype.as_struct() else {
85            unreachable!(
86                "struct arrays must have be a DType::Struct, this is likely an internal bug."
87            )
88        };
89        struct_dtype
90    }
91
92    pub fn try_new(
93        names: FieldNames,
94        fields: Vec<ArrayRef>,
95        length: usize,
96        validity: Validity,
97    ) -> VortexResult<Self> {
98        let nullability = validity.nullability();
99
100        if names.len() != fields.len() {
101            vortex_bail!("Got {} names and {} fields", names.len(), fields.len());
102        }
103
104        for field in fields.iter() {
105            if field.len() != length {
106                vortex_bail!(
107                    "Expected all struct fields to have length {length}, found {}",
108                    fields.iter().map(|f| f.len()).format(","),
109                );
110            }
111        }
112
113        let field_dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect();
114        let dtype = DType::Struct(
115            Arc::new(StructFields::new(names, field_dtypes)),
116            nullability,
117        );
118
119        if length != validity.maybe_len().unwrap_or(length) {
120            vortex_bail!(
121                "array length {} and validity length must match {}",
122                length,
123                validity
124                    .maybe_len()
125                    .vortex_expect("can only fail if maybe is some")
126            )
127        }
128
129        Ok(Self {
130            len: length,
131            dtype,
132            fields,
133            validity,
134            stats_set: Default::default(),
135        })
136    }
137
138    pub fn try_new_with_dtype(
139        fields: Vec<ArrayRef>,
140        dtype: Arc<StructFields>,
141        length: usize,
142        validity: Validity,
143    ) -> VortexResult<Self> {
144        for (field, struct_dt) in fields.iter().zip(dtype.fields()) {
145            if field.len() != length {
146                vortex_bail!(
147                    "Expected all struct fields to have length {length}, found {}",
148                    field.len()
149                );
150            }
151
152            if &struct_dt != field.dtype() {
153                vortex_bail!(
154                    "Expected all struct fields to have dtype {}, found {}",
155                    struct_dt,
156                    field.dtype()
157                );
158            }
159        }
160
161        Ok(Self {
162            len: length,
163            dtype: DType::Struct(dtype, validity.nullability()),
164            fields,
165            validity,
166            stats_set: Default::default(),
167        })
168    }
169
170    pub fn from_fields<N: AsRef<str>>(items: &[(N, ArrayRef)]) -> VortexResult<Self> {
171        Self::try_from_iter(items.iter().map(|(a, b)| (a, b.to_array())))
172    }
173
174    pub fn try_from_iter_with_validity<
175        N: AsRef<str>,
176        A: IntoArray,
177        T: IntoIterator<Item = (N, A)>,
178    >(
179        iter: T,
180        validity: Validity,
181    ) -> VortexResult<Self> {
182        let (names, fields): (Vec<FieldName>, Vec<ArrayRef>) = iter
183            .into_iter()
184            .map(|(name, fields)| (FieldName::from(name.as_ref()), fields.into_array()))
185            .unzip();
186        let len = fields
187            .first()
188            .map(|f| f.len())
189            .ok_or_else(|| vortex_err!("StructArray cannot be constructed from an empty slice of arrays because the length is unspecified"))?;
190
191        Self::try_new(FieldNames::from_iter(names), fields, len, validity)
192    }
193
194    pub fn try_from_iter<N: AsRef<str>, A: IntoArray, T: IntoIterator<Item = (N, A)>>(
195        iter: T,
196    ) -> VortexResult<Self> {
197        Self::try_from_iter_with_validity(iter, Validity::NonNullable)
198    }
199
200    // TODO(aduffy): Add equivalent function to support field masks for nested column access.
201    /// Return a new StructArray with the given projection applied.
202    ///
203    /// Projection does not copy data arrays. Projection is defined by an ordinal array slice
204    /// which specifies the new ordering of columns in the struct. The projection can be used to
205    /// perform column re-ordering, deletion, or duplication at a logical level, without any data
206    /// copying.
207    #[allow(clippy::same_name_method)]
208    pub fn project(&self, projection: &[FieldName]) -> VortexResult<Self> {
209        let mut children = Vec::with_capacity(projection.len());
210        let mut names = Vec::with_capacity(projection.len());
211
212        for f_name in projection.iter() {
213            let idx = self
214                .names()
215                .iter()
216                .position(|name| name == f_name)
217                .ok_or_else(|| vortex_err!("Unknown field {f_name}"))?;
218
219            names.push(self.names()[idx].clone());
220            children.push(self.fields()[idx].clone());
221        }
222
223        StructArray::try_new(
224            FieldNames::from(names.as_slice()),
225            children,
226            self.len(),
227            self.validity().clone(),
228        )
229    }
230
231    /// Removes and returns a column from the struct array by name.
232    /// If the column does not exist, returns `None`.
233    pub fn remove_column(&mut self, name: impl Into<FieldName>) -> Option<ArrayRef> {
234        let name = name.into();
235
236        let Some(struct_dtype) = self.dtype.as_struct() else {
237            unreachable!(
238                "struct arrays must have be a DType::Struct, this is likely an internal bug."
239            )
240        };
241
242        let position = struct_dtype
243            .names()
244            .iter()
245            .position(|field_name| field_name.as_ref() == name.as_ref())?;
246
247        let field = self.fields.remove(position);
248
249        let new_dtype = struct_dtype.without_field(position);
250        self.dtype = DType::Struct(Arc::new(new_dtype), self.dtype.nullability());
251
252        Some(field)
253    }
254}
255
256impl ValidityHelper for StructArray {
257    fn validity(&self) -> &Validity {
258        &self.validity
259    }
260}
261
262impl ArrayVTable<StructVTable> for StructVTable {
263    fn len(array: &StructArray) -> usize {
264        array.len
265    }
266
267    fn dtype(array: &StructArray) -> &DType {
268        &array.dtype
269    }
270
271    fn stats(array: &StructArray) -> StatsSetRef<'_> {
272        array.stats_set.to_ref(array.as_ref())
273    }
274}
275
276impl CanonicalVTable<StructVTable> for StructVTable {
277    fn canonicalize(array: &StructArray) -> VortexResult<Canonical> {
278        Ok(Canonical::Struct(array.clone()))
279    }
280}
281
282impl OperationsVTable<StructVTable> for StructVTable {
283    fn slice(array: &StructArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
284        let fields = array
285            .fields()
286            .iter()
287            .map(|field| field.slice(start, stop))
288            .try_collect()?;
289        StructArray::try_new_with_dtype(
290            fields,
291            array.struct_fields().clone(),
292            stop - start,
293            array.validity().slice(start, stop)?,
294        )
295        .map(|a| a.into_array())
296    }
297
298    fn scalar_at(array: &StructArray, index: usize) -> VortexResult<Scalar> {
299        Ok(Scalar::struct_(
300            array.dtype().clone(),
301            array
302                .fields()
303                .iter()
304                .map(|field| field.scalar_at(index))
305                .try_collect()?,
306        ))
307    }
308}
309
310#[cfg(test)]
311mod test {
312    use vortex_buffer::buffer;
313    use vortex_dtype::{DType, FieldName, FieldNames, Nullability, PType};
314
315    use crate::IntoArray;
316    use crate::arrays::primitive::PrimitiveArray;
317    use crate::arrays::struct_::StructArray;
318    use crate::arrays::varbin::VarBinArray;
319    use crate::arrays::{BoolArray, BoolVTable, PrimitiveVTable};
320    use crate::validity::Validity;
321
322    #[test]
323    fn test_project() {
324        let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
325        let ys = VarBinArray::from_vec(
326            vec!["a", "b", "c", "d", "e"],
327            DType::Utf8(Nullability::NonNullable),
328        );
329        let zs = BoolArray::from_iter([true, true, true, false, false]);
330
331        let struct_a = StructArray::try_new(
332            FieldNames::from(["xs".into(), "ys".into(), "zs".into()]),
333            vec![xs.into_array(), ys.into_array(), zs.into_array()],
334            5,
335            Validity::NonNullable,
336        )
337        .unwrap();
338
339        let struct_b = struct_a
340            .project(&[FieldName::from("zs"), FieldName::from("xs")])
341            .unwrap();
342        assert_eq!(
343            struct_b.names().as_ref(),
344            [FieldName::from("zs"), FieldName::from("xs")],
345        );
346
347        assert_eq!(struct_b.len(), 5);
348
349        let bools = &struct_b.fields[0];
350        assert_eq!(
351            bools
352                .as_::<BoolVTable>()
353                .boolean_buffer()
354                .iter()
355                .collect::<Vec<_>>(),
356            vec![true, true, true, false, false]
357        );
358
359        let prims = &struct_b.fields[1];
360        assert_eq!(
361            prims.as_::<PrimitiveVTable>().as_slice::<i64>(),
362            [0i64, 1, 2, 3, 4]
363        );
364    }
365
366    #[test]
367    fn test_remove_column() {
368        let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
369        let ys = PrimitiveArray::new(buffer![4u64, 5, 6, 7, 8], Validity::NonNullable);
370
371        let mut struct_a = StructArray::try_new(
372            FieldNames::from(["xs".into(), "ys".into()]),
373            vec![xs.into_array(), ys.into_array()],
374            5,
375            Validity::NonNullable,
376        )
377        .unwrap();
378
379        let removed = struct_a.remove_column("xs").unwrap();
380        assert_eq!(
381            removed.dtype(),
382            &DType::Primitive(PType::I64, Nullability::NonNullable)
383        );
384        assert_eq!(
385            removed.as_::<PrimitiveVTable>().as_slice::<i64>(),
386            [0i64, 1, 2, 3, 4]
387        );
388
389        assert_eq!(struct_a.names().as_ref(), [FieldName::from("ys")]);
390        assert_eq!(struct_a.fields.len(), 1);
391        assert_eq!(struct_a.len(), 5);
392        assert_eq!(
393            struct_a.fields[0].dtype(),
394            &DType::Primitive(PType::U64, Nullability::NonNullable)
395        );
396        assert_eq!(
397            struct_a.fields[0]
398                .as_::<PrimitiveVTable>()
399                .as_slice::<u64>(),
400            [4u64, 5, 6, 7, 8]
401        );
402
403        let empty = struct_a.remove_column("non_existent");
404        assert!(
405            empty.is_none(),
406            "Expected None when removing non-existent column"
407        );
408        assert_eq!(struct_a.names().as_ref(), [FieldName::from("ys")]);
409    }
410}