Skip to main content

vortex_layout/layouts/
table.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! A more configurable variant of the `StructStrategy` that allows overwriting
5//! specific leaf fields with custom write strategies.
6
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use futures::StreamExt;
11use futures::TryStreamExt;
12use futures::future::try_join_all;
13use futures::pin_mut;
14use itertools::Itertools;
15use vortex_array::ArrayContext;
16use vortex_array::ArrayRef;
17use vortex_array::IntoArray;
18use vortex_array::ToCanonical;
19use vortex_dtype::DType;
20use vortex_dtype::Field;
21use vortex_dtype::FieldName;
22use vortex_dtype::FieldPath;
23use vortex_dtype::Nullability;
24use vortex_error::VortexError;
25use vortex_error::VortexResult;
26use vortex_error::vortex_bail;
27use vortex_io::kanal_ext::KanalExt;
28use vortex_io::runtime::Handle;
29use vortex_utils::aliases::DefaultHashBuilder;
30use vortex_utils::aliases::hash_map::HashMap;
31use vortex_utils::aliases::hash_set::HashSet;
32
33use crate::IntoLayout;
34use crate::LayoutRef;
35use crate::LayoutStrategy;
36use crate::layouts::struct_::StructLayout;
37use crate::segments::SegmentSinkRef;
38use crate::sequence::SendableSequentialStream;
39use crate::sequence::SequenceId;
40use crate::sequence::SequencePointer;
41use crate::sequence::SequentialStreamAdapter;
42use crate::sequence::SequentialStreamExt;
43
44/// A configurable strategy for writing tables with nested field columns, allowing
45/// overrides for specific leaf columns.
46pub struct TableStrategy {
47    /// A set of leaf field overrides, e.g. to force one column to be compact-compressed.
48    leaf_writers: HashMap<FieldPath, Arc<dyn LayoutStrategy>>,
49    /// The writer for any validity arrays that may be present
50    validity: Arc<dyn LayoutStrategy>,
51    /// The fallback writer for any fields that do not have an explicit writer set in `leaf_writers`
52    fallback: Arc<dyn LayoutStrategy>,
53}
54
55impl TableStrategy {
56    /// Create a new writer with the specified write strategies for validity, and for all leaf
57    /// fields, with no overrides.
58    ///
59    /// Additional overrides can be configured using the `with_leaf_strategy` method.
60    ///
61    /// ## Example
62    ///
63    /// ```ignore
64    /// # use std::sync::Arc;
65    /// # use vortex_layout::layouts::flat::writer::FlatLayoutStrategy;
66    /// # use vortex_layout::layouts::table::TableStrategy;
67    ///
68    /// // Build a write strategy that does not compress validity or any leaf fields.
69    /// let flat = Arc::new(FlatLayoutStrategy::default());
70    ///
71    /// let strategy = TableStrategy::new(flat.clone(), flat.clone());
72    /// ```
73    pub fn new(validity: Arc<dyn LayoutStrategy>, fallback: Arc<dyn LayoutStrategy>) -> Self {
74        Self {
75            leaf_writers: Default::default(),
76            validity,
77            fallback,
78        }
79    }
80
81    /// Add a custom write strategy for the given leaf field.
82    ///
83    /// ## Example
84    ///
85    /// ```ignore
86    /// # use std::sync::Arc;
87    /// # use vortex_dtype::{field_path, Field, FieldPath};
88    /// # use vortex_layout::layouts::compressed::CompressingStrategy;
89    /// # use vortex_layout::layouts::flat::writer::FlatLayoutStrategy;
90    /// # use vortex_layout::layouts::table::TableStrategy;
91    ///
92    /// // A strategy for compressing data using the balanced BtrBlocks compressor.
93    /// let compress = CompressingStrategy::new_btrblocks(FlatLayoutStrategy::default(), true);
94    ///
95    /// // Our combined strategy uses no compression for validity buffers, BtrBlocks compression
96    /// // for most columns, and stores a nested binary column uncompressed (flat) because it
97    /// // is pre-compressed or never filtered on.
98    /// let strategy = TableStrategy::new(
99    ///         Arc::new(FlatLayoutStrategy::default()),
100    ///         Arc::new(compress),
101    ///     )
102    ///     .with_field_writer(
103    ///         field_path!(request.body.bytes),
104    ///         Arc::new(FlatLayoutStrategy::default()),
105    ///     );
106    /// ```
107    pub fn with_field_writer(
108        mut self,
109        field_path: impl Into<FieldPath>,
110        writer: Arc<dyn LayoutStrategy>,
111    ) -> Self {
112        self.leaf_writers
113            .insert(self.validate_path(field_path.into()), writer);
114        self
115    }
116
117    /// Set writers for several fields at once.
118    ///
119    /// See also: [`with_field_writer`][Self::with_field_writer].
120    pub fn with_field_writers(
121        mut self,
122        writers: impl IntoIterator<Item = (FieldPath, Arc<dyn LayoutStrategy>)>,
123    ) -> Self {
124        for (field_path, strategy) in writers {
125            self.leaf_writers
126                .insert(self.validate_path(field_path), strategy);
127        }
128        self
129    }
130
131    /// Override the default strategy for leaf columns that don't have overrides.
132    pub fn with_default_strategy(mut self, default: Arc<dyn LayoutStrategy>) -> Self {
133        self.fallback = default;
134        self
135    }
136
137    /// Override the strategy for compressing struct validity at all levels of the schema tree.
138    pub fn with_validity_strategy(mut self, validity: Arc<dyn LayoutStrategy>) -> Self {
139        self.validity = validity;
140        self
141    }
142}
143
144impl TableStrategy {
145    /// Descend into a subfield for the writer.
146    fn descend(&self, field: &Field) -> Self {
147        // Start with the existing set of overrides, then only retain the ones that contain
148        // the current field
149        let mut new_writers = HashMap::with_capacity(self.leaf_writers.len());
150
151        for (field_path, strategy) in &self.leaf_writers {
152            if field_path.starts_with_field(field)
153                && let Some(subpath) = field_path.clone().step_into()
154            {
155                new_writers.insert(subpath, strategy.clone());
156            }
157        }
158
159        Self {
160            leaf_writers: new_writers,
161            validity: self.validity.clone(),
162            fallback: self.fallback.clone(),
163        }
164    }
165
166    fn validate_path(&self, path: FieldPath) -> FieldPath {
167        assert!(
168            !path.is_root(),
169            "Do not set override as a root strategy, instead set the default strategy"
170        );
171
172        // Validate that the field path does not conflict with any overrides
173        // that we've added by overlapping.
174        for field_path in self.leaf_writers.keys() {
175            assert!(
176                !path.overlap(field_path),
177                "Override for field_path {path} conflicts with existing override for {field_path}"
178            );
179        }
180
181        path
182    }
183}
184
185/// Specialized strategy for when we exactly know the input schema.
186#[async_trait]
187impl LayoutStrategy for TableStrategy {
188    async fn write_stream(
189        &self,
190        ctx: ArrayContext,
191        segment_sink: SegmentSinkRef,
192        stream: SendableSequentialStream,
193        mut eof: SequencePointer,
194        handle: Handle,
195    ) -> VortexResult<LayoutRef> {
196        let dtype = stream.dtype().clone();
197
198        // Fallback: if the array is not a struct, fallback to writing a single array.
199        if !dtype.is_struct() {
200            return self
201                .fallback
202                .write_stream(ctx, segment_sink, stream, eof, handle)
203                .await;
204        }
205
206        let struct_dtype = dtype.as_struct_fields();
207
208        // Check for unique field names at write time.
209        if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len()
210            != struct_dtype.names().len()
211        {
212            vortex_bail!("StructLayout must have unique field names");
213        }
214        let is_nullable = dtype.is_nullable();
215
216        // Optimization: when there are no fields, don't spawn any work and just write a trivial
217        // StructLayout.
218        if struct_dtype.nfields() == 0 && !is_nullable {
219            let row_count = stream
220                .try_fold(
221                    0u64,
222                    |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) },
223                )
224                .await?;
225            return Ok(StructLayout::new(row_count, dtype, vec![]).into_layout());
226        }
227
228        // stream<struct_chunk> -> stream<vec<column_chunk>>
229        let columns_vec_stream = stream.map(move |chunk| {
230            let (sequence_id, chunk) = chunk?;
231            let mut sequence_pointer = sequence_id.descend();
232            let struct_chunk = chunk.to_struct();
233            let mut columns: Vec<(SequenceId, ArrayRef)> = Vec::new();
234            if is_nullable {
235                columns.push((
236                    sequence_pointer.advance(),
237                    chunk.validity_mask()?.into_array(),
238                ));
239            }
240
241            columns.extend(
242                struct_chunk
243                    .unmasked_fields()
244                    .iter()
245                    .map(|field| (sequence_pointer.advance(), field.to_array())),
246            );
247
248            Ok(columns)
249        });
250
251        let mut stream_count = struct_dtype.nfields();
252        if is_nullable {
253            stream_count += 1;
254        }
255
256        let (column_streams_tx, column_streams_rx): (Vec<_>, Vec<_>) =
257            (0..stream_count).map(|_| kanal::bounded_async(1)).unzip();
258
259        // Spawn a task to fan out column chunks to their respective transposed streams
260        handle
261            .spawn(async move {
262                pin_mut!(columns_vec_stream);
263                while let Some(result) = columns_vec_stream.next().await {
264                    match result {
265                        Ok(columns) => {
266                            for (tx, column) in column_streams_tx.iter().zip_eq(columns.into_iter())
267                            {
268                                let _ = tx.send(Ok(column)).await;
269                            }
270                        }
271                        Err(e) => {
272                            let e: Arc<VortexError> = Arc::new(e);
273                            for tx in column_streams_tx.iter() {
274                                let _ = tx.send(Err(VortexError::from(e.clone()))).await;
275                            }
276                            break;
277                        }
278                    }
279                }
280            })
281            .detach();
282
283        // First child column is the validity, subsequence children are the individual struct fields
284        let column_dtypes: Vec<DType> = if is_nullable {
285            std::iter::once(DType::Bool(Nullability::NonNullable))
286                .chain(struct_dtype.fields())
287                .collect()
288        } else {
289            struct_dtype.fields().collect()
290        };
291
292        let column_names: Vec<FieldName> = if is_nullable {
293            std::iter::once(FieldName::from("__validity"))
294                .chain(struct_dtype.names().iter().cloned())
295                .collect()
296        } else {
297            struct_dtype.names().iter().cloned().collect()
298        };
299
300        let layout_futures: Vec<_> = column_dtypes
301            .into_iter()
302            .zip_eq(column_streams_rx)
303            .zip_eq(column_names)
304            .enumerate()
305            .map(move |(index, ((dtype, recv), name))| {
306                let column_stream =
307                    SequentialStreamAdapter::new(dtype.clone(), recv.into_stream().boxed())
308                        .sendable();
309                let child_eof = eof.split_off();
310                let field = Field::Name(name.clone());
311                handle.spawn_nested(|h| {
312                    let validity = self.validity.clone();
313                    // descend further and try with new fields
314                    let writer = self
315                        .leaf_writers
316                        .get(&FieldPath::from_name(name))
317                        .cloned()
318                        .unwrap_or_else(|| {
319                            if dtype.is_struct() {
320                                // Step into the field path for struct columns
321                                Arc::new(self.descend(&field))
322                            } else {
323                                // Use fallback for leaf columns
324                                self.fallback.clone()
325                            }
326                        });
327                    let ctx = ctx.clone();
328                    let segment_sink = segment_sink.clone();
329
330                    async move {
331                        // If we have a matching writer, we use it.
332                        // Otherwise, we descend into a new modified one.
333                        // Write validity stream
334                        if index == 0 && is_nullable {
335                            validity
336                                .write_stream(ctx, segment_sink, column_stream, child_eof, h)
337                                .await
338                        } else {
339                            // Use the underlying writer, otherwise use the fallback writer.
340                            writer
341                                .write_stream(ctx, segment_sink, column_stream, child_eof, h)
342                                .await
343                        }
344                    }
345                })
346            })
347            .collect();
348
349        let column_layouts = try_join_all(layout_futures).await?;
350        // TODO(os): transposed stream could count row counts as well,
351        // This must hold though, all columns must have the same row count of the struct layout
352        let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0);
353        Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout())
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use std::sync::Arc;
360
361    use vortex_dtype::FieldPath;
362    use vortex_dtype::field_path;
363
364    use crate::layouts::flat::writer::FlatLayoutStrategy;
365    use crate::layouts::table::TableStrategy;
366
367    #[test]
368    #[should_panic(
369        expected = "Override for field_path $a.$b conflicts with existing override for $a.$b.$c"
370    )]
371    fn test_overlapping_paths_fail() {
372        let flat = Arc::new(FlatLayoutStrategy::default());
373
374        // Success
375        let path = TableStrategy::new(flat.clone(), flat.clone())
376            .with_field_writer(field_path!(a.b.c), flat.clone());
377
378        // Should panic right here.
379        let _path = path.with_field_writer(field_path!(a.b), flat);
380    }
381
382    #[test]
383    #[should_panic(
384        expected = "Do not set override as a root strategy, instead set the default strategy"
385    )]
386    fn test_root_override() {
387        let flat = Arc::new(FlatLayoutStrategy::default());
388        let _strategy = TableStrategy::new(flat.clone(), flat.clone())
389            .with_field_writer(FieldPath::root(), flat);
390    }
391}