Skip to main content

vortex_array/builders/
struct_.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5
6use itertools::Itertools;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_panic;
12use vortex_mask::Mask;
13
14use crate::ArrayRef;
15use crate::IntoArray;
16use crate::arrays::StructArray;
17use crate::builders::ArrayBuilder;
18use crate::builders::DEFAULT_BUILDER_CAPACITY;
19use crate::builders::LazyBitBufferBuilder;
20use crate::builders::builder_with_capacity;
21use crate::canonical::Canonical;
22use crate::canonical::ToCanonical;
23use crate::dtype::DType;
24use crate::dtype::Nullability;
25use crate::dtype::StructFields;
26use crate::scalar::Scalar;
27use crate::scalar::StructScalar;
28
29/// The builder for building a [`StructArray`].
30pub struct StructBuilder {
31    dtype: DType,
32    builders: Vec<Box<dyn ArrayBuilder>>,
33    nulls: LazyBitBufferBuilder,
34}
35
36impl StructBuilder {
37    /// Creates a new `StructBuilder` with a capacity of [`DEFAULT_BUILDER_CAPACITY`].
38    pub fn new(struct_dtype: StructFields, nullability: Nullability) -> Self {
39        Self::with_capacity(struct_dtype, nullability, DEFAULT_BUILDER_CAPACITY)
40    }
41
42    /// Creates a new `StructBuilder` with the given `capacity`.
43    pub fn with_capacity(
44        struct_dtype: StructFields,
45        nullability: Nullability,
46        capacity: usize,
47    ) -> Self {
48        let builders = struct_dtype
49            .fields()
50            .map(|dt| builder_with_capacity(&dt, capacity))
51            .collect();
52
53        Self {
54            builders,
55            nulls: LazyBitBufferBuilder::new(capacity),
56            dtype: DType::Struct(struct_dtype, nullability),
57        }
58    }
59
60    /// Appends a struct `value` to the builder.
61    pub fn append_value(&mut self, struct_scalar: StructScalar) -> VortexResult<()> {
62        if !self.dtype.is_nullable() && struct_scalar.is_null() {
63            vortex_bail!("Tried to append a null `StructScalar` to a non-nullable struct builder",);
64        }
65
66        if struct_scalar.struct_fields() != self.struct_fields() {
67            vortex_bail!(
68                "Tried to append a `StructScalar` with fields {} to a \
69                    struct builder with fields {}",
70                struct_scalar.struct_fields(),
71                self.struct_fields()
72            );
73        }
74
75        if let Some(fields) = struct_scalar.fields_iter() {
76            for (builder, field) in self.builders.iter_mut().zip_eq(fields) {
77                builder.append_scalar(&field)?;
78            }
79            self.nulls.append_non_null();
80        } else {
81            self.append_null()
82        }
83
84        Ok(())
85    }
86
87    /// Finishes the builder directly into a [`StructArray`].
88    pub fn finish_into_struct(&mut self) -> StructArray {
89        let len = self.len();
90        let fields = self
91            .builders
92            .iter_mut()
93            .map(|builder| builder.finish())
94            .collect::<Vec<_>>();
95
96        if fields.len() > 1 {
97            let expected_length = fields[0].len();
98            for (index, field) in fields[1..].iter().enumerate() {
99                assert_eq!(
100                    field.len(),
101                    expected_length,
102                    "Field {index} does not have expected length {expected_length}"
103                );
104            }
105        }
106
107        let validity = self.nulls.finish_with_nullability(self.dtype.nullability());
108
109        StructArray::try_new_with_dtype(fields, self.struct_fields().clone(), len, validity)
110            .vortex_expect("Fields must all have same length.")
111    }
112
113    /// The [`StructFields`] of this struct builder.
114    pub fn struct_fields(&self) -> &StructFields {
115        let DType::Struct(struct_fields, _) = &self.dtype else {
116            vortex_panic!("`StructBuilder` somehow had dtype {}", self.dtype);
117        };
118
119        struct_fields
120    }
121}
122
123impl ArrayBuilder for StructBuilder {
124    fn as_any(&self) -> &dyn Any {
125        self
126    }
127
128    fn as_any_mut(&mut self) -> &mut dyn Any {
129        self
130    }
131
132    fn dtype(&self) -> &DType {
133        &self.dtype
134    }
135
136    fn len(&self) -> usize {
137        self.nulls.len()
138    }
139
140    fn append_zeros(&mut self, n: usize) {
141        self.builders
142            .iter_mut()
143            .for_each(|builder| builder.append_zeros(n));
144        self.nulls.append_n_non_nulls(n);
145    }
146
147    unsafe fn append_nulls_unchecked(&mut self, n: usize) {
148        self.builders
149            .iter_mut()
150            // We push zero values into our children when appending a null in case the children are
151            // themselves non-nullable.
152            .for_each(|builder| builder.append_defaults(n));
153        self.nulls.append_n_nulls(n);
154    }
155
156    fn append_scalar(&mut self, scalar: &Scalar) -> VortexResult<()> {
157        vortex_ensure!(
158            scalar.dtype() == self.dtype(),
159            "StructBuilder expected scalar with dtype {}, got {}",
160            self.dtype(),
161            scalar.dtype()
162        );
163
164        self.append_value(scalar.as_struct())
165    }
166
167    unsafe fn extend_from_array_unchecked(&mut self, array: &ArrayRef) {
168        let array = array.to_struct();
169
170        for (a, builder) in array
171            .unmasked_fields()
172            .iter()
173            .zip_eq(self.builders.iter_mut())
174        {
175            builder.extend_from_array(a);
176        }
177
178        self.nulls.append_validity_mask(
179            array
180                .validity_mask()
181                .vortex_expect("validity_mask in extend_from_array_unchecked"),
182        );
183    }
184
185    fn reserve_exact(&mut self, capacity: usize) {
186        self.builders.iter_mut().for_each(|builder| {
187            builder.reserve_exact(capacity);
188        });
189        self.nulls.reserve_exact(capacity);
190    }
191
192    unsafe fn set_validity_unchecked(&mut self, validity: Mask) {
193        self.nulls = LazyBitBufferBuilder::new(validity.len());
194        self.nulls.append_validity_mask(validity);
195    }
196
197    fn finish(&mut self) -> ArrayRef {
198        self.finish_into_struct().into_array()
199    }
200
201    fn finish_into_canonical(&mut self) -> Canonical {
202        Canonical::Struct(self.finish_into_struct())
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use crate::IntoArray;
209    use crate::arrays::PrimitiveArray;
210    use crate::arrays::StructArray;
211    use crate::arrays::VarBinArray;
212    use crate::assert_arrays_eq;
213    use crate::builders::ArrayBuilder;
214    use crate::builders::struct_::StructBuilder;
215    use crate::dtype::DType;
216    use crate::dtype::Nullability;
217    use crate::dtype::PType::I32;
218    use crate::dtype::StructFields;
219    use crate::scalar::Scalar;
220    use crate::validity::Validity;
221
222    #[test]
223    fn test_struct_builder() {
224        let sdt = StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]);
225        let dtype = DType::Struct(sdt.clone(), Nullability::NonNullable);
226        let mut builder = StructBuilder::with_capacity(sdt, Nullability::NonNullable, 0);
227
228        builder
229            .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
230            .unwrap();
231
232        let struct_ = builder.finish();
233        assert_eq!(struct_.len(), 1);
234        assert_eq!(struct_.dtype(), &dtype);
235    }
236
237    #[test]
238    fn test_append_nullable_struct() {
239        let sdt = StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]);
240        let dtype = DType::Struct(sdt.clone(), Nullability::Nullable);
241        let mut builder = StructBuilder::with_capacity(sdt, Nullability::Nullable, 0);
242
243        builder
244            .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
245            .unwrap();
246
247        builder.append_nulls(2);
248
249        let struct_ = builder.finish();
250        assert_eq!(struct_.len(), 3);
251        assert_eq!(struct_.dtype(), &dtype);
252        assert_eq!(struct_.valid_count().unwrap(), 1);
253    }
254
255    #[test]
256    fn test_append_scalar() {
257        use crate::scalar::Scalar;
258
259        let dtype = DType::Struct(
260            StructFields::from_iter([
261                ("a", DType::Primitive(I32, Nullability::Nullable)),
262                ("b", DType::Utf8(Nullability::Nullable)),
263            ]),
264            Nullability::Nullable,
265        );
266
267        let struct_fields = match &dtype {
268            DType::Struct(fields, _) => fields.clone(),
269            _ => panic!("Expected struct dtype"),
270        };
271        let mut builder = StructBuilder::new(struct_fields, Nullability::Nullable);
272
273        // Test appending a valid struct value.
274        let struct_scalar1 = Scalar::struct_(
275            dtype.clone(),
276            vec![
277                Scalar::primitive(42i32, Nullability::Nullable),
278                Scalar::utf8("hello", Nullability::Nullable),
279            ],
280        );
281        builder.append_scalar(&struct_scalar1).unwrap();
282
283        // Test appending another struct value.
284        let struct_scalar2 = Scalar::struct_(
285            dtype.clone(),
286            vec![
287                Scalar::primitive(84i32, Nullability::Nullable),
288                Scalar::utf8("world", Nullability::Nullable),
289            ],
290        );
291        builder.append_scalar(&struct_scalar2).unwrap();
292
293        // Test appending null value.
294        let null_scalar = Scalar::null(dtype.clone());
295        builder.append_scalar(&null_scalar).unwrap();
296
297        let array = builder.finish_into_struct();
298
299        let expected = StructArray::try_from_iter_with_validity(
300            [
301                (
302                    "a",
303                    PrimitiveArray::from_option_iter([Some(42i32), Some(84), Some(123)])
304                        .into_array(),
305                ),
306                (
307                    "b",
308                    <VarBinArray as FromIterator<_>>::from_iter([
309                        Some("hello"),
310                        Some("world"),
311                        Some("x"),
312                    ])
313                    .into_array(),
314                ),
315            ],
316            Validity::from_iter([true, true, false]),
317        )
318        .unwrap();
319        assert_arrays_eq!(&array, &expected);
320
321        // Test wrong dtype error.
322        let struct_fields = match &dtype {
323            DType::Struct(fields, _) => fields.clone(),
324            _ => panic!("Expected struct dtype"),
325        };
326        let mut builder = StructBuilder::new(struct_fields, Nullability::NonNullable);
327        let wrong_scalar = Scalar::from(42i32);
328        assert!(builder.append_scalar(&wrong_scalar).is_err());
329    }
330}