vortex_layout/layouts/struct_/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::StreamExt;
8use futures::TryStreamExt;
9use futures::future::try_join_all;
10use futures::pin_mut;
11use itertools::Itertools;
12use vortex_array::Array;
13use vortex_array::ArrayContext;
14use vortex_array::ArrayRef;
15use vortex_array::IntoArray;
16use vortex_array::ToCanonical;
17use vortex_dtype::DType;
18use vortex_dtype::Nullability;
19use vortex_error::VortexError;
20use vortex_error::VortexResult;
21use vortex_error::vortex_bail;
22use vortex_io::kanal_ext::KanalExt;
23use vortex_io::runtime::Handle;
24use vortex_utils::aliases::DefaultHashBuilder;
25use vortex_utils::aliases::hash_set::HashSet;
26
27use crate::IntoLayout as _;
28use crate::LayoutRef;
29use crate::LayoutStrategy;
30use crate::layouts::struct_::StructLayout;
31use crate::segments::SegmentSinkRef;
32use crate::sequence::SendableSequentialStream;
33use crate::sequence::SequenceId;
34use crate::sequence::SequencePointer;
35use crate::sequence::SequentialStreamAdapter;
36use crate::sequence::SequentialStreamExt;
37
38#[derive(Clone)]
39pub struct StructStrategy {
40    child: Arc<dyn LayoutStrategy>,
41    validity: Arc<dyn LayoutStrategy>,
42}
43
44/// A [`LayoutStrategy`] that splits a StructArray batch into child layout writers
45impl StructStrategy {
46    pub fn new<S: LayoutStrategy, V: LayoutStrategy>(child: S, validity: V) -> Self {
47        Self {
48            child: Arc::new(child),
49            validity: Arc::new(validity),
50        }
51    }
52}
53
54#[async_trait]
55impl LayoutStrategy for StructStrategy {
56    async fn write_stream(
57        &self,
58        ctx: ArrayContext,
59        segment_sink: SegmentSinkRef,
60        stream: SendableSequentialStream,
61        mut eof: SequencePointer,
62        handle: Handle,
63    ) -> VortexResult<LayoutRef> {
64        let dtype = stream.dtype().clone();
65        let Some(struct_dtype) = stream.dtype().as_struct_fields_opt().cloned() else {
66            return self
67                .child
68                .write_stream(ctx, segment_sink, stream, eof, handle)
69                .await;
70        };
71
72        // Check for unique field names at write time.
73        if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len()
74            != struct_dtype.names().len()
75        {
76            vortex_bail!("StructLayout must have unique field names");
77        }
78
79        let is_nullable = dtype.is_nullable();
80
81        // Optimization: when there are no fields, don't spawn any work and just write a trivial
82        // StructLayout.
83        if struct_dtype.nfields() == 0 && !is_nullable {
84            let row_count = stream
85                .try_fold(
86                    0u64,
87                    |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) },
88                )
89                .await?;
90            return Ok(StructLayout::new(row_count, dtype, vec![]).into_layout());
91        }
92
93        // stream<struct_chunk> -> stream<vec<column_chunk>>
94        let columns_vec_stream = stream.map(move |chunk| {
95            let (sequence_id, chunk) = chunk?;
96            let mut sequence_pointer = sequence_id.descend();
97            let struct_chunk = chunk.to_struct();
98            let mut columns: Vec<(SequenceId, ArrayRef)> = Vec::new();
99            if is_nullable {
100                columns.push((
101                    sequence_pointer.advance(),
102                    chunk.validity_mask().into_array(),
103                ));
104            }
105
106            columns.extend(
107                struct_chunk
108                    .fields()
109                    .iter()
110                    .map(|field| (sequence_pointer.advance(), field.to_array())),
111            );
112
113            Ok(columns)
114        });
115
116        let mut stream_count = struct_dtype.nfields();
117        if is_nullable {
118            stream_count += 1;
119        }
120
121        let (column_streams_tx, column_streams_rx): (Vec<_>, Vec<_>) =
122            (0..stream_count).map(|_| kanal::bounded_async(1)).unzip();
123
124        // Spawn a task to fan out column chunks to their respective transposed streams
125        handle
126            .spawn(async move {
127                pin_mut!(columns_vec_stream);
128                while let Some(result) = columns_vec_stream.next().await {
129                    match result {
130                        Ok(columns) => {
131                            for (tx, column) in column_streams_tx.iter().zip_eq(columns.into_iter())
132                            {
133                                let _ = tx.send(Ok(column)).await;
134                            }
135                        }
136                        Err(e) => {
137                            let e: Arc<VortexError> = Arc::new(e);
138                            for tx in column_streams_tx.iter() {
139                                let _ = tx.send(Err(VortexError::from(e.clone()))).await;
140                            }
141                            break;
142                        }
143                    }
144                }
145            })
146            .detach();
147
148        // First child column is the validity, subsequence children are the individual struct fields
149        let column_dtypes: Vec<DType> = if is_nullable {
150            std::iter::once(DType::Bool(Nullability::NonNullable))
151                .chain(struct_dtype.fields())
152                .collect()
153        } else {
154            struct_dtype.fields().collect()
155        };
156
157        let layout_futures: Vec<_> = column_dtypes
158            .into_iter()
159            .zip_eq(column_streams_rx)
160            .enumerate()
161            .map(move |(index, (dtype, recv))| {
162                let column_stream =
163                    SequentialStreamAdapter::new(dtype.clone(), recv.into_stream().boxed())
164                        .sendable();
165                let child_eof = eof.split_off();
166                handle.spawn_nested(|h| {
167                    let child = self.child.clone();
168                    let validity = self.validity.clone();
169                    let this = self.clone();
170                    let ctx = ctx.clone();
171                    let dtype = dtype.clone();
172                    let segment_sink = segment_sink.clone();
173                    async move {
174                        // Write validity stream
175                        if index == 0 && is_nullable {
176                            validity
177                                .write_stream(ctx, segment_sink, column_stream, child_eof, h)
178                                .await
179                        } else {
180                            // Build recursive StructLayout for nested struct fields
181                            // TODO(aduffy): add branch for ListLayout once that's implemented
182                            if dtype.is_struct() {
183                                this.write_stream(ctx, segment_sink, column_stream, child_eof, h)
184                                    .await
185                            } else {
186                                child
187                                    .write_stream(ctx, segment_sink, column_stream, child_eof, h)
188                                    .await
189                            }
190                        }
191                    }
192                })
193            })
194            .collect();
195
196        let column_layouts = try_join_all(layout_futures).await?;
197        // TODO(os): transposed stream could count row counts as well,
198        // This must hold though, all columns must have the same row count of the struct layout
199        let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0);
200        Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout())
201    }
202
203    fn buffered_bytes(&self) -> u64 {
204        self.child.buffered_bytes()
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use std::sync::Arc;
211
212    use vortex_array::ArrayContext;
213    use vortex_array::Canonical;
214    use vortex_array::IntoArray as _;
215    use vortex_array::arrays::ChunkedArray;
216    use vortex_array::arrays::StructArray;
217    use vortex_array::validity::Validity;
218    use vortex_dtype::DType;
219    use vortex_dtype::FieldNames;
220    use vortex_dtype::Nullability;
221    use vortex_dtype::PType;
222    use vortex_io::runtime::single::block_on;
223
224    use crate::LayoutStrategy;
225    use crate::layouts::flat::writer::FlatLayoutStrategy;
226    use crate::layouts::struct_::writer::StructStrategy;
227    use crate::segments::TestSegments;
228    use crate::sequence::SequenceId;
229    use crate::sequence::SequentialArrayStreamExt;
230
231    #[test]
232    #[should_panic]
233    fn fails_on_duplicate_field() {
234        let strategy =
235            StructStrategy::new(FlatLayoutStrategy::default(), FlatLayoutStrategy::default());
236        let (ptr, eof) = SequenceId::root().split();
237        let ctx = ArrayContext::empty();
238
239        let segments = Arc::new(TestSegments::default());
240        block_on(|handle| {
241            strategy.write_stream(
242                ctx,
243                segments,
244                Canonical::empty(&DType::Struct(
245                    [
246                        ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
247                        ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
248                    ]
249                    .into_iter()
250                    .collect(),
251                    Nullability::NonNullable,
252                ))
253                .into_array()
254                .to_array_stream()
255                .sequenced(ptr),
256                eof,
257                handle,
258            )
259        })
260        .unwrap();
261    }
262
263    #[test]
264    fn write_empty_field_struct_array() {
265        let strategy =
266            StructStrategy::new(FlatLayoutStrategy::default(), FlatLayoutStrategy::default());
267        let (ptr, eof) = SequenceId::root().split();
268        let ctx = ArrayContext::empty();
269
270        let segments = Arc::new(TestSegments::default());
271        let res = block_on(|handle| {
272            strategy.write_stream(
273                ctx,
274                segments,
275                ChunkedArray::from_iter([
276                    StructArray::try_new(FieldNames::default(), vec![], 3, Validity::NonNullable)
277                        .unwrap()
278                        .into_array(),
279                    StructArray::try_new(FieldNames::default(), vec![], 5, Validity::NonNullable)
280                        .unwrap()
281                        .into_array(),
282                ])
283                .into_array()
284                .to_array_stream()
285                .sequenced(ptr),
286                eof,
287                handle,
288            )
289        });
290
291        assert_eq!(res.unwrap().row_count(), 8);
292    }
293}