1use std::collections::BTreeSet;
5use std::ops::Range;
6
7use itertools::Itertools;
8use vortex_array::stats::StatBound;
9use vortex_dtype::FieldMask;
10use vortex_error::{VortexResult, vortex_err};
11
12use crate::LayoutReader;
13
14#[derive(Default, Copy, Clone, Debug)]
18pub enum SplitBy {
19 #[default]
20 Layout,
22 RowCount(usize),
24 }
26
27impl SplitBy {
28 pub(crate) fn splits(
31 &self,
32 layout_reader: &dyn LayoutReader,
33 field_mask: &[FieldMask],
34 ) -> VortexResult<Vec<Range<u64>>> {
35 Ok(match *self {
36 SplitBy::Layout => {
37 let mut row_splits = BTreeSet::<u64>::new();
38 row_splits.insert(0);
39
40 layout_reader.register_splits(field_mask, 0, &mut row_splits)?;
42
43 row_splits
44 .into_iter()
45 .tuple_windows()
46 .map(|(start, end)| start..end)
47 .collect()
48 }
49 SplitBy::RowCount(n) => {
50 let row_count = *layout_reader.row_count().to_exact().ok_or_else(|| {
51 vortex_err!("Cannot split layout by row count, row count is not exact")
52 })?;
53 let mut splits =
54 Vec::with_capacity(usize::try_from((row_count + n as u64) / n as u64)?);
55 for start in (0..row_count).step_by(n) {
56 let end = (start + n as u64).min(row_count);
57 splits.push(start..end);
58 }
59 splits
60 }
61 })
62 }
63}
64
65#[cfg(test)]
66mod test {
67 use std::sync::Arc;
68
69 use futures::executor::block_on;
70 use futures::stream;
71 use vortex_array::{ArrayContext, IntoArray};
72 use vortex_buffer::buffer;
73 use vortex_dtype::Nullability::NonNullable;
74 use vortex_dtype::{DType, FieldPath, PType};
75 use vortex_layout::layouts::flat::writer::FlatLayoutStrategy;
76 use vortex_layout::segments::{SegmentSource, SequenceWriter, TestSegments};
77 use vortex_layout::sequence::SequenceId;
78 use vortex_layout::{LayoutStrategy, SequentialStreamAdapter, SequentialStreamExt as _};
79
80 use super::*;
81
82 #[test]
83 fn test_layout_splits_flat() {
84 let segments = TestSegments::default();
85 let layout = block_on(
86 FlatLayoutStrategy::default().write_stream(
87 &ArrayContext::empty(),
88 SequenceWriter::new(Box::new(segments.clone())),
89 SequentialStreamAdapter::new(
90 DType::Primitive(PType::I32, NonNullable),
91 stream::once(async {
92 Ok((
93 SequenceId::root().downgrade(),
94 buffer![1_i32; 10].into_array(),
95 ))
96 }),
97 )
98 .sendable(),
99 ),
100 )
101 .unwrap();
102
103 let segments: Arc<dyn SegmentSource> = Arc::new(segments);
104 let reader = layout.new_reader("".into(), segments).unwrap();
105
106 let splits = SplitBy::Layout
107 .splits(reader.as_ref(), &[FieldMask::Exact(FieldPath::root())])
108 .unwrap();
109 assert_eq!(splits, vec![0..10]);
110 }
111
112 #[test]
113 fn test_row_count_splits() {
114 let segments = TestSegments::default();
115 let layout = block_on(
116 FlatLayoutStrategy::default().write_stream(
117 &ArrayContext::empty(),
118 SequenceWriter::new(Box::new(segments.clone())),
119 SequentialStreamAdapter::new(
120 DType::Primitive(PType::I32, NonNullable),
121 stream::once(async {
122 Ok((
123 SequenceId::root().downgrade(),
124 buffer![1_i32; 10].into_array(),
125 ))
126 }),
127 )
128 .sendable(),
129 ),
130 )
131 .unwrap();
132
133 let segments: Arc<dyn SegmentSource> = Arc::new(segments);
134 let reader = layout.new_reader("".into(), segments).unwrap();
135
136 let splits = SplitBy::RowCount(3)
137 .splits(reader.as_ref(), &[FieldMask::Exact(FieldPath::root())])
138 .unwrap();
139 assert_eq!(splits, vec![0..3, 3..6, 6..9, 9..10]);
140 }
141}