vortex_array/arrays/struct_/
mod.rs

1use std::fmt::Debug;
2use std::sync::Arc;
3
4use vortex_dtype::{DType, FieldName, FieldNames, StructDType};
5use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
6use vortex_mask::Mask;
7
8use crate::array::{ArrayCanonicalImpl, ArrayValidityImpl};
9use crate::stats::{ArrayStats, Precision, Stat, StatsSet, StatsSetRef};
10use crate::validity::Validity;
11use crate::variants::StructArrayTrait;
12use crate::vtable::{EncodingVTable, StatisticsVTable, VTableRef};
13use crate::{
14    Array, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayVariantsImpl, Canonical, EmptyMetadata,
15    Encoding, EncodingId,
16};
17mod compute;
18mod serde;
19
20#[derive(Clone, Debug)]
21pub struct StructArray {
22    len: usize,
23    dtype: DType,
24    fields: Vec<ArrayRef>,
25    validity: Validity,
26    stats_set: ArrayStats,
27}
28
29pub struct StructEncoding;
30impl Encoding for StructEncoding {
31    type Array = StructArray;
32    type Metadata = EmptyMetadata;
33}
34
35impl EncodingVTable for StructEncoding {
36    fn id(&self) -> EncodingId {
37        EncodingId::new_ref("vortex.struct")
38    }
39}
40
41impl StructArray {
42    pub fn validity(&self) -> &Validity {
43        &self.validity
44    }
45
46    pub fn fields(&self) -> &[ArrayRef] {
47        &self.fields
48    }
49
50    pub fn try_new(
51        names: FieldNames,
52        fields: Vec<ArrayRef>,
53        length: usize,
54        validity: Validity,
55    ) -> VortexResult<Self> {
56        let nullability = validity.nullability();
57
58        if names.len() != fields.len() {
59            vortex_bail!("Got {} names and {} fields", names.len(), fields.len());
60        }
61
62        for field in fields.iter() {
63            if field.len() != length {
64                vortex_bail!(
65                    "Expected all struct fields to have length {length}, found {}",
66                    field.len()
67                );
68            }
69        }
70
71        let field_dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect();
72        let dtype = DType::Struct(Arc::new(StructDType::new(names, field_dtypes)), nullability);
73
74        Ok(Self {
75            len: length,
76            dtype,
77            fields,
78            validity,
79            stats_set: Default::default(),
80        })
81    }
82
83    pub fn from_fields<N: AsRef<str>>(items: &[(N, ArrayRef)]) -> VortexResult<Self> {
84        let names = items.iter().map(|(name, _)| FieldName::from(name.as_ref()));
85        let fields: Vec<ArrayRef> = items.iter().map(|(_, array)| array.to_array()).collect();
86        let len = fields
87            .first()
88            .map(|f| f.len())
89            .ok_or_else(|| vortex_err!("StructArray cannot be constructed from an empty slice of arrays because the length is unspecified"))?;
90
91        Self::try_new(
92            FieldNames::from_iter(names),
93            fields,
94            len,
95            Validity::NonNullable,
96        )
97    }
98
99    // TODO(aduffy): Add equivalent function to support field masks for nested column access.
100    /// Return a new StructArray with the given projection applied.
101    ///
102    /// Projection does not copy data arrays. Projection is defined by an ordinal array slice
103    /// which specifies the new ordering of columns in the struct. The projection can be used to
104    /// perform column re-ordering, deletion, or duplication at a logical level, without any data
105    /// copying.
106    #[allow(clippy::same_name_method)]
107    pub fn project(&self, projection: &[FieldName]) -> VortexResult<Self> {
108        let mut children = Vec::with_capacity(projection.len());
109        let mut names = Vec::with_capacity(projection.len());
110
111        for f_name in projection.iter() {
112            let idx = self
113                .names()
114                .iter()
115                .position(|name| name == f_name)
116                .ok_or_else(|| vortex_err!("Unknown field {f_name}"))?;
117
118            names.push(self.names()[idx].clone());
119            children.push(
120                self.maybe_null_field_by_idx(idx)
121                    .vortex_expect("never out of bounds"),
122            );
123        }
124
125        StructArray::try_new(
126            FieldNames::from(names.as_slice()),
127            children,
128            self.len(),
129            self.validity().clone(),
130        )
131    }
132}
133
134impl ArrayImpl for StructArray {
135    type Encoding = StructEncoding;
136
137    fn _len(&self) -> usize {
138        self.len
139    }
140
141    fn _dtype(&self) -> &DType {
142        &self.dtype
143    }
144
145    fn _vtable(&self) -> VTableRef {
146        VTableRef::new_ref(&StructEncoding)
147    }
148}
149
150impl ArrayStatisticsImpl for StructArray {
151    fn _stats_ref(&self) -> StatsSetRef<'_> {
152        self.stats_set.to_ref(self)
153    }
154}
155
156impl ArrayVariantsImpl for StructArray {
157    fn _as_struct_typed(&self) -> Option<&dyn StructArrayTrait> {
158        Some(self)
159    }
160}
161
162impl StructArrayTrait for StructArray {
163    fn maybe_null_field_by_idx(&self, idx: usize) -> VortexResult<ArrayRef> {
164        Ok(self.fields[idx].clone())
165    }
166
167    fn project(&self, projection: &[FieldName]) -> VortexResult<ArrayRef> {
168        self.project(projection).map(|a| a.into_array())
169    }
170}
171
172impl ArrayCanonicalImpl for StructArray {
173    fn _to_canonical(&self) -> VortexResult<Canonical> {
174        Ok(Canonical::Struct(self.clone()))
175    }
176}
177
178impl ArrayValidityImpl for StructArray {
179    fn _is_valid(&self, index: usize) -> VortexResult<bool> {
180        self.validity.is_valid(index)
181    }
182
183    fn _all_valid(&self) -> VortexResult<bool> {
184        self.validity.all_valid()
185    }
186
187    fn _all_invalid(&self) -> VortexResult<bool> {
188        self.validity.all_invalid()
189    }
190
191    fn _validity_mask(&self) -> VortexResult<Mask> {
192        self.validity.to_logical(self.len())
193    }
194}
195
196impl StatisticsVTable<&StructArray> for StructEncoding {
197    fn compute_statistics(&self, array: &StructArray, stat: Stat) -> VortexResult<StatsSet> {
198        Ok(match stat {
199            Stat::NullCount => StatsSet::of(
200                stat,
201                Precision::exact(array.validity().null_count(array.len())?),
202            ),
203            _ => StatsSet::default(),
204        })
205    }
206}
207
208#[cfg(test)]
209mod test {
210    use vortex_buffer::buffer;
211    use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
212
213    use crate::ArrayExt;
214    use crate::array::Array;
215    use crate::arrays::BoolArray;
216    use crate::arrays::primitive::PrimitiveArray;
217    use crate::arrays::struct_::StructArray;
218    use crate::arrays::varbin::VarBinArray;
219    use crate::validity::Validity;
220    use crate::variants::StructArrayTrait;
221
222    #[test]
223    fn test_project() {
224        let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
225        let ys = VarBinArray::from_vec(
226            vec!["a", "b", "c", "d", "e"],
227            DType::Utf8(Nullability::NonNullable),
228        );
229        let zs = BoolArray::from_iter([true, true, true, false, false]);
230
231        let struct_a = StructArray::try_new(
232            FieldNames::from(["xs".into(), "ys".into(), "zs".into()]),
233            vec![xs.into_array(), ys.into_array(), zs.into_array()],
234            5,
235            Validity::NonNullable,
236        )
237        .unwrap();
238
239        let struct_b = struct_a
240            .project(&[FieldName::from("zs"), FieldName::from("xs")])
241            .unwrap();
242        assert_eq!(
243            struct_b.names().as_ref(),
244            [FieldName::from("zs"), FieldName::from("xs")],
245        );
246
247        assert_eq!(struct_b.len(), 5);
248
249        let bools = struct_b.maybe_null_field_by_idx(0).unwrap();
250        assert_eq!(
251            bools
252                .as_::<BoolArray>()
253                .boolean_buffer()
254                .iter()
255                .collect::<Vec<_>>(),
256            vec![true, true, true, false, false]
257        );
258
259        let prims = struct_b.maybe_null_field_by_idx(1).unwrap();
260        assert_eq!(
261            prims.as_::<PrimitiveArray>().as_slice::<i64>(),
262            [0i64, 1, 2, 3, 4]
263        );
264    }
265}