Skip to main content

vortex_array/builders/
struct_.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5
6use itertools::Itertools;
7use vortex_dtype::DType;
8use vortex_dtype::Nullability;
9use vortex_dtype::StructFields;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_ensure;
14use vortex_error::vortex_panic;
15use vortex_mask::Mask;
16
17use crate::Array;
18use crate::ArrayRef;
19use crate::IntoArray;
20use crate::arrays::StructArray;
21use crate::builders::ArrayBuilder;
22use crate::builders::DEFAULT_BUILDER_CAPACITY;
23use crate::builders::LazyBitBufferBuilder;
24use crate::builders::builder_with_capacity;
25use crate::canonical::Canonical;
26use crate::canonical::ToCanonical;
27use crate::scalar::Scalar;
28use crate::scalar::StructScalar;
29
30/// The builder for building a [`StructArray`].
31pub struct StructBuilder {
32    dtype: DType,
33    builders: Vec<Box<dyn ArrayBuilder>>,
34    nulls: LazyBitBufferBuilder,
35}
36
37impl StructBuilder {
38    /// Creates a new `StructBuilder` with a capacity of [`DEFAULT_BUILDER_CAPACITY`].
39    pub fn new(struct_dtype: StructFields, nullability: Nullability) -> Self {
40        Self::with_capacity(struct_dtype, nullability, DEFAULT_BUILDER_CAPACITY)
41    }
42
43    /// Creates a new `StructBuilder` with the given `capacity`.
44    pub fn with_capacity(
45        struct_dtype: StructFields,
46        nullability: Nullability,
47        capacity: usize,
48    ) -> Self {
49        let builders = struct_dtype
50            .fields()
51            .map(|dt| builder_with_capacity(&dt, capacity))
52            .collect();
53
54        Self {
55            builders,
56            nulls: LazyBitBufferBuilder::new(capacity),
57            dtype: DType::Struct(struct_dtype, nullability),
58        }
59    }
60
61    /// Appends a struct `value` to the builder.
62    pub fn append_value(&mut self, struct_scalar: StructScalar) -> VortexResult<()> {
63        if !self.dtype.is_nullable() && struct_scalar.is_null() {
64            vortex_bail!("Tried to append a null `StructScalar` to a non-nullable struct builder",);
65        }
66
67        if struct_scalar.struct_fields() != self.struct_fields() {
68            vortex_bail!(
69                "Tried to append a `StructScalar` with fields {} to a \
70                    struct builder with fields {}",
71                struct_scalar.struct_fields(),
72                self.struct_fields()
73            );
74        }
75
76        if let Some(fields) = struct_scalar.fields_iter() {
77            for (builder, field) in self.builders.iter_mut().zip_eq(fields) {
78                builder.append_scalar(&field)?;
79            }
80            self.nulls.append_non_null();
81        } else {
82            self.append_null()
83        }
84
85        Ok(())
86    }
87
88    /// Finishes the builder directly into a [`StructArray`].
89    pub fn finish_into_struct(&mut self) -> StructArray {
90        let len = self.len();
91        let fields = self
92            .builders
93            .iter_mut()
94            .map(|builder| builder.finish())
95            .collect::<Vec<_>>();
96
97        if fields.len() > 1 {
98            let expected_length = fields[0].len();
99            for (index, field) in fields[1..].iter().enumerate() {
100                assert_eq!(
101                    field.len(),
102                    expected_length,
103                    "Field {index} does not have expected length {expected_length}"
104                );
105            }
106        }
107
108        let validity = self.nulls.finish_with_nullability(self.dtype.nullability());
109
110        StructArray::try_new_with_dtype(fields, self.struct_fields().clone(), len, validity)
111            .vortex_expect("Fields must all have same length.")
112    }
113
114    /// The [`StructFields`] of this struct builder.
115    pub fn struct_fields(&self) -> &StructFields {
116        let DType::Struct(struct_fields, _) = &self.dtype else {
117            vortex_panic!("`StructBuilder` somehow had dtype {}", self.dtype);
118        };
119
120        struct_fields
121    }
122}
123
124impl ArrayBuilder for StructBuilder {
125    fn as_any(&self) -> &dyn Any {
126        self
127    }
128
129    fn as_any_mut(&mut self) -> &mut dyn Any {
130        self
131    }
132
133    fn dtype(&self) -> &DType {
134        &self.dtype
135    }
136
137    fn len(&self) -> usize {
138        self.nulls.len()
139    }
140
141    fn append_zeros(&mut self, n: usize) {
142        self.builders
143            .iter_mut()
144            .for_each(|builder| builder.append_zeros(n));
145        self.nulls.append_n_non_nulls(n);
146    }
147
148    unsafe fn append_nulls_unchecked(&mut self, n: usize) {
149        self.builders
150            .iter_mut()
151            // We push zero values into our children when appending a null in case the children are
152            // themselves non-nullable.
153            .for_each(|builder| builder.append_defaults(n));
154        self.nulls.append_n_nulls(n);
155    }
156
157    fn append_scalar(&mut self, scalar: &Scalar) -> VortexResult<()> {
158        vortex_ensure!(
159            scalar.dtype() == self.dtype(),
160            "StructBuilder expected scalar with dtype {}, got {}",
161            self.dtype(),
162            scalar.dtype()
163        );
164
165        self.append_value(scalar.as_struct())
166    }
167
168    unsafe fn extend_from_array_unchecked(&mut self, array: &dyn Array) {
169        let array = array.to_struct();
170
171        for (a, builder) in array
172            .unmasked_fields()
173            .iter()
174            .zip_eq(self.builders.iter_mut())
175        {
176            builder.extend_from_array(a.as_ref());
177        }
178
179        self.nulls.append_validity_mask(
180            array
181                .validity_mask()
182                .vortex_expect("validity_mask in extend_from_array_unchecked"),
183        );
184    }
185
186    fn reserve_exact(&mut self, capacity: usize) {
187        self.builders.iter_mut().for_each(|builder| {
188            builder.reserve_exact(capacity);
189        });
190        self.nulls.reserve_exact(capacity);
191    }
192
193    unsafe fn set_validity_unchecked(&mut self, validity: Mask) {
194        self.nulls = LazyBitBufferBuilder::new(validity.len());
195        self.nulls.append_validity_mask(validity);
196    }
197
198    fn finish(&mut self) -> ArrayRef {
199        self.finish_into_struct().into_array()
200    }
201
202    fn finish_into_canonical(&mut self) -> Canonical {
203        Canonical::Struct(self.finish_into_struct())
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use vortex_dtype::DType;
210    use vortex_dtype::Nullability;
211    use vortex_dtype::PType::I32;
212    use vortex_dtype::StructFields;
213
214    use crate::IntoArray;
215    use crate::arrays::PrimitiveArray;
216    use crate::arrays::StructArray;
217    use crate::arrays::VarBinArray;
218    use crate::assert_arrays_eq;
219    use crate::builders::ArrayBuilder;
220    use crate::builders::struct_::StructBuilder;
221    use crate::scalar::Scalar;
222    use crate::validity::Validity;
223
224    #[test]
225    fn test_struct_builder() {
226        let sdt = StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]);
227        let dtype = DType::Struct(sdt.clone(), Nullability::NonNullable);
228        let mut builder = StructBuilder::with_capacity(sdt, Nullability::NonNullable, 0);
229
230        builder
231            .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
232            .unwrap();
233
234        let struct_ = builder.finish();
235        assert_eq!(struct_.len(), 1);
236        assert_eq!(struct_.dtype(), &dtype);
237    }
238
239    #[test]
240    fn test_append_nullable_struct() {
241        let sdt = StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]);
242        let dtype = DType::Struct(sdt.clone(), Nullability::Nullable);
243        let mut builder = StructBuilder::with_capacity(sdt, Nullability::Nullable, 0);
244
245        builder
246            .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
247            .unwrap();
248
249        builder.append_nulls(2);
250
251        let struct_ = builder.finish();
252        assert_eq!(struct_.len(), 3);
253        assert_eq!(struct_.dtype(), &dtype);
254        assert_eq!(struct_.valid_count().unwrap(), 1);
255    }
256
257    #[test]
258    fn test_append_scalar() {
259        use crate::scalar::Scalar;
260
261        let dtype = DType::Struct(
262            StructFields::from_iter([
263                ("a", DType::Primitive(I32, Nullability::Nullable)),
264                ("b", DType::Utf8(Nullability::Nullable)),
265            ]),
266            Nullability::Nullable,
267        );
268
269        let struct_fields = match &dtype {
270            DType::Struct(fields, _) => fields.clone(),
271            _ => panic!("Expected struct dtype"),
272        };
273        let mut builder = StructBuilder::new(struct_fields, Nullability::Nullable);
274
275        // Test appending a valid struct value.
276        let struct_scalar1 = Scalar::struct_(
277            dtype.clone(),
278            vec![
279                Scalar::primitive(42i32, Nullability::Nullable),
280                Scalar::utf8("hello", Nullability::Nullable),
281            ],
282        );
283        builder.append_scalar(&struct_scalar1).unwrap();
284
285        // Test appending another struct value.
286        let struct_scalar2 = Scalar::struct_(
287            dtype.clone(),
288            vec![
289                Scalar::primitive(84i32, Nullability::Nullable),
290                Scalar::utf8("world", Nullability::Nullable),
291            ],
292        );
293        builder.append_scalar(&struct_scalar2).unwrap();
294
295        // Test appending null value.
296        let null_scalar = Scalar::null(dtype.clone());
297        builder.append_scalar(&null_scalar).unwrap();
298
299        let array = builder.finish_into_struct();
300
301        let expected = StructArray::try_from_iter_with_validity(
302            [
303                (
304                    "a",
305                    PrimitiveArray::from_option_iter([Some(42i32), Some(84), Some(123)])
306                        .into_array(),
307                ),
308                (
309                    "b",
310                    <VarBinArray as FromIterator<_>>::from_iter([
311                        Some("hello"),
312                        Some("world"),
313                        Some("x"),
314                    ])
315                    .into_array(),
316                ),
317            ],
318            Validity::from_iter([true, true, false]),
319        )
320        .unwrap();
321        assert_arrays_eq!(&array, &expected);
322
323        // Test wrong dtype error.
324        let struct_fields = match &dtype {
325            DType::Struct(fields, _) => fields.clone(),
326            _ => panic!("Expected struct dtype"),
327        };
328        let mut builder = StructBuilder::new(struct_fields, Nullability::NonNullable);
329        let wrong_scalar = Scalar::from(42i32);
330        assert!(builder.append_scalar(&wrong_scalar).is_err());
331    }
332}