1use std::collections::BTreeSet;
5use std::iter::once;
6
7use vortex_array::stats::StatBound;
8use vortex_dtype::FieldMask;
9use vortex_error::{VortexResult, vortex_err};
10
11use crate::LayoutReader;
12
13#[derive(Default, Copy, Clone, Debug)]
17pub enum SplitBy {
18 #[default]
19 Layout,
21 RowCount(usize),
23 }
25
26impl SplitBy {
27 pub fn splits(
30 &self,
31 layout_reader: &dyn LayoutReader,
32 field_mask: &[FieldMask],
33 ) -> VortexResult<BTreeSet<u64>> {
34 Ok(match *self {
35 SplitBy::Layout => {
36 let mut row_splits = BTreeSet::<u64>::new();
37 row_splits.insert(0);
38
39 layout_reader.register_splits(field_mask, 0, &mut row_splits)?;
41 row_splits
42 }
43 SplitBy::RowCount(n) => {
44 let row_count = *layout_reader.row_count().to_exact().ok_or_else(|| {
45 vortex_err!("Cannot split layout by row count, row count is not exact")
46 })?;
47 (0..row_count).step_by(n).chain(once(row_count)).collect()
48 }
49 })
50 }
51}
52
53#[cfg(test)]
54mod test {
55 use std::sync::Arc;
56
57 use vortex_array::{ArrayContext, IntoArray};
58 use vortex_buffer::buffer;
59 use vortex_dtype::FieldPath;
60 use vortex_io::runtime::single::block_on;
61 use vortex_layout::LayoutStrategy;
62 use vortex_layout::layouts::flat::writer::FlatLayoutStrategy;
63 use vortex_layout::segments::TestSegments;
64 use vortex_layout::sequence::{SequenceId, SequentialArrayStreamExt};
65
66 use super::*;
67
68 #[test]
69 fn test_layout_splits_flat() {
70 let ctx = ArrayContext::empty();
71 let segments = Arc::new(TestSegments::default());
72 let (ptr, eof) = SequenceId::root().split();
73 let layout = block_on(|handle| async {
74 FlatLayoutStrategy::default()
75 .write_stream(
76 ctx,
77 segments.clone(),
78 buffer![1_i32; 10]
79 .into_array()
80 .to_array_stream()
81 .sequenced(ptr),
82 eof,
83 handle,
84 )
85 .await
86 })
87 .unwrap();
88
89 let reader = layout.new_reader("".into(), segments).unwrap();
90
91 let splits = SplitBy::Layout
92 .splits(reader.as_ref(), &[FieldMask::Exact(FieldPath::root())])
93 .unwrap();
94 assert_eq!(splits, [0, 10].into_iter().collect());
95 }
96
97 #[test]
98 fn test_row_count_splits() {
99 let ctx = ArrayContext::empty();
100 let segments = Arc::new(TestSegments::default());
101 let (ptr, eof) = SequenceId::root().split();
102 let layout = block_on(|handle| async {
103 FlatLayoutStrategy::default()
104 .write_stream(
105 ctx,
106 segments.clone(),
107 buffer![1_i32; 10]
108 .into_array()
109 .to_array_stream()
110 .sequenced(ptr),
111 eof,
112 handle,
113 )
114 .await
115 })
116 .unwrap();
117
118 let reader = layout.new_reader("".into(), segments).unwrap();
119
120 let splits = SplitBy::RowCount(3)
121 .splits(reader.as_ref(), &[FieldMask::Exact(FieldPath::root())])
122 .unwrap();
123 assert_eq!(splits, [0, 3, 6, 9, 10].into_iter().collect());
124 }
125}