1use std::fmt::Debug;
5use std::iter::once;
6
7use itertools::Itertools;
8use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::Scalar;
11
12use crate::stats::{ArrayStats, StatsSetRef};
13use crate::validity::Validity;
14use crate::vtable::{
15 ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
16 ValidityVTableFromValidityHelper,
17};
18use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
19
20mod compute;
21mod serde;
22
23vtable!(Struct);
24
25impl VTable for StructVTable {
26 type Array = StructArray;
27 type Encoding = StructEncoding;
28
29 type ArrayVTable = Self;
30 type CanonicalVTable = Self;
31 type OperationsVTable = Self;
32 type ValidityVTable = ValidityVTableFromValidityHelper;
33 type VisitorVTable = Self;
34 type ComputeVTable = NotSupported;
35 type EncodeVTable = NotSupported;
36 type SerdeVTable = Self;
37
38 fn id(_encoding: &Self::Encoding) -> EncodingId {
39 EncodingId::new_ref("vortex.struct")
40 }
41
42 fn encoding(_array: &Self::Array) -> EncodingRef {
43 EncodingRef::new_ref(StructEncoding.as_ref())
44 }
45}
46
47#[derive(Clone, Debug)]
48pub struct StructArray {
49 len: usize,
50 dtype: DType,
51 fields: Vec<ArrayRef>,
52 validity: Validity,
53 stats_set: ArrayStats,
54}
55
56#[derive(Clone, Debug)]
57pub struct StructEncoding;
58
59impl StructArray {
60 pub fn fields(&self) -> &[ArrayRef] {
61 &self.fields
62 }
63
64 pub fn field_by_name(&self, name: impl AsRef<str>) -> VortexResult<&ArrayRef> {
65 let name = name.as_ref();
66 self.field_by_name_opt(name).ok_or_else(|| {
67 vortex_err!(
68 "Field {name} not found in struct array with names {:?}",
69 self.names()
70 )
71 })
72 }
73
74 pub fn field_by_name_opt(&self, name: impl AsRef<str>) -> Option<&ArrayRef> {
75 let name = name.as_ref();
76 self.names()
77 .iter()
78 .position(|field_name| field_name.as_ref() == name)
79 .map(|idx| &self.fields[idx])
80 }
81
82 pub fn names(&self) -> &FieldNames {
83 self.struct_fields().names()
84 }
85
86 pub fn struct_fields(&self) -> &StructFields {
87 let Some(struct_dtype) = &self.dtype.as_struct() else {
88 unreachable!(
89 "struct arrays must have be a DType::Struct, this is likely an internal bug."
90 )
91 };
92 struct_dtype
93 }
94
95 pub fn new_with_len(len: usize) -> Self {
97 Self::try_new(
98 FieldNames::default(),
99 Vec::new(),
100 len,
101 Validity::NonNullable,
102 )
103 .vortex_expect("StructArray::new_with_len should not fail")
104 }
105
106 pub fn try_new(
107 names: FieldNames,
108 fields: Vec<ArrayRef>,
109 length: usize,
110 validity: Validity,
111 ) -> VortexResult<Self> {
112 let nullability = validity.nullability();
113
114 if names.len() != fields.len() {
115 vortex_bail!("Got {} names and {} fields", names.len(), fields.len());
116 }
117
118 for field in fields.iter() {
119 if field.len() != length {
120 vortex_bail!(
121 "Expected all struct fields to have length {length}, found {}",
122 fields.iter().map(|f| f.len()).format(","),
123 );
124 }
125 }
126
127 let field_dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect();
128 let dtype = DType::Struct(StructFields::new(names, field_dtypes), nullability);
129
130 if length != validity.maybe_len().unwrap_or(length) {
131 vortex_bail!(
132 "array length {} and validity length must match {}",
133 length,
134 validity
135 .maybe_len()
136 .vortex_expect("can only fail if maybe is some")
137 )
138 }
139
140 Ok(Self {
141 len: length,
142 dtype,
143 fields,
144 validity,
145 stats_set: Default::default(),
146 })
147 }
148
149 pub fn try_new_with_dtype(
150 fields: Vec<ArrayRef>,
151 dtype: StructFields,
152 length: usize,
153 validity: Validity,
154 ) -> VortexResult<Self> {
155 for (field, struct_dt) in fields.iter().zip(dtype.fields()) {
156 if field.len() != length {
157 vortex_bail!(
158 "Expected all struct fields to have length {length}, found {}",
159 field.len()
160 );
161 }
162
163 if &struct_dt != field.dtype() {
164 vortex_bail!(
165 "Expected all struct fields to have dtype {}, found {}",
166 struct_dt,
167 field.dtype()
168 );
169 }
170 }
171
172 Ok(Self {
173 len: length,
174 dtype: DType::Struct(dtype, validity.nullability()),
175 fields,
176 validity,
177 stats_set: Default::default(),
178 })
179 }
180
181 pub fn from_fields<N: AsRef<str>>(items: &[(N, ArrayRef)]) -> VortexResult<Self> {
182 Self::try_from_iter(items.iter().map(|(a, b)| (a, b.to_array())))
183 }
184
185 pub fn try_from_iter_with_validity<
186 N: AsRef<str>,
187 A: IntoArray,
188 T: IntoIterator<Item = (N, A)>,
189 >(
190 iter: T,
191 validity: Validity,
192 ) -> VortexResult<Self> {
193 let (names, fields): (Vec<FieldName>, Vec<ArrayRef>) = iter
194 .into_iter()
195 .map(|(name, fields)| (FieldName::from(name.as_ref()), fields.into_array()))
196 .unzip();
197 let len = fields
198 .first()
199 .map(|f| f.len())
200 .ok_or_else(|| vortex_err!("StructArray cannot be constructed from an empty slice of arrays because the length is unspecified"))?;
201
202 Self::try_new(FieldNames::from_iter(names), fields, len, validity)
203 }
204
205 pub fn try_from_iter<N: AsRef<str>, A: IntoArray, T: IntoIterator<Item = (N, A)>>(
206 iter: T,
207 ) -> VortexResult<Self> {
208 Self::try_from_iter_with_validity(iter, Validity::NonNullable)
209 }
210
211 #[allow(clippy::same_name_method)]
219 pub fn project(&self, projection: &[FieldName]) -> VortexResult<Self> {
220 let mut children = Vec::with_capacity(projection.len());
221 let mut names = Vec::with_capacity(projection.len());
222
223 for f_name in projection.iter() {
224 let idx = self
225 .names()
226 .iter()
227 .position(|name| name == f_name)
228 .ok_or_else(|| vortex_err!("Unknown field {f_name}"))?;
229
230 names.push(self.names()[idx].clone());
231 children.push(self.fields()[idx].clone());
232 }
233
234 StructArray::try_new(
235 FieldNames::from(names.as_slice()),
236 children,
237 self.len(),
238 self.validity().clone(),
239 )
240 }
241
242 pub fn remove_column(&mut self, name: impl Into<FieldName>) -> Option<ArrayRef> {
245 let name = name.into();
246
247 let struct_dtype = self.struct_fields().clone();
248
249 let position = struct_dtype
250 .names()
251 .iter()
252 .position(|field_name| field_name.as_ref() == name.as_ref())?;
253
254 let field = self.fields.remove(position);
255
256 let new_dtype = struct_dtype.without_field(position);
257 self.dtype = DType::Struct(new_dtype, self.dtype.nullability());
258
259 Some(field)
260 }
261
262 pub fn with_column(&self, name: impl Into<FieldName>, array: ArrayRef) -> VortexResult<Self> {
264 let name = name.into();
265 let struct_dtype = self.struct_fields().clone();
266
267 let names = struct_dtype.names().iter().cloned().chain(once(name));
268 let types = struct_dtype.fields().chain(once(array.dtype().clone()));
269 let new_fields = StructFields::new(names.collect(), types.collect());
270
271 let mut children = self.fields.clone();
272 children.push(array);
273
274 Self::try_new_with_dtype(children, new_fields, self.len, self.validity.clone())
275 }
276}
277
278impl ValidityHelper for StructArray {
279 fn validity(&self) -> &Validity {
280 &self.validity
281 }
282}
283
284impl ArrayVTable<StructVTable> for StructVTable {
285 fn len(array: &StructArray) -> usize {
286 array.len
287 }
288
289 fn dtype(array: &StructArray) -> &DType {
290 &array.dtype
291 }
292
293 fn stats(array: &StructArray) -> StatsSetRef<'_> {
294 array.stats_set.to_ref(array.as_ref())
295 }
296}
297
298impl CanonicalVTable<StructVTable> for StructVTable {
299 fn canonicalize(array: &StructArray) -> VortexResult<Canonical> {
300 Ok(Canonical::Struct(array.clone()))
301 }
302}
303
304impl OperationsVTable<StructVTable> for StructVTable {
305 fn slice(array: &StructArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
306 let fields = array
307 .fields()
308 .iter()
309 .map(|field| field.slice(start, stop))
310 .try_collect()?;
311 StructArray::try_new_with_dtype(
312 fields,
313 array.struct_fields().clone(),
314 stop - start,
315 array.validity().slice(start, stop)?,
316 )
317 .map(|a| a.into_array())
318 }
319
320 fn scalar_at(array: &StructArray, index: usize) -> VortexResult<Scalar> {
321 Ok(Scalar::struct_(
322 array.dtype().clone(),
323 array
324 .fields()
325 .iter()
326 .map(|field| field.scalar_at(index))
327 .try_collect()?,
328 ))
329 }
330}
331
332#[cfg(test)]
333mod test {
334 use vortex_buffer::buffer;
335 use vortex_dtype::{DType, FieldName, FieldNames, Nullability, PType};
336
337 use crate::IntoArray;
338 use crate::arrays::primitive::PrimitiveArray;
339 use crate::arrays::struct_::StructArray;
340 use crate::arrays::varbin::VarBinArray;
341 use crate::arrays::{BoolArray, BoolVTable, PrimitiveVTable};
342 use crate::validity::Validity;
343
344 #[test]
345 fn test_project() {
346 let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
347 let ys = VarBinArray::from_vec(
348 vec!["a", "b", "c", "d", "e"],
349 DType::Utf8(Nullability::NonNullable),
350 );
351 let zs = BoolArray::from_iter([true, true, true, false, false]);
352
353 let struct_a = StructArray::try_new(
354 FieldNames::from(["xs", "ys", "zs"]),
355 vec![xs.into_array(), ys.into_array(), zs.into_array()],
356 5,
357 Validity::NonNullable,
358 )
359 .unwrap();
360
361 let struct_b = struct_a
362 .project(&[FieldName::from("zs"), FieldName::from("xs")])
363 .unwrap();
364 assert_eq!(
365 struct_b.names().as_ref(),
366 [FieldName::from("zs"), FieldName::from("xs")],
367 );
368
369 assert_eq!(struct_b.len(), 5);
370
371 let bools = &struct_b.fields[0];
372 assert_eq!(
373 bools
374 .as_::<BoolVTable>()
375 .boolean_buffer()
376 .iter()
377 .collect::<Vec<_>>(),
378 vec![true, true, true, false, false]
379 );
380
381 let prims = &struct_b.fields[1];
382 assert_eq!(
383 prims.as_::<PrimitiveVTable>().as_slice::<i64>(),
384 [0i64, 1, 2, 3, 4]
385 );
386 }
387
388 #[test]
389 fn test_remove_column() {
390 let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
391 let ys = PrimitiveArray::new(buffer![4u64, 5, 6, 7, 8], Validity::NonNullable);
392
393 let mut struct_a = StructArray::try_new(
394 FieldNames::from(["xs", "ys"]),
395 vec![xs.into_array(), ys.into_array()],
396 5,
397 Validity::NonNullable,
398 )
399 .unwrap();
400
401 let removed = struct_a.remove_column("xs").unwrap();
402 assert_eq!(
403 removed.dtype(),
404 &DType::Primitive(PType::I64, Nullability::NonNullable)
405 );
406 assert_eq!(
407 removed.as_::<PrimitiveVTable>().as_slice::<i64>(),
408 [0i64, 1, 2, 3, 4]
409 );
410
411 assert_eq!(struct_a.names(), &[FieldName::from("ys")].into());
412 assert_eq!(struct_a.fields.len(), 1);
413 assert_eq!(struct_a.len(), 5);
414 assert_eq!(
415 struct_a.fields[0].dtype(),
416 &DType::Primitive(PType::U64, Nullability::NonNullable)
417 );
418 assert_eq!(
419 struct_a.fields[0]
420 .as_::<PrimitiveVTable>()
421 .as_slice::<u64>(),
422 [4u64, 5, 6, 7, 8]
423 );
424
425 let empty = struct_a.remove_column("non_existent");
426 assert!(
427 empty.is_none(),
428 "Expected None when removing non-existent column"
429 );
430 assert_eq!(struct_a.names(), &[FieldName::from("ys")].into());
431 }
432}