1use std::fmt::Debug;
5use std::iter::once;
6
7use itertools::Itertools;
8use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::Scalar;
11
12use crate::stats::{ArrayStats, StatsSetRef};
13use crate::validity::Validity;
14use crate::vtable::{
15 ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
16 ValidityVTableFromValidityHelper,
17};
18use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
19
20mod compute;
21mod serde;
22
23vtable!(Struct);
24
25impl VTable for StructVTable {
26 type Array = StructArray;
27 type Encoding = StructEncoding;
28
29 type ArrayVTable = Self;
30 type CanonicalVTable = Self;
31 type OperationsVTable = Self;
32 type ValidityVTable = ValidityVTableFromValidityHelper;
33 type VisitorVTable = Self;
34 type ComputeVTable = NotSupported;
35 type EncodeVTable = NotSupported;
36 type PipelineVTable = NotSupported;
37 type SerdeVTable = Self;
38
39 fn id(_encoding: &Self::Encoding) -> EncodingId {
40 EncodingId::new_ref("vortex.struct")
41 }
42
43 fn encoding(_array: &Self::Array) -> EncodingRef {
44 EncodingRef::new_ref(StructEncoding.as_ref())
45 }
46}
47
48#[derive(Clone, Debug)]
167pub struct StructArray {
168 len: usize,
169 dtype: DType,
170 fields: Vec<ArrayRef>,
171 validity: Validity,
172 stats_set: ArrayStats,
173}
174
175#[derive(Clone, Debug)]
176pub struct StructEncoding;
177
178impl StructArray {
179 pub fn fields(&self) -> &[ArrayRef] {
180 &self.fields
181 }
182
183 pub fn field_by_name(&self, name: impl AsRef<str>) -> VortexResult<&ArrayRef> {
184 let name = name.as_ref();
185 self.field_by_name_opt(name).ok_or_else(|| {
186 vortex_err!(
187 "Field {name} not found in struct array with names {:?}",
188 self.names()
189 )
190 })
191 }
192
193 pub fn field_by_name_opt(&self, name: impl AsRef<str>) -> Option<&ArrayRef> {
194 let name = name.as_ref();
195 self.names()
196 .iter()
197 .position(|field_name| field_name.as_ref() == name)
198 .map(|idx| &self.fields[idx])
199 }
200
201 pub fn names(&self) -> &FieldNames {
202 self.struct_fields().names()
203 }
204
205 pub fn struct_fields(&self) -> &StructFields {
206 let Some(struct_dtype) = &self.dtype.as_struct_opt() else {
207 unreachable!(
208 "struct arrays must have be a DType::Struct, this is likely an internal bug."
209 )
210 };
211 struct_dtype
212 }
213
214 pub fn new_with_len(len: usize) -> Self {
216 Self::try_new(
217 FieldNames::default(),
218 Vec::new(),
219 len,
220 Validity::NonNullable,
221 )
222 .vortex_expect("StructArray::new_with_len should not fail")
223 }
224
225 pub fn try_new(
226 names: FieldNames,
227 fields: Vec<ArrayRef>,
228 length: usize,
229 validity: Validity,
230 ) -> VortexResult<Self> {
231 let nullability = validity.nullability();
232
233 if names.len() != fields.len() {
234 vortex_bail!("Got {} names and {} fields", names.len(), fields.len());
235 }
236
237 for field in fields.iter() {
238 if field.len() != length {
239 vortex_bail!(
240 "Expected all struct fields to have length {length}, found {}",
241 fields.iter().map(|f| f.len()).format(","),
242 );
243 }
244 }
245
246 let field_dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect();
247 let dtype = DType::Struct(StructFields::new(names, field_dtypes), nullability);
248
249 if length != validity.maybe_len().unwrap_or(length) {
250 vortex_bail!(
251 "array length {} and validity length must match {}",
252 length,
253 validity
254 .maybe_len()
255 .vortex_expect("can only fail if maybe is some")
256 )
257 }
258
259 Ok(Self {
260 len: length,
261 dtype,
262 fields,
263 validity,
264 stats_set: Default::default(),
265 })
266 }
267
268 pub(crate) fn new_unchecked(
274 fields: Vec<ArrayRef>,
275 dtype: StructFields,
276 length: usize,
277 validity: Validity,
278 ) -> Self {
279 Self {
280 len: length,
281 dtype: DType::Struct(dtype, validity.nullability()),
282 fields,
283 validity,
284 stats_set: Default::default(),
285 }
286 }
287
288 pub fn try_new_with_dtype(
289 fields: Vec<ArrayRef>,
290 dtype: StructFields,
291 length: usize,
292 validity: Validity,
293 ) -> VortexResult<Self> {
294 for (field, struct_dt) in fields.iter().zip(dtype.fields()) {
295 if field.len() != length {
296 vortex_bail!(
297 "Expected all struct fields to have length {length}, found {}",
298 field.len()
299 );
300 }
301
302 if &struct_dt != field.dtype() {
303 vortex_bail!(
304 "Expected all struct fields to have dtype {}, found {}",
305 struct_dt,
306 field.dtype()
307 );
308 }
309 }
310
311 Ok(Self {
312 len: length,
313 dtype: DType::Struct(dtype, validity.nullability()),
314 fields,
315 validity,
316 stats_set: Default::default(),
317 })
318 }
319
320 pub fn from_fields<N: AsRef<str>>(items: &[(N, ArrayRef)]) -> VortexResult<Self> {
321 Self::try_from_iter(items.iter().map(|(a, b)| (a, b.to_array())))
322 }
323
324 pub fn try_from_iter_with_validity<
325 N: AsRef<str>,
326 A: IntoArray,
327 T: IntoIterator<Item = (N, A)>,
328 >(
329 iter: T,
330 validity: Validity,
331 ) -> VortexResult<Self> {
332 let (names, fields): (Vec<FieldName>, Vec<ArrayRef>) = iter
333 .into_iter()
334 .map(|(name, fields)| (FieldName::from(name.as_ref()), fields.into_array()))
335 .unzip();
336 let len = fields
337 .first()
338 .map(|f| f.len())
339 .ok_or_else(|| vortex_err!("StructArray cannot be constructed from an empty slice of arrays because the length is unspecified"))?;
340
341 Self::try_new(FieldNames::from_iter(names), fields, len, validity)
342 }
343
344 pub fn try_from_iter<N: AsRef<str>, A: IntoArray, T: IntoIterator<Item = (N, A)>>(
345 iter: T,
346 ) -> VortexResult<Self> {
347 Self::try_from_iter_with_validity(iter, Validity::NonNullable)
348 }
349
350 #[allow(clippy::same_name_method)]
358 pub fn project(&self, projection: &[FieldName]) -> VortexResult<Self> {
359 let mut children = Vec::with_capacity(projection.len());
360 let mut names = Vec::with_capacity(projection.len());
361
362 for f_name in projection.iter() {
363 let idx = self
364 .names()
365 .iter()
366 .position(|name| name == f_name)
367 .ok_or_else(|| vortex_err!("Unknown field {f_name}"))?;
368
369 names.push(self.names()[idx].clone());
370 children.push(self.fields()[idx].clone());
371 }
372
373 StructArray::try_new(
374 FieldNames::from(names.as_slice()),
375 children,
376 self.len(),
377 self.validity().clone(),
378 )
379 }
380
381 pub fn remove_column(&mut self, name: impl Into<FieldName>) -> Option<ArrayRef> {
384 let name = name.into();
385
386 let struct_dtype = self.struct_fields().clone();
387
388 let position = struct_dtype
389 .names()
390 .iter()
391 .position(|field_name| field_name.as_ref() == name.as_ref())?;
392
393 let field = self.fields.remove(position);
394
395 if let Ok(new_dtype) = struct_dtype.without_field(position) {
396 self.dtype = DType::Struct(new_dtype, self.dtype.nullability());
397 return Some(field);
398 }
399 None
400 }
401
402 pub fn with_column(&self, name: impl Into<FieldName>, array: ArrayRef) -> VortexResult<Self> {
404 let name = name.into();
405 let struct_dtype = self.struct_fields().clone();
406
407 let names = struct_dtype.names().iter().cloned().chain(once(name));
408 let types = struct_dtype.fields().chain(once(array.dtype().clone()));
409 let new_fields = StructFields::new(names.collect(), types.collect());
410
411 let mut children = self.fields.clone();
412 children.push(array);
413
414 Self::try_new_with_dtype(children, new_fields, self.len, self.validity.clone())
415 }
416}
417
418impl ValidityHelper for StructArray {
419 fn validity(&self) -> &Validity {
420 &self.validity
421 }
422}
423
424impl ArrayVTable<StructVTable> for StructVTable {
425 fn len(array: &StructArray) -> usize {
426 array.len
427 }
428
429 fn dtype(array: &StructArray) -> &DType {
430 &array.dtype
431 }
432
433 fn stats(array: &StructArray) -> StatsSetRef<'_> {
434 array.stats_set.to_ref(array.as_ref())
435 }
436}
437
438impl CanonicalVTable<StructVTable> for StructVTable {
439 fn canonicalize(array: &StructArray) -> VortexResult<Canonical> {
440 Ok(Canonical::Struct(array.clone()))
441 }
442}
443
444impl OperationsVTable<StructVTable> for StructVTable {
445 fn slice(array: &StructArray, start: usize, stop: usize) -> ArrayRef {
446 let fields = array
447 .fields()
448 .iter()
449 .map(|field| field.slice(start, stop))
450 .collect_vec();
451 StructArray::new_unchecked(
452 fields,
453 array.struct_fields().clone(),
454 stop - start,
455 array.validity().slice(start, stop),
456 )
457 .into_array()
458 }
459
460 fn scalar_at(array: &StructArray, index: usize) -> Scalar {
461 Scalar::struct_(
462 array.dtype().clone(),
463 array
464 .fields()
465 .iter()
466 .map(|field| field.scalar_at(index))
467 .collect_vec(),
468 )
469 }
470}
471
472#[cfg(test)]
473mod test {
474 use vortex_buffer::buffer;
475 use vortex_dtype::{DType, FieldName, FieldNames, Nullability, PType};
476
477 use crate::IntoArray;
478 use crate::arrays::primitive::PrimitiveArray;
479 use crate::arrays::struct_::StructArray;
480 use crate::arrays::varbin::VarBinArray;
481 use crate::arrays::{BoolArray, BoolVTable, PrimitiveVTable};
482 use crate::validity::Validity;
483
484 #[test]
485 fn test_project() {
486 let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
487 let ys = VarBinArray::from_vec(
488 vec!["a", "b", "c", "d", "e"],
489 DType::Utf8(Nullability::NonNullable),
490 );
491 let zs = BoolArray::from_iter([true, true, true, false, false]);
492
493 let struct_a = StructArray::try_new(
494 FieldNames::from(["xs", "ys", "zs"]),
495 vec![xs.into_array(), ys.into_array(), zs.into_array()],
496 5,
497 Validity::NonNullable,
498 )
499 .unwrap();
500
501 let struct_b = struct_a
502 .project(&[FieldName::from("zs"), FieldName::from("xs")])
503 .unwrap();
504 assert_eq!(
505 struct_b.names().as_ref(),
506 [FieldName::from("zs"), FieldName::from("xs")],
507 );
508
509 assert_eq!(struct_b.len(), 5);
510
511 let bools = &struct_b.fields[0];
512 assert_eq!(
513 bools
514 .as_::<BoolVTable>()
515 .boolean_buffer()
516 .iter()
517 .collect::<Vec<_>>(),
518 vec![true, true, true, false, false]
519 );
520
521 let prims = &struct_b.fields[1];
522 assert_eq!(
523 prims.as_::<PrimitiveVTable>().as_slice::<i64>(),
524 [0i64, 1, 2, 3, 4]
525 );
526 }
527
528 #[test]
529 fn test_remove_column() {
530 let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
531 let ys = PrimitiveArray::new(buffer![4u64, 5, 6, 7, 8], Validity::NonNullable);
532
533 let mut struct_a = StructArray::try_new(
534 FieldNames::from(["xs", "ys"]),
535 vec![xs.into_array(), ys.into_array()],
536 5,
537 Validity::NonNullable,
538 )
539 .unwrap();
540
541 let removed = struct_a.remove_column("xs").unwrap();
542 assert_eq!(
543 removed.dtype(),
544 &DType::Primitive(PType::I64, Nullability::NonNullable)
545 );
546 assert_eq!(
547 removed.as_::<PrimitiveVTable>().as_slice::<i64>(),
548 [0i64, 1, 2, 3, 4]
549 );
550
551 assert_eq!(struct_a.names(), &["ys"]);
552 assert_eq!(struct_a.fields.len(), 1);
553 assert_eq!(struct_a.len(), 5);
554 assert_eq!(
555 struct_a.fields[0].dtype(),
556 &DType::Primitive(PType::U64, Nullability::NonNullable)
557 );
558 assert_eq!(
559 struct_a.fields[0]
560 .as_::<PrimitiveVTable>()
561 .as_slice::<u64>(),
562 [4u64, 5, 6, 7, 8]
563 );
564
565 let empty = struct_a.remove_column("non_existent");
566 assert!(
567 empty.is_none(),
568 "Expected None when removing non-existent column"
569 );
570 assert_eq!(struct_a.names(), &["ys"]);
571 }
572
573 #[test]
574 fn test_duplicate_field_names() {
575 let field1 = buffer![1i32, 2, 3].into_array();
577 let field2 = buffer![10i32, 20, 30].into_array();
578 let field3 = buffer![100i32, 200, 300].into_array();
579
580 let struct_array = StructArray::try_new(
582 FieldNames::from(["value", "other", "value"]),
583 vec![field1, field2, field3],
584 3,
585 Validity::NonNullable,
586 )
587 .unwrap();
588
589 let first_value_field = struct_array.field_by_name("value").unwrap();
591 assert_eq!(
592 first_value_field.as_::<PrimitiveVTable>().as_slice::<i32>(),
593 [1i32, 2, 3] );
595
596 let opt_field = struct_array.field_by_name_opt("value").unwrap();
598 assert_eq!(
599 opt_field.as_::<PrimitiveVTable>().as_slice::<i32>(),
600 [1i32, 2, 3] );
602
603 let third_field = &struct_array.fields()[2];
605 assert_eq!(
606 third_field.as_::<PrimitiveVTable>().as_slice::<i32>(),
607 [100i32, 200, 300]
608 );
609 }
610}