1use super::generate::Sink;
2use super::output_plan::OutputPlanGenerator;
3use super::parquet::IntoSize;
4use super::plan::DEFAULT_PARQUET_ROW_GROUP_BYTES;
5use super::progress::ProgressTracker;
6use super::runner::PlanRunner;
7use super::statistics::WriteStatistics;
8pub use ::parquet::basic::Compression;
9use log::info;
10use std::fmt::Display;
11use std::fs::File;
12use std::io;
13use std::io::{BufWriter, Stdout, Write};
14use std::str::FromStr;
15use std::sync::Arc;
16use std::time::Instant;
17use tpchgen::distribution::Distributions;
18use tpchgen::text::TextPool;
19
20pub struct WriterSink<W: Write> {
22 statistics: WriteStatistics,
23 inner: W,
24}
25
26impl<W: Write> WriterSink<W> {
27 pub fn new(inner: W) -> Self {
28 Self {
29 inner,
30 statistics: WriteStatistics::new("buffers"),
31 }
32 }
33}
34
35impl<W: Write + Send> Sink for WriterSink<W> {
36 fn sink(&mut self, buffer: &[u8]) -> Result<(), io::Error> {
37 self.statistics.increment_chunks(1);
38 self.statistics.increment_bytes(buffer.len());
39 self.inner.write_all(buffer)
40 }
41
42 fn flush(mut self) -> Result<(), io::Error> {
43 self.inner.flush()
44 }
45}
46
47impl IntoSize for BufWriter<Stdout> {
48 fn into_size(self) -> Result<usize, io::Error> {
49 Ok(0)
51 }
52}
53
54impl IntoSize for BufWriter<File> {
55 fn into_size(self) -> Result<usize, io::Error> {
56 let file = self.into_inner()?;
57 let metadata = file.metadata()?;
58 Ok(metadata.len() as usize)
59 }
60}
61
62#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
67pub enum Table {
68 Nation,
70 Region,
72 Part,
74 Supplier,
76 Partsupp,
78 Customer,
80 Orders,
82 Lineitem,
84}
85
86impl Display for Table {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.name())
89 }
90}
91
92impl FromStr for Table {
93 type Err = &'static str;
94
95 fn from_str(s: &str) -> Result<Self, Self::Err> {
102 match s {
103 "n" | "nation" => Ok(Table::Nation),
104 "r" | "region" => Ok(Table::Region),
105 "s" | "supplier" => Ok(Table::Supplier),
106 "P" | "part" => Ok(Table::Part),
107 "S" | "partsupp" => Ok(Table::Partsupp),
108 "c" | "customer" => Ok(Table::Customer),
109 "O" | "orders" => Ok(Table::Orders),
110 "L" | "lineitem" => Ok(Table::Lineitem),
111 _ => Err("Invalid table name {s}"),
112 }
113 }
114}
115
116impl Table {
117 fn name(&self) -> &'static str {
118 match self {
119 Table::Nation => "nation",
120 Table::Region => "region",
121 Table::Part => "part",
122 Table::Supplier => "supplier",
123 Table::Partsupp => "partsupp",
124 Table::Customer => "customer",
125 Table::Orders => "orders",
126 Table::Lineitem => "lineitem",
127 }
128 }
129}
130
131#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
139pub enum OutputFormat {
140 Tbl,
142 Csv,
144 Parquet,
146}
147
148impl FromStr for OutputFormat {
149 type Err = String;
150
151 fn from_str(s: &str) -> Result<Self, Self::Err> {
152 match s.to_lowercase().as_str() {
153 "tbl" => Ok(OutputFormat::Tbl),
154 "csv" => Ok(OutputFormat::Csv),
155 "parquet" => Ok(OutputFormat::Parquet),
156 _ => Err(format!(
157 "Invalid output format: {s}. Valid formats are: tbl, csv, parquet"
158 )),
159 }
160 }
161}
162
163impl Display for OutputFormat {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 match self {
166 OutputFormat::Tbl => write!(f, "tbl"),
167 OutputFormat::Csv => write!(f, "csv"),
168 OutputFormat::Parquet => write!(f, "parquet"),
169 }
170 }
171}
172
173#[derive(Debug, Clone)]
178pub struct GeneratorConfig {
179 pub scale_factor: f64,
181 pub output_dir: std::path::PathBuf,
183 pub tables: Option<Vec<Table>>,
185 pub format: OutputFormat,
187 pub num_threads: usize,
189 pub parquet_compression: Compression,
191 pub parquet_row_group_bytes: i64,
193 pub parts: Option<i32>,
195 pub part: Option<i32>,
197 pub stdout: bool,
199 pub csv_delimiter: char,
201}
202
203impl Default for GeneratorConfig {
204 fn default() -> Self {
205 Self {
206 scale_factor: 1.0,
207 output_dir: std::path::PathBuf::from("."),
208 tables: None,
209 format: OutputFormat::Tbl,
210 num_threads: num_cpus::get(),
211 parquet_compression: Compression::SNAPPY,
212 parquet_row_group_bytes: DEFAULT_PARQUET_ROW_GROUP_BYTES,
213 parts: None,
214 part: None,
215 stdout: false,
216 csv_delimiter: ',',
217 }
218 }
219}
220
221pub struct TpchGenerator {
226 config: GeneratorConfig,
227 progress_tracker: Option<Arc<dyn ProgressTracker>>,
228}
229
230impl TpchGenerator {
231 pub fn builder() -> TpchGeneratorBuilder {
233 TpchGeneratorBuilder::new()
234 }
235
236 pub async fn generate(self) -> io::Result<()> {
238 let config = self.config;
239 let progress_tracker = self.progress_tracker;
240
241 if !config.stdout {
243 std::fs::create_dir_all(&config.output_dir)?;
244 }
245
246 let tables: Vec<Table> = if let Some(tables) = config.tables {
248 tables
249 } else {
250 vec![
251 Table::Nation,
252 Table::Region,
253 Table::Part,
254 Table::Supplier,
255 Table::Partsupp,
256 Table::Customer,
257 Table::Orders,
258 Table::Lineitem,
259 ]
260 };
261
262 let mut output_plan_generator = OutputPlanGenerator::new(
264 config.format,
265 config.scale_factor,
266 config.parquet_compression,
267 config.parquet_row_group_bytes,
268 config.stdout,
269 config.output_dir,
270 config.csv_delimiter,
271 );
272
273 for table in tables {
274 output_plan_generator.generate_plans(table, config.part, config.parts)?;
275 }
276 let output_plans = output_plan_generator.build();
277
278 let start = Instant::now();
281 Distributions::static_default();
282 TextPool::get_or_init_default();
283 let elapsed = start.elapsed();
284 info!("Created static distributions and text pools in {elapsed:?}");
285
286 let runner = PlanRunner::new(output_plans, config.num_threads);
287 let runner = if let Some(tracker) = progress_tracker {
288 runner.with_progress_tracker(tracker)
289 } else {
290 runner
291 };
292 runner.run().await?;
293 info!("Generation complete!");
294 Ok(())
295 }
296}
297
298#[derive(Debug, Clone)]
300pub struct TpchGeneratorBuilder {
301 config: GeneratorConfig,
302 progress_tracker: Option<Arc<dyn ProgressTracker>>,
303}
304
305impl TpchGeneratorBuilder {
306 pub fn new() -> Self {
308 Self {
309 config: GeneratorConfig::default(),
310 progress_tracker: None,
311 }
312 }
313
314 pub fn scale_factor(&self) -> f64 {
316 self.config.scale_factor
317 }
318
319 pub fn with_scale_factor(mut self, scale_factor: f64) -> Self {
321 self.config.scale_factor = scale_factor;
322 self
323 }
324
325 pub fn with_output_dir(mut self, output_dir: impl Into<std::path::PathBuf>) -> Self {
327 self.config.output_dir = output_dir.into();
328 self
329 }
330
331 pub fn with_tables(mut self, tables: Vec<Table>) -> Self {
333 self.config.tables = Some(tables);
334 self
335 }
336
337 pub fn with_format(mut self, format: OutputFormat) -> Self {
339 self.config.format = format;
340 self
341 }
342
343 pub fn with_num_threads(mut self, num_threads: usize) -> Self {
345 self.config.num_threads = num_threads;
346 self
347 }
348
349 pub fn with_parquet_compression(mut self, compression: Compression) -> Self {
351 self.config.parquet_compression = compression;
352 self
353 }
354
355 pub fn with_parquet_row_group_bytes(mut self, bytes: i64) -> Self {
357 self.config.parquet_row_group_bytes = bytes;
358 self
359 }
360
361 pub fn with_parts(mut self, parts: i32) -> Self {
363 self.config.parts = Some(parts);
364 self
365 }
366
367 pub fn with_part(mut self, part: i32) -> Self {
369 self.config.part = Some(part);
370 self
371 }
372
373 pub fn with_stdout(mut self, stdout: bool) -> Self {
375 self.config.stdout = stdout;
376 self
377 }
378
379 pub fn with_csv_delimiter(mut self, delimiter: char) -> Self {
381 self.config.csv_delimiter = delimiter;
382 self
383 }
384
385 pub fn with_progress_tracker(mut self, tracker: Arc<dyn ProgressTracker>) -> Self {
391 self.progress_tracker = Some(tracker);
392 self
393 }
394
395 pub fn build(self) -> TpchGenerator {
397 TpchGenerator {
398 config: self.config,
399 progress_tracker: self.progress_tracker,
400 }
401 }
402}
403
404impl Default for TpchGeneratorBuilder {
405 fn default() -> Self {
406 Self::new()
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use crate::tpch_cli::progress::ProgressTracker;
414 use std::sync::{
415 atomic::{AtomicU64, Ordering},
416 Arc, Mutex,
417 };
418
419 #[derive(Debug, Default)]
420 struct RecordingProgress {
421 registered: Mutex<Vec<(Table, u64)>>,
422 increments: Mutex<Vec<(Table, u64)>>,
423 finishes: AtomicU64,
424 }
425
426 impl ProgressTracker for RecordingProgress {
427 fn register(&self, table: Table, total_units: u64) {
428 self.registered.lock().unwrap().push((table, total_units));
429 }
430
431 fn increment(&self, table: Table, units: u64) {
432 self.increments.lock().unwrap().push((table, units));
433 }
434
435 fn finish(&self) {
436 self.finishes.fetch_add(1, Ordering::Relaxed);
437 }
438 }
439
440 #[tokio::test]
441 async fn builder_passes_custom_progress_tracker_to_runner() {
442 let output_dir = tempfile::tempdir().unwrap();
443 let tracker = Arc::new(RecordingProgress::default());
444 let progress: Arc<dyn ProgressTracker> = tracker.clone();
445
446 TpchGenerator::builder()
447 .with_output_dir(output_dir.path())
448 .with_tables(vec![Table::Region])
449 .with_num_threads(1)
450 .with_progress_tracker(progress)
451 .build()
452 .generate()
453 .await
454 .unwrap();
455
456 assert_eq!(
457 *tracker.registered.lock().unwrap(),
458 vec![(Table::Region, 1)]
459 );
460 assert_eq!(
461 *tracker.increments.lock().unwrap(),
462 vec![(Table::Region, 1)]
463 );
464 assert_eq!(tracker.finishes.load(Ordering::Relaxed), 1);
465 }
466}