wordchipper_cli_util/io/
input_batcher.rs1use std::io::{
2 BufRead,
3 BufReader,
4};
5
6use arrow::array::{
7 Array,
8 StringArray,
9};
10use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
11
12#[derive(Debug, Clone, Copy, clap::ValueEnum)]
14pub enum FileFormat {
15 Text,
17
18 Parquet,
20}
21
22#[derive(clap::Args, Debug)]
24pub struct BatchedInputArgs {
25 files: Vec<String>,
27
28 #[arg(long)]
30 input_format: FileFormat,
31
32 #[arg(long, default_value = "100")]
34 input_batch_size: usize,
35}
36
37impl BatchedInputArgs {
38 pub fn for_each_batch<F>(
40 &self,
41 f: &mut F,
42 ) -> Result<(), Box<dyn std::error::Error>>
43 where
44 F: FnMut(&[String]) -> Result<bool, Box<dyn std::error::Error>>,
45 {
46 InputBatcher::new(self.input_format, self.files.clone())
47 .with_batch_size(self.input_batch_size)
48 .for_each_batch(f)
49 }
50}
51
52pub struct InputBatcher {
54 pub format: FileFormat,
55 pub files: Vec<String>,
56 pub batch_size: usize,
57}
58
59impl InputBatcher {
60 pub fn new(
62 format: FileFormat,
63 files: Vec<String>,
64 ) -> Self {
65 Self {
66 format,
67 files,
68 batch_size: 1000,
69 }
70 }
71
72 pub fn with_batch_size(
74 mut self,
75 batch_size: usize,
76 ) -> Self {
77 self.batch_size = batch_size;
78 self
79 }
80
81 pub fn for_each_batch<F>(
83 &self,
84 f: &mut F,
85 ) -> Result<(), Box<dyn std::error::Error>>
86 where
87 F: FnMut(&[String]) -> Result<bool, Box<dyn std::error::Error>>,
88 {
89 log::info!("InputBatcher: Batch Size: {}", self.batch_size);
90 log::info!("InputBatcher: Input Format: {:?}", self.format);
91 let mut buffer: Vec<String> = Vec::with_capacity(self.batch_size);
92 for (idx, path) in self.files.iter().enumerate() {
93 log::info!("File {idx}: {path}");
94
95 if !self.for_each_path_item(path, &mut |item| {
96 buffer.push(item.to_string());
97
98 if buffer.len() >= self.batch_size {
99 let chunk: Vec<String> = buffer.drain(..self.batch_size).collect();
100 return f(&chunk);
101 }
102
103 Ok(true)
104 })? {
105 return Ok(());
106 }
107 }
108 if !buffer.is_empty() {
109 f(&buffer)?;
110 }
111 Ok(())
112 }
113
114 fn for_each_path_item<F>(
115 &self,
116 path: &str,
117 f: &mut F,
118 ) -> Result<bool, Box<dyn std::error::Error>>
119 where
120 F: FnMut(&String) -> Result<bool, Box<dyn std::error::Error>>,
121 {
122 match self.format {
123 FileFormat::Text => self.for_each_text_item(path, f),
124 FileFormat::Parquet => self.for_each_parquet_item(path, f),
125 }
126 }
127
128 fn for_each_text_item<F>(
129 &self,
130 path: &str,
131 f: &mut F,
132 ) -> Result<bool, Box<dyn std::error::Error>>
133 where
134 F: FnMut(&String) -> Result<bool, Box<dyn std::error::Error>>,
135 {
136 let reader = BufReader::new(std::fs::File::open(path)?);
137 for line in reader.lines() {
138 let line = line?;
139 if !f(&line)? {
140 return Ok(false);
141 }
142 }
143 Ok(true)
144 }
145
146 fn for_each_parquet_item<F>(
147 &self,
148 path: &str,
149 f: &mut F,
150 ) -> Result<bool, Box<dyn std::error::Error>>
151 where
152 F: FnMut(&String) -> Result<bool, Box<dyn std::error::Error>>,
153 {
154 let file = std::fs::File::open(path)?;
155 let reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?;
156 for batch in reader {
157 let batch = batch?;
158
159 let samples = batch
160 .column_by_name("text")
161 .expect("failed to find 'text' column in batch")
162 .as_any()
163 .downcast_ref::<StringArray>()
164 .unwrap()
165 .iter()
166 .filter_map(|s| s.map(|s| s.to_string()));
167
168 for sample in samples {
169 if !f(&sample)? {
170 return Ok(false);
171 }
172 }
173 }
174
175 Ok(true)
176 }
177}