vortex_layout/layouts/struct_/
writer.rs1use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::future::try_join_all;
8use futures::{StreamExt, TryStreamExt, pin_mut};
9use itertools::Itertools;
10use vortex_array::{Array, ArrayContext, ToCanonical};
11use vortex_error::{VortexError, VortexExpect as _, VortexResult, vortex_bail};
12use vortex_io::kanal_ext::KanalExt;
13use vortex_io::runtime::Handle;
14use vortex_utils::aliases::DefaultHashBuilder;
15use vortex_utils::aliases::hash_set::HashSet;
16
17use crate::layouts::struct_::StructLayout;
18use crate::segments::SegmentSinkRef;
19use crate::sequence::{
20 SendableSequentialStream, SequencePointer, SequentialStreamAdapter, SequentialStreamExt,
21};
22use crate::{IntoLayout as _, LayoutRef, LayoutStrategy};
23
24pub struct StructStrategy {
25 child: Arc<dyn LayoutStrategy>,
26}
27
28impl StructStrategy {
30 pub fn new<S: LayoutStrategy>(child: S) -> Self {
31 Self {
32 child: Arc::new(child),
33 }
34 }
35}
36
37#[async_trait]
38impl LayoutStrategy for StructStrategy {
39 async fn write_stream(
40 &self,
41 ctx: ArrayContext,
42 segment_sink: SegmentSinkRef,
43 stream: SendableSequentialStream,
44 mut eof: SequencePointer,
45 handle: Handle,
46 ) -> VortexResult<LayoutRef> {
47 let dtype = stream.dtype().clone();
48 let Some(struct_dtype) = stream.dtype().as_struct_fields_opt().cloned() else {
49 return self
51 .child
52 .write_stream(ctx, segment_sink, stream, eof, handle)
53 .await;
54 };
55 if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len()
56 != struct_dtype.names().len()
57 {
58 vortex_bail!("StructLayout must have unique field names");
59 }
60
61 let stream = stream.map(|chunk| {
62 let (sequence_id, chunk) = chunk?;
63 if !chunk.all_valid() {
64 vortex_bail!("Cannot push struct chunks with top level invalid values");
65 };
66 Ok((sequence_id, chunk))
67 });
68
69 if struct_dtype.nfields() == 0 {
71 let row_count = stream
72 .try_fold(
73 0u64,
74 |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) },
75 )
76 .await?;
77 return Ok(StructLayout::new(row_count, dtype, vec![]).into_layout());
78 }
79
80 let columns_vec_stream = stream.map(|chunk| {
82 let (sequence_id, chunk) = chunk?;
83 let mut sequence_pointer = sequence_id.descend();
84 let struct_chunk = chunk.to_struct();
85 let columns: Vec<_> = struct_chunk
86 .fields()
87 .iter()
88 .map(|field| (sequence_pointer.advance(), field.to_array()))
89 .collect();
90 Ok(columns)
91 });
92
93 let (column_streams_tx, column_streams_rx): (Vec<_>, Vec<_>) = (0..struct_dtype.nfields())
94 .map(|_| kanal::bounded_async(1))
95 .unzip();
96
97 handle
99 .spawn(async move {
100 pin_mut!(columns_vec_stream);
101 while let Some(result) = columns_vec_stream.next().await {
102 match result {
103 Ok(columns) => {
104 for (tx, column) in column_streams_tx.iter().zip_eq(columns.into_iter())
105 {
106 let _ = tx.send(Ok(column)).await;
107 }
108 }
109 Err(e) => {
110 let e: Arc<VortexError> = Arc::new(e);
111 for tx in column_streams_tx.iter() {
112 let _ = tx.send(Err(VortexError::from(e.clone()))).await;
113 }
114 break;
115 }
116 }
117 }
118 })
119 .detach();
120
121 let column_dtypes = (0..struct_dtype.nfields()).map(move |idx| {
122 struct_dtype
123 .field_by_index(idx)
124 .vortex_expect("bound checked")
125 });
126
127 let layout_futures: Vec<_> = column_dtypes
128 .zip_eq(column_streams_rx)
129 .map(move |(dtype, recv)| {
130 let column_stream =
131 SequentialStreamAdapter::new(dtype, recv.into_stream().boxed()).sendable();
132 let child_eof = eof.split_off();
133 handle.spawn_nested(|h| {
134 let child = self.child.clone();
135 let ctx = ctx.clone();
136 let segment_sink = segment_sink.clone();
137 async move {
138 child
139 .write_stream(ctx, segment_sink, column_stream, child_eof, h)
140 .await
141 }
142 })
143 })
144 .collect();
145
146 let column_layouts = try_join_all(layout_futures).await?;
147 let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0);
150 Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout())
151 }
152
153 fn buffered_bytes(&self) -> u64 {
154 self.child.buffered_bytes()
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use std::sync::Arc;
161
162 use vortex_array::arrays::{BoolArray, ChunkedArray, StructArray};
163 use vortex_array::validity::Validity;
164 use vortex_array::{ArrayContext, Canonical, IntoArray as _};
165 use vortex_buffer::buffer;
166 use vortex_dtype::{DType, FieldNames, Nullability, PType};
167 use vortex_io::runtime::single::block_on;
168
169 use crate::LayoutStrategy;
170 use crate::layouts::flat::writer::FlatLayoutStrategy;
171 use crate::layouts::struct_::writer::StructStrategy;
172 use crate::segments::TestSegments;
173 use crate::sequence::{SequenceId, SequentialArrayStreamExt};
174
175 #[test]
176 #[should_panic]
177 fn fails_on_duplicate_field() {
178 let strategy = StructStrategy::new(FlatLayoutStrategy::default());
179 let (ptr, eof) = SequenceId::root().split();
180 let ctx = ArrayContext::empty();
181 let segments = Arc::new(TestSegments::default());
182 block_on(|handle| {
183 strategy.write_stream(
184 ctx,
185 segments,
186 Canonical::empty(&DType::Struct(
187 [
188 ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
189 ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
190 ]
191 .into_iter()
192 .collect(),
193 Nullability::NonNullable,
194 ))
195 .into_array()
196 .to_array_stream()
197 .sequenced(ptr),
198 eof,
199 handle,
200 )
201 })
202 .unwrap();
203 }
204
205 #[test]
206 fn fails_on_top_level_nulls() {
207 let strategy = StructStrategy::new(FlatLayoutStrategy::default());
208 let (ptr, eof) = SequenceId::root().split();
209 let ctx = ArrayContext::empty();
210 let segments = Arc::new(TestSegments::default());
211 let res = block_on(|handle| {
212 strategy.write_stream(
213 ctx,
214 segments,
215 StructArray::try_new(
216 ["a"].into(),
217 vec![buffer![1, 2, 3].into_array()],
218 3,
219 Validity::Array(BoolArray::from_iter(vec![true, true, false]).into_array()),
220 )
221 .unwrap()
222 .into_array()
223 .to_array_stream()
224 .sequenced(ptr),
225 eof,
226 handle,
227 )
228 });
229 assert!(
230 format!("{}", res.unwrap_err())
231 .starts_with("Cannot push struct chunks with top level invalid values"),
232 )
233 }
234
235 #[test]
236 fn write_empty_field_struct_array() {
237 let strategy = StructStrategy::new(FlatLayoutStrategy::default());
238 let (ptr, eof) = SequenceId::root().split();
239 let ctx = ArrayContext::empty();
240 let segments = Arc::new(TestSegments::default());
241 let res = block_on(|handle| {
242 strategy.write_stream(
243 ctx,
244 segments,
245 ChunkedArray::from_iter([
246 StructArray::try_new(FieldNames::default(), vec![], 3, Validity::NonNullable)
247 .unwrap()
248 .into_array(),
249 StructArray::try_new(FieldNames::default(), vec![], 5, Validity::NonNullable)
250 .unwrap()
251 .into_array(),
252 ])
253 .into_array()
254 .to_array_stream()
255 .sequenced(ptr),
256 eof,
257 handle,
258 )
259 });
260
261 assert_eq!(res.unwrap().row_count(), 8);
262 }
263}