shard_csv/
sharded_writer.rs

1use crate::{shard, Error, FileSplitting};
2use csv::StringRecord;
3use std::{
4    collections::{hash_map::Entry, HashMap},
5    io::{BufWriter, Write},
6    path::Path,
7    rc::Rc,
8};
9
10pub struct ShardedWriterBuilder {
11    header: Option<StringRecord>,
12}
13
14impl ShardedWriterBuilder {
15    /// Start creating a sharded writer for data that don't have a header.
16    pub fn new_without_header() -> Self {
17        ShardedWriterBuilder { header: None }
18    }
19
20    /// Start creating a sharded writer for data with the specified `header`.
21    pub fn new_with_header<T>(header: T) -> Self
22    where
23        T: Into<StringRecord>,
24    {
25        ShardedWriterBuilder {
26            header: Some(header.into()),
27        }
28    }
29
30    /// Start creating a sharded writer from the specified [`csv::Reader`]
31    ///
32    /// The reader's header settings will be copied over to the sharded writer.
33    pub fn new_from_csv_reader<T>(csv: &mut csv::Reader<T>) -> Result<Self, Error>
34    where
35        T: std::io::Read,
36    {
37        let header = if csv.has_headers() {
38            Some(csv.headers()?.clone())
39        } else {
40            None
41        };
42
43        Ok(Self { header })
44    }
45
46    /// Specifies how the input will be sharded.
47    ///
48    /// Given a row of input, the key selector determines which shard the record belongs in.
49    pub fn with_key_selector<FKey>(self, key_selector: FKey) -> ShardedWriterWithKey<FKey>
50    where
51        FKey: Fn(&StringRecord) -> String,
52    {
53        ShardedWriterWithKey {
54            header: self.header,
55            key_selector,
56        }
57    }
58}
59
60pub struct ShardedWriterWithKey<FKey> {
61    header: Option<StringRecord>,
62    key_selector: FKey,
63}
64
65impl<FKey> ShardedWriterWithKey<FKey>
66where
67    FKey: Fn(&StringRecord) -> String,
68{
69    /// Specifies how output shard files will be named.
70    ///
71    /// The specified function will be called with the key value (derived from the `key_selector`
72    /// passed to [`ShardedWriterBuilder::with_key_selector`]) and the current sequence number,
73    /// which is a zero-based number identifying how many files have been written for this shard.
74    pub fn with_output_shard_naming<FNameFile>(
75        self,
76        create_output_filename: FNameFile,
77    ) -> ShardedWriter<FKey, FNameFile>
78    where
79        FNameFile: Fn(&str, usize) -> String,
80    {
81        let ShardedWriterWithKey {
82            header,
83            key_selector,
84        } = self;
85
86        ShardedWriter {
87            header_record: header,
88            key_selector,
89            output_splitting: FileSplitting::NoSplit,
90            output_delimiter: b',',
91            on_file_completion: None,
92            create_file_writer: default_create_file_writer,
93            create_output_filename: Rc::new(create_output_filename),
94            handles: HashMap::new(),
95        }
96    }
97}
98
99pub struct ShardedWriter<FKey, FNameFile>
100where
101    FNameFile: Fn(&str, usize) -> String,
102{
103    /// How the input file should be split
104    output_splitting: FileSplitting,
105
106    /// The field delimiter; default is ','
107    output_delimiter: u8,
108
109    /// A closure that accepts a CSV row and returns a String identifying which shard it belongs to.
110    key_selector: FKey,
111
112    /// An optional header record that will be written to every output file.
113    header_record: Option<StringRecord>,
114
115    /// A function that will be called when an intermediate file is completed
116    on_file_completion: Option<fn(&Path, &str)>,
117
118    create_output_filename: Rc<FNameFile>,
119
120    /// A function that creates a writer for a requested output file path
121    create_file_writer: crate::shard::CreateFileWriter,
122
123    /// A mapping of shard keys to the shards that output to files
124    handles: HashMap<String, shard::Shard<FNameFile>>,
125}
126
127impl<FKey, FNameFile> std::fmt::Debug for ShardedWriter<FKey, FNameFile>
128where
129    FNameFile: Fn(&str, usize) -> String,
130{
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        f.debug_struct("ShardedWriter")
133            .field("output_splitting", &self.output_splitting)
134            .field("delimiter", &self.output_delimiter)
135            .finish()
136    }
137}
138
139impl<FKey, FNameFile> ShardedWriter<FKey, FNameFile>
140where
141    FKey: Fn(&StringRecord) -> String,
142    FNameFile: Fn(&str, usize) -> String,
143{
144    /// Creates a new writer.
145    ///
146    /// You must specify the directory into which the output will be written, a function
147    /// that extracts the shard key from a csv [StringRecord], and how output files will
148    /// be named. The file naming function accepts the shard key and a zero-based number
149    /// indicating how many files have been created for this shard.
150    ///
151    /// This function can return an error if the output directory can't be created.
152    ///
153    /// ```
154    /// let writer = ShardedWriter::new(
155    ///     "./foo-sharded/",
156    ///     |record| record.get(7).unwrap_or("_unknown").to_string(),
157    ///     |shard, seq| format!("{}-file{}.csv", shard, seq)
158    /// )?;
159    /// ```
160
161    /// Specifies when sharded output files should be split.
162    pub fn with_output_splitting(mut self, output_splitting: FileSplitting) -> Self {
163        self.output_splitting = output_splitting;
164        self
165    }
166
167    /// Sets the field delimiter to be used for output files. Default is ','.
168    pub fn with_delimiter(mut self, delimiter: u8) -> Self {
169        self.output_delimiter = delimiter;
170        self
171    }
172
173    /// Sets an optional function that will be called when individual files are completed, either
174    /// because they have been split by the number of rows or bytes or because processing is
175    /// complete and the values are being dropped.
176    pub fn on_file_completion(mut self, f: fn(&Path, &str)) -> Self {
177        self.on_file_completion = Some(f);
178        self
179    }
180
181    /// Takes a closure that specifies how to create output files.
182    ///
183    /// The closure provides the [Path] of the output file to be created. If you don't
184    /// provide your own way to create output files, the default implementation will simply create
185    /// a new [BufWriter] for the output file, which is the same as:
186    ///
187    /// ```
188    /// my_sharded_writer.on_create_file(|path| Ok(BufWriter::new(File::create(path)?)));
189    /// ```
190    ///
191    /// This function may be useful if, for example, you want to inject gzip compression into the
192    /// output writer.
193    pub fn on_create_file(mut self, f: fn(&Path) -> std::io::Result<Box<dyn Write>>) -> Self {
194        self.create_file_writer = f;
195        self
196    }
197
198    /// Processes the input `filename`, creating output files according to the specified key
199    /// selector.
200    ///
201    /// This function will fail if the output directory or an output file can't be created or if a
202    /// row can't be written. It can also fail if it is called multiple times with files that have
203    /// different column counts.
204    ///
205    /// On success, the number of records written is returned.
206    pub fn process_file(&mut self, filename: &str) -> Result<usize, Error> {
207        let mut reader = csv::ReaderBuilder::new()
208            .delimiter(self.output_delimiter)
209            .has_headers(self.header_record.is_some())
210            .from_path(filename)?;
211
212        let records = reader.records().filter_map(|r| r.ok());
213        self.process_iter(records)
214    }
215
216    /// Processes the input reader, creating output files as appropriate.
217    ///
218    /// This function will fail if the output directory or an output file can't be created or if a
219    /// row can't be written. It can also fail if it is called multiple times with files that have
220    /// different column counts.
221    ///
222    /// On success, the number of records written is returned.
223    pub fn process_csv<T: std::io::Read>(
224        &mut self,
225        csv_reader: &mut csv::Reader<T>,
226    ) -> Result<usize, Error> {
227        let records = csv_reader.records().filter_map(|r| r.ok());
228
229        self.process_iter(records)
230    }
231
232    /// Processes an iterator of [std::io::Read], creating output files as appropriate.
233    pub fn process_reader(&mut self, reader: impl std::io::Read) -> Result<usize, Error> {
234        let mut reader = csv::ReaderBuilder::new()
235            .delimiter(self.output_delimiter)
236            .has_headers(self.header_record.is_some())
237            .from_reader(reader);
238
239        let records = reader.records().filter_map(|r| r.ok());
240
241        self.process_iter(records)
242    }
243
244    /// Iterates over every record, calculating the shard key for each, getting or creating the shard file,
245    /// and writing the record.
246    pub fn process_iter<T>(&mut self, records: T) -> Result<usize, Error>
247    where
248        T: IntoIterator<Item = StringRecord>,
249    {
250        let mut records_written = 0;
251        for record in records {
252            let key = (self.key_selector)(&record);
253
254            match self.handles.entry(key.clone()) {
255                Entry::Occupied(mut e) => {
256                    e.get_mut().write_record(&record)?;
257                }
258                Entry::Vacant(e) => {
259                    let header_record = self.header_record.clone();
260                    let create_output_filename = self.create_output_filename.clone();
261                    let mut shard = shard::Shard::new(
262                        self.output_splitting,
263                        key,
264                        header_record,
265                        self.create_file_writer,
266                        create_output_filename,
267                        self.on_file_completion,
268                    );
269
270                    shard.write_record(&record)?;
271                    e.insert(shard);
272                }
273            };
274
275            records_written += 1;
276        }
277
278        Ok(records_written)
279    }
280
281    /// Checks if `key` has been seen in the processed data.
282    pub fn is_shard_key_seen(&self, key: &str) -> bool {
283        self.handles.contains_key(key)
284    }
285
286    /// Returns a vec of all keys that have been seen.
287    pub fn shard_keys_seen(&self) -> Vec<String> {
288        self.handles.keys().cloned().collect()
289    }
290}
291
292/// The standard approach to creating a file writer -- create and buffer it.
293///
294/// To do something different (such as gzipping output), [ShardedWriter::on_create_file]
295/// is passed an alternate function with this signature.
296fn default_create_file_writer(path: &Path) -> std::io::Result<Box<dyn Write>> {
297    let writer = std::fs::File::create(path)?;
298    let buf = BufWriter::new(writer);
299    Ok(Box::new(buf))
300}