1use std::ops::Range;
7
8use itertools::Itertools;
9use memchr::memchr2_iter;
10
11use crate::{
12 splitter::{SemanticLevel, Splitter},
13 ChunkConfig, ChunkSizer,
14};
15
16use super::{fallback::GRAPHEME_SEGMENTER, ChunkCharIndex};
17
18#[derive(Debug)]
22#[allow(clippy::module_name_repetitions)]
23pub struct TextSplitter<Sizer>
24where
25 Sizer: ChunkSizer,
26{
27 chunk_config: ChunkConfig<Sizer>,
29}
30
31impl<Sizer> TextSplitter<Sizer>
32where
33 Sizer: ChunkSizer,
34{
35 #[must_use]
44 pub fn new(chunk_config: impl Into<ChunkConfig<Sizer>>) -> Self {
45 Self {
46 chunk_config: chunk_config.into(),
47 }
48 }
49
50 pub fn chunks<'splitter, 'text: 'splitter>(
82 &'splitter self,
83 text: &'text str,
84 ) -> impl Iterator<Item = &'text str> + 'splitter {
85 Splitter::<_>::chunks(self, text)
86 }
87
88 pub fn chunk_indices<'splitter, 'text: 'splitter>(
103 &'splitter self,
104 text: &'text str,
105 ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
106 Splitter::<_>::chunk_indices(self, text)
107 }
108
109 pub fn chunk_char_indices<'splitter, 'text: 'splitter>(
128 &'splitter self,
129 text: &'text str,
130 ) -> impl Iterator<Item = ChunkCharIndex<'text>> + 'splitter {
131 Splitter::<_>::chunk_char_indices(self, text)
132 }
133}
134
135impl<Sizer> Splitter<Sizer> for TextSplitter<Sizer>
136where
137 Sizer: ChunkSizer,
138{
139 type Level = LineBreaks;
140
141 fn chunk_config(&self) -> &ChunkConfig<Sizer> {
142 &self.chunk_config
143 }
144
145 fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
146 #[allow(clippy::range_plus_one)]
147 memchr2_iter(b'\n', b'\r', text.as_bytes())
148 .map(|i| i..i + 1)
149 .coalesce(|a, b| {
150 if a.end == b.start {
151 Ok(a.start..b.end)
152 } else {
153 Err((a, b))
154 }
155 })
156 .map(|range| {
157 let level = GRAPHEME_SEGMENTER
158 .segment_str(text.get(range.start..range.end).unwrap())
159 .tuple_windows::<(usize, usize)>()
160 .count();
161 (
162 match level {
163 0 => unreachable!("regex should always match at least one newline"),
164 n => LineBreaks(n),
165 },
166 range,
167 )
168 })
169 .collect()
170 }
171}
172
173#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
179pub struct LineBreaks(usize);
180
181impl SemanticLevel for LineBreaks {}
182
183#[cfg(test)]
184mod tests {
185 use std::cmp::min;
186
187 use fake::{Fake, Faker};
188
189 use crate::{splitter::SemanticSplitRanges, ChunkCharIndex};
190
191 use super::*;
192
193 #[test]
194 fn returns_one_chunk_if_text_is_shorter_than_max_chunk_size() {
195 let text = Faker.fake::<String>();
196 let chunks = TextSplitter::new(ChunkConfig::new(text.chars().count()).with_trim(false))
197 .chunks(&text)
198 .collect::<Vec<_>>();
199
200 assert_eq!(vec![&text], chunks);
201 }
202
203 #[test]
204 fn returns_two_chunks_if_text_is_longer_than_max_chunk_size() {
205 let text1 = Faker.fake::<String>();
206 let text2 = Faker.fake::<String>();
207 let text = format!("{text1}{text2}");
208 let max_chunk_size = text.chars().count() / 2 + 1;
210 let chunks = TextSplitter::new(ChunkConfig::new(max_chunk_size).with_trim(false))
211 .chunks(&text)
212 .collect::<Vec<_>>();
213
214 assert!(chunks.iter().all(|c| c.chars().count() <= max_chunk_size));
215
216 let len = min(text1.len(), chunks[0].len());
218 assert_eq!(text1[..len], chunks[0][..len]);
219 let len = min(text2.len(), chunks[1].len());
221 assert_eq!(
222 text2[(text2.len() - len)..],
223 chunks[1][chunks[1].len() - len..]
224 );
225
226 assert_eq!(chunks.join(""), text);
227 }
228
229 #[test]
230 fn empty_string() {
231 let text = "";
232 let chunks = TextSplitter::new(ChunkConfig::new(100).with_trim(false))
233 .chunks(text)
234 .collect::<Vec<_>>();
235
236 assert!(chunks.is_empty());
237 }
238
239 #[test]
240 fn can_handle_unicode_characters() {
241 let text = "éé"; let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
243 .chunks(text)
244 .collect::<Vec<_>>();
245 assert_eq!(vec!["é", "é"], chunks);
246 }
247
248 struct Str;
250
251 impl ChunkSizer for Str {
252 fn size(&self, chunk: &str) -> usize {
253 chunk.len()
254 }
255 }
256
257 #[test]
258 fn custom_len_function() {
259 let text = "éé"; let chunks = TextSplitter::new(ChunkConfig::new(2).with_sizer(Str).with_trim(false))
261 .chunks(text)
262 .collect::<Vec<_>>();
263
264 assert_eq!(vec!["é", "é"], chunks);
265 }
266
267 #[test]
268 fn handles_char_bigger_than_len() {
269 let text = "éé"; let chunks = TextSplitter::new(ChunkConfig::new(1).with_sizer(Str).with_trim(false))
271 .chunks(text)
272 .collect::<Vec<_>>();
273
274 assert_eq!(vec!["é", "é"], chunks);
276 }
277
278 #[test]
279 fn chunk_by_graphemes() {
280 let text = "a̐éö̲\r\n";
281 let chunks = TextSplitter::new(ChunkConfig::new(3).with_trim(false))
282 .chunks(text)
283 .collect::<Vec<_>>();
284
285 assert_eq!(vec!["a̐é", "ö̲", "\r\n"], chunks);
287 }
288
289 #[test]
290 fn trim_char_indices() {
291 let text = " a b ";
292 let chunks = TextSplitter::new(1).chunk_indices(text).collect::<Vec<_>>();
293
294 assert_eq!(vec![(1, "a"), (3, "b")], chunks);
295 }
296
297 #[test]
298 fn chunk_char_indices() {
299 let text = " a b ";
300 let chunks = TextSplitter::new(1)
301 .chunk_char_indices(text)
302 .collect::<Vec<_>>();
303
304 assert_eq!(
305 vec![
306 ChunkCharIndex {
307 chunk: "a",
308 byte_offset: 1,
309 char_offset: 1
310 },
311 ChunkCharIndex {
312 chunk: "b",
313 byte_offset: 3,
314 char_offset: 3,
315 },
316 ],
317 chunks
318 );
319 }
320
321 #[test]
322 fn graphemes_fallback_to_chars() {
323 let text = "a̐éö̲\r\n";
324 let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
325 .chunks(text)
326 .collect::<Vec<_>>();
327 assert_eq!(
328 vec!["a", "\u{310}", "é", "ö", "\u{332}", "\r", "\n"],
329 chunks
330 );
331 }
332
333 #[test]
334 fn trim_grapheme_indices() {
335 let text = "\r\na̐éö̲\r\n";
336 let chunks = TextSplitter::new(3).chunk_indices(text).collect::<Vec<_>>();
337
338 assert_eq!(vec![(2, "a̐é"), (7, "ö̲")], chunks);
339 }
340
341 #[test]
342 fn grapheme_char_indices() {
343 let text = "\r\na̐éö̲\r\n";
344 let chunks = TextSplitter::new(3)
345 .chunk_char_indices(text)
346 .collect::<Vec<_>>();
347
348 assert_eq!(
349 vec![
350 ChunkCharIndex {
351 chunk: "a̐é",
352 byte_offset: 2,
353 char_offset: 2
354 },
355 ChunkCharIndex {
356 chunk: "ö̲",
357 byte_offset: 7,
358 char_offset: 5
359 }
360 ],
361 chunks
362 );
363 }
364
365 #[test]
366 fn chunk_by_words() {
367 let text = "The quick (\"brown\") fox can't jump 32.3 feet, right?";
368 let chunks = TextSplitter::new(ChunkConfig::new(10).with_trim(false))
369 .chunks(text)
370 .collect::<Vec<_>>();
371
372 assert_eq!(
373 vec![
374 "The quick ",
375 "(\"brown\") ",
376 "fox can't ",
377 "jump 32.3 ",
378 "feet, ",
379 "right?"
380 ],
381 chunks
382 );
383 }
384
385 #[test]
386 fn words_fallback_to_graphemes() {
387 let text = "Thé quick\r\n";
388 let chunks = TextSplitter::new(ChunkConfig::new(2).with_trim(false))
389 .chunks(text)
390 .collect::<Vec<_>>();
391 assert_eq!(vec!["Th", "é ", "qu", "ic", "k", "\r\n"], chunks);
392 }
393
394 #[test]
395 fn trim_word_indices() {
396 let text = "Some text from a document";
397 let chunks = TextSplitter::new(10)
398 .chunk_indices(text)
399 .collect::<Vec<_>>();
400 assert_eq!(
401 vec![(0, "Some text"), (10, "from a"), (17, "document")],
402 chunks
403 );
404 }
405
406 #[test]
407 fn chunk_by_sentences() {
408 let text = "Mr. Fox jumped. [...] The dog was too lazy.";
409 let chunks = TextSplitter::new(ChunkConfig::new(21).with_trim(false))
410 .chunks(text)
411 .collect::<Vec<_>>();
412 assert_eq!(
413 vec!["Mr. Fox jumped. ", "[...] ", "The dog was too lazy."],
414 chunks
415 );
416 }
417
418 #[test]
419 fn sentences_falls_back_to_words() {
420 let text = "Mr. Fox jumped. [...] The dog was too lazy.";
421 let chunks = TextSplitter::new(ChunkConfig::new(16).with_trim(false))
422 .chunks(text)
423 .collect::<Vec<_>>();
424 assert_eq!(
425 vec!["Mr. Fox jumped. ", "[...] ", "The dog was too ", "lazy."],
426 chunks
427 );
428 }
429
430 #[test]
431 fn trim_sentence_indices() {
432 let text = "Some text. From a document.";
433 let chunks = TextSplitter::new(10)
434 .chunk_indices(text)
435 .collect::<Vec<_>>();
436 assert_eq!(
437 vec![(0, "Some text."), (11, "From a"), (18, "document.")],
438 chunks
439 );
440 }
441
442 #[test]
443 fn trim_paragraph_indices() {
444 let text = "Some text\n\nfrom a\ndocument";
445 let chunks = TextSplitter::new(10)
446 .chunk_indices(text)
447 .collect::<Vec<_>>();
448 assert_eq!(
449 vec![(0, "Some text"), (11, "from a"), (18, "document")],
450 chunks
451 );
452 }
453
454 #[test]
455 fn correctly_determines_newlines() {
456 let text = "\r\n\r\ntext\n\n\ntext2";
457 let splitter = TextSplitter::new(10);
458 let linebreaks = SemanticSplitRanges::new(splitter.parse(text));
459 assert_eq!(
460 vec![(LineBreaks(2), 0..4), (LineBreaks(3), 8..11)],
461 linebreaks.ranges
462 );
463 }
464}