semchunk_rs/
chunker.rs

1// MIT License
2//
3// Copyright (c) 2024 Dominic Tarro
4//
5// Permission is hereby granted, free of charge, to any person obtaining a copy
6// of this software and associated documentation files (the "Software"), to deal
7// in the Software without restriction, including without limitation the rights
8// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9// copies of the Software, and to permit persons to whom the Software is
10// furnished to do so, subject to the following conditions:
11//
12// The above copyright notice and this permission notice shall be included in all
13// copies or substantial portions of the Software.
14//
15// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21// SOFTWARE.
22
23use bisection::bisect_left;
24
25use crate::splitter::Splitter;
26
27
28/// A struct for chunking texts into segments based on a maximum number of tokens per chunk and a token counter function.
29/// 
30/// # Fields
31/// 
32/// * `chunk_size` - The maximum number of tokens that can be in a chunk.
33/// * `token_counter` - A function that counts the number of tokens in a string.
34/// * `splitter` - The Splitter instance used to split the text.
35/// 
36/// # Example
37/// 
38/// ```
39/// use semchunk_rs::Chunker;
40/// let chunker = Chunker::new(4, Box::new(|s: &str| s.len() - s.replace(" ", "").len() + 1));
41/// let text = "The quick brown fox jumps over the lazy dog.";
42/// let chunks = chunker.chunk(text);
43/// assert_eq!(chunks, vec!["The quick brown fox", "jumps over the lazy", "dog."]);
44/// ```
45///
46/// With `rust_tokenizers`:
47///
48/// ```
49/// use rust_tokenizers::tokenizer::{RobertaTokenizer, Tokenizer};
50/// use semchunk_rs::Chunker;
51/// let tokenizer = RobertaTokenizer::from_file("data/roberta-base-vocab.json", "data/roberta-base-merges.txt", false, false)
52///    .expect("Error loading tokenizer");
53/// let token_counter = Box::new(move |s: &str| {
54///    tokenizer.tokenize(s).len()
55/// });
56/// let chunker = Chunker::new(10, token_counter);
57/// ```
58pub struct Chunker {
59    chunk_size: usize,
60    token_counter: Box<dyn Fn(&str) -> usize>,
61    splitter: Splitter,
62}
63
64impl Chunker {
65    /// Creates a new Chunker instance. Uses the default Splitter instance. S
66    ///
67    /// # Arguments
68    ///
69    /// * `chunk_size` - The maximum number of tokens that can be in a chunk.
70    /// * `token_counter` - A function that counts the number of tokens in a string.
71    ///
72    /// # Returns
73    ///
74    /// A new Chunker instance.
75    pub fn new(chunk_size: usize, token_counter: Box<dyn Fn(&str) -> usize>) -> Self {
76        Chunker {
77            chunk_size,
78            token_counter,
79            splitter: Splitter::default(),
80        }
81    }
82
83    /// Sets the splitter for the Chunker instance.
84    pub fn splitter(mut self, splitter: Splitter) -> Self {
85        self.splitter = splitter;
86        self
87    }
88
89    /// Recursively chunks the given text into segments based on the maximum number of tokens per chunk.
90    /// 
91    /// # Arguments
92    /// 
93    /// * `text` - A string slice that holds the text to be chunked.
94    /// * `recursion_depth` - The current recursion depth.
95    /// 
96    /// # Returns
97    /// 
98    /// A vector of string slices representing the chunks of the split text.
99    pub fn _chunk(&self, text: &str, recursion_depth: usize) -> Vec<String> {
100        let (separator, separator_is_whitespace, text_splits) = self.splitter.split_text(text);
101
102        let mut chunks: Vec<String> = Vec::new();
103
104        // Iterate through the splits
105        let mut i = 0;
106        while i < text_splits.len() {
107            if (self.token_counter)(text_splits[i]) > self.chunk_size {
108                // If the split is over the chunk size, recursively chunk it.
109                let sub_chunks = self._chunk(text_splits[i], recursion_depth + 1);
110                for sub_chunk in sub_chunks {
111                    chunks.push(sub_chunk);
112                }
113                i += 1;
114            } else {
115                // If the split is equal to or under the chunk size, add it and any subsequent splits to a new chunk until the chunk size is reached.
116                let (split_idx, merged_chunk) = self.merge_splits(&text_splits[i..], separator);
117                chunks.push(merged_chunk);
118                i += split_idx;
119            }
120
121            let n_chunks = chunks.len();
122            // If the separator is not whitespace and the split is not the last split, add the separator to the end of the last chunk if doing so would not cause it to exceed the chunk size otherwise add the splitter as a new chunk.
123            if !separator_is_whitespace && i < text_splits.len() {
124                let last_chunk_with_separator = chunks[n_chunks - 1].clone() + separator;
125                if (self.token_counter)(&last_chunk_with_separator) <= self.chunk_size {
126                    chunks[n_chunks - 1] = last_chunk_with_separator;
127                } else {
128                    chunks.push(separator.to_string());
129                }
130            }
131        }
132        if recursion_depth > 0 {
133            chunks = chunks
134                .iter()
135                .filter(|&c| !c.is_empty())
136                .map(|c| c.to_string())
137                .collect();
138        }
139        chunks
140    }
141
142    /// Merges first N splits into a chunk that has <= chunk_size tokens.
143    ///
144    /// # Arguments
145    ///
146    /// * `splits` - A vector of string slices representing the splits to merge.
147    /// * `separator` - The separator used to split the text.
148    ///
149    /// # Returns
150    ///
151    /// A tuple containing:
152    /// * The index merging stopped at (not inclusive).
153    /// * The merged text.
154    ///
155    /// # Examples
156    ///
157    /// ```
158    /// use semchunk_rs::Chunker;
159    /// let chunker = Chunker::new(4, Box::new(|s: &str| s.len() - s.replace(" ", "").len() + 1));
160    /// let splits = vec!["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"];
161    /// let separator = " ";
162    /// let (split_idx, merged) = chunker.merge_splits(&splits, separator);
163    /// assert_eq!(split_idx, 4);
164    /// assert_eq!(merged, "The quick brown fox");
165    /// ```
166    pub fn merge_splits(&self, splits: &[&str], separator: &str) -> (usize, String) {
167        let mut low = 0;
168        let mut high = splits.len();
169
170        let mut n_tokens: usize;
171        let mut tokens_per_split = 5.0;
172        let cumulative_split_char_counts = splits
173            .iter()
174            .scan(0, |acc, &s| {
175                *acc += s.len() as u64;
176                Some(*acc)
177            })
178            .collect::<Vec<u64>>();
179
180        while low < high {
181            // estimate number of splits to increment by using the number of tokens per split
182            let increment_by = bisect_left(
183                &cumulative_split_char_counts[low..high],
184                &((self.chunk_size as f64 * tokens_per_split) as u64),
185            );
186            let est_midpoint = std::cmp::min(low + increment_by, high - 1);
187            n_tokens =
188                (self.token_counter)(splits.get(..est_midpoint).unwrap().join(separator).as_ref());
189
190            match n_tokens.cmp(&self.chunk_size) {
191                std::cmp::Ordering::Greater => high = est_midpoint,
192                std::cmp::Ordering::Equal => {
193                    low = est_midpoint;
194                    break;
195                }
196                std::cmp::Ordering::Less => low = est_midpoint + 1,
197            }
198
199            if n_tokens > 0 && cumulative_split_char_counts[est_midpoint] > 0 {
200                tokens_per_split =
201                    n_tokens as f64 / cumulative_split_char_counts[est_midpoint] as f64;
202            }
203        }
204        (low, splits.get(..low).unwrap().join(separator))
205    }
206
207    /// Chunks the given text into segments based on the maximum number of tokens per chunk.
208    /// 
209    /// # Arguments
210    /// 
211    /// * `text` - A string slice that holds the text to be chunked.
212    /// 
213    /// # Examples
214    /// 
215    /// ```
216    /// use semchunk_rs::Chunker;
217    /// 
218    /// let chunker = Chunker::new(4, Box::new(|s: &str| s.len() - s.replace(" ", "").len() + 1));
219    /// let text = "The quick brown fox jumps over the lazy dog.";
220    /// let chunks = chunker._chunk(text, 0);
221    /// assert_eq!(chunks, vec!["The quick brown fox", "jumps over the lazy", "dog."]);
222    /// ```
223    pub fn chunk(&self, text: &str) -> Vec<String> {
224        self._chunk(text, 0)
225    }
226}
227
228
229#[cfg(test)]
230mod chunker_tests {
231    use super::*;
232    use std::io::Read;
233    use std::path::PathBuf;
234
235    #[cfg(feature = "rust_tokenizers")]
236    use rust_tokenizers::tokenizer::{RobertaTokenizer, Tokenizer};
237
238    fn get_data_path() -> PathBuf {
239        PathBuf::from(std::env::var("DATA_DIR").unwrap_or_else(|_| ".".to_string()))
240    }
241
242    fn get_roberta_vocab_path() -> PathBuf {
243        get_data_path().join("roberta-base-vocab.json")
244    }
245
246    fn get_roberta_merges_path() -> PathBuf {
247        get_data_path().join("roberta-base-merges.txt")
248    }
249
250    fn get_gutenberg_path() -> PathBuf {
251        get_data_path().join("gutenberg")
252    }
253
254    fn get_gutenberg_corpus_path(corpus_filename: &str) -> PathBuf {
255        get_gutenberg_path().join(corpus_filename)
256    }
257
258    fn read_gutenberg_corpus(corpus_filename: &str) -> String {
259        let mut file = std::fs::File::open(get_gutenberg_corpus_path(corpus_filename))
260            .expect("Error opening file");
261        let mut buffer = Vec::new();
262        file.read_to_end(&mut buffer).expect("Error reading file");
263        String::from_utf8_lossy(&buffer).to_string()
264    }
265
266    #[test]
267    #[cfg(feature = "rust_tokenizers")]
268    fn test_chunk_rust_tokenizers() {
269        let tokenizer = RobertaTokenizer::from_file(
270            get_roberta_vocab_path(),
271            get_roberta_merges_path(),
272            false,
273            false,
274        )
275        .expect("Error loading tokenizer");
276
277        let token_counter = Box::new(move |s: &str| tokenizer.tokenize(s).len());
278        let chunker = Chunker::new(10, token_counter);
279        let text = "The quick brown fox jumps over the lazy dog.\n\nThe subject is\n\t- \"The quick brown fox\"\n\t- \"jumps over\"\n\t- \"the lazy dog\"";
280        println!("Text: {}", text);
281        let chunks = chunker.chunk(text);
282        assert_eq!(
283            chunks,
284            vec![
285                "The quick brown fox jumps over the lazy dog.",
286                "The subject is\n\t- \"The quick brown fox\"",
287                "\t- \"jumps over\"\n\t- \"the lazy dog\"",
288            ]
289        )
290    }
291
292    #[test]
293    #[cfg(feature = "rust_tokenizers")]
294    fn test_chunk_rust_tokenizers_gutenberg_austen_emma() {
295        let tokenizer = RobertaTokenizer::from_file(
296            get_roberta_vocab_path(),
297            get_roberta_merges_path(),
298            false,
299            false,
300        )
301        .expect("Error loading tokenizer");
302
303        let token_counter = Box::new(move |s: &str| tokenizer.tokenize(s).len());
304        let chunker = Chunker::new(10, token_counter);
305        let text = read_gutenberg_corpus("austen-emma.txt");
306        let chunks = chunker.chunk(&text);
307        assert_eq!(chunks.len(), 606);
308    }
309
310    #[test]
311    #[cfg(feature = "rust_tokenizers")]
312    fn test_chunk_rust_tokenizers_gutenberg_milton_paradise() {
313        let tokenizer = RobertaTokenizer::from_file(
314            get_roberta_vocab_path(),
315            get_roberta_merges_path(),
316            false,
317            false,
318        )
319        .expect("Error loading tokenizer");
320
321        let token_counter = Box::new(move |s: &str| tokenizer.tokenize(s).len());
322        let chunker = Chunker::new(10, token_counter);
323        let text = read_gutenberg_corpus("milton-paradise.txt");
324        let chunks = chunker.chunk(&text);
325        assert_eq!(chunks.len(), 12196);
326    }
327
328    #[test]
329    #[cfg(feature = "rust_tokenizers")]
330    fn test_chunk_rust_tokenizers_gutenberg_shakespeare_hamlet() {
331        let tokenizer = RobertaTokenizer::from_file(
332            get_roberta_vocab_path(),
333            get_roberta_merges_path(),
334            false,
335            false,
336        )
337        .expect("Error loading tokenizer");
338
339        let token_counter = Box::new(move |s: &str| tokenizer.tokenize(s).len());
340        let chunker = Chunker::new(10, token_counter);
341        let text = read_gutenberg_corpus("shakespeare-hamlet.txt");
342        let chunks = chunker.chunk(&text);
343        assert_eq!(chunks.len(), 4474);
344    }
345
346    #[test]
347    fn test_merge_splits_simple() {
348        let chunker = Chunker::new(
349            2,
350            Box::new(|s: &str| s.len() - s.replace(" ", "").len() + 1),
351        );
352        let splits = vec!["Hello", "World", "Goodbye", "World"];
353        let separator = " ";
354        let (split_idx, merged) = chunker.merge_splits(&splits, separator);
355        assert_eq!(split_idx, 2);
356        assert_eq!(merged, "Hello World");
357
358        let (split_idx, merged) = chunker.merge_splits(splits.get(split_idx..).unwrap(), separator);
359        assert_eq!(split_idx, 2);
360        assert_eq!(merged, "Goodbye World");
361    }
362
363    #[test]
364    fn test_merge_splits_uneven() {
365        let chunker = Chunker::new(
366            4,
367            Box::new(|s: &str| s.len() - s.replace(" ", "").len() + 1),
368        );
369        let splits = vec![
370            "The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog",
371        ]; // 9 tokens
372        let separator = " ";
373        let (split_idx, merged) = chunker.merge_splits(&splits, separator);
374        assert_eq!(split_idx, 4);
375        assert_eq!(merged, "The quick brown fox");
376
377        let (split_idx_2, merged) =
378            chunker.merge_splits(splits.get(split_idx..).unwrap(), separator);
379        assert_eq!(split_idx_2, 4);
380        assert_eq!(merged, "jumps over the lazy");
381
382        let (split_idx_3, merged) =
383            chunker.merge_splits(splits.get(split_idx + split_idx_2..).unwrap(), separator);
384        assert_eq!(split_idx_3, 1);
385        assert_eq!(merged, "dog");
386    }
387}