vortex_layout/layouts/struct_/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::collections::VecDeque;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll, Waker, ready};
8
9use async_trait::async_trait;
10use futures::future::try_join_all;
11use futures::task::{ArcWake, waker_ref};
12use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
13use itertools::Itertools;
14use parking_lot::Mutex;
15use vortex_array::{Array, ArrayContext, ToCanonical};
16use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
17use vortex_utils::aliases::DefaultHashBuilder;
18use vortex_utils::aliases::hash_map::HashMap;
19use vortex_utils::aliases::hash_set::HashSet;
20
21use crate::layouts::struct_::StructLayout;
22use crate::segments::SequenceWriter;
23use crate::{
24    IntoLayout as _, LayoutRef, LayoutStrategy, SendableSequentialStream, SequentialStreamAdapter,
25    SequentialStreamExt,
26};
27
28pub struct StructStrategy<S> {
29    child: S,
30}
31
32/// A [`LayoutStrategy`] that splits a StructArray batch into child layout writers
33impl<S> StructStrategy<S>
34where
35    S: LayoutStrategy,
36{
37    pub fn new(child: S) -> Self {
38        Self { child }
39    }
40}
41
42#[async_trait]
43impl<S> LayoutStrategy for StructStrategy<S>
44where
45    S: LayoutStrategy,
46{
47    async fn write_stream(
48        &self,
49        ctx: &ArrayContext,
50        sequence_writer: SequenceWriter,
51        stream: SendableSequentialStream,
52    ) -> VortexResult<LayoutRef> {
53        let dtype = stream.dtype().clone();
54        let Some(struct_dtype) = stream.dtype().as_struct_opt().cloned() else {
55            // nothing we can do if dtype is not struct
56            return self.child.write_stream(ctx, sequence_writer, stream).await;
57        };
58        if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len()
59            != struct_dtype.names().len()
60        {
61            vortex_bail!("StructLayout must have unique field names");
62        }
63
64        let stream = stream.map(|chunk| {
65            let (sequence_id, chunk) = chunk?;
66            if !chunk.all_valid()? {
67                vortex_bail!("Cannot push struct chunks with top level invalid values");
68            };
69            Ok((sequence_id, chunk))
70        });
71
72        // There are now fields so this is the layout leaf
73        if struct_dtype.nfields() == 0 {
74            let row_count = stream
75                .try_fold(
76                    0u64,
77                    |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) },
78                )
79                .await?;
80            return Ok(StructLayout::new(row_count, dtype, vec![]).into_layout());
81        }
82
83        // stream<struct_chunk> -> stream<vec<column_chunk>>
84        let columns_vec_stream = stream.map(|chunk| {
85            let (sequence_id, chunk) = chunk?;
86            let mut sequence_pointer = sequence_id.descend();
87            let struct_chunk = chunk.to_struct()?;
88            let columns: Vec<_> = (0..struct_chunk.struct_fields().nfields())
89                .map(|idx| {
90                    (
91                        sequence_pointer.advance(),
92                        struct_chunk.fields()[idx].to_array(),
93                    )
94                })
95                .collect();
96            Ok(columns)
97        });
98
99        // stream<vec<column_chunk>> -> vec<stream<column_chunk>>
100        let column_streams = transpose_stream(columns_vec_stream, struct_dtype.nfields());
101
102        let column_dtypes = (0..struct_dtype.nfields()).map(move |idx| {
103            struct_dtype
104                .field_by_index(idx)
105                .vortex_expect("bound checked")
106        });
107
108        let layout_futures: Vec<_> = column_dtypes
109            .zip_eq(column_streams)
110            .map(move |(dtype, stream)| {
111                let column_stream = SequentialStreamAdapter::new(dtype, stream).sendable();
112                self.child
113                    .write_stream(ctx, sequence_writer.clone(), column_stream)
114                    .boxed()
115            })
116            .collect();
117
118        let column_layouts = try_join_all(layout_futures).await?;
119        // TODO(os): transposed stream could count row counts as well,
120        // This must hold though, all columns must have the same row count of the struct layout
121        let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0);
122        Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout())
123    }
124}
125
126fn transpose_stream<T, S>(stream: S, elements: usize) -> Vec<impl Stream<Item = VortexResult<T>>>
127where
128    S: Stream<Item = VortexResult<Vec<T>>> + Unpin,
129    T: Unpin + 'static,
130{
131    let state = Arc::new(Mutex::new(TransposeState {
132        upstream: stream,
133        buffers: (0..elements).map(|_| VecDeque::new()).collect(),
134        exhausted: false,
135    }));
136
137    let shared_waker = Arc::new(SharedWaker {
138        wakers: Default::default(),
139    });
140
141    (0..elements)
142        .map(|index| TransposedStream {
143            index,
144            state: state.clone(),
145            shared_waker: shared_waker.clone(),
146        })
147        .collect()
148}
149
150struct TransposeState<T, S>
151where
152    S: Stream<Item = VortexResult<Vec<T>>> + Unpin,
153    T: Unpin,
154{
155    upstream: S,
156    // TODO(os): make these buffers bounded so transposed streams can not run ahead unbounded
157    buffers: Vec<VecDeque<VortexResult<T>>>,
158    exhausted: bool,
159}
160
161struct SharedWaker {
162    wakers: Arc<Mutex<HashMap<usize, Waker>>>,
163}
164
165impl SharedWaker {
166    pub fn add(self: Arc<Self>, index: usize, waker: Waker) {
167        self.wakers.lock().insert(index, waker);
168    }
169}
170
171impl ArcWake for SharedWaker {
172    fn wake_by_ref(arc_self: &Arc<Self>) {
173        for (_, waker) in arc_self.wakers.lock().drain() {
174            waker.wake();
175        }
176    }
177}
178
179struct TransposedStream<T, S>
180where
181    S: Stream<Item = VortexResult<Vec<T>>> + Unpin,
182    T: Unpin,
183{
184    index: usize,
185    state: Arc<Mutex<TransposeState<T, S>>>,
186    shared_waker: Arc<SharedWaker>,
187}
188
189impl<T, S> Stream for TransposedStream<T, S>
190where
191    S: Stream<Item = VortexResult<Vec<T>>> + Unpin,
192    T: Unpin,
193{
194    type Item = VortexResult<T>;
195    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
196        let mut guard = self.state.lock();
197        if let Some(item) = guard.buffers[self.index].pop_front() {
198            return Poll::Ready(Some(item));
199        }
200
201        // if we know upstream is exhausted we can skip polling it again.
202        if guard.exhausted {
203            return Poll::Ready(None);
204        }
205
206        self.shared_waker
207            .clone()
208            .add(self.index, cx.waker().clone());
209
210        let shared_waker_ref = waker_ref(&self.shared_waker);
211        let mut upstream_cx = Context::from_waker(&shared_waker_ref);
212        match ready!(Pin::new(&mut guard.upstream).poll_next(&mut upstream_cx)) {
213            None => {
214                guard.exhausted = true;
215                Poll::Ready(None)
216            }
217            Some(Ok(vec_t)) => {
218                for (t, buffer) in vec_t.into_iter().zip_eq(guard.buffers.iter_mut()) {
219                    buffer.push_back(Ok(t));
220                }
221                let item = guard.buffers[self.index]
222                    .pop_front()
223                    .vortex_expect("just pushed");
224                Poll::Ready(Some(item))
225            }
226            Some(Err(err)) => {
227                let shared_err = Arc::new(err);
228                for buffer in guard.buffers.iter_mut() {
229                    buffer.push_back(Err(shared_err.clone().into()));
230                }
231                Poll::Ready(Some(Err(shared_err.into())))
232            }
233        }
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use futures::executor::block_on;
240    use futures::stream;
241    use vortex_array::arrays::{BoolArray, StructArray};
242    use vortex_array::validity::Validity;
243    use vortex_array::{ArrayContext, IntoArray as _};
244    use vortex_buffer::buffer;
245    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
246
247    use crate::layouts::flat::writer::FlatLayoutStrategy;
248    use crate::layouts::struct_::writer::StructStrategy;
249    use crate::segments::{SequenceWriter, TestSegments};
250    use crate::sequence::SequenceId;
251    use crate::{LayoutStrategy, SequentialStreamAdapter, SequentialStreamExt};
252
253    #[test]
254    #[should_panic]
255    fn fails_on_duplicate_field() {
256        let strategy = StructStrategy::new(FlatLayoutStrategy::default());
257        block_on(
258            strategy.write_stream(
259                &ArrayContext::empty(),
260                SequenceWriter::new(Box::new(TestSegments::default())),
261                SequentialStreamAdapter::new(
262                    DType::Struct(
263                        [
264                            ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
265                            ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
266                        ]
267                        .into_iter()
268                        .collect(),
269                        Nullability::NonNullable,
270                    ),
271                    stream::empty(),
272                )
273                .sendable(),
274            ),
275        )
276        .unwrap();
277    }
278
279    #[test]
280    fn fails_on_top_level_nulls() {
281        let strategy = StructStrategy::new(FlatLayoutStrategy::default());
282        let res = block_on(
283            strategy.write_stream(
284                &ArrayContext::empty(),
285                SequenceWriter::new(Box::new(TestSegments::default())),
286                SequentialStreamAdapter::new(
287                    DType::Struct(
288                        [("a", DType::Primitive(PType::I32, Nullability::NonNullable))]
289                            .into_iter()
290                            .collect(),
291                        Nullability::Nullable,
292                    ),
293                    stream::once(async move {
294                        Ok((
295                            SequenceId::root().downgrade(),
296                            StructArray::try_new(
297                                ["a"].into(),
298                                vec![buffer![1, 2, 3].into_array()],
299                                3,
300                                Validity::Array(
301                                    BoolArray::from_iter(vec![true, true, false]).into_array(),
302                                ),
303                            )
304                            .unwrap()
305                            .into_array(),
306                        ))
307                    }),
308                )
309                .sendable(),
310            ),
311        );
312        assert!(
313            format!("{}", res.unwrap_err())
314                .starts_with("Cannot push struct chunks with top level invalid values"),
315        )
316    }
317
318    #[test]
319    fn write_empty_field_struct_array() {
320        let strategy = StructStrategy::new(FlatLayoutStrategy::default());
321        let res = block_on(
322            strategy.write_stream(
323                &ArrayContext::empty(),
324                SequenceWriter::new(Box::new(TestSegments::default())),
325                SequentialStreamAdapter::new(
326                    DType::Struct(
327                        StructFields::new(FieldNames::default(), vec![]),
328                        Nullability::NonNullable,
329                    ),
330                    stream::iter([
331                        {
332                            Ok((
333                                SequenceId::root().downgrade(),
334                                StructArray::try_new(
335                                    FieldNames::default(),
336                                    vec![],
337                                    3,
338                                    Validity::NonNullable,
339                                )
340                                .unwrap()
341                                .into_array(),
342                            ))
343                        },
344                        {
345                            Ok((
346                                SequenceId::root().advance(),
347                                StructArray::try_new(
348                                    FieldNames::default(),
349                                    vec![],
350                                    5,
351                                    Validity::NonNullable,
352                                )
353                                .unwrap()
354                                .into_array(),
355                            ))
356                        },
357                    ]),
358                )
359                .sendable(),
360            ),
361        );
362
363        assert_eq!(res.unwrap().row_count(), 8);
364    }
365}