vortex_array/arrays/struct_/
mod.rs

1use std::fmt::Debug;
2
3use itertools::Itertools;
4use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
5use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
6use vortex_scalar::Scalar;
7
8use crate::stats::{ArrayStats, StatsSetRef};
9use crate::validity::Validity;
10use crate::vtable::{
11    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
12    ValidityVTableFromValidityHelper,
13};
14use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
15
16mod compute;
17mod serde;
18
19vtable!(Struct);
20
21impl VTable for StructVTable {
22    type Array = StructArray;
23    type Encoding = StructEncoding;
24
25    type ArrayVTable = Self;
26    type CanonicalVTable = Self;
27    type OperationsVTable = Self;
28    type ValidityVTable = ValidityVTableFromValidityHelper;
29    type VisitorVTable = Self;
30    type ComputeVTable = NotSupported;
31    type EncodeVTable = NotSupported;
32    type SerdeVTable = Self;
33
34    fn id(_encoding: &Self::Encoding) -> EncodingId {
35        EncodingId::new_ref("vortex.struct")
36    }
37
38    fn encoding(_array: &Self::Array) -> EncodingRef {
39        EncodingRef::new_ref(StructEncoding.as_ref())
40    }
41}
42
43#[derive(Clone, Debug)]
44pub struct StructArray {
45    len: usize,
46    dtype: DType,
47    fields: Vec<ArrayRef>,
48    validity: Validity,
49    stats_set: ArrayStats,
50}
51
52#[derive(Clone, Debug)]
53pub struct StructEncoding;
54
55impl StructArray {
56    pub fn fields(&self) -> &[ArrayRef] {
57        &self.fields
58    }
59
60    pub fn field_by_name(&self, name: impl AsRef<str>) -> VortexResult<&ArrayRef> {
61        let name = name.as_ref();
62        self.field_by_name_opt(name).ok_or_else(|| {
63            vortex_err!(
64                "Field {name} not found in struct array with names {:?}",
65                self.names()
66            )
67        })
68    }
69
70    pub fn field_by_name_opt(&self, name: impl AsRef<str>) -> Option<&ArrayRef> {
71        let name = name.as_ref();
72        self.names()
73            .iter()
74            .position(|field_name| field_name.as_ref() == name)
75            .map(|idx| &self.fields[idx])
76    }
77
78    pub fn names(&self) -> &FieldNames {
79        self.struct_fields().names()
80    }
81
82    pub fn struct_fields(&self) -> &StructFields {
83        let Some(struct_dtype) = &self.dtype.as_struct() else {
84            unreachable!(
85                "struct arrays must have be a DType::Struct, this is likely an internal bug."
86            )
87        };
88        struct_dtype
89    }
90
91    pub fn try_new(
92        names: FieldNames,
93        fields: Vec<ArrayRef>,
94        length: usize,
95        validity: Validity,
96    ) -> VortexResult<Self> {
97        let nullability = validity.nullability();
98
99        if names.len() != fields.len() {
100            vortex_bail!("Got {} names and {} fields", names.len(), fields.len());
101        }
102
103        for field in fields.iter() {
104            if field.len() != length {
105                vortex_bail!(
106                    "Expected all struct fields to have length {length}, found {}",
107                    fields.iter().map(|f| f.len()).format(","),
108                );
109            }
110        }
111
112        let field_dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect();
113        let dtype = DType::Struct(StructFields::new(names, field_dtypes), nullability);
114
115        if length != validity.maybe_len().unwrap_or(length) {
116            vortex_bail!(
117                "array length {} and validity length must match {}",
118                length,
119                validity
120                    .maybe_len()
121                    .vortex_expect("can only fail if maybe is some")
122            )
123        }
124
125        Ok(Self {
126            len: length,
127            dtype,
128            fields,
129            validity,
130            stats_set: Default::default(),
131        })
132    }
133
134    pub fn try_new_with_dtype(
135        fields: Vec<ArrayRef>,
136        dtype: StructFields,
137        length: usize,
138        validity: Validity,
139    ) -> VortexResult<Self> {
140        for (field, struct_dt) in fields.iter().zip(dtype.fields()) {
141            if field.len() != length {
142                vortex_bail!(
143                    "Expected all struct fields to have length {length}, found {}",
144                    field.len()
145                );
146            }
147
148            if &struct_dt != field.dtype() {
149                vortex_bail!(
150                    "Expected all struct fields to have dtype {}, found {}",
151                    struct_dt,
152                    field.dtype()
153                );
154            }
155        }
156
157        Ok(Self {
158            len: length,
159            dtype: DType::Struct(dtype, validity.nullability()),
160            fields,
161            validity,
162            stats_set: Default::default(),
163        })
164    }
165
166    pub fn from_fields<N: AsRef<str>>(items: &[(N, ArrayRef)]) -> VortexResult<Self> {
167        Self::try_from_iter(items.iter().map(|(a, b)| (a, b.to_array())))
168    }
169
170    pub fn try_from_iter_with_validity<
171        N: AsRef<str>,
172        A: IntoArray,
173        T: IntoIterator<Item = (N, A)>,
174    >(
175        iter: T,
176        validity: Validity,
177    ) -> VortexResult<Self> {
178        let (names, fields): (Vec<FieldName>, Vec<ArrayRef>) = iter
179            .into_iter()
180            .map(|(name, fields)| (FieldName::from(name.as_ref()), fields.into_array()))
181            .unzip();
182        let len = fields
183            .first()
184            .map(|f| f.len())
185            .ok_or_else(|| vortex_err!("StructArray cannot be constructed from an empty slice of arrays because the length is unspecified"))?;
186
187        Self::try_new(FieldNames::from_iter(names), fields, len, validity)
188    }
189
190    pub fn try_from_iter<N: AsRef<str>, A: IntoArray, T: IntoIterator<Item = (N, A)>>(
191        iter: T,
192    ) -> VortexResult<Self> {
193        Self::try_from_iter_with_validity(iter, Validity::NonNullable)
194    }
195
196    // TODO(aduffy): Add equivalent function to support field masks for nested column access.
197    /// Return a new StructArray with the given projection applied.
198    ///
199    /// Projection does not copy data arrays. Projection is defined by an ordinal array slice
200    /// which specifies the new ordering of columns in the struct. The projection can be used to
201    /// perform column re-ordering, deletion, or duplication at a logical level, without any data
202    /// copying.
203    #[allow(clippy::same_name_method)]
204    pub fn project(&self, projection: &[FieldName]) -> VortexResult<Self> {
205        let mut children = Vec::with_capacity(projection.len());
206        let mut names = Vec::with_capacity(projection.len());
207
208        for f_name in projection.iter() {
209            let idx = self
210                .names()
211                .iter()
212                .position(|name| name == f_name)
213                .ok_or_else(|| vortex_err!("Unknown field {f_name}"))?;
214
215            names.push(self.names()[idx].clone());
216            children.push(self.fields()[idx].clone());
217        }
218
219        StructArray::try_new(
220            FieldNames::from(names.as_slice()),
221            children,
222            self.len(),
223            self.validity().clone(),
224        )
225    }
226
227    /// Removes and returns a column from the struct array by name.
228    /// If the column does not exist, returns `None`.
229    pub fn remove_column(&mut self, name: impl Into<FieldName>) -> Option<ArrayRef> {
230        let name = name.into();
231
232        let Some(struct_dtype) = self.dtype.as_struct() else {
233            unreachable!(
234                "struct arrays must have be a DType::Struct, this is likely an internal bug."
235            )
236        };
237
238        let position = struct_dtype
239            .names()
240            .iter()
241            .position(|field_name| field_name.as_ref() == name.as_ref())?;
242
243        let field = self.fields.remove(position);
244
245        let new_dtype = struct_dtype.without_field(position);
246        self.dtype = DType::Struct(new_dtype, self.dtype.nullability());
247
248        Some(field)
249    }
250}
251
252impl ValidityHelper for StructArray {
253    fn validity(&self) -> &Validity {
254        &self.validity
255    }
256}
257
258impl ArrayVTable<StructVTable> for StructVTable {
259    fn len(array: &StructArray) -> usize {
260        array.len
261    }
262
263    fn dtype(array: &StructArray) -> &DType {
264        &array.dtype
265    }
266
267    fn stats(array: &StructArray) -> StatsSetRef<'_> {
268        array.stats_set.to_ref(array.as_ref())
269    }
270}
271
272impl CanonicalVTable<StructVTable> for StructVTable {
273    fn canonicalize(array: &StructArray) -> VortexResult<Canonical> {
274        Ok(Canonical::Struct(array.clone()))
275    }
276}
277
278impl OperationsVTable<StructVTable> for StructVTable {
279    fn slice(array: &StructArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
280        let fields = array
281            .fields()
282            .iter()
283            .map(|field| field.slice(start, stop))
284            .try_collect()?;
285        StructArray::try_new_with_dtype(
286            fields,
287            array.struct_fields().clone(),
288            stop - start,
289            array.validity().slice(start, stop)?,
290        )
291        .map(|a| a.into_array())
292    }
293
294    fn scalar_at(array: &StructArray, index: usize) -> VortexResult<Scalar> {
295        Ok(Scalar::struct_(
296            array.dtype().clone(),
297            array
298                .fields()
299                .iter()
300                .map(|field| field.scalar_at(index))
301                .try_collect()?,
302        ))
303    }
304}
305
306#[cfg(test)]
307mod test {
308    use vortex_buffer::buffer;
309    use vortex_dtype::{DType, FieldName, FieldNames, Nullability, PType};
310
311    use crate::IntoArray;
312    use crate::arrays::primitive::PrimitiveArray;
313    use crate::arrays::struct_::StructArray;
314    use crate::arrays::varbin::VarBinArray;
315    use crate::arrays::{BoolArray, BoolVTable, PrimitiveVTable};
316    use crate::validity::Validity;
317
318    #[test]
319    fn test_project() {
320        let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
321        let ys = VarBinArray::from_vec(
322            vec!["a", "b", "c", "d", "e"],
323            DType::Utf8(Nullability::NonNullable),
324        );
325        let zs = BoolArray::from_iter([true, true, true, false, false]);
326
327        let struct_a = StructArray::try_new(
328            FieldNames::from(["xs", "ys", "zs"]),
329            vec![xs.into_array(), ys.into_array(), zs.into_array()],
330            5,
331            Validity::NonNullable,
332        )
333        .unwrap();
334
335        let struct_b = struct_a
336            .project(&[FieldName::from("zs"), FieldName::from("xs")])
337            .unwrap();
338        assert_eq!(
339            struct_b.names().as_ref(),
340            [FieldName::from("zs"), FieldName::from("xs")],
341        );
342
343        assert_eq!(struct_b.len(), 5);
344
345        let bools = &struct_b.fields[0];
346        assert_eq!(
347            bools
348                .as_::<BoolVTable>()
349                .boolean_buffer()
350                .iter()
351                .collect::<Vec<_>>(),
352            vec![true, true, true, false, false]
353        );
354
355        let prims = &struct_b.fields[1];
356        assert_eq!(
357            prims.as_::<PrimitiveVTable>().as_slice::<i64>(),
358            [0i64, 1, 2, 3, 4]
359        );
360    }
361
362    #[test]
363    fn test_remove_column() {
364        let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
365        let ys = PrimitiveArray::new(buffer![4u64, 5, 6, 7, 8], Validity::NonNullable);
366
367        let mut struct_a = StructArray::try_new(
368            FieldNames::from(["xs", "ys"]),
369            vec![xs.into_array(), ys.into_array()],
370            5,
371            Validity::NonNullable,
372        )
373        .unwrap();
374
375        let removed = struct_a.remove_column("xs").unwrap();
376        assert_eq!(
377            removed.dtype(),
378            &DType::Primitive(PType::I64, Nullability::NonNullable)
379        );
380        assert_eq!(
381            removed.as_::<PrimitiveVTable>().as_slice::<i64>(),
382            [0i64, 1, 2, 3, 4]
383        );
384
385        assert_eq!(struct_a.names(), &[FieldName::from("ys")].into());
386        assert_eq!(struct_a.fields.len(), 1);
387        assert_eq!(struct_a.len(), 5);
388        assert_eq!(
389            struct_a.fields[0].dtype(),
390            &DType::Primitive(PType::U64, Nullability::NonNullable)
391        );
392        assert_eq!(
393            struct_a.fields[0]
394                .as_::<PrimitiveVTable>()
395                .as_slice::<u64>(),
396            [4u64, 5, 6, 7, 8]
397        );
398
399        let empty = struct_a.remove_column("non_existent");
400        assert!(
401            empty.is_none(),
402            "Expected None when removing non-existent column"
403        );
404        assert_eq!(struct_a.names(), &[FieldName::from("ys")].into());
405    }
406}