1use std::fmt::Debug;
2use std::sync::Arc;
3
4use itertools::Itertools;
5use vortex_dtype::{DType, FieldName, FieldNames, StructFields};
6use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
7use vortex_scalar::Scalar;
8
9use crate::stats::{ArrayStats, StatsSetRef};
10use crate::validity::Validity;
11use crate::vtable::{
12 ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
13 ValidityVTableFromValidityHelper,
14};
15use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
16
17mod compute;
18mod serde;
19
20vtable!(Struct);
21
22impl VTable for StructVTable {
23 type Array = StructArray;
24 type Encoding = StructEncoding;
25
26 type ArrayVTable = Self;
27 type CanonicalVTable = Self;
28 type OperationsVTable = Self;
29 type ValidityVTable = ValidityVTableFromValidityHelper;
30 type VisitorVTable = Self;
31 type ComputeVTable = NotSupported;
32 type EncodeVTable = NotSupported;
33 type SerdeVTable = Self;
34
35 fn id(_encoding: &Self::Encoding) -> EncodingId {
36 EncodingId::new_ref("vortex.struct")
37 }
38
39 fn encoding(_array: &Self::Array) -> EncodingRef {
40 EncodingRef::new_ref(StructEncoding.as_ref())
41 }
42}
43
44#[derive(Clone, Debug)]
45pub struct StructArray {
46 len: usize,
47 dtype: DType,
48 fields: Vec<ArrayRef>,
49 validity: Validity,
50 stats_set: ArrayStats,
51}
52
53#[derive(Clone, Debug)]
54pub struct StructEncoding;
55
56impl StructArray {
57 pub fn fields(&self) -> &[ArrayRef] {
58 &self.fields
59 }
60
61 pub fn field_by_name(&self, name: impl AsRef<str>) -> VortexResult<&ArrayRef> {
62 let name = name.as_ref();
63 self.field_by_name_opt(name).ok_or_else(|| {
64 vortex_err!(
65 "Field {name} not found in struct array with names {:?}",
66 self.names()
67 )
68 })
69 }
70
71 pub fn field_by_name_opt(&self, name: impl AsRef<str>) -> Option<&ArrayRef> {
72 let name = name.as_ref();
73 self.names()
74 .iter()
75 .position(|field_name| field_name.as_ref() == name)
76 .map(|idx| &self.fields[idx])
77 }
78
79 pub fn names(&self) -> &FieldNames {
80 self.struct_fields().names()
81 }
82
83 pub fn struct_fields(&self) -> &Arc<StructFields> {
84 let Some(struct_dtype) = &self.dtype.as_struct() else {
85 unreachable!(
86 "struct arrays must have be a DType::Struct, this is likely an internal bug."
87 )
88 };
89 struct_dtype
90 }
91
92 pub fn try_new(
93 names: FieldNames,
94 fields: Vec<ArrayRef>,
95 length: usize,
96 validity: Validity,
97 ) -> VortexResult<Self> {
98 let nullability = validity.nullability();
99
100 if names.len() != fields.len() {
101 vortex_bail!("Got {} names and {} fields", names.len(), fields.len());
102 }
103
104 for field in fields.iter() {
105 if field.len() != length {
106 vortex_bail!(
107 "Expected all struct fields to have length {length}, found {}",
108 fields.iter().map(|f| f.len()).format(","),
109 );
110 }
111 }
112
113 let field_dtypes: Vec<_> = fields.iter().map(|d| d.dtype()).cloned().collect();
114 let dtype = DType::Struct(
115 Arc::new(StructFields::new(names, field_dtypes)),
116 nullability,
117 );
118
119 if length != validity.maybe_len().unwrap_or(length) {
120 vortex_bail!(
121 "array length {} and validity length must match {}",
122 length,
123 validity
124 .maybe_len()
125 .vortex_expect("can only fail if maybe is some")
126 )
127 }
128
129 Ok(Self {
130 len: length,
131 dtype,
132 fields,
133 validity,
134 stats_set: Default::default(),
135 })
136 }
137
138 pub fn try_new_with_dtype(
139 fields: Vec<ArrayRef>,
140 dtype: Arc<StructFields>,
141 length: usize,
142 validity: Validity,
143 ) -> VortexResult<Self> {
144 for (field, struct_dt) in fields.iter().zip(dtype.fields()) {
145 if field.len() != length {
146 vortex_bail!(
147 "Expected all struct fields to have length {length}, found {}",
148 field.len()
149 );
150 }
151
152 if &struct_dt != field.dtype() {
153 vortex_bail!(
154 "Expected all struct fields to have dtype {}, found {}",
155 struct_dt,
156 field.dtype()
157 );
158 }
159 }
160
161 Ok(Self {
162 len: length,
163 dtype: DType::Struct(dtype, validity.nullability()),
164 fields,
165 validity,
166 stats_set: Default::default(),
167 })
168 }
169
170 pub fn from_fields<N: AsRef<str>>(items: &[(N, ArrayRef)]) -> VortexResult<Self> {
171 Self::try_from_iter(items.iter().map(|(a, b)| (a, b.to_array())))
172 }
173
174 pub fn try_from_iter_with_validity<
175 N: AsRef<str>,
176 A: IntoArray,
177 T: IntoIterator<Item = (N, A)>,
178 >(
179 iter: T,
180 validity: Validity,
181 ) -> VortexResult<Self> {
182 let (names, fields): (Vec<FieldName>, Vec<ArrayRef>) = iter
183 .into_iter()
184 .map(|(name, fields)| (FieldName::from(name.as_ref()), fields.into_array()))
185 .unzip();
186 let len = fields
187 .first()
188 .map(|f| f.len())
189 .ok_or_else(|| vortex_err!("StructArray cannot be constructed from an empty slice of arrays because the length is unspecified"))?;
190
191 Self::try_new(FieldNames::from_iter(names), fields, len, validity)
192 }
193
194 pub fn try_from_iter<N: AsRef<str>, A: IntoArray, T: IntoIterator<Item = (N, A)>>(
195 iter: T,
196 ) -> VortexResult<Self> {
197 Self::try_from_iter_with_validity(iter, Validity::NonNullable)
198 }
199
200 #[allow(clippy::same_name_method)]
208 pub fn project(&self, projection: &[FieldName]) -> VortexResult<Self> {
209 let mut children = Vec::with_capacity(projection.len());
210 let mut names = Vec::with_capacity(projection.len());
211
212 for f_name in projection.iter() {
213 let idx = self
214 .names()
215 .iter()
216 .position(|name| name == f_name)
217 .ok_or_else(|| vortex_err!("Unknown field {f_name}"))?;
218
219 names.push(self.names()[idx].clone());
220 children.push(self.fields()[idx].clone());
221 }
222
223 StructArray::try_new(
224 FieldNames::from(names.as_slice()),
225 children,
226 self.len(),
227 self.validity().clone(),
228 )
229 }
230
231 pub fn remove_column(&mut self, name: impl Into<FieldName>) -> Option<ArrayRef> {
234 let name = name.into();
235
236 let Some(struct_dtype) = self.dtype.as_struct() else {
237 unreachable!(
238 "struct arrays must have be a DType::Struct, this is likely an internal bug."
239 )
240 };
241
242 let position = struct_dtype
243 .names()
244 .iter()
245 .position(|field_name| field_name.as_ref() == name.as_ref())?;
246
247 let field = self.fields.remove(position);
248
249 let new_dtype = struct_dtype.without_field(position);
250 self.dtype = DType::Struct(Arc::new(new_dtype), self.dtype.nullability());
251
252 Some(field)
253 }
254}
255
256impl ValidityHelper for StructArray {
257 fn validity(&self) -> &Validity {
258 &self.validity
259 }
260}
261
262impl ArrayVTable<StructVTable> for StructVTable {
263 fn len(array: &StructArray) -> usize {
264 array.len
265 }
266
267 fn dtype(array: &StructArray) -> &DType {
268 &array.dtype
269 }
270
271 fn stats(array: &StructArray) -> StatsSetRef<'_> {
272 array.stats_set.to_ref(array.as_ref())
273 }
274}
275
276impl CanonicalVTable<StructVTable> for StructVTable {
277 fn canonicalize(array: &StructArray) -> VortexResult<Canonical> {
278 Ok(Canonical::Struct(array.clone()))
279 }
280}
281
282impl OperationsVTable<StructVTable> for StructVTable {
283 fn slice(array: &StructArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
284 let fields = array
285 .fields()
286 .iter()
287 .map(|field| field.slice(start, stop))
288 .try_collect()?;
289 StructArray::try_new_with_dtype(
290 fields,
291 array.struct_fields().clone(),
292 stop - start,
293 array.validity().slice(start, stop)?,
294 )
295 .map(|a| a.into_array())
296 }
297
298 fn scalar_at(array: &StructArray, index: usize) -> VortexResult<Scalar> {
299 Ok(Scalar::struct_(
300 array.dtype().clone(),
301 array
302 .fields()
303 .iter()
304 .map(|field| field.scalar_at(index))
305 .try_collect()?,
306 ))
307 }
308}
309
310#[cfg(test)]
311mod test {
312 use vortex_buffer::buffer;
313 use vortex_dtype::{DType, FieldName, FieldNames, Nullability, PType};
314
315 use crate::IntoArray;
316 use crate::arrays::primitive::PrimitiveArray;
317 use crate::arrays::struct_::StructArray;
318 use crate::arrays::varbin::VarBinArray;
319 use crate::arrays::{BoolArray, BoolVTable, PrimitiveVTable};
320 use crate::validity::Validity;
321
322 #[test]
323 fn test_project() {
324 let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
325 let ys = VarBinArray::from_vec(
326 vec!["a", "b", "c", "d", "e"],
327 DType::Utf8(Nullability::NonNullable),
328 );
329 let zs = BoolArray::from_iter([true, true, true, false, false]);
330
331 let struct_a = StructArray::try_new(
332 FieldNames::from(["xs".into(), "ys".into(), "zs".into()]),
333 vec![xs.into_array(), ys.into_array(), zs.into_array()],
334 5,
335 Validity::NonNullable,
336 )
337 .unwrap();
338
339 let struct_b = struct_a
340 .project(&[FieldName::from("zs"), FieldName::from("xs")])
341 .unwrap();
342 assert_eq!(
343 struct_b.names().as_ref(),
344 [FieldName::from("zs"), FieldName::from("xs")],
345 );
346
347 assert_eq!(struct_b.len(), 5);
348
349 let bools = &struct_b.fields[0];
350 assert_eq!(
351 bools
352 .as_::<BoolVTable>()
353 .boolean_buffer()
354 .iter()
355 .collect::<Vec<_>>(),
356 vec![true, true, true, false, false]
357 );
358
359 let prims = &struct_b.fields[1];
360 assert_eq!(
361 prims.as_::<PrimitiveVTable>().as_slice::<i64>(),
362 [0i64, 1, 2, 3, 4]
363 );
364 }
365
366 #[test]
367 fn test_remove_column() {
368 let xs = PrimitiveArray::new(buffer![0i64, 1, 2, 3, 4], Validity::NonNullable);
369 let ys = PrimitiveArray::new(buffer![4u64, 5, 6, 7, 8], Validity::NonNullable);
370
371 let mut struct_a = StructArray::try_new(
372 FieldNames::from(["xs".into(), "ys".into()]),
373 vec![xs.into_array(), ys.into_array()],
374 5,
375 Validity::NonNullable,
376 )
377 .unwrap();
378
379 let removed = struct_a.remove_column("xs").unwrap();
380 assert_eq!(
381 removed.dtype(),
382 &DType::Primitive(PType::I64, Nullability::NonNullable)
383 );
384 assert_eq!(
385 removed.as_::<PrimitiveVTable>().as_slice::<i64>(),
386 [0i64, 1, 2, 3, 4]
387 );
388
389 assert_eq!(struct_a.names().as_ref(), [FieldName::from("ys")]);
390 assert_eq!(struct_a.fields.len(), 1);
391 assert_eq!(struct_a.len(), 5);
392 assert_eq!(
393 struct_a.fields[0].dtype(),
394 &DType::Primitive(PType::U64, Nullability::NonNullable)
395 );
396 assert_eq!(
397 struct_a.fields[0]
398 .as_::<PrimitiveVTable>()
399 .as_slice::<u64>(),
400 [4u64, 5, 6, 7, 8]
401 );
402
403 let empty = struct_a.remove_column("non_existent");
404 assert!(
405 empty.is_none(),
406 "Expected None when removing non-existent column"
407 );
408 assert_eq!(struct_a.names().as_ref(), [FieldName::from("ys")]);
409 }
410}