vortex_layout/layouts/struct_/
writer.rs1use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::future::try_join_all;
8use futures::{StreamExt, TryStreamExt, pin_mut};
9use itertools::Itertools;
10use vortex_array::{Array, ArrayContext, ArrayRef, IntoArray, ToCanonical};
11use vortex_dtype::{DType, Nullability};
12use vortex_error::{VortexError, VortexResult, vortex_bail};
13use vortex_io::kanal_ext::KanalExt;
14use vortex_io::runtime::Handle;
15use vortex_utils::aliases::DefaultHashBuilder;
16use vortex_utils::aliases::hash_set::HashSet;
17
18use crate::layouts::struct_::StructLayout;
19use crate::segments::SegmentSinkRef;
20use crate::sequence::{
21 SendableSequentialStream, SequenceId, SequencePointer, SequentialStreamAdapter,
22 SequentialStreamExt,
23};
24use crate::{IntoLayout as _, LayoutRef, LayoutStrategy};
25
26#[derive(Clone)]
27pub struct StructStrategy {
28 child: Arc<dyn LayoutStrategy>,
29 validity: Arc<dyn LayoutStrategy>,
30}
31
32impl StructStrategy {
34 pub fn new<S: LayoutStrategy, V: LayoutStrategy>(child: S, validity: V) -> Self {
35 Self {
36 child: Arc::new(child),
37 validity: Arc::new(validity),
38 }
39 }
40}
41
42#[async_trait]
43impl LayoutStrategy for StructStrategy {
44 async fn write_stream(
45 &self,
46 ctx: ArrayContext,
47 segment_sink: SegmentSinkRef,
48 stream: SendableSequentialStream,
49 mut eof: SequencePointer,
50 handle: Handle,
51 ) -> VortexResult<LayoutRef> {
52 let dtype = stream.dtype().clone();
53 let Some(struct_dtype) = stream.dtype().as_struct_fields_opt().cloned() else {
54 return self
55 .child
56 .write_stream(ctx, segment_sink, stream, eof, handle)
57 .await;
58 };
59
60 if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len()
62 != struct_dtype.names().len()
63 {
64 vortex_bail!("StructLayout must have unique field names");
65 }
66
67 let is_nullable = dtype.is_nullable();
68
69 if struct_dtype.nfields() == 0 && !is_nullable {
72 let row_count = stream
73 .try_fold(
74 0u64,
75 |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) },
76 )
77 .await?;
78 return Ok(StructLayout::new(row_count, dtype, vec![]).into_layout());
79 }
80
81 let columns_vec_stream = stream.map(move |chunk| {
83 let (sequence_id, chunk) = chunk?;
84 let mut sequence_pointer = sequence_id.descend();
85 let struct_chunk = chunk.to_struct();
86 let mut columns: Vec<(SequenceId, ArrayRef)> = Vec::new();
87 if is_nullable {
88 columns.push((
89 sequence_pointer.advance(),
90 chunk.validity_mask().into_array(),
91 ));
92 }
93
94 columns.extend(
95 struct_chunk
96 .fields()
97 .iter()
98 .map(|field| (sequence_pointer.advance(), field.to_array())),
99 );
100
101 Ok(columns)
102 });
103
104 let mut stream_count = struct_dtype.nfields();
105 if is_nullable {
106 stream_count += 1;
107 }
108
109 let (column_streams_tx, column_streams_rx): (Vec<_>, Vec<_>) =
110 (0..stream_count).map(|_| kanal::bounded_async(1)).unzip();
111
112 handle
114 .spawn(async move {
115 pin_mut!(columns_vec_stream);
116 while let Some(result) = columns_vec_stream.next().await {
117 match result {
118 Ok(columns) => {
119 for (tx, column) in column_streams_tx.iter().zip_eq(columns.into_iter())
120 {
121 let _ = tx.send(Ok(column)).await;
122 }
123 }
124 Err(e) => {
125 let e: Arc<VortexError> = Arc::new(e);
126 for tx in column_streams_tx.iter() {
127 let _ = tx.send(Err(VortexError::from(e.clone()))).await;
128 }
129 break;
130 }
131 }
132 }
133 })
134 .detach();
135
136 let column_dtypes: Vec<DType> = if is_nullable {
138 std::iter::once(DType::Bool(Nullability::NonNullable))
139 .chain(struct_dtype.fields())
140 .collect()
141 } else {
142 struct_dtype.fields().collect()
143 };
144
145 let layout_futures: Vec<_> = column_dtypes
146 .into_iter()
147 .zip_eq(column_streams_rx)
148 .enumerate()
149 .map(move |(index, (dtype, recv))| {
150 let column_stream =
151 SequentialStreamAdapter::new(dtype.clone(), recv.into_stream().boxed())
152 .sendable();
153 let child_eof = eof.split_off();
154 handle.spawn_nested(|h| {
155 let child = self.child.clone();
156 let validity = self.validity.clone();
157 let this = self.clone();
158 let ctx = ctx.clone();
159 let dtype = dtype.clone();
160 let segment_sink = segment_sink.clone();
161 async move {
162 if index == 0 && is_nullable {
164 validity
165 .write_stream(ctx, segment_sink, column_stream, child_eof, h)
166 .await
167 } else {
168 if dtype.is_struct() {
171 this.write_stream(ctx, segment_sink, column_stream, child_eof, h)
172 .await
173 } else {
174 child
175 .write_stream(ctx, segment_sink, column_stream, child_eof, h)
176 .await
177 }
178 }
179 }
180 })
181 })
182 .collect();
183
184 let column_layouts = try_join_all(layout_futures).await?;
185 let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0);
188 Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout())
189 }
190
191 fn buffered_bytes(&self) -> u64 {
192 self.child.buffered_bytes()
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use std::sync::Arc;
199
200 use vortex_array::arrays::{ChunkedArray, StructArray};
201 use vortex_array::validity::Validity;
202 use vortex_array::{ArrayContext, Canonical, IntoArray as _};
203 use vortex_dtype::{DType, FieldNames, Nullability, PType};
204 use vortex_io::runtime::single::block_on;
205
206 use crate::LayoutStrategy;
207 use crate::layouts::flat::writer::FlatLayoutStrategy;
208 use crate::layouts::struct_::writer::StructStrategy;
209 use crate::segments::TestSegments;
210 use crate::sequence::{SequenceId, SequentialArrayStreamExt};
211
212 #[test]
213 #[should_panic]
214 fn fails_on_duplicate_field() {
215 let strategy =
216 StructStrategy::new(FlatLayoutStrategy::default(), FlatLayoutStrategy::default());
217 let (ptr, eof) = SequenceId::root().split();
218 let ctx = ArrayContext::empty();
219 let segments = Arc::new(TestSegments::default());
220 block_on(|handle| {
221 strategy.write_stream(
222 ctx,
223 segments,
224 Canonical::empty(&DType::Struct(
225 [
226 ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
227 ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
228 ]
229 .into_iter()
230 .collect(),
231 Nullability::NonNullable,
232 ))
233 .into_array()
234 .to_array_stream()
235 .sequenced(ptr),
236 eof,
237 handle,
238 )
239 })
240 .unwrap();
241 }
242
243 #[test]
244 fn write_empty_field_struct_array() {
245 let strategy =
246 StructStrategy::new(FlatLayoutStrategy::default(), FlatLayoutStrategy::default());
247 let (ptr, eof) = SequenceId::root().split();
248 let ctx = ArrayContext::empty();
249 let segments = Arc::new(TestSegments::default());
250 let res = block_on(|handle| {
251 strategy.write_stream(
252 ctx,
253 segments,
254 ChunkedArray::from_iter([
255 StructArray::try_new(FieldNames::default(), vec![], 3, Validity::NonNullable)
256 .unwrap()
257 .into_array(),
258 StructArray::try_new(FieldNames::default(), vec![], 5, Validity::NonNullable)
259 .unwrap()
260 .into_array(),
261 ])
262 .into_array()
263 .to_array_stream()
264 .sequenced(ptr),
265 eof,
266 handle,
267 )
268 });
269
270 assert_eq!(res.unwrap().row_count(), 8);
271 }
272}