text_splitter/splitter/
text.rs

1/*!
2# [`TextSplitter`]
3Semantic splitting of text documents.
4*/
5
6use std::{ops::Range, sync::LazyLock};
7
8use itertools::Itertools;
9use regex::Regex;
10
11use crate::{
12    splitter::{SemanticLevel, Splitter},
13    ChunkConfig, ChunkSizer,
14};
15
16use super::fallback::GRAPHEME_SEGMENTER;
17
18/// Default plain-text splitter. Recursively splits chunks into the largest
19/// semantic units that fit within the chunk size. Also will attempt to merge
20/// neighboring chunks if they can fit within the given chunk size.
21#[derive(Debug)]
22#[allow(clippy::module_name_repetitions)]
23pub struct TextSplitter<Sizer>
24where
25    Sizer: ChunkSizer,
26{
27    /// Method of determining chunk sizes.
28    chunk_config: ChunkConfig<Sizer>,
29}
30
31impl<Sizer> TextSplitter<Sizer>
32where
33    Sizer: ChunkSizer,
34{
35    /// Creates a new [`TextSplitter`].
36    ///
37    /// ```
38    /// use text_splitter::TextSplitter;
39    ///
40    /// // By default, the chunk sizer is based on characters.
41    /// let splitter = TextSplitter::new(512);
42    /// ```
43    #[must_use]
44    pub fn new(chunk_config: impl Into<ChunkConfig<Sizer>>) -> Self {
45        Self {
46            chunk_config: chunk_config.into(),
47        }
48    }
49
50    /// Generate a list of chunks from a given text. Each chunk will be up to the `chunk_capacity`.
51    ///
52    /// ## Method
53    ///
54    /// To preserve as much semantic meaning within a chunk as possible, each chunk is composed of the largest semantic units that can fit in the next given chunk. For each splitter type, there is a defined set of semantic levels. Here is an example of the steps used:
55    //
56    // 1. Split the text by a increasing semantic levels.
57    // 2. Check the first item for each level and select the highest level whose first item still fits within the chunk size.
58    // 3. Merge as many of these neighboring sections of this level or above into a chunk to maximize chunk length.
59    //    Boundaries of higher semantic levels are always included when merging, so that the chunk doesn't inadvertantly cross semantic boundaries.
60    //
61    // The boundaries used to split the text if using the `chunks` method, in ascending order:
62    //
63    // 1. Characters
64    // 2. [Unicode Grapheme Cluster Boundaries](https://www.unicode.org/reports/tr29/#Grapheme_Cluster_Boundaries)
65    // 3. [Unicode Word Boundaries](https://www.unicode.org/reports/tr29/#Word_Boundaries)
66    // 4. [Unicode Sentence Boundaries](https://www.unicode.org/reports/tr29/#Sentence_Boundaries)
67    // 5. Ascending sequence length of newlines. (Newline is `\r\n`, `\n`, or `\r`)
68    //    Each unique length of consecutive newline sequences is treated as its own semantic level. So a sequence of 2 newlines is a higher level than a sequence of 1 newline, and so on.
69    //
70    // Splitting doesn't occur below the character level, otherwise you could get partial bytes of a char, which may not be a valid unicode str.
71    ///
72    /// ```
73    /// use text_splitter::TextSplitter;
74    ///
75    /// let splitter = TextSplitter::new(10);
76    /// let text = "Some text\n\nfrom a\ndocument";
77    /// let chunks = splitter.chunks(text).collect::<Vec<_>>();
78    ///
79    /// assert_eq!(vec!["Some text", "from a", "document"], chunks);
80    /// ```
81    pub fn chunks<'splitter, 'text: 'splitter>(
82        &'splitter self,
83        text: &'text str,
84    ) -> impl Iterator<Item = &'text str> + 'splitter {
85        Splitter::<_>::chunks(self, text)
86    }
87
88    /// Returns an iterator over chunks of the text and their byte offsets.
89    /// Each chunk will be up to the `chunk_capacity`.
90    ///
91    /// See [`TextSplitter::chunks`] for more information.
92    ///
93    /// ```
94    /// use text_splitter::{Characters, TextSplitter};
95    ///
96    /// let splitter = TextSplitter::new(10);
97    /// let text = "Some text\n\nfrom a\ndocument";
98    /// let chunks = splitter.chunk_indices(text).collect::<Vec<_>>();
99    ///
100    /// assert_eq!(vec![(0, "Some text"), (11, "from a"), (18, "document")], chunks);
101    pub fn chunk_indices<'splitter, 'text: 'splitter>(
102        &'splitter self,
103        text: &'text str,
104    ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
105        Splitter::<_>::chunk_indices(self, text)
106    }
107}
108
109impl<Sizer> Splitter<Sizer> for TextSplitter<Sizer>
110where
111    Sizer: ChunkSizer,
112{
113    type Level = LineBreaks;
114
115    fn chunk_config(&self) -> &ChunkConfig<Sizer> {
116        &self.chunk_config
117    }
118
119    fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
120        CAPTURE_LINEBREAKS
121            .find_iter(text)
122            .map(|m| {
123                let range = m.range();
124                let level = GRAPHEME_SEGMENTER
125                    .segment_str(text.get(range.start..range.end).unwrap())
126                    .tuple_windows::<(usize, usize)>()
127                    .count();
128                (
129                    match level {
130                        0 => unreachable!("regex should always match at least one newline"),
131                        n => LineBreaks(n),
132                    },
133                    range,
134                )
135            })
136            .collect()
137    }
138}
139
140/// Different semantic levels that text can be split by.
141/// Each level provides a method of splitting text into chunks of a given level
142/// as well as a fallback in case a given fallback is too large.
143///
144/// Split by given number of linebreaks, either `\n`, `\r`, or `\r\n`.
145#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
146pub struct LineBreaks(usize);
147
148// Lazy so that we don't have to compile them more than once
149static CAPTURE_LINEBREAKS: LazyLock<Regex> =
150    LazyLock::new(|| Regex::new(r"(\r\n)+|\r+|\n+").unwrap());
151
152impl SemanticLevel for LineBreaks {}
153
154#[cfg(test)]
155mod tests {
156    use std::cmp::min;
157
158    use fake::{Fake, Faker};
159
160    use crate::splitter::SemanticSplitRanges;
161
162    use super::*;
163
164    #[test]
165    fn returns_one_chunk_if_text_is_shorter_than_max_chunk_size() {
166        let text = Faker.fake::<String>();
167        let chunks = TextSplitter::new(ChunkConfig::new(text.chars().count()).with_trim(false))
168            .chunks(&text)
169            .collect::<Vec<_>>();
170
171        assert_eq!(vec![&text], chunks);
172    }
173
174    #[test]
175    fn returns_two_chunks_if_text_is_longer_than_max_chunk_size() {
176        let text1 = Faker.fake::<String>();
177        let text2 = Faker.fake::<String>();
178        let text = format!("{text1}{text2}");
179        // Round up to one above half so it goes to 2 chunks
180        let max_chunk_size = text.chars().count() / 2 + 1;
181        let chunks = TextSplitter::new(ChunkConfig::new(max_chunk_size).with_trim(false))
182            .chunks(&text)
183            .collect::<Vec<_>>();
184
185        assert!(chunks.iter().all(|c| c.chars().count() <= max_chunk_size));
186
187        // Check that beginning of first chunk and text 1 matches
188        let len = min(text1.len(), chunks[0].len());
189        assert_eq!(text1[..len], chunks[0][..len]);
190        // Check that end of second chunk and text 2 matches
191        let len = min(text2.len(), chunks[1].len());
192        assert_eq!(
193            text2[(text2.len() - len)..],
194            chunks[1][chunks[1].len() - len..]
195        );
196
197        assert_eq!(chunks.join(""), text);
198    }
199
200    #[test]
201    fn empty_string() {
202        let text = "";
203        let chunks = TextSplitter::new(ChunkConfig::new(100).with_trim(false))
204            .chunks(text)
205            .collect::<Vec<_>>();
206
207        assert!(chunks.is_empty());
208    }
209
210    #[test]
211    fn can_handle_unicode_characters() {
212        let text = "éé"; // Char that is more than one byte
213        let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
214            .chunks(text)
215            .collect::<Vec<_>>();
216        assert_eq!(vec!["é", "é"], chunks);
217    }
218
219    // Just for testing
220    struct Str;
221
222    impl ChunkSizer for Str {
223        fn size(&self, chunk: &str) -> usize {
224            chunk.len()
225        }
226    }
227
228    #[test]
229    fn custom_len_function() {
230        let text = "éé"; // Char that is two bytes each
231        let chunks = TextSplitter::new(ChunkConfig::new(2).with_sizer(Str).with_trim(false))
232            .chunks(text)
233            .collect::<Vec<_>>();
234
235        assert_eq!(vec!["é", "é"], chunks);
236    }
237
238    #[test]
239    fn handles_char_bigger_than_len() {
240        let text = "éé"; // Char that is two bytes each
241        let chunks = TextSplitter::new(ChunkConfig::new(1).with_sizer(Str).with_trim(false))
242            .chunks(text)
243            .collect::<Vec<_>>();
244
245        // We can only go so small
246        assert_eq!(vec!["é", "é"], chunks);
247    }
248
249    #[test]
250    fn chunk_by_graphemes() {
251        let text = "a̐éö̲\r\n";
252        let chunks = TextSplitter::new(ChunkConfig::new(3).with_trim(false))
253            .chunks(text)
254            .collect::<Vec<_>>();
255
256        // \r\n is grouped together not separated
257        assert_eq!(vec!["a̐é", "ö̲", "\r\n"], chunks);
258    }
259
260    #[test]
261    fn trim_char_indices() {
262        let text = " a b ";
263        let chunks = TextSplitter::new(1).chunk_indices(text).collect::<Vec<_>>();
264
265        assert_eq!(vec![(1, "a"), (3, "b")], chunks);
266    }
267
268    #[test]
269    fn graphemes_fallback_to_chars() {
270        let text = "a̐éö̲\r\n";
271        let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
272            .chunks(text)
273            .collect::<Vec<_>>();
274        assert_eq!(
275            vec!["a", "\u{310}", "é", "ö", "\u{332}", "\r", "\n"],
276            chunks
277        );
278    }
279
280    #[test]
281    fn trim_grapheme_indices() {
282        let text = "\r\na̐éö̲\r\n";
283        let chunks = TextSplitter::new(3).chunk_indices(text).collect::<Vec<_>>();
284
285        assert_eq!(vec![(2, "a̐é"), (7, "ö̲")], chunks);
286    }
287
288    #[test]
289    fn chunk_by_words() {
290        let text = "The quick (\"brown\") fox can't jump 32.3 feet, right?";
291        let chunks = TextSplitter::new(ChunkConfig::new(10).with_trim(false))
292            .chunks(text)
293            .collect::<Vec<_>>();
294
295        assert_eq!(
296            vec![
297                "The quick ",
298                "(\"brown\") ",
299                "fox can't ",
300                "jump 32.3 ",
301                "feet, ",
302                "right?"
303            ],
304            chunks
305        );
306    }
307
308    #[test]
309    fn words_fallback_to_graphemes() {
310        let text = "Thé quick\r\n";
311        let chunks = TextSplitter::new(ChunkConfig::new(2).with_trim(false))
312            .chunks(text)
313            .collect::<Vec<_>>();
314        assert_eq!(vec!["Th", "é ", "qu", "ic", "k", "\r\n"], chunks);
315    }
316
317    #[test]
318    fn trim_word_indices() {
319        let text = "Some text from a document";
320        let chunks = TextSplitter::new(10)
321            .chunk_indices(text)
322            .collect::<Vec<_>>();
323        assert_eq!(
324            vec![(0, "Some text"), (10, "from a"), (17, "document")],
325            chunks
326        );
327    }
328
329    #[test]
330    fn chunk_by_sentences() {
331        let text = "Mr. Fox jumped. [...] The dog was too lazy.";
332        let chunks = TextSplitter::new(ChunkConfig::new(21).with_trim(false))
333            .chunks(text)
334            .collect::<Vec<_>>();
335        assert_eq!(
336            vec!["Mr. Fox jumped. ", "[...] ", "The dog was too lazy."],
337            chunks
338        );
339    }
340
341    #[test]
342    fn sentences_falls_back_to_words() {
343        let text = "Mr. Fox jumped. [...] The dog was too lazy.";
344        let chunks = TextSplitter::new(ChunkConfig::new(16).with_trim(false))
345            .chunks(text)
346            .collect::<Vec<_>>();
347        assert_eq!(
348            vec!["Mr. Fox jumped. ", "[...] ", "The dog was too ", "lazy."],
349            chunks
350        );
351    }
352
353    #[test]
354    fn trim_sentence_indices() {
355        let text = "Some text. From a document.";
356        let chunks = TextSplitter::new(10)
357            .chunk_indices(text)
358            .collect::<Vec<_>>();
359        assert_eq!(
360            vec![(0, "Some text."), (11, "From a"), (18, "document.")],
361            chunks
362        );
363    }
364
365    #[test]
366    fn trim_paragraph_indices() {
367        let text = "Some text\n\nfrom a\ndocument";
368        let chunks = TextSplitter::new(10)
369            .chunk_indices(text)
370            .collect::<Vec<_>>();
371        assert_eq!(
372            vec![(0, "Some text"), (11, "from a"), (18, "document")],
373            chunks
374        );
375    }
376
377    #[test]
378    fn correctly_determines_newlines() {
379        let text = "\r\n\r\ntext\n\n\ntext2";
380        let splitter = TextSplitter::new(10);
381        let linebreaks = SemanticSplitRanges::new(splitter.parse(text));
382        assert_eq!(
383            vec![(LineBreaks(2), 0..4), (LineBreaks(3), 8..11)],
384            linebreaks.ranges
385        );
386    }
387}