1use std::collections::BTreeSet;
5use std::iter::once;
6
7use vortex_array::stats::StatBound;
8use vortex_dtype::FieldMask;
9use vortex_error::{VortexResult, vortex_err};
10use vortex_layout::LayoutReader;
11
12#[derive(Default, Copy, Clone, Debug)]
16pub enum SplitBy {
17 #[default]
18 Layout,
20 RowCount(usize),
22 }
24
25impl SplitBy {
26 pub fn splits(
29 &self,
30 layout_reader: &dyn LayoutReader,
31 field_mask: &[FieldMask],
32 ) -> VortexResult<BTreeSet<u64>> {
33 Ok(match *self {
34 SplitBy::Layout => {
35 let mut row_splits = BTreeSet::<u64>::new();
36 row_splits.insert(0);
37
38 layout_reader.register_splits(field_mask, 0, &mut row_splits)?;
40 row_splits
41 }
42 SplitBy::RowCount(n) => {
43 let row_count = *layout_reader.row_count().to_exact().ok_or_else(|| {
44 vortex_err!("Cannot split layout by row count, row count is not exact")
45 })?;
46 (0..row_count).step_by(n).chain(once(row_count)).collect()
47 }
48 })
49 }
50}
51
52#[cfg(test)]
53mod test {
54 use std::sync::Arc;
55
56 use vortex_array::{ArrayContext, IntoArray};
57 use vortex_buffer::buffer;
58 use vortex_dtype::FieldPath;
59 use vortex_io::runtime::single::block_on;
60 use vortex_layout::LayoutStrategy;
61 use vortex_layout::layouts::flat::writer::FlatLayoutStrategy;
62 use vortex_layout::segments::TestSegments;
63 use vortex_layout::sequence::{SequenceId, SequentialArrayStreamExt};
64
65 use super::*;
66
67 #[test]
68 fn test_layout_splits_flat() {
69 let ctx = ArrayContext::empty();
70 let segments = Arc::new(TestSegments::default());
71 let (ptr, eof) = SequenceId::root().split();
72 let layout = block_on(|handle| async {
73 FlatLayoutStrategy::default()
74 .write_stream(
75 ctx,
76 segments.clone(),
77 buffer![1_i32; 10]
78 .into_array()
79 .to_array_stream()
80 .sequenced(ptr),
81 eof,
82 handle,
83 )
84 .await
85 })
86 .unwrap();
87
88 let reader = layout.new_reader("".into(), segments).unwrap();
89
90 let splits = SplitBy::Layout
91 .splits(reader.as_ref(), &[FieldMask::Exact(FieldPath::root())])
92 .unwrap();
93 assert_eq!(splits, [0, 10].into_iter().collect());
94 }
95
96 #[test]
97 fn test_row_count_splits() {
98 let ctx = ArrayContext::empty();
99 let segments = Arc::new(TestSegments::default());
100 let (ptr, eof) = SequenceId::root().split();
101 let layout = block_on(|handle| async {
102 FlatLayoutStrategy::default()
103 .write_stream(
104 ctx,
105 segments.clone(),
106 buffer![1_i32; 10]
107 .into_array()
108 .to_array_stream()
109 .sequenced(ptr),
110 eof,
111 handle,
112 )
113 .await
114 })
115 .unwrap();
116
117 let reader = layout.new_reader("".into(), segments).unwrap();
118
119 let splits = SplitBy::RowCount(3)
120 .splits(reader.as_ref(), &[FieldMask::Exact(FieldPath::root())])
121 .unwrap();
122 assert_eq!(splits, [0, 3, 6, 9, 10].into_iter().collect());
123 }
124}