Skip to main content

vortex_layout/scan/
split_by.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::collections::BTreeSet;
5use std::iter::once;
6use std::ops::Range;
7
8use vortex_array::dtype::FieldMask;
9use vortex_error::VortexResult;
10
11use crate::LayoutReader;
12use crate::SplitRange;
13
14/// Defines how the Vortex file is split into batches for reading.
15///
16/// Note that each split must fit into the platform's maximum usize.
17#[derive(Default, Copy, Clone, Debug)]
18pub enum SplitBy {
19    #[default]
20    /// Splits any time there is a chunk boundary in the file.
21    Layout,
22    /// Splits every n rows.
23    RowCount(usize),
24    // UncompressedSize(u64),
25}
26
27impl SplitBy {
28    /// Compute the splits for the given layout.
29    // TODO(ngates): remove this once layout readers are stream based.
30    pub fn splits(
31        &self,
32        layout_reader: &dyn LayoutReader,
33        row_range: &Range<u64>,
34        field_mask: &[FieldMask],
35    ) -> VortexResult<BTreeSet<u64>> {
36        Ok(match *self {
37            SplitBy::Layout => {
38                let mut row_splits = BTreeSet::<u64>::new();
39                row_splits.insert(row_range.start);
40
41                // Register all splits in the row range for all layouts that are needed
42                // to read the field mask.
43                layout_reader.register_splits(
44                    field_mask,
45                    &SplitRange::root(row_range.clone())?,
46                    &mut row_splits,
47                )?;
48                row_splits
49            }
50            SplitBy::RowCount(n) => row_range
51                .clone()
52                .step_by(n)
53                .chain(once(row_range.end))
54                .collect(),
55        })
56    }
57}
58
59#[cfg(test)]
60mod test {
61    use std::sync::Arc;
62
63    use vortex_array::ArrayContext;
64    use vortex_array::IntoArray;
65    use vortex_array::dtype::FieldPath;
66    use vortex_buffer::buffer;
67    use vortex_io::runtime::single::block_on;
68    use vortex_io::session::RuntimeSessionExt;
69
70    use super::*;
71    use crate::LayoutReaderRef;
72    use crate::LayoutStrategy;
73    use crate::layouts::flat::writer::FlatLayoutStrategy;
74    use crate::scan::test::SCAN_SESSION;
75    use crate::segments::TestSegments;
76    use crate::sequence::SequenceId;
77    use crate::sequence::SequentialArrayStreamExt;
78
79    fn reader() -> LayoutReaderRef {
80        let ctx = ArrayContext::empty();
81        let segments = Arc::new(TestSegments::default());
82        let (ptr, eof) = SequenceId::root().split();
83        let layout = block_on(|handle| async {
84            let session = SCAN_SESSION.clone().with_handle(handle);
85            FlatLayoutStrategy::default()
86                .write_stream(
87                    ctx,
88                    Arc::<TestSegments>::clone(&segments),
89                    buffer![1_i32; 10]
90                        .into_array()
91                        .to_array_stream()
92                        .sequenced(ptr),
93                    eof,
94                    &session,
95                )
96                .await
97        })
98        .unwrap();
99
100        layout
101            .new_reader("".into(), segments, &SCAN_SESSION)
102            .unwrap()
103    }
104
105    #[test]
106    fn test_layout_splits_flat() {
107        let reader = reader();
108
109        let splits = SplitBy::Layout
110            .splits(
111                reader.as_ref(),
112                &(0..10),
113                &[FieldMask::Exact(FieldPath::root())],
114            )
115            .unwrap();
116        assert_eq!(splits, [0, 10].into_iter().collect());
117    }
118
119    #[test]
120    fn test_row_count_splits() {
121        let reader = reader();
122
123        let splits = SplitBy::RowCount(3)
124            .splits(
125                reader.as_ref(),
126                &(0..10),
127                &[FieldMask::Exact(FieldPath::root())],
128            )
129            .unwrap();
130        assert_eq!(splits, [0, 3, 6, 9, 10].into_iter().collect());
131    }
132}