vortex_layout/layouts/struct_/
writer.rs1use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::StreamExt;
8use futures::TryStreamExt;
9use futures::future::try_join_all;
10use futures::pin_mut;
11use itertools::Itertools;
12use vortex_array::Array;
13use vortex_array::ArrayContext;
14use vortex_array::ArrayRef;
15use vortex_array::IntoArray;
16use vortex_array::ToCanonical;
17use vortex_dtype::DType;
18use vortex_dtype::Nullability;
19use vortex_error::VortexError;
20use vortex_error::VortexResult;
21use vortex_error::vortex_bail;
22use vortex_io::kanal_ext::KanalExt;
23use vortex_io::runtime::Handle;
24use vortex_utils::aliases::DefaultHashBuilder;
25use vortex_utils::aliases::hash_set::HashSet;
26
27use crate::IntoLayout as _;
28use crate::LayoutRef;
29use crate::LayoutStrategy;
30use crate::layouts::struct_::StructLayout;
31use crate::segments::SegmentSinkRef;
32use crate::sequence::SendableSequentialStream;
33use crate::sequence::SequenceId;
34use crate::sequence::SequencePointer;
35use crate::sequence::SequentialStreamAdapter;
36use crate::sequence::SequentialStreamExt;
37
38#[derive(Clone)]
39pub struct StructStrategy {
40 child: Arc<dyn LayoutStrategy>,
41 validity: Arc<dyn LayoutStrategy>,
42}
43
44impl StructStrategy {
46 pub fn new<S: LayoutStrategy, V: LayoutStrategy>(child: S, validity: V) -> Self {
47 Self {
48 child: Arc::new(child),
49 validity: Arc::new(validity),
50 }
51 }
52}
53
54#[async_trait]
55impl LayoutStrategy for StructStrategy {
56 async fn write_stream(
57 &self,
58 ctx: ArrayContext,
59 segment_sink: SegmentSinkRef,
60 stream: SendableSequentialStream,
61 mut eof: SequencePointer,
62 handle: Handle,
63 ) -> VortexResult<LayoutRef> {
64 let dtype = stream.dtype().clone();
65 let Some(struct_dtype) = stream.dtype().as_struct_fields_opt().cloned() else {
66 return self
67 .child
68 .write_stream(ctx, segment_sink, stream, eof, handle)
69 .await;
70 };
71
72 if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len()
74 != struct_dtype.names().len()
75 {
76 vortex_bail!("StructLayout must have unique field names");
77 }
78
79 let is_nullable = dtype.is_nullable();
80
81 if struct_dtype.nfields() == 0 && !is_nullable {
84 let row_count = stream
85 .try_fold(
86 0u64,
87 |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) },
88 )
89 .await?;
90 return Ok(StructLayout::new(row_count, dtype, vec![]).into_layout());
91 }
92
93 let columns_vec_stream = stream.map(move |chunk| {
95 let (sequence_id, chunk) = chunk?;
96 let mut sequence_pointer = sequence_id.descend();
97 let struct_chunk = chunk.to_struct();
98 let mut columns: Vec<(SequenceId, ArrayRef)> = Vec::new();
99 if is_nullable {
100 columns.push((
101 sequence_pointer.advance(),
102 chunk.validity_mask().into_array(),
103 ));
104 }
105
106 columns.extend(
107 struct_chunk
108 .fields()
109 .iter()
110 .map(|field| (sequence_pointer.advance(), field.to_array())),
111 );
112
113 Ok(columns)
114 });
115
116 let mut stream_count = struct_dtype.nfields();
117 if is_nullable {
118 stream_count += 1;
119 }
120
121 let (column_streams_tx, column_streams_rx): (Vec<_>, Vec<_>) =
122 (0..stream_count).map(|_| kanal::bounded_async(1)).unzip();
123
124 handle
126 .spawn(async move {
127 pin_mut!(columns_vec_stream);
128 while let Some(result) = columns_vec_stream.next().await {
129 match result {
130 Ok(columns) => {
131 for (tx, column) in column_streams_tx.iter().zip_eq(columns.into_iter())
132 {
133 let _ = tx.send(Ok(column)).await;
134 }
135 }
136 Err(e) => {
137 let e: Arc<VortexError> = Arc::new(e);
138 for tx in column_streams_tx.iter() {
139 let _ = tx.send(Err(VortexError::from(e.clone()))).await;
140 }
141 break;
142 }
143 }
144 }
145 })
146 .detach();
147
148 let column_dtypes: Vec<DType> = if is_nullable {
150 std::iter::once(DType::Bool(Nullability::NonNullable))
151 .chain(struct_dtype.fields())
152 .collect()
153 } else {
154 struct_dtype.fields().collect()
155 };
156
157 let layout_futures: Vec<_> = column_dtypes
158 .into_iter()
159 .zip_eq(column_streams_rx)
160 .enumerate()
161 .map(move |(index, (dtype, recv))| {
162 let column_stream =
163 SequentialStreamAdapter::new(dtype.clone(), recv.into_stream().boxed())
164 .sendable();
165 let child_eof = eof.split_off();
166 handle.spawn_nested(|h| {
167 let child = self.child.clone();
168 let validity = self.validity.clone();
169 let this = self.clone();
170 let ctx = ctx.clone();
171 let dtype = dtype.clone();
172 let segment_sink = segment_sink.clone();
173 async move {
174 if index == 0 && is_nullable {
176 validity
177 .write_stream(ctx, segment_sink, column_stream, child_eof, h)
178 .await
179 } else {
180 if dtype.is_struct() {
183 this.write_stream(ctx, segment_sink, column_stream, child_eof, h)
184 .await
185 } else {
186 child
187 .write_stream(ctx, segment_sink, column_stream, child_eof, h)
188 .await
189 }
190 }
191 }
192 })
193 })
194 .collect();
195
196 let column_layouts = try_join_all(layout_futures).await?;
197 let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0);
200 Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout())
201 }
202
203 fn buffered_bytes(&self) -> u64 {
204 self.child.buffered_bytes()
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use std::sync::Arc;
211
212 use vortex_array::ArrayContext;
213 use vortex_array::Canonical;
214 use vortex_array::IntoArray as _;
215 use vortex_array::arrays::ChunkedArray;
216 use vortex_array::arrays::StructArray;
217 use vortex_array::validity::Validity;
218 use vortex_dtype::DType;
219 use vortex_dtype::FieldNames;
220 use vortex_dtype::Nullability;
221 use vortex_dtype::PType;
222 use vortex_io::runtime::single::block_on;
223
224 use crate::LayoutStrategy;
225 use crate::layouts::flat::writer::FlatLayoutStrategy;
226 use crate::layouts::struct_::writer::StructStrategy;
227 use crate::segments::TestSegments;
228 use crate::sequence::SequenceId;
229 use crate::sequence::SequentialArrayStreamExt;
230
231 #[test]
232 #[should_panic]
233 fn fails_on_duplicate_field() {
234 let strategy =
235 StructStrategy::new(FlatLayoutStrategy::default(), FlatLayoutStrategy::default());
236 let (ptr, eof) = SequenceId::root().split();
237 let ctx = ArrayContext::empty();
238
239 let segments = Arc::new(TestSegments::default());
240 block_on(|handle| {
241 strategy.write_stream(
242 ctx,
243 segments,
244 Canonical::empty(&DType::Struct(
245 [
246 ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
247 ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
248 ]
249 .into_iter()
250 .collect(),
251 Nullability::NonNullable,
252 ))
253 .into_array()
254 .to_array_stream()
255 .sequenced(ptr),
256 eof,
257 handle,
258 )
259 })
260 .unwrap();
261 }
262
263 #[test]
264 fn write_empty_field_struct_array() {
265 let strategy =
266 StructStrategy::new(FlatLayoutStrategy::default(), FlatLayoutStrategy::default());
267 let (ptr, eof) = SequenceId::root().split();
268 let ctx = ArrayContext::empty();
269
270 let segments = Arc::new(TestSegments::default());
271 let res = block_on(|handle| {
272 strategy.write_stream(
273 ctx,
274 segments,
275 ChunkedArray::from_iter([
276 StructArray::try_new(FieldNames::default(), vec![], 3, Validity::NonNullable)
277 .unwrap()
278 .into_array(),
279 StructArray::try_new(FieldNames::default(), vec![], 5, Validity::NonNullable)
280 .unwrap()
281 .into_array(),
282 ])
283 .into_array()
284 .to_array_stream()
285 .sequenced(ptr),
286 eof,
287 handle,
288 )
289 });
290
291 assert_eq!(res.unwrap().row_count(), 8);
292 }
293}