vortex_layout/layouts/
table.rs1use std::sync::Arc;
9
10use async_trait::async_trait;
11use futures::StreamExt;
12use futures::TryStreamExt;
13use futures::future::try_join_all;
14use futures::pin_mut;
15use itertools::Itertools;
16use vortex_array::ArrayContext;
17use vortex_array::ArrayRef;
18use vortex_array::IntoArray;
19use vortex_array::VortexSessionExecute;
20use vortex_array::arrays::StructArray;
21use vortex_array::arrays::struct_::StructArrayExt;
22use vortex_array::dtype::DType;
23use vortex_array::dtype::Field;
24use vortex_array::dtype::FieldName;
25use vortex_array::dtype::FieldPath;
26use vortex_array::dtype::Nullability;
27use vortex_error::VortexError;
28use vortex_error::VortexResult;
29use vortex_error::vortex_bail;
30use vortex_io::kanal_ext::KanalExt;
31use vortex_io::session::RuntimeSessionExt;
32use vortex_session::VortexSession;
33use vortex_utils::aliases::DefaultHashBuilder;
34use vortex_utils::aliases::hash_map::HashMap;
35use vortex_utils::aliases::hash_set::HashSet;
36
37use crate::IntoLayout;
38use crate::LayoutRef;
39use crate::LayoutStrategy;
40use crate::layouts::struct_::StructLayout;
41use crate::segments::SegmentSinkRef;
42use crate::sequence::SendableSequentialStream;
43use crate::sequence::SequenceId;
44use crate::sequence::SequencePointer;
45use crate::sequence::SequentialStreamAdapter;
46use crate::sequence::SequentialStreamExt;
47
48pub struct TableStrategy {
51 leaf_writers: HashMap<FieldPath, Arc<dyn LayoutStrategy>>,
53 validity: Arc<dyn LayoutStrategy>,
55 fallback: Arc<dyn LayoutStrategy>,
57}
58
59impl TableStrategy {
60 pub fn new(validity: Arc<dyn LayoutStrategy>, fallback: Arc<dyn LayoutStrategy>) -> Self {
78 Self {
79 leaf_writers: Default::default(),
80 validity,
81 fallback,
82 }
83 }
84
85 pub fn with_field_writer(
114 mut self,
115 field_path: impl Into<FieldPath>,
116 writer: Arc<dyn LayoutStrategy>,
117 ) -> Self {
118 self.leaf_writers
119 .insert(self.validate_path(field_path.into()), writer);
120 self
121 }
122
123 pub fn with_field_writers(
127 mut self,
128 writers: impl IntoIterator<Item = (FieldPath, Arc<dyn LayoutStrategy>)>,
129 ) -> Self {
130 for (field_path, strategy) in writers {
131 self.leaf_writers
132 .insert(self.validate_path(field_path), strategy);
133 }
134 self
135 }
136
137 pub fn with_default_strategy(mut self, default: Arc<dyn LayoutStrategy>) -> Self {
139 self.fallback = default;
140 self
141 }
142
143 pub fn with_validity_strategy(mut self, validity: Arc<dyn LayoutStrategy>) -> Self {
145 self.validity = validity;
146 self
147 }
148}
149
150impl TableStrategy {
151 fn descend(&self, field: &Field) -> Self {
153 let mut new_writers = HashMap::with_capacity(self.leaf_writers.len());
156
157 for (field_path, strategy) in &self.leaf_writers {
158 if field_path.starts_with_field(field)
159 && let Some(subpath) = field_path.clone().step_into()
160 {
161 new_writers.insert(subpath, Arc::clone(strategy));
162 }
163 }
164
165 Self {
166 leaf_writers: new_writers,
167 validity: Arc::clone(&self.validity),
168 fallback: Arc::clone(&self.fallback),
169 }
170 }
171
172 fn validate_path(&self, path: FieldPath) -> FieldPath {
173 assert!(
174 !path.is_root(),
175 "Do not set override as a root strategy, instead set the default strategy"
176 );
177
178 for field_path in self.leaf_writers.keys() {
181 assert!(
182 !path.overlap(field_path),
183 "Override for field_path {path} conflicts with existing override for {field_path}"
184 );
185 }
186
187 path
188 }
189}
190
191#[async_trait]
193impl LayoutStrategy for TableStrategy {
194 async fn write_stream(
195 &self,
196 ctx: ArrayContext,
197 segment_sink: SegmentSinkRef,
198 stream: SendableSequentialStream,
199 mut eof: SequencePointer,
200 session: &VortexSession,
201 ) -> VortexResult<LayoutRef> {
202 let dtype = stream.dtype().clone();
203
204 if !dtype.is_struct() {
206 return self
207 .fallback
208 .write_stream(ctx, segment_sink, stream, eof, session)
209 .await;
210 }
211
212 let struct_dtype = dtype.as_struct_fields();
213
214 if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len()
216 != struct_dtype.names().len()
217 {
218 vortex_bail!("StructLayout must have unique field names");
219 }
220 let is_nullable = dtype.is_nullable();
221
222 if struct_dtype.nfields() == 0 && !is_nullable {
225 let row_count = stream
226 .try_fold(
227 0u64,
228 |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) },
229 )
230 .await?;
231 return Ok(StructLayout::new(row_count, dtype, vec![]).into_layout());
232 }
233
234 let columns_session = session.clone();
236 let columns_vec_stream = stream.map(move |chunk| {
237 let (sequence_id, chunk) = chunk?;
238 let mut sequence_pointer = sequence_id.descend();
239 let mut ctx = columns_session.create_execution_ctx();
240 let struct_chunk = chunk.clone().execute::<StructArray>(&mut ctx)?;
241 let mut columns: Vec<(SequenceId, ArrayRef)> = Vec::new();
242 if is_nullable {
243 columns.push((
244 sequence_pointer.advance(),
245 chunk
246 .validity()?
247 .execute_mask(chunk.len(), &mut ctx)?
248 .into_array(),
249 ));
250 }
251
252 columns.extend(
253 struct_chunk
254 .iter_unmasked_fields()
255 .map(|field| (sequence_pointer.advance(), field.clone())),
256 );
257
258 Ok(columns)
259 });
260
261 let mut stream_count = struct_dtype.nfields();
262 if is_nullable {
263 stream_count += 1;
264 }
265
266 let (column_streams_tx, column_streams_rx): (Vec<_>, Vec<_>) =
267 (0..stream_count).map(|_| kanal::bounded_async(1)).unzip();
268
269 let handle = session.handle();
271 handle
272 .spawn(async move {
273 pin_mut!(columns_vec_stream);
274 while let Some(result) = columns_vec_stream.next().await {
275 match result {
276 Ok(columns) => {
277 for (tx, column) in column_streams_tx.iter().zip_eq(columns.into_iter())
278 {
279 let _ = tx.send(Ok(column)).await;
280 }
281 }
282 Err(e) => {
283 let e: Arc<VortexError> = Arc::new(e);
284 for tx in column_streams_tx.iter() {
285 let _ = tx.send(Err(VortexError::from(Arc::clone(&e)))).await;
286 }
287 break;
288 }
289 }
290 }
291 })
292 .detach();
293
294 let column_dtypes: Vec<DType> = if is_nullable {
296 std::iter::once(DType::Bool(Nullability::NonNullable))
297 .chain(struct_dtype.fields())
298 .collect()
299 } else {
300 struct_dtype.fields().collect()
301 };
302
303 let column_names: Vec<FieldName> = if is_nullable {
304 std::iter::once(FieldName::from("__validity"))
305 .chain(struct_dtype.names().iter().cloned())
306 .collect()
307 } else {
308 struct_dtype.names().iter().cloned().collect()
309 };
310
311 let layout_futures: Vec<_> = column_dtypes
312 .into_iter()
313 .zip_eq(column_streams_rx)
314 .zip_eq(column_names)
315 .enumerate()
316 .map(move |(index, ((dtype, recv), name))| {
317 let column_stream =
318 SequentialStreamAdapter::new(dtype.clone(), recv.into_stream().boxed())
319 .sendable();
320 let child_eof = eof.split_off();
321 let field = Field::Name(name.clone());
322 let session = session.clone();
323 let ctx = ctx.clone();
324 let segment_sink = Arc::clone(&segment_sink);
325 handle.spawn_nested(move |h| {
326 let validity = Arc::clone(&self.validity);
327 let writer = self
329 .leaf_writers
330 .get(&FieldPath::from_name(name))
331 .cloned()
332 .unwrap_or_else(|| {
333 if dtype.is_struct() {
334 Arc::new(self.descend(&field))
336 } else {
337 Arc::clone(&self.fallback)
339 }
340 });
341 let session = session.with_handle(h);
342
343 async move {
344 if index == 0 && is_nullable {
348 validity
349 .write_stream(ctx, segment_sink, column_stream, child_eof, &session)
350 .await
351 } else {
352 writer
354 .write_stream(ctx, segment_sink, column_stream, child_eof, &session)
355 .await
356 }
357 }
358 })
359 })
360 .collect();
361
362 let column_layouts = try_join_all(layout_futures).await?;
363 let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0);
366 Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout())
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use std::sync::Arc;
373
374 use vortex_array::dtype::FieldPath;
375 use vortex_array::field_path;
376
377 use crate::layouts::flat::writer::FlatLayoutStrategy;
378 use crate::layouts::table::TableStrategy;
379
380 #[test]
381 #[should_panic(
382 expected = "Override for field_path $a.$b conflicts with existing override for $a.$b.$c"
383 )]
384 fn test_overlapping_paths_fail() {
385 let flat = Arc::new(FlatLayoutStrategy::default());
386
387 let path = TableStrategy::new(
389 Arc::<FlatLayoutStrategy>::clone(&flat),
390 Arc::<FlatLayoutStrategy>::clone(&flat),
391 )
392 .with_field_writer(field_path!(a.b.c), Arc::<FlatLayoutStrategy>::clone(&flat));
393
394 let _path = path.with_field_writer(field_path!(a.b), flat);
396 }
397
398 #[test]
399 #[should_panic(
400 expected = "Do not set override as a root strategy, instead set the default strategy"
401 )]
402 fn test_root_override() {
403 let flat = Arc::new(FlatLayoutStrategy::default());
404 let _strategy = TableStrategy::new(
405 Arc::<FlatLayoutStrategy>::clone(&flat),
406 Arc::<FlatLayoutStrategy>::clone(&flat),
407 )
408 .with_field_writer(FieldPath::root(), flat);
409 }
410}