1use std::{cmp::Ordering, fmt, iter::once, ops::Range};
2
3use either::Either;
4use itertools::Itertools;
5use strum::IntoEnumIterator;
6
7use self::fallback::FallbackLevel;
8use crate::{chunk_size::MemoizedChunkSizer, trim::Trim, ChunkCapacity, ChunkConfig, ChunkSizer};
9
10#[cfg(feature = "code")]
11mod code;
12mod fallback;
13#[cfg(feature = "markdown")]
14mod markdown;
15mod text;
16
17#[cfg(feature = "code")]
18pub use code::{CodeSplitter, CodeSplitterError};
19#[cfg(feature = "markdown")]
20pub use markdown::MarkdownSplitter;
21pub use text::TextSplitter;
22
23trait Splitter<Sizer>
26where
27 Sizer: ChunkSizer,
28{
29 type Level: SemanticLevel;
30
31 const TRIM: Trim = Trim::All;
33
34 fn chunk_config(&self) -> &ChunkConfig<Sizer>;
36
37 fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)>;
39
40 fn chunk_indices<'splitter, 'text: 'splitter>(
43 &'splitter self,
44 text: &'text str,
45 ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter
46 where
47 Sizer: 'splitter,
48 {
49 TextChunks::<Sizer, Self::Level>::new(
50 self.chunk_config(),
51 text,
52 self.parse(text),
53 Self::TRIM,
54 )
55 }
56
57 fn chunk_char_indices<'splitter, 'text: 'splitter>(
64 &'splitter self,
65 text: &'text str,
66 ) -> impl Iterator<Item = ChunkCharIndex<'text>> + 'splitter
67 where
68 Sizer: 'splitter,
69 {
70 TextChunksWithCharIndices::<Sizer, Self::Level>::new(
71 self.chunk_config(),
72 text,
73 self.parse(text),
74 Self::TRIM,
75 )
76 }
77
78 fn chunks<'splitter, 'text: 'splitter>(
81 &'splitter self,
82 text: &'text str,
83 ) -> impl Iterator<Item = &'text str> + 'splitter
84 where
85 Sizer: 'splitter,
86 {
87 self.chunk_indices(text).map(|(_, t)| t)
88 }
89}
90
91trait SemanticLevel: Copy + fmt::Debug + Ord + PartialOrd + 'static {
93 fn sections(
98 text: &str,
99 level_ranges: impl Iterator<Item = (Self, Range<usize>)>,
100 ) -> impl Iterator<Item = (usize, &str)> {
101 let mut cursor = 0;
102 let mut final_match = false;
103 level_ranges
104 .batching(move |it| {
105 loop {
106 match it.next() {
107 None if final_match => return None,
109 None => {
111 final_match = true;
112 return text.get(cursor..).map(|t| Either::Left(once((cursor, t))));
113 }
114 Some((_, range)) => {
116 if range.start < cursor {
117 continue;
118 }
119 let offset = cursor;
120 let prev_section = text
121 .get(offset..range.start)
122 .expect("invalid character sequence");
123 let separator = text
124 .get(range.start..range.end)
125 .expect("invalid character sequence");
126 cursor = range.end;
127 return Some(Either::Right(
128 [(offset, prev_section), (range.start, separator)].into_iter(),
129 ));
130 }
131 }
132 }
133 })
134 .flatten()
135 .filter(|(_, s)| !s.is_empty())
136 }
137}
138
139#[derive(Debug)]
142struct SemanticSplitRanges<Level>
143where
144 Level: SemanticLevel,
145{
146 cursor: usize,
149 ranges: Vec<(Level, Range<usize>)>,
151}
152
153impl<Level> SemanticSplitRanges<Level>
154where
155 Level: SemanticLevel,
156{
157 fn new(mut ranges: Vec<(Level, Range<usize>)>) -> Self {
158 ranges.sort_unstable_by(|(_, a), (_, b)| {
160 a.start.cmp(&b.start).then_with(|| b.end.cmp(&a.end))
161 });
162 Self { cursor: 0, ranges }
163 }
164
165 fn ranges_after_offset(
167 &self,
168 offset: usize,
169 ) -> impl Iterator<Item = (Level, Range<usize>)> + '_ {
170 self.ranges[self.cursor..]
171 .iter()
172 .filter(move |(_, sep)| sep.start >= offset)
173 .map(|(l, r)| (*l, r.start..r.end))
174 }
175 fn level_ranges_after_offset(
177 &self,
178 offset: usize,
179 level: Level,
180 ) -> impl Iterator<Item = (Level, Range<usize>)> + '_ {
181 let first_item = self
184 .ranges_after_offset(offset)
185 .position(|(l, _)| l == level)
186 .and_then(|i| {
187 self.ranges_after_offset(offset)
188 .skip(i)
189 .coalesce(|(a_level, a_range), (b_level, b_range)| {
190 if a_level == b_level && a_range.start == b_range.start && i == 0 {
192 Ok((b_level, b_range))
193 } else {
194 Err(((a_level, a_range), (b_level, b_range)))
195 }
196 })
197 .next()
199 });
200 self.ranges_after_offset(offset)
202 .filter(move |(l, _)| l >= &level)
203 .skip_while(move |(l, r)| {
204 first_item.as_ref().is_some_and(|(_, fir)| {
205 (l > &level && r.contains(&fir.start))
206 || (l == &level && r.start == fir.start && r.end > fir.end)
207 })
208 })
209 }
210
211 fn levels_in_remaining_text(&self, offset: usize) -> impl Iterator<Item = Level> + '_ {
214 self.ranges_after_offset(offset)
215 .map(|(l, _)| l)
216 .sorted()
217 .dedup()
218 }
219
220 fn semantic_chunks<'splitter, 'text: 'splitter>(
222 &'splitter self,
223 offset: usize,
224 text: &'text str,
225 semantic_level: Level,
226 ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
227 Level::sections(
228 text,
229 self.level_ranges_after_offset(offset, semantic_level)
230 .map(move |(l, sep)| (l, sep.start - offset..sep.end - offset)),
231 )
232 .map(move |(i, str)| (offset + i, str))
233 }
234
235 fn update_cursor(&mut self, cursor: usize) {
237 self.cursor += self.ranges[self.cursor..]
238 .iter()
239 .position(|(_, range)| range.start >= cursor)
240 .unwrap_or_else(|| self.ranges.len() - self.cursor);
241 }
242}
243
244#[derive(Debug)]
246struct TextChunks<'text, 'sizer, Sizer, Level>
247where
248 Sizer: ChunkSizer,
249 Level: SemanticLevel,
250{
251 capacity: ChunkCapacity,
253 chunk_sizer: MemoizedChunkSizer<'sizer, Sizer>,
255 chunk_stats: ChunkStats,
257 cursor: usize,
259 next_sections: Vec<(usize, &'text str)>,
261 overlap: ChunkCapacity,
263 prev_item_end: usize,
265 semantic_split: SemanticSplitRanges<Level>,
267 text: &'text str,
269 trim: Trim,
271}
272
273impl<'sizer, 'text: 'sizer, Sizer, Level> TextChunks<'text, 'sizer, Sizer, Level>
274where
275 Sizer: ChunkSizer,
276 Level: SemanticLevel,
277{
278 fn new(
281 chunk_config: &'sizer ChunkConfig<Sizer>,
282 text: &'text str,
283 offsets: Vec<(Level, Range<usize>)>,
284 trim: Trim,
285 ) -> Self {
286 let ChunkConfig {
287 capacity,
288 overlap,
289 sizer,
290 trim: trim_enabled,
291 } = chunk_config;
292 Self {
293 capacity: *capacity,
294 chunk_sizer: MemoizedChunkSizer::new(sizer),
295 chunk_stats: ChunkStats::new(),
296 cursor: 0,
297 next_sections: Vec::new(),
298 overlap: (*overlap).into(),
299 prev_item_end: 0,
300 semantic_split: SemanticSplitRanges::new(offsets),
301 text,
302 trim: if *trim_enabled { trim } else { Trim::None },
303 }
304 }
305
306 fn next_chunk(&mut self) -> Option<(usize, &'text str)> {
310 self.semantic_split.update_cursor(self.cursor);
311 let low = self.update_next_sections();
312 let (start, end) = self.binary_search_next_chunk(low)?;
313 let chunk = self.text.get(start..end)?;
314 self.chunk_stats.update_max_chunk_size(end - start);
315
316 self.chunk_sizer.clear_cache();
318 self.update_cursor(end);
320
321 Some(self.trim.trim(start, chunk))
323 }
324
325 fn binary_search_next_chunk(&mut self, mut low: usize) -> Option<(usize, usize)> {
327 let start = self.cursor;
328 let mut end = self.cursor;
329 let mut equals_found = false;
330 let mut high = self.next_sections.len().saturating_sub(1);
331 let mut successful_index = None;
332 let mut successful_chunk_size = None;
333
334 while low <= high {
335 let mid = low + (high - low) / 2;
336 let (offset, str) = self.next_sections[mid];
337 let text_end = offset + str.len();
338 let chunk = self.text.get(start..text_end)?;
339 let chunk_size = self.chunk_sizer.chunk_size(start, chunk, self.trim);
340 let fits = self.capacity.fits(chunk_size);
341
342 match fits {
343 Ordering::Less => {
344 if text_end > end {
346 end = text_end;
347 successful_index = Some(mid);
348 successful_chunk_size = Some(chunk_size);
349 }
350 }
351 Ordering::Equal => {
352 if text_end < end || !equals_found {
354 end = text_end;
355 successful_index = Some(mid);
356 successful_chunk_size = Some(chunk_size);
357 }
358 equals_found = true;
359 }
360 Ordering::Greater => {
361 if mid == 0 && start == end {
363 end = text_end;
364 successful_index = Some(mid);
365 successful_chunk_size = Some(chunk_size);
366 }
367 }
368 }
369
370 if fits.is_lt() {
372 low = mid + 1;
373 } else if mid > 0 {
374 high = mid - 1;
375 } else {
376 break;
378 }
379 }
380
381 if let (Some(successful_index), Some(chunk_size)) =
382 (successful_index, successful_chunk_size)
383 {
384 let mut range = successful_index..self.next_sections.len();
385 range.next();
387
388 for index in range {
389 let (offset, str) = self.next_sections[index];
390 let text_end = offset + str.len();
391 let chunk = self.text.get(start..text_end)?;
392 let size = self.chunk_sizer.chunk_size(start, chunk, self.trim);
393 if size <= chunk_size {
394 if text_end > end {
395 end = text_end;
396 }
397 } else {
398 break;
399 }
400 }
401 }
402
403 Some((start, end))
404 }
405
406 fn update_cursor(&mut self, end: usize) {
409 if self.overlap.max == 0 {
410 self.cursor = end;
411 return;
412 }
413
414 let mut start = end;
416 let mut low = 0;
417 let mut high = match self
419 .next_sections
420 .binary_search_by_key(&end, |(offset, str)| offset + str.len())
421 {
422 Ok(i) | Err(i) => i,
423 };
424
425 while low <= high {
426 let mid = low + (high - low) / 2;
427 let (offset, _) = self.next_sections[mid];
428 let chunk_size = self.chunk_sizer.chunk_size(
429 offset,
430 self.text.get(offset..end).expect("Invalid range"),
431 self.trim,
432 );
433 let fits = self.overlap.fits(chunk_size);
434
435 if fits.is_le() && offset < start && offset > self.cursor {
437 start = offset;
438 }
439
440 if fits.is_lt() && mid > 0 {
442 high = mid - 1;
443 } else {
444 low = mid + 1;
445 }
446 }
447
448 self.cursor = start;
449 }
450
451 #[expect(clippy::too_many_lines)]
455 fn update_next_sections(&mut self) -> usize {
456 self.next_sections.clear();
458
459 let remaining_text = self.text.get(self.cursor..).unwrap();
460 let mut lower_level = None;
461
462 let (semantic_level, mut max_offset) = self.chunk_sizer.find_correct_level(
463 self.cursor,
464 &self.capacity,
465 self.semantic_split
466 .levels_in_remaining_text(self.cursor)
467 .filter_map(|level| {
468 let first_chunk = self
469 .semantic_split
470 .semantic_chunks(self.cursor, remaining_text, level)
471 .next();
472
473 let result = first_chunk.map(|(_, str)| {
474 let candidate_lower_level = lower_level;
475 (level, str, candidate_lower_level)
476 });
477
478 lower_level = Some(level);
479 result
480 }),
481 |lower_level, chunk_end| {
482 lower_level.map_or_else(
483 || Either::Left(std::iter::empty()),
484 |lower_level| {
485 Either::Right(
486 self.semantic_split
487 .semantic_chunks(self.cursor, remaining_text, lower_level)
488 .map(|(offset, text)| offset + text.len())
489 .take_while(move |end| *end <= chunk_end),
490 )
491 },
492 )
493 },
494 self.trim,
495 );
496
497 let sections = if let Some(semantic_level) = semantic_level {
498 Either::Left(self.semantic_split.semantic_chunks(
499 self.cursor,
500 remaining_text,
501 semantic_level,
502 ))
503 } else {
504 let (semantic_level, fallback_max_offset) = self.chunk_sizer.find_correct_level(
505 self.cursor,
506 &self.capacity,
507 FallbackLevel::iter().filter_map(|level| {
508 level
509 .sections(remaining_text)
510 .next()
511 .map(|(_, str)| (level, str, level.boundary_level_for_probe()))
512 }),
513 |lower_level, chunk_end| {
514 lower_level.map_or_else(
515 || Either::Left(std::iter::empty()),
516 |lower_level| {
517 Either::Right(
518 lower_level
519 .sections(remaining_text)
520 .map(|(offset, text)| self.cursor + offset + text.len())
521 .take_while(move |end| *end <= chunk_end),
522 )
523 },
524 )
525 },
526 self.trim,
527 );
528
529 max_offset = match (fallback_max_offset, max_offset) {
530 (Some(fallback), Some(max)) => Some(fallback.min(max)),
531 (fallback, max) => fallback.or(max),
532 };
533
534 let fallback_level = semantic_level.unwrap_or(FallbackLevel::Char);
535
536 Either::Right(
537 fallback_level
538 .sections(remaining_text)
539 .map(|(offset, text)| (self.cursor + offset, text)),
540 )
541 };
542
543 let mut sections = sections
544 .take_while(move |(offset, _)| max_offset.is_none_or(|max| *offset <= max))
545 .filter(|(_, str)| !str.is_empty());
546
547 let mut low = 0;
550 let mut prev_equals: Option<usize> = None;
551 let max = self.capacity.max;
552 let mut target_offset = self.chunk_stats.max_chunk_size.unwrap_or(max);
553
554 loop {
555 let prev_num = self.next_sections.len();
556 for (offset, str) in sections.by_ref() {
557 self.next_sections.push((offset, str));
558 if offset + str.len() > (self.cursor.saturating_add(target_offset)) {
559 break;
560 }
561 }
562 let new_num = self.next_sections.len();
563 if new_num - prev_num == 0 {
565 break;
566 }
567
568 if let Some(&(offset, str)) = self.next_sections.last() {
570 let text_end = offset + str.len();
571 if (text_end - self.cursor) < target_offset {
572 break;
573 }
574 let chunk_size = self.chunk_sizer.chunk_size(
575 offset,
576 self.text.get(self.cursor..text_end).expect("Invalid range"),
577 self.trim,
578 );
579 let fits = self.capacity.fits(chunk_size);
580
581 if fits.is_le() {
582 let final_offset = offset + str.len() - self.cursor;
583 let size = chunk_size.max(1);
584 let diff = (max - size).max(1);
585 let avg_size = final_offset.div_ceil(size);
586
587 target_offset = final_offset
588 .saturating_add(diff.saturating_mul(avg_size))
589 .saturating_add(final_offset.div_ceil(10));
590 }
591
592 match fits {
593 Ordering::Less => {
594 low = new_num.saturating_sub(1);
596 }
597 Ordering::Equal => {
598 if let Some(prev) = prev_equals {
601 if prev < chunk_size {
602 break;
603 }
604 }
605 prev_equals = Some(chunk_size);
606 }
607 Ordering::Greater => {
608 break;
609 }
610 }
611 }
612 }
613
614 low
615 }
616}
617
618impl<'sizer, 'text: 'sizer, Sizer, Level> Iterator for TextChunks<'text, 'sizer, Sizer, Level>
619where
620 Sizer: ChunkSizer,
621 Level: SemanticLevel,
622{
623 type Item = (usize, &'text str);
624
625 fn next(&mut self) -> Option<Self::Item> {
626 loop {
627 if self.cursor >= self.text.len() {
629 return None;
630 }
631
632 match self.next_chunk()? {
633 (_, "") => {}
636 c => {
637 let item_end = c.0 + c.1.len();
638 if item_end <= self.prev_item_end {
640 continue;
641 }
642 self.prev_item_end = item_end;
643 return Some(c);
644 }
645 }
646 }
647 }
648}
649
650#[derive(Debug, Clone, Copy, PartialEq, Eq)]
652pub struct ChunkCharIndex<'text> {
653 pub chunk: &'text str,
655 pub byte_offset: usize,
657 pub char_offset: usize,
659}
660
661#[derive(Debug)]
663struct TextChunksWithCharIndices<'text, 'sizer, Sizer, Level>
664where
665 Sizer: ChunkSizer,
666 Level: SemanticLevel,
667{
668 text: &'text str,
670 text_chunks: TextChunks<'text, 'sizer, Sizer, Level>,
672 byte_offset: usize,
674 char_offset: usize,
676}
677
678impl<'sizer, 'text: 'sizer, Sizer, Level> TextChunksWithCharIndices<'text, 'sizer, Sizer, Level>
679where
680 Sizer: ChunkSizer,
681 Level: SemanticLevel,
682{
683 fn new(
686 chunk_config: &'sizer ChunkConfig<Sizer>,
687 text: &'text str,
688 offsets: Vec<(Level, Range<usize>)>,
689 trim: Trim,
690 ) -> Self {
691 Self {
692 text,
693 text_chunks: TextChunks::new(chunk_config, text, offsets, trim),
694 byte_offset: 0,
695 char_offset: 0,
696 }
697 }
698}
699
700impl<'sizer, 'text: 'sizer, Sizer, Level> Iterator
701 for TextChunksWithCharIndices<'text, 'sizer, Sizer, Level>
702where
703 Sizer: ChunkSizer,
704 Level: SemanticLevel,
705{
706 type Item = ChunkCharIndex<'text>;
707
708 fn next(&mut self) -> Option<Self::Item> {
709 let (byte_offset, chunk) = self.text_chunks.next()?;
710 let preceding_text = self
711 .text
712 .get(self.byte_offset..byte_offset)
713 .expect("Invalid byte sequence");
714 self.byte_offset = byte_offset;
715 self.char_offset += preceding_text.chars().count();
716 Some(ChunkCharIndex {
717 chunk,
718 byte_offset,
719 char_offset: self.char_offset,
720 })
721 }
722}
723
724#[derive(Debug, Default)]
726struct ChunkStats {
727 max_chunk_size: Option<usize>,
729}
730
731impl ChunkStats {
732 fn new() -> Self {
733 Self::default()
734 }
735
736 fn update_max_chunk_size(&mut self, size: usize) {
738 self.max_chunk_size = self.max_chunk_size.map(|s| s.max(size)).or(Some(size));
739 }
740}
741
742#[cfg(test)]
743mod tests {
744 use super::*;
745
746 #[test]
747 fn chunk_stats_empty() {
748 let stats = ChunkStats::new();
749 assert_eq!(stats.max_chunk_size, None);
750 }
751
752 #[test]
753 fn chunk_stats_one() {
754 let mut stats = ChunkStats::new();
755 stats.update_max_chunk_size(10);
756 assert_eq!(stats.max_chunk_size, Some(10));
757 }
758
759 #[test]
760 fn chunk_stats_multiple() {
761 let mut stats = ChunkStats::new();
762 stats.update_max_chunk_size(10);
763 stats.update_max_chunk_size(20);
764 stats.update_max_chunk_size(30);
765 assert_eq!(stats.max_chunk_size, Some(30));
766 }
767
768 impl SemanticLevel for usize {}
769
770 #[test]
771 fn semantic_ranges_are_sorted() {
772 let ranges = SemanticSplitRanges::new(vec![(0, 0..1), (1, 0..2), (0, 1..2), (2, 0..4)]);
773
774 assert_eq!(
775 ranges.ranges,
776 vec![(2, 0..4), (1, 0..2), (0, 0..1), (0, 1..2)]
777 );
778 }
779
780 #[test]
781 fn semantic_ranges_skip_previous_ranges() {
782 let mut ranges = SemanticSplitRanges::new(vec![(0, 0..1), (1, 0..2), (0, 1..2), (2, 0..4)]);
783
784 ranges.update_cursor(1);
785
786 assert_eq!(
787 ranges.ranges_after_offset(0).collect::<Vec<_>>(),
788 vec![(0, 1..2)]
789 );
790 }
791}