shard_csv/sharded_writer.rs
1use crate::{shard, Error, FileSplitting};
2use csv::StringRecord;
3use std::{
4 collections::{hash_map::Entry, HashMap},
5 io::{BufWriter, Write},
6 path::Path,
7 rc::Rc,
8};
9
10pub struct ShardedWriterBuilder {
11 header: Option<StringRecord>,
12}
13
14impl ShardedWriterBuilder {
15 /// Start creating a sharded writer for data that don't have a header.
16 pub fn new_without_header() -> Self {
17 ShardedWriterBuilder { header: None }
18 }
19
20 /// Start creating a sharded writer for data with the specified `header`.
21 pub fn new_with_header<T>(header: T) -> Self
22 where
23 T: Into<StringRecord>,
24 {
25 ShardedWriterBuilder {
26 header: Some(header.into()),
27 }
28 }
29
30 /// Start creating a sharded writer from the specified [`csv::Reader`]
31 ///
32 /// The reader's header settings will be copied over to the sharded writer.
33 pub fn new_from_csv_reader<T>(csv: &mut csv::Reader<T>) -> Result<Self, Error>
34 where
35 T: std::io::Read,
36 {
37 let header = if csv.has_headers() {
38 Some(csv.headers()?.clone())
39 } else {
40 None
41 };
42
43 Ok(Self { header })
44 }
45
46 /// Specifies how the input will be sharded.
47 ///
48 /// Given a row of input, the key selector determines which shard the record belongs in.
49 pub fn with_key_selector<FKey>(self, key_selector: FKey) -> ShardedWriterWithKey<FKey>
50 where
51 FKey: Fn(&StringRecord) -> String,
52 {
53 ShardedWriterWithKey {
54 header: self.header,
55 key_selector,
56 }
57 }
58}
59
60pub struct ShardedWriterWithKey<FKey> {
61 header: Option<StringRecord>,
62 key_selector: FKey,
63}
64
65impl<FKey> ShardedWriterWithKey<FKey>
66where
67 FKey: Fn(&StringRecord) -> String,
68{
69 /// Specifies how output shard files will be named.
70 ///
71 /// The specified function will be called with the key value (derived from the `key_selector`
72 /// passed to [`ShardedWriterBuilder::with_key_selector`]) and the current sequence number,
73 /// which is a zero-based number identifying how many files have been written for this shard.
74 pub fn with_output_shard_naming<FNameFile>(
75 self,
76 create_output_filename: FNameFile,
77 ) -> ShardedWriter<FKey, FNameFile>
78 where
79 FNameFile: Fn(&str, usize) -> String,
80 {
81 let ShardedWriterWithKey {
82 header,
83 key_selector,
84 } = self;
85
86 ShardedWriter {
87 header_record: header,
88 key_selector,
89 output_splitting: FileSplitting::NoSplit,
90 output_delimiter: b',',
91 on_file_completion: None,
92 create_file_writer: default_create_file_writer,
93 create_output_filename: Rc::new(create_output_filename),
94 handles: HashMap::new(),
95 }
96 }
97}
98
99pub struct ShardedWriter<FKey, FNameFile>
100where
101 FNameFile: Fn(&str, usize) -> String,
102{
103 /// How the input file should be split
104 output_splitting: FileSplitting,
105
106 /// The field delimiter; default is ','
107 output_delimiter: u8,
108
109 /// A closure that accepts a CSV row and returns a String identifying which shard it belongs to.
110 key_selector: FKey,
111
112 /// An optional header record that will be written to every output file.
113 header_record: Option<StringRecord>,
114
115 /// A function that will be called when an intermediate file is completed
116 on_file_completion: Option<fn(&Path, &str)>,
117
118 create_output_filename: Rc<FNameFile>,
119
120 /// A function that creates a writer for a requested output file path
121 create_file_writer: crate::shard::CreateFileWriter,
122
123 /// A mapping of shard keys to the shards that output to files
124 handles: HashMap<String, shard::Shard<FNameFile>>,
125}
126
127impl<FKey, FNameFile> std::fmt::Debug for ShardedWriter<FKey, FNameFile>
128where
129 FNameFile: Fn(&str, usize) -> String,
130{
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 f.debug_struct("ShardedWriter")
133 .field("output_splitting", &self.output_splitting)
134 .field("delimiter", &self.output_delimiter)
135 .finish()
136 }
137}
138
139impl<FKey, FNameFile> ShardedWriter<FKey, FNameFile>
140where
141 FKey: Fn(&StringRecord) -> String,
142 FNameFile: Fn(&str, usize) -> String,
143{
144 /// Creates a new writer.
145 ///
146 /// You must specify the directory into which the output will be written, a function
147 /// that extracts the shard key from a csv [StringRecord], and how output files will
148 /// be named. The file naming function accepts the shard key and a zero-based number
149 /// indicating how many files have been created for this shard.
150 ///
151 /// This function can return an error if the output directory can't be created.
152 ///
153 /// ```
154 /// let writer = ShardedWriter::new(
155 /// "./foo-sharded/",
156 /// |record| record.get(7).unwrap_or("_unknown").to_string(),
157 /// |shard, seq| format!("{}-file{}.csv", shard, seq)
158 /// )?;
159 /// ```
160
161 /// Specifies when sharded output files should be split.
162 pub fn with_output_splitting(mut self, output_splitting: FileSplitting) -> Self {
163 self.output_splitting = output_splitting;
164 self
165 }
166
167 /// Sets the field delimiter to be used for output files. Default is ','.
168 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
169 self.output_delimiter = delimiter;
170 self
171 }
172
173 /// Sets an optional function that will be called when individual files are completed, either
174 /// because they have been split by the number of rows or bytes or because processing is
175 /// complete and the values are being dropped.
176 pub fn on_file_completion(mut self, f: fn(&Path, &str)) -> Self {
177 self.on_file_completion = Some(f);
178 self
179 }
180
181 /// Takes a closure that specifies how to create output files.
182 ///
183 /// The closure provides the [Path] of the output file to be created. If you don't
184 /// provide your own way to create output files, the default implementation will simply create
185 /// a new [BufWriter] for the output file, which is the same as:
186 ///
187 /// ```
188 /// my_sharded_writer.on_create_file(|path| Ok(BufWriter::new(File::create(path)?)));
189 /// ```
190 ///
191 /// This function may be useful if, for example, you want to inject gzip compression into the
192 /// output writer.
193 pub fn on_create_file(mut self, f: fn(&Path) -> std::io::Result<Box<dyn Write>>) -> Self {
194 self.create_file_writer = f;
195 self
196 }
197
198 /// Processes the input `filename`, creating output files according to the specified key
199 /// selector.
200 ///
201 /// This function will fail if the output directory or an output file can't be created or if a
202 /// row can't be written. It can also fail if it is called multiple times with files that have
203 /// different column counts.
204 ///
205 /// On success, the number of records written is returned.
206 pub fn process_file(&mut self, filename: &str) -> Result<usize, Error> {
207 let mut reader = csv::ReaderBuilder::new()
208 .delimiter(self.output_delimiter)
209 .has_headers(self.header_record.is_some())
210 .from_path(filename)?;
211
212 let records = reader.records().filter_map(|r| r.ok());
213 self.process_iter(records)
214 }
215
216 /// Processes the input reader, creating output files as appropriate.
217 ///
218 /// This function will fail if the output directory or an output file can't be created or if a
219 /// row can't be written. It can also fail if it is called multiple times with files that have
220 /// different column counts.
221 ///
222 /// On success, the number of records written is returned.
223 pub fn process_csv<T: std::io::Read>(
224 &mut self,
225 csv_reader: &mut csv::Reader<T>,
226 ) -> Result<usize, Error> {
227 let records = csv_reader.records().filter_map(|r| r.ok());
228
229 self.process_iter(records)
230 }
231
232 /// Processes an iterator of [std::io::Read], creating output files as appropriate.
233 pub fn process_reader(&mut self, reader: impl std::io::Read) -> Result<usize, Error> {
234 let mut reader = csv::ReaderBuilder::new()
235 .delimiter(self.output_delimiter)
236 .has_headers(self.header_record.is_some())
237 .from_reader(reader);
238
239 let records = reader.records().filter_map(|r| r.ok());
240
241 self.process_iter(records)
242 }
243
244 /// Iterates over every record, calculating the shard key for each, getting or creating the shard file,
245 /// and writing the record.
246 pub fn process_iter<T>(&mut self, records: T) -> Result<usize, Error>
247 where
248 T: IntoIterator<Item = StringRecord>,
249 {
250 let mut records_written = 0;
251 for record in records {
252 let key = (self.key_selector)(&record);
253
254 match self.handles.entry(key.clone()) {
255 Entry::Occupied(mut e) => {
256 e.get_mut().write_record(&record)?;
257 }
258 Entry::Vacant(e) => {
259 let header_record = self.header_record.clone();
260 let create_output_filename = self.create_output_filename.clone();
261 let mut shard = shard::Shard::new(
262 self.output_splitting,
263 key,
264 header_record,
265 self.create_file_writer,
266 create_output_filename,
267 self.on_file_completion,
268 );
269
270 shard.write_record(&record)?;
271 e.insert(shard);
272 }
273 };
274
275 records_written += 1;
276 }
277
278 Ok(records_written)
279 }
280
281 /// Checks if `key` has been seen in the processed data.
282 pub fn is_shard_key_seen(&self, key: &str) -> bool {
283 self.handles.contains_key(key)
284 }
285
286 /// Returns a vec of all keys that have been seen.
287 pub fn shard_keys_seen(&self) -> Vec<String> {
288 self.handles.keys().cloned().collect()
289 }
290}
291
292/// The standard approach to creating a file writer -- create and buffer it.
293///
294/// To do something different (such as gzipping output), [ShardedWriter::on_create_file]
295/// is passed an alternate function with this signature.
296fn default_create_file_writer(path: &Path) -> std::io::Result<Box<dyn Write>> {
297 let writer = std::fs::File::create(path)?;
298 let buf = BufWriter::new(writer);
299 Ok(Box::new(buf))
300}