text_splitter/splitter/
text.rs

1/*!
2# [`TextSplitter`]
3Semantic splitting of text documents.
4*/
5
6use std::{ops::Range, sync::LazyLock};
7
8use itertools::Itertools;
9use regex::Regex;
10
11use crate::{
12    splitter::{SemanticLevel, Splitter},
13    ChunkConfig, ChunkSizer,
14};
15
16use super::{fallback::GRAPHEME_SEGMENTER, ChunkCharIndex};
17
18/// Default plain-text splitter. Recursively splits chunks into the largest
19/// semantic units that fit within the chunk size. Also will attempt to merge
20/// neighboring chunks if they can fit within the given chunk size.
21#[derive(Debug)]
22#[allow(clippy::module_name_repetitions)]
23pub struct TextSplitter<Sizer>
24where
25    Sizer: ChunkSizer,
26{
27    /// Method of determining chunk sizes.
28    chunk_config: ChunkConfig<Sizer>,
29}
30
31impl<Sizer> TextSplitter<Sizer>
32where
33    Sizer: ChunkSizer,
34{
35    /// Creates a new [`TextSplitter`].
36    ///
37    /// ```
38    /// use text_splitter::TextSplitter;
39    ///
40    /// // By default, the chunk sizer is based on characters.
41    /// let splitter = TextSplitter::new(512);
42    /// ```
43    #[must_use]
44    pub fn new(chunk_config: impl Into<ChunkConfig<Sizer>>) -> Self {
45        Self {
46            chunk_config: chunk_config.into(),
47        }
48    }
49
50    /// Generate a list of chunks from a given text. Each chunk will be up to the `chunk_capacity`.
51    ///
52    /// ## Method
53    ///
54    /// To preserve as much semantic meaning within a chunk as possible, each chunk is composed of the largest semantic units that can fit in the next given chunk. For each splitter type, there is a defined set of semantic levels. Here is an example of the steps used:
55    //
56    // 1. Split the text by a increasing semantic levels.
57    // 2. Check the first item for each level and select the highest level whose first item still fits within the chunk size.
58    // 3. Merge as many of these neighboring sections of this level or above into a chunk to maximize chunk length.
59    //    Boundaries of higher semantic levels are always included when merging, so that the chunk doesn't inadvertantly cross semantic boundaries.
60    //
61    // The boundaries used to split the text if using the `chunks` method, in ascending order:
62    //
63    // 1. Characters
64    // 2. [Unicode Grapheme Cluster Boundaries](https://www.unicode.org/reports/tr29/#Grapheme_Cluster_Boundaries)
65    // 3. [Unicode Word Boundaries](https://www.unicode.org/reports/tr29/#Word_Boundaries)
66    // 4. [Unicode Sentence Boundaries](https://www.unicode.org/reports/tr29/#Sentence_Boundaries)
67    // 5. Ascending sequence length of newlines. (Newline is `\r\n`, `\n`, or `\r`)
68    //    Each unique length of consecutive newline sequences is treated as its own semantic level. So a sequence of 2 newlines is a higher level than a sequence of 1 newline, and so on.
69    //
70    // Splitting doesn't occur below the character level, otherwise you could get partial bytes of a char, which may not be a valid unicode str.
71    ///
72    /// ```
73    /// use text_splitter::TextSplitter;
74    ///
75    /// let splitter = TextSplitter::new(10);
76    /// let text = "Some text\n\nfrom a\ndocument";
77    /// let chunks = splitter.chunks(text).collect::<Vec<_>>();
78    ///
79    /// assert_eq!(vec!["Some text", "from a", "document"], chunks);
80    /// ```
81    pub fn chunks<'splitter, 'text: 'splitter>(
82        &'splitter self,
83        text: &'text str,
84    ) -> impl Iterator<Item = &'text str> + 'splitter {
85        Splitter::<_>::chunks(self, text)
86    }
87
88    /// Returns an iterator over chunks of the text and their byte offsets.
89    /// Each chunk will be up to the `chunk_capacity`.
90    ///
91    /// See [`TextSplitter::chunks`] for more information.
92    ///
93    /// ```
94    /// use text_splitter::{Characters, TextSplitter};
95    ///
96    /// let splitter = TextSplitter::new(10);
97    /// let text = "Some text\n\nfrom a\ndocument";
98    /// let chunks = splitter.chunk_indices(text).collect::<Vec<_>>();
99    ///
100    /// assert_eq!(vec![(0, "Some text"), (11, "from a"), (18, "document")], chunks);
101    /// ```
102    pub fn chunk_indices<'splitter, 'text: 'splitter>(
103        &'splitter self,
104        text: &'text str,
105    ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
106        Splitter::<_>::chunk_indices(self, text)
107    }
108
109    /// Returns an iterator over chunks of the text with their byte and character offsets.
110    /// Each chunk will be up to the `chunk_capacity`.
111    ///
112    /// See [`TextSplitter::chunks`] for more information.
113    ///
114    /// This will be more expensive than just byte offsets, and for most usage in Rust, just
115    /// having byte offsets is sufficient. But when interfacing with other languages or systems
116    /// that require character offsets, this will track the character offsets for you,
117    /// accounting for any trimming that may have occurred.
118    ///
119    /// ```
120    /// use text_splitter::{Characters, ChunkCharIndex, TextSplitter};
121    ///
122    /// let splitter = TextSplitter::new(10);
123    /// let text = "Some text\n\nfrom a\ndocument";
124    /// let chunks = splitter.chunk_char_indices(text).collect::<Vec<_>>();
125    ///
126    /// assert_eq!(vec![ChunkCharIndex { chunk: "Some text", byte_offset: 0, char_offset: 0 }, ChunkCharIndex { chunk: "from a", byte_offset: 11, char_offset: 11 }, ChunkCharIndex { chunk: "document", byte_offset: 18, char_offset: 18 }], chunks);
127    pub fn chunk_char_indices<'splitter, 'text: 'splitter>(
128        &'splitter self,
129        text: &'text str,
130    ) -> impl Iterator<Item = ChunkCharIndex<'text>> + 'splitter {
131        Splitter::<_>::chunk_char_indices(self, text)
132    }
133}
134
135impl<Sizer> Splitter<Sizer> for TextSplitter<Sizer>
136where
137    Sizer: ChunkSizer,
138{
139    type Level = LineBreaks;
140
141    fn chunk_config(&self) -> &ChunkConfig<Sizer> {
142        &self.chunk_config
143    }
144
145    fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
146        CAPTURE_LINEBREAKS
147            .find_iter(text)
148            .map(|m| {
149                let range = m.range();
150                let level = GRAPHEME_SEGMENTER
151                    .segment_str(text.get(range.start..range.end).unwrap())
152                    .tuple_windows::<(usize, usize)>()
153                    .count();
154                (
155                    match level {
156                        0 => unreachable!("regex should always match at least one newline"),
157                        n => LineBreaks(n),
158                    },
159                    range,
160                )
161            })
162            .collect()
163    }
164}
165
166/// Different semantic levels that text can be split by.
167/// Each level provides a method of splitting text into chunks of a given level
168/// as well as a fallback in case a given fallback is too large.
169///
170/// Split by given number of linebreaks, either `\n`, `\r`, or `\r\n`.
171#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
172pub struct LineBreaks(usize);
173
174// Lazy so that we don't have to compile them more than once
175static CAPTURE_LINEBREAKS: LazyLock<Regex> =
176    LazyLock::new(|| Regex::new(r"(\r\n)+|\r+|\n+").unwrap());
177
178impl SemanticLevel for LineBreaks {}
179
180#[cfg(test)]
181mod tests {
182    use std::cmp::min;
183
184    use fake::{Fake, Faker};
185
186    use crate::{splitter::SemanticSplitRanges, ChunkCharIndex};
187
188    use super::*;
189
190    #[test]
191    fn returns_one_chunk_if_text_is_shorter_than_max_chunk_size() {
192        let text = Faker.fake::<String>();
193        let chunks = TextSplitter::new(ChunkConfig::new(text.chars().count()).with_trim(false))
194            .chunks(&text)
195            .collect::<Vec<_>>();
196
197        assert_eq!(vec![&text], chunks);
198    }
199
200    #[test]
201    fn returns_two_chunks_if_text_is_longer_than_max_chunk_size() {
202        let text1 = Faker.fake::<String>();
203        let text2 = Faker.fake::<String>();
204        let text = format!("{text1}{text2}");
205        // Round up to one above half so it goes to 2 chunks
206        let max_chunk_size = text.chars().count() / 2 + 1;
207        let chunks = TextSplitter::new(ChunkConfig::new(max_chunk_size).with_trim(false))
208            .chunks(&text)
209            .collect::<Vec<_>>();
210
211        assert!(chunks.iter().all(|c| c.chars().count() <= max_chunk_size));
212
213        // Check that beginning of first chunk and text 1 matches
214        let len = min(text1.len(), chunks[0].len());
215        assert_eq!(text1[..len], chunks[0][..len]);
216        // Check that end of second chunk and text 2 matches
217        let len = min(text2.len(), chunks[1].len());
218        assert_eq!(
219            text2[(text2.len() - len)..],
220            chunks[1][chunks[1].len() - len..]
221        );
222
223        assert_eq!(chunks.join(""), text);
224    }
225
226    #[test]
227    fn empty_string() {
228        let text = "";
229        let chunks = TextSplitter::new(ChunkConfig::new(100).with_trim(false))
230            .chunks(text)
231            .collect::<Vec<_>>();
232
233        assert!(chunks.is_empty());
234    }
235
236    #[test]
237    fn can_handle_unicode_characters() {
238        let text = "éé"; // Char that is more than one byte
239        let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
240            .chunks(text)
241            .collect::<Vec<_>>();
242        assert_eq!(vec!["é", "é"], chunks);
243    }
244
245    // Just for testing
246    struct Str;
247
248    impl ChunkSizer for Str {
249        fn size(&self, chunk: &str) -> usize {
250            chunk.len()
251        }
252    }
253
254    #[test]
255    fn custom_len_function() {
256        let text = "éé"; // Char that is two bytes each
257        let chunks = TextSplitter::new(ChunkConfig::new(2).with_sizer(Str).with_trim(false))
258            .chunks(text)
259            .collect::<Vec<_>>();
260
261        assert_eq!(vec!["é", "é"], chunks);
262    }
263
264    #[test]
265    fn handles_char_bigger_than_len() {
266        let text = "éé"; // Char that is two bytes each
267        let chunks = TextSplitter::new(ChunkConfig::new(1).with_sizer(Str).with_trim(false))
268            .chunks(text)
269            .collect::<Vec<_>>();
270
271        // We can only go so small
272        assert_eq!(vec!["é", "é"], chunks);
273    }
274
275    #[test]
276    fn chunk_by_graphemes() {
277        let text = "a̐éö̲\r\n";
278        let chunks = TextSplitter::new(ChunkConfig::new(3).with_trim(false))
279            .chunks(text)
280            .collect::<Vec<_>>();
281
282        // \r\n is grouped together not separated
283        assert_eq!(vec!["a̐é", "ö̲", "\r\n"], chunks);
284    }
285
286    #[test]
287    fn trim_char_indices() {
288        let text = " a b ";
289        let chunks = TextSplitter::new(1).chunk_indices(text).collect::<Vec<_>>();
290
291        assert_eq!(vec![(1, "a"), (3, "b")], chunks);
292    }
293
294    #[test]
295    fn chunk_char_indices() {
296        let text = " a b ";
297        let chunks = TextSplitter::new(1)
298            .chunk_char_indices(text)
299            .collect::<Vec<_>>();
300
301        assert_eq!(
302            vec![
303                ChunkCharIndex {
304                    chunk: "a",
305                    byte_offset: 1,
306                    char_offset: 1
307                },
308                ChunkCharIndex {
309                    chunk: "b",
310                    byte_offset: 3,
311                    char_offset: 3,
312                },
313            ],
314            chunks
315        );
316    }
317
318    #[test]
319    fn graphemes_fallback_to_chars() {
320        let text = "a̐éö̲\r\n";
321        let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
322            .chunks(text)
323            .collect::<Vec<_>>();
324        assert_eq!(
325            vec!["a", "\u{310}", "é", "ö", "\u{332}", "\r", "\n"],
326            chunks
327        );
328    }
329
330    #[test]
331    fn trim_grapheme_indices() {
332        let text = "\r\na̐éö̲\r\n";
333        let chunks = TextSplitter::new(3).chunk_indices(text).collect::<Vec<_>>();
334
335        assert_eq!(vec![(2, "a̐é"), (7, "ö̲")], chunks);
336    }
337
338    #[test]
339    fn grapheme_char_indices() {
340        let text = "\r\na̐éö̲\r\n";
341        let chunks = TextSplitter::new(3)
342            .chunk_char_indices(text)
343            .collect::<Vec<_>>();
344
345        assert_eq!(
346            vec![
347                ChunkCharIndex {
348                    chunk: "a̐é",
349                    byte_offset: 2,
350                    char_offset: 2
351                },
352                ChunkCharIndex {
353                    chunk: "ö̲",
354                    byte_offset: 7,
355                    char_offset: 5
356                }
357            ],
358            chunks
359        );
360    }
361
362    #[test]
363    fn chunk_by_words() {
364        let text = "The quick (\"brown\") fox can't jump 32.3 feet, right?";
365        let chunks = TextSplitter::new(ChunkConfig::new(10).with_trim(false))
366            .chunks(text)
367            .collect::<Vec<_>>();
368
369        assert_eq!(
370            vec![
371                "The quick ",
372                "(\"brown\") ",
373                "fox can't ",
374                "jump 32.3 ",
375                "feet, ",
376                "right?"
377            ],
378            chunks
379        );
380    }
381
382    #[test]
383    fn words_fallback_to_graphemes() {
384        let text = "Thé quick\r\n";
385        let chunks = TextSplitter::new(ChunkConfig::new(2).with_trim(false))
386            .chunks(text)
387            .collect::<Vec<_>>();
388        assert_eq!(vec!["Th", "é ", "qu", "ic", "k", "\r\n"], chunks);
389    }
390
391    #[test]
392    fn trim_word_indices() {
393        let text = "Some text from a document";
394        let chunks = TextSplitter::new(10)
395            .chunk_indices(text)
396            .collect::<Vec<_>>();
397        assert_eq!(
398            vec![(0, "Some text"), (10, "from a"), (17, "document")],
399            chunks
400        );
401    }
402
403    #[test]
404    fn chunk_by_sentences() {
405        let text = "Mr. Fox jumped. [...] The dog was too lazy.";
406        let chunks = TextSplitter::new(ChunkConfig::new(21).with_trim(false))
407            .chunks(text)
408            .collect::<Vec<_>>();
409        assert_eq!(
410            vec!["Mr. Fox jumped. ", "[...] ", "The dog was too lazy."],
411            chunks
412        );
413    }
414
415    #[test]
416    fn sentences_falls_back_to_words() {
417        let text = "Mr. Fox jumped. [...] The dog was too lazy.";
418        let chunks = TextSplitter::new(ChunkConfig::new(16).with_trim(false))
419            .chunks(text)
420            .collect::<Vec<_>>();
421        assert_eq!(
422            vec!["Mr. Fox jumped. ", "[...] ", "The dog was too ", "lazy."],
423            chunks
424        );
425    }
426
427    #[test]
428    fn trim_sentence_indices() {
429        let text = "Some text. From a document.";
430        let chunks = TextSplitter::new(10)
431            .chunk_indices(text)
432            .collect::<Vec<_>>();
433        assert_eq!(
434            vec![(0, "Some text."), (11, "From a"), (18, "document.")],
435            chunks
436        );
437    }
438
439    #[test]
440    fn trim_paragraph_indices() {
441        let text = "Some text\n\nfrom a\ndocument";
442        let chunks = TextSplitter::new(10)
443            .chunk_indices(text)
444            .collect::<Vec<_>>();
445        assert_eq!(
446            vec![(0, "Some text"), (11, "from a"), (18, "document")],
447            chunks
448        );
449    }
450
451    #[test]
452    fn correctly_determines_newlines() {
453        let text = "\r\n\r\ntext\n\n\ntext2";
454        let splitter = TextSplitter::new(10);
455        let linebreaks = SemanticSplitRanges::new(splitter.parse(text));
456        assert_eq!(
457            vec![(LineBreaks(2), 0..4), (LineBreaks(3), 8..11)],
458            linebreaks.ranges
459        );
460    }
461}