vortex_array/builders/
struct_.rs

1use std::any::Any;
2use std::sync::Arc;
3
4use itertools::Itertools;
5use vortex_dtype::{DType, Nullability, StructDType};
6use vortex_error::{VortexExpect, VortexResult, vortex_bail};
7use vortex_mask::Mask;
8use vortex_scalar::StructScalar;
9
10use crate::arrays::StructArray;
11use crate::builders::lazy_validity_builder::LazyNullBufferBuilder;
12use crate::builders::{ArrayBuilder, ArrayBuilderExt, builder_with_capacity};
13use crate::{Array, ArrayRef, IntoArray, ToCanonical};
14
15pub struct StructBuilder {
16    builders: Vec<Box<dyn ArrayBuilder>>,
17    validity: LazyNullBufferBuilder,
18    struct_dtype: Arc<StructDType>,
19    nullability: Nullability,
20    dtype: DType,
21}
22
23impl StructBuilder {
24    pub fn with_capacity(
25        struct_dtype: Arc<StructDType>,
26        nullability: Nullability,
27        capacity: usize,
28    ) -> Self {
29        let builders = struct_dtype
30            .fields()
31            .map(|dt| builder_with_capacity(&dt, capacity))
32            .collect();
33
34        Self {
35            builders,
36            validity: LazyNullBufferBuilder::new(capacity),
37            struct_dtype: struct_dtype.clone(),
38            nullability,
39            dtype: DType::Struct(struct_dtype, nullability),
40        }
41    }
42
43    pub fn append_value(&mut self, struct_scalar: StructScalar) -> VortexResult<()> {
44        if struct_scalar.dtype() != &DType::Struct(self.struct_dtype.clone(), self.nullability) {
45            vortex_bail!(
46                "Expected struct scalar with dtype {:?}, found {:?}",
47                self.struct_dtype,
48                struct_scalar.dtype()
49            )
50        }
51
52        if let Some(fields) = struct_scalar.fields() {
53            for (builder, field) in self.builders.iter_mut().zip_eq(fields) {
54                builder.append_scalar(&field)?;
55            }
56            self.validity.append_non_null();
57        } else {
58            self.append_null()
59        }
60
61        Ok(())
62    }
63}
64
65impl ArrayBuilder for StructBuilder {
66    fn as_any(&self) -> &dyn Any {
67        self
68    }
69
70    fn as_any_mut(&mut self) -> &mut dyn Any {
71        self
72    }
73
74    fn dtype(&self) -> &DType {
75        &self.dtype
76    }
77
78    fn len(&self) -> usize {
79        self.validity.len()
80    }
81
82    fn append_zeros(&mut self, n: usize) {
83        self.builders
84            .iter_mut()
85            .for_each(|builder| builder.append_zeros(n));
86        self.validity.append_n_non_nulls(n);
87    }
88
89    fn append_nulls(&mut self, n: usize) {
90        self.builders
91            .iter_mut()
92            // We push zero values into our children when appending a null in case the children are
93            // themselves non-nullable.
94            .for_each(|builder| builder.append_zeros(n));
95        self.validity.append_null();
96    }
97
98    fn extend_from_array(&mut self, array: &dyn Array) -> VortexResult<()> {
99        let array = array.to_struct()?;
100
101        if array.dtype() != self.dtype() {
102            vortex_bail!(
103                "Cannot extend from array with different dtype: expected {}, found {}",
104                self.dtype(),
105                array.dtype()
106            );
107        }
108
109        for (a, builder) in (0..array.struct_dtype().nfields())
110            .map(|i| &array.fields()[i])
111            .zip_eq(self.builders.iter_mut())
112        {
113            a.append_to_builder(builder.as_mut())?;
114        }
115
116        self.validity.append_validity_mask(array.validity_mask()?);
117        Ok(())
118    }
119
120    fn ensure_capacity(&mut self, capacity: usize) {
121        self.builders.iter_mut().for_each(|builder| {
122            builder.ensure_capacity(capacity);
123        });
124        self.validity.ensure_capacity(capacity);
125    }
126
127    fn set_validity(&mut self, validity: Mask) {
128        self.validity = LazyNullBufferBuilder::new(validity.len());
129        self.validity.append_validity_mask(validity);
130    }
131
132    fn finish(&mut self) -> ArrayRef {
133        let len = self.len();
134        let fields = self
135            .builders
136            .iter_mut()
137            .map(|builder| builder.finish())
138            .collect::<Vec<_>>();
139
140        if fields.len() > 1 {
141            let expected_length = fields[0].len();
142            for (index, field) in fields[1..].iter().enumerate() {
143                assert_eq!(
144                    field.len(),
145                    expected_length,
146                    "Field {} does not have expected length {}",
147                    index,
148                    expected_length
149                );
150            }
151        }
152
153        let validity = self.validity.finish_with_nullability(self.nullability);
154
155        StructArray::try_new_with_dtype(fields, self.struct_dtype.clone(), len, validity)
156            .vortex_expect("Fields must all have same length.")
157            .into_array()
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use std::sync::Arc;
164
165    use vortex_dtype::PType::I32;
166    use vortex_dtype::{DType, Nullability, StructDType};
167    use vortex_scalar::Scalar;
168
169    use crate::builders::ArrayBuilder;
170    use crate::builders::struct_::StructBuilder;
171
172    #[test]
173    fn test_struct_builder() {
174        let sdt = Arc::new(StructDType::new(
175            vec![Arc::from("a"), Arc::from("b")].into(),
176            vec![I32.into(), I32.into()],
177        ));
178        let dtype = DType::Struct(sdt.clone(), Nullability::NonNullable);
179        let mut builder = StructBuilder::with_capacity(sdt, Nullability::NonNullable, 0);
180
181        builder
182            .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
183            .unwrap();
184
185        let struct_ = builder.finish();
186        assert_eq!(struct_.len(), 1);
187        assert_eq!(struct_.dtype(), &dtype);
188    }
189
190    #[test]
191    fn test_append_nullable_struct() {
192        let sdt = Arc::new(StructDType::new(
193            vec![Arc::from("a"), Arc::from("b")].into(),
194            vec![I32.into(), I32.into()],
195        ));
196        let dtype = DType::Struct(sdt.clone(), Nullability::Nullable);
197        let mut builder = StructBuilder::with_capacity(sdt, Nullability::Nullable, 0);
198
199        builder
200            .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
201            .unwrap();
202
203        let struct_ = builder.finish();
204        assert_eq!(struct_.len(), 1);
205        assert_eq!(struct_.dtype(), &dtype);
206    }
207}