1use std::any::Any;
5
6use itertools::Itertools;
7use vortex_dtype::DType;
8use vortex_dtype::Nullability;
9use vortex_dtype::StructFields;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_ensure;
14use vortex_error::vortex_panic;
15use vortex_mask::Mask;
16
17use crate::Array;
18use crate::ArrayRef;
19use crate::IntoArray;
20use crate::arrays::StructArray;
21use crate::builders::ArrayBuilder;
22use crate::builders::DEFAULT_BUILDER_CAPACITY;
23use crate::builders::LazyBitBufferBuilder;
24use crate::builders::builder_with_capacity;
25use crate::canonical::Canonical;
26use crate::canonical::ToCanonical;
27use crate::scalar::Scalar;
28use crate::scalar::StructScalar;
29
30pub struct StructBuilder {
32 dtype: DType,
33 builders: Vec<Box<dyn ArrayBuilder>>,
34 nulls: LazyBitBufferBuilder,
35}
36
37impl StructBuilder {
38 pub fn new(struct_dtype: StructFields, nullability: Nullability) -> Self {
40 Self::with_capacity(struct_dtype, nullability, DEFAULT_BUILDER_CAPACITY)
41 }
42
43 pub fn with_capacity(
45 struct_dtype: StructFields,
46 nullability: Nullability,
47 capacity: usize,
48 ) -> Self {
49 let builders = struct_dtype
50 .fields()
51 .map(|dt| builder_with_capacity(&dt, capacity))
52 .collect();
53
54 Self {
55 builders,
56 nulls: LazyBitBufferBuilder::new(capacity),
57 dtype: DType::Struct(struct_dtype, nullability),
58 }
59 }
60
61 pub fn append_value(&mut self, struct_scalar: StructScalar) -> VortexResult<()> {
63 if !self.dtype.is_nullable() && struct_scalar.is_null() {
64 vortex_bail!("Tried to append a null `StructScalar` to a non-nullable struct builder",);
65 }
66
67 if struct_scalar.struct_fields() != self.struct_fields() {
68 vortex_bail!(
69 "Tried to append a `StructScalar` with fields {} to a \
70 struct builder with fields {}",
71 struct_scalar.struct_fields(),
72 self.struct_fields()
73 );
74 }
75
76 if let Some(fields) = struct_scalar.fields_iter() {
77 for (builder, field) in self.builders.iter_mut().zip_eq(fields) {
78 builder.append_scalar(&field)?;
79 }
80 self.nulls.append_non_null();
81 } else {
82 self.append_null()
83 }
84
85 Ok(())
86 }
87
88 pub fn finish_into_struct(&mut self) -> StructArray {
90 let len = self.len();
91 let fields = self
92 .builders
93 .iter_mut()
94 .map(|builder| builder.finish())
95 .collect::<Vec<_>>();
96
97 if fields.len() > 1 {
98 let expected_length = fields[0].len();
99 for (index, field) in fields[1..].iter().enumerate() {
100 assert_eq!(
101 field.len(),
102 expected_length,
103 "Field {index} does not have expected length {expected_length}"
104 );
105 }
106 }
107
108 let validity = self.nulls.finish_with_nullability(self.dtype.nullability());
109
110 StructArray::try_new_with_dtype(fields, self.struct_fields().clone(), len, validity)
111 .vortex_expect("Fields must all have same length.")
112 }
113
114 pub fn struct_fields(&self) -> &StructFields {
116 let DType::Struct(struct_fields, _) = &self.dtype else {
117 vortex_panic!("`StructBuilder` somehow had dtype {}", self.dtype);
118 };
119
120 struct_fields
121 }
122}
123
124impl ArrayBuilder for StructBuilder {
125 fn as_any(&self) -> &dyn Any {
126 self
127 }
128
129 fn as_any_mut(&mut self) -> &mut dyn Any {
130 self
131 }
132
133 fn dtype(&self) -> &DType {
134 &self.dtype
135 }
136
137 fn len(&self) -> usize {
138 self.nulls.len()
139 }
140
141 fn append_zeros(&mut self, n: usize) {
142 self.builders
143 .iter_mut()
144 .for_each(|builder| builder.append_zeros(n));
145 self.nulls.append_n_non_nulls(n);
146 }
147
148 unsafe fn append_nulls_unchecked(&mut self, n: usize) {
149 self.builders
150 .iter_mut()
151 .for_each(|builder| builder.append_defaults(n));
154 self.nulls.append_n_nulls(n);
155 }
156
157 fn append_scalar(&mut self, scalar: &Scalar) -> VortexResult<()> {
158 vortex_ensure!(
159 scalar.dtype() == self.dtype(),
160 "StructBuilder expected scalar with dtype {}, got {}",
161 self.dtype(),
162 scalar.dtype()
163 );
164
165 self.append_value(scalar.as_struct())
166 }
167
168 unsafe fn extend_from_array_unchecked(&mut self, array: &dyn Array) {
169 let array = array.to_struct();
170
171 for (a, builder) in array
172 .unmasked_fields()
173 .iter()
174 .zip_eq(self.builders.iter_mut())
175 {
176 builder.extend_from_array(a.as_ref());
177 }
178
179 self.nulls.append_validity_mask(
180 array
181 .validity_mask()
182 .vortex_expect("validity_mask in extend_from_array_unchecked"),
183 );
184 }
185
186 fn reserve_exact(&mut self, capacity: usize) {
187 self.builders.iter_mut().for_each(|builder| {
188 builder.reserve_exact(capacity);
189 });
190 self.nulls.reserve_exact(capacity);
191 }
192
193 unsafe fn set_validity_unchecked(&mut self, validity: Mask) {
194 self.nulls = LazyBitBufferBuilder::new(validity.len());
195 self.nulls.append_validity_mask(validity);
196 }
197
198 fn finish(&mut self) -> ArrayRef {
199 self.finish_into_struct().into_array()
200 }
201
202 fn finish_into_canonical(&mut self) -> Canonical {
203 Canonical::Struct(self.finish_into_struct())
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use vortex_dtype::DType;
210 use vortex_dtype::Nullability;
211 use vortex_dtype::PType::I32;
212 use vortex_dtype::StructFields;
213
214 use crate::IntoArray;
215 use crate::arrays::PrimitiveArray;
216 use crate::arrays::StructArray;
217 use crate::arrays::VarBinArray;
218 use crate::assert_arrays_eq;
219 use crate::builders::ArrayBuilder;
220 use crate::builders::struct_::StructBuilder;
221 use crate::scalar::Scalar;
222 use crate::validity::Validity;
223
224 #[test]
225 fn test_struct_builder() {
226 let sdt = StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]);
227 let dtype = DType::Struct(sdt.clone(), Nullability::NonNullable);
228 let mut builder = StructBuilder::with_capacity(sdt, Nullability::NonNullable, 0);
229
230 builder
231 .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
232 .unwrap();
233
234 let struct_ = builder.finish();
235 assert_eq!(struct_.len(), 1);
236 assert_eq!(struct_.dtype(), &dtype);
237 }
238
239 #[test]
240 fn test_append_nullable_struct() {
241 let sdt = StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]);
242 let dtype = DType::Struct(sdt.clone(), Nullability::Nullable);
243 let mut builder = StructBuilder::with_capacity(sdt, Nullability::Nullable, 0);
244
245 builder
246 .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
247 .unwrap();
248
249 builder.append_nulls(2);
250
251 let struct_ = builder.finish();
252 assert_eq!(struct_.len(), 3);
253 assert_eq!(struct_.dtype(), &dtype);
254 assert_eq!(struct_.valid_count().unwrap(), 1);
255 }
256
257 #[test]
258 fn test_append_scalar() {
259 use crate::scalar::Scalar;
260
261 let dtype = DType::Struct(
262 StructFields::from_iter([
263 ("a", DType::Primitive(I32, Nullability::Nullable)),
264 ("b", DType::Utf8(Nullability::Nullable)),
265 ]),
266 Nullability::Nullable,
267 );
268
269 let struct_fields = match &dtype {
270 DType::Struct(fields, _) => fields.clone(),
271 _ => panic!("Expected struct dtype"),
272 };
273 let mut builder = StructBuilder::new(struct_fields, Nullability::Nullable);
274
275 let struct_scalar1 = Scalar::struct_(
277 dtype.clone(),
278 vec![
279 Scalar::primitive(42i32, Nullability::Nullable),
280 Scalar::utf8("hello", Nullability::Nullable),
281 ],
282 );
283 builder.append_scalar(&struct_scalar1).unwrap();
284
285 let struct_scalar2 = Scalar::struct_(
287 dtype.clone(),
288 vec![
289 Scalar::primitive(84i32, Nullability::Nullable),
290 Scalar::utf8("world", Nullability::Nullable),
291 ],
292 );
293 builder.append_scalar(&struct_scalar2).unwrap();
294
295 let null_scalar = Scalar::null(dtype.clone());
297 builder.append_scalar(&null_scalar).unwrap();
298
299 let array = builder.finish_into_struct();
300
301 let expected = StructArray::try_from_iter_with_validity(
302 [
303 (
304 "a",
305 PrimitiveArray::from_option_iter([Some(42i32), Some(84), Some(123)])
306 .into_array(),
307 ),
308 (
309 "b",
310 <VarBinArray as FromIterator<_>>::from_iter([
311 Some("hello"),
312 Some("world"),
313 Some("x"),
314 ])
315 .into_array(),
316 ),
317 ],
318 Validity::from_iter([true, true, false]),
319 )
320 .unwrap();
321 assert_arrays_eq!(&array, &expected);
322
323 let struct_fields = match &dtype {
325 DType::Struct(fields, _) => fields.clone(),
326 _ => panic!("Expected struct dtype"),
327 };
328 let mut builder = StructBuilder::new(struct_fields, Nullability::NonNullable);
329 let wrong_scalar = Scalar::from(42i32);
330 assert!(builder.append_scalar(&wrong_scalar).is_err());
331 }
332}