text_splitter/splitter/
text.rs

1/*!
2# [`TextSplitter`]
3Semantic splitting of text documents.
4*/
5
6use std::ops::Range;
7
8use itertools::Itertools;
9use memchr::memchr2_iter;
10
11use crate::{
12    splitter::{SemanticLevel, Splitter},
13    ChunkConfig, ChunkSizer,
14};
15
16use super::{fallback::GRAPHEME_SEGMENTER, ChunkCharIndex};
17
18/// Default plain-text splitter. Recursively splits chunks into the largest
19/// semantic units that fit within the chunk size. Also will attempt to merge
20/// neighboring chunks if they can fit within the given chunk size.
21#[derive(Debug)]
22pub struct TextSplitter<Sizer>
23where
24    Sizer: ChunkSizer,
25{
26    /// Method of determining chunk sizes.
27    chunk_config: ChunkConfig<Sizer>,
28}
29
30impl<Sizer> TextSplitter<Sizer>
31where
32    Sizer: ChunkSizer,
33{
34    /// Creates a new [`TextSplitter`].
35    ///
36    /// ```
37    /// use text_splitter::TextSplitter;
38    ///
39    /// // By default, the chunk sizer is based on characters.
40    /// let splitter = TextSplitter::new(512);
41    /// ```
42    #[must_use]
43    pub fn new(chunk_config: impl Into<ChunkConfig<Sizer>>) -> Self {
44        Self {
45            chunk_config: chunk_config.into(),
46        }
47    }
48
49    /// Generate a list of chunks from a given text. Each chunk will be up to the `chunk_capacity`.
50    ///
51    /// ## Method
52    ///
53    /// To preserve as much semantic meaning within a chunk as possible, each chunk is composed of the largest semantic units that can fit in the next given chunk. For each splitter type, there is a defined set of semantic levels. Here is an example of the steps used:
54    ///
55    /// 1. Split the text by a increasing semantic levels.
56    /// 2. Check the first item for each level and select the highest level whose first item still fits within the chunk size.
57    /// 3. Merge as many of these neighboring sections of this level or above into a chunk to maximize chunk length.
58    ///    Boundaries of higher semantic levels are always included when merging, so that the chunk doesn't inadvertantly cross semantic boundaries.
59    ///
60    /// The boundaries used to split the text if using the `chunks` method, in ascending order:
61    ///
62    /// 1. Characters
63    /// 2. [Unicode Grapheme Cluster Boundaries](https://www.unicode.org/reports/tr29/#Grapheme_Cluster_Boundaries)
64    /// 3. [Unicode Word Boundaries](https://www.unicode.org/reports/tr29/#Word_Boundaries)
65    /// 4. [Unicode Sentence Boundaries](https://www.unicode.org/reports/tr29/#Sentence_Boundaries)
66    /// 5. Ascending sequence length of newlines. (Newline is `\r\n`, `\n`, or `\r`)
67    ///    Each unique length of consecutive newline sequences is treated as its own semantic level. So a sequence of 2 newlines is a higher level than a sequence of 1 newline, and so on.
68    ///
69    /// Splitting doesn't occur below the character level, otherwise you could get partial bytes of a char, which may not be a valid unicode str.
70    ///
71    /// ```
72    /// use text_splitter::TextSplitter;
73    ///
74    /// let splitter = TextSplitter::new(10);
75    /// let text = "Some text\n\nfrom a\ndocument";
76    /// let chunks = splitter.chunks(text).collect::<Vec<_>>();
77    ///
78    /// assert_eq!(vec!["Some text", "from a", "document"], chunks);
79    /// ```
80    pub fn chunks<'splitter, 'text: 'splitter>(
81        &'splitter self,
82        text: &'text str,
83    ) -> impl Iterator<Item = &'text str> + 'splitter {
84        Splitter::<_>::chunks(self, text)
85    }
86
87    /// Returns an iterator over chunks of the text and their byte offsets.
88    /// Each chunk will be up to the `chunk_capacity`.
89    ///
90    /// See [`TextSplitter::chunks`] for more information.
91    ///
92    /// ```
93    /// use text_splitter::{Characters, TextSplitter};
94    ///
95    /// let splitter = TextSplitter::new(10);
96    /// let text = "Some text\n\nfrom a\ndocument";
97    /// let chunks = splitter.chunk_indices(text).collect::<Vec<_>>();
98    ///
99    /// assert_eq!(vec![(0, "Some text"), (11, "from a"), (18, "document")], chunks);
100    /// ```
101    pub fn chunk_indices<'splitter, 'text: 'splitter>(
102        &'splitter self,
103        text: &'text str,
104    ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
105        Splitter::<_>::chunk_indices(self, text)
106    }
107
108    /// Returns an iterator over chunks of the text with their byte and character offsets.
109    /// Each chunk will be up to the `chunk_capacity`.
110    ///
111    /// See [`TextSplitter::chunks`] for more information.
112    ///
113    /// This will be more expensive than just byte offsets, and for most usage in Rust, just
114    /// having byte offsets is sufficient. But when interfacing with other languages or systems
115    /// that require character offsets, this will track the character offsets for you,
116    /// accounting for any trimming that may have occurred.
117    ///
118    /// ```
119    /// use text_splitter::{Characters, ChunkCharIndex, TextSplitter};
120    ///
121    /// let splitter = TextSplitter::new(10);
122    /// let text = "Some text\n\nfrom a\ndocument";
123    /// let chunks = splitter.chunk_char_indices(text).collect::<Vec<_>>();
124    ///
125    /// assert_eq!(vec![ChunkCharIndex { chunk: "Some text", byte_offset: 0, char_offset: 0 }, ChunkCharIndex { chunk: "from a", byte_offset: 11, char_offset: 11 }, ChunkCharIndex { chunk: "document", byte_offset: 18, char_offset: 18 }], chunks);
126    pub fn chunk_char_indices<'splitter, 'text: 'splitter>(
127        &'splitter self,
128        text: &'text str,
129    ) -> impl Iterator<Item = ChunkCharIndex<'text>> + 'splitter {
130        Splitter::<_>::chunk_char_indices(self, text)
131    }
132}
133
134impl<Sizer> Splitter<Sizer> for TextSplitter<Sizer>
135where
136    Sizer: ChunkSizer,
137{
138    type Level = LineBreaks;
139
140    fn chunk_config(&self) -> &ChunkConfig<Sizer> {
141        &self.chunk_config
142    }
143
144    fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
145        #[expect(clippy::range_plus_one)]
146        memchr2_iter(b'\n', b'\r', text.as_bytes())
147            .map(|i| i..i + 1)
148            .coalesce(|a, b| {
149                if a.end == b.start {
150                    Ok(a.start..b.end)
151                } else {
152                    Err((a, b))
153                }
154            })
155            .map(|range| {
156                let level = GRAPHEME_SEGMENTER
157                    .segment_str(text.get(range.start..range.end).unwrap())
158                    .tuple_windows::<(usize, usize)>()
159                    .count();
160                (
161                    match level {
162                        0 => unreachable!("regex should always match at least one newline"),
163                        n => LineBreaks(n),
164                    },
165                    range,
166                )
167            })
168            .collect()
169    }
170}
171
172/// Different semantic levels that text can be split by.
173/// Each level provides a method of splitting text into chunks of a given level
174/// as well as a fallback in case a given fallback is too large.
175///
176/// Split by given number of linebreaks, either `\n`, `\r`, or `\r\n`.
177#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
178pub struct LineBreaks(usize);
179
180impl SemanticLevel for LineBreaks {}
181
182#[cfg(test)]
183mod tests {
184    use std::cmp::min;
185
186    use fake::{Fake, Faker};
187
188    use crate::{splitter::SemanticSplitRanges, ChunkCharIndex};
189
190    use super::*;
191
192    #[test]
193    fn returns_one_chunk_if_text_is_shorter_than_max_chunk_size() {
194        let text = Faker.fake::<String>();
195        let chunks = TextSplitter::new(ChunkConfig::new(text.chars().count()).with_trim(false))
196            .chunks(&text)
197            .collect::<Vec<_>>();
198
199        assert_eq!(vec![&text], chunks);
200    }
201
202    #[test]
203    fn returns_two_chunks_if_text_is_longer_than_max_chunk_size() {
204        let text1 = Faker.fake::<String>();
205        let text2 = Faker.fake::<String>();
206        let text = format!("{text1}{text2}");
207        // Round up to one above half so it goes to 2 chunks
208        let max_chunk_size = text.chars().count() / 2 + 1;
209        let chunks = TextSplitter::new(ChunkConfig::new(max_chunk_size).with_trim(false))
210            .chunks(&text)
211            .collect::<Vec<_>>();
212
213        assert!(chunks.iter().all(|c| c.chars().count() <= max_chunk_size));
214
215        // Check that beginning of first chunk and text 1 matches
216        let len = min(text1.len(), chunks[0].len());
217        assert_eq!(text1[..len], chunks[0][..len]);
218        // Check that end of second chunk and text 2 matches
219        let len = min(text2.len(), chunks[1].len());
220        assert_eq!(
221            text2[(text2.len() - len)..],
222            chunks[1][chunks[1].len() - len..]
223        );
224
225        assert_eq!(chunks.join(""), text);
226    }
227
228    #[test]
229    fn empty_string() {
230        let text = "";
231        let chunks = TextSplitter::new(ChunkConfig::new(100).with_trim(false))
232            .chunks(text)
233            .collect::<Vec<_>>();
234
235        assert!(chunks.is_empty());
236    }
237
238    #[test]
239    fn can_handle_unicode_characters() {
240        let text = "éé"; // Char that is more than one byte
241        let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
242            .chunks(text)
243            .collect::<Vec<_>>();
244        assert_eq!(vec!["é", "é"], chunks);
245    }
246
247    // Just for testing
248    struct Str;
249
250    impl ChunkSizer for Str {
251        fn size(&self, chunk: &str) -> usize {
252            chunk.len()
253        }
254    }
255
256    #[test]
257    fn custom_len_function() {
258        let text = "éé"; // Char that is two bytes each
259        let chunks = TextSplitter::new(ChunkConfig::new(2).with_sizer(Str).with_trim(false))
260            .chunks(text)
261            .collect::<Vec<_>>();
262
263        assert_eq!(vec!["é", "é"], chunks);
264    }
265
266    #[test]
267    fn handles_char_bigger_than_len() {
268        let text = "éé"; // Char that is two bytes each
269        let chunks = TextSplitter::new(ChunkConfig::new(1).with_sizer(Str).with_trim(false))
270            .chunks(text)
271            .collect::<Vec<_>>();
272
273        // We can only go so small
274        assert_eq!(vec!["é", "é"], chunks);
275    }
276
277    #[test]
278    fn chunk_by_graphemes() {
279        let text = "a̐éö̲\r\n";
280        let chunks = TextSplitter::new(ChunkConfig::new(3).with_trim(false))
281            .chunks(text)
282            .collect::<Vec<_>>();
283
284        // \r\n is grouped together not separated
285        assert_eq!(vec!["a̐é", "ö̲", "\r\n"], chunks);
286    }
287
288    #[test]
289    fn trim_char_indices() {
290        let text = " a b ";
291        let chunks = TextSplitter::new(1).chunk_indices(text).collect::<Vec<_>>();
292
293        assert_eq!(vec![(1, "a"), (3, "b")], chunks);
294    }
295
296    #[test]
297    fn chunk_char_indices() {
298        let text = " a b ";
299        let chunks = TextSplitter::new(1)
300            .chunk_char_indices(text)
301            .collect::<Vec<_>>();
302
303        assert_eq!(
304            vec![
305                ChunkCharIndex {
306                    chunk: "a",
307                    byte_offset: 1,
308                    char_offset: 1
309                },
310                ChunkCharIndex {
311                    chunk: "b",
312                    byte_offset: 3,
313                    char_offset: 3,
314                },
315            ],
316            chunks
317        );
318    }
319
320    #[test]
321    fn graphemes_fallback_to_chars() {
322        let text = "a̐éö̲\r\n";
323        let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
324            .chunks(text)
325            .collect::<Vec<_>>();
326        assert_eq!(
327            vec!["a", "\u{310}", "é", "ö", "\u{332}", "\r", "\n"],
328            chunks
329        );
330    }
331
332    #[test]
333    fn trim_grapheme_indices() {
334        let text = "\r\na̐éö̲\r\n";
335        let chunks = TextSplitter::new(3).chunk_indices(text).collect::<Vec<_>>();
336
337        assert_eq!(vec![(2, "a̐é"), (7, "ö̲")], chunks);
338    }
339
340    #[test]
341    fn grapheme_char_indices() {
342        let text = "\r\na̐éö̲\r\n";
343        let chunks = TextSplitter::new(3)
344            .chunk_char_indices(text)
345            .collect::<Vec<_>>();
346
347        assert_eq!(
348            vec![
349                ChunkCharIndex {
350                    chunk: "a̐é",
351                    byte_offset: 2,
352                    char_offset: 2
353                },
354                ChunkCharIndex {
355                    chunk: "ö̲",
356                    byte_offset: 7,
357                    char_offset: 5
358                }
359            ],
360            chunks
361        );
362    }
363
364    #[test]
365    fn chunk_by_words() {
366        let text = "The quick (\"brown\") fox can't jump 32.3 feet, right?";
367        let chunks = TextSplitter::new(ChunkConfig::new(10).with_trim(false))
368            .chunks(text)
369            .collect::<Vec<_>>();
370
371        assert_eq!(
372            vec![
373                "The quick ",
374                "(\"brown\") ",
375                "fox can't ",
376                "jump 32.3 ",
377                "feet, ",
378                "right?"
379            ],
380            chunks
381        );
382    }
383
384    #[test]
385    fn words_fallback_to_graphemes() {
386        let text = "Thé quick\r\n";
387        let chunks = TextSplitter::new(ChunkConfig::new(2).with_trim(false))
388            .chunks(text)
389            .collect::<Vec<_>>();
390        assert_eq!(vec!["Th", "é ", "qu", "ic", "k", "\r\n"], chunks);
391    }
392
393    #[test]
394    fn trim_word_indices() {
395        let text = "Some text from a document";
396        let chunks = TextSplitter::new(10)
397            .chunk_indices(text)
398            .collect::<Vec<_>>();
399        assert_eq!(
400            vec![(0, "Some text"), (10, "from a"), (17, "document")],
401            chunks
402        );
403    }
404
405    #[test]
406    fn chunk_by_sentences() {
407        let text = "Mr. Fox jumped. [...] The dog was too lazy.";
408        let chunks = TextSplitter::new(ChunkConfig::new(21).with_trim(false))
409            .chunks(text)
410            .collect::<Vec<_>>();
411        assert_eq!(
412            vec!["Mr. Fox jumped. ", "[...] ", "The dog was too lazy."],
413            chunks
414        );
415    }
416
417    #[test]
418    fn sentences_falls_back_to_words() {
419        let text = "Mr. Fox jumped. [...] The dog was too lazy.";
420        let chunks = TextSplitter::new(ChunkConfig::new(16).with_trim(false))
421            .chunks(text)
422            .collect::<Vec<_>>();
423        assert_eq!(
424            vec!["Mr. Fox jumped. ", "[...] ", "The dog was too ", "lazy."],
425            chunks
426        );
427    }
428
429    #[test]
430    fn trim_sentence_indices() {
431        let text = "Some text. From a document.";
432        let chunks = TextSplitter::new(10)
433            .chunk_indices(text)
434            .collect::<Vec<_>>();
435        assert_eq!(
436            vec![(0, "Some text."), (11, "From a"), (18, "document.")],
437            chunks
438        );
439    }
440
441    #[test]
442    fn trim_paragraph_indices() {
443        let text = "Some text\n\nfrom a\ndocument";
444        let chunks = TextSplitter::new(10)
445            .chunk_indices(text)
446            .collect::<Vec<_>>();
447        assert_eq!(
448            vec![(0, "Some text"), (11, "from a"), (18, "document")],
449            chunks
450        );
451    }
452
453    #[test]
454    fn correctly_determines_newlines() {
455        let text = "\r\n\r\ntext\n\n\ntext2";
456        let splitter = TextSplitter::new(10);
457        let linebreaks = SemanticSplitRanges::new(splitter.parse(text));
458        assert_eq!(
459            vec![(LineBreaks(2), 0..4), (LineBreaks(3), 8..11)],
460            linebreaks.ranges
461        );
462    }
463}