Skip to main content

vortex_array/builders/
struct_.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5
6use itertools::Itertools;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_panic;
12use vortex_mask::Mask;
13
14use crate::ArrayRef;
15use crate::IntoArray;
16use crate::LEGACY_SESSION;
17use crate::VortexSessionExecute;
18use crate::arrays::StructArray;
19use crate::arrays::struct_::StructArrayExt;
20use crate::builders::ArrayBuilder;
21use crate::builders::DEFAULT_BUILDER_CAPACITY;
22use crate::builders::LazyBitBufferBuilder;
23use crate::builders::builder_with_capacity;
24use crate::canonical::Canonical;
25#[expect(deprecated)]
26use crate::canonical::ToCanonical as _;
27use crate::dtype::DType;
28use crate::dtype::Nullability;
29use crate::dtype::StructFields;
30use crate::scalar::Scalar;
31use crate::scalar::StructScalar;
32
33/// The builder for building a [`StructArray`].
34pub struct StructBuilder {
35    dtype: DType,
36    builders: Vec<Box<dyn ArrayBuilder>>,
37    nulls: LazyBitBufferBuilder,
38}
39
40impl StructBuilder {
41    /// Creates a new `StructBuilder` with a capacity of [`DEFAULT_BUILDER_CAPACITY`].
42    pub fn new(struct_dtype: StructFields, nullability: Nullability) -> Self {
43        Self::with_capacity(struct_dtype, nullability, DEFAULT_BUILDER_CAPACITY)
44    }
45
46    /// Creates a new `StructBuilder` with the given `capacity`.
47    pub fn with_capacity(
48        struct_dtype: StructFields,
49        nullability: Nullability,
50        capacity: usize,
51    ) -> Self {
52        let builders = struct_dtype
53            .fields()
54            .map(|dt| builder_with_capacity(&dt, capacity))
55            .collect();
56
57        Self {
58            builders,
59            nulls: LazyBitBufferBuilder::new(capacity),
60            dtype: DType::Struct(struct_dtype, nullability),
61        }
62    }
63
64    /// Appends a struct `value` to the builder.
65    pub fn append_value(&mut self, struct_scalar: StructScalar) -> VortexResult<()> {
66        if !self.dtype.is_nullable() && struct_scalar.is_null() {
67            vortex_bail!("Tried to append a null `StructScalar` to a non-nullable struct builder",);
68        }
69
70        if struct_scalar.struct_fields() != self.struct_fields() {
71            vortex_bail!(
72                "Tried to append a `StructScalar` with fields {} to a \
73                    struct builder with fields {}",
74                struct_scalar.struct_fields(),
75                self.struct_fields()
76            );
77        }
78
79        if let Some(fields) = struct_scalar.fields_iter() {
80            for (builder, field) in self.builders.iter_mut().zip_eq(fields) {
81                builder.append_scalar(&field)?;
82            }
83            self.nulls.append_non_null();
84        } else {
85            self.append_null()
86        }
87
88        Ok(())
89    }
90
91    /// Finishes the builder directly into a [`StructArray`].
92    pub fn finish_into_struct(&mut self) -> StructArray {
93        let len = self.len();
94        let fields = self
95            .builders
96            .iter_mut()
97            .map(|builder| builder.finish())
98            .collect::<Vec<_>>();
99
100        if fields.len() > 1 {
101            let expected_length = fields[0].len();
102            for (index, field) in fields[1..].iter().enumerate() {
103                assert_eq!(
104                    field.len(),
105                    expected_length,
106                    "Field {index} does not have expected length {expected_length}"
107                );
108            }
109        }
110
111        let validity = self.nulls.finish_with_nullability(self.dtype.nullability());
112
113        StructArray::try_new_with_dtype(fields, self.struct_fields().clone(), len, validity)
114            .vortex_expect("Fields must all have same length.")
115    }
116
117    /// The [`StructFields`] of this struct builder.
118    pub fn struct_fields(&self) -> &StructFields {
119        let DType::Struct(struct_fields, _) = &self.dtype else {
120            vortex_panic!("`StructBuilder` somehow had dtype {}", self.dtype);
121        };
122
123        struct_fields
124    }
125}
126
127impl ArrayBuilder for StructBuilder {
128    fn as_any(&self) -> &dyn Any {
129        self
130    }
131
132    fn as_any_mut(&mut self) -> &mut dyn Any {
133        self
134    }
135
136    fn dtype(&self) -> &DType {
137        &self.dtype
138    }
139
140    fn len(&self) -> usize {
141        self.nulls.len()
142    }
143
144    fn append_zeros(&mut self, n: usize) {
145        self.builders
146            .iter_mut()
147            .for_each(|builder| builder.append_zeros(n));
148        self.nulls.append_n_non_nulls(n);
149    }
150
151    unsafe fn append_nulls_unchecked(&mut self, n: usize) {
152        self.builders
153            .iter_mut()
154            // We push zero values into our children when appending a null in case the children are
155            // themselves non-nullable.
156            .for_each(|builder| builder.append_defaults(n));
157        self.nulls.append_n_nulls(n);
158    }
159
160    fn append_scalar(&mut self, scalar: &Scalar) -> VortexResult<()> {
161        vortex_ensure!(
162            scalar.dtype() == self.dtype(),
163            "StructBuilder expected scalar with dtype {}, got {}",
164            self.dtype(),
165            scalar.dtype()
166        );
167
168        self.append_value(scalar.as_struct())
169    }
170
171    unsafe fn extend_from_array_unchecked(&mut self, array: &ArrayRef) {
172        #[expect(deprecated)]
173        let array = array.to_struct();
174
175        for (a, builder) in array
176            .iter_unmasked_fields()
177            .zip_eq(self.builders.iter_mut())
178        {
179            builder.extend_from_array(a);
180        }
181
182        self.nulls.append_validity_mask(
183            array
184                .validity()
185                .vortex_expect("validity_mask")
186                .execute_mask(array.len(), &mut LEGACY_SESSION.create_execution_ctx())
187                .vortex_expect("Failed to compute validity mask"),
188        );
189    }
190
191    fn reserve_exact(&mut self, capacity: usize) {
192        self.builders.iter_mut().for_each(|builder| {
193            builder.reserve_exact(capacity);
194        });
195        self.nulls.reserve_exact(capacity);
196    }
197
198    unsafe fn set_validity_unchecked(&mut self, validity: Mask) {
199        self.nulls = LazyBitBufferBuilder::new(validity.len());
200        self.nulls.append_validity_mask(validity);
201    }
202
203    fn finish(&mut self) -> ArrayRef {
204        self.finish_into_struct().into_array()
205    }
206
207    fn finish_into_canonical(&mut self) -> Canonical {
208        Canonical::Struct(self.finish_into_struct())
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use crate::IntoArray;
215    use crate::LEGACY_SESSION;
216    use crate::VortexSessionExecute;
217    use crate::arrays::PrimitiveArray;
218    use crate::arrays::VarBinArray;
219    use crate::assert_arrays_eq;
220    use crate::builders::ArrayBuilder;
221    use crate::builders::struct_::StructArray;
222    use crate::builders::struct_::StructBuilder;
223    use crate::dtype::DType;
224    use crate::dtype::Nullability;
225    use crate::dtype::PType::I32;
226    use crate::dtype::StructFields;
227    use crate::scalar::Scalar;
228    use crate::validity::Validity;
229
230    #[test]
231    fn test_struct_builder() {
232        let sdt = StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]);
233        let dtype = DType::Struct(sdt.clone(), Nullability::NonNullable);
234        let mut builder = StructBuilder::with_capacity(sdt, Nullability::NonNullable, 0);
235
236        builder
237            .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
238            .unwrap();
239
240        let struct_ = builder.finish();
241        assert_eq!(struct_.len(), 1);
242        assert_eq!(struct_.dtype(), &dtype);
243    }
244
245    #[test]
246    fn test_append_nullable_struct() {
247        let sdt = StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]);
248        let dtype = DType::Struct(sdt.clone(), Nullability::Nullable);
249        let mut builder = StructBuilder::with_capacity(sdt, Nullability::Nullable, 0);
250
251        builder
252            .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
253            .unwrap();
254
255        builder.append_nulls(2);
256
257        let struct_ = builder.finish();
258        assert_eq!(struct_.len(), 3);
259        assert_eq!(struct_.dtype(), &dtype);
260        assert_eq!(
261            struct_
262                .valid_count(&mut LEGACY_SESSION.create_execution_ctx())
263                .unwrap(),
264            1
265        );
266    }
267
268    #[test]
269    fn test_append_scalar() {
270        use crate::scalar::Scalar;
271
272        let dtype = DType::Struct(
273            StructFields::from_iter([
274                ("a", DType::Primitive(I32, Nullability::Nullable)),
275                ("b", DType::Utf8(Nullability::Nullable)),
276            ]),
277            Nullability::Nullable,
278        );
279
280        let struct_fields = match &dtype {
281            DType::Struct(fields, _) => fields.clone(),
282            _ => panic!("Expected struct dtype"),
283        };
284        let mut builder = StructBuilder::new(struct_fields, Nullability::Nullable);
285
286        // Test appending a valid struct value.
287        let struct_scalar1 = Scalar::struct_(
288            dtype.clone(),
289            vec![
290                Scalar::primitive(42i32, Nullability::Nullable),
291                Scalar::utf8("hello", Nullability::Nullable),
292            ],
293        );
294        builder.append_scalar(&struct_scalar1).unwrap();
295
296        // Test appending another struct value.
297        let struct_scalar2 = Scalar::struct_(
298            dtype.clone(),
299            vec![
300                Scalar::primitive(84i32, Nullability::Nullable),
301                Scalar::utf8("world", Nullability::Nullable),
302            ],
303        );
304        builder.append_scalar(&struct_scalar2).unwrap();
305
306        // Test appending null value.
307        let null_scalar = Scalar::null(dtype.clone());
308        builder.append_scalar(&null_scalar).unwrap();
309
310        let array = builder.finish_into_struct();
311
312        let expected = StructArray::try_from_iter_with_validity(
313            [
314                (
315                    "a",
316                    PrimitiveArray::from_option_iter([Some(42i32), Some(84), Some(123)])
317                        .into_array(),
318                ),
319                (
320                    "b",
321                    <VarBinArray as FromIterator<_>>::from_iter([
322                        Some("hello"),
323                        Some("world"),
324                        Some("x"),
325                    ])
326                    .into_array(),
327                ),
328            ],
329            Validity::from_iter([true, true, false]),
330        )
331        .unwrap();
332        assert_arrays_eq!(&array, &expected);
333
334        // Test wrong dtype error.
335        let struct_fields = match &dtype {
336            DType::Struct(fields, _) => fields.clone(),
337            _ => panic!("Expected struct dtype"),
338        };
339        let mut builder = StructBuilder::new(struct_fields, Nullability::NonNullable);
340        let wrong_scalar = Scalar::from(42i32);
341        assert!(builder.append_scalar(&wrong_scalar).is_err());
342    }
343}