1use crate::tpch_cli::csv::*;
4use crate::tpch_cli::generate::generate_in_chunks_with_progress;
5use crate::tpch_cli::generate::Source;
6use crate::tpch_cli::output_plan::{OutputLocation, OutputPlan};
7use crate::tpch_cli::parquet::generate_parquet_with_progress;
8use crate::tpch_cli::progress::ProgressTracker;
9use crate::tpch_cli::progress::RunProgress;
10use crate::tpch_cli::tbl::*;
11use crate::tpch_cli::tbl::{LineItemTblSource, NationTblSource, RegionTblSource};
12use crate::tpch_cli::{OutputFormat, Table, WriterSink};
13use log::{debug, info};
14use std::io;
15use std::io::BufWriter;
16use std::sync::Arc;
17use tokio::task::{JoinError, JoinSet};
18use tpchgen::generators::{
19 CustomerGenerator, LineItemGenerator, NationGenerator, OrderGenerator, PartGenerator,
20 PartSuppGenerator, RegionGenerator, SupplierGenerator,
21};
22use tpchgen_arrow::{
23 CustomerArrow, LineItemArrow, NationArrow, OrderArrow, PartArrow, PartSuppArrow,
24 RecordBatchIterator, RegionArrow, SupplierArrow,
25};
26
27#[derive(Debug)]
30pub struct PlanRunner {
31 plans: Vec<OutputPlan>,
32 num_threads: usize,
33 progress: RunProgress,
34}
35
36impl PlanRunner {
37 pub fn new(plans: Vec<OutputPlan>, num_threads: usize) -> Self {
40 Self {
41 plans,
42 num_threads,
43 progress: RunProgress::default(),
44 }
45 }
46
47 pub fn with_progress_tracker(mut self, tracker: Arc<dyn ProgressTracker>) -> Self {
55 self.progress = RunProgress::with_tracker(tracker);
56 self
57 }
58
59 pub async fn run(self) -> Result<(), io::Error> {
61 debug!(
62 "Running {} plans with {} threads...",
63 self.plans.len(),
64 self.num_threads
65 );
66 let Self {
67 mut plans,
68 num_threads,
69 progress,
70 } = self;
71
72 plans.sort_unstable_by(|a, b| {
74 let a_cnt = a.chunk_count();
75 let b_cnt = b.chunk_count();
76 a_cnt.cmp(&b_cnt)
77 });
78
79 progress.register_totals(&plans);
82
83 let mut worker_queue = WorkerQueue::new(num_threads, progress.clone());
85 while let Some(plan) = plans.pop() {
86 worker_queue.schedule_plan(plan).await?;
87 }
88 worker_queue.join_all().await?;
89 progress.finish();
90 Ok(())
91 }
92}
93
94struct WorkerQueue {
111 join_set: JoinSet<io::Result<usize>>,
112 available_threads: usize,
114 progress: RunProgress,
115}
116
117impl WorkerQueue {
118 pub fn new(max_threads: usize, progress: RunProgress) -> Self {
119 assert!(max_threads > 0);
120 Self {
121 join_set: JoinSet::new(),
122 available_threads: max_threads,
123 progress,
124 }
125 }
126
127 pub async fn schedule_plan(&mut self, plan: OutputPlan) -> io::Result<()> {
137 debug!("scheduling plan {plan}");
138 loop {
139 if self.available_threads == 0 {
140 debug!("no threads left, wait for one to finish");
141 let Some(result) = self.join_set.join_next().await else {
142 return Err(io::Error::other(
143 "Internal Error No more tasks to wait for, but had no threads",
144 ));
145 };
146 self.available_threads += task_result(result)?;
147 continue; }
149
150 if let Some(result) = self.join_set.try_join_next() {
152 self.available_threads += task_result(result)?;
153 continue;
154 }
155
156 debug_assert!(
157 self.available_threads > 0,
158 "should have at least one thread to continue"
159 );
160
161 let chunk_count = plan.chunk_count();
164
165 let num_plan_threads = self.available_threads.min(chunk_count);
166
167 debug!("Spawning plan {plan} with {num_plan_threads} threads");
169
170 let progress = self.progress.clone();
171 self.join_set
172 .spawn(async move { run_plan(plan, num_plan_threads, progress).await });
173 self.available_threads -= num_plan_threads;
174 return Ok(());
175 }
176 }
177
178 pub async fn join_all(mut self) -> io::Result<()> {
180 debug!("Waiting for tasks to finish...");
181 while let Some(result) = self.join_set.join_next().await {
182 task_result(result)?;
183 }
184 debug!("Tasks finished.");
185 Ok(())
186 }
187}
188
189fn task_result<T>(result: Result<io::Result<T>, JoinError>) -> io::Result<T> {
191 result.map_err(|e| io::Error::other(format!("Task Panic: {e}")))?
192}
193
194async fn run_plan(
196 plan: OutputPlan,
197 num_threads: usize,
198 progress: RunProgress,
199) -> io::Result<usize> {
200 match plan.table() {
201 Table::Nation => run_nation_plan(plan, num_threads, progress).await,
202 Table::Region => run_region_plan(plan, num_threads, progress).await,
203 Table::Part => run_part_plan(plan, num_threads, progress).await,
204 Table::Supplier => run_supplier_plan(plan, num_threads, progress).await,
205 Table::Partsupp => run_partsupp_plan(plan, num_threads, progress).await,
206 Table::Customer => run_customer_plan(plan, num_threads, progress).await,
207 Table::Orders => run_orders_plan(plan, num_threads, progress).await,
208 Table::Lineitem => run_lineitem_plan(plan, num_threads, progress).await,
209 }
210}
211
212fn maybe_skip_existing(path: &std::path::Path, plan: &OutputPlan, progress: &RunProgress) -> bool {
216 if !path.exists() {
217 return false;
218 }
219 log::warn!("{} already exists, skipping generation", path.display());
220 progress.increment_for_existing(plan);
221 true
222}
223
224async fn write_file<I>(
226 plan: OutputPlan,
227 num_threads: usize,
228 sources: I,
229 progress: RunProgress,
230) -> Result<(), io::Error>
231where
232 I: Iterator<Item: Source> + 'static,
233{
234 let table = plan.table();
235 let table_progress = progress.for_table(table);
236 match plan.output_location() {
239 OutputLocation::Stdout => {
240 let sink = WriterSink::new(io::stdout());
241 generate_in_chunks_with_progress(sink, sources, num_threads, table_progress).await
242 }
243 OutputLocation::File(path) => {
244 if maybe_skip_existing(path, &plan, &progress) {
245 return Ok(());
246 }
247 let temp_path = path.with_extension("inprogress");
249 let file = std::fs::File::create(&temp_path).map_err(|err| {
250 io::Error::other(format!("Failed to create {temp_path:?}: {err}"))
251 })?;
252 let sink = WriterSink::new(file);
253 generate_in_chunks_with_progress(sink, sources, num_threads, table_progress).await?;
254 std::fs::rename(&temp_path, path).map_err(|e| {
256 io::Error::other(format!(
257 "Failed to rename {temp_path:?} to {path:?} file: {e}"
258 ))
259 })?;
260 Ok(())
261 }
262 }
263}
264
265async fn write_parquet<I>(
267 plan: OutputPlan,
268 num_threads: usize,
269 sources: I,
270 progress: RunProgress,
271) -> Result<(), io::Error>
272where
273 I: Iterator<Item: RecordBatchIterator> + 'static,
274{
275 let table = plan.table();
276 let table_progress = progress.for_table(table);
277 match plan.output_location() {
278 OutputLocation::Stdout => {
279 let writer = BufWriter::with_capacity(32 * 1024 * 1024, io::stdout()); generate_parquet_with_progress(
281 writer,
282 sources,
283 num_threads,
284 plan.parquet_compression(),
285 table_progress,
286 )
287 .await
288 }
289 OutputLocation::File(path) => {
290 if maybe_skip_existing(path, &plan, &progress) {
291 return Ok(());
292 }
293 let temp_path = path.with_extension("inprogress");
295 let file = std::fs::File::create(&temp_path).map_err(|err| {
296 io::Error::other(format!("Failed to create {temp_path:?}: {err}"))
297 })?;
298 let writer = BufWriter::with_capacity(32 * 1024 * 1024, file); generate_parquet_with_progress(
300 writer,
301 sources,
302 num_threads,
303 plan.parquet_compression(),
304 table_progress,
305 )
306 .await?;
307 std::fs::rename(&temp_path, path).map_err(|e| {
309 io::Error::other(format!(
310 "Failed to rename {temp_path:?} to {path:?} file: {e}"
311 ))
312 })?;
313 Ok(())
314 }
315 }
316}
317
318macro_rules! define_run {
327 ($FUN_NAME:ident, $GENERATOR:ident, $TBL_SOURCE:ty, $CSV_SOURCE:ty, $PARQUET_SOURCE:ty) => {
328 async fn $FUN_NAME(
329 plan: OutputPlan,
330 num_threads: usize,
331 progress: RunProgress,
332 ) -> io::Result<usize> {
333 use crate::tpch_cli::GenerationPlan;
334 let scale_factor = plan.scale_factor();
335 info!("Writing {plan} using {num_threads} threads");
336
337 fn tbl_sources(
346 generation_plan: &GenerationPlan,
347 scale_factor: f64,
348 ) -> impl Iterator<Item: Source> + 'static {
349 generation_plan
350 .clone()
351 .into_iter()
352 .map(move |(part, num_parts)| $GENERATOR::new(scale_factor, part, num_parts))
353 .map(<$TBL_SOURCE>::new)
354 }
355
356 fn csv_sources(
357 generation_plan: &GenerationPlan,
358 scale_factor: f64,
359 delimiter: char,
360 ) -> impl Iterator<Item: Source> + 'static {
361 generation_plan
362 .clone()
363 .into_iter()
364 .map(move |(part, num_parts)| $GENERATOR::new(scale_factor, part, num_parts))
365 .map(move |gen| <$CSV_SOURCE>::new(gen, delimiter))
366 }
367
368 fn parquet_sources(
369 generation_plan: &GenerationPlan,
370 scale_factor: f64,
371 ) -> impl Iterator<Item: RecordBatchIterator> + 'static {
372 generation_plan
373 .clone()
374 .into_iter()
375 .map(move |(part, num_parts)| $GENERATOR::new(scale_factor, part, num_parts))
376 .map(<$PARQUET_SOURCE>::new)
377 }
378
379 match plan.output_format() {
381 OutputFormat::Tbl => {
382 let gens = tbl_sources(plan.generation_plan(), scale_factor);
383 write_file(plan, num_threads, gens, progress).await?
384 }
385 OutputFormat::Csv => {
386 let delimiter = plan.csv_delimiter();
387 let gens = csv_sources(plan.generation_plan(), scale_factor, delimiter);
388 write_file(plan, num_threads, gens, progress).await?
389 }
390 OutputFormat::Parquet => {
391 let gens = parquet_sources(plan.generation_plan(), scale_factor);
392 write_parquet(plan, num_threads, gens, progress).await?
393 }
394 };
395 Ok(num_threads)
396 }
397 };
398}
399
400define_run!(
401 run_lineitem_plan,
402 LineItemGenerator,
403 LineItemTblSource,
404 LineItemCsvSource,
405 LineItemArrow
406);
407
408define_run!(
409 run_nation_plan,
410 NationGenerator,
411 NationTblSource,
412 NationCsvSource,
413 NationArrow
414);
415
416define_run!(
417 run_region_plan,
418 RegionGenerator,
419 RegionTblSource,
420 RegionCsvSource,
421 RegionArrow
422);
423
424define_run!(
425 run_part_plan,
426 PartGenerator,
427 PartTblSource,
428 PartCsvSource,
429 PartArrow
430);
431
432define_run!(
433 run_supplier_plan,
434 SupplierGenerator,
435 SupplierTblSource,
436 SupplierCsvSource,
437 SupplierArrow
438);
439define_run!(
440 run_partsupp_plan,
441 PartSuppGenerator,
442 PartSuppTblSource,
443 PartSuppCsvSource,
444 PartSuppArrow
445);
446
447define_run!(
448 run_customer_plan,
449 CustomerGenerator,
450 CustomerTblSource,
451 CustomerCsvSource,
452 CustomerArrow
453);
454
455define_run!(
456 run_orders_plan,
457 OrderGenerator,
458 OrderTblSource,
459 OrderCsvSource,
460 OrderArrow
461);
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466 use crate::tpch_cli::progress::ProgressTracker;
467 use crate::tpch_cli::{Compression, GenerationPlan, DEFAULT_PARQUET_ROW_GROUP_BYTES};
468 use std::sync::{
469 atomic::{AtomicU64, Ordering},
470 Arc,
471 };
472
473 #[derive(Debug)]
474 struct CountingProgress {
475 increments: AtomicU64,
476 }
477
478 impl ProgressTracker for CountingProgress {
479 fn increment(&self, _table: Table, units: u64) {
480 self.increments.fetch_add(units, Ordering::Relaxed);
481 }
482 }
483
484 #[test]
485 fn skip_existing_advances_progress_by_full_plan() {
486 let output_dir = tempfile::tempdir().unwrap();
487 let output_path = output_dir.path().join("lineitem.tbl");
488 std::fs::write(&output_path, b"already here").unwrap();
489
490 let generation_plan = GenerationPlan::try_new(
491 Table::Lineitem,
492 OutputFormat::Tbl,
493 1.0,
494 Some(1),
495 Some(4),
496 DEFAULT_PARQUET_ROW_GROUP_BYTES,
497 )
498 .unwrap();
499 let plan = OutputPlan::new(
500 Table::Lineitem,
501 1.0,
502 OutputFormat::Tbl,
503 Compression::SNAPPY,
504 OutputLocation::File(output_path.clone()),
505 generation_plan,
506 ',',
507 );
508 let expected_units = plan.chunk_count() as u64;
509 assert!(expected_units > 1);
510
511 let tracker = Arc::new(CountingProgress {
512 increments: AtomicU64::new(0),
513 });
514 let progress: Arc<dyn ProgressTracker> = tracker.clone();
515 let progress = RunProgress::with_tracker(progress);
516
517 assert!(maybe_skip_existing(&output_path, &plan, &progress));
518 assert_eq!(tracker.increments.load(Ordering::Relaxed), expected_units);
519 }
520}