vortex_array/arrays/struct_/
mod.rs

1use std::fmt::Debug;
2use std::sync::Arc;
3
4use itertools::Itertools;
5use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
6use vortex_error::{VortexResult, vortex_bail, vortex_err};
7use vortex_scalar::Scalar;
8
9use crate::stats::{ArrayStats, StatsSetRef};
10use crate::validity::Validity;
11use crate::vtable::{
12    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
13    ValidityVTableFromValidityHelper,
14};
15use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
16
17mod compute;
18mod serde;
19
20vtable!(Struct);
21
22impl VTable for StructVTable {
23    type Array = StructArray;
24    type Encoding = StructEncoding;
25
26    type ArrayVTable = Self;
27    type CanonicalVTable = Self;
28    type OperationsVTable = Self;
29    type ValidityVTable = ValidityVTableFromValidityHelper;
30    type VisitorVTable = Self;
31    type ComputeVTable = NotSupported;
32    type EncodeVTable = NotSupported;
33    type SerdeVTable = Self;
34
35    fn id(_encoding: &Self::Encoding) -> EncodingId {
36        EncodingId::new_ref("vortex.struct")
37    }
38
39    fn encoding(_array: &Self::Array) -> EncodingRef {
40        EncodingRef::new_ref(StructEncoding.as_ref())
41    }
42}
43
44#[derive(Clone, Debug)]
45pub struct StructArray {
46    len: usize,
47    dtype: DType,
48    fields: Vec<ArrayRef>,
49    validity: Validity,
50    stats_set: ArrayStats,
51}
52
53#[derive(Clone, Debug)]
54pub struct StructEncoding;
55
56impl StructArray {
57    pub fn fields(&self) -> &[ArrayRef] {
58        &self.fields
59    }
60
61    pub fn field_by_name(&self, name: impl AsRef<str>) -> VortexResult<&ArrayRef> {
62        let name = name.as_ref();
63        self.field_by_name_opt(name).ok_or_else(|| {
64            vortex_err!(
65                "Field {name} not found in struct array with names {:?}",
66                self.names()
67            )
68        })
69    }
70
71    pub fn field_by_name_opt(&self, name: impl AsRef<str>) -> Option<&ArrayRef> {
72        let name = name.as_ref();
73        self.names()
74            .iter()
75            .position(|field_name| field_name.as_ref() == name)
76            .map(|idx| &self.fields[idx])
77    }
78
79    pub fn names(&self) -> &FieldNames {
80        self.struct_fields().names()
81    }
82
83    pub fn struct_fields(&self) -> &Arc<StructFields> {
84        let Some(struct_dtype) = &self.dtype.as_struct() else {
85            unreachable!(
86                "struct arrays must have be a DType::Struct, this is likely an internal bug."
87            )
88        };
89        struct_dtype
90    }
91
92    pub fn try_new(
93        names: FieldNames,
94        fields: Vec<ArrayRef>,
95        length: usize,
96        validity: Validity,
97    ) -> VortexResult<Self> {
98        let nullability = validity.nullability();
99
100        if names.len() != fields.len() {
101            vortex_bail!("Got {} names and {} fields", names.len(), fields.len());
102        }
103
104        for field in fields.iter() {
105            if field.len() != length {
106                vortex_bail!(
107                    "Expected all struct fields to have length {length}, found {}",
108                    field.len()
109                );
110            }
111        }
112
113        let field_dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect();
114        let dtype = DType::Struct(
115            Arc::new(StructFields::new(names, field_dtypes)),
116            nullability,
117        );
118
119        Ok(Self {
120            len: length,
121            dtype,
122            fields,
123            validity,
124            stats_set: Default::default(),
125        })
126    }
127
128    pub fn try_new_with_dtype(
129        fields: Vec<ArrayRef>,
130        dtype: Arc<StructFields>,
131        length: usize,
132        validity: Validity,
133    ) -> VortexResult<Self> {
134        for (field, struct_dt) in fields.iter().zip(dtype.fields()) {
135            if field.len() != length {
136                vortex_bail!(
137                    "Expected all struct fields to have length {length}, found {}",
138                    field.len()
139                );
140            }
141
142            if &struct_dt != field.dtype() {
143                vortex_bail!(
144                    "Expected all struct fields to have dtype {}, found {}",
145                    struct_dt,
146                    field.dtype()
147                );
148            }
149        }
150
151        Ok(Self {
152            len: length,
153            dtype: DType::Struct(dtype, validity.nullability()),
154            fields,
155            validity,
156            stats_set: Default::default(),
157        })
158    }
159
160    pub fn from_fields<N: AsRef<str>>(items: &[(N, ArrayRef)]) -> VortexResult<Self> {
161        let names = items.iter().map(|(name, _)| FieldName::from(name.as_ref()));
162        let fields: Vec<ArrayRef> = items.iter().map(|(_, array)| array.to_array()).collect();
163        let len = fields
164            .first()
165            .map(|f| f.len())
166            .ok_or_else(|| vortex_err!("StructArray cannot be constructed from an empty slice of arrays because the length is unspecified"))?;
167
168        Self::try_new(
169            FieldNames::from_iter(names),
170            fields,
171            len,
172            Validity::NonNullable,
173        )
174    }
175
176    // TODO(aduffy): Add equivalent function to support field masks for nested column access.
177    /// Return a new StructArray with the given projection applied.
178    ///
179    /// Projection does not copy data arrays. Projection is defined by an ordinal array slice
180    /// which specifies the new ordering of columns in the struct. The projection can be used to
181    /// perform column re-ordering, deletion, or duplication at a logical level, without any data
182    /// copying.
183    #[allow(clippy::same_name_method)]
184    pub fn project(&self, projection: &[FieldName]) -> VortexResult<Self> {
185        let mut children = Vec::with_capacity(projection.len());
186        let mut names = Vec::with_capacity(projection.len());
187
188        for f_name in projection.iter() {
189            let idx = self
190                .names()
191                .iter()
192                .position(|name| name == f_name)
193                .ok_or_else(|| vortex_err!("Unknown field {f_name}"))?;
194
195            names.push(self.names()[idx].clone());
196            children.push(self.fields()[idx].clone());
197        }
198
199        StructArray::try_new(
200            FieldNames::from(names.as_slice()),
201            children,
202            self.len(),
203            self.validity().clone(),
204        )
205    }
206}
207
208impl ValidityHelper for StructArray {
209    fn validity(&self) -> &Validity {
210        &self.validity
211    }
212}
213
214impl ArrayVTable<StructVTable> for StructVTable {
215    fn len(array: &StructArray) -> usize {
216        array.len
217    }
218
219    fn dtype(array: &StructArray) -> &DType {
220        &array.dtype
221    }
222
223    fn stats(array: &StructArray) -> StatsSetRef<'_> {
224        array.stats_set.to_ref(array.as_ref())
225    }
226}
227
228impl CanonicalVTable<StructVTable> for StructVTable {
229    fn canonicalize(array: &StructArray) -> VortexResult<Canonical> {
230        Ok(Canonical::Struct(array.clone()))
231    }
232}
233
234impl OperationsVTable<StructVTable> for StructVTable {
235    fn slice(array: &StructArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
236        let fields = array
237            .fields()
238            .iter()
239            .map(|field| field.slice(start, stop))
240            .try_collect()?;
241        StructArray::try_new_with_dtype(
242            fields,
243            array.struct_fields().clone(),
244            stop - start,
245            array.validity().slice(start, stop)?,
246        )
247        .map(|a| a.into_array())
248    }
249
250    fn scalar_at(array: &StructArray, index: usize) -> VortexResult<Scalar> {
251        Ok(Scalar::struct_(
252            array.dtype().clone(),
253            array
254                .fields()
255                .iter()
256                .map(|field| field.scalar_at(index))
257                .try_collect()?,
258        ))
259    }
260}
261
262#[cfg(test)]
263mod test {
264    use vortex_buffer::buffer;
265    use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
266
267    use crate::IntoArray;
268    use crate::arrays::primitive::PrimitiveArray;
269    use crate::arrays::struct_::StructArray;
270    use crate::arrays::varbin::VarBinArray;
271    use crate::arrays::{BoolArray, BoolVTable, PrimitiveVTable};
272    use crate::validity::Validity;
273
274    #[test]
275    fn test_project() {
276        let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
277        let ys = VarBinArray::from_vec(
278            vec!["a", "b", "c", "d", "e"],
279            DType::Utf8(Nullability::NonNullable),
280        );
281        let zs = BoolArray::from_iter([true, true, true, false, false]);
282
283        let struct_a = StructArray::try_new(
284            FieldNames::from(["xs".into(), "ys".into(), "zs".into()]),
285            vec![xs.into_array(), ys.into_array(), zs.into_array()],
286            5,
287            Validity::NonNullable,
288        )
289        .unwrap();
290
291        let struct_b = struct_a
292            .project(&[FieldName::from("zs"), FieldName::from("xs")])
293            .unwrap();
294        assert_eq!(
295            struct_b.names().as_ref(),
296            [FieldName::from("zs"), FieldName::from("xs")],
297        );
298
299        assert_eq!(struct_b.len(), 5);
300
301        let bools = &struct_b.fields[0];
302        assert_eq!(
303            bools
304                .as_::<BoolVTable>()
305                .boolean_buffer()
306                .iter()
307                .collect::<Vec<_>>(),
308            vec![true, true, true, false, false]
309        );
310
311        let prims = &struct_b.fields[1];
312        assert_eq!(
313            prims.as_::<PrimitiveVTable>().as_slice::<i64>(),
314            [0i64, 1, 2, 3, 4]
315        );
316    }
317}