1use std::any::Any;
5
6use itertools::Itertools;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_panic;
12use vortex_mask::Mask;
13
14use crate::ArrayRef;
15use crate::IntoArray;
16use crate::LEGACY_SESSION;
17use crate::VortexSessionExecute;
18use crate::arrays::StructArray;
19use crate::arrays::struct_::StructArrayExt;
20use crate::builders::ArrayBuilder;
21use crate::builders::DEFAULT_BUILDER_CAPACITY;
22use crate::builders::LazyBitBufferBuilder;
23use crate::builders::builder_with_capacity;
24use crate::canonical::Canonical;
25use crate::canonical::ToCanonical;
26use crate::dtype::DType;
27use crate::dtype::Nullability;
28use crate::dtype::StructFields;
29use crate::scalar::Scalar;
30use crate::scalar::StructScalar;
31
32pub struct StructBuilder {
34 dtype: DType,
35 builders: Vec<Box<dyn ArrayBuilder>>,
36 nulls: LazyBitBufferBuilder,
37}
38
39impl StructBuilder {
40 pub fn new(struct_dtype: StructFields, nullability: Nullability) -> Self {
42 Self::with_capacity(struct_dtype, nullability, DEFAULT_BUILDER_CAPACITY)
43 }
44
45 pub fn with_capacity(
47 struct_dtype: StructFields,
48 nullability: Nullability,
49 capacity: usize,
50 ) -> Self {
51 let builders = struct_dtype
52 .fields()
53 .map(|dt| builder_with_capacity(&dt, capacity))
54 .collect();
55
56 Self {
57 builders,
58 nulls: LazyBitBufferBuilder::new(capacity),
59 dtype: DType::Struct(struct_dtype, nullability),
60 }
61 }
62
63 pub fn append_value(&mut self, struct_scalar: StructScalar) -> VortexResult<()> {
65 if !self.dtype.is_nullable() && struct_scalar.is_null() {
66 vortex_bail!("Tried to append a null `StructScalar` to a non-nullable struct builder",);
67 }
68
69 if struct_scalar.struct_fields() != self.struct_fields() {
70 vortex_bail!(
71 "Tried to append a `StructScalar` with fields {} to a \
72 struct builder with fields {}",
73 struct_scalar.struct_fields(),
74 self.struct_fields()
75 );
76 }
77
78 if let Some(fields) = struct_scalar.fields_iter() {
79 for (builder, field) in self.builders.iter_mut().zip_eq(fields) {
80 builder.append_scalar(&field)?;
81 }
82 self.nulls.append_non_null();
83 } else {
84 self.append_null()
85 }
86
87 Ok(())
88 }
89
90 pub fn finish_into_struct(&mut self) -> StructArray {
92 let len = self.len();
93 let fields = self
94 .builders
95 .iter_mut()
96 .map(|builder| builder.finish())
97 .collect::<Vec<_>>();
98
99 if fields.len() > 1 {
100 let expected_length = fields[0].len();
101 for (index, field) in fields[1..].iter().enumerate() {
102 assert_eq!(
103 field.len(),
104 expected_length,
105 "Field {index} does not have expected length {expected_length}"
106 );
107 }
108 }
109
110 let validity = self.nulls.finish_with_nullability(self.dtype.nullability());
111
112 StructArray::try_new_with_dtype(fields, self.struct_fields().clone(), len, validity)
113 .vortex_expect("Fields must all have same length.")
114 }
115
116 pub fn struct_fields(&self) -> &StructFields {
118 let DType::Struct(struct_fields, _) = &self.dtype else {
119 vortex_panic!("`StructBuilder` somehow had dtype {}", self.dtype);
120 };
121
122 struct_fields
123 }
124}
125
126impl ArrayBuilder for StructBuilder {
127 fn as_any(&self) -> &dyn Any {
128 self
129 }
130
131 fn as_any_mut(&mut self) -> &mut dyn Any {
132 self
133 }
134
135 fn dtype(&self) -> &DType {
136 &self.dtype
137 }
138
139 fn len(&self) -> usize {
140 self.nulls.len()
141 }
142
143 fn append_zeros(&mut self, n: usize) {
144 self.builders
145 .iter_mut()
146 .for_each(|builder| builder.append_zeros(n));
147 self.nulls.append_n_non_nulls(n);
148 }
149
150 unsafe fn append_nulls_unchecked(&mut self, n: usize) {
151 self.builders
152 .iter_mut()
153 .for_each(|builder| builder.append_defaults(n));
156 self.nulls.append_n_nulls(n);
157 }
158
159 fn append_scalar(&mut self, scalar: &Scalar) -> VortexResult<()> {
160 vortex_ensure!(
161 scalar.dtype() == self.dtype(),
162 "StructBuilder expected scalar with dtype {}, got {}",
163 self.dtype(),
164 scalar.dtype()
165 );
166
167 self.append_value(scalar.as_struct())
168 }
169
170 unsafe fn extend_from_array_unchecked(&mut self, array: &ArrayRef) {
171 let array = array.to_struct();
172
173 for (a, builder) in array
174 .iter_unmasked_fields()
175 .zip_eq(self.builders.iter_mut())
176 {
177 builder.extend_from_array(a);
178 }
179
180 self.nulls.append_validity_mask(
181 array
182 .validity()
183 .vortex_expect("validity_mask")
184 .to_mask(array.len(), &mut LEGACY_SESSION.create_execution_ctx())
185 .vortex_expect("Failed to compute validity mask"),
186 );
187 }
188
189 fn reserve_exact(&mut self, capacity: usize) {
190 self.builders.iter_mut().for_each(|builder| {
191 builder.reserve_exact(capacity);
192 });
193 self.nulls.reserve_exact(capacity);
194 }
195
196 unsafe fn set_validity_unchecked(&mut self, validity: Mask) {
197 self.nulls = LazyBitBufferBuilder::new(validity.len());
198 self.nulls.append_validity_mask(validity);
199 }
200
201 fn finish(&mut self) -> ArrayRef {
202 self.finish_into_struct().into_array()
203 }
204
205 fn finish_into_canonical(&mut self) -> Canonical {
206 Canonical::Struct(self.finish_into_struct())
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use crate::IntoArray;
213 use crate::LEGACY_SESSION;
214 use crate::VortexSessionExecute;
215 use crate::arrays::PrimitiveArray;
216 use crate::arrays::VarBinArray;
217 use crate::assert_arrays_eq;
218 use crate::builders::ArrayBuilder;
219 use crate::builders::struct_::StructArray;
220 use crate::builders::struct_::StructBuilder;
221 use crate::dtype::DType;
222 use crate::dtype::Nullability;
223 use crate::dtype::PType::I32;
224 use crate::dtype::StructFields;
225 use crate::scalar::Scalar;
226 use crate::validity::Validity;
227
228 #[test]
229 fn test_struct_builder() {
230 let sdt = StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]);
231 let dtype = DType::Struct(sdt.clone(), Nullability::NonNullable);
232 let mut builder = StructBuilder::with_capacity(sdt, Nullability::NonNullable, 0);
233
234 builder
235 .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
236 .unwrap();
237
238 let struct_ = builder.finish();
239 assert_eq!(struct_.len(), 1);
240 assert_eq!(struct_.dtype(), &dtype);
241 }
242
243 #[test]
244 fn test_append_nullable_struct() {
245 let sdt = StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]);
246 let dtype = DType::Struct(sdt.clone(), Nullability::Nullable);
247 let mut builder = StructBuilder::with_capacity(sdt, Nullability::Nullable, 0);
248
249 builder
250 .append_value(Scalar::struct_(dtype.clone(), vec![1.into(), 2.into()]).as_struct())
251 .unwrap();
252
253 builder.append_nulls(2);
254
255 let struct_ = builder.finish();
256 assert_eq!(struct_.len(), 3);
257 assert_eq!(struct_.dtype(), &dtype);
258 assert_eq!(
259 struct_
260 .valid_count(&mut LEGACY_SESSION.create_execution_ctx())
261 .unwrap(),
262 1
263 );
264 }
265
266 #[test]
267 fn test_append_scalar() {
268 use crate::scalar::Scalar;
269
270 let dtype = DType::Struct(
271 StructFields::from_iter([
272 ("a", DType::Primitive(I32, Nullability::Nullable)),
273 ("b", DType::Utf8(Nullability::Nullable)),
274 ]),
275 Nullability::Nullable,
276 );
277
278 let struct_fields = match &dtype {
279 DType::Struct(fields, _) => fields.clone(),
280 _ => panic!("Expected struct dtype"),
281 };
282 let mut builder = StructBuilder::new(struct_fields, Nullability::Nullable);
283
284 let struct_scalar1 = Scalar::struct_(
286 dtype.clone(),
287 vec![
288 Scalar::primitive(42i32, Nullability::Nullable),
289 Scalar::utf8("hello", Nullability::Nullable),
290 ],
291 );
292 builder.append_scalar(&struct_scalar1).unwrap();
293
294 let struct_scalar2 = Scalar::struct_(
296 dtype.clone(),
297 vec![
298 Scalar::primitive(84i32, Nullability::Nullable),
299 Scalar::utf8("world", Nullability::Nullable),
300 ],
301 );
302 builder.append_scalar(&struct_scalar2).unwrap();
303
304 let null_scalar = Scalar::null(dtype.clone());
306 builder.append_scalar(&null_scalar).unwrap();
307
308 let array = builder.finish_into_struct();
309
310 let expected = StructArray::try_from_iter_with_validity(
311 [
312 (
313 "a",
314 PrimitiveArray::from_option_iter([Some(42i32), Some(84), Some(123)])
315 .into_array(),
316 ),
317 (
318 "b",
319 <VarBinArray as FromIterator<_>>::from_iter([
320 Some("hello"),
321 Some("world"),
322 Some("x"),
323 ])
324 .into_array(),
325 ),
326 ],
327 Validity::from_iter([true, true, false]),
328 )
329 .unwrap();
330 assert_arrays_eq!(&array, &expected);
331
332 let struct_fields = match &dtype {
334 DType::Struct(fields, _) => fields.clone(),
335 _ => panic!("Expected struct dtype"),
336 };
337 let mut builder = StructBuilder::new(struct_fields, Nullability::NonNullable);
338 let wrong_scalar = Scalar::from(42i32);
339 assert!(builder.append_scalar(&wrong_scalar).is_err());
340 }
341}