tpchgen_cli/
generate.rs

1//! Parallel data generation: [`Source`] and [`Sink`] and [`generate_in_chunks`]
2//!
3//! These traits and function are used to generate data in parallel and write it to a sink
4//! in streaming fashion (chunks). This is useful for generating large datasets that don't fit in memory.
5
6use futures::StreamExt;
7use log::debug;
8use std::collections::VecDeque;
9use std::io;
10use std::sync::{Arc, Mutex};
11use tokio::task::JoinSet;
12
13/// Something that knows how to generate data into a buffer
14///
15/// For example, this is implemented for the different generators in the tpchgen
16/// crate
17pub trait Source: Send {
18    /// generates the data for this generator into the buffer, returning the buffer.
19    fn create(self, buffer: Vec<u8>) -> Vec<u8>;
20
21    /// Create the first line for the output, into the buffer
22    ///
23    /// This will be called before the first call to [`Self::create`] and
24    /// exactly once across all [`Source`]es
25    fn header(&self, buffer: Vec<u8>) -> Vec<u8>;
26}
27
28/// Something that can write the contents of a buffer somewhere
29///
30/// For example, this is implemented for a file writer.
31pub trait Sink: Send {
32    /// Write all data from the buffer to the sink
33    fn sink(&mut self, buffer: &[u8]) -> Result<(), io::Error>;
34
35    /// Complete and flush any remaining data from the sink
36    fn flush(self) -> Result<(), io::Error>;
37}
38
39/// Generates data in parallel from a series of [`Source`] and writes to a [`Sink`]
40///
41/// Each [`Source`] is a data generator that generates data directly into an in
42/// memory buffer.
43///
44/// This function will run the [`Source`]es in parallel up to num_threads.
45/// Data is written to the [`Sink`] in the order of the [`Source`]es in
46/// the input iterator.
47///
48/// G: Generator
49/// I: Iterator<Item = G>
50/// S: Sink that writes buffers somewhere
51pub async fn generate_in_chunks<G, I, S>(
52    mut sink: S,
53    sources: I,
54    num_threads: usize,
55) -> Result<(), io::Error>
56where
57    G: Source + 'static,
58    I: Iterator<Item = G>,
59    S: Sink + 'static,
60{
61    let recycler = BufferRecycler::new();
62    let mut sources = sources.peekable();
63
64    // use all cores to make data
65    debug!("Using {num_threads} threads");
66
67    // create a channel to communicate between the generator tasks and the writer task
68    let (tx, mut rx) = tokio::sync::mpsc::channel(num_threads);
69
70    // write the header
71    let Some(first) = sources.peek() else {
72        return Ok(()); // no sources
73    };
74    let header = first.header(Vec::new());
75    tx.send(header)
76        .await
77        .expect("tx just created, it should not be closed");
78
79    let sources_and_recyclers = sources.map(|generator| (generator, recycler.clone()));
80
81    // convert to an async stream to run on tokio
82    let mut stream = futures::stream::iter(sources_and_recyclers)
83        // each generator writes to a buffer
84        .map(async |(source, recycler)| {
85            let buffer = recycler.new_buffer(1024 * 1024 * 8);
86            // do the work in a task (on a different thread)
87            let mut join_set = JoinSet::new();
88            join_set.spawn(async move { source.create(buffer) });
89            // wait for the task to be done and return the result
90            join_set
91                .join_next()
92                .await
93                .expect("had one item")
94                .expect("join_next join is infallible unless task panics")
95        })
96        // run in parallel
97        .buffered(num_threads)
98        .map(async |buffer| {
99            // send the buffer to the writer task, in order.
100
101            // Note we ignore errors writing because if the write errors it
102            // means the channel is closed / the program is exiting so there
103            // is nothing listening to send errors
104            if let Err(e) = tx.send(buffer).await {
105                debug!("Error sending buffer to writer: {e}");
106            }
107        });
108
109    // The writer task runs in a blocking thread to avoid blocking the async
110    // runtime. It reads from the channel and writes to the sink (doing File IO)
111    let captured_recycler = recycler.clone();
112    let writer_task = tokio::task::spawn_blocking(move || {
113        while let Some(buffer) = rx.blocking_recv() {
114            sink.sink(&buffer)?;
115            captured_recycler.return_buffer(buffer);
116        }
117        // No more input, flush the sink and return
118        sink.flush()
119    });
120
121    // drive the stream to completion
122    while let Some(write_task) = stream.next().await {
123        // break early if the writer stream is done (errored)
124        if writer_task.is_finished() {
125            debug!("writer task is done early, stopping writer");
126            break;
127        }
128        write_task.await; // sends the buffer to the writer task
129    }
130    drop(stream); // drop any stream references
131    drop(tx); // drop last tx reference to tell the writer it is done.
132
133    // wait for writer to finish
134    debug!("waiting for writer task to complete");
135    writer_task.await.expect("writer task panicked")
136}
137
138/// A simple buffer recycler to avoid allocating new buffers for each part
139///
140/// Clones share the same underlying recycler, so it is not thread safe
141#[derive(Debug, Clone)]
142struct BufferRecycler {
143    buffers: Arc<Mutex<VecDeque<Vec<u8>>>>,
144}
145
146impl BufferRecycler {
147    fn new() -> Self {
148        Self {
149            buffers: Arc::new(Mutex::new(VecDeque::new())),
150        }
151    }
152    /// return a new empty buffer, with size bytes capacity
153    fn new_buffer(&self, size: usize) -> Vec<u8> {
154        let mut buffers = self.buffers.lock().unwrap();
155        if let Some(mut buffer) = buffers.pop_front() {
156            buffer.clear();
157            if size > buffer.capacity() {
158                buffer.reserve(size - buffer.capacity());
159            }
160            buffer
161        } else {
162            Vec::with_capacity(size)
163        }
164    }
165
166    fn return_buffer(&self, buffer: Vec<u8>) {
167        let mut buffers = self.buffers.lock().unwrap();
168        buffers.push_back(buffer);
169    }
170}