Skip to main content

tpcgen_cli/tpch_cli/
runner.rs

1//! [`PlanRunner`] for running [`OutputPlan`]s.
2
3use crate::tpch_cli::csv::*;
4use crate::tpch_cli::generate::generate_in_chunks_with_progress;
5use crate::tpch_cli::generate::Source;
6use crate::tpch_cli::output_plan::{OutputLocation, OutputPlan};
7use crate::tpch_cli::parquet::generate_parquet_with_progress;
8use crate::tpch_cli::progress::ProgressTracker;
9use crate::tpch_cli::progress::RunProgress;
10use crate::tpch_cli::tbl::*;
11use crate::tpch_cli::tbl::{LineItemTblSource, NationTblSource, RegionTblSource};
12use crate::tpch_cli::{OutputFormat, Table, WriterSink};
13use log::{debug, info};
14use std::io;
15use std::io::BufWriter;
16use std::sync::Arc;
17use tokio::task::{JoinError, JoinSet};
18use tpchgen::generators::{
19    CustomerGenerator, LineItemGenerator, NationGenerator, OrderGenerator, PartGenerator,
20    PartSuppGenerator, RegionGenerator, SupplierGenerator,
21};
22use tpchgen_arrow::{
23    CustomerArrow, LineItemArrow, NationArrow, OrderArrow, PartArrow, PartSuppArrow,
24    RecordBatchIterator, RegionArrow, SupplierArrow,
25};
26
27/// Runs multiple [`OutputPlan`]s in parallel, managing the number of threads
28/// used to run them.
29#[derive(Debug)]
30pub struct PlanRunner {
31    plans: Vec<OutputPlan>,
32    num_threads: usize,
33    progress: RunProgress,
34}
35
36impl PlanRunner {
37    /// Create a new [`PlanRunner`] with the given plans and number of threads.
38    /// Progress reporting is disabled by default.
39    pub fn new(plans: Vec<OutputPlan>, num_threads: usize) -> Self {
40        Self {
41            plans,
42            num_threads,
43            progress: RunProgress::default(),
44        }
45    }
46
47    /// Attach a [`ProgressTracker`].
48    ///
49    /// The runner pre-registers each table's output-unit total with the
50    /// tracker before scheduling, calls [`ProgressTracker::increment`]
51    /// after output units are written, and calls [`ProgressTracker::finish`]
52    /// once on the success path. Implementations needing cleanup on the
53    /// error or panic path should use `Drop` as a fallback.
54    pub fn with_progress_tracker(mut self, tracker: Arc<dyn ProgressTracker>) -> Self {
55        self.progress = RunProgress::with_tracker(tracker);
56        self
57    }
58
59    /// Run all the plans in the runner.
60    pub async fn run(self) -> Result<(), io::Error> {
61        debug!(
62            "Running {} plans with {} threads...",
63            self.plans.len(),
64            self.num_threads
65        );
66        let Self {
67            mut plans,
68            num_threads,
69            progress,
70        } = self;
71
72        // Sort the plans by the number of parts so the largest are first
73        plans.sort_unstable_by(|a, b| {
74            let a_cnt = a.chunk_count();
75            let b_cnt = b.chunk_count();
76            a_cnt.cmp(&b_cnt)
77        });
78
79        // Pre-register per-table output-unit totals so trackers can size their
80        // bars before the first `increment`.
81        progress.register_totals(&plans);
82
83        // Do the actual work in parallel, using a worker queue
84        let mut worker_queue = WorkerQueue::new(num_threads, progress.clone());
85        while let Some(plan) = plans.pop() {
86            worker_queue.schedule_plan(plan).await?;
87        }
88        worker_queue.join_all().await?;
89        progress.finish();
90        Ok(())
91    }
92}
93
94/// Manages worker tasks, limiting the number of total outstanding threads
95/// to some fixed number
96///
97/// The runner executes each plan with a number of threads equal to the
98/// number of parts in the plan, but no more than the total number of
99/// threads specified when creating the runner. If a plan does not need all
100/// the threads, the remaining threads are used to run other plans.
101///
102/// This is important to keep all cores busy for smaller tables that may not
103/// have sufficient parts to keep all threads busy (see [`GenerationPlan`]
104/// for more details), but not schedule more tasks than we have threads for.
105///
106/// Scheduling too many tasks requires more memory and leads to context
107/// switching overhead, which can slow down the generation process.
108///
109/// [`GenerationPlan`]: crate::tpch_cli::plan::GenerationPlan
110struct WorkerQueue {
111    join_set: JoinSet<io::Result<usize>>,
112    /// Current number of threads available to commit
113    available_threads: usize,
114    progress: RunProgress,
115}
116
117impl WorkerQueue {
118    pub fn new(max_threads: usize, progress: RunProgress) -> Self {
119        assert!(max_threads > 0);
120        Self {
121            join_set: JoinSet::new(),
122            available_threads: max_threads,
123            progress,
124        }
125    }
126
127    /// Spawns a task to run the plan with as many threads as possible
128    /// without exceeding the maximum number of threads.
129    ///
130    /// If there are no threads available, it will wait for one to finish
131    /// before spawning the new task.
132    ///
133    /// Note this algorithm does not guarantee that all threads are always busy,
134    /// but it should be good enough for most cases. For best thread utilization
135    /// spawn the largest plans first.
136    pub async fn schedule_plan(&mut self, plan: OutputPlan) -> io::Result<()> {
137        debug!("scheduling plan {plan}");
138        loop {
139            if self.available_threads == 0 {
140                debug!("no threads left, wait for one to finish");
141                let Some(result) = self.join_set.join_next().await else {
142                    return Err(io::Error::other(
143                        "Internal Error No more tasks to wait for, but had no threads",
144                    ));
145                };
146                self.available_threads += task_result(result)?;
147                continue; // look for threads again
148            }
149
150            // Check for any other jobs done so we can reuse their threads
151            if let Some(result) = self.join_set.try_join_next() {
152                self.available_threads += task_result(result)?;
153                continue;
154            }
155
156            debug_assert!(
157                self.available_threads > 0,
158                "should have at least one thread to continue"
159            );
160
161            // figure out how many threads to allocate to this plan. Each plan
162            // can use up to `part_count` threads.
163            let chunk_count = plan.chunk_count();
164
165            let num_plan_threads = self.available_threads.min(chunk_count);
166
167            // run the plan in a separate task, which returns the number of threads it used
168            debug!("Spawning plan {plan} with {num_plan_threads} threads");
169
170            let progress = self.progress.clone();
171            self.join_set
172                .spawn(async move { run_plan(plan, num_plan_threads, progress).await });
173            self.available_threads -= num_plan_threads;
174            return Ok(());
175        }
176    }
177
178    // Wait for all tasks to finish
179    pub async fn join_all(mut self) -> io::Result<()> {
180        debug!("Waiting for tasks to finish...");
181        while let Some(result) = self.join_set.join_next().await {
182            task_result(result)?;
183        }
184        debug!("Tasks finished.");
185        Ok(())
186    }
187}
188
189/// unwraps the result of a task and converts it to an `io::Result<T>`.
190fn task_result<T>(result: Result<io::Result<T>, JoinError>) -> io::Result<T> {
191    result.map_err(|e| io::Error::other(format!("Task Panic: {e}")))?
192}
193
194/// Run a single [`OutputPlan`]
195async fn run_plan(
196    plan: OutputPlan,
197    num_threads: usize,
198    progress: RunProgress,
199) -> io::Result<usize> {
200    match plan.table() {
201        Table::Nation => run_nation_plan(plan, num_threads, progress).await,
202        Table::Region => run_region_plan(plan, num_threads, progress).await,
203        Table::Part => run_part_plan(plan, num_threads, progress).await,
204        Table::Supplier => run_supplier_plan(plan, num_threads, progress).await,
205        Table::Partsupp => run_partsupp_plan(plan, num_threads, progress).await,
206        Table::Customer => run_customer_plan(plan, num_threads, progress).await,
207        Table::Orders => run_orders_plan(plan, num_threads, progress).await,
208        Table::Lineitem => run_lineitem_plan(plan, num_threads, progress).await,
209    }
210}
211
212/// If `path` already exists, log a warning, advance progress by the full
213/// output-unit count for this plan, and return `true` so the caller can skip
214/// generation. Returns `false` otherwise.
215fn maybe_skip_existing(path: &std::path::Path, plan: &OutputPlan, progress: &RunProgress) -> bool {
216    if !path.exists() {
217        return false;
218    }
219    log::warn!("{} already exists, skipping generation", path.display());
220    progress.increment_for_existing(plan);
221    true
222}
223
224/// Writes a CSV/TSV output from the sources
225async fn write_file<I>(
226    plan: OutputPlan,
227    num_threads: usize,
228    sources: I,
229    progress: RunProgress,
230) -> Result<(), io::Error>
231where
232    I: Iterator<Item: Source> + 'static,
233{
234    let table = plan.table();
235    let table_progress = progress.for_table(table);
236    // Since generate_in_chunks already buffers, there is no need to buffer
237    // again (aka don't use BufWriter here)
238    match plan.output_location() {
239        OutputLocation::Stdout => {
240            let sink = WriterSink::new(io::stdout());
241            generate_in_chunks_with_progress(sink, sources, num_threads, table_progress).await
242        }
243        OutputLocation::File(path) => {
244            if maybe_skip_existing(path, &plan, &progress) {
245                return Ok(());
246            }
247            // write to a temp file and then rename to avoid partial files
248            let temp_path = path.with_extension("inprogress");
249            let file = std::fs::File::create(&temp_path).map_err(|err| {
250                io::Error::other(format!("Failed to create {temp_path:?}: {err}"))
251            })?;
252            let sink = WriterSink::new(file);
253            generate_in_chunks_with_progress(sink, sources, num_threads, table_progress).await?;
254            // rename the temp file to the final path
255            std::fs::rename(&temp_path, path).map_err(|e| {
256                io::Error::other(format!(
257                    "Failed to rename {temp_path:?} to {path:?} file: {e}"
258                ))
259            })?;
260            Ok(())
261        }
262    }
263}
264
265/// Generates an output parquet file from the sources
266async fn write_parquet<I>(
267    plan: OutputPlan,
268    num_threads: usize,
269    sources: I,
270    progress: RunProgress,
271) -> Result<(), io::Error>
272where
273    I: Iterator<Item: RecordBatchIterator> + 'static,
274{
275    let table = plan.table();
276    let table_progress = progress.for_table(table);
277    match plan.output_location() {
278        OutputLocation::Stdout => {
279            let writer = BufWriter::with_capacity(32 * 1024 * 1024, io::stdout()); // 32MB buffer
280            generate_parquet_with_progress(
281                writer,
282                sources,
283                num_threads,
284                plan.parquet_compression(),
285                table_progress,
286            )
287            .await
288        }
289        OutputLocation::File(path) => {
290            if maybe_skip_existing(path, &plan, &progress) {
291                return Ok(());
292            }
293            // write to a temp file and then rename to avoid partial files
294            let temp_path = path.with_extension("inprogress");
295            let file = std::fs::File::create(&temp_path).map_err(|err| {
296                io::Error::other(format!("Failed to create {temp_path:?}: {err}"))
297            })?;
298            let writer = BufWriter::with_capacity(32 * 1024 * 1024, file); // 32MB buffer
299            generate_parquet_with_progress(
300                writer,
301                sources,
302                num_threads,
303                plan.parquet_compression(),
304                table_progress,
305            )
306            .await?;
307            // rename the temp file to the final path
308            std::fs::rename(&temp_path, path).map_err(|e| {
309                io::Error::other(format!(
310                    "Failed to rename {temp_path:?} to {path:?} file: {e}"
311                ))
312            })?;
313            Ok(())
314        }
315    }
316}
317
318/// macro to create a function for generating a part of a particular able
319///
320/// Arguments:
321/// $FUN_NAME: name of the function to create
322/// $GENERATOR: The generator type to use
323/// $TBL_SOURCE: The [`Source`] type to use for TBL format
324/// $CSV_SOURCE: The [`Source`] type to use for CSV format
325/// $PARQUET_SOURCE: The [`RecordBatchIterator`] type to use for Parquet format
326macro_rules! define_run {
327    ($FUN_NAME:ident, $GENERATOR:ident, $TBL_SOURCE:ty, $CSV_SOURCE:ty, $PARQUET_SOURCE:ty) => {
328        async fn $FUN_NAME(
329            plan: OutputPlan,
330            num_threads: usize,
331            progress: RunProgress,
332        ) -> io::Result<usize> {
333            use crate::tpch_cli::GenerationPlan;
334            let scale_factor = plan.scale_factor();
335            info!("Writing {plan} using {num_threads} threads");
336
337            /// These interior functions are used to tell the compiler that the lifetime is 'static
338            /// (when these were closures, the compiler could not figure out the lifetime) and
339            /// resulted in errors like this:
340            ///          let _ = join_set.spawn(async move {
341            ///                 |  _____________________^
342            ///              96 | |                 run_plan(plan, num_plan_threads).await
343            ///              97 | |             });
344            ///                 | |______________^ implementation of `FnOnce` is not general enough
345            fn tbl_sources(
346                generation_plan: &GenerationPlan,
347                scale_factor: f64,
348            ) -> impl Iterator<Item: Source> + 'static {
349                generation_plan
350                    .clone()
351                    .into_iter()
352                    .map(move |(part, num_parts)| $GENERATOR::new(scale_factor, part, num_parts))
353                    .map(<$TBL_SOURCE>::new)
354            }
355
356            fn csv_sources(
357                generation_plan: &GenerationPlan,
358                scale_factor: f64,
359                delimiter: char,
360            ) -> impl Iterator<Item: Source> + 'static {
361                generation_plan
362                    .clone()
363                    .into_iter()
364                    .map(move |(part, num_parts)| $GENERATOR::new(scale_factor, part, num_parts))
365                    .map(move |gen| <$CSV_SOURCE>::new(gen, delimiter))
366            }
367
368            fn parquet_sources(
369                generation_plan: &GenerationPlan,
370                scale_factor: f64,
371            ) -> impl Iterator<Item: RecordBatchIterator> + 'static {
372                generation_plan
373                    .clone()
374                    .into_iter()
375                    .map(move |(part, num_parts)| $GENERATOR::new(scale_factor, part, num_parts))
376                    .map(<$PARQUET_SOURCE>::new)
377            }
378
379            // Dispatch to the appropriate output format
380            match plan.output_format() {
381                OutputFormat::Tbl => {
382                    let gens = tbl_sources(plan.generation_plan(), scale_factor);
383                    write_file(plan, num_threads, gens, progress).await?
384                }
385                OutputFormat::Csv => {
386                    let delimiter = plan.csv_delimiter();
387                    let gens = csv_sources(plan.generation_plan(), scale_factor, delimiter);
388                    write_file(plan, num_threads, gens, progress).await?
389                }
390                OutputFormat::Parquet => {
391                    let gens = parquet_sources(plan.generation_plan(), scale_factor);
392                    write_parquet(plan, num_threads, gens, progress).await?
393                }
394            };
395            Ok(num_threads)
396        }
397    };
398}
399
400define_run!(
401    run_lineitem_plan,
402    LineItemGenerator,
403    LineItemTblSource,
404    LineItemCsvSource,
405    LineItemArrow
406);
407
408define_run!(
409    run_nation_plan,
410    NationGenerator,
411    NationTblSource,
412    NationCsvSource,
413    NationArrow
414);
415
416define_run!(
417    run_region_plan,
418    RegionGenerator,
419    RegionTblSource,
420    RegionCsvSource,
421    RegionArrow
422);
423
424define_run!(
425    run_part_plan,
426    PartGenerator,
427    PartTblSource,
428    PartCsvSource,
429    PartArrow
430);
431
432define_run!(
433    run_supplier_plan,
434    SupplierGenerator,
435    SupplierTblSource,
436    SupplierCsvSource,
437    SupplierArrow
438);
439define_run!(
440    run_partsupp_plan,
441    PartSuppGenerator,
442    PartSuppTblSource,
443    PartSuppCsvSource,
444    PartSuppArrow
445);
446
447define_run!(
448    run_customer_plan,
449    CustomerGenerator,
450    CustomerTblSource,
451    CustomerCsvSource,
452    CustomerArrow
453);
454
455define_run!(
456    run_orders_plan,
457    OrderGenerator,
458    OrderTblSource,
459    OrderCsvSource,
460    OrderArrow
461);
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466    use crate::tpch_cli::progress::ProgressTracker;
467    use crate::tpch_cli::{Compression, GenerationPlan, DEFAULT_PARQUET_ROW_GROUP_BYTES};
468    use std::sync::{
469        atomic::{AtomicU64, Ordering},
470        Arc,
471    };
472
473    #[derive(Debug)]
474    struct CountingProgress {
475        increments: AtomicU64,
476    }
477
478    impl ProgressTracker for CountingProgress {
479        fn increment(&self, _table: Table, units: u64) {
480            self.increments.fetch_add(units, Ordering::Relaxed);
481        }
482    }
483
484    #[test]
485    fn skip_existing_advances_progress_by_full_plan() {
486        let output_dir = tempfile::tempdir().unwrap();
487        let output_path = output_dir.path().join("lineitem.tbl");
488        std::fs::write(&output_path, b"already here").unwrap();
489
490        let generation_plan = GenerationPlan::try_new(
491            Table::Lineitem,
492            OutputFormat::Tbl,
493            1.0,
494            Some(1),
495            Some(4),
496            DEFAULT_PARQUET_ROW_GROUP_BYTES,
497        )
498        .unwrap();
499        let plan = OutputPlan::new(
500            Table::Lineitem,
501            1.0,
502            OutputFormat::Tbl,
503            Compression::SNAPPY,
504            OutputLocation::File(output_path.clone()),
505            generation_plan,
506            ',',
507        );
508        let expected_units = plan.chunk_count() as u64;
509        assert!(expected_units > 1);
510
511        let tracker = Arc::new(CountingProgress {
512            increments: AtomicU64::new(0),
513        });
514        let progress: Arc<dyn ProgressTracker> = tracker.clone();
515        let progress = RunProgress::with_tracker(progress);
516
517        assert!(maybe_skip_existing(&output_path, &plan, &progress));
518        assert_eq!(tracker.increments.load(Ordering::Relaxed), expected_units);
519    }
520}