1use std::fmt::Debug;
5use std::iter::once;
6use std::ops::Range;
7
8use itertools::Itertools;
9use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
10use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
11use vortex_scalar::Scalar;
12
13use crate::stats::{ArrayStats, StatsSetRef};
14use crate::validity::Validity;
15use crate::vtable::{
16 ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
17 ValidityVTableFromValidityHelper,
18};
19use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
20
21mod compute;
22mod serde;
23
24vtable!(Struct);
25
26impl VTable for StructVTable {
27 type Array = StructArray;
28 type Encoding = StructEncoding;
29
30 type ArrayVTable = Self;
31 type CanonicalVTable = Self;
32 type OperationsVTable = Self;
33 type ValidityVTable = ValidityVTableFromValidityHelper;
34 type VisitorVTable = Self;
35 type ComputeVTable = NotSupported;
36 type EncodeVTable = NotSupported;
37 type PipelineVTable = NotSupported;
38 type SerdeVTable = Self;
39
40 fn id(_encoding: &Self::Encoding) -> EncodingId {
41 EncodingId::new_ref("vortex.struct")
42 }
43
44 fn encoding(_array: &Self::Array) -> EncodingRef {
45 EncodingRef::new_ref(StructEncoding.as_ref())
46 }
47}
48
49#[derive(Clone, Debug)]
168pub struct StructArray {
169 len: usize,
170 dtype: DType,
171 fields: Vec<ArrayRef>,
172 validity: Validity,
173 stats_set: ArrayStats,
174}
175
176#[derive(Clone, Debug)]
177pub struct StructEncoding;
178
179impl StructArray {
180 pub fn fields(&self) -> &[ArrayRef] {
181 &self.fields
182 }
183
184 pub fn field_by_name(&self, name: impl AsRef<str>) -> VortexResult<&ArrayRef> {
185 let name = name.as_ref();
186 self.field_by_name_opt(name).ok_or_else(|| {
187 vortex_err!(
188 "Field {name} not found in struct array with names {:?}",
189 self.names()
190 )
191 })
192 }
193
194 pub fn field_by_name_opt(&self, name: impl AsRef<str>) -> Option<&ArrayRef> {
195 let name = name.as_ref();
196 self.names()
197 .iter()
198 .position(|field_name| field_name.as_ref() == name)
199 .map(|idx| &self.fields[idx])
200 }
201
202 pub fn names(&self) -> &FieldNames {
203 self.struct_fields().names()
204 }
205
206 pub fn struct_fields(&self) -> &StructFields {
207 let Some(struct_dtype) = &self.dtype.as_struct_fields_opt() else {
208 unreachable!(
209 "struct arrays must have be a DType::Struct, this is likely an internal bug."
210 )
211 };
212 struct_dtype
213 }
214
215 pub fn new_with_len(len: usize) -> Self {
217 Self::try_new(
218 FieldNames::default(),
219 Vec::new(),
220 len,
221 Validity::NonNullable,
222 )
223 .vortex_expect("StructArray::new_with_len should not fail")
224 }
225
226 pub fn try_new(
227 names: FieldNames,
228 fields: Vec<ArrayRef>,
229 length: usize,
230 validity: Validity,
231 ) -> VortexResult<Self> {
232 let nullability = validity.nullability();
233
234 if names.len() != fields.len() {
235 vortex_bail!("Got {} names and {} fields", names.len(), fields.len());
236 }
237
238 for field in fields.iter() {
239 if field.len() != length {
240 vortex_bail!(
241 "Expected all struct fields to have length {length}, found {}",
242 fields.iter().map(|f| f.len()).format(","),
243 );
244 }
245 }
246
247 let field_dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect();
248 let dtype = DType::Struct(StructFields::new(names, field_dtypes), nullability);
249
250 if length != validity.maybe_len().unwrap_or(length) {
251 vortex_bail!(
252 "array length {} and validity length must match {}",
253 length,
254 validity
255 .maybe_len()
256 .vortex_expect("can only fail if maybe is some")
257 )
258 }
259
260 Ok(Self {
261 len: length,
262 dtype,
263 fields,
264 validity,
265 stats_set: Default::default(),
266 })
267 }
268
269 pub(crate) fn new_unchecked(
275 fields: Vec<ArrayRef>,
276 dtype: StructFields,
277 length: usize,
278 validity: Validity,
279 ) -> Self {
280 Self {
281 len: length,
282 dtype: DType::Struct(dtype, validity.nullability()),
283 fields,
284 validity,
285 stats_set: Default::default(),
286 }
287 }
288
289 pub fn try_new_with_dtype(
290 fields: Vec<ArrayRef>,
291 dtype: StructFields,
292 length: usize,
293 validity: Validity,
294 ) -> VortexResult<Self> {
295 for (field, struct_dt) in fields.iter().zip(dtype.fields()) {
296 if field.len() != length {
297 vortex_bail!(
298 "Expected all struct fields to have length {length}, found {}",
299 field.len()
300 );
301 }
302
303 if &struct_dt != field.dtype() {
304 vortex_bail!(
305 "Expected all struct fields to have dtype {}, found {}",
306 struct_dt,
307 field.dtype()
308 );
309 }
310 }
311
312 Ok(Self {
313 len: length,
314 dtype: DType::Struct(dtype, validity.nullability()),
315 fields,
316 validity,
317 stats_set: Default::default(),
318 })
319 }
320
321 pub fn from_fields<N: AsRef<str>>(items: &[(N, ArrayRef)]) -> VortexResult<Self> {
322 Self::try_from_iter(items.iter().map(|(a, b)| (a, b.to_array())))
323 }
324
325 pub fn try_from_iter_with_validity<
326 N: AsRef<str>,
327 A: IntoArray,
328 T: IntoIterator<Item = (N, A)>,
329 >(
330 iter: T,
331 validity: Validity,
332 ) -> VortexResult<Self> {
333 let (names, fields): (Vec<FieldName>, Vec<ArrayRef>) = iter
334 .into_iter()
335 .map(|(name, fields)| (FieldName::from(name.as_ref()), fields.into_array()))
336 .unzip();
337 let len = fields
338 .first()
339 .map(|f| f.len())
340 .ok_or_else(|| vortex_err!("StructArray cannot be constructed from an empty slice of arrays because the length is unspecified"))?;
341
342 Self::try_new(FieldNames::from_iter(names), fields, len, validity)
343 }
344
345 pub fn try_from_iter<N: AsRef<str>, A: IntoArray, T: IntoIterator<Item = (N, A)>>(
346 iter: T,
347 ) -> VortexResult<Self> {
348 Self::try_from_iter_with_validity(iter, Validity::NonNullable)
349 }
350
351 #[allow(clippy::same_name_method)]
359 pub fn project(&self, projection: &[FieldName]) -> VortexResult<Self> {
360 let mut children = Vec::with_capacity(projection.len());
361 let mut names = Vec::with_capacity(projection.len());
362
363 for f_name in projection.iter() {
364 let idx = self
365 .names()
366 .iter()
367 .position(|name| name == f_name)
368 .ok_or_else(|| vortex_err!("Unknown field {f_name}"))?;
369
370 names.push(self.names()[idx].clone());
371 children.push(self.fields()[idx].clone());
372 }
373
374 StructArray::try_new(
375 FieldNames::from(names.as_slice()),
376 children,
377 self.len(),
378 self.validity().clone(),
379 )
380 }
381
382 pub fn remove_column(&mut self, name: impl Into<FieldName>) -> Option<ArrayRef> {
385 let name = name.into();
386
387 let struct_dtype = self.struct_fields().clone();
388
389 let position = struct_dtype
390 .names()
391 .iter()
392 .position(|field_name| field_name.as_ref() == name.as_ref())?;
393
394 let field = self.fields.remove(position);
395
396 if let Ok(new_dtype) = struct_dtype.without_field(position) {
397 self.dtype = DType::Struct(new_dtype, self.dtype.nullability());
398 return Some(field);
399 }
400 None
401 }
402
403 pub fn with_column(&self, name: impl Into<FieldName>, array: ArrayRef) -> VortexResult<Self> {
405 let name = name.into();
406 let struct_dtype = self.struct_fields().clone();
407
408 let names = struct_dtype.names().iter().cloned().chain(once(name));
409 let types = struct_dtype.fields().chain(once(array.dtype().clone()));
410 let new_fields = StructFields::new(names.collect(), types.collect());
411
412 let mut children = self.fields.clone();
413 children.push(array);
414
415 Self::try_new_with_dtype(children, new_fields, self.len, self.validity.clone())
416 }
417}
418
419impl ValidityHelper for StructArray {
420 fn validity(&self) -> &Validity {
421 &self.validity
422 }
423}
424
425impl ArrayVTable<StructVTable> for StructVTable {
426 fn len(array: &StructArray) -> usize {
427 array.len
428 }
429
430 fn dtype(array: &StructArray) -> &DType {
431 &array.dtype
432 }
433
434 fn stats(array: &StructArray) -> StatsSetRef<'_> {
435 array.stats_set.to_ref(array.as_ref())
436 }
437}
438
439impl CanonicalVTable<StructVTable> for StructVTable {
440 fn canonicalize(array: &StructArray) -> VortexResult<Canonical> {
441 Ok(Canonical::Struct(array.clone()))
442 }
443}
444
445impl OperationsVTable<StructVTable> for StructVTable {
446 fn slice(array: &StructArray, range: Range<usize>) -> ArrayRef {
447 let fields = array
448 .fields()
449 .iter()
450 .map(|field| field.slice(range.clone()))
451 .collect_vec();
452 StructArray::new_unchecked(
453 fields,
454 array.struct_fields().clone(),
455 range.len(),
456 array.validity().slice(range),
457 )
458 .into_array()
459 }
460
461 fn scalar_at(array: &StructArray, index: usize) -> Scalar {
462 Scalar::struct_(
463 array.dtype().clone(),
464 array
465 .fields()
466 .iter()
467 .map(|field| field.scalar_at(index))
468 .collect_vec(),
469 )
470 }
471}
472
473#[cfg(test)]
474mod test {
475 use vortex_buffer::buffer;
476 use vortex_dtype::{DType, FieldName, FieldNames, Nullability, PType};
477
478 use crate::IntoArray;
479 use crate::arrays::primitive::PrimitiveArray;
480 use crate::arrays::struct_::StructArray;
481 use crate::arrays::varbin::VarBinArray;
482 use crate::arrays::{BoolArray, BoolVTable, PrimitiveVTable};
483 use crate::validity::Validity;
484
485 #[test]
486 fn test_project() {
487 let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
488 let ys = VarBinArray::from_vec(
489 vec!["a", "b", "c", "d", "e"],
490 DType::Utf8(Nullability::NonNullable),
491 );
492 let zs = BoolArray::from_iter([true, true, true, false, false]);
493
494 let struct_a = StructArray::try_new(
495 FieldNames::from(["xs", "ys", "zs"]),
496 vec![xs.into_array(), ys.into_array(), zs.into_array()],
497 5,
498 Validity::NonNullable,
499 )
500 .unwrap();
501
502 let struct_b = struct_a
503 .project(&[FieldName::from("zs"), FieldName::from("xs")])
504 .unwrap();
505 assert_eq!(
506 struct_b.names().as_ref(),
507 [FieldName::from("zs"), FieldName::from("xs")],
508 );
509
510 assert_eq!(struct_b.len(), 5);
511
512 let bools = &struct_b.fields[0];
513 assert_eq!(
514 bools
515 .as_::<BoolVTable>()
516 .boolean_buffer()
517 .iter()
518 .collect::<Vec<_>>(),
519 vec![true, true, true, false, false]
520 );
521
522 let prims = &struct_b.fields[1];
523 assert_eq!(
524 prims.as_::<PrimitiveVTable>().as_slice::<i64>(),
525 [0i64, 1, 2, 3, 4]
526 );
527 }
528
529 #[test]
530 fn test_remove_column() {
531 let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
532 let ys = PrimitiveArray::new(buffer![4u64, 5, 6, 7, 8], Validity::NonNullable);
533
534 let mut struct_a = StructArray::try_new(
535 FieldNames::from(["xs", "ys"]),
536 vec![xs.into_array(), ys.into_array()],
537 5,
538 Validity::NonNullable,
539 )
540 .unwrap();
541
542 let removed = struct_a.remove_column("xs").unwrap();
543 assert_eq!(
544 removed.dtype(),
545 &DType::Primitive(PType::I64, Nullability::NonNullable)
546 );
547 assert_eq!(
548 removed.as_::<PrimitiveVTable>().as_slice::<i64>(),
549 [0i64, 1, 2, 3, 4]
550 );
551
552 assert_eq!(struct_a.names(), &["ys"]);
553 assert_eq!(struct_a.fields.len(), 1);
554 assert_eq!(struct_a.len(), 5);
555 assert_eq!(
556 struct_a.fields[0].dtype(),
557 &DType::Primitive(PType::U64, Nullability::NonNullable)
558 );
559 assert_eq!(
560 struct_a.fields[0]
561 .as_::<PrimitiveVTable>()
562 .as_slice::<u64>(),
563 [4u64, 5, 6, 7, 8]
564 );
565
566 let empty = struct_a.remove_column("non_existent");
567 assert!(
568 empty.is_none(),
569 "Expected None when removing non-existent column"
570 );
571 assert_eq!(struct_a.names(), &["ys"]);
572 }
573
574 #[test]
575 fn test_duplicate_field_names() {
576 let field1 = buffer![1i32, 2, 3].into_array();
578 let field2 = buffer![10i32, 20, 30].into_array();
579 let field3 = buffer![100i32, 200, 300].into_array();
580
581 let struct_array = StructArray::try_new(
583 FieldNames::from(["value", "other", "value"]),
584 vec![field1, field2, field3],
585 3,
586 Validity::NonNullable,
587 )
588 .unwrap();
589
590 let first_value_field = struct_array.field_by_name("value").unwrap();
592 assert_eq!(
593 first_value_field.as_::<PrimitiveVTable>().as_slice::<i32>(),
594 [1i32, 2, 3] );
596
597 let opt_field = struct_array.field_by_name_opt("value").unwrap();
599 assert_eq!(
600 opt_field.as_::<PrimitiveVTable>().as_slice::<i32>(),
601 [1i32, 2, 3] );
603
604 let third_field = &struct_array.fields()[2];
606 assert_eq!(
607 third_field.as_::<PrimitiveVTable>().as_slice::<i32>(),
608 [100i32, 200, 300]
609 );
610 }
611}