Skip to main content

wordchipper_cli_util/io/
input_batcher.rs

1use std::io::{
2    BufRead,
3    BufReader,
4};
5
6use arrow::array::{
7    Array,
8    StringArray,
9};
10use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
11
12/// File formats for the train command.
13#[derive(Debug, Clone, Copy, clap::ValueEnum)]
14pub enum FileFormat {
15    /// Simple text files.
16    Text,
17
18    /// Parquet files.
19    Parquet,
20}
21
22/// Args for batched input.
23#[derive(clap::Args, Debug)]
24pub struct BatchedInputArgs {
25    /// Input files.
26    files: Vec<String>,
27
28    /// The input shard file format.
29    #[arg(long)]
30    input_format: FileFormat,
31
32    /// The input batch size.
33    #[arg(long, default_value = "100")]
34    input_batch_size: usize,
35}
36
37impl BatchedInputArgs {
38    /// Run the function for each batch.
39    pub fn for_each_batch<F>(
40        &self,
41        f: &mut F,
42    ) -> Result<(), Box<dyn std::error::Error>>
43    where
44        F: FnMut(&[String]) -> Result<bool, Box<dyn std::error::Error>>,
45    {
46        InputBatcher::new(self.input_format, self.files.clone())
47            .with_batch_size(self.input_batch_size)
48            .for_each_batch(f)
49    }
50}
51
52/// Batcher for input files.
53pub struct InputBatcher {
54    pub format: FileFormat,
55    pub files: Vec<String>,
56    pub batch_size: usize,
57}
58
59impl InputBatcher {
60    /// Create a new input batcher.
61    pub fn new(
62        format: FileFormat,
63        files: Vec<String>,
64    ) -> Self {
65        Self {
66            format,
67            files,
68            batch_size: 1000,
69        }
70    }
71
72    /// Set the batch size.
73    pub fn with_batch_size(
74        mut self,
75        batch_size: usize,
76    ) -> Self {
77        self.batch_size = batch_size;
78        self
79    }
80
81    /// Run the function for each batch.
82    pub fn for_each_batch<F>(
83        &self,
84        f: &mut F,
85    ) -> Result<(), Box<dyn std::error::Error>>
86    where
87        F: FnMut(&[String]) -> Result<bool, Box<dyn std::error::Error>>,
88    {
89        log::info!("InputBatcher: Batch Size: {}", self.batch_size);
90        log::info!("InputBatcher: Input Format: {:?}", self.format);
91        let mut buffer: Vec<String> = Vec::with_capacity(self.batch_size);
92        for (idx, path) in self.files.iter().enumerate() {
93            log::info!("File {idx}: {path}");
94
95            if !self.for_each_path_item(path, &mut |item| {
96                buffer.push(item.to_string());
97
98                if buffer.len() >= self.batch_size {
99                    let chunk: Vec<String> = buffer.drain(..self.batch_size).collect();
100                    return f(&chunk);
101                }
102
103                Ok(true)
104            })? {
105                return Ok(());
106            }
107        }
108        if !buffer.is_empty() {
109            f(&buffer)?;
110        }
111        Ok(())
112    }
113
114    fn for_each_path_item<F>(
115        &self,
116        path: &str,
117        f: &mut F,
118    ) -> Result<bool, Box<dyn std::error::Error>>
119    where
120        F: FnMut(&String) -> Result<bool, Box<dyn std::error::Error>>,
121    {
122        match self.format {
123            FileFormat::Text => self.for_each_text_item(path, f),
124            FileFormat::Parquet => self.for_each_parquet_item(path, f),
125        }
126    }
127
128    fn for_each_text_item<F>(
129        &self,
130        path: &str,
131        f: &mut F,
132    ) -> Result<bool, Box<dyn std::error::Error>>
133    where
134        F: FnMut(&String) -> Result<bool, Box<dyn std::error::Error>>,
135    {
136        let reader = BufReader::new(std::fs::File::open(path)?);
137        for line in reader.lines() {
138            let line = line?;
139            if !f(&line)? {
140                return Ok(false);
141            }
142        }
143        Ok(true)
144    }
145
146    fn for_each_parquet_item<F>(
147        &self,
148        path: &str,
149        f: &mut F,
150    ) -> Result<bool, Box<dyn std::error::Error>>
151    where
152        F: FnMut(&String) -> Result<bool, Box<dyn std::error::Error>>,
153    {
154        let file = std::fs::File::open(path)?;
155        let reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?;
156        for batch in reader {
157            let batch = batch?;
158
159            let samples = batch
160                .column_by_name("text")
161                .expect("failed to find 'text' column in batch")
162                .as_any()
163                .downcast_ref::<StringArray>()
164                .unwrap()
165                .iter()
166                .filter_map(|s| s.map(|s| s.to_string()));
167
168            for sample in samples {
169                if !f(&sample)? {
170                    return Ok(false);
171                }
172            }
173        }
174
175        Ok(true)
176    }
177}