Skip to main content

tpcgen_cli/tpch_cli/
generator.rs

1use super::generate::Sink;
2use super::output_plan::OutputPlanGenerator;
3use super::parquet::IntoSize;
4use super::plan::DEFAULT_PARQUET_ROW_GROUP_BYTES;
5use super::progress::ProgressTracker;
6use super::runner::PlanRunner;
7use super::statistics::WriteStatistics;
8pub use ::parquet::basic::Compression;
9use log::info;
10use std::fmt::Display;
11use std::fs::File;
12use std::io;
13use std::io::{BufWriter, Stdout, Write};
14use std::str::FromStr;
15use std::sync::Arc;
16use std::time::Instant;
17use tpchgen::distribution::Distributions;
18use tpchgen::text::TextPool;
19
20/// Wrapper around a buffer writer that counts the number of buffers and bytes written
21pub struct WriterSink<W: Write> {
22    statistics: WriteStatistics,
23    inner: W,
24}
25
26impl<W: Write> WriterSink<W> {
27    pub fn new(inner: W) -> Self {
28        Self {
29            inner,
30            statistics: WriteStatistics::new("buffers"),
31        }
32    }
33}
34
35impl<W: Write + Send> Sink for WriterSink<W> {
36    fn sink(&mut self, buffer: &[u8]) -> Result<(), io::Error> {
37        self.statistics.increment_chunks(1);
38        self.statistics.increment_bytes(buffer.len());
39        self.inner.write_all(buffer)
40    }
41
42    fn flush(mut self) -> Result<(), io::Error> {
43        self.inner.flush()
44    }
45}
46
47impl IntoSize for BufWriter<Stdout> {
48    fn into_size(self) -> Result<usize, io::Error> {
49        // we can't get the size of stdout, so just return 0
50        Ok(0)
51    }
52}
53
54impl IntoSize for BufWriter<File> {
55    fn into_size(self) -> Result<usize, io::Error> {
56        let file = self.into_inner()?;
57        let metadata = file.metadata()?;
58        Ok(metadata.len() as usize)
59    }
60}
61
62/// TPC-H table types
63///
64/// Represents the 8 tables in the TPC-H benchmark schema.
65/// Tables are ordered by size (smallest to largest at SF=1).
66#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
67pub enum Table {
68    /// Nation table (25 rows)
69    Nation,
70    /// Region table (5 rows)
71    Region,
72    /// Part table (200,000 rows at SF=1)
73    Part,
74    /// Supplier table (10,000 rows at SF=1)
75    Supplier,
76    /// Part-Supplier relationship table (800,000 rows at SF=1)
77    Partsupp,
78    /// Customer table (150,000 rows at SF=1)
79    Customer,
80    /// Orders table (1,500,000 rows at SF=1)
81    Orders,
82    /// Line item table (6,000,000 rows at SF=1)
83    Lineitem,
84}
85
86impl Display for Table {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        write!(f, "{}", self.name())
89    }
90}
91
92impl FromStr for Table {
93    type Err = &'static str;
94
95    /// Returns the table enum value from the given string full name or abbreviation
96    ///
97    /// The original dbgen tool allows some abbreviations to mean two different tables
98    /// like 'p' which aliases to both 'part' and 'partsupp'. This implementation does
99    /// not support this since it just adds unnecessary complexity and confusion so we
100    /// only support the exclusive abbreviations.
101    fn from_str(s: &str) -> Result<Self, Self::Err> {
102        match s {
103            "n" | "nation" => Ok(Table::Nation),
104            "r" | "region" => Ok(Table::Region),
105            "s" | "supplier" => Ok(Table::Supplier),
106            "P" | "part" => Ok(Table::Part),
107            "S" | "partsupp" => Ok(Table::Partsupp),
108            "c" | "customer" => Ok(Table::Customer),
109            "O" | "orders" => Ok(Table::Orders),
110            "L" | "lineitem" => Ok(Table::Lineitem),
111            _ => Err("Invalid table name {s}"),
112        }
113    }
114}
115
116impl Table {
117    fn name(&self) -> &'static str {
118        match self {
119            Table::Nation => "nation",
120            Table::Region => "region",
121            Table::Part => "part",
122            Table::Supplier => "supplier",
123            Table::Partsupp => "partsupp",
124            Table::Customer => "customer",
125            Table::Orders => "orders",
126            Table::Lineitem => "lineitem",
127        }
128    }
129}
130
131/// Output format for generated data
132///
133/// # Format Details
134///
135/// - **TBL**: Pipe-delimited format compatible with original dbgen tool
136/// - **CSV**: Comma-separated values with proper escaping
137/// - **Parquet**: Columnar Apache Parquet format with configurable compression
138#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
139pub enum OutputFormat {
140    /// TBL format (pipe-delimited, dbgen-compatible)
141    Tbl,
142    /// CSV format (comma-separated values)
143    Csv,
144    /// Apache Parquet format (columnar, compressed)
145    Parquet,
146}
147
148impl FromStr for OutputFormat {
149    type Err = String;
150
151    fn from_str(s: &str) -> Result<Self, Self::Err> {
152        match s.to_lowercase().as_str() {
153            "tbl" => Ok(OutputFormat::Tbl),
154            "csv" => Ok(OutputFormat::Csv),
155            "parquet" => Ok(OutputFormat::Parquet),
156            _ => Err(format!(
157                "Invalid output format: {s}. Valid formats are: tbl, csv, parquet"
158            )),
159        }
160    }
161}
162
163impl Display for OutputFormat {
164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165        match self {
166            OutputFormat::Tbl => write!(f, "tbl"),
167            OutputFormat::Csv => write!(f, "csv"),
168            OutputFormat::Parquet => write!(f, "parquet"),
169        }
170    }
171}
172
173/// Configuration for TPC-H data generation
174///
175/// This struct holds all the parameters needed to generate TPC-H benchmark data.
176/// It's typically not constructed directly - use [`TpchGeneratorBuilder`] instead.
177#[derive(Debug, Clone)]
178pub struct GeneratorConfig {
179    /// Scale factor (e.g., 1.0 for 1GB, 10.0 for 10GB)
180    pub scale_factor: f64,
181    /// Output directory for generated files
182    pub output_dir: std::path::PathBuf,
183    /// Tables to generate (if None, generates all tables)
184    pub tables: Option<Vec<Table>>,
185    /// Output format (TBL, CSV, or Parquet)
186    pub format: OutputFormat,
187    /// Number of threads for parallel generation
188    pub num_threads: usize,
189    /// Parquet compression format
190    pub parquet_compression: Compression,
191    /// Target row group size in bytes for Parquet files
192    pub parquet_row_group_bytes: i64,
193    /// Number of partitions to generate (if None, generates a single file per table)
194    pub parts: Option<i32>,
195    /// Specific partition to generate (1-based, requires parts to be set)
196    pub part: Option<i32>,
197    /// Write output to stdout instead of files
198    pub stdout: bool,
199    /// CSV delimiter character (only applies to CSV format)
200    pub csv_delimiter: char,
201}
202
203impl Default for GeneratorConfig {
204    fn default() -> Self {
205        Self {
206            scale_factor: 1.0,
207            output_dir: std::path::PathBuf::from("."),
208            tables: None,
209            format: OutputFormat::Tbl,
210            num_threads: num_cpus::get(),
211            parquet_compression: Compression::SNAPPY,
212            parquet_row_group_bytes: DEFAULT_PARQUET_ROW_GROUP_BYTES,
213            parts: None,
214            part: None,
215            stdout: false,
216            csv_delimiter: ',',
217        }
218    }
219}
220
221/// TPC-H data generator
222///
223/// The main entry point for generating TPC-H benchmark data.
224/// Use the builder pattern via [`TpchGenerator::builder()`] to configure and create instances.
225pub struct TpchGenerator {
226    config: GeneratorConfig,
227    progress_tracker: Option<Arc<dyn ProgressTracker>>,
228}
229
230impl TpchGenerator {
231    /// Create a new builder for configuring the generator.
232    pub fn builder() -> TpchGeneratorBuilder {
233        TpchGeneratorBuilder::new()
234    }
235
236    /// Generate TPC-H data with the configured settings.
237    pub async fn generate(self) -> io::Result<()> {
238        let config = self.config;
239        let progress_tracker = self.progress_tracker;
240
241        // Create output directory if it doesn't exist and we are not writing to stdout
242        if !config.stdout {
243            std::fs::create_dir_all(&config.output_dir)?;
244        }
245
246        // Determine which tables to generate
247        let tables: Vec<Table> = if let Some(tables) = config.tables {
248            tables
249        } else {
250            vec![
251                Table::Nation,
252                Table::Region,
253                Table::Part,
254                Table::Supplier,
255                Table::Partsupp,
256                Table::Customer,
257                Table::Orders,
258                Table::Lineitem,
259            ]
260        };
261
262        // Determine what files to generate
263        let mut output_plan_generator = OutputPlanGenerator::new(
264            config.format,
265            config.scale_factor,
266            config.parquet_compression,
267            config.parquet_row_group_bytes,
268            config.stdout,
269            config.output_dir,
270            config.csv_delimiter,
271        );
272
273        for table in tables {
274            output_plan_generator.generate_plans(table, config.part, config.parts)?;
275        }
276        let output_plans = output_plan_generator.build();
277
278        // Force the creation of the distributions and text pool so it doesn't
279        // get charged to the first table.
280        let start = Instant::now();
281        Distributions::static_default();
282        TextPool::get_or_init_default();
283        let elapsed = start.elapsed();
284        info!("Created static distributions and text pools in {elapsed:?}");
285
286        let runner = PlanRunner::new(output_plans, config.num_threads);
287        let runner = if let Some(tracker) = progress_tracker {
288            runner.with_progress_tracker(tracker)
289        } else {
290            runner
291        };
292        runner.run().await?;
293        info!("Generation complete!");
294        Ok(())
295    }
296}
297
298/// Builder for constructing a [`TpchGenerator`].
299#[derive(Debug, Clone)]
300pub struct TpchGeneratorBuilder {
301    config: GeneratorConfig,
302    progress_tracker: Option<Arc<dyn ProgressTracker>>,
303}
304
305impl TpchGeneratorBuilder {
306    /// Create a new builder with default configuration.
307    pub fn new() -> Self {
308        Self {
309            config: GeneratorConfig::default(),
310            progress_tracker: None,
311        }
312    }
313
314    /// Returns the scale factor.
315    pub fn scale_factor(&self) -> f64 {
316        self.config.scale_factor
317    }
318
319    /// Set the scale factor (e.g., 1.0 for 1GB, 10.0 for 10GB).
320    pub fn with_scale_factor(mut self, scale_factor: f64) -> Self {
321        self.config.scale_factor = scale_factor;
322        self
323    }
324
325    /// Set the output directory.
326    pub fn with_output_dir(mut self, output_dir: impl Into<std::path::PathBuf>) -> Self {
327        self.config.output_dir = output_dir.into();
328        self
329    }
330
331    /// Set which tables to generate (default: all tables).
332    pub fn with_tables(mut self, tables: Vec<Table>) -> Self {
333        self.config.tables = Some(tables);
334        self
335    }
336
337    /// Set the output format (default: TBL).
338    pub fn with_format(mut self, format: OutputFormat) -> Self {
339        self.config.format = format;
340        self
341    }
342
343    /// Set the number of threads for parallel generation (default: number of CPUs).
344    pub fn with_num_threads(mut self, num_threads: usize) -> Self {
345        self.config.num_threads = num_threads;
346        self
347    }
348
349    /// Set Parquet compression format (default: SNAPPY).
350    pub fn with_parquet_compression(mut self, compression: Compression) -> Self {
351        self.config.parquet_compression = compression;
352        self
353    }
354
355    /// Set target row group size in bytes for Parquet files (default: 7MB).
356    pub fn with_parquet_row_group_bytes(mut self, bytes: i64) -> Self {
357        self.config.parquet_row_group_bytes = bytes;
358        self
359    }
360
361    /// Set the number of partitions to generate.
362    pub fn with_parts(mut self, parts: i32) -> Self {
363        self.config.parts = Some(parts);
364        self
365    }
366
367    /// Set the specific partition to generate (1-based, requires parts to be set).
368    pub fn with_part(mut self, part: i32) -> Self {
369        self.config.part = Some(part);
370        self
371    }
372
373    /// Write output to stdout instead of files.
374    pub fn with_stdout(mut self, stdout: bool) -> Self {
375        self.config.stdout = stdout;
376        self
377    }
378
379    /// Set the CSV delimiter character (only applies to CSV format, default: ',').
380    pub fn with_csv_delimiter(mut self, delimiter: char) -> Self {
381        self.config.csv_delimiter = delimiter;
382        self
383    }
384
385    /// Attach a custom [`ProgressTracker`] to receive generation progress updates.
386    ///
387    /// The runner calls [`ProgressTracker::finish`] on successful completion.
388    /// Trackers that need error or panic cleanup should use `Drop` as a
389    /// fallback. See [`crate::tpch_cli::progress`] for the full contract and examples.
390    pub fn with_progress_tracker(mut self, tracker: Arc<dyn ProgressTracker>) -> Self {
391        self.progress_tracker = Some(tracker);
392        self
393    }
394
395    /// Build the [`TpchGenerator`] with the configured settings.
396    pub fn build(self) -> TpchGenerator {
397        TpchGenerator {
398            config: self.config,
399            progress_tracker: self.progress_tracker,
400        }
401    }
402}
403
404impl Default for TpchGeneratorBuilder {
405    fn default() -> Self {
406        Self::new()
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413    use crate::tpch_cli::progress::ProgressTracker;
414    use std::sync::{
415        atomic::{AtomicU64, Ordering},
416        Arc, Mutex,
417    };
418
419    #[derive(Debug, Default)]
420    struct RecordingProgress {
421        registered: Mutex<Vec<(Table, u64)>>,
422        increments: Mutex<Vec<(Table, u64)>>,
423        finishes: AtomicU64,
424    }
425
426    impl ProgressTracker for RecordingProgress {
427        fn register(&self, table: Table, total_units: u64) {
428            self.registered.lock().unwrap().push((table, total_units));
429        }
430
431        fn increment(&self, table: Table, units: u64) {
432            self.increments.lock().unwrap().push((table, units));
433        }
434
435        fn finish(&self) {
436            self.finishes.fetch_add(1, Ordering::Relaxed);
437        }
438    }
439
440    #[tokio::test]
441    async fn builder_passes_custom_progress_tracker_to_runner() {
442        let output_dir = tempfile::tempdir().unwrap();
443        let tracker = Arc::new(RecordingProgress::default());
444        let progress: Arc<dyn ProgressTracker> = tracker.clone();
445
446        TpchGenerator::builder()
447            .with_output_dir(output_dir.path())
448            .with_tables(vec![Table::Region])
449            .with_num_threads(1)
450            .with_progress_tracker(progress)
451            .build()
452            .generate()
453            .await
454            .unwrap();
455
456        assert_eq!(
457            *tracker.registered.lock().unwrap(),
458            vec![(Table::Region, 1)]
459        );
460        assert_eq!(
461            *tracker.increments.lock().unwrap(),
462            vec![(Table::Region, 1)]
463        );
464        assert_eq!(tracker.finishes.load(Ordering::Relaxed), 1);
465    }
466}