1use std::{ops::Range, sync::LazyLock};
7
8use itertools::Itertools;
9use regex::Regex;
10
11use crate::{
12 splitter::{SemanticLevel, Splitter},
13 ChunkConfig, ChunkSizer,
14};
15
16use super::fallback::GRAPHEME_SEGMENTER;
17
18#[derive(Debug)]
22#[allow(clippy::module_name_repetitions)]
23pub struct TextSplitter<Sizer>
24where
25 Sizer: ChunkSizer,
26{
27 chunk_config: ChunkConfig<Sizer>,
29}
30
31impl<Sizer> TextSplitter<Sizer>
32where
33 Sizer: ChunkSizer,
34{
35 #[must_use]
44 pub fn new(chunk_config: impl Into<ChunkConfig<Sizer>>) -> Self {
45 Self {
46 chunk_config: chunk_config.into(),
47 }
48 }
49
50 pub fn chunks<'splitter, 'text: 'splitter>(
82 &'splitter self,
83 text: &'text str,
84 ) -> impl Iterator<Item = &'text str> + 'splitter {
85 Splitter::<_>::chunks(self, text)
86 }
87
88 pub fn chunk_indices<'splitter, 'text: 'splitter>(
102 &'splitter self,
103 text: &'text str,
104 ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
105 Splitter::<_>::chunk_indices(self, text)
106 }
107}
108
109impl<Sizer> Splitter<Sizer> for TextSplitter<Sizer>
110where
111 Sizer: ChunkSizer,
112{
113 type Level = LineBreaks;
114
115 fn chunk_config(&self) -> &ChunkConfig<Sizer> {
116 &self.chunk_config
117 }
118
119 fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
120 CAPTURE_LINEBREAKS
121 .find_iter(text)
122 .map(|m| {
123 let range = m.range();
124 let level = GRAPHEME_SEGMENTER
125 .segment_str(text.get(range.start..range.end).unwrap())
126 .tuple_windows::<(usize, usize)>()
127 .count();
128 (
129 match level {
130 0 => unreachable!("regex should always match at least one newline"),
131 n => LineBreaks(n),
132 },
133 range,
134 )
135 })
136 .collect()
137 }
138}
139
140#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
146pub struct LineBreaks(usize);
147
148static CAPTURE_LINEBREAKS: LazyLock<Regex> =
150 LazyLock::new(|| Regex::new(r"(\r\n)+|\r+|\n+").unwrap());
151
152impl SemanticLevel for LineBreaks {}
153
154#[cfg(test)]
155mod tests {
156 use std::cmp::min;
157
158 use fake::{Fake, Faker};
159
160 use crate::splitter::SemanticSplitRanges;
161
162 use super::*;
163
164 #[test]
165 fn returns_one_chunk_if_text_is_shorter_than_max_chunk_size() {
166 let text = Faker.fake::<String>();
167 let chunks = TextSplitter::new(ChunkConfig::new(text.chars().count()).with_trim(false))
168 .chunks(&text)
169 .collect::<Vec<_>>();
170
171 assert_eq!(vec![&text], chunks);
172 }
173
174 #[test]
175 fn returns_two_chunks_if_text_is_longer_than_max_chunk_size() {
176 let text1 = Faker.fake::<String>();
177 let text2 = Faker.fake::<String>();
178 let text = format!("{text1}{text2}");
179 let max_chunk_size = text.chars().count() / 2 + 1;
181 let chunks = TextSplitter::new(ChunkConfig::new(max_chunk_size).with_trim(false))
182 .chunks(&text)
183 .collect::<Vec<_>>();
184
185 assert!(chunks.iter().all(|c| c.chars().count() <= max_chunk_size));
186
187 let len = min(text1.len(), chunks[0].len());
189 assert_eq!(text1[..len], chunks[0][..len]);
190 let len = min(text2.len(), chunks[1].len());
192 assert_eq!(
193 text2[(text2.len() - len)..],
194 chunks[1][chunks[1].len() - len..]
195 );
196
197 assert_eq!(chunks.join(""), text);
198 }
199
200 #[test]
201 fn empty_string() {
202 let text = "";
203 let chunks = TextSplitter::new(ChunkConfig::new(100).with_trim(false))
204 .chunks(text)
205 .collect::<Vec<_>>();
206
207 assert!(chunks.is_empty());
208 }
209
210 #[test]
211 fn can_handle_unicode_characters() {
212 let text = "éé"; let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
214 .chunks(text)
215 .collect::<Vec<_>>();
216 assert_eq!(vec!["é", "é"], chunks);
217 }
218
219 struct Str;
221
222 impl ChunkSizer for Str {
223 fn size(&self, chunk: &str) -> usize {
224 chunk.len()
225 }
226 }
227
228 #[test]
229 fn custom_len_function() {
230 let text = "éé"; let chunks = TextSplitter::new(ChunkConfig::new(2).with_sizer(Str).with_trim(false))
232 .chunks(text)
233 .collect::<Vec<_>>();
234
235 assert_eq!(vec!["é", "é"], chunks);
236 }
237
238 #[test]
239 fn handles_char_bigger_than_len() {
240 let text = "éé"; let chunks = TextSplitter::new(ChunkConfig::new(1).with_sizer(Str).with_trim(false))
242 .chunks(text)
243 .collect::<Vec<_>>();
244
245 assert_eq!(vec!["é", "é"], chunks);
247 }
248
249 #[test]
250 fn chunk_by_graphemes() {
251 let text = "a̐éö̲\r\n";
252 let chunks = TextSplitter::new(ChunkConfig::new(3).with_trim(false))
253 .chunks(text)
254 .collect::<Vec<_>>();
255
256 assert_eq!(vec!["a̐é", "ö̲", "\r\n"], chunks);
258 }
259
260 #[test]
261 fn trim_char_indices() {
262 let text = " a b ";
263 let chunks = TextSplitter::new(1).chunk_indices(text).collect::<Vec<_>>();
264
265 assert_eq!(vec![(1, "a"), (3, "b")], chunks);
266 }
267
268 #[test]
269 fn graphemes_fallback_to_chars() {
270 let text = "a̐éö̲\r\n";
271 let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
272 .chunks(text)
273 .collect::<Vec<_>>();
274 assert_eq!(
275 vec!["a", "\u{310}", "é", "ö", "\u{332}", "\r", "\n"],
276 chunks
277 );
278 }
279
280 #[test]
281 fn trim_grapheme_indices() {
282 let text = "\r\na̐éö̲\r\n";
283 let chunks = TextSplitter::new(3).chunk_indices(text).collect::<Vec<_>>();
284
285 assert_eq!(vec![(2, "a̐é"), (7, "ö̲")], chunks);
286 }
287
288 #[test]
289 fn chunk_by_words() {
290 let text = "The quick (\"brown\") fox can't jump 32.3 feet, right?";
291 let chunks = TextSplitter::new(ChunkConfig::new(10).with_trim(false))
292 .chunks(text)
293 .collect::<Vec<_>>();
294
295 assert_eq!(
296 vec![
297 "The quick ",
298 "(\"brown\") ",
299 "fox can't ",
300 "jump 32.3 ",
301 "feet, ",
302 "right?"
303 ],
304 chunks
305 );
306 }
307
308 #[test]
309 fn words_fallback_to_graphemes() {
310 let text = "Thé quick\r\n";
311 let chunks = TextSplitter::new(ChunkConfig::new(2).with_trim(false))
312 .chunks(text)
313 .collect::<Vec<_>>();
314 assert_eq!(vec!["Th", "é ", "qu", "ic", "k", "\r\n"], chunks);
315 }
316
317 #[test]
318 fn trim_word_indices() {
319 let text = "Some text from a document";
320 let chunks = TextSplitter::new(10)
321 .chunk_indices(text)
322 .collect::<Vec<_>>();
323 assert_eq!(
324 vec![(0, "Some text"), (10, "from a"), (17, "document")],
325 chunks
326 );
327 }
328
329 #[test]
330 fn chunk_by_sentences() {
331 let text = "Mr. Fox jumped. [...] The dog was too lazy.";
332 let chunks = TextSplitter::new(ChunkConfig::new(21).with_trim(false))
333 .chunks(text)
334 .collect::<Vec<_>>();
335 assert_eq!(
336 vec!["Mr. Fox jumped. ", "[...] ", "The dog was too lazy."],
337 chunks
338 );
339 }
340
341 #[test]
342 fn sentences_falls_back_to_words() {
343 let text = "Mr. Fox jumped. [...] The dog was too lazy.";
344 let chunks = TextSplitter::new(ChunkConfig::new(16).with_trim(false))
345 .chunks(text)
346 .collect::<Vec<_>>();
347 assert_eq!(
348 vec!["Mr. Fox jumped. ", "[...] ", "The dog was too ", "lazy."],
349 chunks
350 );
351 }
352
353 #[test]
354 fn trim_sentence_indices() {
355 let text = "Some text. From a document.";
356 let chunks = TextSplitter::new(10)
357 .chunk_indices(text)
358 .collect::<Vec<_>>();
359 assert_eq!(
360 vec![(0, "Some text."), (11, "From a"), (18, "document.")],
361 chunks
362 );
363 }
364
365 #[test]
366 fn trim_paragraph_indices() {
367 let text = "Some text\n\nfrom a\ndocument";
368 let chunks = TextSplitter::new(10)
369 .chunk_indices(text)
370 .collect::<Vec<_>>();
371 assert_eq!(
372 vec![(0, "Some text"), (11, "from a"), (18, "document")],
373 chunks
374 );
375 }
376
377 #[test]
378 fn correctly_determines_newlines() {
379 let text = "\r\n\r\ntext\n\n\ntext2";
380 let splitter = TextSplitter::new(10);
381 let linebreaks = SemanticSplitRanges::new(splitter.parse(text));
382 assert_eq!(
383 vec![(LineBreaks(2), 0..4), (LineBreaks(3), 8..11)],
384 linebreaks.ranges
385 );
386 }
387}