1use std::fmt::Debug;
5use std::iter::once;
6
7use itertools::Itertools;
8use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::Scalar;
11
12use crate::stats::{ArrayStats, StatsSetRef};
13use crate::validity::Validity;
14use crate::vtable::{
15 ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
16 ValidityVTableFromValidityHelper,
17};
18use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
19
20mod compute;
21mod serde;
22
23vtable!(Struct);
24
25impl VTable for StructVTable {
26 type Array = StructArray;
27 type Encoding = StructEncoding;
28
29 type ArrayVTable = Self;
30 type CanonicalVTable = Self;
31 type OperationsVTable = Self;
32 type ValidityVTable = ValidityVTableFromValidityHelper;
33 type VisitorVTable = Self;
34 type ComputeVTable = NotSupported;
35 type EncodeVTable = NotSupported;
36 type SerdeVTable = Self;
37
38 fn id(_encoding: &Self::Encoding) -> EncodingId {
39 EncodingId::new_ref("vortex.struct")
40 }
41
42 fn encoding(_array: &Self::Array) -> EncodingRef {
43 EncodingRef::new_ref(StructEncoding.as_ref())
44 }
45}
46
47#[derive(Clone, Debug)]
166pub struct StructArray {
167 len: usize,
168 dtype: DType,
169 fields: Vec<ArrayRef>,
170 validity: Validity,
171 stats_set: ArrayStats,
172}
173
174#[derive(Clone, Debug)]
175pub struct StructEncoding;
176
177impl StructArray {
178 pub fn fields(&self) -> &[ArrayRef] {
179 &self.fields
180 }
181
182 pub fn field_by_name(&self, name: impl AsRef<str>) -> VortexResult<&ArrayRef> {
183 let name = name.as_ref();
184 self.field_by_name_opt(name).ok_or_else(|| {
185 vortex_err!(
186 "Field {name} not found in struct array with names {:?}",
187 self.names()
188 )
189 })
190 }
191
192 pub fn field_by_name_opt(&self, name: impl AsRef<str>) -> Option<&ArrayRef> {
193 let name = name.as_ref();
194 self.names()
195 .iter()
196 .position(|field_name| field_name.as_ref() == name)
197 .map(|idx| &self.fields[idx])
198 }
199
200 pub fn names(&self) -> &FieldNames {
201 self.struct_fields().names()
202 }
203
204 pub fn struct_fields(&self) -> &StructFields {
205 let Some(struct_dtype) = &self.dtype.as_struct() else {
206 unreachable!(
207 "struct arrays must have be a DType::Struct, this is likely an internal bug."
208 )
209 };
210 struct_dtype
211 }
212
213 pub fn new_with_len(len: usize) -> Self {
215 Self::try_new(
216 FieldNames::default(),
217 Vec::new(),
218 len,
219 Validity::NonNullable,
220 )
221 .vortex_expect("StructArray::new_with_len should not fail")
222 }
223
224 pub fn try_new(
225 names: FieldNames,
226 fields: Vec<ArrayRef>,
227 length: usize,
228 validity: Validity,
229 ) -> VortexResult<Self> {
230 let nullability = validity.nullability();
231
232 if names.len() != fields.len() {
233 vortex_bail!("Got {} names and {} fields", names.len(), fields.len());
234 }
235
236 for field in fields.iter() {
237 if field.len() != length {
238 vortex_bail!(
239 "Expected all struct fields to have length {length}, found {}",
240 fields.iter().map(|f| f.len()).format(","),
241 );
242 }
243 }
244
245 let field_dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect();
246 let dtype = DType::Struct(StructFields::new(names, field_dtypes), nullability);
247
248 if length != validity.maybe_len().unwrap_or(length) {
249 vortex_bail!(
250 "array length {} and validity length must match {}",
251 length,
252 validity
253 .maybe_len()
254 .vortex_expect("can only fail if maybe is some")
255 )
256 }
257
258 Ok(Self {
259 len: length,
260 dtype,
261 fields,
262 validity,
263 stats_set: Default::default(),
264 })
265 }
266
267 pub fn try_new_with_dtype(
268 fields: Vec<ArrayRef>,
269 dtype: StructFields,
270 length: usize,
271 validity: Validity,
272 ) -> VortexResult<Self> {
273 for (field, struct_dt) in fields.iter().zip(dtype.fields()) {
274 if field.len() != length {
275 vortex_bail!(
276 "Expected all struct fields to have length {length}, found {}",
277 field.len()
278 );
279 }
280
281 if &struct_dt != field.dtype() {
282 vortex_bail!(
283 "Expected all struct fields to have dtype {}, found {}",
284 struct_dt,
285 field.dtype()
286 );
287 }
288 }
289
290 Ok(Self {
291 len: length,
292 dtype: DType::Struct(dtype, validity.nullability()),
293 fields,
294 validity,
295 stats_set: Default::default(),
296 })
297 }
298
299 pub fn from_fields<N: AsRef<str>>(items: &[(N, ArrayRef)]) -> VortexResult<Self> {
300 Self::try_from_iter(items.iter().map(|(a, b)| (a, b.to_array())))
301 }
302
303 pub fn try_from_iter_with_validity<
304 N: AsRef<str>,
305 A: IntoArray,
306 T: IntoIterator<Item = (N, A)>,
307 >(
308 iter: T,
309 validity: Validity,
310 ) -> VortexResult<Self> {
311 let (names, fields): (Vec<FieldName>, Vec<ArrayRef>) = iter
312 .into_iter()
313 .map(|(name, fields)| (FieldName::from(name.as_ref()), fields.into_array()))
314 .unzip();
315 let len = fields
316 .first()
317 .map(|f| f.len())
318 .ok_or_else(|| vortex_err!("StructArray cannot be constructed from an empty slice of arrays because the length is unspecified"))?;
319
320 Self::try_new(FieldNames::from_iter(names), fields, len, validity)
321 }
322
323 pub fn try_from_iter<N: AsRef<str>, A: IntoArray, T: IntoIterator<Item = (N, A)>>(
324 iter: T,
325 ) -> VortexResult<Self> {
326 Self::try_from_iter_with_validity(iter, Validity::NonNullable)
327 }
328
329 #[allow(clippy::same_name_method)]
337 pub fn project(&self, projection: &[FieldName]) -> VortexResult<Self> {
338 let mut children = Vec::with_capacity(projection.len());
339 let mut names = Vec::with_capacity(projection.len());
340
341 for f_name in projection.iter() {
342 let idx = self
343 .names()
344 .iter()
345 .position(|name| name == f_name)
346 .ok_or_else(|| vortex_err!("Unknown field {f_name}"))?;
347
348 names.push(self.names()[idx].clone());
349 children.push(self.fields()[idx].clone());
350 }
351
352 StructArray::try_new(
353 FieldNames::from(names.as_slice()),
354 children,
355 self.len(),
356 self.validity().clone(),
357 )
358 }
359
360 pub fn remove_column(&mut self, name: impl Into<FieldName>) -> Option<ArrayRef> {
363 let name = name.into();
364
365 let struct_dtype = self.struct_fields().clone();
366
367 let position = struct_dtype
368 .names()
369 .iter()
370 .position(|field_name| field_name.as_ref() == name.as_ref())?;
371
372 let field = self.fields.remove(position);
373
374 if let Ok(new_dtype) = struct_dtype.without_field(position) {
375 self.dtype = DType::Struct(new_dtype, self.dtype.nullability());
376 return Some(field);
377 }
378 None
379 }
380
381 pub fn with_column(&self, name: impl Into<FieldName>, array: ArrayRef) -> VortexResult<Self> {
383 let name = name.into();
384 let struct_dtype = self.struct_fields().clone();
385
386 let names = struct_dtype.names().iter().cloned().chain(once(name));
387 let types = struct_dtype.fields().chain(once(array.dtype().clone()));
388 let new_fields = StructFields::new(names.collect(), types.collect());
389
390 let mut children = self.fields.clone();
391 children.push(array);
392
393 Self::try_new_with_dtype(children, new_fields, self.len, self.validity.clone())
394 }
395}
396
397impl ValidityHelper for StructArray {
398 fn validity(&self) -> &Validity {
399 &self.validity
400 }
401}
402
403impl ArrayVTable<StructVTable> for StructVTable {
404 fn len(array: &StructArray) -> usize {
405 array.len
406 }
407
408 fn dtype(array: &StructArray) -> &DType {
409 &array.dtype
410 }
411
412 fn stats(array: &StructArray) -> StatsSetRef<'_> {
413 array.stats_set.to_ref(array.as_ref())
414 }
415}
416
417impl CanonicalVTable<StructVTable> for StructVTable {
418 fn canonicalize(array: &StructArray) -> VortexResult<Canonical> {
419 Ok(Canonical::Struct(array.clone()))
420 }
421}
422
423impl OperationsVTable<StructVTable> for StructVTable {
424 fn slice(array: &StructArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
425 let fields = array
426 .fields()
427 .iter()
428 .map(|field| field.slice(start, stop))
429 .try_collect()?;
430 StructArray::try_new_with_dtype(
431 fields,
432 array.struct_fields().clone(),
433 stop - start,
434 array.validity().slice(start, stop)?,
435 )
436 .map(|a| a.into_array())
437 }
438
439 fn scalar_at(array: &StructArray, index: usize) -> VortexResult<Scalar> {
440 if array.is_valid(index)? {
441 Ok(Scalar::struct_(
442 array.dtype().clone(),
443 array
444 .fields()
445 .iter()
446 .map(|field| field.scalar_at(index))
447 .try_collect()?,
448 ))
449 } else {
450 Ok(Scalar::null(array.dtype().clone()))
451 }
452 }
453}
454
455#[cfg(test)]
456mod test {
457 use vortex_buffer::buffer;
458 use vortex_dtype::{DType, FieldName, FieldNames, Nullability, PType};
459
460 use crate::IntoArray;
461 use crate::arrays::primitive::PrimitiveArray;
462 use crate::arrays::struct_::StructArray;
463 use crate::arrays::varbin::VarBinArray;
464 use crate::arrays::{BoolArray, BoolVTable, PrimitiveVTable};
465 use crate::validity::Validity;
466
467 #[test]
468 fn test_project() {
469 let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
470 let ys = VarBinArray::from_vec(
471 vec!["a", "b", "c", "d", "e"],
472 DType::Utf8(Nullability::NonNullable),
473 );
474 let zs = BoolArray::from_iter([true, true, true, false, false]);
475
476 let struct_a = StructArray::try_new(
477 FieldNames::from(["xs", "ys", "zs"]),
478 vec![xs.into_array(), ys.into_array(), zs.into_array()],
479 5,
480 Validity::NonNullable,
481 )
482 .unwrap();
483
484 let struct_b = struct_a
485 .project(&[FieldName::from("zs"), FieldName::from("xs")])
486 .unwrap();
487 assert_eq!(
488 struct_b.names().as_ref(),
489 [FieldName::from("zs"), FieldName::from("xs")],
490 );
491
492 assert_eq!(struct_b.len(), 5);
493
494 let bools = &struct_b.fields[0];
495 assert_eq!(
496 bools
497 .as_::<BoolVTable>()
498 .boolean_buffer()
499 .iter()
500 .collect::<Vec<_>>(),
501 vec![true, true, true, false, false]
502 );
503
504 let prims = &struct_b.fields[1];
505 assert_eq!(
506 prims.as_::<PrimitiveVTable>().as_slice::<i64>(),
507 [0i64, 1, 2, 3, 4]
508 );
509 }
510
511 #[test]
512 fn test_remove_column() {
513 let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
514 let ys = PrimitiveArray::new(buffer![4u64, 5, 6, 7, 8], Validity::NonNullable);
515
516 let mut struct_a = StructArray::try_new(
517 FieldNames::from(["xs", "ys"]),
518 vec![xs.into_array(), ys.into_array()],
519 5,
520 Validity::NonNullable,
521 )
522 .unwrap();
523
524 let removed = struct_a.remove_column("xs").unwrap();
525 assert_eq!(
526 removed.dtype(),
527 &DType::Primitive(PType::I64, Nullability::NonNullable)
528 );
529 assert_eq!(
530 removed.as_::<PrimitiveVTable>().as_slice::<i64>(),
531 [0i64, 1, 2, 3, 4]
532 );
533
534 assert_eq!(struct_a.names(), &["ys"]);
535 assert_eq!(struct_a.fields.len(), 1);
536 assert_eq!(struct_a.len(), 5);
537 assert_eq!(
538 struct_a.fields[0].dtype(),
539 &DType::Primitive(PType::U64, Nullability::NonNullable)
540 );
541 assert_eq!(
542 struct_a.fields[0]
543 .as_::<PrimitiveVTable>()
544 .as_slice::<u64>(),
545 [4u64, 5, 6, 7, 8]
546 );
547
548 let empty = struct_a.remove_column("non_existent");
549 assert!(
550 empty.is_none(),
551 "Expected None when removing non-existent column"
552 );
553 assert_eq!(struct_a.names(), &["ys"]);
554 }
555
556 #[test]
557 fn test_duplicate_field_names() {
558 let field1 = buffer![1i32, 2, 3].into_array();
560 let field2 = buffer![10i32, 20, 30].into_array();
561 let field3 = buffer![100i32, 200, 300].into_array();
562
563 let struct_array = StructArray::try_new(
565 FieldNames::from(["value", "other", "value"]),
566 vec![field1, field2, field3],
567 3,
568 Validity::NonNullable,
569 )
570 .unwrap();
571
572 let first_value_field = struct_array.field_by_name("value").unwrap();
574 assert_eq!(
575 first_value_field.as_::<PrimitiveVTable>().as_slice::<i32>(),
576 [1i32, 2, 3] );
578
579 let opt_field = struct_array.field_by_name_opt("value").unwrap();
581 assert_eq!(
582 opt_field.as_::<PrimitiveVTable>().as_slice::<i32>(),
583 [1i32, 2, 3] );
585
586 let third_field = &struct_array.fields()[2];
588 assert_eq!(
589 third_field.as_::<PrimitiveVTable>().as_slice::<i32>(),
590 [100i32, 200, 300]
591 );
592 }
593}