1use std::ops::Range;
7
8use itertools::Itertools;
9use memchr::memchr2_iter;
10
11use crate::{
12 splitter::{SemanticLevel, Splitter},
13 ChunkConfig, ChunkSizer,
14};
15
16use super::{fallback::GRAPHEME_SEGMENTER, ChunkCharIndex};
17
18#[derive(Debug)]
22pub struct TextSplitter<Sizer>
23where
24 Sizer: ChunkSizer,
25{
26 chunk_config: ChunkConfig<Sizer>,
28}
29
30impl<Sizer> TextSplitter<Sizer>
31where
32 Sizer: ChunkSizer,
33{
34 #[must_use]
43 pub fn new(chunk_config: impl Into<ChunkConfig<Sizer>>) -> Self {
44 Self {
45 chunk_config: chunk_config.into(),
46 }
47 }
48
49 pub fn chunks<'splitter, 'text: 'splitter>(
81 &'splitter self,
82 text: &'text str,
83 ) -> impl Iterator<Item = &'text str> + 'splitter {
84 Splitter::<_>::chunks(self, text)
85 }
86
87 pub fn chunk_indices<'splitter, 'text: 'splitter>(
102 &'splitter self,
103 text: &'text str,
104 ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
105 Splitter::<_>::chunk_indices(self, text)
106 }
107
108 pub fn chunk_char_indices<'splitter, 'text: 'splitter>(
127 &'splitter self,
128 text: &'text str,
129 ) -> impl Iterator<Item = ChunkCharIndex<'text>> + 'splitter {
130 Splitter::<_>::chunk_char_indices(self, text)
131 }
132}
133
134impl<Sizer> Splitter<Sizer> for TextSplitter<Sizer>
135where
136 Sizer: ChunkSizer,
137{
138 type Level = LineBreaks;
139
140 fn chunk_config(&self) -> &ChunkConfig<Sizer> {
141 &self.chunk_config
142 }
143
144 fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
145 #[expect(clippy::range_plus_one)]
146 memchr2_iter(b'\n', b'\r', text.as_bytes())
147 .map(|i| i..i + 1)
148 .coalesce(|a, b| {
149 if a.end == b.start {
150 Ok(a.start..b.end)
151 } else {
152 Err((a, b))
153 }
154 })
155 .map(|range| {
156 let level = GRAPHEME_SEGMENTER
157 .segment_str(text.get(range.start..range.end).unwrap())
158 .tuple_windows::<(usize, usize)>()
159 .count();
160 (
161 match level {
162 0 => unreachable!("regex should always match at least one newline"),
163 n => LineBreaks(n),
164 },
165 range,
166 )
167 })
168 .collect()
169 }
170}
171
172#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
178pub struct LineBreaks(usize);
179
180impl SemanticLevel for LineBreaks {}
181
182#[cfg(test)]
183mod tests {
184 use std::cmp::min;
185
186 use fake::{Fake, Faker};
187
188 use crate::{splitter::SemanticSplitRanges, ChunkCharIndex};
189
190 use super::*;
191
192 #[test]
193 fn returns_one_chunk_if_text_is_shorter_than_max_chunk_size() {
194 let text = Faker.fake::<String>();
195 let chunks = TextSplitter::new(ChunkConfig::new(text.chars().count()).with_trim(false))
196 .chunks(&text)
197 .collect::<Vec<_>>();
198
199 assert_eq!(vec![&text], chunks);
200 }
201
202 #[test]
203 fn returns_two_chunks_if_text_is_longer_than_max_chunk_size() {
204 let text1 = Faker.fake::<String>();
205 let text2 = Faker.fake::<String>();
206 let text = format!("{text1}{text2}");
207 let max_chunk_size = text.chars().count() / 2 + 1;
209 let chunks = TextSplitter::new(ChunkConfig::new(max_chunk_size).with_trim(false))
210 .chunks(&text)
211 .collect::<Vec<_>>();
212
213 assert!(chunks.iter().all(|c| c.chars().count() <= max_chunk_size));
214
215 let len = min(text1.len(), chunks[0].len());
217 assert_eq!(text1[..len], chunks[0][..len]);
218 let len = min(text2.len(), chunks[1].len());
220 assert_eq!(
221 text2[(text2.len() - len)..],
222 chunks[1][chunks[1].len() - len..]
223 );
224
225 assert_eq!(chunks.join(""), text);
226 }
227
228 #[test]
229 fn empty_string() {
230 let text = "";
231 let chunks = TextSplitter::new(ChunkConfig::new(100).with_trim(false))
232 .chunks(text)
233 .collect::<Vec<_>>();
234
235 assert!(chunks.is_empty());
236 }
237
238 #[test]
239 fn can_handle_unicode_characters() {
240 let text = "éé"; let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
242 .chunks(text)
243 .collect::<Vec<_>>();
244 assert_eq!(vec!["é", "é"], chunks);
245 }
246
247 struct Str;
249
250 impl ChunkSizer for Str {
251 fn size(&self, chunk: &str) -> usize {
252 chunk.len()
253 }
254 }
255
256 #[test]
257 fn custom_len_function() {
258 let text = "éé"; let chunks = TextSplitter::new(ChunkConfig::new(2).with_sizer(Str).with_trim(false))
260 .chunks(text)
261 .collect::<Vec<_>>();
262
263 assert_eq!(vec!["é", "é"], chunks);
264 }
265
266 #[test]
267 fn handles_char_bigger_than_len() {
268 let text = "éé"; let chunks = TextSplitter::new(ChunkConfig::new(1).with_sizer(Str).with_trim(false))
270 .chunks(text)
271 .collect::<Vec<_>>();
272
273 assert_eq!(vec!["é", "é"], chunks);
275 }
276
277 #[test]
278 fn chunk_by_graphemes() {
279 let text = "a̐éö̲\r\n";
280 let chunks = TextSplitter::new(ChunkConfig::new(3).with_trim(false))
281 .chunks(text)
282 .collect::<Vec<_>>();
283
284 assert_eq!(vec!["a̐é", "ö̲", "\r\n"], chunks);
286 }
287
288 #[test]
289 fn trim_char_indices() {
290 let text = " a b ";
291 let chunks = TextSplitter::new(1).chunk_indices(text).collect::<Vec<_>>();
292
293 assert_eq!(vec![(1, "a"), (3, "b")], chunks);
294 }
295
296 #[test]
297 fn chunk_char_indices() {
298 let text = " a b ";
299 let chunks = TextSplitter::new(1)
300 .chunk_char_indices(text)
301 .collect::<Vec<_>>();
302
303 assert_eq!(
304 vec![
305 ChunkCharIndex {
306 chunk: "a",
307 byte_offset: 1,
308 char_offset: 1
309 },
310 ChunkCharIndex {
311 chunk: "b",
312 byte_offset: 3,
313 char_offset: 3,
314 },
315 ],
316 chunks
317 );
318 }
319
320 #[test]
321 fn graphemes_fallback_to_chars() {
322 let text = "a̐éö̲\r\n";
323 let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
324 .chunks(text)
325 .collect::<Vec<_>>();
326 assert_eq!(
327 vec!["a", "\u{310}", "é", "ö", "\u{332}", "\r", "\n"],
328 chunks
329 );
330 }
331
332 #[test]
333 fn trim_grapheme_indices() {
334 let text = "\r\na̐éö̲\r\n";
335 let chunks = TextSplitter::new(3).chunk_indices(text).collect::<Vec<_>>();
336
337 assert_eq!(vec![(2, "a̐é"), (7, "ö̲")], chunks);
338 }
339
340 #[test]
341 fn grapheme_char_indices() {
342 let text = "\r\na̐éö̲\r\n";
343 let chunks = TextSplitter::new(3)
344 .chunk_char_indices(text)
345 .collect::<Vec<_>>();
346
347 assert_eq!(
348 vec![
349 ChunkCharIndex {
350 chunk: "a̐é",
351 byte_offset: 2,
352 char_offset: 2
353 },
354 ChunkCharIndex {
355 chunk: "ö̲",
356 byte_offset: 7,
357 char_offset: 5
358 }
359 ],
360 chunks
361 );
362 }
363
364 #[test]
365 fn chunk_by_words() {
366 let text = "The quick (\"brown\") fox can't jump 32.3 feet, right?";
367 let chunks = TextSplitter::new(ChunkConfig::new(10).with_trim(false))
368 .chunks(text)
369 .collect::<Vec<_>>();
370
371 assert_eq!(
372 vec![
373 "The quick ",
374 "(\"brown\") ",
375 "fox can't ",
376 "jump 32.3 ",
377 "feet, ",
378 "right?"
379 ],
380 chunks
381 );
382 }
383
384 #[test]
385 fn words_fallback_to_graphemes() {
386 let text = "Thé quick\r\n";
387 let chunks = TextSplitter::new(ChunkConfig::new(2).with_trim(false))
388 .chunks(text)
389 .collect::<Vec<_>>();
390 assert_eq!(vec!["Th", "é ", "qu", "ic", "k", "\r\n"], chunks);
391 }
392
393 #[test]
394 fn trim_word_indices() {
395 let text = "Some text from a document";
396 let chunks = TextSplitter::new(10)
397 .chunk_indices(text)
398 .collect::<Vec<_>>();
399 assert_eq!(
400 vec![(0, "Some text"), (10, "from a"), (17, "document")],
401 chunks
402 );
403 }
404
405 #[test]
406 fn chunk_by_sentences() {
407 let text = "Mr. Fox jumped. [...] The dog was too lazy.";
408 let chunks = TextSplitter::new(ChunkConfig::new(21).with_trim(false))
409 .chunks(text)
410 .collect::<Vec<_>>();
411 assert_eq!(
412 vec!["Mr. Fox jumped. ", "[...] ", "The dog was too lazy."],
413 chunks
414 );
415 }
416
417 #[test]
418 fn sentences_falls_back_to_words() {
419 let text = "Mr. Fox jumped. [...] The dog was too lazy.";
420 let chunks = TextSplitter::new(ChunkConfig::new(16).with_trim(false))
421 .chunks(text)
422 .collect::<Vec<_>>();
423 assert_eq!(
424 vec!["Mr. Fox jumped. ", "[...] ", "The dog was too ", "lazy."],
425 chunks
426 );
427 }
428
429 #[test]
430 fn trim_sentence_indices() {
431 let text = "Some text. From a document.";
432 let chunks = TextSplitter::new(10)
433 .chunk_indices(text)
434 .collect::<Vec<_>>();
435 assert_eq!(
436 vec![(0, "Some text."), (11, "From a"), (18, "document.")],
437 chunks
438 );
439 }
440
441 #[test]
442 fn trim_paragraph_indices() {
443 let text = "Some text\n\nfrom a\ndocument";
444 let chunks = TextSplitter::new(10)
445 .chunk_indices(text)
446 .collect::<Vec<_>>();
447 assert_eq!(
448 vec![(0, "Some text"), (11, "from a"), (18, "document")],
449 chunks
450 );
451 }
452
453 #[test]
454 fn correctly_determines_newlines() {
455 let text = "\r\n\r\ntext\n\n\ntext2";
456 let splitter = TextSplitter::new(10);
457 let linebreaks = SemanticSplitRanges::new(splitter.parse(text));
458 assert_eq!(
459 vec![(LineBreaks(2), 0..4), (LineBreaks(3), 8..11)],
460 linebreaks.ranges
461 );
462 }
463}