tpcgen_cli/tpch_cli/
generate.rs1use crate::tpch_cli::progress::TableProgress;
7use futures::StreamExt;
8use log::debug;
9use std::collections::VecDeque;
10use std::io;
11use std::sync::{Arc, Mutex};
12use tokio::task::JoinSet;
13
14pub trait Source: Send {
19 fn create(self, buffer: Vec<u8>) -> Vec<u8>;
21
22 fn header(&self, buffer: Vec<u8>) -> Vec<u8>;
27}
28
29pub trait Sink: Send {
33 fn sink(&mut self, buffer: &[u8]) -> Result<(), io::Error>;
35
36 fn flush(self) -> Result<(), io::Error>;
38}
39
40pub async fn generate_in_chunks<G, I, S>(
53 sink: S,
54 sources: I,
55 num_threads: usize,
56) -> Result<(), io::Error>
57where
58 G: Source + 'static,
59 I: Iterator<Item = G>,
60 S: Sink + 'static,
61{
62 generate_in_chunks_with_progress(sink, sources, num_threads, TableProgress::default()).await
63}
64
65pub(crate) async fn generate_in_chunks_with_progress<G, I, S>(
66 mut sink: S,
67 sources: I,
68 num_threads: usize,
69 progress: TableProgress,
70) -> Result<(), io::Error>
71where
72 G: Source + 'static,
73 I: Iterator<Item = G>,
74 S: Sink + 'static,
75{
76 let recycler = BufferRecycler::new();
77 let mut sources = sources.peekable();
78
79 debug!("Using {num_threads} threads");
81
82 let Some(first) = sources.peek() else {
83 return Ok(()); };
85 let header = first.header(Vec::new());
86
87 let sources_and_recyclers = sources.map(|generator| (generator, recycler.clone()));
88
89 let (tx, mut rx) = tokio::sync::mpsc::channel(num_threads);
91
92 let mut stream = futures::stream::iter(sources_and_recyclers)
94 .map(async |(source, recycler)| {
96 let buffer = recycler.new_buffer(1024 * 1024 * 8);
97 let mut join_set = JoinSet::new();
99 join_set.spawn(async move { source.create(buffer) });
100 join_set
102 .join_next()
103 .await
104 .expect("had one item")
105 .expect("join_next join is infallible unless task panics")
106 })
107 .buffered(num_threads)
109 .map(async |buffer| {
110 if let Err(e) = tx.send(buffer).await {
116 debug!("Error sending buffer to writer: {e}");
117 }
118 });
119
120 let captured_recycler = recycler.clone();
123 let writer_task = tokio::task::spawn_blocking(move || {
124 sink.sink(&header)?;
126 while let Some(buffer) = rx.blocking_recv() {
127 sink.sink(&buffer)?;
128 captured_recycler.return_buffer(buffer);
129 progress.increment_output_unit();
130 }
131 sink.flush()
133 });
134
135 while let Some(write_task) = stream.next().await {
137 if writer_task.is_finished() {
139 debug!("writer task is done early, stopping writer");
140 break;
141 }
142 write_task.await; }
144 drop(stream); drop(tx); debug!("waiting for writer task to complete");
149 writer_task.await.expect("writer task panicked")
150}
151
152#[derive(Debug, Clone)]
156struct BufferRecycler {
157 buffers: Arc<Mutex<VecDeque<Vec<u8>>>>,
158}
159
160impl BufferRecycler {
161 fn new() -> Self {
162 Self {
163 buffers: Arc::new(Mutex::new(VecDeque::new())),
164 }
165 }
166 fn new_buffer(&self, size: usize) -> Vec<u8> {
168 let mut buffers = self.buffers.lock().unwrap();
169 if let Some(mut buffer) = buffers.pop_front() {
170 buffer.clear();
171 if size > buffer.capacity() {
172 buffer.reserve(size - buffer.capacity());
173 }
174 buffer
175 } else {
176 Vec::with_capacity(size)
177 }
178 }
179
180 fn return_buffer(&self, buffer: Vec<u8>) {
181 let mut buffers = self.buffers.lock().unwrap();
182 buffers.push_back(buffer);
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use crate::tpch_cli::progress::{ProgressTracker, RunProgress};
190 use crate::tpch_cli::Table;
191 use std::sync::atomic::{AtomicU64, Ordering};
192
193 #[derive(Debug)]
194 struct CountingProgress {
195 increments: AtomicU64,
196 }
197
198 impl ProgressTracker for CountingProgress {
199 fn increment(&self, _table: Table, units: u64) {
200 self.increments.fetch_add(units, Ordering::Relaxed);
201 }
202 }
203
204 struct TestSource {
205 header: &'static [u8],
206 data: &'static [u8],
207 }
208
209 impl Source for TestSource {
210 fn header(&self, mut buffer: Vec<u8>) -> Vec<u8> {
211 buffer.extend_from_slice(self.header);
212 buffer
213 }
214
215 fn create(self, mut buffer: Vec<u8>) -> Vec<u8> {
216 buffer.extend_from_slice(self.data);
217 buffer
218 }
219 }
220
221 struct CapturingSink {
222 writes: Arc<Mutex<Vec<Vec<u8>>>>,
223 }
224
225 impl Sink for CapturingSink {
226 fn sink(&mut self, buffer: &[u8]) -> Result<(), io::Error> {
227 self.writes.lock().unwrap().push(buffer.to_vec());
228 Ok(())
229 }
230
231 fn flush(self) -> Result<(), io::Error> {
232 Ok(())
233 }
234 }
235
236 #[tokio::test]
237 async fn progress_counts_generated_chunks_not_header() {
238 let writes = Arc::new(Mutex::new(Vec::new()));
239 let tracker = Arc::new(CountingProgress {
240 increments: AtomicU64::new(0),
241 });
242 let progress: Arc<dyn ProgressTracker> = tracker.clone();
243 let progress = RunProgress::with_tracker(progress).for_table(Table::Region);
244
245 let sources = vec![
246 TestSource {
247 header: b"header\n",
248 data: b"row-1\n",
249 },
250 TestSource {
251 header: b"header\n",
252 data: b"row-2\n",
253 },
254 ];
255
256 generate_in_chunks_with_progress(
257 CapturingSink {
258 writes: Arc::clone(&writes),
259 },
260 sources.into_iter(),
261 1,
262 progress,
263 )
264 .await
265 .unwrap();
266
267 assert_eq!(tracker.increments.load(Ordering::Relaxed), 2);
268 assert_eq!(
269 *writes.lock().unwrap(),
270 vec![
271 b"header\n".to_vec(),
272 b"row-1\n".to_vec(),
273 b"row-2\n".to_vec()
274 ]
275 );
276 }
277}