Skip to main content

vortex_array/builders/
struct_.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5
6use itertools::Itertools;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_panic;
12use vortex_mask::Mask;
13
14use crate::ArrayRef;
15use crate::IntoArray;
16use crate::LEGACY_SESSION;
17use crate::VortexSessionExecute;
18use crate::arrays::StructArray;
19use crate::arrays::struct_::StructArrayExt;
20use crate::builders::ArrayBuilder;
21use crate::builders::DEFAULT_BUILDER_CAPACITY;
22use crate::builders::LazyBitBufferBuilder;
23use crate::builders::builder_with_capacity;
24use crate::canonical::Canonical;
25#[expect(deprecated)]
26use crate::canonical::ToCanonical as _;
27use crate::dtype::DType;
28use crate::dtype::Nullability;
29use crate::dtype::StructFields;
30use crate::scalar::Scalar;
31use crate::scalar::StructScalar;
32
33/// The builder for building a [`StructArray`].
34pub struct StructBuilder {
35    dtype: DType,
36    builders: Vec<Box<dyn ArrayBuilder>>,
37    nulls: LazyBitBufferBuilder,
38}
39
40impl StructBuilder {
41    /// Creates a new `StructBuilder` with a capacity of [`DEFAULT_BUILDER_CAPACITY`].
42    pub fn new(struct_dtype: StructFields, nullability: Nullability) -> Self {
43        Self::with_capacity(struct_dtype, nullability, DEFAULT_BUILDER_CAPACITY)
44    }
45
46    /// Creates a new `StructBuilder` with the given `capacity`.
47    pub fn with_capacity(
48        struct_dtype: StructFields,
49        nullability: Nullability,
50        capacity: usize,
51    ) -> Self {
52        let builders = struct_dtype
53            .fields()
54            .map(|dt| builder_with_capacity(&dt, capacity))
55            .collect();
56
57        Self {
58            builders,
59            nulls: LazyBitBufferBuilder::new(capacity),
60            dtype: DType::Struct(struct_dtype, nullability),
61        }
62    }
63
64    /// Appends a struct `value` to the builder.
65    pub fn append_value(&mut self, struct_scalar: StructScalar) -> VortexResult<()> {
66        if !self.dtype.is_nullable() && struct_scalar.is_null() {
67            vortex_bail!("Tried to append a null `StructScalar` to a non-nullable struct builder",);
68        }
69
70        if struct_scalar.struct_fields() != self.struct_fields() {
71            vortex_bail!(
72                "Tried to append a `StructScalar` with fields {} to a \
73                    struct builder with fields {}",
74                struct_scalar.struct_fields(),
75                self.struct_fields()
76            );
77        }
78
79        if let Some(fields) = struct_scalar.fields_iter() {
80            for (builder, field) in self.builders.iter_mut().zip_eq(fields) {
81                builder.append_scalar(&field)?;
82            }
83            self.nulls.append_non_null();
84        } else {
85            self.append_null()
86        }
87
88        Ok(())
89    }
90
91    /// Finishes the builder directly into a [`StructArray`].
92    pub fn finish_into_struct(&mut self) -> StructArray {
93        let len = self.len();
94        let fields = self
95            .builders
96            .iter_mut()
97            .map(|builder| builder.finish())
98            .collect::<Vec<_>>();
99
100        if fields.len() > 1 {
101            let expected_length = fields[0].len();
102            for (index, field) in fields[1..].iter().enumerate() {
103                assert_eq!(
104                    field.len(),
105                    expected_length,
106                    "Field {index} does not have expected length {expected_length}"
107                );
108            }
109        }
110
111        let validity = self.nulls.finish_with_nullability(self.dtype.nullability());
112
113        StructArray::try_new_with_dtype(fields, self.struct_fields().clone(), len, validity)
114            .vortex_expect("Fields must all have same length.")
115    }
116
117    /// The [`StructFields`] of this struct builder.
118    pub fn struct_fields(&self) -> &StructFields {
119        let DType::Struct(struct_fields, _) = &self.dtype else {
120            vortex_panic!("`StructBuilder` somehow had dtype {}", self.dtype);
121        };
122
123        struct_fields
124    }
125}
126
127impl ArrayBuilder for StructBuilder {
128    fn as_any(&self) -> &dyn Any {
129        self
130    }
131
132    fn as_any_mut(&mut self) -> &mut dyn Any {
133        self
134    }
135
136    fn dtype(&self) -> &DType {
137        &self.dtype
138    }
139
140    fn len(&self) -> usize {
141        self.nulls.len()
142    }
143
144    fn append_zeros(&mut self, n: usize) {
145        self.builders
146            .iter_mut()
147            .for_each(|builder| builder.append_zeros(n));
148        self.nulls.append_n_non_nulls(n);
149    }
150
151    unsafe fn append_nulls_unchecked(&mut self, n: usize) {
152        self.builders
153            .iter_mut()
154            // We push zero values into our children when appending a null in case the children are
155            // themselves non-nullable.
156            .for_each(|builder| builder.append_defaults(n));
157        self.nulls.append_n_nulls(n);
158    }
159
160    fn append_scalar(&mut self, scalar: &Scalar) -> VortexResult<()> {
161        vortex_ensure!(
162            scalar.dtype() == self.dtype(),
163            "StructBuilder expected scalar with dtype {}, got {}",
164            self.dtype(),
165            scalar.dtype()
166        );
167
168        self.append_value(scalar.as_struct())
169    }
170
171    unsafe fn extend_from_array_unchecked(&mut self, array: &ArrayRef) {
172        #[expect(deprecated)]
173        let array = array.to_struct();
174
175        for (a, builder) in array
176            .iter_unmasked_fields()
177            .zip_eq(self.builders.iter_mut())
178        {
179            builder.extend_from_array(a);
180        }
181
182        self.nulls.append_validity_mask(
183            &array
184                .validity()
185                .vortex_expect("validity_mask")
186                .execute_mask(array.len(), &mut LEGACY_SESSION.create_execution_ctx())
187                .vortex_expect("Failed to compute validity mask"),
188        );
189    }
190
191    fn reserve_exact(&mut self, capacity: usize) {
192        self.builders.iter_mut().for_each(|builder| {
193            builder.reserve_exact(capacity);
194        });
195        self.nulls.reserve_exact(capacity);
196    }
197
198    unsafe fn set_validity_unchecked(&mut self, validity: Mask) {
199        self.nulls = LazyBitBufferBuilder::from_validity_mask(validity);
200    }
201
202    fn finish(&mut self) -> ArrayRef {
203        self.finish_into_struct().into_array()
204    }
205
206    fn finish_into_canonical(&mut self) -> Canonical {
207        Canonical::Struct(self.finish_into_struct())
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use crate::IntoArray;
214    use crate::LEGACY_SESSION;
215    use crate::VortexSessionExecute;
216    use crate::arrays::PrimitiveArray;
217    use crate::arrays::VarBinArray;
218    use crate::assert_arrays_eq;
219    use crate::builders::ArrayBuilder;
220    use crate::builders::struct_::StructArray;
221    use crate::builders::struct_::StructBuilder;
222    use crate::dtype::DType;
223    use crate::dtype::Nullability;
224    use crate::dtype::PType::I32;
225    use crate::dtype::StructFields;
226    use crate::scalar::Scalar;
227    use crate::validity::Validity;
228
229    #[test]
230    fn test_struct_builder() {
231        let sdt = StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]);
232        let dtype = DType::Struct(sdt.clone(), Nullability::NonNullable);
233        let mut builder = StructBuilder::with_capacity(sdt, Nullability::NonNullable, 0);
234
235        builder
236            .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
237            .unwrap();
238
239        let struct_ = builder.finish();
240        assert_eq!(struct_.len(), 1);
241        assert_eq!(struct_.dtype(), &dtype);
242    }
243
244    #[test]
245    fn test_append_nullable_struct() {
246        let sdt = StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]);
247        let dtype = DType::Struct(sdt.clone(), Nullability::Nullable);
248        let mut builder = StructBuilder::with_capacity(sdt, Nullability::Nullable, 0);
249
250        builder
251            .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
252            .unwrap();
253
254        builder.append_nulls(2);
255
256        let struct_ = builder.finish();
257        assert_eq!(struct_.len(), 3);
258        assert_eq!(struct_.dtype(), &dtype);
259        assert_eq!(
260            struct_
261                .valid_count(&mut LEGACY_SESSION.create_execution_ctx())
262                .unwrap(),
263            1
264        );
265    }
266
267    #[test]
268    fn test_append_scalar() {
269        use crate::scalar::Scalar;
270
271        let dtype = DType::Struct(
272            StructFields::from_iter([
273                ("a", DType::Primitive(I32, Nullability::Nullable)),
274                ("b", DType::Utf8(Nullability::Nullable)),
275            ]),
276            Nullability::Nullable,
277        );
278
279        let struct_fields = match &dtype {
280            DType::Struct(fields, _) => fields.clone(),
281            _ => panic!("Expected struct dtype"),
282        };
283        let mut builder = StructBuilder::new(struct_fields, Nullability::Nullable);
284
285        // Test appending a valid struct value.
286        let struct_scalar1 = Scalar::struct_(
287            dtype.clone(),
288            vec![
289                Scalar::primitive(42i32, Nullability::Nullable),
290                Scalar::utf8("hello", Nullability::Nullable),
291            ],
292        );
293        builder.append_scalar(&struct_scalar1).unwrap();
294
295        // Test appending another struct value.
296        let struct_scalar2 = Scalar::struct_(
297            dtype.clone(),
298            vec![
299                Scalar::primitive(84i32, Nullability::Nullable),
300                Scalar::utf8("world", Nullability::Nullable),
301            ],
302        );
303        builder.append_scalar(&struct_scalar2).unwrap();
304
305        // Test appending null value.
306        let null_scalar = Scalar::null(dtype.clone());
307        builder.append_scalar(&null_scalar).unwrap();
308
309        let array = builder.finish_into_struct();
310
311        let expected = StructArray::try_from_iter_with_validity(
312            [
313                (
314                    "a",
315                    PrimitiveArray::from_option_iter([Some(42i32), Some(84), Some(123)])
316                        .into_array(),
317                ),
318                (
319                    "b",
320                    <VarBinArray as FromIterator<_>>::from_iter([
321                        Some("hello"),
322                        Some("world"),
323                        Some("x"),
324                    ])
325                    .into_array(),
326                ),
327            ],
328            Validity::from_iter([true, true, false]),
329        )
330        .unwrap();
331        assert_arrays_eq!(&array, &expected);
332
333        // Test wrong dtype error.
334        let struct_fields = match &dtype {
335            DType::Struct(fields, _) => fields.clone(),
336            _ => panic!("Expected struct dtype"),
337        };
338        let mut builder = StructBuilder::new(struct_fields, Nullability::NonNullable);
339        let wrong_scalar = Scalar::from(42i32);
340        assert!(builder.append_scalar(&wrong_scalar).is_err());
341    }
342}