1use std::collections::BTreeSet;
5use std::iter::once;
6use std::ops::Range;
7
8use vortex_dtype::FieldMask;
9use vortex_error::VortexResult;
10use vortex_layout::LayoutReader;
11
12#[derive(Default, Copy, Clone, Debug)]
16pub enum SplitBy {
17 #[default]
18 Layout,
20 RowCount(usize),
22 }
24
25impl SplitBy {
26 pub fn splits(
29 &self,
30 layout_reader: &dyn LayoutReader,
31 row_range: &Range<u64>,
32 field_mask: &[FieldMask],
33 ) -> VortexResult<BTreeSet<u64>> {
34 Ok(match *self {
35 SplitBy::Layout => {
36 let mut row_splits = BTreeSet::<u64>::new();
37 row_splits.insert(row_range.start);
38
39 layout_reader.register_splits(field_mask, row_range, &mut row_splits)?;
42 row_splits
43 }
44 SplitBy::RowCount(n) => row_range
45 .clone()
46 .step_by(n)
47 .chain(once(row_range.end))
48 .collect(),
49 })
50 }
51}
52
53#[cfg(test)]
54mod test {
55 use std::sync::Arc;
56
57 use vortex_array::ArrayContext;
58 use vortex_array::IntoArray;
59 use vortex_buffer::buffer;
60 use vortex_dtype::FieldPath;
61 use vortex_io::runtime::single::block_on;
62 use vortex_layout::LayoutReaderRef;
63 use vortex_layout::LayoutStrategy;
64 use vortex_layout::layouts::flat::writer::FlatLayoutStrategy;
65 use vortex_layout::segments::TestSegments;
66 use vortex_layout::sequence::SequenceId;
67 use vortex_layout::sequence::SequentialArrayStreamExt;
68
69 use super::*;
70 use crate::test::SESSION;
71
72 fn reader() -> LayoutReaderRef {
73 let ctx = ArrayContext::empty();
74 let segments = Arc::new(TestSegments::default());
75 let (ptr, eof) = SequenceId::root().split();
76 let layout = block_on(|handle| async {
77 FlatLayoutStrategy::default()
78 .write_stream(
79 ctx,
80 segments.clone(),
81 buffer![1_i32; 10]
82 .into_array()
83 .to_array_stream()
84 .sequenced(ptr),
85 eof,
86 handle,
87 )
88 .await
89 })
90 .unwrap();
91
92 layout.new_reader("".into(), segments, &SESSION).unwrap()
93 }
94
95 #[test]
96 fn test_layout_splits_flat() {
97 let reader = reader();
98
99 let splits = SplitBy::Layout
100 .splits(
101 reader.as_ref(),
102 &(0..10),
103 &[FieldMask::Exact(FieldPath::root())],
104 )
105 .unwrap();
106 assert_eq!(splits, [0, 10].into_iter().collect());
107 }
108
109 #[test]
110 fn test_row_count_splits() {
111 let reader = reader();
112
113 let splits = SplitBy::RowCount(3)
114 .splits(
115 reader.as_ref(),
116 &(0..10),
117 &[FieldMask::Exact(FieldPath::root())],
118 )
119 .unwrap();
120 assert_eq!(splits, [0, 3, 6, 9, 10].into_iter().collect());
121 }
122}