1use bisection::bisect_left;
24
25use crate::splitter::Splitter;
26
27
28pub struct Chunker {
59 chunk_size: usize,
60 token_counter: Box<dyn Fn(&str) -> usize>,
61 splitter: Splitter,
62}
63
64impl Chunker {
65 pub fn new(chunk_size: usize, token_counter: Box<dyn Fn(&str) -> usize>) -> Self {
76 Chunker {
77 chunk_size,
78 token_counter,
79 splitter: Splitter::default(),
80 }
81 }
82
83 pub fn splitter(mut self, splitter: Splitter) -> Self {
85 self.splitter = splitter;
86 self
87 }
88
89 pub fn _chunk(&self, text: &str, recursion_depth: usize) -> Vec<String> {
100 let (separator, separator_is_whitespace, text_splits) = self.splitter.split_text(text);
101
102 let mut chunks: Vec<String> = Vec::new();
103
104 let mut i = 0;
106 while i < text_splits.len() {
107 if (self.token_counter)(text_splits[i]) > self.chunk_size {
108 let sub_chunks = self._chunk(text_splits[i], recursion_depth + 1);
110 for sub_chunk in sub_chunks {
111 chunks.push(sub_chunk);
112 }
113 i += 1;
114 } else {
115 let (split_idx, merged_chunk) = self.merge_splits(&text_splits[i..], separator);
117 chunks.push(merged_chunk);
118 i += split_idx;
119 }
120
121 let n_chunks = chunks.len();
122 if !separator_is_whitespace && i < text_splits.len() {
124 let last_chunk_with_separator = chunks[n_chunks - 1].clone() + separator;
125 if (self.token_counter)(&last_chunk_with_separator) <= self.chunk_size {
126 chunks[n_chunks - 1] = last_chunk_with_separator;
127 } else {
128 chunks.push(separator.to_string());
129 }
130 }
131 }
132 if recursion_depth > 0 {
133 chunks = chunks
134 .iter()
135 .filter(|&c| !c.is_empty())
136 .map(|c| c.to_string())
137 .collect();
138 }
139 chunks
140 }
141
142 pub fn merge_splits(&self, splits: &[&str], separator: &str) -> (usize, String) {
167 let mut low = 0;
168 let mut high = splits.len();
169
170 let mut n_tokens: usize;
171 let mut tokens_per_split = 5.0;
172 let cumulative_split_char_counts = splits
173 .iter()
174 .scan(0, |acc, &s| {
175 *acc += s.len() as u64;
176 Some(*acc)
177 })
178 .collect::<Vec<u64>>();
179
180 while low < high {
181 let increment_by = bisect_left(
183 &cumulative_split_char_counts[low..high],
184 &((self.chunk_size as f64 * tokens_per_split) as u64),
185 );
186 let est_midpoint = std::cmp::min(low + increment_by, high - 1);
187 n_tokens =
188 (self.token_counter)(splits.get(..est_midpoint).unwrap().join(separator).as_ref());
189
190 match n_tokens.cmp(&self.chunk_size) {
191 std::cmp::Ordering::Greater => high = est_midpoint,
192 std::cmp::Ordering::Equal => {
193 low = est_midpoint;
194 break;
195 }
196 std::cmp::Ordering::Less => low = est_midpoint + 1,
197 }
198
199 if n_tokens > 0 && cumulative_split_char_counts[est_midpoint] > 0 {
200 tokens_per_split =
201 n_tokens as f64 / cumulative_split_char_counts[est_midpoint] as f64;
202 }
203 }
204 (low, splits.get(..low).unwrap().join(separator))
205 }
206
207 pub fn chunk(&self, text: &str) -> Vec<String> {
224 self._chunk(text, 0)
225 }
226}
227
228
229#[cfg(test)]
230mod chunker_tests {
231 use super::*;
232 use std::io::Read;
233 use std::path::PathBuf;
234
235 #[cfg(feature = "rust_tokenizers")]
236 use rust_tokenizers::tokenizer::{RobertaTokenizer, Tokenizer};
237
238 fn get_data_path() -> PathBuf {
239 PathBuf::from(std::env::var("DATA_DIR").unwrap_or_else(|_| ".".to_string()))
240 }
241
242 fn get_roberta_vocab_path() -> PathBuf {
243 get_data_path().join("roberta-base-vocab.json")
244 }
245
246 fn get_roberta_merges_path() -> PathBuf {
247 get_data_path().join("roberta-base-merges.txt")
248 }
249
250 fn get_gutenberg_path() -> PathBuf {
251 get_data_path().join("gutenberg")
252 }
253
254 fn get_gutenberg_corpus_path(corpus_filename: &str) -> PathBuf {
255 get_gutenberg_path().join(corpus_filename)
256 }
257
258 fn read_gutenberg_corpus(corpus_filename: &str) -> String {
259 let mut file = std::fs::File::open(get_gutenberg_corpus_path(corpus_filename))
260 .expect("Error opening file");
261 let mut buffer = Vec::new();
262 file.read_to_end(&mut buffer).expect("Error reading file");
263 String::from_utf8_lossy(&buffer).to_string()
264 }
265
266 #[test]
267 #[cfg(feature = "rust_tokenizers")]
268 fn test_chunk_rust_tokenizers() {
269 let tokenizer = RobertaTokenizer::from_file(
270 get_roberta_vocab_path(),
271 get_roberta_merges_path(),
272 false,
273 false,
274 )
275 .expect("Error loading tokenizer");
276
277 let token_counter = Box::new(move |s: &str| tokenizer.tokenize(s).len());
278 let chunker = Chunker::new(10, token_counter);
279 let text = "The quick brown fox jumps over the lazy dog.\n\nThe subject is\n\t- \"The quick brown fox\"\n\t- \"jumps over\"\n\t- \"the lazy dog\"";
280 println!("Text: {}", text);
281 let chunks = chunker.chunk(text);
282 assert_eq!(
283 chunks,
284 vec![
285 "The quick brown fox jumps over the lazy dog.",
286 "The subject is\n\t- \"The quick brown fox\"",
287 "\t- \"jumps over\"\n\t- \"the lazy dog\"",
288 ]
289 )
290 }
291
292 #[test]
293 #[cfg(feature = "rust_tokenizers")]
294 fn test_chunk_rust_tokenizers_gutenberg_austen_emma() {
295 let tokenizer = RobertaTokenizer::from_file(
296 get_roberta_vocab_path(),
297 get_roberta_merges_path(),
298 false,
299 false,
300 )
301 .expect("Error loading tokenizer");
302
303 let token_counter = Box::new(move |s: &str| tokenizer.tokenize(s).len());
304 let chunker = Chunker::new(10, token_counter);
305 let text = read_gutenberg_corpus("austen-emma.txt");
306 let chunks = chunker.chunk(&text);
307 assert_eq!(chunks.len(), 606);
308 }
309
310 #[test]
311 #[cfg(feature = "rust_tokenizers")]
312 fn test_chunk_rust_tokenizers_gutenberg_milton_paradise() {
313 let tokenizer = RobertaTokenizer::from_file(
314 get_roberta_vocab_path(),
315 get_roberta_merges_path(),
316 false,
317 false,
318 )
319 .expect("Error loading tokenizer");
320
321 let token_counter = Box::new(move |s: &str| tokenizer.tokenize(s).len());
322 let chunker = Chunker::new(10, token_counter);
323 let text = read_gutenberg_corpus("milton-paradise.txt");
324 let chunks = chunker.chunk(&text);
325 assert_eq!(chunks.len(), 12196);
326 }
327
328 #[test]
329 #[cfg(feature = "rust_tokenizers")]
330 fn test_chunk_rust_tokenizers_gutenberg_shakespeare_hamlet() {
331 let tokenizer = RobertaTokenizer::from_file(
332 get_roberta_vocab_path(),
333 get_roberta_merges_path(),
334 false,
335 false,
336 )
337 .expect("Error loading tokenizer");
338
339 let token_counter = Box::new(move |s: &str| tokenizer.tokenize(s).len());
340 let chunker = Chunker::new(10, token_counter);
341 let text = read_gutenberg_corpus("shakespeare-hamlet.txt");
342 let chunks = chunker.chunk(&text);
343 assert_eq!(chunks.len(), 4474);
344 }
345
346 #[test]
347 fn test_merge_splits_simple() {
348 let chunker = Chunker::new(
349 2,
350 Box::new(|s: &str| s.len() - s.replace(" ", "").len() + 1),
351 );
352 let splits = vec!["Hello", "World", "Goodbye", "World"];
353 let separator = " ";
354 let (split_idx, merged) = chunker.merge_splits(&splits, separator);
355 assert_eq!(split_idx, 2);
356 assert_eq!(merged, "Hello World");
357
358 let (split_idx, merged) = chunker.merge_splits(splits.get(split_idx..).unwrap(), separator);
359 assert_eq!(split_idx, 2);
360 assert_eq!(merged, "Goodbye World");
361 }
362
363 #[test]
364 fn test_merge_splits_uneven() {
365 let chunker = Chunker::new(
366 4,
367 Box::new(|s: &str| s.len() - s.replace(" ", "").len() + 1),
368 );
369 let splits = vec![
370 "The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog",
371 ]; let separator = " ";
373 let (split_idx, merged) = chunker.merge_splits(&splits, separator);
374 assert_eq!(split_idx, 4);
375 assert_eq!(merged, "The quick brown fox");
376
377 let (split_idx_2, merged) =
378 chunker.merge_splits(splits.get(split_idx..).unwrap(), separator);
379 assert_eq!(split_idx_2, 4);
380 assert_eq!(merged, "jumps over the lazy");
381
382 let (split_idx_3, merged) =
383 chunker.merge_splits(splits.get(split_idx + split_idx_2..).unwrap(), separator);
384 assert_eq!(split_idx_3, 1);
385 assert_eq!(merged, "dog");
386 }
387}