tpcgen_cli/tpch_cli/
parquet.rs1use crate::tpch_cli::progress::TableProgress;
4use crate::tpch_cli::statistics::WriteStatistics;
5use arrow::datatypes::SchemaRef;
6use futures::StreamExt;
7use log::debug;
8use parquet::arrow::arrow_writer::{compute_leaves, ArrowColumnChunk};
9use parquet::arrow::ArrowSchemaConverter;
10use parquet::basic::Compression;
11use parquet::file::properties::WriterProperties;
12use parquet::file::writer::SerializedFileWriter;
13use parquet::schema::types::SchemaDescPtr;
14use std::io;
15use std::io::Write;
16use std::sync::Arc;
17use tokio::sync::mpsc::{Receiver, Sender};
18use tpchgen_arrow::RecordBatchIterator;
19
20pub trait IntoSize {
21 fn into_size(self) -> Result<usize, io::Error>;
23}
24
25pub async fn generate_parquet<W: Write + Send + IntoSize + 'static, I>(
32 writer: W,
33 iter_iter: I,
34 num_threads: usize,
35 parquet_compression: Compression,
36) -> Result<(), io::Error>
37where
38 I: Iterator<Item: RecordBatchIterator> + 'static,
39{
40 generate_parquet_with_progress(
41 writer,
42 iter_iter,
43 num_threads,
44 parquet_compression,
45 TableProgress::default(),
46 )
47 .await
48}
49
50pub(crate) async fn generate_parquet_with_progress<W: Write + Send + IntoSize + 'static, I>(
51 writer: W,
52 iter_iter: I,
53 num_threads: usize,
54 parquet_compression: Compression,
55 progress: TableProgress,
56) -> Result<(), io::Error>
57where
58 I: Iterator<Item: RecordBatchIterator> + 'static,
59{
60 debug!(
61 "Generating Parquet with {num_threads} threads, using {parquet_compression} compression"
62 );
63 let mut iter_iter = iter_iter.peekable();
65
66 let Some(first_iter) = iter_iter.peek() else {
68 return Ok(()); };
70 let schema = Arc::clone(first_iter.schema());
71
72 let writer_properties = WriterProperties::builder()
74 .set_compression(parquet_compression)
75 .build();
76 let writer_properties = Arc::new(writer_properties);
77 let parquet_schema = Arc::new(
78 ArrowSchemaConverter::new()
79 .with_coerce_types(writer_properties.coerce_types())
80 .convert(&schema)
81 .unwrap(),
82 );
83
84 let mut row_group_stream = futures::stream::iter(iter_iter)
86 .map(async |iter| {
87 let parquet_schema = Arc::clone(&parquet_schema);
88 let writer_properties = Arc::clone(&writer_properties);
89 let schema = Arc::clone(&schema);
90 tokio::task::spawn(async move {
92 encode_row_group(parquet_schema, writer_properties, schema, iter)
93 })
94 .await
95 .expect("Inner task panicked")
96 })
97 .buffered(num_threads); let mut statistics = WriteStatistics::new("row groups");
100
101 let root_schema = parquet_schema.root_schema_ptr();
105 let writer_properties_captured = Arc::clone(&writer_properties);
106 let (tx, mut rx): (
107 Sender<Vec<ArrowColumnChunk>>,
108 Receiver<Vec<ArrowColumnChunk>>,
109 ) = tokio::sync::mpsc::channel(num_threads);
110 let writer_task = tokio::task::spawn_blocking(move || {
111 let mut writer =
113 SerializedFileWriter::new(writer, root_schema, writer_properties_captured).unwrap();
114
115 while let Some(column_chunks) = rx.blocking_recv() {
116 let mut row_group_writer = writer.next_row_group().unwrap();
118
119 for column_chunk in column_chunks {
121 column_chunk
122 .append_to_row_group(&mut row_group_writer)
123 .unwrap();
124 }
125 row_group_writer.close().unwrap();
126 statistics.increment_chunks(1);
127 progress.increment_output_unit();
128 }
129 let size = writer.into_inner()?.into_size()?;
130 statistics.increment_bytes(size);
131 Ok(()) as Result<(), io::Error>
132 });
133
134 while let Some(column_chunks) = row_group_stream.next().await {
136 if let Err(e) = tx.send(column_chunks).await {
138 debug!("Error sending row group to writer: {e}");
139 break; }
141 }
142 drop(tx);
144
145 writer_task.await??;
147
148 Ok(())
149}
150
151fn encode_row_group<I>(
158 parquet_schema: SchemaDescPtr,
159 writer_properties: Arc<WriterProperties>,
160 schema: SchemaRef,
161 iter: I,
162) -> Vec<ArrowColumnChunk>
163where
164 I: RecordBatchIterator,
165{
166 #[allow(deprecated)]
168 let mut col_writers = parquet::arrow::arrow_writer::get_column_writers(
169 &parquet_schema,
170 &writer_properties,
171 &schema,
172 )
173 .unwrap();
174
175 for batch in iter {
177 let columns = batch.columns().iter();
178 let col_writers = col_writers.iter_mut();
179 let fields = schema.fields().iter();
180
181 for ((col_writer, field), arr) in col_writers.zip(fields).zip(columns) {
182 for leaves in compute_leaves(field.as_ref(), arr).unwrap() {
183 col_writer.write(&leaves).unwrap();
184 }
185 }
186 }
187 col_writers
189 .into_iter()
190 .map(|col_writer| col_writer.close().unwrap())
191 .collect()
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use crate::tpch_cli::progress::{ProgressTracker, RunProgress};
198 use crate::tpch_cli::Table;
199 use std::fs::File;
200 use std::io::BufWriter;
201 use std::sync::atomic::{AtomicU64, Ordering};
202 use std::sync::Arc;
203 use tpchgen::generators::RegionGenerator;
204 use tpchgen_arrow::RegionArrow;
205
206 #[derive(Debug)]
207 struct CountingProgress {
208 increments: AtomicU64,
209 }
210
211 impl ProgressTracker for CountingProgress {
212 fn increment(&self, _table: Table, row_groups: u64) {
213 self.increments.fetch_add(row_groups, Ordering::Relaxed);
214 }
215 }
216
217 fn region_source() -> RegionArrow {
218 RegionArrow::new(RegionGenerator::default()).with_batch_size(5)
219 }
220
221 #[tokio::test]
222 async fn progress_counts_written_row_groups() {
223 let output_dir = tempfile::tempdir().unwrap();
224 let output_path = output_dir.path().join("progress.parquet");
225 let writer = BufWriter::new(File::create(&output_path).unwrap());
226
227 let tracker = Arc::new(CountingProgress {
228 increments: AtomicU64::new(0),
229 });
230 let progress: Arc<dyn ProgressTracker> = tracker.clone();
231 let progress = RunProgress::with_tracker(progress).for_table(Table::Region);
232
233 generate_parquet_with_progress(
234 writer,
235 vec![region_source(), region_source()].into_iter(),
236 1,
237 Compression::UNCOMPRESSED,
238 progress,
239 )
240 .await
241 .unwrap();
242
243 assert_eq!(tracker.increments.load(Ordering::Relaxed), 2);
244 assert!(std::fs::metadata(output_path).unwrap().len() > 0);
245 }
246}