vortex_layout/layouts/struct_/
writer.rs1#![allow(deprecated, reason = "This module is deprecated")]
5
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use futures::StreamExt;
10use futures::TryStreamExt;
11use futures::future::try_join_all;
12use futures::pin_mut;
13use itertools::Itertools;
14use vortex_array::Array;
15use vortex_array::ArrayContext;
16use vortex_array::ArrayRef;
17use vortex_array::IntoArray;
18use vortex_array::ToCanonical;
19use vortex_dtype::DType;
20use vortex_dtype::Nullability;
21use vortex_error::VortexError;
22use vortex_error::VortexResult;
23use vortex_error::vortex_bail;
24use vortex_io::kanal_ext::KanalExt;
25use vortex_io::runtime::Handle;
26use vortex_utils::aliases::DefaultHashBuilder;
27use vortex_utils::aliases::hash_set::HashSet;
28
29use crate::IntoLayout as _;
30use crate::LayoutRef;
31use crate::LayoutStrategy;
32use crate::layouts::struct_::StructLayout;
33use crate::segments::SegmentSinkRef;
34use crate::sequence::SendableSequentialStream;
35use crate::sequence::SequenceId;
36use crate::sequence::SequencePointer;
37use crate::sequence::SequentialStreamAdapter;
38use crate::sequence::SequentialStreamExt;
39
40#[derive(Clone)]
46#[deprecated(since = "0.59.0", note = "Use the `TableStrategy` instead.")]
47pub struct StructStrategy {
48 child: Arc<dyn LayoutStrategy>,
49 validity: Arc<dyn LayoutStrategy>,
50}
51
52impl StructStrategy {
54 pub fn new<S: LayoutStrategy, V: LayoutStrategy>(child: S, validity: V) -> Self {
55 Self {
56 child: Arc::new(child),
57 validity: Arc::new(validity),
58 }
59 }
60}
61
62#[async_trait]
63impl LayoutStrategy for StructStrategy {
64 async fn write_stream(
65 &self,
66 ctx: ArrayContext,
67 segment_sink: SegmentSinkRef,
68 stream: SendableSequentialStream,
69 mut eof: SequencePointer,
70 handle: Handle,
71 ) -> VortexResult<LayoutRef> {
72 let dtype = stream.dtype().clone();
73 let Some(struct_dtype) = stream.dtype().as_struct_fields_opt().cloned() else {
74 return self
75 .child
76 .write_stream(ctx, segment_sink, stream, eof, handle)
77 .await;
78 };
79
80 if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len()
82 != struct_dtype.names().len()
83 {
84 vortex_bail!("StructLayout must have unique field names");
85 }
86
87 let is_nullable = dtype.is_nullable();
88
89 if struct_dtype.nfields() == 0 && !is_nullable {
92 let row_count = stream
93 .try_fold(
94 0u64,
95 |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) },
96 )
97 .await?;
98 return Ok(StructLayout::new(row_count, dtype, vec![]).into_layout());
99 }
100
101 let columns_vec_stream = stream.map(move |chunk| {
103 let (sequence_id, chunk) = chunk?;
104 let mut sequence_pointer = sequence_id.descend();
105 let struct_chunk = chunk.to_struct();
106 let mut columns: Vec<(SequenceId, ArrayRef)> = Vec::new();
107 if is_nullable {
108 columns.push((
109 sequence_pointer.advance(),
110 chunk.validity_mask()?.into_array(),
111 ));
112 }
113
114 columns.extend(
115 struct_chunk
116 .unmasked_fields()
117 .iter()
118 .map(|field| (sequence_pointer.advance(), field.to_array())),
119 );
120
121 Ok(columns)
122 });
123
124 let mut stream_count = struct_dtype.nfields();
125 if is_nullable {
126 stream_count += 1;
127 }
128
129 let (column_streams_tx, column_streams_rx): (Vec<_>, Vec<_>) =
130 (0..stream_count).map(|_| kanal::bounded_async(1)).unzip();
131
132 handle
134 .spawn(async move {
135 pin_mut!(columns_vec_stream);
136 while let Some(result) = columns_vec_stream.next().await {
137 match result {
138 Ok(columns) => {
139 for (tx, column) in column_streams_tx.iter().zip_eq(columns.into_iter())
140 {
141 let _ = tx.send(Ok(column)).await;
142 }
143 }
144 Err(e) => {
145 let e: Arc<VortexError> = Arc::new(e);
146 for tx in column_streams_tx.iter() {
147 let _ = tx.send(Err(VortexError::from(e.clone()))).await;
148 }
149 break;
150 }
151 }
152 }
153 })
154 .detach();
155
156 let column_dtypes: Vec<DType> = if is_nullable {
158 std::iter::once(DType::Bool(Nullability::NonNullable))
159 .chain(struct_dtype.fields())
160 .collect()
161 } else {
162 struct_dtype.fields().collect()
163 };
164
165 let layout_futures: Vec<_> = column_dtypes
166 .into_iter()
167 .zip_eq(column_streams_rx)
168 .enumerate()
169 .map(move |(index, (dtype, recv))| {
170 let column_stream =
171 SequentialStreamAdapter::new(dtype.clone(), recv.into_stream().boxed())
172 .sendable();
173 let child_eof = eof.split_off();
174 handle.spawn_nested(|h| {
175 let child = self.child.clone();
176 let validity = self.validity.clone();
177 let this = self.clone();
178 let ctx = ctx.clone();
179 let dtype = dtype.clone();
180 let segment_sink = segment_sink.clone();
181 async move {
182 if index == 0 && is_nullable {
184 validity
185 .write_stream(ctx, segment_sink, column_stream, child_eof, h)
186 .await
187 } else {
188 if dtype.is_struct() {
191 this.write_stream(ctx, segment_sink, column_stream, child_eof, h)
192 .await
193 } else {
194 child
195 .write_stream(ctx, segment_sink, column_stream, child_eof, h)
196 .await
197 }
198 }
199 }
200 })
201 })
202 .collect();
203
204 let column_layouts = try_join_all(layout_futures).await?;
205 let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0);
208 Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout())
209 }
210
211 fn buffered_bytes(&self) -> u64 {
212 self.child.buffered_bytes()
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use std::sync::Arc;
219
220 use vortex_array::ArrayContext;
221 use vortex_array::Canonical;
222 use vortex_array::IntoArray as _;
223 use vortex_array::arrays::ChunkedArray;
224 use vortex_array::arrays::StructArray;
225 use vortex_array::validity::Validity;
226 use vortex_dtype::DType;
227 use vortex_dtype::FieldNames;
228 use vortex_dtype::Nullability;
229 use vortex_dtype::PType;
230 use vortex_io::runtime::single::block_on;
231
232 use crate::LayoutStrategy;
233 use crate::layouts::flat::writer::FlatLayoutStrategy;
234 use crate::layouts::struct_::writer::StructStrategy;
235 use crate::segments::TestSegments;
236 use crate::sequence::SequenceId;
237 use crate::sequence::SequentialArrayStreamExt;
238
239 #[test]
240 #[should_panic]
241 fn fails_on_duplicate_field() {
242 let strategy =
243 StructStrategy::new(FlatLayoutStrategy::default(), FlatLayoutStrategy::default());
244 let (ptr, eof) = SequenceId::root().split();
245 let ctx = ArrayContext::empty();
246
247 let segments = Arc::new(TestSegments::default());
248 block_on(|handle| {
249 strategy.write_stream(
250 ctx,
251 segments,
252 Canonical::empty(&DType::Struct(
253 [
254 ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
255 ("a", DType::Primitive(PType::I32, Nullability::NonNullable)),
256 ]
257 .into_iter()
258 .collect(),
259 Nullability::NonNullable,
260 ))
261 .into_array()
262 .to_array_stream()
263 .sequenced(ptr),
264 eof,
265 handle,
266 )
267 })
268 .unwrap();
269 }
270
271 #[test]
272 fn write_empty_field_struct_array() {
273 let strategy =
274 StructStrategy::new(FlatLayoutStrategy::default(), FlatLayoutStrategy::default());
275 let (ptr, eof) = SequenceId::root().split();
276 let ctx = ArrayContext::empty();
277
278 let segments = Arc::new(TestSegments::default());
279 let res = block_on(|handle| {
280 strategy.write_stream(
281 ctx,
282 segments,
283 ChunkedArray::from_iter([
284 StructArray::try_new(FieldNames::default(), vec![], 3, Validity::NonNullable)
285 .unwrap()
286 .into_array(),
287 StructArray::try_new(FieldNames::default(), vec![], 5, Validity::NonNullable)
288 .unwrap()
289 .into_array(),
290 ])
291 .into_array()
292 .to_array_stream()
293 .sequenced(ptr),
294 eof,
295 handle,
296 )
297 });
298
299 assert_eq!(res.unwrap().row_count(), 8);
300 }
301}