vortex_array/arrays/struct_/
mod.rs

1use std::fmt::Debug;
2use std::sync::Arc;
3
4use itertools::Itertools;
5use vortex_dtype::{DType, FieldName, FieldNames, StructDType};
6use vortex_error::{VortexResult, vortex_bail, vortex_err};
7use vortex_scalar::Scalar;
8
9use crate::stats::{ArrayStats, StatsSetRef};
10use crate::validity::Validity;
11use crate::vtable::{
12    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
13    ValidityVTableFromValidityHelper,
14};
15use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
16
17mod compute;
18mod serde;
19
20vtable!(Struct);
21
22impl VTable for StructVTable {
23    type Array = StructArray;
24    type Encoding = StructEncoding;
25
26    type ArrayVTable = Self;
27    type CanonicalVTable = Self;
28    type OperationsVTable = Self;
29    type ValidityVTable = ValidityVTableFromValidityHelper;
30    type VisitorVTable = Self;
31    type ComputeVTable = NotSupported;
32    type EncodeVTable = NotSupported;
33    type SerdeVTable = Self;
34
35    fn id(_encoding: &Self::Encoding) -> EncodingId {
36        EncodingId::new_ref("vortex.struct")
37    }
38
39    fn encoding(_array: &Self::Array) -> EncodingRef {
40        EncodingRef::new_ref(StructEncoding.as_ref())
41    }
42}
43
44#[derive(Clone, Debug)]
45pub struct StructArray {
46    len: usize,
47    dtype: DType,
48    fields: Vec<ArrayRef>,
49    validity: Validity,
50    stats_set: ArrayStats,
51}
52
53#[derive(Clone, Debug)]
54pub struct StructEncoding;
55
56impl StructArray {
57    pub fn fields(&self) -> &[ArrayRef] {
58        &self.fields
59    }
60
61    pub fn field_by_name(&self, name: impl AsRef<str>) -> VortexResult<&ArrayRef> {
62        let name = name.as_ref();
63        self.field_by_name_opt(name).ok_or_else(|| {
64            vortex_err!(
65                "Field {name} not found in struct array with names {:?}",
66                self.names()
67            )
68        })
69    }
70
71    pub fn field_by_name_opt(&self, name: impl AsRef<str>) -> Option<&ArrayRef> {
72        let name = name.as_ref();
73        self.names()
74            .iter()
75            .position(|field_name| field_name.as_ref() == name)
76            .map(|idx| &self.fields[idx])
77    }
78
79    pub fn names(&self) -> &FieldNames {
80        self.struct_dtype().names()
81    }
82
83    pub fn struct_dtype(&self) -> &Arc<StructDType> {
84        let Some(struct_dtype) = &self.dtype.as_struct() else {
85            unreachable!(
86                "struct arrays must have be a DType::Struct, this is likely an internal bug."
87            )
88        };
89        struct_dtype
90    }
91
92    pub fn try_new(
93        names: FieldNames,
94        fields: Vec<ArrayRef>,
95        length: usize,
96        validity: Validity,
97    ) -> VortexResult<Self> {
98        let nullability = validity.nullability();
99
100        if names.len() != fields.len() {
101            vortex_bail!("Got {} names and {} fields", names.len(), fields.len());
102        }
103
104        for field in fields.iter() {
105            if field.len() != length {
106                vortex_bail!(
107                    "Expected all struct fields to have length {length}, found {}",
108                    field.len()
109                );
110            }
111        }
112
113        let field_dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect();
114        let dtype = DType::Struct(Arc::new(StructDType::new(names, field_dtypes)), nullability);
115
116        Ok(Self {
117            len: length,
118            dtype,
119            fields,
120            validity,
121            stats_set: Default::default(),
122        })
123    }
124
125    pub fn try_new_with_dtype(
126        fields: Vec<ArrayRef>,
127        dtype: Arc<StructDType>,
128        length: usize,
129        validity: Validity,
130    ) -> VortexResult<Self> {
131        for (field, struct_dt) in fields.iter().zip(dtype.fields()) {
132            if field.len() != length {
133                vortex_bail!(
134                    "Expected all struct fields to have length {length}, found {}",
135                    field.len()
136                );
137            }
138
139            if &struct_dt != field.dtype() {
140                vortex_bail!(
141                    "Expected all struct fields to have dtype {}, found {}",
142                    struct_dt,
143                    field.dtype()
144                );
145            }
146        }
147
148        Ok(Self {
149            len: length,
150            dtype: DType::Struct(dtype, validity.nullability()),
151            fields,
152            validity,
153            stats_set: Default::default(),
154        })
155    }
156
157    pub fn from_fields<N: AsRef<str>>(items: &[(N, ArrayRef)]) -> VortexResult<Self> {
158        let names = items.iter().map(|(name, _)| FieldName::from(name.as_ref()));
159        let fields: Vec<ArrayRef> = items.iter().map(|(_, array)| array.to_array()).collect();
160        let len = fields
161            .first()
162            .map(|f| f.len())
163            .ok_or_else(|| vortex_err!("StructArray cannot be constructed from an empty slice of arrays because the length is unspecified"))?;
164
165        Self::try_new(
166            FieldNames::from_iter(names),
167            fields,
168            len,
169            Validity::NonNullable,
170        )
171    }
172
173    // TODO(aduffy): Add equivalent function to support field masks for nested column access.
174    /// Return a new StructArray with the given projection applied.
175    ///
176    /// Projection does not copy data arrays. Projection is defined by an ordinal array slice
177    /// which specifies the new ordering of columns in the struct. The projection can be used to
178    /// perform column re-ordering, deletion, or duplication at a logical level, without any data
179    /// copying.
180    #[allow(clippy::same_name_method)]
181    pub fn project(&self, projection: &[FieldName]) -> VortexResult<Self> {
182        let mut children = Vec::with_capacity(projection.len());
183        let mut names = Vec::with_capacity(projection.len());
184
185        for f_name in projection.iter() {
186            let idx = self
187                .names()
188                .iter()
189                .position(|name| name == f_name)
190                .ok_or_else(|| vortex_err!("Unknown field {f_name}"))?;
191
192            names.push(self.names()[idx].clone());
193            children.push(self.fields()[idx].clone());
194        }
195
196        StructArray::try_new(
197            FieldNames::from(names.as_slice()),
198            children,
199            self.len(),
200            self.validity().clone(),
201        )
202    }
203}
204
205impl ValidityHelper for StructArray {
206    fn validity(&self) -> &Validity {
207        &self.validity
208    }
209}
210
211impl ArrayVTable<StructVTable> for StructVTable {
212    fn len(array: &StructArray) -> usize {
213        array.len
214    }
215
216    fn dtype(array: &StructArray) -> &DType {
217        &array.dtype
218    }
219
220    fn stats(array: &StructArray) -> StatsSetRef<'_> {
221        array.stats_set.to_ref(array.as_ref())
222    }
223}
224
225impl CanonicalVTable<StructVTable> for StructVTable {
226    fn canonicalize(array: &StructArray) -> VortexResult<Canonical> {
227        Ok(Canonical::Struct(array.clone()))
228    }
229}
230
231impl OperationsVTable<StructVTable> for StructVTable {
232    fn slice(array: &StructArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
233        let fields = array
234            .fields()
235            .iter()
236            .map(|field| field.slice(start, stop))
237            .try_collect()?;
238        StructArray::try_new_with_dtype(
239            fields,
240            array.struct_dtype().clone(),
241            stop - start,
242            array.validity().slice(start, stop)?,
243        )
244        .map(|a| a.into_array())
245    }
246
247    fn scalar_at(array: &StructArray, index: usize) -> VortexResult<Scalar> {
248        Ok(Scalar::struct_(
249            array.dtype().clone(),
250            array
251                .fields()
252                .iter()
253                .map(|field| field.scalar_at(index))
254                .try_collect()?,
255        ))
256    }
257}
258
259#[cfg(test)]
260mod test {
261    use vortex_buffer::buffer;
262    use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
263
264    use crate::IntoArray;
265    use crate::arrays::primitive::PrimitiveArray;
266    use crate::arrays::struct_::StructArray;
267    use crate::arrays::varbin::VarBinArray;
268    use crate::arrays::{BoolArray, BoolVTable, PrimitiveVTable};
269    use crate::validity::Validity;
270
271    #[test]
272    fn test_project() {
273        let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
274        let ys = VarBinArray::from_vec(
275            vec!["a", "b", "c", "d", "e"],
276            DType::Utf8(Nullability::NonNullable),
277        );
278        let zs = BoolArray::from_iter([true, true, true, false, false]);
279
280        let struct_a = StructArray::try_new(
281            FieldNames::from(["xs".into(), "ys".into(), "zs".into()]),
282            vec![xs.into_array(), ys.into_array(), zs.into_array()],
283            5,
284            Validity::NonNullable,
285        )
286        .unwrap();
287
288        let struct_b = struct_a
289            .project(&[FieldName::from("zs"), FieldName::from("xs")])
290            .unwrap();
291        assert_eq!(
292            struct_b.names().as_ref(),
293            [FieldName::from("zs"), FieldName::from("xs")],
294        );
295
296        assert_eq!(struct_b.len(), 5);
297
298        let bools = &struct_b.fields[0];
299        assert_eq!(
300            bools
301                .as_::<BoolVTable>()
302                .boolean_buffer()
303                .iter()
304                .collect::<Vec<_>>(),
305            vec![true, true, true, false, false]
306        );
307
308        let prims = &struct_b.fields[1];
309        assert_eq!(
310            prims.as_::<PrimitiveVTable>().as_slice::<i64>(),
311            [0i64, 1, 2, 3, 4]
312        );
313    }
314}