vortex_layout/scan/
split_by.rs1use std::collections::BTreeSet;
5use std::iter::once;
6use std::ops::Range;
7
8use vortex_array::dtype::FieldMask;
9use vortex_error::VortexResult;
10
11use crate::LayoutReader;
12
13#[derive(Default, Copy, Clone, Debug)]
17pub enum SplitBy {
18 #[default]
19 Layout,
21 RowCount(usize),
23 }
25
26impl SplitBy {
27 pub fn splits(
30 &self,
31 layout_reader: &dyn LayoutReader,
32 row_range: &Range<u64>,
33 field_mask: &[FieldMask],
34 ) -> VortexResult<BTreeSet<u64>> {
35 Ok(match *self {
36 SplitBy::Layout => {
37 let mut row_splits = BTreeSet::<u64>::new();
38 row_splits.insert(row_range.start);
39
40 layout_reader.register_splits(field_mask, row_range, &mut row_splits)?;
43 row_splits
44 }
45 SplitBy::RowCount(n) => row_range
46 .clone()
47 .step_by(n)
48 .chain(once(row_range.end))
49 .collect(),
50 })
51 }
52}
53
54#[cfg(test)]
55mod test {
56 use std::sync::Arc;
57
58 use vortex_array::ArrayContext;
59 use vortex_array::IntoArray;
60 use vortex_array::dtype::FieldPath;
61 use vortex_buffer::buffer;
62 use vortex_io::runtime::single::block_on;
63
64 use super::*;
65 use crate::LayoutReaderRef;
66 use crate::LayoutStrategy;
67 use crate::layouts::flat::writer::FlatLayoutStrategy;
68 use crate::scan::test::SCAN_SESSION;
69 use crate::segments::TestSegments;
70 use crate::sequence::SequenceId;
71 use crate::sequence::SequentialArrayStreamExt;
72
73 fn reader() -> LayoutReaderRef {
74 let ctx = ArrayContext::empty();
75 let segments = Arc::new(TestSegments::default());
76 let (ptr, eof) = SequenceId::root().split();
77 let layout = block_on(|handle| async {
78 FlatLayoutStrategy::default()
79 .write_stream(
80 ctx,
81 segments.clone(),
82 buffer![1_i32; 10]
83 .into_array()
84 .to_array_stream()
85 .sequenced(ptr),
86 eof,
87 handle,
88 )
89 .await
90 })
91 .unwrap();
92
93 layout
94 .new_reader("".into(), segments, &SCAN_SESSION)
95 .unwrap()
96 }
97
98 #[test]
99 fn test_layout_splits_flat() {
100 let reader = reader();
101
102 let splits = SplitBy::Layout
103 .splits(
104 reader.as_ref(),
105 &(0..10),
106 &[FieldMask::Exact(FieldPath::root())],
107 )
108 .unwrap();
109 assert_eq!(splits, [0, 10].into_iter().collect());
110 }
111
112 #[test]
113 fn test_row_count_splits() {
114 let reader = reader();
115
116 let splits = SplitBy::RowCount(3)
117 .splits(
118 reader.as_ref(),
119 &(0..10),
120 &[FieldMask::Exact(FieldPath::root())],
121 )
122 .unwrap();
123 assert_eq!(splits, [0, 3, 6, 9, 10].into_iter().collect());
124 }
125}