vortex_layout/layouts/struct_/
writer.rs1use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::future::try_join_all;
8use futures::{StreamExt, TryStreamExt, pin_mut};
9use itertools::Itertools;
10use vortex_array::{Array, ArrayContext, ToCanonical};
11use vortex_error::{VortexError, VortexExpect as _, VortexResult, vortex_bail};
12use vortex_io::kanal_ext::KanalExt;
13use vortex_io::runtime::Handle;
14use vortex_utils::aliases::DefaultHashBuilder;
15use vortex_utils::aliases::hash_set::HashSet;
16
17use crate::layouts::struct_::StructLayout;
18use crate::segments::SegmentSinkRef;
19use crate::sequence::{
20 SendableSequentialStream, SequencePointer, SequentialStreamAdapter, SequentialStreamExt,
21};
22use crate::{IntoLayout as _, LayoutRef, LayoutStrategy};
23
24pub struct StructStrategy {
25 child: Arc<dyn LayoutStrategy>,
26}
27
28impl StructStrategy {
30 pub fn new<S: LayoutStrategy>(child: S) -> Self {
31 Self {
32 child: Arc::new(child),
33 }
34 }
35}
36
37#[async_trait]
38impl LayoutStrategy for StructStrategy {
39 async fn write_stream(
40 &self,
41 ctx: ArrayContext,
42 segment_sink: SegmentSinkRef,
43 stream: SendableSequentialStream,
44 mut eof: SequencePointer,
45 handle: Handle,
46 ) -> VortexResult<LayoutRef> {
47 let dtype = stream.dtype().clone();
48 let Some(struct_dtype) = stream.dtype().as_struct_fields_opt().cloned() else {
49 return self
51 .child
52 .write_stream(ctx, segment_sink, stream, eof, handle)
53 .await;
54 };
55 if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len()
56 != struct_dtype.names().len()
57 {
58 vortex_bail!("StructLayout must have unique field names");
59 }
60
61 let stream = stream.map(|chunk| {
62 let (sequence_id, chunk) = chunk?;
63 if !chunk.all_valid() {
64 vortex_bail!("Cannot push struct chunks with top level invalid values");
65 };
66 Ok((sequence_id, chunk))
67 });
68
69 if struct_dtype.nfields() == 0 {
71 let row_count = stream
72 .try_fold(
73 0u64,
74 |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) },
75 )
76 .await?;
77 return Ok(StructLayout::new(row_count, dtype, vec![]).into_layout());
78 }
79
80 let columns_vec_stream = stream.map(|chunk| {
82 let (sequence_id, chunk) = chunk?;
83 let mut sequence_pointer = sequence_id.descend();
84 let struct_chunk = chunk.to_struct();
85 let columns: Vec<_> = (0..struct_chunk.struct_fields().nfields())
86 .map(|idx| {
87 (
88 sequence_pointer.advance(),
89 struct_chunk.fields()[idx].to_array(),
90 )
91 })
92 .collect();
93 Ok(columns)
94 });
95
96 let (column_streams_tx, column_streams_rx): (Vec<_>, Vec<_>) = (0..struct_dtype.nfields())
97 .map(|_| kanal::bounded_async(1))
98 .unzip();
99
100 handle
102 .spawn(async move {
103 pin_mut!(columns_vec_stream);
104 while let Some(result) = columns_vec_stream.next().await {
105 match result {
106 Ok(columns) => {
107 for (tx, column) in column_streams_tx.iter().zip_eq(columns.into_iter())
108 {
109 let _ = tx.send(Ok(column)).await;
110 }
111 }
112 Err(e) => {
113 let e: Arc<VortexError> = Arc::new(e);
114 for tx in column_streams_tx.iter() {
115 let _ = tx.send(Err(VortexError::from(e.clone()))).await;
116 }
117 break;
118 }
119 }
120 }
121 })
122 .detach();
123
124 let column_dtypes = (0..struct_dtype.nfields()).map(move |idx| {
125 struct_dtype
126 .field_by_index(idx)
127 .vortex_expect("bound checked")
128 });
129
130 let layout_futures: Vec<_> = column_dtypes
131 .zip_eq(column_streams_rx)
132 .map(move |(dtype, recv)| {
133 let column_stream =
134 SequentialStreamAdapter::new(dtype, recv.into_stream().boxed()).sendable();
135 let child_eof = eof.split_off();
136 handle.spawn_nested(|h| {
137 let child = self.child.clone();
138 let ctx = ctx.clone();
139 let segment_sink = segment_sink.clone();
140 async move {
141 child
142 .write_stream(ctx, segment_sink, column_stream, child_eof, h)
143 .await
144 }
145 })
146 })
147 .collect();
148
149 let column_layouts = try_join_all(layout_futures).await?;
150 let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0);
153 Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout())
154 }
155
156 fn buffered_bytes(&self) -> u64 {
157 self.child.buffered_bytes()
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use std::sync::Arc;
164
165 use vortex_array::arrays::{BoolArray, ChunkedArray, StructArray};
166 use vortex_array::validity::Validity;
167 use vortex_array::{ArrayContext, Canonical, IntoArray as _};
168 use vortex_buffer::buffer;
169 use vortex_dtype::{DType, FieldNames, Nullability, PType};
170 use vortex_io::runtime::single::block_on;
171
172 use crate::LayoutStrategy;
173 use crate::layouts::flat::writer::FlatLayoutStrategy;
174 use crate::layouts::struct_::writer::StructStrategy;
175 use crate::segments::TestSegments;
176 use crate::sequence::{SequenceId, SequentialArrayStreamExt};
177
178 #[test]
179 #[should_panic]
180 fn fails_on_duplicate_field() {
181 let strategy = StructStrategy::new(FlatLayoutStrategy::default());
182 let (ptr, eof) = SequenceId::root().split();
183 let ctx = ArrayContext::empty();
184 let segments = Arc::new(TestSegments::default());
185 block_on(|handle| {
186 strategy.write_stream(
187 ctx,
188 segments,
189 Canonical::empty(&DType::Struct(
190 [
191 ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
192 ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
193 ]
194 .into_iter()
195 .collect(),
196 Nullability::NonNullable,
197 ))
198 .into_array()
199 .to_array_stream()
200 .sequenced(ptr),
201 eof,
202 handle,
203 )
204 })
205 .unwrap();
206 }
207
208 #[test]
209 fn fails_on_top_level_nulls() {
210 let strategy = StructStrategy::new(FlatLayoutStrategy::default());
211 let (ptr, eof) = SequenceId::root().split();
212 let ctx = ArrayContext::empty();
213 let segments = Arc::new(TestSegments::default());
214 let res = block_on(|handle| {
215 strategy.write_stream(
216 ctx,
217 segments,
218 StructArray::try_new(
219 ["a"].into(),
220 vec![buffer![1, 2, 3].into_array()],
221 3,
222 Validity::Array(BoolArray::from_iter(vec![true, true, false]).into_array()),
223 )
224 .unwrap()
225 .into_array()
226 .to_array_stream()
227 .sequenced(ptr),
228 eof,
229 handle,
230 )
231 });
232 assert!(
233 format!("{}", res.unwrap_err())
234 .starts_with("Cannot push struct chunks with top level invalid values"),
235 )
236 }
237
238 #[test]
239 fn write_empty_field_struct_array() {
240 let strategy = StructStrategy::new(FlatLayoutStrategy::default());
241 let (ptr, eof) = SequenceId::root().split();
242 let ctx = ArrayContext::empty();
243 let segments = Arc::new(TestSegments::default());
244 let res = block_on(|handle| {
245 strategy.write_stream(
246 ctx,
247 segments,
248 ChunkedArray::from_iter([
249 StructArray::try_new(FieldNames::default(), vec![], 3, Validity::NonNullable)
250 .unwrap()
251 .into_array(),
252 StructArray::try_new(FieldNames::default(), vec![], 5, Validity::NonNullable)
253 .unwrap()
254 .into_array(),
255 ])
256 .into_array()
257 .to_array_stream()
258 .sequenced(ptr),
259 eof,
260 handle,
261 )
262 });
263
264 assert_eq!(res.unwrap().row_count(), 8);
265 }
266}