1use std::{cmp::Ordering, fmt, iter::once, ops::Range};
2
3use either::Either;
4use itertools::Itertools;
5use strum::IntoEnumIterator;
6
7use self::fallback::FallbackLevel;
8use crate::{chunk_size::MemoizedChunkSizer, trim::Trim, ChunkCapacity, ChunkConfig, ChunkSizer};
9
10#[cfg(feature = "code")]
11mod code;
12mod fallback;
13#[cfg(feature = "markdown")]
14mod markdown;
15mod text;
16
17#[cfg(feature = "code")]
18#[allow(clippy::module_name_repetitions)]
19pub use code::{CodeSplitter, CodeSplitterError};
20#[cfg(feature = "markdown")]
21#[allow(clippy::module_name_repetitions)]
22pub use markdown::MarkdownSplitter;
23#[allow(clippy::module_name_repetitions)]
24pub use text::TextSplitter;
25
26trait Splitter<Sizer>
29where
30 Sizer: ChunkSizer,
31{
32 type Level: SemanticLevel;
33
34 const TRIM: Trim = Trim::All;
36
37 fn chunk_config(&self) -> &ChunkConfig<Sizer>;
39
40 fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)>;
42
43 fn chunk_indices<'splitter, 'text: 'splitter>(
46 &'splitter self,
47 text: &'text str,
48 ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter
49 where
50 Sizer: 'splitter,
51 {
52 TextChunks::<Sizer, Self::Level>::new(
53 self.chunk_config(),
54 text,
55 self.parse(text),
56 Self::TRIM,
57 )
58 }
59
60 fn chunk_char_indices<'splitter, 'text: 'splitter>(
67 &'splitter self,
68 text: &'text str,
69 ) -> impl Iterator<Item = ChunkCharIndex<'text>> + 'splitter
70 where
71 Sizer: 'splitter,
72 {
73 TextChunksWithCharIndices::<Sizer, Self::Level>::new(
74 self.chunk_config(),
75 text,
76 self.parse(text),
77 Self::TRIM,
78 )
79 }
80
81 fn chunks<'splitter, 'text: 'splitter>(
84 &'splitter self,
85 text: &'text str,
86 ) -> impl Iterator<Item = &'text str> + 'splitter
87 where
88 Sizer: 'splitter,
89 {
90 self.chunk_indices(text).map(|(_, t)| t)
91 }
92}
93
94trait SemanticLevel: Copy + fmt::Debug + Ord + PartialOrd + 'static {
96 fn sections(
101 text: &str,
102 level_ranges: impl Iterator<Item = (Self, Range<usize>)>,
103 ) -> impl Iterator<Item = (usize, &str)> {
104 let mut cursor = 0;
105 let mut final_match = false;
106 level_ranges
107 .batching(move |it| {
108 loop {
109 match it.next() {
110 None if final_match => return None,
112 None => {
114 final_match = true;
115 return text.get(cursor..).map(|t| Either::Left(once((cursor, t))));
116 }
117 Some((_, range)) => {
119 if range.start < cursor {
120 continue;
121 }
122 let offset = cursor;
123 let prev_section = text
124 .get(offset..range.start)
125 .expect("invalid character sequence");
126 let separator = text
127 .get(range.start..range.end)
128 .expect("invalid character sequence");
129 cursor = range.end;
130 return Some(Either::Right(
131 [(offset, prev_section), (range.start, separator)].into_iter(),
132 ));
133 }
134 }
135 }
136 })
137 .flatten()
138 .filter(|(_, s)| !s.is_empty())
139 }
140}
141
142#[derive(Debug)]
145struct SemanticSplitRanges<Level>
146where
147 Level: SemanticLevel,
148{
149 cursor: usize,
152 ranges: Vec<(Level, Range<usize>)>,
154}
155
156impl<Level> SemanticSplitRanges<Level>
157where
158 Level: SemanticLevel,
159{
160 fn new(mut ranges: Vec<(Level, Range<usize>)>) -> Self {
161 ranges.sort_unstable_by(|(_, a), (_, b)| {
163 a.start.cmp(&b.start).then_with(|| b.end.cmp(&a.end))
164 });
165 Self { cursor: 0, ranges }
166 }
167
168 fn ranges_after_offset(
170 &self,
171 offset: usize,
172 ) -> impl Iterator<Item = (Level, Range<usize>)> + '_ {
173 self.ranges[self.cursor..]
174 .iter()
175 .filter(move |(_, sep)| sep.start >= offset)
176 .map(|(l, r)| (*l, r.start..r.end))
177 }
178 fn level_ranges_after_offset(
180 &self,
181 offset: usize,
182 level: Level,
183 ) -> impl Iterator<Item = (Level, Range<usize>)> + '_ {
184 let first_item = self
187 .ranges_after_offset(offset)
188 .position(|(l, _)| l == level)
189 .and_then(|i| {
190 self.ranges_after_offset(offset)
191 .skip(i)
192 .coalesce(|(a_level, a_range), (b_level, b_range)| {
193 if a_level == b_level && a_range.start == b_range.start && i == 0 {
195 Ok((b_level, b_range))
196 } else {
197 Err(((a_level, a_range), (b_level, b_range)))
198 }
199 })
200 .next()
202 });
203 self.ranges_after_offset(offset)
205 .filter(move |(l, _)| l >= &level)
206 .skip_while(move |(l, r)| {
207 first_item.as_ref().is_some_and(|(_, fir)| {
208 (l > &level && r.contains(&fir.start))
209 || (l == &level && r.start == fir.start && r.end > fir.end)
210 })
211 })
212 }
213
214 fn levels_in_remaining_text(&self, offset: usize) -> impl Iterator<Item = Level> + '_ {
217 self.ranges_after_offset(offset)
218 .map(|(l, _)| l)
219 .sorted()
220 .dedup()
221 }
222
223 fn semantic_chunks<'splitter, 'text: 'splitter>(
225 &'splitter self,
226 offset: usize,
227 text: &'text str,
228 semantic_level: Level,
229 ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
230 Level::sections(
231 text,
232 self.level_ranges_after_offset(offset, semantic_level)
233 .map(move |(l, sep)| (l, sep.start - offset..sep.end - offset)),
234 )
235 .map(move |(i, str)| (offset + i, str))
236 }
237
238 fn update_cursor(&mut self, cursor: usize) {
240 self.cursor += self.ranges[self.cursor..]
241 .iter()
242 .position(|(_, range)| range.start >= cursor)
243 .unwrap_or_else(|| self.ranges.len() - self.cursor);
244 }
245}
246
247#[derive(Debug)]
249struct TextChunks<'text, 'sizer, Sizer, Level>
250where
251 Sizer: ChunkSizer,
252 Level: SemanticLevel,
253{
254 capacity: ChunkCapacity,
256 chunk_sizer: MemoizedChunkSizer<'sizer, Sizer>,
258 chunk_stats: ChunkStats,
260 cursor: usize,
262 next_sections: Vec<(usize, &'text str)>,
264 overlap: ChunkCapacity,
266 prev_item_end: usize,
268 semantic_split: SemanticSplitRanges<Level>,
270 text: &'text str,
272 trim: Trim,
274}
275
276impl<'sizer, 'text: 'sizer, Sizer, Level> TextChunks<'text, 'sizer, Sizer, Level>
277where
278 Sizer: ChunkSizer,
279 Level: SemanticLevel,
280{
281 fn new(
284 chunk_config: &'sizer ChunkConfig<Sizer>,
285 text: &'text str,
286 offsets: Vec<(Level, Range<usize>)>,
287 trim: Trim,
288 ) -> Self {
289 let ChunkConfig {
290 capacity,
291 overlap,
292 sizer,
293 trim: trim_enabled,
294 } = chunk_config;
295 Self {
296 capacity: *capacity,
297 chunk_sizer: MemoizedChunkSizer::new(sizer),
298 chunk_stats: ChunkStats::new(),
299 cursor: 0,
300 next_sections: Vec::new(),
301 overlap: (*overlap).into(),
302 prev_item_end: 0,
303 semantic_split: SemanticSplitRanges::new(offsets),
304 text,
305 trim: if *trim_enabled { trim } else { Trim::None },
306 }
307 }
308
309 fn next_chunk(&mut self) -> Option<(usize, &'text str)> {
313 self.semantic_split.update_cursor(self.cursor);
314 let low = self.update_next_sections();
315 let (start, end) = self.binary_search_next_chunk(low)?;
316 let chunk = self.text.get(start..end)?;
317 self.chunk_stats.update_max_chunk_size(end - start);
318
319 self.chunk_sizer.clear_cache();
321 self.update_cursor(end);
323
324 Some(self.trim.trim(start, chunk))
326 }
327
328 fn binary_search_next_chunk(&mut self, mut low: usize) -> Option<(usize, usize)> {
330 let start = self.cursor;
331 let mut end = self.cursor;
332 let mut equals_found = false;
333 let mut high = self.next_sections.len().saturating_sub(1);
334 let mut successful_index = None;
335 let mut successful_chunk_size = None;
336
337 while low <= high {
338 let mid = low + (high - low) / 2;
339 let (offset, str) = self.next_sections[mid];
340 let text_end = offset + str.len();
341 let chunk = self.text.get(start..text_end)?;
342 let chunk_size = self.chunk_sizer.chunk_size(start, chunk, self.trim);
343 let fits = self.capacity.fits(chunk_size);
344
345 match fits {
346 Ordering::Less => {
347 if text_end > end {
349 end = text_end;
350 successful_index = Some(mid);
351 successful_chunk_size = Some(chunk_size);
352 }
353 }
354 Ordering::Equal => {
355 if text_end < end || !equals_found {
357 end = text_end;
358 successful_index = Some(mid);
359 successful_chunk_size = Some(chunk_size);
360 }
361 equals_found = true;
362 }
363 Ordering::Greater => {
364 if mid == 0 && start == end {
366 end = text_end;
367 successful_index = Some(mid);
368 successful_chunk_size = Some(chunk_size);
369 }
370 }
371 };
372
373 if fits.is_lt() {
375 low = mid + 1;
376 } else if mid > 0 {
377 high = mid - 1;
378 } else {
379 break;
381 }
382 }
383
384 if let (Some(successful_index), Some(chunk_size)) =
385 (successful_index, successful_chunk_size)
386 {
387 let mut range = successful_index..self.next_sections.len();
388 range.next();
390
391 for index in range {
392 let (offset, str) = self.next_sections[index];
393 let text_end = offset + str.len();
394 let chunk = self.text.get(start..text_end)?;
395 let size = self.chunk_sizer.chunk_size(start, chunk, self.trim);
396 if size <= chunk_size {
397 if text_end > end {
398 end = text_end;
399 }
400 } else {
401 break;
402 }
403 }
404 }
405
406 Some((start, end))
407 }
408
409 fn update_cursor(&mut self, end: usize) {
412 if self.overlap.max == 0 {
413 self.cursor = end;
414 return;
415 }
416
417 let mut start = end;
419 let mut low = 0;
420 let mut high = match self
422 .next_sections
423 .binary_search_by_key(&end, |(offset, str)| offset + str.len())
424 {
425 Ok(i) | Err(i) => i,
426 };
427
428 while low <= high {
429 let mid = low + (high - low) / 2;
430 let (offset, _) = self.next_sections[mid];
431 let chunk_size = self.chunk_sizer.chunk_size(
432 offset,
433 self.text.get(offset..end).expect("Invalid range"),
434 self.trim,
435 );
436 let fits = self.overlap.fits(chunk_size);
437
438 if fits.is_le() && offset < start && offset > self.cursor {
440 start = offset;
441 }
442
443 if fits.is_lt() && mid > 0 {
445 high = mid - 1;
446 } else {
447 low = mid + 1;
448 }
449 }
450
451 self.cursor = start;
452 }
453
454 #[allow(clippy::too_many_lines)]
458 fn update_next_sections(&mut self) -> usize {
459 self.next_sections.clear();
461
462 let remaining_text = self.text.get(self.cursor..).unwrap();
463
464 let (semantic_level, mut max_offset) = self.chunk_sizer.find_correct_level(
465 self.cursor,
466 &self.capacity,
467 self.semantic_split
468 .levels_in_remaining_text(self.cursor)
469 .filter_map(|level| {
470 self.semantic_split
471 .semantic_chunks(self.cursor, remaining_text, level)
472 .next()
473 .map(|(_, str)| (level, str))
474 }),
475 self.trim,
476 );
477
478 let sections = if let Some(semantic_level) = semantic_level {
479 Either::Left(self.semantic_split.semantic_chunks(
480 self.cursor,
481 remaining_text,
482 semantic_level,
483 ))
484 } else {
485 let (semantic_level, fallback_max_offset) = self.chunk_sizer.find_correct_level(
486 self.cursor,
487 &self.capacity,
488 FallbackLevel::iter().filter_map(|level| {
489 level
490 .sections(remaining_text)
491 .next()
492 .map(|(_, str)| (level, str))
493 }),
494 self.trim,
495 );
496
497 max_offset = match (fallback_max_offset, max_offset) {
498 (Some(fallback), Some(max)) => Some(fallback.min(max)),
499 (fallback, max) => fallback.or(max),
500 };
501
502 let fallback_level = semantic_level.unwrap_or(FallbackLevel::Char);
503
504 Either::Right(
505 fallback_level
506 .sections(remaining_text)
507 .map(|(offset, text)| (self.cursor + offset, text)),
508 )
509 };
510
511 let mut sections = sections
512 .take_while(move |(offset, _)| max_offset.map_or(true, |max| *offset <= max))
513 .filter(|(_, str)| !str.is_empty());
514
515 let mut low = 0;
518 let mut prev_equals: Option<usize> = None;
519 let max = self.capacity.max;
520 let mut target_offset = self.chunk_stats.max_chunk_size.unwrap_or(max);
521
522 loop {
523 let prev_num = self.next_sections.len();
524 for (offset, str) in sections.by_ref() {
525 self.next_sections.push((offset, str));
526 if offset + str.len() > (self.cursor.saturating_add(target_offset)) {
527 break;
528 }
529 }
530 let new_num = self.next_sections.len();
531 if new_num - prev_num == 0 {
533 break;
534 }
535
536 if let Some(&(offset, str)) = self.next_sections.last() {
538 let text_end = offset + str.len();
539 if (text_end - self.cursor) < target_offset {
540 break;
541 }
542 let chunk_size = self.chunk_sizer.chunk_size(
543 offset,
544 self.text.get(self.cursor..text_end).expect("Invalid range"),
545 self.trim,
546 );
547 let fits = self.capacity.fits(chunk_size);
548
549 if fits.is_le() {
550 let final_offset = offset + str.len() - self.cursor;
551 let size = chunk_size.max(1);
552 let diff = (max - size).max(1);
553 let avg_size = final_offset.div_ceil(size);
554
555 target_offset = final_offset
556 .saturating_add(diff.saturating_mul(avg_size))
557 .saturating_add(final_offset.div_ceil(10));
558 }
559
560 match fits {
561 Ordering::Less => {
562 low = new_num.saturating_sub(1);
564 continue;
565 }
566 Ordering::Equal => {
567 if let Some(prev) = prev_equals {
570 if prev < chunk_size {
571 break;
572 }
573 }
574 prev_equals = Some(chunk_size);
575 continue;
576 }
577 Ordering::Greater => {
578 break;
579 }
580 };
581 }
582 }
583
584 low
585 }
586}
587
588impl<'sizer, 'text: 'sizer, Sizer, Level> Iterator for TextChunks<'text, 'sizer, Sizer, Level>
589where
590 Sizer: ChunkSizer,
591 Level: SemanticLevel,
592{
593 type Item = (usize, &'text str);
594
595 fn next(&mut self) -> Option<Self::Item> {
596 loop {
597 if self.cursor >= self.text.len() {
599 return None;
600 }
601
602 match self.next_chunk()? {
603 (_, "") => continue,
606 c => {
607 let item_end = c.0 + c.1.len();
608 if item_end <= self.prev_item_end {
610 continue;
611 }
612 self.prev_item_end = item_end;
613 return Some(c);
614 }
615 }
616 }
617 }
618}
619
620#[derive(Debug, Clone, Copy, PartialEq, Eq)]
622pub struct ChunkCharIndex<'text> {
623 pub chunk: &'text str,
625 pub byte_offset: usize,
627 pub char_offset: usize,
629}
630
631#[derive(Debug)]
633struct TextChunksWithCharIndices<'text, 'sizer, Sizer, Level>
634where
635 Sizer: ChunkSizer,
636 Level: SemanticLevel,
637{
638 text: &'text str,
640 text_chunks: TextChunks<'text, 'sizer, Sizer, Level>,
642 byte_offset: usize,
644 char_offset: usize,
646}
647
648impl<'sizer, 'text: 'sizer, Sizer, Level> TextChunksWithCharIndices<'text, 'sizer, Sizer, Level>
649where
650 Sizer: ChunkSizer,
651 Level: SemanticLevel,
652{
653 fn new(
656 chunk_config: &'sizer ChunkConfig<Sizer>,
657 text: &'text str,
658 offsets: Vec<(Level, Range<usize>)>,
659 trim: Trim,
660 ) -> Self {
661 Self {
662 text,
663 text_chunks: TextChunks::new(chunk_config, text, offsets, trim),
664 byte_offset: 0,
665 char_offset: 0,
666 }
667 }
668}
669
670impl<'sizer, 'text: 'sizer, Sizer, Level> Iterator
671 for TextChunksWithCharIndices<'text, 'sizer, Sizer, Level>
672where
673 Sizer: ChunkSizer,
674 Level: SemanticLevel,
675{
676 type Item = ChunkCharIndex<'text>;
677
678 fn next(&mut self) -> Option<Self::Item> {
679 let (byte_offset, chunk) = self.text_chunks.next()?;
680 let preceding_text = self
681 .text
682 .get(self.byte_offset..byte_offset)
683 .expect("Invalid byte sequence");
684 self.byte_offset = byte_offset;
685 self.char_offset += preceding_text.chars().count();
686 Some(ChunkCharIndex {
687 chunk,
688 byte_offset,
689 char_offset: self.char_offset,
690 })
691 }
692}
693
694#[derive(Debug, Default)]
696struct ChunkStats {
697 max_chunk_size: Option<usize>,
699}
700
701impl ChunkStats {
702 fn new() -> Self {
703 Self::default()
704 }
705
706 fn update_max_chunk_size(&mut self, size: usize) {
708 self.max_chunk_size = self.max_chunk_size.map(|s| s.max(size)).or(Some(size));
709 }
710}
711
712#[cfg(test)]
713mod tests {
714 use super::*;
715
716 #[test]
717 fn chunk_stats_empty() {
718 let stats = ChunkStats::new();
719 assert_eq!(stats.max_chunk_size, None);
720 }
721
722 #[test]
723 fn chunk_stats_one() {
724 let mut stats = ChunkStats::new();
725 stats.update_max_chunk_size(10);
726 assert_eq!(stats.max_chunk_size, Some(10));
727 }
728
729 #[test]
730 fn chunk_stats_multiple() {
731 let mut stats = ChunkStats::new();
732 stats.update_max_chunk_size(10);
733 stats.update_max_chunk_size(20);
734 stats.update_max_chunk_size(30);
735 assert_eq!(stats.max_chunk_size, Some(30));
736 }
737
738 impl SemanticLevel for usize {}
739
740 #[test]
741 fn semantic_ranges_are_sorted() {
742 let ranges = SemanticSplitRanges::new(vec![(0, 0..1), (1, 0..2), (0, 1..2), (2, 0..4)]);
743
744 assert_eq!(
745 ranges.ranges,
746 vec![(2, 0..4), (1, 0..2), (0, 0..1), (0, 1..2)]
747 );
748 }
749
750 #[test]
751 fn semantic_ranges_skip_previous_ranges() {
752 let mut ranges = SemanticSplitRanges::new(vec![(0, 0..1), (1, 0..2), (0, 1..2), (2, 0..4)]);
753
754 ranges.update_cursor(1);
755
756 assert_eq!(
757 ranges.ranges_after_offset(0).collect::<Vec<_>>(),
758 vec![(0, 1..2)]
759 );
760 }
761}