vortex_scan/
split_by.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::collections::BTreeSet;
5use std::iter::once;
6
7use vortex_array::stats::StatBound;
8use vortex_dtype::FieldMask;
9use vortex_error::{VortexResult, vortex_err};
10use vortex_layout::LayoutReader;
11
12/// Defines how the Vortex file is split into batches for reading.
13///
14/// Note that each split must fit into the platform's maximum usize.
15#[derive(Default, Copy, Clone, Debug)]
16pub enum SplitBy {
17    #[default]
18    /// Splits any time there is a chunk boundary in the file.
19    Layout,
20    /// Splits every n rows.
21    RowCount(usize),
22    // UncompressedSize(u64),
23}
24
25impl SplitBy {
26    /// Compute the splits for the given layout.
27    // TODO(ngates): remove this once layout readers are stream based.
28    pub fn splits(
29        &self,
30        layout_reader: &dyn LayoutReader,
31        field_mask: &[FieldMask],
32    ) -> VortexResult<BTreeSet<u64>> {
33        Ok(match *self {
34            SplitBy::Layout => {
35                let mut row_splits = BTreeSet::<u64>::new();
36                row_splits.insert(0);
37
38                // Register the splits for all the layouts.
39                layout_reader.register_splits(field_mask, 0, &mut row_splits)?;
40                row_splits
41            }
42            SplitBy::RowCount(n) => {
43                let row_count = *layout_reader.row_count().to_exact().ok_or_else(|| {
44                    vortex_err!("Cannot split layout by row count, row count is not exact")
45                })?;
46                (0..row_count).step_by(n).chain(once(row_count)).collect()
47            }
48        })
49    }
50}
51
52#[cfg(test)]
53mod test {
54    use std::sync::Arc;
55
56    use vortex_array::{ArrayContext, IntoArray};
57    use vortex_buffer::buffer;
58    use vortex_dtype::FieldPath;
59    use vortex_io::runtime::single::block_on;
60    use vortex_layout::LayoutStrategy;
61    use vortex_layout::layouts::flat::writer::FlatLayoutStrategy;
62    use vortex_layout::segments::TestSegments;
63    use vortex_layout::sequence::{SequenceId, SequentialArrayStreamExt};
64
65    use super::*;
66
67    #[test]
68    fn test_layout_splits_flat() {
69        let ctx = ArrayContext::empty();
70        let segments = Arc::new(TestSegments::default());
71        let (ptr, eof) = SequenceId::root().split();
72        let layout = block_on(|handle| async {
73            FlatLayoutStrategy::default()
74                .write_stream(
75                    ctx,
76                    segments.clone(),
77                    buffer![1_i32; 10]
78                        .into_array()
79                        .to_array_stream()
80                        .sequenced(ptr),
81                    eof,
82                    handle,
83                )
84                .await
85        })
86        .unwrap();
87
88        let reader = layout.new_reader("".into(), segments).unwrap();
89
90        let splits = SplitBy::Layout
91            .splits(reader.as_ref(), &[FieldMask::Exact(FieldPath::root())])
92            .unwrap();
93        assert_eq!(splits, [0, 10].into_iter().collect());
94    }
95
96    #[test]
97    fn test_row_count_splits() {
98        let ctx = ArrayContext::empty();
99        let segments = Arc::new(TestSegments::default());
100        let (ptr, eof) = SequenceId::root().split();
101        let layout = block_on(|handle| async {
102            FlatLayoutStrategy::default()
103                .write_stream(
104                    ctx,
105                    segments.clone(),
106                    buffer![1_i32; 10]
107                        .into_array()
108                        .to_array_stream()
109                        .sequenced(ptr),
110                    eof,
111                    handle,
112                )
113                .await
114        })
115        .unwrap();
116
117        let reader = layout.new_reader("".into(), segments).unwrap();
118
119        let splits = SplitBy::RowCount(3)
120            .splits(reader.as_ref(), &[FieldMask::Exact(FieldPath::root())])
121            .unwrap();
122        assert_eq!(splits, [0, 3, 6, 9, 10].into_iter().collect());
123    }
124}