spatialbench_arrow/driver.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::conversions::string_view_array_from_display_iter;
19use crate::{RecordBatchIterator, DEFAULT_BATCH_SIZE};
20use arrow::array::{Int64Array, RecordBatch, StringViewArray};
21use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
22use spatialbench::generators::{DriverGenerator, DriverGeneratorIterator};
23use std::sync::{Arc, LazyLock};
24
25/// Generate [`Driver`]s in [`RecordBatch`] format
26///
27/// [`Driver`]: spatialbench::generators::Driver
28///
29/// # Example:
30/// ```
31/// # use spatialbench::generators::{DriverGenerator};
32/// # use spatialbench_arrow::DriverArrow;
33///
34/// // Create a SF=1.0 generator and wrap it in an Arrow generator
35/// let generator = DriverGenerator::new(1.0, 1, 1);
36/// let mut arrow_generator = DriverArrow::new(generator)
37/// .with_batch_size(10);
38/// // Read the first 10 batches
39/// let batch = arrow_generator.next().unwrap();
40/// // compare the output by pretty printing it
41/// let formatted_batches = arrow::util::pretty::pretty_format_batches(&[batch])
42/// .unwrap()
43/// .to_string();
44/// let lines = formatted_batches.lines().collect::<Vec<_>>();
45/// assert_eq!(lines, vec![
46/// "+-------------+------------------+-------------------------------------+-------------+----------------+-----------------+",
47/// "| d_driverkey | d_name | d_address | d_region | d_nation | d_phone |",
48/// "+-------------+------------------+-------------------------------------+-------------+----------------+-----------------+",
49/// "| 1 | Driver#000000001 | N kD4on9OM Ipw3,gf0JBoQDd7tgrzrddZ | AMERICA | PERU | 27-918-335-1736 |",
50/// "| 2 | Driver#000000002 | 89eJ5ksX3ImxJQBvxObC, | AFRICA | ETHIOPIA | 15-679-861-2259 |",
51/// "| 3 | Driver#000000003 | q1,G3Pj6OjIuUYfUoH18BFTKP5aU9bEV3 | AMERICA | ARGENTINA | 11-383-516-1199 |",
52/// "| 4 | Driver#000000004 | Bk7ah4CK8SYQTepEmvMkkgMwg | AFRICA | MOROCCO | 25-843-787-7479 |",
53/// "| 5 | Driver#000000005 | Gcdm2rJRzl5qlTVzc | MIDDLE EAST | IRAQ | 21-151-690-3663 |",
54/// "| 6 | Driver#000000006 | tQxuVm7s7CnK | AFRICA | KENYA | 24-696-997-4969 |",
55/// "| 7 | Driver#000000007 | s,4TicNGB4uO6PaSqNBUq | EUROPE | UNITED KINGDOM | 33-990-965-2201 |",
56/// "| 8 | Driver#000000008 | 9Sq4bBH2FQEmaFOocY45sRTxo6yuoG | AMERICA | PERU | 27-498-742-3860 |",
57/// "| 9 | Driver#000000009 | 1KhUgZegwM3ua7dsYmekYBsK | MIDDLE EAST | IRAN | 20-403-398-8662 |",
58/// "| 10 | Driver#000000010 | Saygah3gYWMp72i PY | AMERICA | UNITED STATES | 34-852-489-8585 |",
59/// "+-------------+------------------+-------------------------------------+-------------+----------------+-----------------+"
60/// ]);
61/// ```
62pub struct DriverArrow {
63 inner: DriverGeneratorIterator<'static>,
64 batch_size: usize,
65}
66
67impl DriverArrow {
68 pub fn new(generator: DriverGenerator<'static>) -> Self {
69 Self {
70 inner: generator.iter(),
71 batch_size: DEFAULT_BATCH_SIZE,
72 }
73 }
74
75 /// Set the batch size
76 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
77 self.batch_size = batch_size;
78 self
79 }
80}
81
82impl RecordBatchIterator for DriverArrow {
83 fn schema(&self) -> &SchemaRef {
84 &DRIVER_SCHEMA
85 }
86}
87
88impl Iterator for DriverArrow {
89 type Item = RecordBatch;
90
91 fn next(&mut self) -> Option<Self::Item> {
92 // Get next rows to convert
93 let rows: Vec<_> = self.inner.by_ref().take(self.batch_size).collect();
94 if rows.is_empty() {
95 return None;
96 }
97
98 let d_driverkey = Int64Array::from_iter_values(rows.iter().map(|r| r.d_driverkey));
99 let d_name = string_view_array_from_display_iter(rows.iter().map(|r| r.d_name));
100 let d_address = string_view_array_from_display_iter(rows.iter().map(|r| &r.d_address));
101 let d_region = StringViewArray::from_iter_values(rows.iter().map(|r| &r.d_region));
102 let d_nation = StringViewArray::from_iter_values(rows.iter().map(|r| &r.d_nation));
103 let d_phone = string_view_array_from_display_iter(rows.iter().map(|r| &r.d_phone));
104
105 let batch = RecordBatch::try_new(
106 Arc::clone(self.schema()),
107 vec![
108 Arc::new(d_driverkey),
109 Arc::new(d_name),
110 Arc::new(d_address),
111 Arc::new(d_region),
112 Arc::new(d_nation),
113 Arc::new(d_phone),
114 ],
115 )
116 .unwrap();
117 Some(batch)
118 }
119}
120
121/// Schema for the PartSupp
122static DRIVER_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(make_driver_schema);
123fn make_driver_schema() -> SchemaRef {
124 Arc::new(Schema::new(vec![
125 Field::new("d_driverkey", DataType::Int64, false),
126 Field::new("d_name", DataType::Utf8View, false),
127 Field::new("d_address", DataType::Utf8View, false),
128 Field::new("d_region", DataType::Utf8View, false),
129 Field::new("d_nation", DataType::Utf8View, false),
130 Field::new("d_phone", DataType::Utf8View, false),
131 ]))
132}