vortex_layout/layouts/
table.rs1use std::sync::Arc;
9
10use async_trait::async_trait;
11use futures::StreamExt;
12use futures::TryStreamExt;
13use futures::future::try_join_all;
14use futures::pin_mut;
15use itertools::Itertools;
16use vortex_array::ArrayContext;
17use vortex_array::ArrayRef;
18use vortex_array::IntoArray;
19use vortex_array::ToCanonical;
20use vortex_array::arrays::struct_::StructArrayExt;
21use vortex_array::dtype::DType;
22use vortex_array::dtype::Field;
23use vortex_array::dtype::FieldName;
24use vortex_array::dtype::FieldPath;
25use vortex_array::dtype::Nullability;
26use vortex_error::VortexError;
27use vortex_error::VortexResult;
28use vortex_error::vortex_bail;
29use vortex_io::kanal_ext::KanalExt;
30use vortex_io::session::RuntimeSessionExt;
31use vortex_session::VortexSession;
32use vortex_utils::aliases::DefaultHashBuilder;
33use vortex_utils::aliases::hash_map::HashMap;
34use vortex_utils::aliases::hash_set::HashSet;
35
36use crate::IntoLayout;
37use crate::LayoutRef;
38use crate::LayoutStrategy;
39use crate::layouts::struct_::StructLayout;
40use crate::segments::SegmentSinkRef;
41use crate::sequence::SendableSequentialStream;
42use crate::sequence::SequenceId;
43use crate::sequence::SequencePointer;
44use crate::sequence::SequentialStreamAdapter;
45use crate::sequence::SequentialStreamExt;
46
47pub struct TableStrategy {
50 leaf_writers: HashMap<FieldPath, Arc<dyn LayoutStrategy>>,
52 validity: Arc<dyn LayoutStrategy>,
54 fallback: Arc<dyn LayoutStrategy>,
56}
57
58impl TableStrategy {
59 pub fn new(validity: Arc<dyn LayoutStrategy>, fallback: Arc<dyn LayoutStrategy>) -> Self {
77 Self {
78 leaf_writers: Default::default(),
79 validity,
80 fallback,
81 }
82 }
83
84 pub fn with_field_writer(
113 mut self,
114 field_path: impl Into<FieldPath>,
115 writer: Arc<dyn LayoutStrategy>,
116 ) -> Self {
117 self.leaf_writers
118 .insert(self.validate_path(field_path.into()), writer);
119 self
120 }
121
122 pub fn with_field_writers(
126 mut self,
127 writers: impl IntoIterator<Item = (FieldPath, Arc<dyn LayoutStrategy>)>,
128 ) -> Self {
129 for (field_path, strategy) in writers {
130 self.leaf_writers
131 .insert(self.validate_path(field_path), strategy);
132 }
133 self
134 }
135
136 pub fn with_default_strategy(mut self, default: Arc<dyn LayoutStrategy>) -> Self {
138 self.fallback = default;
139 self
140 }
141
142 pub fn with_validity_strategy(mut self, validity: Arc<dyn LayoutStrategy>) -> Self {
144 self.validity = validity;
145 self
146 }
147}
148
149impl TableStrategy {
150 fn descend(&self, field: &Field) -> Self {
152 let mut new_writers = HashMap::with_capacity(self.leaf_writers.len());
155
156 for (field_path, strategy) in &self.leaf_writers {
157 if field_path.starts_with_field(field)
158 && let Some(subpath) = field_path.clone().step_into()
159 {
160 new_writers.insert(subpath, Arc::clone(strategy));
161 }
162 }
163
164 Self {
165 leaf_writers: new_writers,
166 validity: Arc::clone(&self.validity),
167 fallback: Arc::clone(&self.fallback),
168 }
169 }
170
171 fn validate_path(&self, path: FieldPath) -> FieldPath {
172 assert!(
173 !path.is_root(),
174 "Do not set override as a root strategy, instead set the default strategy"
175 );
176
177 for field_path in self.leaf_writers.keys() {
180 assert!(
181 !path.overlap(field_path),
182 "Override for field_path {path} conflicts with existing override for {field_path}"
183 );
184 }
185
186 path
187 }
188}
189
190#[async_trait]
192impl LayoutStrategy for TableStrategy {
193 async fn write_stream(
194 &self,
195 ctx: ArrayContext,
196 segment_sink: SegmentSinkRef,
197 stream: SendableSequentialStream,
198 mut eof: SequencePointer,
199 session: &VortexSession,
200 ) -> VortexResult<LayoutRef> {
201 let dtype = stream.dtype().clone();
202
203 if !dtype.is_struct() {
205 return self
206 .fallback
207 .write_stream(ctx, segment_sink, stream, eof, session)
208 .await;
209 }
210
211 let struct_dtype = dtype.as_struct_fields();
212
213 if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len()
215 != struct_dtype.names().len()
216 {
217 vortex_bail!("StructLayout must have unique field names");
218 }
219 let is_nullable = dtype.is_nullable();
220
221 if struct_dtype.nfields() == 0 && !is_nullable {
224 let row_count = stream
225 .try_fold(
226 0u64,
227 |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) },
228 )
229 .await?;
230 return Ok(StructLayout::new(row_count, dtype, vec![]).into_layout());
231 }
232
233 let columns_vec_stream = stream.map(move |chunk| {
235 let (sequence_id, chunk) = chunk?;
236 let mut sequence_pointer = sequence_id.descend();
237 let struct_chunk = chunk.to_struct();
238 let mut columns: Vec<(SequenceId, ArrayRef)> = Vec::new();
239 if is_nullable {
240 columns.push((
241 sequence_pointer.advance(),
242 chunk.validity_mask()?.into_array(),
243 ));
244 }
245
246 columns.extend(
247 struct_chunk
248 .iter_unmasked_fields()
249 .map(|field| (sequence_pointer.advance(), field.clone())),
250 );
251
252 Ok(columns)
253 });
254
255 let mut stream_count = struct_dtype.nfields();
256 if is_nullable {
257 stream_count += 1;
258 }
259
260 let (column_streams_tx, column_streams_rx): (Vec<_>, Vec<_>) =
261 (0..stream_count).map(|_| kanal::bounded_async(1)).unzip();
262
263 let handle = session.handle();
265 handle
266 .spawn(async move {
267 pin_mut!(columns_vec_stream);
268 while let Some(result) = columns_vec_stream.next().await {
269 match result {
270 Ok(columns) => {
271 for (tx, column) in column_streams_tx.iter().zip_eq(columns.into_iter())
272 {
273 let _ = tx.send(Ok(column)).await;
274 }
275 }
276 Err(e) => {
277 let e: Arc<VortexError> = Arc::new(e);
278 for tx in column_streams_tx.iter() {
279 let _ = tx.send(Err(VortexError::from(Arc::clone(&e)))).await;
280 }
281 break;
282 }
283 }
284 }
285 })
286 .detach();
287
288 let column_dtypes: Vec<DType> = if is_nullable {
290 std::iter::once(DType::Bool(Nullability::NonNullable))
291 .chain(struct_dtype.fields())
292 .collect()
293 } else {
294 struct_dtype.fields().collect()
295 };
296
297 let column_names: Vec<FieldName> = if is_nullable {
298 std::iter::once(FieldName::from("__validity"))
299 .chain(struct_dtype.names().iter().cloned())
300 .collect()
301 } else {
302 struct_dtype.names().iter().cloned().collect()
303 };
304
305 let layout_futures: Vec<_> = column_dtypes
306 .into_iter()
307 .zip_eq(column_streams_rx)
308 .zip_eq(column_names)
309 .enumerate()
310 .map(move |(index, ((dtype, recv), name))| {
311 let column_stream =
312 SequentialStreamAdapter::new(dtype.clone(), recv.into_stream().boxed())
313 .sendable();
314 let child_eof = eof.split_off();
315 let field = Field::Name(name.clone());
316 let session = session.clone();
317 let ctx = ctx.clone();
318 let segment_sink = Arc::clone(&segment_sink);
319 handle.spawn_nested(move |h| {
320 let validity = Arc::clone(&self.validity);
321 let writer = self
323 .leaf_writers
324 .get(&FieldPath::from_name(name))
325 .cloned()
326 .unwrap_or_else(|| {
327 if dtype.is_struct() {
328 Arc::new(self.descend(&field))
330 } else {
331 Arc::clone(&self.fallback)
333 }
334 });
335 let session = session.with_handle(h);
336
337 async move {
338 if index == 0 && is_nullable {
342 validity
343 .write_stream(ctx, segment_sink, column_stream, child_eof, &session)
344 .await
345 } else {
346 writer
348 .write_stream(ctx, segment_sink, column_stream, child_eof, &session)
349 .await
350 }
351 }
352 })
353 })
354 .collect();
355
356 let column_layouts = try_join_all(layout_futures).await?;
357 let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0);
360 Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout())
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use std::sync::Arc;
367
368 use vortex_array::dtype::FieldPath;
369 use vortex_array::field_path;
370
371 use crate::layouts::flat::writer::FlatLayoutStrategy;
372 use crate::layouts::table::TableStrategy;
373
374 #[test]
375 #[should_panic(
376 expected = "Override for field_path $a.$b conflicts with existing override for $a.$b.$c"
377 )]
378 fn test_overlapping_paths_fail() {
379 let flat = Arc::new(FlatLayoutStrategy::default());
380
381 let path = TableStrategy::new(
383 Arc::<FlatLayoutStrategy>::clone(&flat),
384 Arc::<FlatLayoutStrategy>::clone(&flat),
385 )
386 .with_field_writer(field_path!(a.b.c), Arc::<FlatLayoutStrategy>::clone(&flat));
387
388 let _path = path.with_field_writer(field_path!(a.b), flat);
390 }
391
392 #[test]
393 #[should_panic(
394 expected = "Do not set override as a root strategy, instead set the default strategy"
395 )]
396 fn test_root_override() {
397 let flat = Arc::new(FlatLayoutStrategy::default());
398 let _strategy = TableStrategy::new(
399 Arc::<FlatLayoutStrategy>::clone(&flat),
400 Arc::<FlatLayoutStrategy>::clone(&flat),
401 )
402 .with_field_writer(FieldPath::root(), flat);
403 }
404}