vortex_array/builders/
struct_.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5
6use itertools::Itertools;
7use vortex_dtype::{DType, Nullability, StructFields};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_ensure, vortex_panic};
9use vortex_mask::Mask;
10use vortex_scalar::{Scalar, StructScalar};
11
12use crate::arrays::StructArray;
13use crate::builders::{
14    ArrayBuilder, DEFAULT_BUILDER_CAPACITY, LazyNullBufferBuilder, builder_with_capacity,
15};
16use crate::canonical::{Canonical, ToCanonical};
17use crate::{Array, ArrayRef, IntoArray};
18
19/// The builder for building a [`StructArray`].
20pub struct StructBuilder {
21    dtype: DType,
22    builders: Vec<Box<dyn ArrayBuilder>>,
23    nulls: LazyNullBufferBuilder,
24}
25
26impl StructBuilder {
27    /// Creates a new `StructBuilder` with a capacity of [`DEFAULT_BUILDER_CAPACITY`].
28    pub fn new(struct_dtype: StructFields, nullability: Nullability) -> Self {
29        Self::with_capacity(struct_dtype, nullability, DEFAULT_BUILDER_CAPACITY)
30    }
31
32    /// Creates a new `StructBuilder` with the given `capacity`.
33    pub fn with_capacity(
34        struct_dtype: StructFields,
35        nullability: Nullability,
36        capacity: usize,
37    ) -> Self {
38        let builders = struct_dtype
39            .fields()
40            .map(|dt| builder_with_capacity(&dt, capacity))
41            .collect();
42
43        Self {
44            builders,
45            nulls: LazyNullBufferBuilder::new(capacity),
46            dtype: DType::Struct(struct_dtype, nullability),
47        }
48    }
49
50    /// Appends a struct `value` to the builder.
51    pub fn append_value(&mut self, struct_scalar: StructScalar) -> VortexResult<()> {
52        if !self.dtype.is_nullable() && struct_scalar.is_null() {
53            vortex_bail!("Tried to append a null `StructScalar` to a non-nullable struct builder",);
54        }
55
56        if struct_scalar.struct_fields() != self.struct_fields() {
57            vortex_bail!(
58                "Tried to append a `StructScalar` with fields {} to a \
59                    struct builder with fields {}",
60                struct_scalar.struct_fields(),
61                self.struct_fields()
62            );
63        }
64
65        if let Some(fields) = struct_scalar.fields() {
66            for (builder, field) in self.builders.iter_mut().zip_eq(fields) {
67                builder.append_scalar(&field)?;
68            }
69            self.nulls.append_non_null();
70        } else {
71            self.append_null()
72        }
73
74        Ok(())
75    }
76
77    /// Finishes the builder directly into a [`StructArray`].
78    pub fn finish_into_struct(&mut self) -> StructArray {
79        let len = self.len();
80        let fields = self
81            .builders
82            .iter_mut()
83            .map(|builder| builder.finish())
84            .collect::<Vec<_>>();
85
86        if fields.len() > 1 {
87            let expected_length = fields[0].len();
88            for (index, field) in fields[1..].iter().enumerate() {
89                assert_eq!(
90                    field.len(),
91                    expected_length,
92                    "Field {index} does not have expected length {expected_length}"
93                );
94            }
95        }
96
97        let validity = self.nulls.finish_with_nullability(self.dtype.nullability());
98
99        StructArray::try_new_with_dtype(fields, self.struct_fields().clone(), len, validity)
100            .vortex_expect("Fields must all have same length.")
101    }
102
103    /// The [`StructFields`] of this struct builder.
104    pub fn struct_fields(&self) -> &StructFields {
105        let DType::Struct(struct_fields, _) = &self.dtype else {
106            vortex_panic!("`StructBuilder` somehow had dtype {}", self.dtype);
107        };
108
109        struct_fields
110    }
111}
112
113impl ArrayBuilder for StructBuilder {
114    fn as_any(&self) -> &dyn Any {
115        self
116    }
117
118    fn as_any_mut(&mut self) -> &mut dyn Any {
119        self
120    }
121
122    fn dtype(&self) -> &DType {
123        &self.dtype
124    }
125
126    fn len(&self) -> usize {
127        self.nulls.len()
128    }
129
130    fn append_zeros(&mut self, n: usize) {
131        self.builders
132            .iter_mut()
133            .for_each(|builder| builder.append_zeros(n));
134        self.nulls.append_n_non_nulls(n);
135    }
136
137    unsafe fn append_nulls_unchecked(&mut self, n: usize) {
138        self.builders
139            .iter_mut()
140            // We push zero values into our children when appending a null in case the children are
141            // themselves non-nullable.
142            .for_each(|builder| builder.append_defaults(n));
143        self.nulls.append_null();
144    }
145
146    fn append_scalar(&mut self, scalar: &Scalar) -> VortexResult<()> {
147        vortex_ensure!(
148            scalar.dtype() == self.dtype(),
149            "StructBuilder expected scalar with dtype {:?}, got {:?}",
150            self.dtype(),
151            scalar.dtype()
152        );
153
154        let struct_scalar = StructScalar::try_from(scalar)?;
155        self.append_value(struct_scalar)
156    }
157
158    unsafe fn extend_from_array_unchecked(&mut self, array: &dyn Array) {
159        let array = array.to_struct();
160
161        for (a, builder) in (0..array.struct_fields().nfields())
162            .map(|i| &array.fields()[i])
163            .zip_eq(self.builders.iter_mut())
164        {
165            a.append_to_builder(builder.as_mut());
166        }
167
168        self.nulls.append_validity_mask(array.validity_mask());
169    }
170
171    fn ensure_capacity(&mut self, capacity: usize) {
172        self.builders.iter_mut().for_each(|builder| {
173            builder.ensure_capacity(capacity);
174        });
175        self.nulls.ensure_capacity(capacity);
176    }
177
178    fn set_validity(&mut self, validity: Mask) {
179        self.nulls = LazyNullBufferBuilder::new(validity.len());
180        self.nulls.append_validity_mask(validity);
181    }
182
183    fn finish(&mut self) -> ArrayRef {
184        self.finish_into_struct().into_array()
185    }
186
187    fn finish_into_canonical(&mut self) -> Canonical {
188        Canonical::Struct(self.finish_into_struct())
189    }
190}
191
192#[cfg(test)]
193mod tests {
194
195    use vortex_dtype::PType::I32;
196    use vortex_dtype::{DType, Nullability, StructFields};
197    use vortex_scalar::Scalar;
198
199    use crate::builders::ArrayBuilder;
200    use crate::builders::struct_::StructBuilder;
201
202    #[test]
203    fn test_struct_builder() {
204        let sdt = StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]);
205        let dtype = DType::Struct(sdt.clone(), Nullability::NonNullable);
206        let mut builder = StructBuilder::with_capacity(sdt, Nullability::NonNullable, 0);
207
208        builder
209            .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
210            .unwrap();
211
212        let struct_ = builder.finish();
213        assert_eq!(struct_.len(), 1);
214        assert_eq!(struct_.dtype(), &dtype);
215    }
216
217    #[test]
218    fn test_append_nullable_struct() {
219        let sdt = StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]);
220        let dtype = DType::Struct(sdt.clone(), Nullability::Nullable);
221        let mut builder = StructBuilder::with_capacity(sdt, Nullability::Nullable, 0);
222
223        builder
224            .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
225            .unwrap();
226
227        let struct_ = builder.finish();
228        assert_eq!(struct_.len(), 1);
229        assert_eq!(struct_.dtype(), &dtype);
230    }
231
232    #[test]
233    fn test_append_scalar() {
234        use vortex_scalar::Scalar;
235
236        let dtype = DType::Struct(
237            StructFields::from_iter([
238                ("a", DType::Primitive(I32, Nullability::Nullable)),
239                ("b", DType::Utf8(Nullability::Nullable)),
240            ]),
241            Nullability::Nullable,
242        );
243
244        let struct_fields = match &dtype {
245            DType::Struct(fields, _) => fields.clone(),
246            _ => panic!("Expected struct dtype"),
247        };
248        let mut builder = StructBuilder::new(struct_fields, Nullability::Nullable);
249
250        // Test appending a valid struct value.
251        let struct_scalar1 = Scalar::struct_(
252            dtype.clone(),
253            vec![
254                Scalar::primitive(42i32, Nullability::Nullable),
255                Scalar::utf8("hello", Nullability::Nullable),
256            ],
257        );
258        builder.append_scalar(&struct_scalar1).unwrap();
259
260        // Test appending another struct value.
261        let struct_scalar2 = Scalar::struct_(
262            dtype.clone(),
263            vec![
264                Scalar::primitive(84i32, Nullability::Nullable),
265                Scalar::utf8("world", Nullability::Nullable),
266            ],
267        );
268        builder.append_scalar(&struct_scalar2).unwrap();
269
270        // Test appending null value.
271        let null_scalar = Scalar::null(dtype.clone());
272        builder.append_scalar(&null_scalar).unwrap();
273
274        let array = builder.finish_into_struct();
275        assert_eq!(array.len(), 3);
276
277        // Check actual values using scalar_at.
278
279        let scalar0 = array.scalar_at(0);
280        let struct0 = scalar0.as_struct();
281        if let Some(fields0) = struct0.fields() {
282            let fields0 = fields0.collect::<Vec<_>>();
283            assert_eq!(fields0[0].as_primitive().typed_value::<i32>(), Some(42));
284            assert_eq!(fields0[1].as_utf8().value().as_deref(), Some("hello"));
285        }
286
287        let scalar1 = array.scalar_at(1);
288        let struct1 = scalar1.as_struct();
289        if let Some(fields1) = struct1.fields() {
290            let fields1 = fields1.collect::<Vec<_>>();
291            assert_eq!(fields1[0].as_primitive().typed_value::<i32>(), Some(84));
292            assert_eq!(fields1[1].as_utf8().value().as_deref(), Some("world"));
293        }
294
295        let scalar2 = array.scalar_at(2);
296        let struct2 = scalar2.as_struct();
297        assert!(struct2.fields().is_none()); // Null struct has no fields.
298
299        // Check validity - first two should be valid, third should be null.
300        use crate::vtable::ValidityHelper;
301        assert!(array.validity().is_valid(0));
302        assert!(array.validity().is_valid(1));
303        assert!(!array.validity().is_valid(2));
304
305        // Test wrong dtype error.
306        let struct_fields = match &dtype {
307            DType::Struct(fields, _) => fields.clone(),
308            _ => panic!("Expected struct dtype"),
309        };
310        let mut builder = StructBuilder::new(struct_fields, Nullability::NonNullable);
311        let wrong_scalar = Scalar::from(42i32);
312        assert!(builder.append_scalar(&wrong_scalar).is_err());
313    }
314}