1use crate::statistics::WriteStatistics;
4use arrow::datatypes::SchemaRef;
5use futures::StreamExt;
6use log::debug;
7use parquet::arrow::arrow_writer::{compute_leaves, ArrowColumnChunk};
8use parquet::arrow::ArrowSchemaConverter;
9use parquet::basic::Compression;
10use parquet::file::properties::WriterProperties;
11use parquet::file::writer::SerializedFileWriter;
12use parquet::schema::types::SchemaDescPtr;
13use std::io;
14use std::io::Write;
15use std::sync::Arc;
16use tokio::sync::mpsc::{Receiver, Sender};
17use tpchgen_arrow::RecordBatchIterator;
18
19pub trait IntoSize {
20 fn into_size(self) -> Result<usize, io::Error>;
22}
23
24pub async fn generate_parquet<W: Write + Send + IntoSize + 'static, I>(
31 writer: W,
32 iter_iter: I,
33 num_threads: usize,
34 parquet_compression: Compression,
35) -> Result<(), io::Error>
36where
37 I: Iterator<Item: RecordBatchIterator> + 'static,
38{
39 debug!(
40 "Generating Parquet with {num_threads} threads, using {parquet_compression} compression"
41 );
42 let mut iter_iter = iter_iter.peekable();
44
45 let Some(first_iter) = iter_iter.peek() else {
47 return Ok(()); };
49 let schema = Arc::clone(first_iter.schema());
50
51 let writer_properties = WriterProperties::builder()
53 .set_compression(parquet_compression)
54 .build();
55 let writer_properties = Arc::new(writer_properties);
56 let parquet_schema = Arc::new(
57 ArrowSchemaConverter::new()
58 .with_coerce_types(writer_properties.coerce_types())
59 .convert(&schema)
60 .unwrap(),
61 );
62
63 let mut row_group_stream = futures::stream::iter(iter_iter)
65 .map(async |iter| {
66 let parquet_schema = Arc::clone(&parquet_schema);
67 let writer_properties = Arc::clone(&writer_properties);
68 let schema = Arc::clone(&schema);
69 tokio::task::spawn(async move {
71 encode_row_group(parquet_schema, writer_properties, schema, iter)
72 })
73 .await
74 .expect("Inner task panicked")
75 })
76 .buffered(num_threads); let mut statistics = WriteStatistics::new("row groups");
79
80 let root_schema = parquet_schema.root_schema_ptr();
84 let writer_properties_captured = Arc::clone(&writer_properties);
85 let (tx, mut rx): (
86 Sender<Vec<ArrowColumnChunk>>,
87 Receiver<Vec<ArrowColumnChunk>>,
88 ) = tokio::sync::mpsc::channel(num_threads);
89 let writer_task = tokio::task::spawn_blocking(move || {
90 let mut writer =
92 SerializedFileWriter::new(writer, root_schema, writer_properties_captured).unwrap();
93
94 while let Some(chunks) = rx.blocking_recv() {
95 let mut row_group_writer = writer.next_row_group().unwrap();
97
98 for chunk in chunks {
100 chunk.append_to_row_group(&mut row_group_writer).unwrap();
101 }
102 row_group_writer.close().unwrap();
103 statistics.increment_chunks(1);
104 }
105 let size = writer.into_inner()?.into_size()?;
106 statistics.increment_bytes(size);
107 Ok(()) as Result<(), io::Error>
108 });
109
110 while let Some(chunks) = row_group_stream.next().await {
112 if let Err(e) = tx.send(chunks).await {
114 debug!("Error sending chunks to writer: {e}");
115 break; }
117 }
118 drop(tx);
120
121 writer_task.await??;
123
124 Ok(())
125}
126
127fn encode_row_group<I>(
134 parquet_schema: SchemaDescPtr,
135 writer_properties: Arc<WriterProperties>,
136 schema: SchemaRef,
137 iter: I,
138) -> Vec<ArrowColumnChunk>
139where
140 I: RecordBatchIterator,
141{
142 #[allow(deprecated)]
144 let mut col_writers = parquet::arrow::arrow_writer::get_column_writers(
145 &parquet_schema,
146 &writer_properties,
147 &schema,
148 )
149 .unwrap();
150
151 for batch in iter {
153 let columns = batch.columns().iter();
154 let col_writers = col_writers.iter_mut();
155 let fields = schema.fields().iter();
156
157 for ((col_writer, field), arr) in col_writers.zip(fields).zip(columns) {
158 for leaves in compute_leaves(field.as_ref(), arr).unwrap() {
159 col_writer.write(&leaves).unwrap();
160 }
161 }
162 }
163 col_writers
165 .into_iter()
166 .map(|col_writer| col_writer.close().unwrap())
167 .collect()
168}