vortex_layout/layouts/struct_/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::future::try_join_all;
8use futures::{StreamExt, TryStreamExt, pin_mut};
9use itertools::Itertools;
10use vortex_array::{Array, ArrayContext, ToCanonical};
11use vortex_error::{VortexError, VortexExpect as _, VortexResult, vortex_bail};
12use vortex_io::kanal_ext::KanalExt;
13use vortex_io::runtime::Handle;
14use vortex_utils::aliases::DefaultHashBuilder;
15use vortex_utils::aliases::hash_set::HashSet;
16
17use crate::layouts::struct_::StructLayout;
18use crate::segments::SegmentSinkRef;
19use crate::sequence::{
20    SendableSequentialStream, SequencePointer, SequentialStreamAdapter, SequentialStreamExt,
21};
22use crate::{IntoLayout as _, LayoutRef, LayoutStrategy};
23
24pub struct StructStrategy {
25    child: Arc<dyn LayoutStrategy>,
26}
27
28/// A [`LayoutStrategy`] that splits a StructArray batch into child layout writers
29impl StructStrategy {
30    pub fn new<S: LayoutStrategy>(child: S) -> Self {
31        Self {
32            child: Arc::new(child),
33        }
34    }
35}
36
37#[async_trait]
38impl LayoutStrategy for StructStrategy {
39    async fn write_stream(
40        &self,
41        ctx: ArrayContext,
42        segment_sink: SegmentSinkRef,
43        stream: SendableSequentialStream,
44        mut eof: SequencePointer,
45        handle: Handle,
46    ) -> VortexResult<LayoutRef> {
47        let dtype = stream.dtype().clone();
48        let Some(struct_dtype) = stream.dtype().as_struct_fields_opt().cloned() else {
49            // nothing we can do if dtype is not struct
50            return self
51                .child
52                .write_stream(ctx, segment_sink, stream, eof, handle)
53                .await;
54        };
55        if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len()
56            != struct_dtype.names().len()
57        {
58            vortex_bail!("StructLayout must have unique field names");
59        }
60
61        let stream = stream.map(|chunk| {
62            let (sequence_id, chunk) = chunk?;
63            if !chunk.all_valid() {
64                vortex_bail!("Cannot push struct chunks with top level invalid values");
65            };
66            Ok((sequence_id, chunk))
67        });
68
69        // There are now fields so this is the layout leaf
70        if struct_dtype.nfields() == 0 {
71            let row_count = stream
72                .try_fold(
73                    0u64,
74                    |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) },
75                )
76                .await?;
77            return Ok(StructLayout::new(row_count, dtype, vec![]).into_layout());
78        }
79
80        // stream<struct_chunk> -> stream<vec<column_chunk>>
81        let columns_vec_stream = stream.map(|chunk| {
82            let (sequence_id, chunk) = chunk?;
83            let mut sequence_pointer = sequence_id.descend();
84            let struct_chunk = chunk.to_struct();
85            let columns: Vec<_> = (0..struct_chunk.struct_fields().nfields())
86                .map(|idx| {
87                    (
88                        sequence_pointer.advance(),
89                        struct_chunk.fields()[idx].to_array(),
90                    )
91                })
92                .collect();
93            Ok(columns)
94        });
95
96        let (column_streams_tx, column_streams_rx): (Vec<_>, Vec<_>) = (0..struct_dtype.nfields())
97            .map(|_| kanal::bounded_async(1))
98            .unzip();
99
100        // Spawn a task to fan out column chunks to their respective transposed streams
101        handle
102            .spawn(async move {
103                pin_mut!(columns_vec_stream);
104                while let Some(result) = columns_vec_stream.next().await {
105                    match result {
106                        Ok(columns) => {
107                            for (tx, column) in column_streams_tx.iter().zip_eq(columns.into_iter())
108                            {
109                                let _ = tx.send(Ok(column)).await;
110                            }
111                        }
112                        Err(e) => {
113                            let e: Arc<VortexError> = Arc::new(e);
114                            for tx in column_streams_tx.iter() {
115                                let _ = tx.send(Err(VortexError::from(e.clone()))).await;
116                            }
117                            break;
118                        }
119                    }
120                }
121            })
122            .detach();
123
124        let column_dtypes = (0..struct_dtype.nfields()).map(move |idx| {
125            struct_dtype
126                .field_by_index(idx)
127                .vortex_expect("bound checked")
128        });
129
130        let layout_futures: Vec<_> = column_dtypes
131            .zip_eq(column_streams_rx)
132            .map(move |(dtype, recv)| {
133                let column_stream =
134                    SequentialStreamAdapter::new(dtype, recv.into_stream().boxed()).sendable();
135                let child_eof = eof.split_off();
136                handle.spawn_nested(|h| {
137                    let child = self.child.clone();
138                    let ctx = ctx.clone();
139                    let segment_sink = segment_sink.clone();
140                    async move {
141                        child
142                            .write_stream(ctx, segment_sink, column_stream, child_eof, h)
143                            .await
144                    }
145                })
146            })
147            .collect();
148
149        let column_layouts = try_join_all(layout_futures).await?;
150        // TODO(os): transposed stream could count row counts as well,
151        // This must hold though, all columns must have the same row count of the struct layout
152        let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0);
153        Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout())
154    }
155
156    fn buffered_bytes(&self) -> u64 {
157        self.child.buffered_bytes()
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use std::sync::Arc;
164
165    use vortex_array::arrays::{BoolArray, ChunkedArray, StructArray};
166    use vortex_array::validity::Validity;
167    use vortex_array::{ArrayContext, Canonical, IntoArray as _};
168    use vortex_buffer::buffer;
169    use vortex_dtype::{DType, FieldNames, Nullability, PType};
170    use vortex_io::runtime::single::block_on;
171
172    use crate::LayoutStrategy;
173    use crate::layouts::flat::writer::FlatLayoutStrategy;
174    use crate::layouts::struct_::writer::StructStrategy;
175    use crate::segments::TestSegments;
176    use crate::sequence::{SequenceId, SequentialArrayStreamExt};
177
178    #[test]
179    #[should_panic]
180    fn fails_on_duplicate_field() {
181        let strategy = StructStrategy::new(FlatLayoutStrategy::default());
182        let (ptr, eof) = SequenceId::root().split();
183        let ctx = ArrayContext::empty();
184        let segments = Arc::new(TestSegments::default());
185        block_on(|handle| {
186            strategy.write_stream(
187                ctx,
188                segments,
189                Canonical::empty(&DType::Struct(
190                    [
191                        ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
192                        ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
193                    ]
194                    .into_iter()
195                    .collect(),
196                    Nullability::NonNullable,
197                ))
198                .into_array()
199                .to_array_stream()
200                .sequenced(ptr),
201                eof,
202                handle,
203            )
204        })
205        .unwrap();
206    }
207
208    #[test]
209    fn fails_on_top_level_nulls() {
210        let strategy = StructStrategy::new(FlatLayoutStrategy::default());
211        let (ptr, eof) = SequenceId::root().split();
212        let ctx = ArrayContext::empty();
213        let segments = Arc::new(TestSegments::default());
214        let res = block_on(|handle| {
215            strategy.write_stream(
216                ctx,
217                segments,
218                StructArray::try_new(
219                    ["a"].into(),
220                    vec![buffer![1, 2, 3].into_array()],
221                    3,
222                    Validity::Array(BoolArray::from_iter(vec![true, true, false]).into_array()),
223                )
224                .unwrap()
225                .into_array()
226                .to_array_stream()
227                .sequenced(ptr),
228                eof,
229                handle,
230            )
231        });
232        assert!(
233            format!("{}", res.unwrap_err())
234                .starts_with("Cannot push struct chunks with top level invalid values"),
235        )
236    }
237
238    #[test]
239    fn write_empty_field_struct_array() {
240        let strategy = StructStrategy::new(FlatLayoutStrategy::default());
241        let (ptr, eof) = SequenceId::root().split();
242        let ctx = ArrayContext::empty();
243        let segments = Arc::new(TestSegments::default());
244        let res = block_on(|handle| {
245            strategy.write_stream(
246                ctx,
247                segments,
248                ChunkedArray::from_iter([
249                    StructArray::try_new(FieldNames::default(), vec![], 3, Validity::NonNullable)
250                        .unwrap()
251                        .into_array(),
252                    StructArray::try_new(FieldNames::default(), vec![], 5, Validity::NonNullable)
253                        .unwrap()
254                        .into_array(),
255                ])
256                .into_array()
257                .to_array_stream()
258                .sequenced(ptr),
259                eof,
260                handle,
261            )
262        });
263
264        assert_eq!(res.unwrap().row_count(), 8);
265    }
266}