vortex_layout/layouts/struct_/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::future::try_join_all;
8use futures::{StreamExt, TryStreamExt, pin_mut};
9use itertools::Itertools;
10use vortex_array::{Array, ArrayContext, ArrayRef, IntoArray, ToCanonical};
11use vortex_dtype::{DType, Nullability};
12use vortex_error::{VortexError, VortexResult, vortex_bail};
13use vortex_io::kanal_ext::KanalExt;
14use vortex_io::runtime::Handle;
15use vortex_utils::aliases::DefaultHashBuilder;
16use vortex_utils::aliases::hash_set::HashSet;
17
18use crate::layouts::struct_::StructLayout;
19use crate::segments::SegmentSinkRef;
20use crate::sequence::{
21    SendableSequentialStream, SequenceId, SequencePointer, SequentialStreamAdapter,
22    SequentialStreamExt,
23};
24use crate::{IntoLayout as _, LayoutRef, LayoutStrategy};
25
26#[derive(Clone)]
27pub struct StructStrategy {
28    child: Arc<dyn LayoutStrategy>,
29    validity: Arc<dyn LayoutStrategy>,
30}
31
32/// A [`LayoutStrategy`] that splits a StructArray batch into child layout writers
33impl StructStrategy {
34    pub fn new<S: LayoutStrategy, V: LayoutStrategy>(child: S, validity: V) -> Self {
35        Self {
36            child: Arc::new(child),
37            validity: Arc::new(validity),
38        }
39    }
40}
41
42#[async_trait]
43impl LayoutStrategy for StructStrategy {
44    async fn write_stream(
45        &self,
46        ctx: ArrayContext,
47        segment_sink: SegmentSinkRef,
48        stream: SendableSequentialStream,
49        mut eof: SequencePointer,
50        handle: Handle,
51    ) -> VortexResult<LayoutRef> {
52        let dtype = stream.dtype().clone();
53        let Some(struct_dtype) = stream.dtype().as_struct_fields_opt().cloned() else {
54            return self
55                .child
56                .write_stream(ctx, segment_sink, stream, eof, handle)
57                .await;
58        };
59
60        // Check for unique field names at write time.
61        if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len()
62            != struct_dtype.names().len()
63        {
64            vortex_bail!("StructLayout must have unique field names");
65        }
66
67        let is_nullable = dtype.is_nullable();
68
69        // Optimization: when there are no fields, don't spawn any work and just write a trivial
70        // StructLayout.
71        if struct_dtype.nfields() == 0 && !is_nullable {
72            let row_count = stream
73                .try_fold(
74                    0u64,
75                    |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) },
76                )
77                .await?;
78            return Ok(StructLayout::new(row_count, dtype, vec![]).into_layout());
79        }
80
81        // stream<struct_chunk> -> stream<vec<column_chunk>>
82        let columns_vec_stream = stream.map(move |chunk| {
83            let (sequence_id, chunk) = chunk?;
84            let mut sequence_pointer = sequence_id.descend();
85            let struct_chunk = chunk.to_struct();
86            let mut columns: Vec<(SequenceId, ArrayRef)> = Vec::new();
87            if is_nullable {
88                columns.push((
89                    sequence_pointer.advance(),
90                    chunk.validity_mask().into_array(),
91                ));
92            }
93
94            columns.extend(
95                struct_chunk
96                    .fields()
97                    .iter()
98                    .map(|field| (sequence_pointer.advance(), field.to_array())),
99            );
100
101            Ok(columns)
102        });
103
104        let mut stream_count = struct_dtype.nfields();
105        if is_nullable {
106            stream_count += 1;
107        }
108
109        let (column_streams_tx, column_streams_rx): (Vec<_>, Vec<_>) =
110            (0..stream_count).map(|_| kanal::bounded_async(1)).unzip();
111
112        // Spawn a task to fan out column chunks to their respective transposed streams
113        handle
114            .spawn(async move {
115                pin_mut!(columns_vec_stream);
116                while let Some(result) = columns_vec_stream.next().await {
117                    match result {
118                        Ok(columns) => {
119                            for (tx, column) in column_streams_tx.iter().zip_eq(columns.into_iter())
120                            {
121                                let _ = tx.send(Ok(column)).await;
122                            }
123                        }
124                        Err(e) => {
125                            let e: Arc<VortexError> = Arc::new(e);
126                            for tx in column_streams_tx.iter() {
127                                let _ = tx.send(Err(VortexError::from(e.clone()))).await;
128                            }
129                            break;
130                        }
131                    }
132                }
133            })
134            .detach();
135
136        // First child column is the validity, subsequence children are the individual struct fields
137        let column_dtypes: Vec<DType> = if is_nullable {
138            std::iter::once(DType::Bool(Nullability::NonNullable))
139                .chain(struct_dtype.fields())
140                .collect()
141        } else {
142            struct_dtype.fields().collect()
143        };
144
145        let layout_futures: Vec<_> = column_dtypes
146            .into_iter()
147            .zip_eq(column_streams_rx)
148            .enumerate()
149            .map(move |(index, (dtype, recv))| {
150                let column_stream =
151                    SequentialStreamAdapter::new(dtype.clone(), recv.into_stream().boxed())
152                        .sendable();
153                let child_eof = eof.split_off();
154                handle.spawn_nested(|h| {
155                    let child = self.child.clone();
156                    let validity = self.validity.clone();
157                    let this = self.clone();
158                    let ctx = ctx.clone();
159                    let dtype = dtype.clone();
160                    let segment_sink = segment_sink.clone();
161                    async move {
162                        // Write validity stream
163                        if index == 0 && is_nullable {
164                            validity
165                                .write_stream(ctx, segment_sink, column_stream, child_eof, h)
166                                .await
167                        } else {
168                            // Build recursive StructLayout for nested struct fields
169                            // TODO(aduffy): add branch for ListLayout once that's implemented
170                            if dtype.is_struct() {
171                                this.write_stream(ctx, segment_sink, column_stream, child_eof, h)
172                                    .await
173                            } else {
174                                child
175                                    .write_stream(ctx, segment_sink, column_stream, child_eof, h)
176                                    .await
177                            }
178                        }
179                    }
180                })
181            })
182            .collect();
183
184        let column_layouts = try_join_all(layout_futures).await?;
185        // TODO(os): transposed stream could count row counts as well,
186        // This must hold though, all columns must have the same row count of the struct layout
187        let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0);
188        Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout())
189    }
190
191    fn buffered_bytes(&self) -> u64 {
192        self.child.buffered_bytes()
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use std::sync::Arc;
199
200    use vortex_array::arrays::{ChunkedArray, StructArray};
201    use vortex_array::validity::Validity;
202    use vortex_array::{ArrayContext, Canonical, IntoArray as _};
203    use vortex_dtype::{DType, FieldNames, Nullability, PType};
204    use vortex_io::runtime::single::block_on;
205
206    use crate::LayoutStrategy;
207    use crate::layouts::flat::writer::FlatLayoutStrategy;
208    use crate::layouts::struct_::writer::StructStrategy;
209    use crate::segments::TestSegments;
210    use crate::sequence::{SequenceId, SequentialArrayStreamExt};
211
212    #[test]
213    #[should_panic]
214    fn fails_on_duplicate_field() {
215        let strategy =
216            StructStrategy::new(FlatLayoutStrategy::default(), FlatLayoutStrategy::default());
217        let (ptr, eof) = SequenceId::root().split();
218        let ctx = ArrayContext::empty();
219        let segments = Arc::new(TestSegments::default());
220        block_on(|handle| {
221            strategy.write_stream(
222                ctx,
223                segments,
224                Canonical::empty(&DType::Struct(
225                    [
226                        ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
227                        ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
228                    ]
229                    .into_iter()
230                    .collect(),
231                    Nullability::NonNullable,
232                ))
233                .into_array()
234                .to_array_stream()
235                .sequenced(ptr),
236                eof,
237                handle,
238            )
239        })
240        .unwrap();
241    }
242
243    #[test]
244    fn write_empty_field_struct_array() {
245        let strategy =
246            StructStrategy::new(FlatLayoutStrategy::default(), FlatLayoutStrategy::default());
247        let (ptr, eof) = SequenceId::root().split();
248        let ctx = ArrayContext::empty();
249        let segments = Arc::new(TestSegments::default());
250        let res = block_on(|handle| {
251            strategy.write_stream(
252                ctx,
253                segments,
254                ChunkedArray::from_iter([
255                    StructArray::try_new(FieldNames::default(), vec![], 3, Validity::NonNullable)
256                        .unwrap()
257                        .into_array(),
258                    StructArray::try_new(FieldNames::default(), vec![], 5, Validity::NonNullable)
259                        .unwrap()
260                        .into_array(),
261                ])
262                .into_array()
263                .to_array_stream()
264                .sequenced(ptr),
265                eof,
266                handle,
267            )
268        });
269
270        assert_eq!(res.unwrap().row_count(), 8);
271    }
272}