tpchgen_arrow/part.rs
1use crate::conversions::{decimal128_array_from_iter, string_view_array_from_display_iter};
2use crate::{DEFAULT_BATCH_SIZE, RecordBatchIterator};
3use arrow::array::{Int32Array, Int64Array, RecordBatch, StringViewArray};
4use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
5use std::sync::{Arc, LazyLock};
6use tpchgen::generators::{PartGenerator, PartGeneratorIterator};
7
8/// Generate [`Part`]s in [`RecordBatch`] format
9///
10/// [`Part`]: tpchgen::generators::Part
11///
12/// # Example
13/// ```
14/// # use tpchgen::generators::{PartGenerator};
15/// # use tpchgen_arrow::PartArrow;
16///
17/// // Create a SF=1.0 generator and wrap it in an Arrow generator
18/// let generator = PartGenerator::new(1.0, 1, 1);
19/// let mut arrow_generator = PartArrow::new(generator)
20/// .with_batch_size(10);
21/// // Read the first 10 batches
22/// let batch = arrow_generator.next().unwrap();
23/// // compare the output by pretty printing it
24/// let formatted_batches = arrow::util::pretty::pretty_format_batches(&[batch])
25/// .unwrap()
26/// .to_string();
27/// let lines = formatted_batches.lines().collect::<Vec<_>>();
28/// assert_eq!(lines, vec![
29/// "+-----------+------------------------------------------+----------------+----------+-------------------------+--------+-------------+---------------+----------------------+",
30/// "| p_partkey | p_name | p_mfgr | p_brand | p_type | p_size | p_container | p_retailprice | p_comment |",
31/// "+-----------+------------------------------------------+----------------+----------+-------------------------+--------+-------------+---------------+----------------------+",
32/// "| 1 | goldenrod lavender spring chocolate lace | Manufacturer#1 | Brand#13 | PROMO BURNISHED COPPER | 7 | JUMBO PKG | 901.00 | ly. slyly ironi |",
33/// "| 2 | blush thistle blue yellow saddle | Manufacturer#1 | Brand#13 | LARGE BRUSHED BRASS | 1 | LG CASE | 902.00 | lar accounts amo |",
34/// "| 3 | spring green yellow purple cornsilk | Manufacturer#4 | Brand#42 | STANDARD POLISHED BRASS | 21 | WRAP CASE | 903.00 | egular deposits hag |",
35/// "| 4 | cornflower chocolate smoke green pink | Manufacturer#3 | Brand#34 | SMALL PLATED BRASS | 14 | MED DRUM | 904.00 | p furiously r |",
36/// "| 5 | forest brown coral puff cream | Manufacturer#3 | Brand#32 | STANDARD POLISHED TIN | 15 | SM PKG | 905.00 | wake carefully |",
37/// "| 6 | bisque cornflower lawn forest magenta | Manufacturer#2 | Brand#24 | PROMO PLATED STEEL | 4 | MED BAG | 906.00 | sual a |",
38/// "| 7 | moccasin green thistle khaki floral | Manufacturer#1 | Brand#11 | SMALL PLATED COPPER | 45 | SM BAG | 907.00 | lyly. ex |",
39/// "| 8 | misty lace thistle snow royal | Manufacturer#4 | Brand#44 | PROMO BURNISHED TIN | 41 | LG DRUM | 908.00 | eposi |",
40/// "| 9 | thistle dim navajo dark gainsboro | Manufacturer#4 | Brand#43 | SMALL BURNISHED STEEL | 12 | WRAP CASE | 909.00 | ironic foxe |",
41/// "| 10 | linen pink saddle puff powder | Manufacturer#5 | Brand#54 | LARGE BURNISHED STEEL | 44 | LG CAN | 910.01 | ithely final deposit |",
42/// "+-----------+------------------------------------------+----------------+----------+-------------------------+--------+-------------+---------------+----------------------+"
43/// ]);
44/// ```
45pub struct PartArrow {
46 inner: PartGeneratorIterator<'static>,
47 batch_size: usize,
48}
49
50impl PartArrow {
51 pub fn new(generator: PartGenerator<'static>) -> Self {
52 Self {
53 inner: generator.iter(),
54 batch_size: DEFAULT_BATCH_SIZE,
55 }
56 }
57
58 /// Set the batch size
59 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
60 self.batch_size = batch_size;
61 self
62 }
63}
64
65impl RecordBatchIterator for PartArrow {
66 fn schema(&self) -> &SchemaRef {
67 &PART_SCHEMA
68 }
69}
70
71impl Iterator for PartArrow {
72 type Item = RecordBatch;
73
74 fn next(&mut self) -> Option<Self::Item> {
75 // Get next rows to convert
76 let rows: Vec<_> = self.inner.by_ref().take(self.batch_size).collect();
77 if rows.is_empty() {
78 return None;
79 }
80
81 let p_partkey = Int64Array::from_iter_values(rows.iter().map(|r| r.p_partkey));
82 let p_name = string_view_array_from_display_iter(rows.iter().map(|r| &r.p_name));
83 let p_mfgr = string_view_array_from_display_iter(rows.iter().map(|r| r.p_mfgr));
84 let p_brand = string_view_array_from_display_iter(rows.iter().map(|r| r.p_brand));
85 let p_type = StringViewArray::from_iter_values(rows.iter().map(|r| r.p_type));
86 let p_size = Int32Array::from_iter_values(rows.iter().map(|r| r.p_size));
87 let p_container = StringViewArray::from_iter_values(rows.iter().map(|r| r.p_container));
88 let p_retailprice = decimal128_array_from_iter(rows.iter().map(|r| r.p_retailprice));
89 let p_comment = StringViewArray::from_iter_values(rows.iter().map(|r| r.p_comment));
90
91 let batch = RecordBatch::try_new(
92 Arc::clone(self.schema()),
93 vec![
94 Arc::new(p_partkey),
95 Arc::new(p_name),
96 Arc::new(p_mfgr),
97 Arc::new(p_brand),
98 Arc::new(p_type),
99 Arc::new(p_size),
100 Arc::new(p_container),
101 Arc::new(p_retailprice),
102 Arc::new(p_comment),
103 ],
104 )
105 .unwrap();
106 Some(batch)
107 }
108}
109
110/// Schema for the Part
111static PART_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(make_part_schema);
112fn make_part_schema() -> SchemaRef {
113 Arc::new(Schema::new(vec![
114 Field::new("p_partkey", DataType::Int64, false),
115 Field::new("p_name", DataType::Utf8View, false),
116 Field::new("p_mfgr", DataType::Utf8View, false),
117 Field::new("p_brand", DataType::Utf8View, false),
118 Field::new("p_type", DataType::Utf8View, false),
119 Field::new("p_size", DataType::Int32, false),
120 Field::new("p_container", DataType::Utf8View, false),
121 Field::new("p_retailprice", DataType::Decimal128(15, 2), false),
122 Field::new("p_comment", DataType::Utf8View, false),
123 ]))
124}