1use std::any::Any;
5
6use itertools::Itertools;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_panic;
12use vortex_mask::Mask;
13
14use crate::ArrayRef;
15use crate::IntoArray;
16use crate::arrays::StructArray;
17use crate::builders::ArrayBuilder;
18use crate::builders::DEFAULT_BUILDER_CAPACITY;
19use crate::builders::LazyBitBufferBuilder;
20use crate::builders::builder_with_capacity;
21use crate::canonical::Canonical;
22use crate::canonical::ToCanonical;
23use crate::dtype::DType;
24use crate::dtype::Nullability;
25use crate::dtype::StructFields;
26use crate::scalar::Scalar;
27use crate::scalar::StructScalar;
28
29pub struct StructBuilder {
31 dtype: DType,
32 builders: Vec<Box<dyn ArrayBuilder>>,
33 nulls: LazyBitBufferBuilder,
34}
35
36impl StructBuilder {
37 pub fn new(struct_dtype: StructFields, nullability: Nullability) -> Self {
39 Self::with_capacity(struct_dtype, nullability, DEFAULT_BUILDER_CAPACITY)
40 }
41
42 pub fn with_capacity(
44 struct_dtype: StructFields,
45 nullability: Nullability,
46 capacity: usize,
47 ) -> Self {
48 let builders = struct_dtype
49 .fields()
50 .map(|dt| builder_with_capacity(&dt, capacity))
51 .collect();
52
53 Self {
54 builders,
55 nulls: LazyBitBufferBuilder::new(capacity),
56 dtype: DType::Struct(struct_dtype, nullability),
57 }
58 }
59
60 pub fn append_value(&mut self, struct_scalar: StructScalar) -> VortexResult<()> {
62 if !self.dtype.is_nullable() && struct_scalar.is_null() {
63 vortex_bail!("Tried to append a null `StructScalar` to a non-nullable struct builder",);
64 }
65
66 if struct_scalar.struct_fields() != self.struct_fields() {
67 vortex_bail!(
68 "Tried to append a `StructScalar` with fields {} to a \
69 struct builder with fields {}",
70 struct_scalar.struct_fields(),
71 self.struct_fields()
72 );
73 }
74
75 if let Some(fields) = struct_scalar.fields_iter() {
76 for (builder, field) in self.builders.iter_mut().zip_eq(fields) {
77 builder.append_scalar(&field)?;
78 }
79 self.nulls.append_non_null();
80 } else {
81 self.append_null()
82 }
83
84 Ok(())
85 }
86
87 pub fn finish_into_struct(&mut self) -> StructArray {
89 let len = self.len();
90 let fields = self
91 .builders
92 .iter_mut()
93 .map(|builder| builder.finish())
94 .collect::<Vec<_>>();
95
96 if fields.len() > 1 {
97 let expected_length = fields[0].len();
98 for (index, field) in fields[1..].iter().enumerate() {
99 assert_eq!(
100 field.len(),
101 expected_length,
102 "Field {index} does not have expected length {expected_length}"
103 );
104 }
105 }
106
107 let validity = self.nulls.finish_with_nullability(self.dtype.nullability());
108
109 StructArray::try_new_with_dtype(fields, self.struct_fields().clone(), len, validity)
110 .vortex_expect("Fields must all have same length.")
111 }
112
113 pub fn struct_fields(&self) -> &StructFields {
115 let DType::Struct(struct_fields, _) = &self.dtype else {
116 vortex_panic!("`StructBuilder` somehow had dtype {}", self.dtype);
117 };
118
119 struct_fields
120 }
121}
122
123impl ArrayBuilder for StructBuilder {
124 fn as_any(&self) -> &dyn Any {
125 self
126 }
127
128 fn as_any_mut(&mut self) -> &mut dyn Any {
129 self
130 }
131
132 fn dtype(&self) -> &DType {
133 &self.dtype
134 }
135
136 fn len(&self) -> usize {
137 self.nulls.len()
138 }
139
140 fn append_zeros(&mut self, n: usize) {
141 self.builders
142 .iter_mut()
143 .for_each(|builder| builder.append_zeros(n));
144 self.nulls.append_n_non_nulls(n);
145 }
146
147 unsafe fn append_nulls_unchecked(&mut self, n: usize) {
148 self.builders
149 .iter_mut()
150 .for_each(|builder| builder.append_defaults(n));
153 self.nulls.append_n_nulls(n);
154 }
155
156 fn append_scalar(&mut self, scalar: &Scalar) -> VortexResult<()> {
157 vortex_ensure!(
158 scalar.dtype() == self.dtype(),
159 "StructBuilder expected scalar with dtype {}, got {}",
160 self.dtype(),
161 scalar.dtype()
162 );
163
164 self.append_value(scalar.as_struct())
165 }
166
167 unsafe fn extend_from_array_unchecked(&mut self, array: &ArrayRef) {
168 let array = array.to_struct();
169
170 for (a, builder) in array
171 .unmasked_fields()
172 .iter()
173 .zip_eq(self.builders.iter_mut())
174 {
175 builder.extend_from_array(a);
176 }
177
178 self.nulls.append_validity_mask(
179 array
180 .validity_mask()
181 .vortex_expect("validity_mask in extend_from_array_unchecked"),
182 );
183 }
184
185 fn reserve_exact(&mut self, capacity: usize) {
186 self.builders.iter_mut().for_each(|builder| {
187 builder.reserve_exact(capacity);
188 });
189 self.nulls.reserve_exact(capacity);
190 }
191
192 unsafe fn set_validity_unchecked(&mut self, validity: Mask) {
193 self.nulls = LazyBitBufferBuilder::new(validity.len());
194 self.nulls.append_validity_mask(validity);
195 }
196
197 fn finish(&mut self) -> ArrayRef {
198 self.finish_into_struct().into_array()
199 }
200
201 fn finish_into_canonical(&mut self) -> Canonical {
202 Canonical::Struct(self.finish_into_struct())
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use crate::IntoArray;
209 use crate::arrays::PrimitiveArray;
210 use crate::arrays::StructArray;
211 use crate::arrays::VarBinArray;
212 use crate::assert_arrays_eq;
213 use crate::builders::ArrayBuilder;
214 use crate::builders::struct_::StructBuilder;
215 use crate::dtype::DType;
216 use crate::dtype::Nullability;
217 use crate::dtype::PType::I32;
218 use crate::dtype::StructFields;
219 use crate::scalar::Scalar;
220 use crate::validity::Validity;
221
222 #[test]
223 fn test_struct_builder() {
224 let sdt = StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]);
225 let dtype = DType::Struct(sdt.clone(), Nullability::NonNullable);
226 let mut builder = StructBuilder::with_capacity(sdt, Nullability::NonNullable, 0);
227
228 builder
229 .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
230 .unwrap();
231
232 let struct_ = builder.finish();
233 assert_eq!(struct_.len(), 1);
234 assert_eq!(struct_.dtype(), &dtype);
235 }
236
237 #[test]
238 fn test_append_nullable_struct() {
239 let sdt = StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]);
240 let dtype = DType::Struct(sdt.clone(), Nullability::Nullable);
241 let mut builder = StructBuilder::with_capacity(sdt, Nullability::Nullable, 0);
242
243 builder
244 .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
245 .unwrap();
246
247 builder.append_nulls(2);
248
249 let struct_ = builder.finish();
250 assert_eq!(struct_.len(), 3);
251 assert_eq!(struct_.dtype(), &dtype);
252 assert_eq!(struct_.valid_count().unwrap(), 1);
253 }
254
255 #[test]
256 fn test_append_scalar() {
257 use crate::scalar::Scalar;
258
259 let dtype = DType::Struct(
260 StructFields::from_iter([
261 ("a", DType::Primitive(I32, Nullability::Nullable)),
262 ("b", DType::Utf8(Nullability::Nullable)),
263 ]),
264 Nullability::Nullable,
265 );
266
267 let struct_fields = match &dtype {
268 DType::Struct(fields, _) => fields.clone(),
269 _ => panic!("Expected struct dtype"),
270 };
271 let mut builder = StructBuilder::new(struct_fields, Nullability::Nullable);
272
273 let struct_scalar1 = Scalar::struct_(
275 dtype.clone(),
276 vec![
277 Scalar::primitive(42i32, Nullability::Nullable),
278 Scalar::utf8("hello", Nullability::Nullable),
279 ],
280 );
281 builder.append_scalar(&struct_scalar1).unwrap();
282
283 let struct_scalar2 = Scalar::struct_(
285 dtype.clone(),
286 vec![
287 Scalar::primitive(84i32, Nullability::Nullable),
288 Scalar::utf8("world", Nullability::Nullable),
289 ],
290 );
291 builder.append_scalar(&struct_scalar2).unwrap();
292
293 let null_scalar = Scalar::null(dtype.clone());
295 builder.append_scalar(&null_scalar).unwrap();
296
297 let array = builder.finish_into_struct();
298
299 let expected = StructArray::try_from_iter_with_validity(
300 [
301 (
302 "a",
303 PrimitiveArray::from_option_iter([Some(42i32), Some(84), Some(123)])
304 .into_array(),
305 ),
306 (
307 "b",
308 <VarBinArray as FromIterator<_>>::from_iter([
309 Some("hello"),
310 Some("world"),
311 Some("x"),
312 ])
313 .into_array(),
314 ),
315 ],
316 Validity::from_iter([true, true, false]),
317 )
318 .unwrap();
319 assert_arrays_eq!(&array, &expected);
320
321 let struct_fields = match &dtype {
323 DType::Struct(fields, _) => fields.clone(),
324 _ => panic!("Expected struct dtype"),
325 };
326 let mut builder = StructBuilder::new(struct_fields, Nullability::NonNullable);
327 let wrong_scalar = Scalar::from(42i32);
328 assert!(builder.append_scalar(&wrong_scalar).is_err());
329 }
330}