vortex_layout/layouts/struct_/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::future::try_join_all;
8use futures::{StreamExt, TryStreamExt, pin_mut};
9use itertools::Itertools;
10use vortex_array::{Array, ArrayContext, ToCanonical};
11use vortex_error::{VortexError, VortexExpect as _, VortexResult, vortex_bail};
12use vortex_io::kanal_ext::KanalExt;
13use vortex_io::runtime::Handle;
14use vortex_utils::aliases::DefaultHashBuilder;
15use vortex_utils::aliases::hash_set::HashSet;
16
17use crate::layouts::struct_::StructLayout;
18use crate::segments::SegmentSinkRef;
19use crate::sequence::{
20    SendableSequentialStream, SequencePointer, SequentialStreamAdapter, SequentialStreamExt,
21};
22use crate::{IntoLayout as _, LayoutRef, LayoutStrategy};
23
24pub struct StructStrategy {
25    child: Arc<dyn LayoutStrategy>,
26}
27
28/// A [`LayoutStrategy`] that splits a StructArray batch into child layout writers
29impl StructStrategy {
30    pub fn new<S: LayoutStrategy>(child: S) -> Self {
31        Self {
32            child: Arc::new(child),
33        }
34    }
35}
36
37#[async_trait]
38impl LayoutStrategy for StructStrategy {
39    async fn write_stream(
40        &self,
41        ctx: ArrayContext,
42        segment_sink: SegmentSinkRef,
43        stream: SendableSequentialStream,
44        mut eof: SequencePointer,
45        handle: Handle,
46    ) -> VortexResult<LayoutRef> {
47        let dtype = stream.dtype().clone();
48        let Some(struct_dtype) = stream.dtype().as_struct_fields_opt().cloned() else {
49            // nothing we can do if dtype is not struct
50            return self
51                .child
52                .write_stream(ctx, segment_sink, stream, eof, handle)
53                .await;
54        };
55        if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len()
56            != struct_dtype.names().len()
57        {
58            vortex_bail!("StructLayout must have unique field names");
59        }
60
61        let stream = stream.map(|chunk| {
62            let (sequence_id, chunk) = chunk?;
63            if !chunk.all_valid() {
64                vortex_bail!("Cannot push struct chunks with top level invalid values");
65            };
66            Ok((sequence_id, chunk))
67        });
68
69        // There are now fields so this is the layout leaf
70        if struct_dtype.nfields() == 0 {
71            let row_count = stream
72                .try_fold(
73                    0u64,
74                    |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) },
75                )
76                .await?;
77            return Ok(StructLayout::new(row_count, dtype, vec![]).into_layout());
78        }
79
80        // stream<struct_chunk> -> stream<vec<column_chunk>>
81        let columns_vec_stream = stream.map(|chunk| {
82            let (sequence_id, chunk) = chunk?;
83            let mut sequence_pointer = sequence_id.descend();
84            let struct_chunk = chunk.to_struct();
85            let columns: Vec<_> = struct_chunk
86                .fields()
87                .iter()
88                .map(|field| (sequence_pointer.advance(), field.to_array()))
89                .collect();
90            Ok(columns)
91        });
92
93        let (column_streams_tx, column_streams_rx): (Vec<_>, Vec<_>) = (0..struct_dtype.nfields())
94            .map(|_| kanal::bounded_async(1))
95            .unzip();
96
97        // Spawn a task to fan out column chunks to their respective transposed streams
98        handle
99            .spawn(async move {
100                pin_mut!(columns_vec_stream);
101                while let Some(result) = columns_vec_stream.next().await {
102                    match result {
103                        Ok(columns) => {
104                            for (tx, column) in column_streams_tx.iter().zip_eq(columns.into_iter())
105                            {
106                                let _ = tx.send(Ok(column)).await;
107                            }
108                        }
109                        Err(e) => {
110                            let e: Arc<VortexError> = Arc::new(e);
111                            for tx in column_streams_tx.iter() {
112                                let _ = tx.send(Err(VortexError::from(e.clone()))).await;
113                            }
114                            break;
115                        }
116                    }
117                }
118            })
119            .detach();
120
121        let column_dtypes = (0..struct_dtype.nfields()).map(move |idx| {
122            struct_dtype
123                .field_by_index(idx)
124                .vortex_expect("bound checked")
125        });
126
127        let layout_futures: Vec<_> = column_dtypes
128            .zip_eq(column_streams_rx)
129            .map(move |(dtype, recv)| {
130                let column_stream =
131                    SequentialStreamAdapter::new(dtype, recv.into_stream().boxed()).sendable();
132                let child_eof = eof.split_off();
133                handle.spawn_nested(|h| {
134                    let child = self.child.clone();
135                    let ctx = ctx.clone();
136                    let segment_sink = segment_sink.clone();
137                    async move {
138                        child
139                            .write_stream(ctx, segment_sink, column_stream, child_eof, h)
140                            .await
141                    }
142                })
143            })
144            .collect();
145
146        let column_layouts = try_join_all(layout_futures).await?;
147        // TODO(os): transposed stream could count row counts as well,
148        // This must hold though, all columns must have the same row count of the struct layout
149        let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0);
150        Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout())
151    }
152
153    fn buffered_bytes(&self) -> u64 {
154        self.child.buffered_bytes()
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use std::sync::Arc;
161
162    use vortex_array::arrays::{BoolArray, ChunkedArray, StructArray};
163    use vortex_array::validity::Validity;
164    use vortex_array::{ArrayContext, Canonical, IntoArray as _};
165    use vortex_buffer::buffer;
166    use vortex_dtype::{DType, FieldNames, Nullability, PType};
167    use vortex_io::runtime::single::block_on;
168
169    use crate::LayoutStrategy;
170    use crate::layouts::flat::writer::FlatLayoutStrategy;
171    use crate::layouts::struct_::writer::StructStrategy;
172    use crate::segments::TestSegments;
173    use crate::sequence::{SequenceId, SequentialArrayStreamExt};
174
175    #[test]
176    #[should_panic]
177    fn fails_on_duplicate_field() {
178        let strategy = StructStrategy::new(FlatLayoutStrategy::default());
179        let (ptr, eof) = SequenceId::root().split();
180        let ctx = ArrayContext::empty();
181        let segments = Arc::new(TestSegments::default());
182        block_on(|handle| {
183            strategy.write_stream(
184                ctx,
185                segments,
186                Canonical::empty(&DType::Struct(
187                    [
188                        ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
189                        ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
190                    ]
191                    .into_iter()
192                    .collect(),
193                    Nullability::NonNullable,
194                ))
195                .into_array()
196                .to_array_stream()
197                .sequenced(ptr),
198                eof,
199                handle,
200            )
201        })
202        .unwrap();
203    }
204
205    #[test]
206    fn fails_on_top_level_nulls() {
207        let strategy = StructStrategy::new(FlatLayoutStrategy::default());
208        let (ptr, eof) = SequenceId::root().split();
209        let ctx = ArrayContext::empty();
210        let segments = Arc::new(TestSegments::default());
211        let res = block_on(|handle| {
212            strategy.write_stream(
213                ctx,
214                segments,
215                StructArray::try_new(
216                    ["a"].into(),
217                    vec![buffer![1, 2, 3].into_array()],
218                    3,
219                    Validity::Array(BoolArray::from_iter(vec![true, true, false]).into_array()),
220                )
221                .unwrap()
222                .into_array()
223                .to_array_stream()
224                .sequenced(ptr),
225                eof,
226                handle,
227            )
228        });
229        assert!(
230            format!("{}", res.unwrap_err())
231                .starts_with("Cannot push struct chunks with top level invalid values"),
232        )
233    }
234
235    #[test]
236    fn write_empty_field_struct_array() {
237        let strategy = StructStrategy::new(FlatLayoutStrategy::default());
238        let (ptr, eof) = SequenceId::root().split();
239        let ctx = ArrayContext::empty();
240        let segments = Arc::new(TestSegments::default());
241        let res = block_on(|handle| {
242            strategy.write_stream(
243                ctx,
244                segments,
245                ChunkedArray::from_iter([
246                    StructArray::try_new(FieldNames::default(), vec![], 3, Validity::NonNullable)
247                        .unwrap()
248                        .into_array(),
249                    StructArray::try_new(FieldNames::default(), vec![], 5, Validity::NonNullable)
250                        .unwrap()
251                        .into_array(),
252                ])
253                .into_array()
254                .to_array_stream()
255                .sequenced(ptr),
256                eof,
257                handle,
258            )
259        });
260
261        assert_eq!(res.unwrap().row_count(), 8);
262    }
263}