vortex_array/builders/
struct_.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5
6use itertools::Itertools;
7use vortex_dtype::{DType, Nullability, StructFields};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail};
9use vortex_mask::Mask;
10use vortex_scalar::StructScalar;
11
12use crate::arrays::StructArray;
13use crate::builders::lazy_validity_builder::LazyNullBufferBuilder;
14use crate::builders::{ArrayBuilder, ArrayBuilderExt, builder_with_capacity};
15use crate::{Array, ArrayRef, IntoArray, ToCanonical};
16
17pub struct StructBuilder {
18    builders: Vec<Box<dyn ArrayBuilder>>,
19    validity: LazyNullBufferBuilder,
20    struct_dtype: StructFields,
21    nullability: Nullability,
22    dtype: DType,
23}
24
25impl StructBuilder {
26    pub fn with_capacity(
27        struct_dtype: StructFields,
28        nullability: Nullability,
29        capacity: usize,
30    ) -> Self {
31        let builders = struct_dtype
32            .fields()
33            .map(|dt| builder_with_capacity(&dt, capacity))
34            .collect();
35
36        Self {
37            builders,
38            validity: LazyNullBufferBuilder::new(capacity),
39            struct_dtype: struct_dtype.clone(),
40            nullability,
41            dtype: DType::Struct(struct_dtype, nullability),
42        }
43    }
44
45    pub fn append_value(&mut self, struct_scalar: StructScalar) -> VortexResult<()> {
46        if struct_scalar.dtype() != &DType::Struct(self.struct_dtype.clone(), self.nullability) {
47            vortex_bail!(
48                "Expected struct scalar with dtype {:?}, found {:?}",
49                self.struct_dtype,
50                struct_scalar.dtype()
51            )
52        }
53
54        if let Some(fields) = struct_scalar.fields() {
55            for (builder, field) in self.builders.iter_mut().zip_eq(fields) {
56                builder.append_scalar(&field)?;
57            }
58            self.validity.append_non_null();
59        } else {
60            self.append_null()
61        }
62
63        Ok(())
64    }
65}
66
67impl ArrayBuilder for StructBuilder {
68    fn as_any(&self) -> &dyn Any {
69        self
70    }
71
72    fn as_any_mut(&mut self) -> &mut dyn Any {
73        self
74    }
75
76    fn dtype(&self) -> &DType {
77        &self.dtype
78    }
79
80    fn len(&self) -> usize {
81        self.validity.len()
82    }
83
84    fn append_zeros(&mut self, n: usize) {
85        self.builders
86            .iter_mut()
87            .for_each(|builder| builder.append_zeros(n));
88        self.validity.append_n_non_nulls(n);
89    }
90
91    fn append_nulls(&mut self, n: usize) {
92        self.builders
93            .iter_mut()
94            // We push zero values into our children when appending a null in case the children are
95            // themselves non-nullable.
96            .for_each(|builder| builder.append_zeros(n));
97        self.validity.append_null();
98    }
99
100    fn extend_from_array(&mut self, array: &dyn Array) -> VortexResult<()> {
101        let array = array.to_struct()?;
102
103        if array.dtype() != self.dtype() {
104            vortex_bail!(
105                "Cannot extend from array with different dtype: expected {}, found {}",
106                self.dtype(),
107                array.dtype()
108            );
109        }
110
111        for (a, builder) in (0..array.struct_fields().nfields())
112            .map(|i| &array.fields()[i])
113            .zip_eq(self.builders.iter_mut())
114        {
115            a.append_to_builder(builder.as_mut())?;
116        }
117
118        self.validity.append_validity_mask(array.validity_mask()?);
119        Ok(())
120    }
121
122    fn ensure_capacity(&mut self, capacity: usize) {
123        self.builders.iter_mut().for_each(|builder| {
124            builder.ensure_capacity(capacity);
125        });
126        self.validity.ensure_capacity(capacity);
127    }
128
129    fn set_validity(&mut self, validity: Mask) {
130        self.validity = LazyNullBufferBuilder::new(validity.len());
131        self.validity.append_validity_mask(validity);
132    }
133
134    fn finish(&mut self) -> ArrayRef {
135        let len = self.len();
136        let fields = self
137            .builders
138            .iter_mut()
139            .map(|builder| builder.finish())
140            .collect::<Vec<_>>();
141
142        if fields.len() > 1 {
143            let expected_length = fields[0].len();
144            for (index, field) in fields[1..].iter().enumerate() {
145                assert_eq!(
146                    field.len(),
147                    expected_length,
148                    "Field {index} does not have expected length {expected_length}"
149                );
150            }
151        }
152
153        let validity = self.validity.finish_with_nullability(self.nullability);
154
155        StructArray::try_new_with_dtype(fields, self.struct_dtype.clone(), len, validity)
156            .vortex_expect("Fields must all have same length.")
157            .into_array()
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use std::sync::Arc;
164
165    use vortex_dtype::PType::I32;
166    use vortex_dtype::{DType, Nullability, StructFields};
167    use vortex_scalar::Scalar;
168
169    use crate::builders::ArrayBuilder;
170    use crate::builders::struct_::StructBuilder;
171
172    #[test]
173    fn test_struct_builder() {
174        let sdt = StructFields::new(
175            vec![Arc::from("a"), Arc::from("b")].into(),
176            vec![I32.into(), I32.into()],
177        );
178        let dtype = DType::Struct(sdt.clone(), Nullability::NonNullable);
179        let mut builder = StructBuilder::with_capacity(sdt, Nullability::NonNullable, 0);
180
181        builder
182            .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
183            .unwrap();
184
185        let struct_ = builder.finish();
186        assert_eq!(struct_.len(), 1);
187        assert_eq!(struct_.dtype(), &dtype);
188    }
189
190    #[test]
191    fn test_append_nullable_struct() {
192        let sdt = StructFields::new(
193            vec![Arc::from("a"), Arc::from("b")].into(),
194            vec![I32.into(), I32.into()],
195        );
196        let dtype = DType::Struct(sdt.clone(), Nullability::Nullable);
197        let mut builder = StructBuilder::with_capacity(sdt, Nullability::Nullable, 0);
198
199        builder
200            .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
201            .unwrap();
202
203        let struct_ = builder.finish();
204        assert_eq!(struct_.len(), 1);
205        assert_eq!(struct_.dtype(), &dtype);
206    }
207}