1use std::{ops::Range, sync::LazyLock};
7
8use itertools::Itertools;
9use regex::Regex;
10
11use crate::{
12 splitter::{SemanticLevel, Splitter},
13 ChunkConfig, ChunkSizer,
14};
15
16use super::{fallback::GRAPHEME_SEGMENTER, ChunkCharIndex};
17
18#[derive(Debug)]
22#[allow(clippy::module_name_repetitions)]
23pub struct TextSplitter<Sizer>
24where
25 Sizer: ChunkSizer,
26{
27 chunk_config: ChunkConfig<Sizer>,
29}
30
31impl<Sizer> TextSplitter<Sizer>
32where
33 Sizer: ChunkSizer,
34{
35 #[must_use]
44 pub fn new(chunk_config: impl Into<ChunkConfig<Sizer>>) -> Self {
45 Self {
46 chunk_config: chunk_config.into(),
47 }
48 }
49
50 pub fn chunks<'splitter, 'text: 'splitter>(
82 &'splitter self,
83 text: &'text str,
84 ) -> impl Iterator<Item = &'text str> + 'splitter {
85 Splitter::<_>::chunks(self, text)
86 }
87
88 pub fn chunk_indices<'splitter, 'text: 'splitter>(
103 &'splitter self,
104 text: &'text str,
105 ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
106 Splitter::<_>::chunk_indices(self, text)
107 }
108
109 pub fn chunk_char_indices<'splitter, 'text: 'splitter>(
128 &'splitter self,
129 text: &'text str,
130 ) -> impl Iterator<Item = ChunkCharIndex<'text>> + 'splitter {
131 Splitter::<_>::chunk_char_indices(self, text)
132 }
133}
134
135impl<Sizer> Splitter<Sizer> for TextSplitter<Sizer>
136where
137 Sizer: ChunkSizer,
138{
139 type Level = LineBreaks;
140
141 fn chunk_config(&self) -> &ChunkConfig<Sizer> {
142 &self.chunk_config
143 }
144
145 fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
146 CAPTURE_LINEBREAKS
147 .find_iter(text)
148 .map(|m| {
149 let range = m.range();
150 let level = GRAPHEME_SEGMENTER
151 .segment_str(text.get(range.start..range.end).unwrap())
152 .tuple_windows::<(usize, usize)>()
153 .count();
154 (
155 match level {
156 0 => unreachable!("regex should always match at least one newline"),
157 n => LineBreaks(n),
158 },
159 range,
160 )
161 })
162 .collect()
163 }
164}
165
166#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
172pub struct LineBreaks(usize);
173
174static CAPTURE_LINEBREAKS: LazyLock<Regex> =
176 LazyLock::new(|| Regex::new(r"(\r\n)+|\r+|\n+").unwrap());
177
178impl SemanticLevel for LineBreaks {}
179
180#[cfg(test)]
181mod tests {
182 use std::cmp::min;
183
184 use fake::{Fake, Faker};
185
186 use crate::{splitter::SemanticSplitRanges, ChunkCharIndex};
187
188 use super::*;
189
190 #[test]
191 fn returns_one_chunk_if_text_is_shorter_than_max_chunk_size() {
192 let text = Faker.fake::<String>();
193 let chunks = TextSplitter::new(ChunkConfig::new(text.chars().count()).with_trim(false))
194 .chunks(&text)
195 .collect::<Vec<_>>();
196
197 assert_eq!(vec![&text], chunks);
198 }
199
200 #[test]
201 fn returns_two_chunks_if_text_is_longer_than_max_chunk_size() {
202 let text1 = Faker.fake::<String>();
203 let text2 = Faker.fake::<String>();
204 let text = format!("{text1}{text2}");
205 let max_chunk_size = text.chars().count() / 2 + 1;
207 let chunks = TextSplitter::new(ChunkConfig::new(max_chunk_size).with_trim(false))
208 .chunks(&text)
209 .collect::<Vec<_>>();
210
211 assert!(chunks.iter().all(|c| c.chars().count() <= max_chunk_size));
212
213 let len = min(text1.len(), chunks[0].len());
215 assert_eq!(text1[..len], chunks[0][..len]);
216 let len = min(text2.len(), chunks[1].len());
218 assert_eq!(
219 text2[(text2.len() - len)..],
220 chunks[1][chunks[1].len() - len..]
221 );
222
223 assert_eq!(chunks.join(""), text);
224 }
225
226 #[test]
227 fn empty_string() {
228 let text = "";
229 let chunks = TextSplitter::new(ChunkConfig::new(100).with_trim(false))
230 .chunks(text)
231 .collect::<Vec<_>>();
232
233 assert!(chunks.is_empty());
234 }
235
236 #[test]
237 fn can_handle_unicode_characters() {
238 let text = "éé"; let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
240 .chunks(text)
241 .collect::<Vec<_>>();
242 assert_eq!(vec!["é", "é"], chunks);
243 }
244
245 struct Str;
247
248 impl ChunkSizer for Str {
249 fn size(&self, chunk: &str) -> usize {
250 chunk.len()
251 }
252 }
253
254 #[test]
255 fn custom_len_function() {
256 let text = "éé"; let chunks = TextSplitter::new(ChunkConfig::new(2).with_sizer(Str).with_trim(false))
258 .chunks(text)
259 .collect::<Vec<_>>();
260
261 assert_eq!(vec!["é", "é"], chunks);
262 }
263
264 #[test]
265 fn handles_char_bigger_than_len() {
266 let text = "éé"; let chunks = TextSplitter::new(ChunkConfig::new(1).with_sizer(Str).with_trim(false))
268 .chunks(text)
269 .collect::<Vec<_>>();
270
271 assert_eq!(vec!["é", "é"], chunks);
273 }
274
275 #[test]
276 fn chunk_by_graphemes() {
277 let text = "a̐éö̲\r\n";
278 let chunks = TextSplitter::new(ChunkConfig::new(3).with_trim(false))
279 .chunks(text)
280 .collect::<Vec<_>>();
281
282 assert_eq!(vec!["a̐é", "ö̲", "\r\n"], chunks);
284 }
285
286 #[test]
287 fn trim_char_indices() {
288 let text = " a b ";
289 let chunks = TextSplitter::new(1).chunk_indices(text).collect::<Vec<_>>();
290
291 assert_eq!(vec![(1, "a"), (3, "b")], chunks);
292 }
293
294 #[test]
295 fn chunk_char_indices() {
296 let text = " a b ";
297 let chunks = TextSplitter::new(1)
298 .chunk_char_indices(text)
299 .collect::<Vec<_>>();
300
301 assert_eq!(
302 vec![
303 ChunkCharIndex {
304 chunk: "a",
305 byte_offset: 1,
306 char_offset: 1
307 },
308 ChunkCharIndex {
309 chunk: "b",
310 byte_offset: 3,
311 char_offset: 3,
312 },
313 ],
314 chunks
315 );
316 }
317
318 #[test]
319 fn graphemes_fallback_to_chars() {
320 let text = "a̐éö̲\r\n";
321 let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
322 .chunks(text)
323 .collect::<Vec<_>>();
324 assert_eq!(
325 vec!["a", "\u{310}", "é", "ö", "\u{332}", "\r", "\n"],
326 chunks
327 );
328 }
329
330 #[test]
331 fn trim_grapheme_indices() {
332 let text = "\r\na̐éö̲\r\n";
333 let chunks = TextSplitter::new(3).chunk_indices(text).collect::<Vec<_>>();
334
335 assert_eq!(vec![(2, "a̐é"), (7, "ö̲")], chunks);
336 }
337
338 #[test]
339 fn grapheme_char_indices() {
340 let text = "\r\na̐éö̲\r\n";
341 let chunks = TextSplitter::new(3)
342 .chunk_char_indices(text)
343 .collect::<Vec<_>>();
344
345 assert_eq!(
346 vec![
347 ChunkCharIndex {
348 chunk: "a̐é",
349 byte_offset: 2,
350 char_offset: 2
351 },
352 ChunkCharIndex {
353 chunk: "ö̲",
354 byte_offset: 7,
355 char_offset: 5
356 }
357 ],
358 chunks
359 );
360 }
361
362 #[test]
363 fn chunk_by_words() {
364 let text = "The quick (\"brown\") fox can't jump 32.3 feet, right?";
365 let chunks = TextSplitter::new(ChunkConfig::new(10).with_trim(false))
366 .chunks(text)
367 .collect::<Vec<_>>();
368
369 assert_eq!(
370 vec![
371 "The quick ",
372 "(\"brown\") ",
373 "fox can't ",
374 "jump 32.3 ",
375 "feet, ",
376 "right?"
377 ],
378 chunks
379 );
380 }
381
382 #[test]
383 fn words_fallback_to_graphemes() {
384 let text = "Thé quick\r\n";
385 let chunks = TextSplitter::new(ChunkConfig::new(2).with_trim(false))
386 .chunks(text)
387 .collect::<Vec<_>>();
388 assert_eq!(vec!["Th", "é ", "qu", "ic", "k", "\r\n"], chunks);
389 }
390
391 #[test]
392 fn trim_word_indices() {
393 let text = "Some text from a document";
394 let chunks = TextSplitter::new(10)
395 .chunk_indices(text)
396 .collect::<Vec<_>>();
397 assert_eq!(
398 vec![(0, "Some text"), (10, "from a"), (17, "document")],
399 chunks
400 );
401 }
402
403 #[test]
404 fn chunk_by_sentences() {
405 let text = "Mr. Fox jumped. [...] The dog was too lazy.";
406 let chunks = TextSplitter::new(ChunkConfig::new(21).with_trim(false))
407 .chunks(text)
408 .collect::<Vec<_>>();
409 assert_eq!(
410 vec!["Mr. Fox jumped. ", "[...] ", "The dog was too lazy."],
411 chunks
412 );
413 }
414
415 #[test]
416 fn sentences_falls_back_to_words() {
417 let text = "Mr. Fox jumped. [...] The dog was too lazy.";
418 let chunks = TextSplitter::new(ChunkConfig::new(16).with_trim(false))
419 .chunks(text)
420 .collect::<Vec<_>>();
421 assert_eq!(
422 vec!["Mr. Fox jumped. ", "[...] ", "The dog was too ", "lazy."],
423 chunks
424 );
425 }
426
427 #[test]
428 fn trim_sentence_indices() {
429 let text = "Some text. From a document.";
430 let chunks = TextSplitter::new(10)
431 .chunk_indices(text)
432 .collect::<Vec<_>>();
433 assert_eq!(
434 vec![(0, "Some text."), (11, "From a"), (18, "document.")],
435 chunks
436 );
437 }
438
439 #[test]
440 fn trim_paragraph_indices() {
441 let text = "Some text\n\nfrom a\ndocument";
442 let chunks = TextSplitter::new(10)
443 .chunk_indices(text)
444 .collect::<Vec<_>>();
445 assert_eq!(
446 vec![(0, "Some text"), (11, "from a"), (18, "document")],
447 chunks
448 );
449 }
450
451 #[test]
452 fn correctly_determines_newlines() {
453 let text = "\r\n\r\ntext\n\n\ntext2";
454 let splitter = TextSplitter::new(10);
455 let linebreaks = SemanticSplitRanges::new(splitter.parse(text));
456 assert_eq!(
457 vec![(LineBreaks(2), 0..4), (LineBreaks(3), 8..11)],
458 linebreaks.ranges
459 );
460 }
461}