Skip to main content

vortex_array/builders/
struct_.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5
6use itertools::Itertools;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_panic;
12use vortex_mask::Mask;
13
14use crate::ArrayRef;
15use crate::IntoArray;
16use crate::LEGACY_SESSION;
17use crate::VortexSessionExecute;
18use crate::arrays::StructArray;
19use crate::arrays::struct_::StructArrayExt;
20use crate::builders::ArrayBuilder;
21use crate::builders::DEFAULT_BUILDER_CAPACITY;
22use crate::builders::LazyBitBufferBuilder;
23use crate::builders::builder_with_capacity;
24use crate::canonical::Canonical;
25use crate::canonical::ToCanonical;
26use crate::dtype::DType;
27use crate::dtype::Nullability;
28use crate::dtype::StructFields;
29use crate::scalar::Scalar;
30use crate::scalar::StructScalar;
31
32/// The builder for building a [`StructArray`].
33pub struct StructBuilder {
34    dtype: DType,
35    builders: Vec<Box<dyn ArrayBuilder>>,
36    nulls: LazyBitBufferBuilder,
37}
38
39impl StructBuilder {
40    /// Creates a new `StructBuilder` with a capacity of [`DEFAULT_BUILDER_CAPACITY`].
41    pub fn new(struct_dtype: StructFields, nullability: Nullability) -> Self {
42        Self::with_capacity(struct_dtype, nullability, DEFAULT_BUILDER_CAPACITY)
43    }
44
45    /// Creates a new `StructBuilder` with the given `capacity`.
46    pub fn with_capacity(
47        struct_dtype: StructFields,
48        nullability: Nullability,
49        capacity: usize,
50    ) -> Self {
51        let builders = struct_dtype
52            .fields()
53            .map(|dt| builder_with_capacity(&dt, capacity))
54            .collect();
55
56        Self {
57            builders,
58            nulls: LazyBitBufferBuilder::new(capacity),
59            dtype: DType::Struct(struct_dtype, nullability),
60        }
61    }
62
63    /// Appends a struct `value` to the builder.
64    pub fn append_value(&mut self, struct_scalar: StructScalar) -> VortexResult<()> {
65        if !self.dtype.is_nullable() && struct_scalar.is_null() {
66            vortex_bail!("Tried to append a null `StructScalar` to a non-nullable struct builder",);
67        }
68
69        if struct_scalar.struct_fields() != self.struct_fields() {
70            vortex_bail!(
71                "Tried to append a `StructScalar` with fields {} to a \
72                    struct builder with fields {}",
73                struct_scalar.struct_fields(),
74                self.struct_fields()
75            );
76        }
77
78        if let Some(fields) = struct_scalar.fields_iter() {
79            for (builder, field) in self.builders.iter_mut().zip_eq(fields) {
80                builder.append_scalar(&field)?;
81            }
82            self.nulls.append_non_null();
83        } else {
84            self.append_null()
85        }
86
87        Ok(())
88    }
89
90    /// Finishes the builder directly into a [`StructArray`].
91    pub fn finish_into_struct(&mut self) -> StructArray {
92        let len = self.len();
93        let fields = self
94            .builders
95            .iter_mut()
96            .map(|builder| builder.finish())
97            .collect::<Vec<_>>();
98
99        if fields.len() > 1 {
100            let expected_length = fields[0].len();
101            for (index, field) in fields[1..].iter().enumerate() {
102                assert_eq!(
103                    field.len(),
104                    expected_length,
105                    "Field {index} does not have expected length {expected_length}"
106                );
107            }
108        }
109
110        let validity = self.nulls.finish_with_nullability(self.dtype.nullability());
111
112        StructArray::try_new_with_dtype(fields, self.struct_fields().clone(), len, validity)
113            .vortex_expect("Fields must all have same length.")
114    }
115
116    /// The [`StructFields`] of this struct builder.
117    pub fn struct_fields(&self) -> &StructFields {
118        let DType::Struct(struct_fields, _) = &self.dtype else {
119            vortex_panic!("`StructBuilder` somehow had dtype {}", self.dtype);
120        };
121
122        struct_fields
123    }
124}
125
126impl ArrayBuilder for StructBuilder {
127    fn as_any(&self) -> &dyn Any {
128        self
129    }
130
131    fn as_any_mut(&mut self) -> &mut dyn Any {
132        self
133    }
134
135    fn dtype(&self) -> &DType {
136        &self.dtype
137    }
138
139    fn len(&self) -> usize {
140        self.nulls.len()
141    }
142
143    fn append_zeros(&mut self, n: usize) {
144        self.builders
145            .iter_mut()
146            .for_each(|builder| builder.append_zeros(n));
147        self.nulls.append_n_non_nulls(n);
148    }
149
150    unsafe fn append_nulls_unchecked(&mut self, n: usize) {
151        self.builders
152            .iter_mut()
153            // We push zero values into our children when appending a null in case the children are
154            // themselves non-nullable.
155            .for_each(|builder| builder.append_defaults(n));
156        self.nulls.append_n_nulls(n);
157    }
158
159    fn append_scalar(&mut self, scalar: &Scalar) -> VortexResult<()> {
160        vortex_ensure!(
161            scalar.dtype() == self.dtype(),
162            "StructBuilder expected scalar with dtype {}, got {}",
163            self.dtype(),
164            scalar.dtype()
165        );
166
167        self.append_value(scalar.as_struct())
168    }
169
170    unsafe fn extend_from_array_unchecked(&mut self, array: &ArrayRef) {
171        let array = array.to_struct();
172
173        for (a, builder) in array
174            .iter_unmasked_fields()
175            .zip_eq(self.builders.iter_mut())
176        {
177            builder.extend_from_array(a);
178        }
179
180        self.nulls.append_validity_mask(
181            array
182                .validity()
183                .vortex_expect("validity_mask")
184                .to_mask(array.len(), &mut LEGACY_SESSION.create_execution_ctx())
185                .vortex_expect("Failed to compute validity mask"),
186        );
187    }
188
189    fn reserve_exact(&mut self, capacity: usize) {
190        self.builders.iter_mut().for_each(|builder| {
191            builder.reserve_exact(capacity);
192        });
193        self.nulls.reserve_exact(capacity);
194    }
195
196    unsafe fn set_validity_unchecked(&mut self, validity: Mask) {
197        self.nulls = LazyBitBufferBuilder::new(validity.len());
198        self.nulls.append_validity_mask(validity);
199    }
200
201    fn finish(&mut self) -> ArrayRef {
202        self.finish_into_struct().into_array()
203    }
204
205    fn finish_into_canonical(&mut self) -> Canonical {
206        Canonical::Struct(self.finish_into_struct())
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use crate::IntoArray;
213    use crate::LEGACY_SESSION;
214    use crate::VortexSessionExecute;
215    use crate::arrays::PrimitiveArray;
216    use crate::arrays::VarBinArray;
217    use crate::assert_arrays_eq;
218    use crate::builders::ArrayBuilder;
219    use crate::builders::struct_::StructArray;
220    use crate::builders::struct_::StructBuilder;
221    use crate::dtype::DType;
222    use crate::dtype::Nullability;
223    use crate::dtype::PType::I32;
224    use crate::dtype::StructFields;
225    use crate::scalar::Scalar;
226    use crate::validity::Validity;
227
228    #[test]
229    fn test_struct_builder() {
230        let sdt = StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]);
231        let dtype = DType::Struct(sdt.clone(), Nullability::NonNullable);
232        let mut builder = StructBuilder::with_capacity(sdt, Nullability::NonNullable, 0);
233
234        builder
235            .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
236            .unwrap();
237
238        let struct_ = builder.finish();
239        assert_eq!(struct_.len(), 1);
240        assert_eq!(struct_.dtype(), &dtype);
241    }
242
243    #[test]
244    fn test_append_nullable_struct() {
245        let sdt = StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]);
246        let dtype = DType::Struct(sdt.clone(), Nullability::Nullable);
247        let mut builder = StructBuilder::with_capacity(sdt, Nullability::Nullable, 0);
248
249        builder
250            .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
251            .unwrap();
252
253        builder.append_nulls(2);
254
255        let struct_ = builder.finish();
256        assert_eq!(struct_.len(), 3);
257        assert_eq!(struct_.dtype(), &dtype);
258        assert_eq!(
259            struct_
260                .valid_count(&mut LEGACY_SESSION.create_execution_ctx())
261                .unwrap(),
262            1
263        );
264    }
265
266    #[test]
267    fn test_append_scalar() {
268        use crate::scalar::Scalar;
269
270        let dtype = DType::Struct(
271            StructFields::from_iter([
272                ("a", DType::Primitive(I32, Nullability::Nullable)),
273                ("b", DType::Utf8(Nullability::Nullable)),
274            ]),
275            Nullability::Nullable,
276        );
277
278        let struct_fields = match &dtype {
279            DType::Struct(fields, _) => fields.clone(),
280            _ => panic!("Expected struct dtype"),
281        };
282        let mut builder = StructBuilder::new(struct_fields, Nullability::Nullable);
283
284        // Test appending a valid struct value.
285        let struct_scalar1 = Scalar::struct_(
286            dtype.clone(),
287            vec![
288                Scalar::primitive(42i32, Nullability::Nullable),
289                Scalar::utf8("hello", Nullability::Nullable),
290            ],
291        );
292        builder.append_scalar(&struct_scalar1).unwrap();
293
294        // Test appending another struct value.
295        let struct_scalar2 = Scalar::struct_(
296            dtype.clone(),
297            vec![
298                Scalar::primitive(84i32, Nullability::Nullable),
299                Scalar::utf8("world", Nullability::Nullable),
300            ],
301        );
302        builder.append_scalar(&struct_scalar2).unwrap();
303
304        // Test appending null value.
305        let null_scalar = Scalar::null(dtype.clone());
306        builder.append_scalar(&null_scalar).unwrap();
307
308        let array = builder.finish_into_struct();
309
310        let expected = StructArray::try_from_iter_with_validity(
311            [
312                (
313                    "a",
314                    PrimitiveArray::from_option_iter([Some(42i32), Some(84), Some(123)])
315                        .into_array(),
316                ),
317                (
318                    "b",
319                    <VarBinArray as FromIterator<_>>::from_iter([
320                        Some("hello"),
321                        Some("world"),
322                        Some("x"),
323                    ])
324                    .into_array(),
325                ),
326            ],
327            Validity::from_iter([true, true, false]),
328        )
329        .unwrap();
330        assert_arrays_eq!(&array, &expected);
331
332        // Test wrong dtype error.
333        let struct_fields = match &dtype {
334            DType::Struct(fields, _) => fields.clone(),
335            _ => panic!("Expected struct dtype"),
336        };
337        let mut builder = StructBuilder::new(struct_fields, Nullability::NonNullable);
338        let wrong_scalar = Scalar::from(42i32);
339        assert!(builder.append_scalar(&wrong_scalar).is_err());
340    }
341}