vortex_layout/layouts/struct_/
writer.rs1use std::collections::VecDeque;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll, Waker, ready};
8
9use async_trait::async_trait;
10use futures::future::try_join_all;
11use futures::task::{ArcWake, waker_ref};
12use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
13use itertools::Itertools;
14use parking_lot::Mutex;
15use vortex_array::{Array, ArrayContext, ToCanonical};
16use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
17use vortex_utils::aliases::DefaultHashBuilder;
18use vortex_utils::aliases::hash_map::HashMap;
19use vortex_utils::aliases::hash_set::HashSet;
20
21use crate::layouts::struct_::StructLayout;
22use crate::segments::SequenceWriter;
23use crate::{
24 IntoLayout as _, LayoutRef, LayoutStrategy, SendableSequentialStream, SequentialStreamAdapter,
25 SequentialStreamExt,
26};
27
28pub struct StructStrategy<S> {
29 child: S,
30}
31
32impl<S> StructStrategy<S>
34where
35 S: LayoutStrategy,
36{
37 pub fn new(child: S) -> Self {
38 Self { child }
39 }
40}
41
42#[async_trait]
43impl<S> LayoutStrategy for StructStrategy<S>
44where
45 S: LayoutStrategy,
46{
47 async fn write_stream(
48 &self,
49 ctx: &ArrayContext,
50 sequence_writer: SequenceWriter,
51 stream: SendableSequentialStream,
52 ) -> VortexResult<LayoutRef> {
53 let dtype = stream.dtype().clone();
54 let Some(struct_dtype) = stream.dtype().as_struct_opt().cloned() else {
55 return self.child.write_stream(ctx, sequence_writer, stream).await;
57 };
58 if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len()
59 != struct_dtype.names().len()
60 {
61 vortex_bail!("StructLayout must have unique field names");
62 }
63
64 let stream = stream.map(|chunk| {
65 let (sequence_id, chunk) = chunk?;
66 if !chunk.all_valid()? {
67 vortex_bail!("Cannot push struct chunks with top level invalid values");
68 };
69 Ok((sequence_id, chunk))
70 });
71
72 if struct_dtype.nfields() == 0 {
74 let row_count = stream
75 .try_fold(
76 0u64,
77 |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) },
78 )
79 .await?;
80 return Ok(StructLayout::new(row_count, dtype, vec![]).into_layout());
81 }
82
83 let columns_vec_stream = stream.map(|chunk| {
85 let (sequence_id, chunk) = chunk?;
86 let mut sequence_pointer = sequence_id.descend();
87 let struct_chunk = chunk.to_struct()?;
88 let columns: Vec<_> = (0..struct_chunk.struct_fields().nfields())
89 .map(|idx| {
90 (
91 sequence_pointer.advance(),
92 struct_chunk.fields()[idx].to_array(),
93 )
94 })
95 .collect();
96 Ok(columns)
97 });
98
99 let column_streams = transpose_stream(columns_vec_stream, struct_dtype.nfields());
101
102 let column_dtypes = (0..struct_dtype.nfields()).map(move |idx| {
103 struct_dtype
104 .field_by_index(idx)
105 .vortex_expect("bound checked")
106 });
107
108 let layout_futures: Vec<_> = column_dtypes
109 .zip_eq(column_streams)
110 .map(move |(dtype, stream)| {
111 let column_stream = SequentialStreamAdapter::new(dtype, stream).sendable();
112 self.child
113 .write_stream(ctx, sequence_writer.clone(), column_stream)
114 .boxed()
115 })
116 .collect();
117
118 let column_layouts = try_join_all(layout_futures).await?;
119 let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0);
122 Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout())
123 }
124}
125
126fn transpose_stream<T, S>(stream: S, elements: usize) -> Vec<impl Stream<Item = VortexResult<T>>>
127where
128 S: Stream<Item = VortexResult<Vec<T>>> + Unpin,
129 T: Unpin + 'static,
130{
131 let state = Arc::new(Mutex::new(TransposeState {
132 upstream: stream,
133 buffers: (0..elements).map(|_| VecDeque::new()).collect(),
134 exhausted: false,
135 }));
136
137 let shared_waker = Arc::new(SharedWaker {
138 wakers: Default::default(),
139 });
140
141 (0..elements)
142 .map(|index| TransposedStream {
143 index,
144 state: state.clone(),
145 shared_waker: shared_waker.clone(),
146 })
147 .collect()
148}
149
150struct TransposeState<T, S>
151where
152 S: Stream<Item = VortexResult<Vec<T>>> + Unpin,
153 T: Unpin,
154{
155 upstream: S,
156 buffers: Vec<VecDeque<VortexResult<T>>>,
158 exhausted: bool,
159}
160
161struct SharedWaker {
162 wakers: Arc<Mutex<HashMap<usize, Waker>>>,
163}
164
165impl SharedWaker {
166 pub fn add(self: Arc<Self>, index: usize, waker: Waker) {
167 self.wakers.lock().insert(index, waker);
168 }
169}
170
171impl ArcWake for SharedWaker {
172 fn wake_by_ref(arc_self: &Arc<Self>) {
173 for (_, waker) in arc_self.wakers.lock().drain() {
174 waker.wake();
175 }
176 }
177}
178
179struct TransposedStream<T, S>
180where
181 S: Stream<Item = VortexResult<Vec<T>>> + Unpin,
182 T: Unpin,
183{
184 index: usize,
185 state: Arc<Mutex<TransposeState<T, S>>>,
186 shared_waker: Arc<SharedWaker>,
187}
188
189impl<T, S> Stream for TransposedStream<T, S>
190where
191 S: Stream<Item = VortexResult<Vec<T>>> + Unpin,
192 T: Unpin,
193{
194 type Item = VortexResult<T>;
195 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
196 let mut guard = self.state.lock();
197 if let Some(item) = guard.buffers[self.index].pop_front() {
198 return Poll::Ready(Some(item));
199 }
200
201 if guard.exhausted {
203 return Poll::Ready(None);
204 }
205
206 self.shared_waker
207 .clone()
208 .add(self.index, cx.waker().clone());
209
210 let shared_waker_ref = waker_ref(&self.shared_waker);
211 let mut upstream_cx = Context::from_waker(&shared_waker_ref);
212 match ready!(Pin::new(&mut guard.upstream).poll_next(&mut upstream_cx)) {
213 None => {
214 guard.exhausted = true;
215 Poll::Ready(None)
216 }
217 Some(Ok(vec_t)) => {
218 for (t, buffer) in vec_t.into_iter().zip_eq(guard.buffers.iter_mut()) {
219 buffer.push_back(Ok(t));
220 }
221 let item = guard.buffers[self.index]
222 .pop_front()
223 .vortex_expect("just pushed");
224 Poll::Ready(Some(item))
225 }
226 Some(Err(err)) => {
227 let shared_err = Arc::new(err);
228 for buffer in guard.buffers.iter_mut() {
229 buffer.push_back(Err(shared_err.clone().into()));
230 }
231 Poll::Ready(Some(Err(shared_err.into())))
232 }
233 }
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use futures::executor::block_on;
240 use futures::stream;
241 use vortex_array::arrays::{BoolArray, StructArray};
242 use vortex_array::validity::Validity;
243 use vortex_array::{ArrayContext, IntoArray as _};
244 use vortex_buffer::buffer;
245 use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
246
247 use crate::layouts::flat::writer::FlatLayoutStrategy;
248 use crate::layouts::struct_::writer::StructStrategy;
249 use crate::segments::{SequenceWriter, TestSegments};
250 use crate::sequence::SequenceId;
251 use crate::{LayoutStrategy, SequentialStreamAdapter, SequentialStreamExt};
252
253 #[test]
254 #[should_panic]
255 fn fails_on_duplicate_field() {
256 let strategy = StructStrategy::new(FlatLayoutStrategy::default());
257 block_on(
258 strategy.write_stream(
259 &ArrayContext::empty(),
260 SequenceWriter::new(Box::new(TestSegments::default())),
261 SequentialStreamAdapter::new(
262 DType::Struct(
263 [
264 ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
265 ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
266 ]
267 .into_iter()
268 .collect(),
269 Nullability::NonNullable,
270 ),
271 stream::empty(),
272 )
273 .sendable(),
274 ),
275 )
276 .unwrap();
277 }
278
279 #[test]
280 fn fails_on_top_level_nulls() {
281 let strategy = StructStrategy::new(FlatLayoutStrategy::default());
282 let res = block_on(
283 strategy.write_stream(
284 &ArrayContext::empty(),
285 SequenceWriter::new(Box::new(TestSegments::default())),
286 SequentialStreamAdapter::new(
287 DType::Struct(
288 [("a", DType::Primitive(PType::I32, Nullability::NonNullable))]
289 .into_iter()
290 .collect(),
291 Nullability::Nullable,
292 ),
293 stream::once(async move {
294 Ok((
295 SequenceId::root().downgrade(),
296 StructArray::try_new(
297 ["a"].into(),
298 vec![buffer![1, 2, 3].into_array()],
299 3,
300 Validity::Array(
301 BoolArray::from_iter(vec![true, true, false]).into_array(),
302 ),
303 )
304 .unwrap()
305 .into_array(),
306 ))
307 }),
308 )
309 .sendable(),
310 ),
311 );
312 assert!(
313 format!("{}", res.unwrap_err())
314 .starts_with("Cannot push struct chunks with top level invalid values"),
315 )
316 }
317
318 #[test]
319 fn write_empty_field_struct_array() {
320 let strategy = StructStrategy::new(FlatLayoutStrategy::default());
321 let res = block_on(
322 strategy.write_stream(
323 &ArrayContext::empty(),
324 SequenceWriter::new(Box::new(TestSegments::default())),
325 SequentialStreamAdapter::new(
326 DType::Struct(
327 StructFields::new(FieldNames::default(), vec![]),
328 Nullability::NonNullable,
329 ),
330 stream::iter([
331 {
332 Ok((
333 SequenceId::root().downgrade(),
334 StructArray::try_new(
335 FieldNames::default(),
336 vec![],
337 3,
338 Validity::NonNullable,
339 )
340 .unwrap()
341 .into_array(),
342 ))
343 },
344 {
345 Ok((
346 SequenceId::root().advance(),
347 StructArray::try_new(
348 FieldNames::default(),
349 vec![],
350 5,
351 Validity::NonNullable,
352 )
353 .unwrap()
354 .into_array(),
355 ))
356 },
357 ]),
358 )
359 .sendable(),
360 ),
361 );
362
363 assert_eq!(res.unwrap().row_count(), 8);
364 }
365}