1use std::fmt::Debug;
2
3use itertools::Itertools;
4use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
5use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
6use vortex_scalar::Scalar;
7
8use crate::stats::{ArrayStats, StatsSetRef};
9use crate::validity::Validity;
10use crate::vtable::{
11 ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
12 ValidityVTableFromValidityHelper,
13};
14use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
15
16mod compute;
17mod serde;
18
19vtable!(Struct);
20
21impl VTable for StructVTable {
22 type Array = StructArray;
23 type Encoding = StructEncoding;
24
25 type ArrayVTable = Self;
26 type CanonicalVTable = Self;
27 type OperationsVTable = Self;
28 type ValidityVTable = ValidityVTableFromValidityHelper;
29 type VisitorVTable = Self;
30 type ComputeVTable = NotSupported;
31 type EncodeVTable = NotSupported;
32 type SerdeVTable = Self;
33
34 fn id(_encoding: &Self::Encoding) -> EncodingId {
35 EncodingId::new_ref("vortex.struct")
36 }
37
38 fn encoding(_array: &Self::Array) -> EncodingRef {
39 EncodingRef::new_ref(StructEncoding.as_ref())
40 }
41}
42
43#[derive(Clone, Debug)]
44pub struct StructArray {
45 len: usize,
46 dtype: DType,
47 fields: Vec<ArrayRef>,
48 validity: Validity,
49 stats_set: ArrayStats,
50}
51
52#[derive(Clone, Debug)]
53pub struct StructEncoding;
54
55impl StructArray {
56 pub fn fields(&self) -> &[ArrayRef] {
57 &self.fields
58 }
59
60 pub fn field_by_name(&self, name: impl AsRef<str>) -> VortexResult<&ArrayRef> {
61 let name = name.as_ref();
62 self.field_by_name_opt(name).ok_or_else(|| {
63 vortex_err!(
64 "Field {name} not found in struct array with names {:?}",
65 self.names()
66 )
67 })
68 }
69
70 pub fn field_by_name_opt(&self, name: impl AsRef<str>) -> Option<&ArrayRef> {
71 let name = name.as_ref();
72 self.names()
73 .iter()
74 .position(|field_name| field_name.as_ref() == name)
75 .map(|idx| &self.fields[idx])
76 }
77
78 pub fn names(&self) -> &FieldNames {
79 self.struct_fields().names()
80 }
81
82 pub fn struct_fields(&self) -> &StructFields {
83 let Some(struct_dtype) = &self.dtype.as_struct() else {
84 unreachable!(
85 "struct arrays must have be a DType::Struct, this is likely an internal bug."
86 )
87 };
88 struct_dtype
89 }
90
91 pub fn try_new(
92 names: FieldNames,
93 fields: Vec<ArrayRef>,
94 length: usize,
95 validity: Validity,
96 ) -> VortexResult<Self> {
97 let nullability = validity.nullability();
98
99 if names.len() != fields.len() {
100 vortex_bail!("Got {} names and {} fields", names.len(), fields.len());
101 }
102
103 for field in fields.iter() {
104 if field.len() != length {
105 vortex_bail!(
106 "Expected all struct fields to have length {length}, found {}",
107 fields.iter().map(|f| f.len()).format(","),
108 );
109 }
110 }
111
112 let field_dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect();
113 let dtype = DType::Struct(StructFields::new(names, field_dtypes), nullability);
114
115 if length != validity.maybe_len().unwrap_or(length) {
116 vortex_bail!(
117 "array length {} and validity length must match {}",
118 length,
119 validity
120 .maybe_len()
121 .vortex_expect("can only fail if maybe is some")
122 )
123 }
124
125 Ok(Self {
126 len: length,
127 dtype,
128 fields,
129 validity,
130 stats_set: Default::default(),
131 })
132 }
133
134 pub fn try_new_with_dtype(
135 fields: Vec<ArrayRef>,
136 dtype: StructFields,
137 length: usize,
138 validity: Validity,
139 ) -> VortexResult<Self> {
140 for (field, struct_dt) in fields.iter().zip(dtype.fields()) {
141 if field.len() != length {
142 vortex_bail!(
143 "Expected all struct fields to have length {length}, found {}",
144 field.len()
145 );
146 }
147
148 if &struct_dt != field.dtype() {
149 vortex_bail!(
150 "Expected all struct fields to have dtype {}, found {}",
151 struct_dt,
152 field.dtype()
153 );
154 }
155 }
156
157 Ok(Self {
158 len: length,
159 dtype: DType::Struct(dtype, validity.nullability()),
160 fields,
161 validity,
162 stats_set: Default::default(),
163 })
164 }
165
166 pub fn from_fields<N: AsRef<str>>(items: &[(N, ArrayRef)]) -> VortexResult<Self> {
167 Self::try_from_iter(items.iter().map(|(a, b)| (a, b.to_array())))
168 }
169
170 pub fn try_from_iter_with_validity<
171 N: AsRef<str>,
172 A: IntoArray,
173 T: IntoIterator<Item = (N, A)>,
174 >(
175 iter: T,
176 validity: Validity,
177 ) -> VortexResult<Self> {
178 let (names, fields): (Vec<FieldName>, Vec<ArrayRef>) = iter
179 .into_iter()
180 .map(|(name, fields)| (FieldName::from(name.as_ref()), fields.into_array()))
181 .unzip();
182 let len = fields
183 .first()
184 .map(|f| f.len())
185 .ok_or_else(|| vortex_err!("StructArray cannot be constructed from an empty slice of arrays because the length is unspecified"))?;
186
187 Self::try_new(FieldNames::from_iter(names), fields, len, validity)
188 }
189
190 pub fn try_from_iter<N: AsRef<str>, A: IntoArray, T: IntoIterator<Item = (N, A)>>(
191 iter: T,
192 ) -> VortexResult<Self> {
193 Self::try_from_iter_with_validity(iter, Validity::NonNullable)
194 }
195
196 #[allow(clippy::same_name_method)]
204 pub fn project(&self, projection: &[FieldName]) -> VortexResult<Self> {
205 let mut children = Vec::with_capacity(projection.len());
206 let mut names = Vec::with_capacity(projection.len());
207
208 for f_name in projection.iter() {
209 let idx = self
210 .names()
211 .iter()
212 .position(|name| name == f_name)
213 .ok_or_else(|| vortex_err!("Unknown field {f_name}"))?;
214
215 names.push(self.names()[idx].clone());
216 children.push(self.fields()[idx].clone());
217 }
218
219 StructArray::try_new(
220 FieldNames::from(names.as_slice()),
221 children,
222 self.len(),
223 self.validity().clone(),
224 )
225 }
226
227 pub fn remove_column(&mut self, name: impl Into<FieldName>) -> Option<ArrayRef> {
230 let name = name.into();
231
232 let Some(struct_dtype) = self.dtype.as_struct() else {
233 unreachable!(
234 "struct arrays must have be a DType::Struct, this is likely an internal bug."
235 )
236 };
237
238 let position = struct_dtype
239 .names()
240 .iter()
241 .position(|field_name| field_name.as_ref() == name.as_ref())?;
242
243 let field = self.fields.remove(position);
244
245 let new_dtype = struct_dtype.without_field(position);
246 self.dtype = DType::Struct(new_dtype, self.dtype.nullability());
247
248 Some(field)
249 }
250}
251
252impl ValidityHelper for StructArray {
253 fn validity(&self) -> &Validity {
254 &self.validity
255 }
256}
257
258impl ArrayVTable<StructVTable> for StructVTable {
259 fn len(array: &StructArray) -> usize {
260 array.len
261 }
262
263 fn dtype(array: &StructArray) -> &DType {
264 &array.dtype
265 }
266
267 fn stats(array: &StructArray) -> StatsSetRef<'_> {
268 array.stats_set.to_ref(array.as_ref())
269 }
270}
271
272impl CanonicalVTable<StructVTable> for StructVTable {
273 fn canonicalize(array: &StructArray) -> VortexResult<Canonical> {
274 Ok(Canonical::Struct(array.clone()))
275 }
276}
277
278impl OperationsVTable<StructVTable> for StructVTable {
279 fn slice(array: &StructArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
280 let fields = array
281 .fields()
282 .iter()
283 .map(|field| field.slice(start, stop))
284 .try_collect()?;
285 StructArray::try_new_with_dtype(
286 fields,
287 array.struct_fields().clone(),
288 stop - start,
289 array.validity().slice(start, stop)?,
290 )
291 .map(|a| a.into_array())
292 }
293
294 fn scalar_at(array: &StructArray, index: usize) -> VortexResult<Scalar> {
295 Ok(Scalar::struct_(
296 array.dtype().clone(),
297 array
298 .fields()
299 .iter()
300 .map(|field| field.scalar_at(index))
301 .try_collect()?,
302 ))
303 }
304}
305
306#[cfg(test)]
307mod test {
308 use vortex_buffer::buffer;
309 use vortex_dtype::{DType, FieldName, FieldNames, Nullability, PType};
310
311 use crate::IntoArray;
312 use crate::arrays::primitive::PrimitiveArray;
313 use crate::arrays::struct_::StructArray;
314 use crate::arrays::varbin::VarBinArray;
315 use crate::arrays::{BoolArray, BoolVTable, PrimitiveVTable};
316 use crate::validity::Validity;
317
318 #[test]
319 fn test_project() {
320 let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
321 let ys = VarBinArray::from_vec(
322 vec!["a", "b", "c", "d", "e"],
323 DType::Utf8(Nullability::NonNullable),
324 );
325 let zs = BoolArray::from_iter([true, true, true, false, false]);
326
327 let struct_a = StructArray::try_new(
328 FieldNames::from(["xs", "ys", "zs"]),
329 vec![xs.into_array(), ys.into_array(), zs.into_array()],
330 5,
331 Validity::NonNullable,
332 )
333 .unwrap();
334
335 let struct_b = struct_a
336 .project(&[FieldName::from("zs"), FieldName::from("xs")])
337 .unwrap();
338 assert_eq!(
339 struct_b.names().as_ref(),
340 [FieldName::from("zs"), FieldName::from("xs")],
341 );
342
343 assert_eq!(struct_b.len(), 5);
344
345 let bools = &struct_b.fields[0];
346 assert_eq!(
347 bools
348 .as_::<BoolVTable>()
349 .boolean_buffer()
350 .iter()
351 .collect::<Vec<_>>(),
352 vec![true, true, true, false, false]
353 );
354
355 let prims = &struct_b.fields[1];
356 assert_eq!(
357 prims.as_::<PrimitiveVTable>().as_slice::<i64>(),
358 [0i64, 1, 2, 3, 4]
359 );
360 }
361
362 #[test]
363 fn test_remove_column() {
364 let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
365 let ys = PrimitiveArray::new(buffer![4u64, 5, 6, 7, 8], Validity::NonNullable);
366
367 let mut struct_a = StructArray::try_new(
368 FieldNames::from(["xs", "ys"]),
369 vec![xs.into_array(), ys.into_array()],
370 5,
371 Validity::NonNullable,
372 )
373 .unwrap();
374
375 let removed = struct_a.remove_column("xs").unwrap();
376 assert_eq!(
377 removed.dtype(),
378 &DType::Primitive(PType::I64, Nullability::NonNullable)
379 );
380 assert_eq!(
381 removed.as_::<PrimitiveVTable>().as_slice::<i64>(),
382 [0i64, 1, 2, 3, 4]
383 );
384
385 assert_eq!(struct_a.names(), &[FieldName::from("ys")].into());
386 assert_eq!(struct_a.fields.len(), 1);
387 assert_eq!(struct_a.len(), 5);
388 assert_eq!(
389 struct_a.fields[0].dtype(),
390 &DType::Primitive(PType::U64, Nullability::NonNullable)
391 );
392 assert_eq!(
393 struct_a.fields[0]
394 .as_::<PrimitiveVTable>()
395 .as_slice::<u64>(),
396 [4u64, 5, 6, 7, 8]
397 );
398
399 let empty = struct_a.remove_column("non_existent");
400 assert!(
401 empty.is_none(),
402 "Expected None when removing non-existent column"
403 );
404 assert_eq!(struct_a.names(), &[FieldName::from("ys")].into());
405 }
406}