Skip to main content

text_splitter/splitter/
text.rs

1/*!
2# [`TextSplitter`]
3Semantic splitting of text documents.
4*/
5
6use std::ops::Range;
7
8use itertools::Itertools;
9use memchr::memchr2_iter;
10
11use crate::{
12    splitter::{SemanticLevel, Splitter},
13    ChunkConfig, ChunkSizer,
14};
15
16use super::{fallback::GRAPHEME_SEGMENTER, ChunkCharIndex};
17
18/// Default plain-text splitter. Recursively splits chunks into the largest
19/// semantic units that fit within the chunk size. Also will attempt to merge
20/// neighboring chunks if they can fit within the given chunk size.
21#[derive(Debug)]
22pub struct TextSplitter<Sizer>
23where
24    Sizer: ChunkSizer,
25{
26    /// Method of determining chunk sizes.
27    chunk_config: ChunkConfig<Sizer>,
28}
29
30impl<Sizer> TextSplitter<Sizer>
31where
32    Sizer: ChunkSizer,
33{
34    /// Creates a new [`TextSplitter`].
35    ///
36    /// ```
37    /// use text_splitter::TextSplitter;
38    ///
39    /// // By default, the chunk sizer is based on characters.
40    /// let splitter = TextSplitter::new(512);
41    /// ```
42    #[must_use]
43    pub fn new(chunk_config: impl Into<ChunkConfig<Sizer>>) -> Self {
44        Self {
45            chunk_config: chunk_config.into(),
46        }
47    }
48
49    /// Generate a list of chunks from a given text. Each chunk will be up to the `chunk_capacity`.
50    ///
51    /// ## Method
52    ///
53    /// To preserve as much semantic meaning within a chunk as possible, each chunk is composed of the largest semantic units that can fit in the next given chunk. For each splitter type, there is a defined set of semantic levels. Here is an example of the steps used:
54    ///
55    /// 1. Split the text by a increasing semantic levels.
56    /// 2. Check the first item for each level and select the highest level whose first item still fits within the chunk size.
57    /// 3. Merge as many of these neighboring sections of this level or above into a chunk to maximize chunk length.
58    ///    Boundaries of higher semantic levels are always included when merging, so that the chunk doesn't inadvertantly cross semantic boundaries.
59    ///
60    /// The boundaries used to split the text if using the `chunks` method, in ascending order:
61    ///
62    /// 1. Characters
63    /// 2. [Unicode Grapheme Cluster Boundaries](https://www.unicode.org/reports/tr29/#Grapheme_Cluster_Boundaries)
64    /// 3. [Unicode Word Boundaries](https://www.unicode.org/reports/tr29/#Word_Boundaries)
65    /// 4. [Unicode Sentence Boundaries](https://www.unicode.org/reports/tr29/#Sentence_Boundaries)
66    /// 5. Ascending sequence length of newlines. (Newline is `\r\n`, `\n`, or `\r`)
67    ///    Each unique length of consecutive newline sequences is treated as its own semantic level. So a sequence of 2 newlines is a higher level than a sequence of 1 newline, and so on.
68    ///
69    /// Splitting doesn't occur below the character level, otherwise you could get partial bytes of a char, which may not be a valid unicode str.
70    ///
71    /// ```
72    /// use text_splitter::TextSplitter;
73    ///
74    /// let splitter = TextSplitter::new(10);
75    /// let text = "Some text\n\nfrom a\ndocument";
76    /// let chunks = splitter.chunks(text).collect::<Vec<_>>();
77    ///
78    /// assert_eq!(vec!["Some text", "from a", "document"], chunks);
79    /// ```
80    pub fn chunks<'splitter, 'text: 'splitter>(
81        &'splitter self,
82        text: &'text str,
83    ) -> impl Iterator<Item = &'text str> + 'splitter {
84        Splitter::<_>::chunks(self, text)
85    }
86
87    /// Returns an iterator over chunks of the text and their byte offsets.
88    /// Each chunk will be up to the `chunk_capacity`.
89    ///
90    /// See [`TextSplitter::chunks`] for more information.
91    ///
92    /// ```
93    /// use text_splitter::{Characters, TextSplitter};
94    ///
95    /// let splitter = TextSplitter::new(10);
96    /// let text = "Some text\n\nfrom a\ndocument";
97    /// let chunks = splitter.chunk_indices(text).collect::<Vec<_>>();
98    ///
99    /// assert_eq!(vec![(0, "Some text"), (11, "from a"), (18, "document")], chunks);
100    /// ```
101    pub fn chunk_indices<'splitter, 'text: 'splitter>(
102        &'splitter self,
103        text: &'text str,
104    ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
105        Splitter::<_>::chunk_indices(self, text)
106    }
107
108    /// Returns an iterator over chunks of the text with their byte and character offsets.
109    /// Each chunk will be up to the `chunk_capacity`.
110    ///
111    /// See [`TextSplitter::chunks`] for more information.
112    ///
113    /// This will be more expensive than just byte offsets, and for most usage in Rust, just
114    /// having byte offsets is sufficient. But when interfacing with other languages or systems
115    /// that require character offsets, this will track the character offsets for you,
116    /// accounting for any trimming that may have occurred.
117    ///
118    /// ```
119    /// use text_splitter::{Characters, ChunkCharIndex, TextSplitter};
120    ///
121    /// let splitter = TextSplitter::new(10);
122    /// let text = "Some text\n\nfrom a\ndocument";
123    /// let chunks = splitter.chunk_char_indices(text).collect::<Vec<_>>();
124    ///
125    /// assert_eq!(vec![ChunkCharIndex { chunk: "Some text", byte_offset: 0, char_offset: 0 }, ChunkCharIndex { chunk: "from a", byte_offset: 11, char_offset: 11 }, ChunkCharIndex { chunk: "document", byte_offset: 18, char_offset: 18 }], chunks);
126    pub fn chunk_char_indices<'splitter, 'text: 'splitter>(
127        &'splitter self,
128        text: &'text str,
129    ) -> impl Iterator<Item = ChunkCharIndex<'text>> + 'splitter {
130        Splitter::<_>::chunk_char_indices(self, text)
131    }
132}
133
134impl<Sizer> Splitter<Sizer> for TextSplitter<Sizer>
135where
136    Sizer: ChunkSizer,
137{
138    type Level = LineBreaks;
139
140    fn chunk_config(&self) -> &ChunkConfig<Sizer> {
141        &self.chunk_config
142    }
143
144    fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
145        memchr2_iter(b'\n', b'\r', text.as_bytes())
146            .map(|i| i..i + 1)
147            .coalesce(|a, b| {
148                if a.end == b.start {
149                    Ok(a.start..b.end)
150                } else {
151                    Err((a, b))
152                }
153            })
154            .map(|range| {
155                let level = GRAPHEME_SEGMENTER
156                    .segment_str(text.get(range.start..range.end).unwrap())
157                    .tuple_windows::<(usize, usize)>()
158                    .count();
159                (
160                    match level {
161                        0 => unreachable!("regex should always match at least one newline"),
162                        n => LineBreaks(n),
163                    },
164                    range,
165                )
166            })
167            .collect()
168    }
169}
170
171/// Different semantic levels that text can be split by.
172/// Each level provides a method of splitting text into chunks of a given level
173/// as well as a fallback in case a given fallback is too large.
174///
175/// Split by given number of linebreaks, either `\n`, `\r`, or `\r\n`.
176#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
177pub struct LineBreaks(usize);
178
179impl SemanticLevel for LineBreaks {}
180
181#[cfg(test)]
182mod tests {
183    use std::cmp::min;
184
185    use fake::{Fake, Faker};
186
187    use crate::{splitter::SemanticSplitRanges, ChunkCharIndex};
188
189    use super::*;
190
191    #[test]
192    fn returns_one_chunk_if_text_is_shorter_than_max_chunk_size() {
193        let text = Faker.fake::<String>();
194        let chunks = TextSplitter::new(ChunkConfig::new(text.chars().count()).with_trim(false))
195            .chunks(&text)
196            .collect::<Vec<_>>();
197
198        assert_eq!(vec![&text], chunks);
199    }
200
201    #[test]
202    fn returns_two_chunks_if_text_is_longer_than_max_chunk_size() {
203        let text1 = Faker.fake::<String>();
204        let text2 = Faker.fake::<String>();
205        let text = format!("{text1}{text2}");
206        // Round up to one above half so it goes to 2 chunks
207        let max_chunk_size = text.chars().count() / 2 + 1;
208        let chunks = TextSplitter::new(ChunkConfig::new(max_chunk_size).with_trim(false))
209            .chunks(&text)
210            .collect::<Vec<_>>();
211
212        assert!(chunks.iter().all(|c| c.chars().count() <= max_chunk_size));
213
214        // Check that beginning of first chunk and text 1 matches
215        let len = min(text1.len(), chunks[0].len());
216        assert_eq!(text1[..len], chunks[0][..len]);
217        // Check that end of second chunk and text 2 matches
218        let len = min(text2.len(), chunks[1].len());
219        assert_eq!(
220            text2[(text2.len() - len)..],
221            chunks[1][chunks[1].len() - len..]
222        );
223
224        assert_eq!(chunks.join(""), text);
225    }
226
227    #[test]
228    fn empty_string() {
229        let text = "";
230        let chunks = TextSplitter::new(ChunkConfig::new(100).with_trim(false))
231            .chunks(text)
232            .collect::<Vec<_>>();
233
234        assert!(chunks.is_empty());
235    }
236
237    #[test]
238    fn can_handle_unicode_characters() {
239        let text = "éé"; // Char that is more than one byte
240        let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
241            .chunks(text)
242            .collect::<Vec<_>>();
243        assert_eq!(vec!["é", "é"], chunks);
244    }
245
246    // Just for testing
247    struct Str;
248
249    impl ChunkSizer for Str {
250        fn size(&self, chunk: &str) -> usize {
251            chunk.len()
252        }
253    }
254
255    #[test]
256    fn custom_len_function() {
257        let text = "éé"; // Char that is two bytes each
258        let chunks = TextSplitter::new(ChunkConfig::new(2).with_sizer(Str).with_trim(false))
259            .chunks(text)
260            .collect::<Vec<_>>();
261
262        assert_eq!(vec!["é", "é"], chunks);
263    }
264
265    #[test]
266    fn handles_char_bigger_than_len() {
267        let text = "éé"; // Char that is two bytes each
268        let chunks = TextSplitter::new(ChunkConfig::new(1).with_sizer(Str).with_trim(false))
269            .chunks(text)
270            .collect::<Vec<_>>();
271
272        // We can only go so small
273        assert_eq!(vec!["é", "é"], chunks);
274    }
275
276    #[test]
277    fn chunk_by_graphemes() {
278        let text = "a̐éö̲\r\n";
279        let chunks = TextSplitter::new(ChunkConfig::new(3).with_trim(false))
280            .chunks(text)
281            .collect::<Vec<_>>();
282
283        // \r\n is grouped together not separated
284        assert_eq!(vec!["a̐é", "ö̲", "\r\n"], chunks);
285    }
286
287    #[test]
288    fn trim_char_indices() {
289        let text = " a b ";
290        let chunks = TextSplitter::new(1).chunk_indices(text).collect::<Vec<_>>();
291
292        assert_eq!(vec![(1, "a"), (3, "b")], chunks);
293    }
294
295    #[test]
296    fn chunk_char_indices() {
297        let text = " a b ";
298        let chunks = TextSplitter::new(1)
299            .chunk_char_indices(text)
300            .collect::<Vec<_>>();
301
302        assert_eq!(
303            vec![
304                ChunkCharIndex {
305                    chunk: "a",
306                    byte_offset: 1,
307                    char_offset: 1
308                },
309                ChunkCharIndex {
310                    chunk: "b",
311                    byte_offset: 3,
312                    char_offset: 3,
313                },
314            ],
315            chunks
316        );
317    }
318
319    #[test]
320    fn graphemes_fallback_to_chars() {
321        let text = "a̐éö̲\r\n";
322        let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
323            .chunks(text)
324            .collect::<Vec<_>>();
325        assert_eq!(
326            vec!["a", "\u{310}", "é", "ö", "\u{332}", "\r", "\n"],
327            chunks
328        );
329    }
330
331    #[test]
332    fn trim_grapheme_indices() {
333        let text = "\r\na̐éö̲\r\n";
334        let chunks = TextSplitter::new(3).chunk_indices(text).collect::<Vec<_>>();
335
336        assert_eq!(vec![(2, "a̐é"), (7, "ö̲")], chunks);
337    }
338
339    #[test]
340    fn grapheme_char_indices() {
341        let text = "\r\na̐éö̲\r\n";
342        let chunks = TextSplitter::new(3)
343            .chunk_char_indices(text)
344            .collect::<Vec<_>>();
345
346        assert_eq!(
347            vec![
348                ChunkCharIndex {
349                    chunk: "a̐é",
350                    byte_offset: 2,
351                    char_offset: 2
352                },
353                ChunkCharIndex {
354                    chunk: "ö̲",
355                    byte_offset: 7,
356                    char_offset: 5
357                }
358            ],
359            chunks
360        );
361    }
362
363    #[test]
364    fn chunk_by_words() {
365        let text = "The quick (\"brown\") fox can't jump 32.3 feet, right?";
366        let chunks = TextSplitter::new(ChunkConfig::new(10).with_trim(false))
367            .chunks(text)
368            .collect::<Vec<_>>();
369
370        assert_eq!(
371            vec![
372                "The quick ",
373                "(\"brown\") ",
374                "fox can't ",
375                "jump 32.3 ",
376                "feet, ",
377                "right?"
378            ],
379            chunks
380        );
381    }
382
383    #[test]
384    fn words_fallback_to_graphemes() {
385        let text = "Thé quick\r\n";
386        let chunks = TextSplitter::new(ChunkConfig::new(2).with_trim(false))
387            .chunks(text)
388            .collect::<Vec<_>>();
389        assert_eq!(vec!["Th", "é ", "qu", "ic", "k", "\r\n"], chunks);
390    }
391
392    #[test]
393    fn trim_word_indices() {
394        let text = "Some text from a document";
395        let chunks = TextSplitter::new(10)
396            .chunk_indices(text)
397            .collect::<Vec<_>>();
398        assert_eq!(
399            vec![(0, "Some text"), (10, "from a"), (17, "document")],
400            chunks
401        );
402    }
403
404    #[test]
405    fn chunk_by_sentences() {
406        let text = "Mr. Fox jumped. [...] The dog was too lazy.";
407        let chunks = TextSplitter::new(ChunkConfig::new(21).with_trim(false))
408            .chunks(text)
409            .collect::<Vec<_>>();
410        assert_eq!(
411            vec!["Mr. Fox jumped. ", "[...] ", "The dog was too lazy."],
412            chunks
413        );
414    }
415
416    #[test]
417    fn sentences_falls_back_to_words() {
418        let text = "Mr. Fox jumped. [...] The dog was too lazy.";
419        let chunks = TextSplitter::new(ChunkConfig::new(16).with_trim(false))
420            .chunks(text)
421            .collect::<Vec<_>>();
422        assert_eq!(
423            vec!["Mr. Fox jumped. ", "[...] ", "The dog was too ", "lazy."],
424            chunks
425        );
426    }
427
428    #[test]
429    fn trim_sentence_indices() {
430        let text = "Some text. From a document.";
431        let chunks = TextSplitter::new(10)
432            .chunk_indices(text)
433            .collect::<Vec<_>>();
434        assert_eq!(
435            vec![(0, "Some text."), (11, "From a"), (18, "document.")],
436            chunks
437        );
438    }
439
440    #[test]
441    fn trim_paragraph_indices() {
442        let text = "Some text\n\nfrom a\ndocument";
443        let chunks = TextSplitter::new(10)
444            .chunk_indices(text)
445            .collect::<Vec<_>>();
446        assert_eq!(
447            vec![(0, "Some text"), (11, "from a"), (18, "document")],
448            chunks
449        );
450    }
451
452    #[test]
453    fn correctly_determines_newlines() {
454        let text = "\r\n\r\ntext\n\n\ntext2";
455        let splitter = TextSplitter::new(10);
456        let linebreaks = SemanticSplitRanges::new(splitter.parse(text));
457        assert_eq!(
458            vec![(LineBreaks(2), 0..4), (LineBreaks(3), 8..11)],
459            linebreaks.ranges
460        );
461    }
462}