Skip to main content

vortex_layout/layouts/
table.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! A configurable writer strategy for tabular data.
5//!
6//! Allows the caller to override specific leaf fields with custom layout strategies.
7
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use futures::StreamExt;
12use futures::TryStreamExt;
13use futures::future::try_join_all;
14use futures::pin_mut;
15use itertools::Itertools;
16use vortex_array::ArrayContext;
17use vortex_array::ArrayRef;
18use vortex_array::IntoArray;
19use vortex_array::VortexSessionExecute;
20use vortex_array::arrays::StructArray;
21use vortex_array::arrays::struct_::StructArrayExt;
22use vortex_array::dtype::DType;
23use vortex_array::dtype::Field;
24use vortex_array::dtype::FieldName;
25use vortex_array::dtype::FieldPath;
26use vortex_array::dtype::Nullability;
27use vortex_error::VortexError;
28use vortex_error::VortexResult;
29use vortex_error::vortex_bail;
30use vortex_io::kanal_ext::KanalExt;
31use vortex_io::session::RuntimeSessionExt;
32use vortex_session::VortexSession;
33use vortex_utils::aliases::DefaultHashBuilder;
34use vortex_utils::aliases::hash_map::HashMap;
35use vortex_utils::aliases::hash_set::HashSet;
36
37use crate::IntoLayout;
38use crate::LayoutRef;
39use crate::LayoutStrategy;
40use crate::layouts::struct_::StructLayout;
41use crate::segments::SegmentSinkRef;
42use crate::sequence::SendableSequentialStream;
43use crate::sequence::SequenceId;
44use crate::sequence::SequencePointer;
45use crate::sequence::SequentialStreamAdapter;
46use crate::sequence::SequentialStreamExt;
47
48/// A configurable strategy for writing tables with nested field columns, allowing
49/// overrides for specific leaf columns.
50pub struct TableStrategy {
51    /// A set of leaf field overrides, e.g. to force one column to be compact-compressed.
52    leaf_writers: HashMap<FieldPath, Arc<dyn LayoutStrategy>>,
53    /// The writer for any validity arrays that may be present
54    validity: Arc<dyn LayoutStrategy>,
55    /// The fallback writer for any fields that do not have an explicit writer set in `leaf_writers`
56    fallback: Arc<dyn LayoutStrategy>,
57}
58
59impl TableStrategy {
60    /// Create a new writer with the specified write strategies for validity, and for all leaf
61    /// fields, with no overrides.
62    ///
63    /// Additional overrides can be configured using the `with_leaf_strategy` method.
64    ///
65    /// ## Example
66    ///
67    /// ```ignore
68    /// # use std::sync::Arc;
69    /// # use vortex_layout::layouts::flat::writer::FlatLayoutStrategy;
70    /// # use vortex_layout::layouts::table::TableStrategy;
71    ///
72    /// // Build a write strategy that does not compress validity or any leaf fields.
73    /// let flat = Arc::new(FlatLayoutStrategy::default());
74    ///
75    /// let strategy = TableStrategy::new(Arc::<FlatLayoutStrategy>::clone(&flat), Arc::<FlatLayoutStrategy>::clone(&flat));
76    /// ```
77    pub fn new(validity: Arc<dyn LayoutStrategy>, fallback: Arc<dyn LayoutStrategy>) -> Self {
78        Self {
79            leaf_writers: Default::default(),
80            validity,
81            fallback,
82        }
83    }
84
85    /// Add a custom write strategy for the given leaf field.
86    ///
87    /// ## Example
88    ///
89    /// ```ignore
90    /// # use std::sync::Arc;
91    /// # use vortex_array::dtype::{field_path, Field, FieldPath};
92    /// # use vortex_btrblocks::BtrBlocksCompressor;
93    /// # use vortex_layout::layouts::compressed::CompressingStrategy;
94    /// # use vortex_layout::layouts::flat::writer::FlatLayoutStrategy;
95    /// # use vortex_layout::layouts::table::TableStrategy;
96    ///
97    /// // A strategy for compressing data using the balanced BtrBlocks compressor.
98    /// let compress =
99    ///     CompressingStrategy::new(FlatLayoutStrategy::default(), BtrBlocksCompressor::default());
100    ///
101    /// // Our combined strategy uses no compression for validity buffers, BtrBlocks compression
102    /// // for most columns, and stores a nested binary column uncompressed (flat) because it
103    /// // is pre-compressed or never filtered on.
104    /// let strategy = TableStrategy::new(
105    ///         Arc::new(FlatLayoutStrategy::default()),
106    ///         Arc::new(compress),
107    ///     )
108    ///     .with_field_writer(
109    ///         field_path!(request.body.bytes),
110    ///         Arc::new(FlatLayoutStrategy::default()),
111    ///     );
112    /// ```
113    pub fn with_field_writer(
114        mut self,
115        field_path: impl Into<FieldPath>,
116        writer: Arc<dyn LayoutStrategy>,
117    ) -> Self {
118        self.leaf_writers
119            .insert(self.validate_path(field_path.into()), writer);
120        self
121    }
122
123    /// Set writers for several fields at once.
124    ///
125    /// See also: [`with_field_writer`][Self::with_field_writer].
126    pub fn with_field_writers(
127        mut self,
128        writers: impl IntoIterator<Item = (FieldPath, Arc<dyn LayoutStrategy>)>,
129    ) -> Self {
130        for (field_path, strategy) in writers {
131            self.leaf_writers
132                .insert(self.validate_path(field_path), strategy);
133        }
134        self
135    }
136
137    /// Override the default strategy for leaf columns that don't have overrides.
138    pub fn with_default_strategy(mut self, default: Arc<dyn LayoutStrategy>) -> Self {
139        self.fallback = default;
140        self
141    }
142
143    /// Override the strategy for compressing struct validity at all levels of the schema tree.
144    pub fn with_validity_strategy(mut self, validity: Arc<dyn LayoutStrategy>) -> Self {
145        self.validity = validity;
146        self
147    }
148}
149
150impl TableStrategy {
151    /// Descend into a subfield for the writer.
152    fn descend(&self, field: &Field) -> Self {
153        // Start with the existing set of overrides, then only retain the ones that contain
154        // the current field
155        let mut new_writers = HashMap::with_capacity(self.leaf_writers.len());
156
157        for (field_path, strategy) in &self.leaf_writers {
158            if field_path.starts_with_field(field)
159                && let Some(subpath) = field_path.clone().step_into()
160            {
161                new_writers.insert(subpath, Arc::clone(strategy));
162            }
163        }
164
165        Self {
166            leaf_writers: new_writers,
167            validity: Arc::clone(&self.validity),
168            fallback: Arc::clone(&self.fallback),
169        }
170    }
171
172    fn validate_path(&self, path: FieldPath) -> FieldPath {
173        assert!(
174            !path.is_root(),
175            "Do not set override as a root strategy, instead set the default strategy"
176        );
177
178        // Validate that the field path does not conflict with any overrides
179        // that we've added by overlapping.
180        for field_path in self.leaf_writers.keys() {
181            assert!(
182                !path.overlap(field_path),
183                "Override for field_path {path} conflicts with existing override for {field_path}"
184            );
185        }
186
187        path
188    }
189}
190
191/// Specialized strategy for when we exactly know the input schema.
192#[async_trait]
193impl LayoutStrategy for TableStrategy {
194    async fn write_stream(
195        &self,
196        ctx: ArrayContext,
197        segment_sink: SegmentSinkRef,
198        stream: SendableSequentialStream,
199        mut eof: SequencePointer,
200        session: &VortexSession,
201    ) -> VortexResult<LayoutRef> {
202        let dtype = stream.dtype().clone();
203
204        // Fallback: if the array is not a struct, fallback to writing a single array.
205        if !dtype.is_struct() {
206            return self
207                .fallback
208                .write_stream(ctx, segment_sink, stream, eof, session)
209                .await;
210        }
211
212        let struct_dtype = dtype.as_struct_fields();
213
214        // Check for unique field names at write time.
215        if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len()
216            != struct_dtype.names().len()
217        {
218            vortex_bail!("StructLayout must have unique field names");
219        }
220        let is_nullable = dtype.is_nullable();
221
222        // Optimization: when there are no fields, don't spawn any work and just write a trivial
223        // StructLayout.
224        if struct_dtype.nfields() == 0 && !is_nullable {
225            let row_count = stream
226                .try_fold(
227                    0u64,
228                    |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) },
229                )
230                .await?;
231            return Ok(StructLayout::new(row_count, dtype, vec![]).into_layout());
232        }
233
234        // stream<struct_chunk> -> stream<vec<column_chunk>>
235        let columns_session = session.clone();
236        let columns_vec_stream = stream.map(move |chunk| {
237            let (sequence_id, chunk) = chunk?;
238            let mut sequence_pointer = sequence_id.descend();
239            let mut ctx = columns_session.create_execution_ctx();
240            let struct_chunk = chunk.clone().execute::<StructArray>(&mut ctx)?;
241            let mut columns: Vec<(SequenceId, ArrayRef)> = Vec::new();
242            if is_nullable {
243                columns.push((
244                    sequence_pointer.advance(),
245                    chunk
246                        .validity()?
247                        .execute_mask(chunk.len(), &mut ctx)?
248                        .into_array(),
249                ));
250            }
251
252            columns.extend(
253                struct_chunk
254                    .iter_unmasked_fields()
255                    .map(|field| (sequence_pointer.advance(), field.clone())),
256            );
257
258            Ok(columns)
259        });
260
261        let mut stream_count = struct_dtype.nfields();
262        if is_nullable {
263            stream_count += 1;
264        }
265
266        let (column_streams_tx, column_streams_rx): (Vec<_>, Vec<_>) =
267            (0..stream_count).map(|_| kanal::bounded_async(1)).unzip();
268
269        // Spawn a task to fan out column chunks to their respective transposed streams
270        let handle = session.handle();
271        handle
272            .spawn(async move {
273                pin_mut!(columns_vec_stream);
274                while let Some(result) = columns_vec_stream.next().await {
275                    match result {
276                        Ok(columns) => {
277                            for (tx, column) in column_streams_tx.iter().zip_eq(columns.into_iter())
278                            {
279                                let _ = tx.send(Ok(column)).await;
280                            }
281                        }
282                        Err(e) => {
283                            let e: Arc<VortexError> = Arc::new(e);
284                            for tx in column_streams_tx.iter() {
285                                let _ = tx.send(Err(VortexError::from(Arc::clone(&e)))).await;
286                            }
287                            break;
288                        }
289                    }
290                }
291            })
292            .detach();
293
294        // First child column is the validity, subsequence children are the individual struct fields
295        let column_dtypes: Vec<DType> = if is_nullable {
296            std::iter::once(DType::Bool(Nullability::NonNullable))
297                .chain(struct_dtype.fields())
298                .collect()
299        } else {
300            struct_dtype.fields().collect()
301        };
302
303        let column_names: Vec<FieldName> = if is_nullable {
304            std::iter::once(FieldName::from("__validity"))
305                .chain(struct_dtype.names().iter().cloned())
306                .collect()
307        } else {
308            struct_dtype.names().iter().cloned().collect()
309        };
310
311        let layout_futures: Vec<_> = column_dtypes
312            .into_iter()
313            .zip_eq(column_streams_rx)
314            .zip_eq(column_names)
315            .enumerate()
316            .map(move |(index, ((dtype, recv), name))| {
317                let column_stream =
318                    SequentialStreamAdapter::new(dtype.clone(), recv.into_stream().boxed())
319                        .sendable();
320                let child_eof = eof.split_off();
321                let field = Field::Name(name.clone());
322                let session = session.clone();
323                let ctx = ctx.clone();
324                let segment_sink = Arc::clone(&segment_sink);
325                handle.spawn_nested(move |h| {
326                    let validity = Arc::clone(&self.validity);
327                    // descend further and try with new fields
328                    let writer = self
329                        .leaf_writers
330                        .get(&FieldPath::from_name(name))
331                        .cloned()
332                        .unwrap_or_else(|| {
333                            if dtype.is_struct() {
334                                // Step into the field path for struct columns
335                                Arc::new(self.descend(&field))
336                            } else {
337                                // Use fallback for leaf columns
338                                Arc::clone(&self.fallback)
339                            }
340                        });
341                    let session = session.with_handle(h);
342
343                    async move {
344                        // If we have a matching writer, we use it.
345                        // Otherwise, we descend into a new modified one.
346                        // Write validity stream
347                        if index == 0 && is_nullable {
348                            validity
349                                .write_stream(ctx, segment_sink, column_stream, child_eof, &session)
350                                .await
351                        } else {
352                            // Use the underlying writer, otherwise use the fallback writer.
353                            writer
354                                .write_stream(ctx, segment_sink, column_stream, child_eof, &session)
355                                .await
356                        }
357                    }
358                })
359            })
360            .collect();
361
362        let column_layouts = try_join_all(layout_futures).await?;
363        // TODO(os): transposed stream could count row counts as well,
364        // This must hold though, all columns must have the same row count of the struct layout
365        let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0);
366        Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout())
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use std::sync::Arc;
373
374    use vortex_array::dtype::FieldPath;
375    use vortex_array::field_path;
376
377    use crate::layouts::flat::writer::FlatLayoutStrategy;
378    use crate::layouts::table::TableStrategy;
379
380    #[test]
381    #[should_panic(
382        expected = "Override for field_path $a.$b conflicts with existing override for $a.$b.$c"
383    )]
384    fn test_overlapping_paths_fail() {
385        let flat = Arc::new(FlatLayoutStrategy::default());
386
387        // Success
388        let path = TableStrategy::new(
389            Arc::<FlatLayoutStrategy>::clone(&flat),
390            Arc::<FlatLayoutStrategy>::clone(&flat),
391        )
392        .with_field_writer(field_path!(a.b.c), Arc::<FlatLayoutStrategy>::clone(&flat));
393
394        // Should panic right here.
395        let _path = path.with_field_writer(field_path!(a.b), flat);
396    }
397
398    #[test]
399    #[should_panic(
400        expected = "Do not set override as a root strategy, instead set the default strategy"
401    )]
402    fn test_root_override() {
403        let flat = Arc::new(FlatLayoutStrategy::default());
404        let _strategy = TableStrategy::new(
405            Arc::<FlatLayoutStrategy>::clone(&flat),
406            Arc::<FlatLayoutStrategy>::clone(&flat),
407        )
408        .with_field_writer(FieldPath::root(), flat);
409    }
410}