vortex_layout/layouts/struct_/
writer.rs1use std::collections::VecDeque;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll, Waker};
8
9use arcref::ArcRef;
10use futures::future::try_join_all;
11use futures::{Stream, StreamExt, TryStreamExt};
12use itertools::Itertools;
13use parking_lot::Mutex;
14use vortex_array::{Array, ArrayContext, ToCanonical};
15use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
16use vortex_utils::aliases::DefaultHashBuilder;
17use vortex_utils::aliases::hash_set::HashSet;
18
19use crate::layouts::struct_::StructLayout;
20use crate::segments::SequenceWriter;
21use crate::{
22 IntoLayout as _, LayoutStrategy, SendableLayoutFuture, SendableSequentialStream,
23 SequentialStreamAdapter, SequentialStreamExt,
24};
25
26pub struct StructStrategy {
27 child: ArcRef<dyn LayoutStrategy>,
28}
29
30impl StructStrategy {
32 pub fn new(child: ArcRef<dyn LayoutStrategy>) -> Self {
33 Self { child }
34 }
35}
36
37impl LayoutStrategy for StructStrategy {
38 fn write_stream(
39 &self,
40 ctx: &ArrayContext,
41 sequence_writer: SequenceWriter,
42 stream: SendableSequentialStream,
43 ) -> SendableLayoutFuture {
44 let dtype = stream.dtype().clone();
45 let Some(struct_dtype) = stream.dtype().as_struct().cloned() else {
46 return self.child.write_stream(ctx, sequence_writer, stream);
48 };
49 if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len()
50 != struct_dtype.names().len()
51 {
52 return Box::pin(async { vortex_bail!("StructLayout must have unique field names") });
53 }
54
55 let stream = stream.map(|chunk| {
56 let (sequence_id, chunk) = chunk?;
57 if !chunk.all_valid()? {
58 vortex_bail!("Cannot push struct chunks with top level invalid values");
59 };
60 Ok((sequence_id, chunk))
61 });
62
63 if struct_dtype.nfields() == 0 {
65 return Box::pin(async move {
66 let row_count = stream
67 .try_fold(
68 0u64,
69 |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) },
70 )
71 .await?;
72 Ok(StructLayout::new(row_count, dtype, vec![]).into_layout())
73 });
74 }
75
76 let columns_vec_stream = stream.map(|chunk| {
78 let (sequence_id, chunk) = chunk?;
79 let mut sequence_pointer = sequence_id.descend();
80 let struct_chunk = chunk.to_struct()?;
81 let columns: Vec<_> = (0..struct_chunk.struct_fields().nfields())
82 .map(|idx| {
83 (
84 sequence_pointer.advance(),
85 struct_chunk.fields()[idx].to_array(),
86 )
87 })
88 .collect();
89 Ok(columns)
90 });
91
92 let column_streams = transpose_stream(columns_vec_stream, struct_dtype.nfields());
94
95 let column_dtypes = (0..struct_dtype.nfields()).map(move |idx| {
96 struct_dtype
97 .field_by_index(idx)
98 .vortex_expect("bound checked")
99 });
100 let child = self.child.clone();
101 let ctx = ctx.clone();
102 let layout_futures = column_dtypes
103 .zip_eq(column_streams)
104 .map(move |(dtype, stream)| {
105 let column_stream = SequentialStreamAdapter::new(dtype, stream).sendable();
106 child.write_stream(&ctx, sequence_writer.clone(), column_stream)
107 });
108
109 Box::pin(async move {
110 let column_layouts = try_join_all(layout_futures).await?;
111 let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0);
114 Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout())
115 })
116 }
117}
118
119fn transpose_stream<T, S>(stream: S, elements: usize) -> Vec<impl Stream<Item = VortexResult<T>>>
120where
121 S: Stream<Item = VortexResult<Vec<T>>> + Unpin,
122 T: Unpin + 'static,
123{
124 let state = Arc::new(Mutex::new(TransposeState {
125 upstream: stream,
126 buffers: (0..elements).map(|_| VecDeque::new()).collect(),
127 wakers: Vec::new(),
128 exhausted: false,
129 }));
130 (0..elements)
131 .map(|index| TransposedStream {
132 index,
133 state: state.clone(),
134 })
135 .collect()
136}
137
138struct TransposeState<T, S>
139where
140 S: Stream<Item = VortexResult<Vec<T>>> + Unpin,
141 T: Unpin,
142{
143 upstream: S,
144 buffers: Vec<VecDeque<VortexResult<T>>>,
146 wakers: Vec<Waker>,
147 exhausted: bool,
148}
149
150struct TransposedStream<T, S>
151where
152 S: Stream<Item = VortexResult<Vec<T>>> + Unpin,
153 T: Unpin,
154{
155 index: usize,
156 state: Arc<Mutex<TransposeState<T, S>>>,
157}
158
159impl<T, S> Stream for TransposedStream<T, S>
160where
161 S: Stream<Item = VortexResult<Vec<T>>> + Unpin,
162 T: Unpin,
163{
164 type Item = VortexResult<T>;
165 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
166 let mut guard = self.state.lock();
167 if let Some(item) = guard.buffers[self.index].pop_front() {
168 return Poll::Ready(Some(item));
169 }
170
171 if guard.exhausted {
173 return Poll::Ready(None);
174 }
175
176 let poll_result = match Pin::new(&mut guard.upstream).poll_next(cx) {
177 Poll::Pending => {
178 guard.wakers.push(cx.waker().clone());
179 Poll::Pending
180 }
181 Poll::Ready(None) => {
182 guard.exhausted = true;
183 Poll::Ready(None)
184 }
185 Poll::Ready(Some(Ok(vec_t))) => {
186 for (t, buffer) in vec_t.into_iter().zip_eq(guard.buffers.iter_mut()) {
187 buffer.push_back(Ok(t));
188 }
189 let item = guard.buffers[self.index]
190 .pop_front()
191 .vortex_expect("just pushed");
192 Poll::Ready(Some(item))
193 }
194 Poll::Ready(Some(Err(err))) => {
195 let shared_err = Arc::new(err);
196 for buffer in guard.buffers.iter_mut() {
197 buffer.push_back(Err(shared_err.clone().into()));
198 }
199 Poll::Ready(Some(Err(shared_err.into())))
200 }
201 };
202
203 if matches!(poll_result, Poll::Ready(_)) {
204 let wakers = std::mem::take(&mut guard.wakers);
205
206 drop(guard);
207 for waker in wakers {
208 waker.wake();
209 }
210 }
211 poll_result
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use std::sync::Arc;
218
219 use arcref::ArcRef;
220 use futures::executor::block_on;
221 use futures::stream;
222 use vortex_array::arrays::{BoolArray, StructArray};
223 use vortex_array::validity::Validity;
224 use vortex_array::{ArrayContext, IntoArray as _};
225 use vortex_buffer::buffer;
226 use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
227
228 use crate::layouts::flat::writer::FlatLayoutStrategy;
229 use crate::layouts::struct_::writer::StructStrategy;
230 use crate::segments::{SequenceWriter, TestSegments};
231 use crate::sequence::SequenceId;
232 use crate::{LayoutStrategy, SequentialStreamAdapter, SequentialStreamExt};
233
234 #[test]
235 #[should_panic]
236 fn fails_on_duplicate_field() {
237 let strategy =
238 StructStrategy::new(ArcRef::new_arc(Arc::new(FlatLayoutStrategy::default())));
239 block_on(
240 strategy.write_stream(
241 &ArrayContext::empty(),
242 SequenceWriter::new(Box::new(TestSegments::default())),
243 SequentialStreamAdapter::new(
244 DType::Struct(
245 [
246 ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
247 ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
248 ]
249 .into_iter()
250 .collect(),
251 Nullability::NonNullable,
252 ),
253 stream::empty(),
254 )
255 .sendable(),
256 ),
257 )
258 .unwrap();
259 }
260
261 #[test]
262 fn fails_on_top_level_nulls() {
263 let strategy =
264 StructStrategy::new(ArcRef::new_arc(Arc::new(FlatLayoutStrategy::default())));
265 let res = block_on(
266 strategy.write_stream(
267 &ArrayContext::empty(),
268 SequenceWriter::new(Box::new(TestSegments::default())),
269 SequentialStreamAdapter::new(
270 DType::Struct(
271 [("a", DType::Primitive(PType::I32, Nullability::NonNullable))]
272 .into_iter()
273 .collect(),
274 Nullability::Nullable,
275 ),
276 stream::once(async move {
277 Ok((
278 SequenceId::root().downgrade(),
279 StructArray::try_new(
280 ["a"].into(),
281 vec![buffer![1, 2, 3].into_array()],
282 3,
283 Validity::Array(
284 BoolArray::from_iter(vec![true, true, false]).into_array(),
285 ),
286 )
287 .unwrap()
288 .into_array(),
289 ))
290 }),
291 )
292 .sendable(),
293 ),
294 );
295 assert!(
296 format!("{}", res.unwrap_err())
297 .starts_with("Cannot push struct chunks with top level invalid values"),
298 )
299 }
300
301 #[test]
302 fn write_empty_field_struct_array() {
303 let strategy =
304 StructStrategy::new(ArcRef::new_arc(Arc::new(FlatLayoutStrategy::default())));
305 let res = block_on(
306 strategy.write_stream(
307 &ArrayContext::empty(),
308 SequenceWriter::new(Box::new(TestSegments::default())),
309 SequentialStreamAdapter::new(
310 DType::Struct(
311 StructFields::new(FieldNames::default(), vec![]),
312 Nullability::NonNullable,
313 ),
314 stream::iter([
315 {
316 Ok((
317 SequenceId::root().downgrade(),
318 StructArray::try_new(
319 FieldNames::default(),
320 vec![],
321 3,
322 Validity::NonNullable,
323 )
324 .unwrap()
325 .into_array(),
326 ))
327 },
328 {
329 Ok((
330 SequenceId::root().advance(),
331 StructArray::try_new(
332 FieldNames::default(),
333 vec![],
334 5,
335 Validity::NonNullable,
336 )
337 .unwrap()
338 .into_array(),
339 ))
340 },
341 ]),
342 )
343 .sendable(),
344 ),
345 );
346
347 assert_eq!(res.unwrap().row_count(), 8);
348 }
349}