vortex_layout/layouts/
table.rs1use std::sync::Arc;
8
9use async_trait::async_trait;
10use futures::StreamExt;
11use futures::TryStreamExt;
12use futures::future::try_join_all;
13use futures::pin_mut;
14use itertools::Itertools;
15use vortex_array::ArrayContext;
16use vortex_array::ArrayRef;
17use vortex_array::IntoArray;
18use vortex_array::ToCanonical;
19use vortex_dtype::DType;
20use vortex_dtype::Field;
21use vortex_dtype::FieldName;
22use vortex_dtype::FieldPath;
23use vortex_dtype::Nullability;
24use vortex_error::VortexError;
25use vortex_error::VortexResult;
26use vortex_error::vortex_bail;
27use vortex_io::kanal_ext::KanalExt;
28use vortex_io::runtime::Handle;
29use vortex_utils::aliases::DefaultHashBuilder;
30use vortex_utils::aliases::hash_map::HashMap;
31use vortex_utils::aliases::hash_set::HashSet;
32
33use crate::IntoLayout;
34use crate::LayoutRef;
35use crate::LayoutStrategy;
36use crate::layouts::struct_::StructLayout;
37use crate::segments::SegmentSinkRef;
38use crate::sequence::SendableSequentialStream;
39use crate::sequence::SequenceId;
40use crate::sequence::SequencePointer;
41use crate::sequence::SequentialStreamAdapter;
42use crate::sequence::SequentialStreamExt;
43
44pub struct TableStrategy {
47 leaf_writers: HashMap<FieldPath, Arc<dyn LayoutStrategy>>,
49 validity: Arc<dyn LayoutStrategy>,
51 fallback: Arc<dyn LayoutStrategy>,
53}
54
55impl TableStrategy {
56 pub fn new(validity: Arc<dyn LayoutStrategy>, fallback: Arc<dyn LayoutStrategy>) -> Self {
74 Self {
75 leaf_writers: Default::default(),
76 validity,
77 fallback,
78 }
79 }
80
81 pub fn with_field_writer(
108 mut self,
109 field_path: impl Into<FieldPath>,
110 writer: Arc<dyn LayoutStrategy>,
111 ) -> Self {
112 self.leaf_writers
113 .insert(self.validate_path(field_path.into()), writer);
114 self
115 }
116
117 pub fn with_field_writers(
121 mut self,
122 writers: impl IntoIterator<Item = (FieldPath, Arc<dyn LayoutStrategy>)>,
123 ) -> Self {
124 for (field_path, strategy) in writers {
125 self.leaf_writers
126 .insert(self.validate_path(field_path), strategy);
127 }
128 self
129 }
130
131 pub fn with_default_strategy(mut self, default: Arc<dyn LayoutStrategy>) -> Self {
133 self.fallback = default;
134 self
135 }
136
137 pub fn with_validity_strategy(mut self, validity: Arc<dyn LayoutStrategy>) -> Self {
139 self.validity = validity;
140 self
141 }
142}
143
144impl TableStrategy {
145 fn descend(&self, field: &Field) -> Self {
147 let mut new_writers = HashMap::with_capacity(self.leaf_writers.len());
150
151 for (field_path, strategy) in &self.leaf_writers {
152 if field_path.starts_with_field(field)
153 && let Some(subpath) = field_path.clone().step_into()
154 {
155 new_writers.insert(subpath, strategy.clone());
156 }
157 }
158
159 Self {
160 leaf_writers: new_writers,
161 validity: self.validity.clone(),
162 fallback: self.fallback.clone(),
163 }
164 }
165
166 fn validate_path(&self, path: FieldPath) -> FieldPath {
167 assert!(
168 !path.is_root(),
169 "Do not set override as a root strategy, instead set the default strategy"
170 );
171
172 for field_path in self.leaf_writers.keys() {
175 assert!(
176 !path.overlap(field_path),
177 "Override for field_path {path} conflicts with existing override for {field_path}"
178 );
179 }
180
181 path
182 }
183}
184
185#[async_trait]
187impl LayoutStrategy for TableStrategy {
188 async fn write_stream(
189 &self,
190 ctx: ArrayContext,
191 segment_sink: SegmentSinkRef,
192 stream: SendableSequentialStream,
193 mut eof: SequencePointer,
194 handle: Handle,
195 ) -> VortexResult<LayoutRef> {
196 let dtype = stream.dtype().clone();
197
198 if !dtype.is_struct() {
200 return self
201 .fallback
202 .write_stream(ctx, segment_sink, stream, eof, handle)
203 .await;
204 }
205
206 let struct_dtype = dtype.as_struct_fields();
207
208 if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len()
210 != struct_dtype.names().len()
211 {
212 vortex_bail!("StructLayout must have unique field names");
213 }
214 let is_nullable = dtype.is_nullable();
215
216 if struct_dtype.nfields() == 0 && !is_nullable {
219 let row_count = stream
220 .try_fold(
221 0u64,
222 |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) },
223 )
224 .await?;
225 return Ok(StructLayout::new(row_count, dtype, vec![]).into_layout());
226 }
227
228 let columns_vec_stream = stream.map(move |chunk| {
230 let (sequence_id, chunk) = chunk?;
231 let mut sequence_pointer = sequence_id.descend();
232 let struct_chunk = chunk.to_struct();
233 let mut columns: Vec<(SequenceId, ArrayRef)> = Vec::new();
234 if is_nullable {
235 columns.push((
236 sequence_pointer.advance(),
237 chunk.validity_mask()?.into_array(),
238 ));
239 }
240
241 columns.extend(
242 struct_chunk
243 .unmasked_fields()
244 .iter()
245 .map(|field| (sequence_pointer.advance(), field.to_array())),
246 );
247
248 Ok(columns)
249 });
250
251 let mut stream_count = struct_dtype.nfields();
252 if is_nullable {
253 stream_count += 1;
254 }
255
256 let (column_streams_tx, column_streams_rx): (Vec<_>, Vec<_>) =
257 (0..stream_count).map(|_| kanal::bounded_async(1)).unzip();
258
259 handle
261 .spawn(async move {
262 pin_mut!(columns_vec_stream);
263 while let Some(result) = columns_vec_stream.next().await {
264 match result {
265 Ok(columns) => {
266 for (tx, column) in column_streams_tx.iter().zip_eq(columns.into_iter())
267 {
268 let _ = tx.send(Ok(column)).await;
269 }
270 }
271 Err(e) => {
272 let e: Arc<VortexError> = Arc::new(e);
273 for tx in column_streams_tx.iter() {
274 let _ = tx.send(Err(VortexError::from(e.clone()))).await;
275 }
276 break;
277 }
278 }
279 }
280 })
281 .detach();
282
283 let column_dtypes: Vec<DType> = if is_nullable {
285 std::iter::once(DType::Bool(Nullability::NonNullable))
286 .chain(struct_dtype.fields())
287 .collect()
288 } else {
289 struct_dtype.fields().collect()
290 };
291
292 let column_names: Vec<FieldName> = if is_nullable {
293 std::iter::once(FieldName::from("__validity"))
294 .chain(struct_dtype.names().iter().cloned())
295 .collect()
296 } else {
297 struct_dtype.names().iter().cloned().collect()
298 };
299
300 let layout_futures: Vec<_> = column_dtypes
301 .into_iter()
302 .zip_eq(column_streams_rx)
303 .zip_eq(column_names)
304 .enumerate()
305 .map(move |(index, ((dtype, recv), name))| {
306 let column_stream =
307 SequentialStreamAdapter::new(dtype.clone(), recv.into_stream().boxed())
308 .sendable();
309 let child_eof = eof.split_off();
310 let field = Field::Name(name.clone());
311 handle.spawn_nested(|h| {
312 let validity = self.validity.clone();
313 let writer = self
315 .leaf_writers
316 .get(&FieldPath::from_name(name))
317 .cloned()
318 .unwrap_or_else(|| {
319 if dtype.is_struct() {
320 Arc::new(self.descend(&field))
322 } else {
323 self.fallback.clone()
325 }
326 });
327 let ctx = ctx.clone();
328 let segment_sink = segment_sink.clone();
329
330 async move {
331 if index == 0 && is_nullable {
335 validity
336 .write_stream(ctx, segment_sink, column_stream, child_eof, h)
337 .await
338 } else {
339 writer
341 .write_stream(ctx, segment_sink, column_stream, child_eof, h)
342 .await
343 }
344 }
345 })
346 })
347 .collect();
348
349 let column_layouts = try_join_all(layout_futures).await?;
350 let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0);
353 Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout())
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use std::sync::Arc;
360
361 use vortex_dtype::FieldPath;
362 use vortex_dtype::field_path;
363
364 use crate::layouts::flat::writer::FlatLayoutStrategy;
365 use crate::layouts::table::TableStrategy;
366
367 #[test]
368 #[should_panic(
369 expected = "Override for field_path $a.$b conflicts with existing override for $a.$b.$c"
370 )]
371 fn test_overlapping_paths_fail() {
372 let flat = Arc::new(FlatLayoutStrategy::default());
373
374 let path = TableStrategy::new(flat.clone(), flat.clone())
376 .with_field_writer(field_path!(a.b.c), flat.clone());
377
378 let _path = path.with_field_writer(field_path!(a.b), flat);
380 }
381
382 #[test]
383 #[should_panic(
384 expected = "Do not set override as a root strategy, instead set the default strategy"
385 )]
386 fn test_root_override() {
387 let flat = Arc::new(FlatLayoutStrategy::default());
388 let _strategy = TableStrategy::new(flat.clone(), flat.clone())
389 .with_field_writer(FieldPath::root(), flat);
390 }
391}