vortex_layout/layouts/struct_/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::collections::VecDeque;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll, Waker};
8
9use arcref::ArcRef;
10use futures::future::try_join_all;
11use futures::{Stream, StreamExt, TryStreamExt};
12use itertools::Itertools;
13use parking_lot::Mutex;
14use vortex_array::{Array, ArrayContext, ToCanonical};
15use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
16use vortex_utils::aliases::DefaultHashBuilder;
17use vortex_utils::aliases::hash_set::HashSet;
18
19use crate::layouts::struct_::StructLayout;
20use crate::segments::SequenceWriter;
21use crate::{
22    IntoLayout as _, LayoutStrategy, SendableLayoutFuture, SendableSequentialStream,
23    SequentialStreamAdapter, SequentialStreamExt,
24};
25
26pub struct StructStrategy {
27    child: ArcRef<dyn LayoutStrategy>,
28}
29
30/// A [`LayoutStrategy`] that splits a StructArray batch into child layout writers
31impl StructStrategy {
32    pub fn new(child: ArcRef<dyn LayoutStrategy>) -> Self {
33        Self { child }
34    }
35}
36
37impl LayoutStrategy for StructStrategy {
38    fn write_stream(
39        &self,
40        ctx: &ArrayContext,
41        sequence_writer: SequenceWriter,
42        stream: SendableSequentialStream,
43    ) -> SendableLayoutFuture {
44        let dtype = stream.dtype().clone();
45        let Some(struct_dtype) = stream.dtype().as_struct().cloned() else {
46            // nothing we can do if dtype is not struct
47            return self.child.write_stream(ctx, sequence_writer, stream);
48        };
49        if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len()
50            != struct_dtype.names().len()
51        {
52            return Box::pin(async { vortex_bail!("StructLayout must have unique field names") });
53        }
54
55        let stream = stream.map(|chunk| {
56            let (sequence_id, chunk) = chunk?;
57            if !chunk.all_valid()? {
58                vortex_bail!("Cannot push struct chunks with top level invalid values");
59            };
60            Ok((sequence_id, chunk))
61        });
62
63        // There are now fields so this is the layout leaf
64        if struct_dtype.nfields() == 0 {
65            return Box::pin(async move {
66                let row_count = stream
67                    .try_fold(
68                        0u64,
69                        |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) },
70                    )
71                    .await?;
72                Ok(StructLayout::new(row_count, dtype, vec![]).into_layout())
73            });
74        }
75
76        // stream<struct_chunk> -> stream<vec<column_chunk>>
77        let columns_vec_stream = stream.map(|chunk| {
78            let (sequence_id, chunk) = chunk?;
79            let mut sequence_pointer = sequence_id.descend();
80            let struct_chunk = chunk.to_struct()?;
81            let columns: Vec<_> = (0..struct_chunk.struct_fields().nfields())
82                .map(|idx| {
83                    (
84                        sequence_pointer.advance(),
85                        struct_chunk.fields()[idx].to_array(),
86                    )
87                })
88                .collect();
89            Ok(columns)
90        });
91
92        // stream<vec<column_chunk>> -> vec<stream<column_chunk>>
93        let column_streams = transpose_stream(columns_vec_stream, struct_dtype.nfields());
94
95        let column_dtypes = (0..struct_dtype.nfields()).map(move |idx| {
96            struct_dtype
97                .field_by_index(idx)
98                .vortex_expect("bound checked")
99        });
100        let child = self.child.clone();
101        let ctx = ctx.clone();
102        let layout_futures = column_dtypes
103            .zip_eq(column_streams)
104            .map(move |(dtype, stream)| {
105                let column_stream = SequentialStreamAdapter::new(dtype, stream).sendable();
106                child.write_stream(&ctx, sequence_writer.clone(), column_stream)
107            });
108
109        Box::pin(async move {
110            let column_layouts = try_join_all(layout_futures).await?;
111            // TODO(os): transposed stream could count row counts as well,
112            // This must hold though, all columns must have the same row count of the struct layout
113            let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0);
114            Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout())
115        })
116    }
117}
118
119fn transpose_stream<T, S>(stream: S, elements: usize) -> Vec<impl Stream<Item = VortexResult<T>>>
120where
121    S: Stream<Item = VortexResult<Vec<T>>> + Unpin,
122    T: Unpin + 'static,
123{
124    let state = Arc::new(Mutex::new(TransposeState {
125        upstream: stream,
126        buffers: (0..elements).map(|_| VecDeque::new()).collect(),
127        wakers: Vec::new(),
128        exhausted: false,
129    }));
130    (0..elements)
131        .map(|index| TransposedStream {
132            index,
133            state: state.clone(),
134        })
135        .collect()
136}
137
138struct TransposeState<T, S>
139where
140    S: Stream<Item = VortexResult<Vec<T>>> + Unpin,
141    T: Unpin,
142{
143    upstream: S,
144    // TODO(os): make these buffers bounded so transposed streams can not run ahead unbounded
145    buffers: Vec<VecDeque<VortexResult<T>>>,
146    wakers: Vec<Waker>,
147    exhausted: bool,
148}
149
150struct TransposedStream<T, S>
151where
152    S: Stream<Item = VortexResult<Vec<T>>> + Unpin,
153    T: Unpin,
154{
155    index: usize,
156    state: Arc<Mutex<TransposeState<T, S>>>,
157}
158
159impl<T, S> Stream for TransposedStream<T, S>
160where
161    S: Stream<Item = VortexResult<Vec<T>>> + Unpin,
162    T: Unpin,
163{
164    type Item = VortexResult<T>;
165    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
166        let mut guard = self.state.lock();
167        if let Some(item) = guard.buffers[self.index].pop_front() {
168            return Poll::Ready(Some(item));
169        }
170
171        // if we know upstream is exhausted we can skip polling it again.
172        if guard.exhausted {
173            return Poll::Ready(None);
174        }
175
176        let poll_result = match Pin::new(&mut guard.upstream).poll_next(cx) {
177            Poll::Pending => {
178                guard.wakers.push(cx.waker().clone());
179                Poll::Pending
180            }
181            Poll::Ready(None) => {
182                guard.exhausted = true;
183                Poll::Ready(None)
184            }
185            Poll::Ready(Some(Ok(vec_t))) => {
186                for (t, buffer) in vec_t.into_iter().zip_eq(guard.buffers.iter_mut()) {
187                    buffer.push_back(Ok(t));
188                }
189                let item = guard.buffers[self.index]
190                    .pop_front()
191                    .vortex_expect("just pushed");
192                Poll::Ready(Some(item))
193            }
194            Poll::Ready(Some(Err(err))) => {
195                let shared_err = Arc::new(err);
196                for buffer in guard.buffers.iter_mut() {
197                    buffer.push_back(Err(shared_err.clone().into()));
198                }
199                Poll::Ready(Some(Err(shared_err.into())))
200            }
201        };
202
203        if matches!(poll_result, Poll::Ready(_)) {
204            let wakers = std::mem::take(&mut guard.wakers);
205
206            drop(guard);
207            for waker in wakers {
208                waker.wake();
209            }
210        }
211        poll_result
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use std::sync::Arc;
218
219    use arcref::ArcRef;
220    use futures::executor::block_on;
221    use futures::stream;
222    use vortex_array::arrays::{BoolArray, StructArray};
223    use vortex_array::validity::Validity;
224    use vortex_array::{ArrayContext, IntoArray as _};
225    use vortex_buffer::buffer;
226    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
227
228    use crate::layouts::flat::writer::FlatLayoutStrategy;
229    use crate::layouts::struct_::writer::StructStrategy;
230    use crate::segments::{SequenceWriter, TestSegments};
231    use crate::sequence::SequenceId;
232    use crate::{LayoutStrategy, SequentialStreamAdapter, SequentialStreamExt};
233
234    #[test]
235    #[should_panic]
236    fn fails_on_duplicate_field() {
237        let strategy =
238            StructStrategy::new(ArcRef::new_arc(Arc::new(FlatLayoutStrategy::default())));
239        block_on(
240            strategy.write_stream(
241                &ArrayContext::empty(),
242                SequenceWriter::new(Box::new(TestSegments::default())),
243                SequentialStreamAdapter::new(
244                    DType::Struct(
245                        [
246                            ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
247                            ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
248                        ]
249                        .into_iter()
250                        .collect(),
251                        Nullability::NonNullable,
252                    ),
253                    stream::empty(),
254                )
255                .sendable(),
256            ),
257        )
258        .unwrap();
259    }
260
261    #[test]
262    fn fails_on_top_level_nulls() {
263        let strategy =
264            StructStrategy::new(ArcRef::new_arc(Arc::new(FlatLayoutStrategy::default())));
265        let res = block_on(
266            strategy.write_stream(
267                &ArrayContext::empty(),
268                SequenceWriter::new(Box::new(TestSegments::default())),
269                SequentialStreamAdapter::new(
270                    DType::Struct(
271                        [("a", DType::Primitive(PType::I32, Nullability::NonNullable))]
272                            .into_iter()
273                            .collect(),
274                        Nullability::Nullable,
275                    ),
276                    stream::once(async move {
277                        Ok((
278                            SequenceId::root().downgrade(),
279                            StructArray::try_new(
280                                ["a"].into(),
281                                vec![buffer![1, 2, 3].into_array()],
282                                3,
283                                Validity::Array(
284                                    BoolArray::from_iter(vec![true, true, false]).into_array(),
285                                ),
286                            )
287                            .unwrap()
288                            .into_array(),
289                        ))
290                    }),
291                )
292                .sendable(),
293            ),
294        );
295        assert!(
296            format!("{}", res.unwrap_err())
297                .starts_with("Cannot push struct chunks with top level invalid values"),
298        )
299    }
300
301    #[test]
302    fn write_empty_field_struct_array() {
303        let strategy =
304            StructStrategy::new(ArcRef::new_arc(Arc::new(FlatLayoutStrategy::default())));
305        let res = block_on(
306            strategy.write_stream(
307                &ArrayContext::empty(),
308                SequenceWriter::new(Box::new(TestSegments::default())),
309                SequentialStreamAdapter::new(
310                    DType::Struct(
311                        StructFields::new(FieldNames::default(), vec![]),
312                        Nullability::NonNullable,
313                    ),
314                    stream::iter([
315                        {
316                            Ok((
317                                SequenceId::root().downgrade(),
318                                StructArray::try_new(
319                                    FieldNames::default(),
320                                    vec![],
321                                    3,
322                                    Validity::NonNullable,
323                                )
324                                .unwrap()
325                                .into_array(),
326                            ))
327                        },
328                        {
329                            Ok((
330                                SequenceId::root().advance(),
331                                StructArray::try_new(
332                                    FieldNames::default(),
333                                    vec![],
334                                    5,
335                                    Validity::NonNullable,
336                                )
337                                .unwrap()
338                                .into_array(),
339                            ))
340                        },
341                    ]),
342                )
343                .sendable(),
344            ),
345        );
346
347        assert_eq!(res.unwrap().row_count(), 8);
348    }
349}