Skip to main content

tpcgen_cli/tpch_cli/
generate.rs

1//! Parallel data generation: [`Source`] and [`Sink`] and [`generate_in_chunks`]
2//!
3//! These traits and function are used to generate data in parallel and write it to a sink
4//! in streaming fashion (chunks). This is useful for generating large datasets that don't fit in memory.
5
6use crate::tpch_cli::progress::TableProgress;
7use futures::StreamExt;
8use log::debug;
9use std::collections::VecDeque;
10use std::io;
11use std::sync::{Arc, Mutex};
12use tokio::task::JoinSet;
13
14/// Something that knows how to generate data into a buffer
15///
16/// For example, this is implemented for the different generators in the tpchgen
17/// crate
18pub trait Source: Send {
19    /// generates the data for this generator into the buffer, returning the buffer.
20    fn create(self, buffer: Vec<u8>) -> Vec<u8>;
21
22    /// Create the first line for the output, into the buffer
23    ///
24    /// This will be called before the first call to [`Self::create`] and
25    /// exactly once across all [`Source`]es
26    fn header(&self, buffer: Vec<u8>) -> Vec<u8>;
27}
28
29/// Something that can write the contents of a buffer somewhere
30///
31/// For example, this is implemented for a file writer.
32pub trait Sink: Send {
33    /// Write all data from the buffer to the sink
34    fn sink(&mut self, buffer: &[u8]) -> Result<(), io::Error>;
35
36    /// Complete and flush any remaining data from the sink
37    fn flush(self) -> Result<(), io::Error>;
38}
39
40/// Generates data in parallel from a series of [`Source`] and writes to a [`Sink`]
41///
42/// Each [`Source`] is a data generator that generates data directly into an in
43/// memory buffer.
44///
45/// This function will run the [`Source`]es in parallel up to num_threads.
46/// Data is written to the [`Sink`] in the order of the [`Source`]es in
47/// the input iterator.
48///
49/// G: Generator
50/// I: Iterator<Item = G>
51/// S: Sink that writes buffers somewhere
52pub async fn generate_in_chunks<G, I, S>(
53    sink: S,
54    sources: I,
55    num_threads: usize,
56) -> Result<(), io::Error>
57where
58    G: Source + 'static,
59    I: Iterator<Item = G>,
60    S: Sink + 'static,
61{
62    generate_in_chunks_with_progress(sink, sources, num_threads, TableProgress::default()).await
63}
64
65pub(crate) async fn generate_in_chunks_with_progress<G, I, S>(
66    mut sink: S,
67    sources: I,
68    num_threads: usize,
69    progress: TableProgress,
70) -> Result<(), io::Error>
71where
72    G: Source + 'static,
73    I: Iterator<Item = G>,
74    S: Sink + 'static,
75{
76    let recycler = BufferRecycler::new();
77    let mut sources = sources.peekable();
78
79    // use all cores to make data
80    debug!("Using {num_threads} threads");
81
82    let Some(first) = sources.peek() else {
83        return Ok(()); // no sources
84    };
85    let header = first.header(Vec::new());
86
87    let sources_and_recyclers = sources.map(|generator| (generator, recycler.clone()));
88
89    // create a channel to communicate between the generator tasks and the writer task
90    let (tx, mut rx) = tokio::sync::mpsc::channel(num_threads);
91
92    // convert to an async stream to run on tokio
93    let mut stream = futures::stream::iter(sources_and_recyclers)
94        // each generator writes to a buffer
95        .map(async |(source, recycler)| {
96            let buffer = recycler.new_buffer(1024 * 1024 * 8);
97            // do the work in a task (on a different thread)
98            let mut join_set = JoinSet::new();
99            join_set.spawn(async move { source.create(buffer) });
100            // wait for the task to be done and return the result
101            join_set
102                .join_next()
103                .await
104                .expect("had one item")
105                .expect("join_next join is infallible unless task panics")
106        })
107        // run in parallel
108        .buffered(num_threads)
109        .map(async |buffer| {
110            // send the buffer to the writer task, in order.
111
112            // Note we ignore errors writing because if the write errors it
113            // means the channel is closed / the program is exiting so there
114            // is nothing listening to send errors
115            if let Err(e) = tx.send(buffer).await {
116                debug!("Error sending buffer to writer: {e}");
117            }
118        });
119
120    // The writer task runs in a blocking thread to avoid blocking the async
121    // runtime. It reads from the channel and writes to the sink (doing File IO)
122    let captured_recycler = recycler.clone();
123    let writer_task = tokio::task::spawn_blocking(move || {
124        // The header is not an output unit; only generated chunks from the channel advance progress.
125        sink.sink(&header)?;
126        while let Some(buffer) = rx.blocking_recv() {
127            sink.sink(&buffer)?;
128            captured_recycler.return_buffer(buffer);
129            progress.increment_output_unit();
130        }
131        // No more input, flush the sink and return
132        sink.flush()
133    });
134
135    // drive the stream to completion
136    while let Some(write_task) = stream.next().await {
137        // break early if the writer stream is done (errored)
138        if writer_task.is_finished() {
139            debug!("writer task is done early, stopping writer");
140            break;
141        }
142        write_task.await; // sends the buffer to the writer task
143    }
144    drop(stream); // drop any stream references
145    drop(tx); // drop last tx reference to tell the writer it is done.
146
147    // wait for writer to finish
148    debug!("waiting for writer task to complete");
149    writer_task.await.expect("writer task panicked")
150}
151
152/// A simple buffer recycler to avoid allocating new buffers for each part
153///
154/// Clones share the same underlying recycler, so it is not thread safe
155#[derive(Debug, Clone)]
156struct BufferRecycler {
157    buffers: Arc<Mutex<VecDeque<Vec<u8>>>>,
158}
159
160impl BufferRecycler {
161    fn new() -> Self {
162        Self {
163            buffers: Arc::new(Mutex::new(VecDeque::new())),
164        }
165    }
166    /// return a new empty buffer, with size bytes capacity
167    fn new_buffer(&self, size: usize) -> Vec<u8> {
168        let mut buffers = self.buffers.lock().unwrap();
169        if let Some(mut buffer) = buffers.pop_front() {
170            buffer.clear();
171            if size > buffer.capacity() {
172                buffer.reserve(size - buffer.capacity());
173            }
174            buffer
175        } else {
176            Vec::with_capacity(size)
177        }
178    }
179
180    fn return_buffer(&self, buffer: Vec<u8>) {
181        let mut buffers = self.buffers.lock().unwrap();
182        buffers.push_back(buffer);
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use crate::tpch_cli::progress::{ProgressTracker, RunProgress};
190    use crate::tpch_cli::Table;
191    use std::sync::atomic::{AtomicU64, Ordering};
192
193    #[derive(Debug)]
194    struct CountingProgress {
195        increments: AtomicU64,
196    }
197
198    impl ProgressTracker for CountingProgress {
199        fn increment(&self, _table: Table, units: u64) {
200            self.increments.fetch_add(units, Ordering::Relaxed);
201        }
202    }
203
204    struct TestSource {
205        header: &'static [u8],
206        data: &'static [u8],
207    }
208
209    impl Source for TestSource {
210        fn header(&self, mut buffer: Vec<u8>) -> Vec<u8> {
211            buffer.extend_from_slice(self.header);
212            buffer
213        }
214
215        fn create(self, mut buffer: Vec<u8>) -> Vec<u8> {
216            buffer.extend_from_slice(self.data);
217            buffer
218        }
219    }
220
221    struct CapturingSink {
222        writes: Arc<Mutex<Vec<Vec<u8>>>>,
223    }
224
225    impl Sink for CapturingSink {
226        fn sink(&mut self, buffer: &[u8]) -> Result<(), io::Error> {
227            self.writes.lock().unwrap().push(buffer.to_vec());
228            Ok(())
229        }
230
231        fn flush(self) -> Result<(), io::Error> {
232            Ok(())
233        }
234    }
235
236    #[tokio::test]
237    async fn progress_counts_generated_chunks_not_header() {
238        let writes = Arc::new(Mutex::new(Vec::new()));
239        let tracker = Arc::new(CountingProgress {
240            increments: AtomicU64::new(0),
241        });
242        let progress: Arc<dyn ProgressTracker> = tracker.clone();
243        let progress = RunProgress::with_tracker(progress).for_table(Table::Region);
244
245        let sources = vec![
246            TestSource {
247                header: b"header\n",
248                data: b"row-1\n",
249            },
250            TestSource {
251                header: b"header\n",
252                data: b"row-2\n",
253            },
254        ];
255
256        generate_in_chunks_with_progress(
257            CapturingSink {
258                writes: Arc::clone(&writes),
259            },
260            sources.into_iter(),
261            1,
262            progress,
263        )
264        .await
265        .unwrap();
266
267        assert_eq!(tracker.increments.load(Ordering::Relaxed), 2);
268        assert_eq!(
269            *writes.lock().unwrap(),
270            vec![
271                b"header\n".to_vec(),
272                b"row-1\n".to_vec(),
273                b"row-2\n".to_vec()
274            ]
275        );
276    }
277}