1use std::fmt::Debug;
5use std::iter::once;
6
7use itertools::Itertools;
8use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::Scalar;
11
12use crate::stats::{ArrayStats, StatsSetRef};
13use crate::validity::Validity;
14use crate::vtable::{
15 ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
16 ValidityVTableFromValidityHelper,
17};
18use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
19
20mod compute;
21mod serde;
22
23vtable!(Struct);
24
25impl VTable for StructVTable {
26 type Array = StructArray;
27 type Encoding = StructEncoding;
28
29 type ArrayVTable = Self;
30 type CanonicalVTable = Self;
31 type OperationsVTable = Self;
32 type ValidityVTable = ValidityVTableFromValidityHelper;
33 type VisitorVTable = Self;
34 type ComputeVTable = NotSupported;
35 type EncodeVTable = NotSupported;
36 type SerdeVTable = Self;
37
38 fn id(_encoding: &Self::Encoding) -> EncodingId {
39 EncodingId::new_ref("vortex.struct")
40 }
41
42 fn encoding(_array: &Self::Array) -> EncodingRef {
43 EncodingRef::new_ref(StructEncoding.as_ref())
44 }
45}
46
47#[derive(Clone, Debug)]
166pub struct StructArray {
167 len: usize,
168 dtype: DType,
169 fields: Vec<ArrayRef>,
170 validity: Validity,
171 stats_set: ArrayStats,
172}
173
174#[derive(Clone, Debug)]
175pub struct StructEncoding;
176
177impl StructArray {
178 pub fn fields(&self) -> &[ArrayRef] {
179 &self.fields
180 }
181
182 pub fn field_by_name(&self, name: impl AsRef<str>) -> VortexResult<&ArrayRef> {
183 let name = name.as_ref();
184 self.field_by_name_opt(name).ok_or_else(|| {
185 vortex_err!(
186 "Field {name} not found in struct array with names {:?}",
187 self.names()
188 )
189 })
190 }
191
192 pub fn field_by_name_opt(&self, name: impl AsRef<str>) -> Option<&ArrayRef> {
193 let name = name.as_ref();
194 self.names()
195 .iter()
196 .position(|field_name| field_name.as_ref() == name)
197 .map(|idx| &self.fields[idx])
198 }
199
200 pub fn names(&self) -> &FieldNames {
201 self.struct_fields().names()
202 }
203
204 pub fn struct_fields(&self) -> &StructFields {
205 let Some(struct_dtype) = &self.dtype.as_struct_opt() else {
206 unreachable!(
207 "struct arrays must have be a DType::Struct, this is likely an internal bug."
208 )
209 };
210 struct_dtype
211 }
212
213 pub fn new_with_len(len: usize) -> Self {
215 Self::try_new(
216 FieldNames::default(),
217 Vec::new(),
218 len,
219 Validity::NonNullable,
220 )
221 .vortex_expect("StructArray::new_with_len should not fail")
222 }
223
224 pub fn try_new(
225 names: FieldNames,
226 fields: Vec<ArrayRef>,
227 length: usize,
228 validity: Validity,
229 ) -> VortexResult<Self> {
230 let nullability = validity.nullability();
231
232 if names.len() != fields.len() {
233 vortex_bail!("Got {} names and {} fields", names.len(), fields.len());
234 }
235
236 for field in fields.iter() {
237 if field.len() != length {
238 vortex_bail!(
239 "Expected all struct fields to have length {length}, found {}",
240 fields.iter().map(|f| f.len()).format(","),
241 );
242 }
243 }
244
245 let field_dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect();
246 let dtype = DType::Struct(StructFields::new(names, field_dtypes), nullability);
247
248 if length != validity.maybe_len().unwrap_or(length) {
249 vortex_bail!(
250 "array length {} and validity length must match {}",
251 length,
252 validity
253 .maybe_len()
254 .vortex_expect("can only fail if maybe is some")
255 )
256 }
257
258 Ok(Self {
259 len: length,
260 dtype,
261 fields,
262 validity,
263 stats_set: Default::default(),
264 })
265 }
266
267 pub(crate) fn new_unchecked(
273 fields: Vec<ArrayRef>,
274 dtype: StructFields,
275 length: usize,
276 validity: Validity,
277 ) -> Self {
278 Self {
279 len: length,
280 dtype: DType::Struct(dtype, validity.nullability()),
281 fields,
282 validity,
283 stats_set: Default::default(),
284 }
285 }
286
287 pub fn try_new_with_dtype(
288 fields: Vec<ArrayRef>,
289 dtype: StructFields,
290 length: usize,
291 validity: Validity,
292 ) -> VortexResult<Self> {
293 for (field, struct_dt) in fields.iter().zip(dtype.fields()) {
294 if field.len() != length {
295 vortex_bail!(
296 "Expected all struct fields to have length {length}, found {}",
297 field.len()
298 );
299 }
300
301 if &struct_dt != field.dtype() {
302 vortex_bail!(
303 "Expected all struct fields to have dtype {}, found {}",
304 struct_dt,
305 field.dtype()
306 );
307 }
308 }
309
310 Ok(Self {
311 len: length,
312 dtype: DType::Struct(dtype, validity.nullability()),
313 fields,
314 validity,
315 stats_set: Default::default(),
316 })
317 }
318
319 pub fn from_fields<N: AsRef<str>>(items: &[(N, ArrayRef)]) -> VortexResult<Self> {
320 Self::try_from_iter(items.iter().map(|(a, b)| (a, b.to_array())))
321 }
322
323 pub fn try_from_iter_with_validity<
324 N: AsRef<str>,
325 A: IntoArray,
326 T: IntoIterator<Item = (N, A)>,
327 >(
328 iter: T,
329 validity: Validity,
330 ) -> VortexResult<Self> {
331 let (names, fields): (Vec<FieldName>, Vec<ArrayRef>) = iter
332 .into_iter()
333 .map(|(name, fields)| (FieldName::from(name.as_ref()), fields.into_array()))
334 .unzip();
335 let len = fields
336 .first()
337 .map(|f| f.len())
338 .ok_or_else(|| vortex_err!("StructArray cannot be constructed from an empty slice of arrays because the length is unspecified"))?;
339
340 Self::try_new(FieldNames::from_iter(names), fields, len, validity)
341 }
342
343 pub fn try_from_iter<N: AsRef<str>, A: IntoArray, T: IntoIterator<Item = (N, A)>>(
344 iter: T,
345 ) -> VortexResult<Self> {
346 Self::try_from_iter_with_validity(iter, Validity::NonNullable)
347 }
348
349 #[allow(clippy::same_name_method)]
357 pub fn project(&self, projection: &[FieldName]) -> VortexResult<Self> {
358 let mut children = Vec::with_capacity(projection.len());
359 let mut names = Vec::with_capacity(projection.len());
360
361 for f_name in projection.iter() {
362 let idx = self
363 .names()
364 .iter()
365 .position(|name| name == f_name)
366 .ok_or_else(|| vortex_err!("Unknown field {f_name}"))?;
367
368 names.push(self.names()[idx].clone());
369 children.push(self.fields()[idx].clone());
370 }
371
372 StructArray::try_new(
373 FieldNames::from(names.as_slice()),
374 children,
375 self.len(),
376 self.validity().clone(),
377 )
378 }
379
380 pub fn remove_column(&mut self, name: impl Into<FieldName>) -> Option<ArrayRef> {
383 let name = name.into();
384
385 let struct_dtype = self.struct_fields().clone();
386
387 let position = struct_dtype
388 .names()
389 .iter()
390 .position(|field_name| field_name.as_ref() == name.as_ref())?;
391
392 let field = self.fields.remove(position);
393
394 if let Ok(new_dtype) = struct_dtype.without_field(position) {
395 self.dtype = DType::Struct(new_dtype, self.dtype.nullability());
396 return Some(field);
397 }
398 None
399 }
400
401 pub fn with_column(&self, name: impl Into<FieldName>, array: ArrayRef) -> VortexResult<Self> {
403 let name = name.into();
404 let struct_dtype = self.struct_fields().clone();
405
406 let names = struct_dtype.names().iter().cloned().chain(once(name));
407 let types = struct_dtype.fields().chain(once(array.dtype().clone()));
408 let new_fields = StructFields::new(names.collect(), types.collect());
409
410 let mut children = self.fields.clone();
411 children.push(array);
412
413 Self::try_new_with_dtype(children, new_fields, self.len, self.validity.clone())
414 }
415}
416
417impl ValidityHelper for StructArray {
418 fn validity(&self) -> &Validity {
419 &self.validity
420 }
421}
422
423impl ArrayVTable<StructVTable> for StructVTable {
424 fn len(array: &StructArray) -> usize {
425 array.len
426 }
427
428 fn dtype(array: &StructArray) -> &DType {
429 &array.dtype
430 }
431
432 fn stats(array: &StructArray) -> StatsSetRef<'_> {
433 array.stats_set.to_ref(array.as_ref())
434 }
435}
436
437impl CanonicalVTable<StructVTable> for StructVTable {
438 fn canonicalize(array: &StructArray) -> VortexResult<Canonical> {
439 Ok(Canonical::Struct(array.clone()))
440 }
441}
442
443impl OperationsVTable<StructVTable> for StructVTable {
444 fn slice(array: &StructArray, start: usize, stop: usize) -> ArrayRef {
445 let fields = array
446 .fields()
447 .iter()
448 .map(|field| field.slice(start, stop))
449 .collect_vec();
450 StructArray::new_unchecked(
451 fields,
452 array.struct_fields().clone(),
453 stop - start,
454 array.validity().slice(start, stop),
455 )
456 .into_array()
457 }
458
459 fn scalar_at(array: &StructArray, index: usize) -> Scalar {
460 Scalar::struct_(
461 array.dtype().clone(),
462 array
463 .fields()
464 .iter()
465 .map(|field| field.scalar_at(index))
466 .collect_vec(),
467 )
468 }
469}
470
471#[cfg(test)]
472mod test {
473 use vortex_buffer::buffer;
474 use vortex_dtype::{DType, FieldName, FieldNames, Nullability, PType};
475
476 use crate::IntoArray;
477 use crate::arrays::primitive::PrimitiveArray;
478 use crate::arrays::struct_::StructArray;
479 use crate::arrays::varbin::VarBinArray;
480 use crate::arrays::{BoolArray, BoolVTable, PrimitiveVTable};
481 use crate::validity::Validity;
482
483 #[test]
484 fn test_project() {
485 let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
486 let ys = VarBinArray::from_vec(
487 vec!["a", "b", "c", "d", "e"],
488 DType::Utf8(Nullability::NonNullable),
489 );
490 let zs = BoolArray::from_iter([true, true, true, false, false]);
491
492 let struct_a = StructArray::try_new(
493 FieldNames::from(["xs", "ys", "zs"]),
494 vec![xs.into_array(), ys.into_array(), zs.into_array()],
495 5,
496 Validity::NonNullable,
497 )
498 .unwrap();
499
500 let struct_b = struct_a
501 .project(&[FieldName::from("zs"), FieldName::from("xs")])
502 .unwrap();
503 assert_eq!(
504 struct_b.names().as_ref(),
505 [FieldName::from("zs"), FieldName::from("xs")],
506 );
507
508 assert_eq!(struct_b.len(), 5);
509
510 let bools = &struct_b.fields[0];
511 assert_eq!(
512 bools
513 .as_::<BoolVTable>()
514 .boolean_buffer()
515 .iter()
516 .collect::<Vec<_>>(),
517 vec![true, true, true, false, false]
518 );
519
520 let prims = &struct_b.fields[1];
521 assert_eq!(
522 prims.as_::<PrimitiveVTable>().as_slice::<i64>(),
523 [0i64, 1, 2, 3, 4]
524 );
525 }
526
527 #[test]
528 fn test_remove_column() {
529 let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
530 let ys = PrimitiveArray::new(buffer![4u64, 5, 6, 7, 8], Validity::NonNullable);
531
532 let mut struct_a = StructArray::try_new(
533 FieldNames::from(["xs", "ys"]),
534 vec![xs.into_array(), ys.into_array()],
535 5,
536 Validity::NonNullable,
537 )
538 .unwrap();
539
540 let removed = struct_a.remove_column("xs").unwrap();
541 assert_eq!(
542 removed.dtype(),
543 &DType::Primitive(PType::I64, Nullability::NonNullable)
544 );
545 assert_eq!(
546 removed.as_::<PrimitiveVTable>().as_slice::<i64>(),
547 [0i64, 1, 2, 3, 4]
548 );
549
550 assert_eq!(struct_a.names(), &["ys"]);
551 assert_eq!(struct_a.fields.len(), 1);
552 assert_eq!(struct_a.len(), 5);
553 assert_eq!(
554 struct_a.fields[0].dtype(),
555 &DType::Primitive(PType::U64, Nullability::NonNullable)
556 );
557 assert_eq!(
558 struct_a.fields[0]
559 .as_::<PrimitiveVTable>()
560 .as_slice::<u64>(),
561 [4u64, 5, 6, 7, 8]
562 );
563
564 let empty = struct_a.remove_column("non_existent");
565 assert!(
566 empty.is_none(),
567 "Expected None when removing non-existent column"
568 );
569 assert_eq!(struct_a.names(), &["ys"]);
570 }
571
572 #[test]
573 fn test_duplicate_field_names() {
574 let field1 = buffer![1i32, 2, 3].into_array();
576 let field2 = buffer![10i32, 20, 30].into_array();
577 let field3 = buffer![100i32, 200, 300].into_array();
578
579 let struct_array = StructArray::try_new(
581 FieldNames::from(["value", "other", "value"]),
582 vec![field1, field2, field3],
583 3,
584 Validity::NonNullable,
585 )
586 .unwrap();
587
588 let first_value_field = struct_array.field_by_name("value").unwrap();
590 assert_eq!(
591 first_value_field.as_::<PrimitiveVTable>().as_slice::<i32>(),
592 [1i32, 2, 3] );
594
595 let opt_field = struct_array.field_by_name_opt("value").unwrap();
597 assert_eq!(
598 opt_field.as_::<PrimitiveVTable>().as_slice::<i32>(),
599 [1i32, 2, 3] );
601
602 let third_field = &struct_array.fields()[2];
604 assert_eq!(
605 third_field.as_::<PrimitiveVTable>().as_slice::<i32>(),
606 [100i32, 200, 300]
607 );
608 }
609}