text_splitter/splitter/
text.rs

1/*!
2# [`TextSplitter`]
3Semantic splitting of text documents.
4*/
5
6use std::ops::Range;
7
8use itertools::Itertools;
9use memchr::memchr2_iter;
10
11use crate::{
12    splitter::{SemanticLevel, Splitter},
13    ChunkConfig, ChunkSizer,
14};
15
16use super::{fallback::GRAPHEME_SEGMENTER, ChunkCharIndex};
17
18/// Default plain-text splitter. Recursively splits chunks into the largest
19/// semantic units that fit within the chunk size. Also will attempt to merge
20/// neighboring chunks if they can fit within the given chunk size.
21#[derive(Debug)]
22#[allow(clippy::module_name_repetitions)]
23pub struct TextSplitter<Sizer>
24where
25    Sizer: ChunkSizer,
26{
27    /// Method of determining chunk sizes.
28    chunk_config: ChunkConfig<Sizer>,
29}
30
31impl<Sizer> TextSplitter<Sizer>
32where
33    Sizer: ChunkSizer,
34{
35    /// Creates a new [`TextSplitter`].
36    ///
37    /// ```
38    /// use text_splitter::TextSplitter;
39    ///
40    /// // By default, the chunk sizer is based on characters.
41    /// let splitter = TextSplitter::new(512);
42    /// ```
43    #[must_use]
44    pub fn new(chunk_config: impl Into<ChunkConfig<Sizer>>) -> Self {
45        Self {
46            chunk_config: chunk_config.into(),
47        }
48    }
49
50    /// Generate a list of chunks from a given text. Each chunk will be up to the `chunk_capacity`.
51    ///
52    /// ## Method
53    ///
54    /// To preserve as much semantic meaning within a chunk as possible, each chunk is composed of the largest semantic units that can fit in the next given chunk. For each splitter type, there is a defined set of semantic levels. Here is an example of the steps used:
55    ///
56    /// 1. Split the text by a increasing semantic levels.
57    /// 2. Check the first item for each level and select the highest level whose first item still fits within the chunk size.
58    /// 3. Merge as many of these neighboring sections of this level or above into a chunk to maximize chunk length.
59    ///    Boundaries of higher semantic levels are always included when merging, so that the chunk doesn't inadvertantly cross semantic boundaries.
60    ///
61    /// The boundaries used to split the text if using the `chunks` method, in ascending order:
62    ///
63    /// 1. Characters
64    /// 2. [Unicode Grapheme Cluster Boundaries](https://www.unicode.org/reports/tr29/#Grapheme_Cluster_Boundaries)
65    /// 3. [Unicode Word Boundaries](https://www.unicode.org/reports/tr29/#Word_Boundaries)
66    /// 4. [Unicode Sentence Boundaries](https://www.unicode.org/reports/tr29/#Sentence_Boundaries)
67    /// 5. Ascending sequence length of newlines. (Newline is `\r\n`, `\n`, or `\r`)
68    ///    Each unique length of consecutive newline sequences is treated as its own semantic level. So a sequence of 2 newlines is a higher level than a sequence of 1 newline, and so on.
69    ///
70    /// Splitting doesn't occur below the character level, otherwise you could get partial bytes of a char, which may not be a valid unicode str.
71    ///
72    /// ```
73    /// use text_splitter::TextSplitter;
74    ///
75    /// let splitter = TextSplitter::new(10);
76    /// let text = "Some text\n\nfrom a\ndocument";
77    /// let chunks = splitter.chunks(text).collect::<Vec<_>>();
78    ///
79    /// assert_eq!(vec!["Some text", "from a", "document"], chunks);
80    /// ```
81    pub fn chunks<'splitter, 'text: 'splitter>(
82        &'splitter self,
83        text: &'text str,
84    ) -> impl Iterator<Item = &'text str> + 'splitter {
85        Splitter::<_>::chunks(self, text)
86    }
87
88    /// Returns an iterator over chunks of the text and their byte offsets.
89    /// Each chunk will be up to the `chunk_capacity`.
90    ///
91    /// See [`TextSplitter::chunks`] for more information.
92    ///
93    /// ```
94    /// use text_splitter::{Characters, TextSplitter};
95    ///
96    /// let splitter = TextSplitter::new(10);
97    /// let text = "Some text\n\nfrom a\ndocument";
98    /// let chunks = splitter.chunk_indices(text).collect::<Vec<_>>();
99    ///
100    /// assert_eq!(vec![(0, "Some text"), (11, "from a"), (18, "document")], chunks);
101    /// ```
102    pub fn chunk_indices<'splitter, 'text: 'splitter>(
103        &'splitter self,
104        text: &'text str,
105    ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
106        Splitter::<_>::chunk_indices(self, text)
107    }
108
109    /// Returns an iterator over chunks of the text with their byte and character offsets.
110    /// Each chunk will be up to the `chunk_capacity`.
111    ///
112    /// See [`TextSplitter::chunks`] for more information.
113    ///
114    /// This will be more expensive than just byte offsets, and for most usage in Rust, just
115    /// having byte offsets is sufficient. But when interfacing with other languages or systems
116    /// that require character offsets, this will track the character offsets for you,
117    /// accounting for any trimming that may have occurred.
118    ///
119    /// ```
120    /// use text_splitter::{Characters, ChunkCharIndex, TextSplitter};
121    ///
122    /// let splitter = TextSplitter::new(10);
123    /// let text = "Some text\n\nfrom a\ndocument";
124    /// let chunks = splitter.chunk_char_indices(text).collect::<Vec<_>>();
125    ///
126    /// assert_eq!(vec![ChunkCharIndex { chunk: "Some text", byte_offset: 0, char_offset: 0 }, ChunkCharIndex { chunk: "from a", byte_offset: 11, char_offset: 11 }, ChunkCharIndex { chunk: "document", byte_offset: 18, char_offset: 18 }], chunks);
127    pub fn chunk_char_indices<'splitter, 'text: 'splitter>(
128        &'splitter self,
129        text: &'text str,
130    ) -> impl Iterator<Item = ChunkCharIndex<'text>> + 'splitter {
131        Splitter::<_>::chunk_char_indices(self, text)
132    }
133}
134
135impl<Sizer> Splitter<Sizer> for TextSplitter<Sizer>
136where
137    Sizer: ChunkSizer,
138{
139    type Level = LineBreaks;
140
141    fn chunk_config(&self) -> &ChunkConfig<Sizer> {
142        &self.chunk_config
143    }
144
145    fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
146        #[allow(clippy::range_plus_one)]
147        memchr2_iter(b'\n', b'\r', text.as_bytes())
148            .map(|i| i..i + 1)
149            .coalesce(|a, b| {
150                if a.end == b.start {
151                    Ok(a.start..b.end)
152                } else {
153                    Err((a, b))
154                }
155            })
156            .map(|range| {
157                let level = GRAPHEME_SEGMENTER
158                    .segment_str(text.get(range.start..range.end).unwrap())
159                    .tuple_windows::<(usize, usize)>()
160                    .count();
161                (
162                    match level {
163                        0 => unreachable!("regex should always match at least one newline"),
164                        n => LineBreaks(n),
165                    },
166                    range,
167                )
168            })
169            .collect()
170    }
171}
172
173/// Different semantic levels that text can be split by.
174/// Each level provides a method of splitting text into chunks of a given level
175/// as well as a fallback in case a given fallback is too large.
176///
177/// Split by given number of linebreaks, either `\n`, `\r`, or `\r\n`.
178#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
179pub struct LineBreaks(usize);
180
181impl SemanticLevel for LineBreaks {}
182
183#[cfg(test)]
184mod tests {
185    use std::cmp::min;
186
187    use fake::{Fake, Faker};
188
189    use crate::{splitter::SemanticSplitRanges, ChunkCharIndex};
190
191    use super::*;
192
193    #[test]
194    fn returns_one_chunk_if_text_is_shorter_than_max_chunk_size() {
195        let text = Faker.fake::<String>();
196        let chunks = TextSplitter::new(ChunkConfig::new(text.chars().count()).with_trim(false))
197            .chunks(&text)
198            .collect::<Vec<_>>();
199
200        assert_eq!(vec![&text], chunks);
201    }
202
203    #[test]
204    fn returns_two_chunks_if_text_is_longer_than_max_chunk_size() {
205        let text1 = Faker.fake::<String>();
206        let text2 = Faker.fake::<String>();
207        let text = format!("{text1}{text2}");
208        // Round up to one above half so it goes to 2 chunks
209        let max_chunk_size = text.chars().count() / 2 + 1;
210        let chunks = TextSplitter::new(ChunkConfig::new(max_chunk_size).with_trim(false))
211            .chunks(&text)
212            .collect::<Vec<_>>();
213
214        assert!(chunks.iter().all(|c| c.chars().count() <= max_chunk_size));
215
216        // Check that beginning of first chunk and text 1 matches
217        let len = min(text1.len(), chunks[0].len());
218        assert_eq!(text1[..len], chunks[0][..len]);
219        // Check that end of second chunk and text 2 matches
220        let len = min(text2.len(), chunks[1].len());
221        assert_eq!(
222            text2[(text2.len() - len)..],
223            chunks[1][chunks[1].len() - len..]
224        );
225
226        assert_eq!(chunks.join(""), text);
227    }
228
229    #[test]
230    fn empty_string() {
231        let text = "";
232        let chunks = TextSplitter::new(ChunkConfig::new(100).with_trim(false))
233            .chunks(text)
234            .collect::<Vec<_>>();
235
236        assert!(chunks.is_empty());
237    }
238
239    #[test]
240    fn can_handle_unicode_characters() {
241        let text = "éé"; // Char that is more than one byte
242        let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
243            .chunks(text)
244            .collect::<Vec<_>>();
245        assert_eq!(vec!["é", "é"], chunks);
246    }
247
248    // Just for testing
249    struct Str;
250
251    impl ChunkSizer for Str {
252        fn size(&self, chunk: &str) -> usize {
253            chunk.len()
254        }
255    }
256
257    #[test]
258    fn custom_len_function() {
259        let text = "éé"; // Char that is two bytes each
260        let chunks = TextSplitter::new(ChunkConfig::new(2).with_sizer(Str).with_trim(false))
261            .chunks(text)
262            .collect::<Vec<_>>();
263
264        assert_eq!(vec!["é", "é"], chunks);
265    }
266
267    #[test]
268    fn handles_char_bigger_than_len() {
269        let text = "éé"; // Char that is two bytes each
270        let chunks = TextSplitter::new(ChunkConfig::new(1).with_sizer(Str).with_trim(false))
271            .chunks(text)
272            .collect::<Vec<_>>();
273
274        // We can only go so small
275        assert_eq!(vec!["é", "é"], chunks);
276    }
277
278    #[test]
279    fn chunk_by_graphemes() {
280        let text = "a̐éö̲\r\n";
281        let chunks = TextSplitter::new(ChunkConfig::new(3).with_trim(false))
282            .chunks(text)
283            .collect::<Vec<_>>();
284
285        // \r\n is grouped together not separated
286        assert_eq!(vec!["a̐é", "ö̲", "\r\n"], chunks);
287    }
288
289    #[test]
290    fn trim_char_indices() {
291        let text = " a b ";
292        let chunks = TextSplitter::new(1).chunk_indices(text).collect::<Vec<_>>();
293
294        assert_eq!(vec![(1, "a"), (3, "b")], chunks);
295    }
296
297    #[test]
298    fn chunk_char_indices() {
299        let text = " a b ";
300        let chunks = TextSplitter::new(1)
301            .chunk_char_indices(text)
302            .collect::<Vec<_>>();
303
304        assert_eq!(
305            vec![
306                ChunkCharIndex {
307                    chunk: "a",
308                    byte_offset: 1,
309                    char_offset: 1
310                },
311                ChunkCharIndex {
312                    chunk: "b",
313                    byte_offset: 3,
314                    char_offset: 3,
315                },
316            ],
317            chunks
318        );
319    }
320
321    #[test]
322    fn graphemes_fallback_to_chars() {
323        let text = "a̐éö̲\r\n";
324        let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
325            .chunks(text)
326            .collect::<Vec<_>>();
327        assert_eq!(
328            vec!["a", "\u{310}", "é", "ö", "\u{332}", "\r", "\n"],
329            chunks
330        );
331    }
332
333    #[test]
334    fn trim_grapheme_indices() {
335        let text = "\r\na̐éö̲\r\n";
336        let chunks = TextSplitter::new(3).chunk_indices(text).collect::<Vec<_>>();
337
338        assert_eq!(vec![(2, "a̐é"), (7, "ö̲")], chunks);
339    }
340
341    #[test]
342    fn grapheme_char_indices() {
343        let text = "\r\na̐éö̲\r\n";
344        let chunks = TextSplitter::new(3)
345            .chunk_char_indices(text)
346            .collect::<Vec<_>>();
347
348        assert_eq!(
349            vec![
350                ChunkCharIndex {
351                    chunk: "a̐é",
352                    byte_offset: 2,
353                    char_offset: 2
354                },
355                ChunkCharIndex {
356                    chunk: "ö̲",
357                    byte_offset: 7,
358                    char_offset: 5
359                }
360            ],
361            chunks
362        );
363    }
364
365    #[test]
366    fn chunk_by_words() {
367        let text = "The quick (\"brown\") fox can't jump 32.3 feet, right?";
368        let chunks = TextSplitter::new(ChunkConfig::new(10).with_trim(false))
369            .chunks(text)
370            .collect::<Vec<_>>();
371
372        assert_eq!(
373            vec![
374                "The quick ",
375                "(\"brown\") ",
376                "fox can't ",
377                "jump 32.3 ",
378                "feet, ",
379                "right?"
380            ],
381            chunks
382        );
383    }
384
385    #[test]
386    fn words_fallback_to_graphemes() {
387        let text = "Thé quick\r\n";
388        let chunks = TextSplitter::new(ChunkConfig::new(2).with_trim(false))
389            .chunks(text)
390            .collect::<Vec<_>>();
391        assert_eq!(vec!["Th", "é ", "qu", "ic", "k", "\r\n"], chunks);
392    }
393
394    #[test]
395    fn trim_word_indices() {
396        let text = "Some text from a document";
397        let chunks = TextSplitter::new(10)
398            .chunk_indices(text)
399            .collect::<Vec<_>>();
400        assert_eq!(
401            vec![(0, "Some text"), (10, "from a"), (17, "document")],
402            chunks
403        );
404    }
405
406    #[test]
407    fn chunk_by_sentences() {
408        let text = "Mr. Fox jumped. [...] The dog was too lazy.";
409        let chunks = TextSplitter::new(ChunkConfig::new(21).with_trim(false))
410            .chunks(text)
411            .collect::<Vec<_>>();
412        assert_eq!(
413            vec!["Mr. Fox jumped. ", "[...] ", "The dog was too lazy."],
414            chunks
415        );
416    }
417
418    #[test]
419    fn sentences_falls_back_to_words() {
420        let text = "Mr. Fox jumped. [...] The dog was too lazy.";
421        let chunks = TextSplitter::new(ChunkConfig::new(16).with_trim(false))
422            .chunks(text)
423            .collect::<Vec<_>>();
424        assert_eq!(
425            vec!["Mr. Fox jumped. ", "[...] ", "The dog was too ", "lazy."],
426            chunks
427        );
428    }
429
430    #[test]
431    fn trim_sentence_indices() {
432        let text = "Some text. From a document.";
433        let chunks = TextSplitter::new(10)
434            .chunk_indices(text)
435            .collect::<Vec<_>>();
436        assert_eq!(
437            vec![(0, "Some text."), (11, "From a"), (18, "document.")],
438            chunks
439        );
440    }
441
442    #[test]
443    fn trim_paragraph_indices() {
444        let text = "Some text\n\nfrom a\ndocument";
445        let chunks = TextSplitter::new(10)
446            .chunk_indices(text)
447            .collect::<Vec<_>>();
448        assert_eq!(
449            vec![(0, "Some text"), (11, "from a"), (18, "document")],
450            chunks
451        );
452    }
453
454    #[test]
455    fn correctly_determines_newlines() {
456        let text = "\r\n\r\ntext\n\n\ntext2";
457        let splitter = TextSplitter::new(10);
458        let linebreaks = SemanticSplitRanges::new(splitter.parse(text));
459        assert_eq!(
460            vec![(LineBreaks(2), 0..4), (LineBreaks(3), 8..11)],
461            linebreaks.ranges
462        );
463    }
464}