vortex_scan/
split_by.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::collections::BTreeSet;
5use std::iter::once;
6
7use vortex_array::stats::StatBound;
8use vortex_dtype::FieldMask;
9use vortex_error::{VortexResult, vortex_err};
10
11use crate::LayoutReader;
12
13/// Defines how the Vortex file is split into batches for reading.
14///
15/// Note that each split must fit into the platform's maximum usize.
16#[derive(Default, Copy, Clone, Debug)]
17pub enum SplitBy {
18    #[default]
19    /// Splits any time there is a chunk boundary in the file.
20    Layout,
21    /// Splits every n rows.
22    RowCount(usize),
23    // UncompressedSize(u64),
24}
25
26impl SplitBy {
27    /// Compute the splits for the given layout.
28    // TODO(ngates): remove this once layout readers are stream based.
29    pub fn splits(
30        &self,
31        layout_reader: &dyn LayoutReader,
32        field_mask: &[FieldMask],
33    ) -> VortexResult<BTreeSet<u64>> {
34        Ok(match *self {
35            SplitBy::Layout => {
36                let mut row_splits = BTreeSet::<u64>::new();
37                row_splits.insert(0);
38
39                // Register the splits for all the layouts.
40                layout_reader.register_splits(field_mask, 0, &mut row_splits)?;
41                row_splits
42            }
43            SplitBy::RowCount(n) => {
44                let row_count = *layout_reader.row_count().to_exact().ok_or_else(|| {
45                    vortex_err!("Cannot split layout by row count, row count is not exact")
46                })?;
47                (0..row_count).step_by(n).chain(once(row_count)).collect()
48            }
49        })
50    }
51}
52
53#[cfg(test)]
54mod test {
55    use std::sync::Arc;
56
57    use vortex_array::{ArrayContext, IntoArray};
58    use vortex_buffer::buffer;
59    use vortex_dtype::FieldPath;
60    use vortex_io::runtime::single::block_on;
61    use vortex_layout::LayoutStrategy;
62    use vortex_layout::layouts::flat::writer::FlatLayoutStrategy;
63    use vortex_layout::segments::TestSegments;
64    use vortex_layout::sequence::{SequenceId, SequentialArrayStreamExt};
65
66    use super::*;
67
68    #[test]
69    fn test_layout_splits_flat() {
70        let ctx = ArrayContext::empty();
71        let segments = Arc::new(TestSegments::default());
72        let (ptr, eof) = SequenceId::root().split();
73        let layout = block_on(|handle| async {
74            FlatLayoutStrategy::default()
75                .write_stream(
76                    ctx,
77                    segments.clone(),
78                    buffer![1_i32; 10]
79                        .into_array()
80                        .to_array_stream()
81                        .sequenced(ptr),
82                    eof,
83                    handle,
84                )
85                .await
86        })
87        .unwrap();
88
89        let reader = layout.new_reader("".into(), segments).unwrap();
90
91        let splits = SplitBy::Layout
92            .splits(reader.as_ref(), &[FieldMask::Exact(FieldPath::root())])
93            .unwrap();
94        assert_eq!(splits, [0, 10].into_iter().collect());
95    }
96
97    #[test]
98    fn test_row_count_splits() {
99        let ctx = ArrayContext::empty();
100        let segments = Arc::new(TestSegments::default());
101        let (ptr, eof) = SequenceId::root().split();
102        let layout = block_on(|handle| async {
103            FlatLayoutStrategy::default()
104                .write_stream(
105                    ctx,
106                    segments.clone(),
107                    buffer![1_i32; 10]
108                        .into_array()
109                        .to_array_stream()
110                        .sequenced(ptr),
111                    eof,
112                    handle,
113                )
114                .await
115        })
116        .unwrap();
117
118        let reader = layout.new_reader("".into(), segments).unwrap();
119
120        let splits = SplitBy::RowCount(3)
121            .splits(reader.as_ref(), &[FieldMask::Exact(FieldPath::root())])
122            .unwrap();
123        assert_eq!(splits, [0, 3, 6, 9, 10].into_iter().collect());
124    }
125}