Skip to main content

tpcgen_cli/tpch_cli/
parquet.rs

1//! Parquet output format
2
3use crate::tpch_cli::progress::TableProgress;
4use crate::tpch_cli::statistics::WriteStatistics;
5use arrow::datatypes::SchemaRef;
6use futures::StreamExt;
7use log::debug;
8use parquet::arrow::arrow_writer::{compute_leaves, ArrowColumnChunk};
9use parquet::arrow::ArrowSchemaConverter;
10use parquet::basic::Compression;
11use parquet::file::properties::WriterProperties;
12use parquet::file::writer::SerializedFileWriter;
13use parquet::schema::types::SchemaDescPtr;
14use std::io;
15use std::io::Write;
16use std::sync::Arc;
17use tokio::sync::mpsc::{Receiver, Sender};
18use tpchgen_arrow::RecordBatchIterator;
19
20pub trait IntoSize {
21    /// Convert the object into a size
22    fn into_size(self) -> Result<usize, io::Error>;
23}
24
25/// Converts a set of RecordBatchIterators into a Parquet file
26///
27/// Uses num_threads to generate the data in parallel
28///
29/// Note the input is an iterator of [`RecordBatchIterator`]; The batches
30/// produced by each iterator is encoded as its own row group.
31pub async fn generate_parquet<W: Write + Send + IntoSize + 'static, I>(
32    writer: W,
33    iter_iter: I,
34    num_threads: usize,
35    parquet_compression: Compression,
36) -> Result<(), io::Error>
37where
38    I: Iterator<Item: RecordBatchIterator> + 'static,
39{
40    generate_parquet_with_progress(
41        writer,
42        iter_iter,
43        num_threads,
44        parquet_compression,
45        TableProgress::default(),
46    )
47    .await
48}
49
50pub(crate) async fn generate_parquet_with_progress<W: Write + Send + IntoSize + 'static, I>(
51    writer: W,
52    iter_iter: I,
53    num_threads: usize,
54    parquet_compression: Compression,
55    progress: TableProgress,
56) -> Result<(), io::Error>
57where
58    I: Iterator<Item: RecordBatchIterator> + 'static,
59{
60    debug!(
61        "Generating Parquet with {num_threads} threads, using {parquet_compression} compression"
62    );
63    // Based on example in https://docs.rs/parquet/latest/parquet/arrow/arrow_writer/struct.ArrowColumnWriter.html
64    let mut iter_iter = iter_iter.peekable();
65
66    // get schema from the first iterator
67    let Some(first_iter) = iter_iter.peek() else {
68        return Ok(()); // no data shrug
69    };
70    let schema = Arc::clone(first_iter.schema());
71
72    // Compute the parquet schema
73    let writer_properties = WriterProperties::builder()
74        .set_compression(parquet_compression)
75        .build();
76    let writer_properties = Arc::new(writer_properties);
77    let parquet_schema = Arc::new(
78        ArrowSchemaConverter::new()
79            .with_coerce_types(writer_properties.coerce_types())
80            .convert(&schema)
81            .unwrap(),
82    );
83
84    // create a stream that computes the data for each row group
85    let mut row_group_stream = futures::stream::iter(iter_iter)
86        .map(async |iter| {
87            let parquet_schema = Arc::clone(&parquet_schema);
88            let writer_properties = Arc::clone(&writer_properties);
89            let schema = Arc::clone(&schema);
90            // run on a separate thread
91            tokio::task::spawn(async move {
92                encode_row_group(parquet_schema, writer_properties, schema, iter)
93            })
94            .await
95            .expect("Inner task panicked")
96        })
97        .buffered(num_threads); // generate row groups in parallel
98
99    let mut statistics = WriteStatistics::new("row groups");
100
101    // A blocking task that writes the row groups to the file
102    // done in a blocking task to avoid having a thread waiting on IO
103    // Now, read each completed row group and write it to the file
104    let root_schema = parquet_schema.root_schema_ptr();
105    let writer_properties_captured = Arc::clone(&writer_properties);
106    let (tx, mut rx): (
107        Sender<Vec<ArrowColumnChunk>>,
108        Receiver<Vec<ArrowColumnChunk>>,
109    ) = tokio::sync::mpsc::channel(num_threads);
110    let writer_task = tokio::task::spawn_blocking(move || {
111        // Create parquet writer
112        let mut writer =
113            SerializedFileWriter::new(writer, root_schema, writer_properties_captured).unwrap();
114
115        while let Some(column_chunks) = rx.blocking_recv() {
116            // Start row group
117            let mut row_group_writer = writer.next_row_group().unwrap();
118
119            // Slap the chunks into the row group
120            for column_chunk in column_chunks {
121                column_chunk
122                    .append_to_row_group(&mut row_group_writer)
123                    .unwrap();
124            }
125            row_group_writer.close().unwrap();
126            statistics.increment_chunks(1);
127            progress.increment_output_unit();
128        }
129        let size = writer.into_inner()?.into_size()?;
130        statistics.increment_bytes(size);
131        Ok(()) as Result<(), io::Error>
132    });
133
134    // now, drive the input stream and send results to the writer task
135    while let Some(column_chunks) = row_group_stream.next().await {
136        // send the chunks to the writer task
137        if let Err(e) = tx.send(column_chunks).await {
138            debug!("Error sending row group to writer: {e}");
139            break; // stop early
140        }
141    }
142    // signal the writer task that we are done
143    drop(tx);
144
145    // Wait for the writer task to finish
146    writer_task.await??;
147
148    Ok(())
149}
150
151/// Creates the data for a particular row group
152///
153/// Note at the moment it does not use multiple tasks/threads but it could
154/// potentially encode multiple columns with different threads .
155///
156/// Returns an array of [`ArrowColumnChunk`]
157fn encode_row_group<I>(
158    parquet_schema: SchemaDescPtr,
159    writer_properties: Arc<WriterProperties>,
160    schema: SchemaRef,
161    iter: I,
162) -> Vec<ArrowColumnChunk>
163where
164    I: RecordBatchIterator,
165{
166    // Create writers for each of the leaf columns
167    #[allow(deprecated)]
168    let mut col_writers = parquet::arrow::arrow_writer::get_column_writers(
169        &parquet_schema,
170        &writer_properties,
171        &schema,
172    )
173    .unwrap();
174
175    // generate the data and send it to the tasks (via the sender channels)
176    for batch in iter {
177        let columns = batch.columns().iter();
178        let col_writers = col_writers.iter_mut();
179        let fields = schema.fields().iter();
180
181        for ((col_writer, field), arr) in col_writers.zip(fields).zip(columns) {
182            for leaves in compute_leaves(field.as_ref(), arr).unwrap() {
183                col_writer.write(&leaves).unwrap();
184            }
185        }
186    }
187    // finish the writers and create the column chunks
188    col_writers
189        .into_iter()
190        .map(|col_writer| col_writer.close().unwrap())
191        .collect()
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use crate::tpch_cli::progress::{ProgressTracker, RunProgress};
198    use crate::tpch_cli::Table;
199    use std::fs::File;
200    use std::io::BufWriter;
201    use std::sync::atomic::{AtomicU64, Ordering};
202    use std::sync::Arc;
203    use tpchgen::generators::RegionGenerator;
204    use tpchgen_arrow::RegionArrow;
205
206    #[derive(Debug)]
207    struct CountingProgress {
208        increments: AtomicU64,
209    }
210
211    impl ProgressTracker for CountingProgress {
212        fn increment(&self, _table: Table, row_groups: u64) {
213            self.increments.fetch_add(row_groups, Ordering::Relaxed);
214        }
215    }
216
217    fn region_source() -> RegionArrow {
218        RegionArrow::new(RegionGenerator::default()).with_batch_size(5)
219    }
220
221    #[tokio::test]
222    async fn progress_counts_written_row_groups() {
223        let output_dir = tempfile::tempdir().unwrap();
224        let output_path = output_dir.path().join("progress.parquet");
225        let writer = BufWriter::new(File::create(&output_path).unwrap());
226
227        let tracker = Arc::new(CountingProgress {
228            increments: AtomicU64::new(0),
229        });
230        let progress: Arc<dyn ProgressTracker> = tracker.clone();
231        let progress = RunProgress::with_tracker(progress).for_table(Table::Region);
232
233        generate_parquet_with_progress(
234            writer,
235            vec![region_source(), region_source()].into_iter(),
236            1,
237            Compression::UNCOMPRESSED,
238            progress,
239        )
240        .await
241        .unwrap();
242
243        assert_eq!(tracker.increments.load(Ordering::Relaxed), 2);
244        assert!(std::fs::metadata(output_path).unwrap().len() > 0);
245    }
246}