1use std::{cmp::Ordering, fmt, iter::once, ops::Range};
2
3use either::Either;
4use itertools::Itertools;
5use strum::IntoEnumIterator;
6
7use self::fallback::FallbackLevel;
8use crate::{chunk_size::MemoizedChunkSizer, trim::Trim, ChunkCapacity, ChunkConfig, ChunkSizer};
9
10#[cfg(feature = "code")]
11mod code;
12mod fallback;
13#[cfg(feature = "markdown")]
14mod markdown;
15mod text;
16
17#[cfg(feature = "code")]
18pub use code::{CodeSplitter, CodeSplitterError};
19#[cfg(feature = "markdown")]
20pub use markdown::MarkdownSplitter;
21pub use text::TextSplitter;
22
23trait Splitter<Sizer>
26where
27    Sizer: ChunkSizer,
28{
29    type Level: SemanticLevel;
30
31    const TRIM: Trim = Trim::All;
33
34    fn chunk_config(&self) -> &ChunkConfig<Sizer>;
36
37    fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)>;
39
40    fn chunk_indices<'splitter, 'text: 'splitter>(
43        &'splitter self,
44        text: &'text str,
45    ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter
46    where
47        Sizer: 'splitter,
48    {
49        TextChunks::<Sizer, Self::Level>::new(
50            self.chunk_config(),
51            text,
52            self.parse(text),
53            Self::TRIM,
54        )
55    }
56
57    fn chunk_char_indices<'splitter, 'text: 'splitter>(
64        &'splitter self,
65        text: &'text str,
66    ) -> impl Iterator<Item = ChunkCharIndex<'text>> + 'splitter
67    where
68        Sizer: 'splitter,
69    {
70        TextChunksWithCharIndices::<Sizer, Self::Level>::new(
71            self.chunk_config(),
72            text,
73            self.parse(text),
74            Self::TRIM,
75        )
76    }
77
78    fn chunks<'splitter, 'text: 'splitter>(
81        &'splitter self,
82        text: &'text str,
83    ) -> impl Iterator<Item = &'text str> + 'splitter
84    where
85        Sizer: 'splitter,
86    {
87        self.chunk_indices(text).map(|(_, t)| t)
88    }
89}
90
91trait SemanticLevel: Copy + fmt::Debug + Ord + PartialOrd + 'static {
93    fn sections(
98        text: &str,
99        level_ranges: impl Iterator<Item = (Self, Range<usize>)>,
100    ) -> impl Iterator<Item = (usize, &str)> {
101        let mut cursor = 0;
102        let mut final_match = false;
103        level_ranges
104            .batching(move |it| {
105                loop {
106                    match it.next() {
107                        None if final_match => return None,
109                        None => {
111                            final_match = true;
112                            return text.get(cursor..).map(|t| Either::Left(once((cursor, t))));
113                        }
114                        Some((_, range)) => {
116                            if range.start < cursor {
117                                continue;
118                            }
119                            let offset = cursor;
120                            let prev_section = text
121                                .get(offset..range.start)
122                                .expect("invalid character sequence");
123                            let separator = text
124                                .get(range.start..range.end)
125                                .expect("invalid character sequence");
126                            cursor = range.end;
127                            return Some(Either::Right(
128                                [(offset, prev_section), (range.start, separator)].into_iter(),
129                            ));
130                        }
131                    }
132                }
133            })
134            .flatten()
135            .filter(|(_, s)| !s.is_empty())
136    }
137}
138
139#[derive(Debug)]
142struct SemanticSplitRanges<Level>
143where
144    Level: SemanticLevel,
145{
146    cursor: usize,
149    ranges: Vec<(Level, Range<usize>)>,
151}
152
153impl<Level> SemanticSplitRanges<Level>
154where
155    Level: SemanticLevel,
156{
157    fn new(mut ranges: Vec<(Level, Range<usize>)>) -> Self {
158        ranges.sort_unstable_by(|(_, a), (_, b)| {
160            a.start.cmp(&b.start).then_with(|| b.end.cmp(&a.end))
161        });
162        Self { cursor: 0, ranges }
163    }
164
165    fn ranges_after_offset(
167        &self,
168        offset: usize,
169    ) -> impl Iterator<Item = (Level, Range<usize>)> + '_ {
170        self.ranges[self.cursor..]
171            .iter()
172            .filter(move |(_, sep)| sep.start >= offset)
173            .map(|(l, r)| (*l, r.start..r.end))
174    }
175    fn level_ranges_after_offset(
177        &self,
178        offset: usize,
179        level: Level,
180    ) -> impl Iterator<Item = (Level, Range<usize>)> + '_ {
181        let first_item = self
184            .ranges_after_offset(offset)
185            .position(|(l, _)| l == level)
186            .and_then(|i| {
187                self.ranges_after_offset(offset)
188                    .skip(i)
189                    .coalesce(|(a_level, a_range), (b_level, b_range)| {
190                        if a_level == b_level && a_range.start == b_range.start && i == 0 {
192                            Ok((b_level, b_range))
193                        } else {
194                            Err(((a_level, a_range), (b_level, b_range)))
195                        }
196                    })
197                    .next()
199            });
200        self.ranges_after_offset(offset)
202            .filter(move |(l, _)| l >= &level)
203            .skip_while(move |(l, r)| {
204                first_item.as_ref().is_some_and(|(_, fir)| {
205                    (l > &level && r.contains(&fir.start))
206                        || (l == &level && r.start == fir.start && r.end > fir.end)
207                })
208            })
209    }
210
211    fn levels_in_remaining_text(&self, offset: usize) -> impl Iterator<Item = Level> + '_ {
214        self.ranges_after_offset(offset)
215            .map(|(l, _)| l)
216            .sorted()
217            .dedup()
218    }
219
220    fn semantic_chunks<'splitter, 'text: 'splitter>(
222        &'splitter self,
223        offset: usize,
224        text: &'text str,
225        semantic_level: Level,
226    ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
227        Level::sections(
228            text,
229            self.level_ranges_after_offset(offset, semantic_level)
230                .map(move |(l, sep)| (l, sep.start - offset..sep.end - offset)),
231        )
232        .map(move |(i, str)| (offset + i, str))
233    }
234
235    fn update_cursor(&mut self, cursor: usize) {
237        self.cursor += self.ranges[self.cursor..]
238            .iter()
239            .position(|(_, range)| range.start >= cursor)
240            .unwrap_or_else(|| self.ranges.len() - self.cursor);
241    }
242}
243
244#[derive(Debug)]
246struct TextChunks<'text, 'sizer, Sizer, Level>
247where
248    Sizer: ChunkSizer,
249    Level: SemanticLevel,
250{
251    capacity: ChunkCapacity,
253    chunk_sizer: MemoizedChunkSizer<'sizer, Sizer>,
255    chunk_stats: ChunkStats,
257    cursor: usize,
259    next_sections: Vec<(usize, &'text str)>,
261    overlap: ChunkCapacity,
263    prev_item_end: usize,
265    semantic_split: SemanticSplitRanges<Level>,
267    text: &'text str,
269    trim: Trim,
271}
272
273impl<'sizer, 'text: 'sizer, Sizer, Level> TextChunks<'text, 'sizer, Sizer, Level>
274where
275    Sizer: ChunkSizer,
276    Level: SemanticLevel,
277{
278    fn new(
281        chunk_config: &'sizer ChunkConfig<Sizer>,
282        text: &'text str,
283        offsets: Vec<(Level, Range<usize>)>,
284        trim: Trim,
285    ) -> Self {
286        let ChunkConfig {
287            capacity,
288            overlap,
289            sizer,
290            trim: trim_enabled,
291        } = chunk_config;
292        Self {
293            capacity: *capacity,
294            chunk_sizer: MemoizedChunkSizer::new(sizer),
295            chunk_stats: ChunkStats::new(),
296            cursor: 0,
297            next_sections: Vec::new(),
298            overlap: (*overlap).into(),
299            prev_item_end: 0,
300            semantic_split: SemanticSplitRanges::new(offsets),
301            text,
302            trim: if *trim_enabled { trim } else { Trim::None },
303        }
304    }
305
306    fn next_chunk(&mut self) -> Option<(usize, &'text str)> {
310        self.semantic_split.update_cursor(self.cursor);
311        let low = self.update_next_sections();
312        let (start, end) = self.binary_search_next_chunk(low)?;
313        let chunk = self.text.get(start..end)?;
314        self.chunk_stats.update_max_chunk_size(end - start);
315
316        self.chunk_sizer.clear_cache();
318        self.update_cursor(end);
320
321        Some(self.trim.trim(start, chunk))
323    }
324
325    fn binary_search_next_chunk(&mut self, mut low: usize) -> Option<(usize, usize)> {
327        let start = self.cursor;
328        let mut end = self.cursor;
329        let mut equals_found = false;
330        let mut high = self.next_sections.len().saturating_sub(1);
331        let mut successful_index = None;
332        let mut successful_chunk_size = None;
333
334        while low <= high {
335            let mid = low + (high - low) / 2;
336            let (offset, str) = self.next_sections[mid];
337            let text_end = offset + str.len();
338            let chunk = self.text.get(start..text_end)?;
339            let chunk_size = self.chunk_sizer.chunk_size(start, chunk, self.trim);
340            let fits = self.capacity.fits(chunk_size);
341
342            match fits {
343                Ordering::Less => {
344                    if text_end > end {
346                        end = text_end;
347                        successful_index = Some(mid);
348                        successful_chunk_size = Some(chunk_size);
349                    }
350                }
351                Ordering::Equal => {
352                    if text_end < end || !equals_found {
354                        end = text_end;
355                        successful_index = Some(mid);
356                        successful_chunk_size = Some(chunk_size);
357                    }
358                    equals_found = true;
359                }
360                Ordering::Greater => {
361                    if mid == 0 && start == end {
363                        end = text_end;
364                        successful_index = Some(mid);
365                        successful_chunk_size = Some(chunk_size);
366                    }
367                }
368            }
369
370            if fits.is_lt() {
372                low = mid + 1;
373            } else if mid > 0 {
374                high = mid - 1;
375            } else {
376                break;
378            }
379        }
380
381        if let (Some(successful_index), Some(chunk_size)) =
382            (successful_index, successful_chunk_size)
383        {
384            let mut range = successful_index..self.next_sections.len();
385            range.next();
387
388            for index in range {
389                let (offset, str) = self.next_sections[index];
390                let text_end = offset + str.len();
391                let chunk = self.text.get(start..text_end)?;
392                let size = self.chunk_sizer.chunk_size(start, chunk, self.trim);
393                if size <= chunk_size {
394                    if text_end > end {
395                        end = text_end;
396                    }
397                } else {
398                    break;
399                }
400            }
401        }
402
403        Some((start, end))
404    }
405
406    fn update_cursor(&mut self, end: usize) {
409        if self.overlap.max == 0 {
410            self.cursor = end;
411            return;
412        }
413
414        let mut start = end;
416        let mut low = 0;
417        let mut high = match self
419            .next_sections
420            .binary_search_by_key(&end, |(offset, str)| offset + str.len())
421        {
422            Ok(i) | Err(i) => i,
423        };
424
425        while low <= high {
426            let mid = low + (high - low) / 2;
427            let (offset, _) = self.next_sections[mid];
428            let chunk_size = self.chunk_sizer.chunk_size(
429                offset,
430                self.text.get(offset..end).expect("Invalid range"),
431                self.trim,
432            );
433            let fits = self.overlap.fits(chunk_size);
434
435            if fits.is_le() && offset < start && offset > self.cursor {
437                start = offset;
438            }
439
440            if fits.is_lt() && mid > 0 {
442                high = mid - 1;
443            } else {
444                low = mid + 1;
445            }
446        }
447
448        self.cursor = start;
449    }
450
451    #[expect(clippy::too_many_lines)]
455    fn update_next_sections(&mut self) -> usize {
456        self.next_sections.clear();
458
459        let remaining_text = self.text.get(self.cursor..).unwrap();
460
461        let (semantic_level, mut max_offset) = self.chunk_sizer.find_correct_level(
462            self.cursor,
463            &self.capacity,
464            self.semantic_split
465                .levels_in_remaining_text(self.cursor)
466                .filter_map(|level| {
467                    self.semantic_split
468                        .semantic_chunks(self.cursor, remaining_text, level)
469                        .next()
470                        .map(|(_, str)| (level, str))
471                }),
472            self.trim,
473        );
474
475        let sections = if let Some(semantic_level) = semantic_level {
476            Either::Left(self.semantic_split.semantic_chunks(
477                self.cursor,
478                remaining_text,
479                semantic_level,
480            ))
481        } else {
482            let (semantic_level, fallback_max_offset) = self.chunk_sizer.find_correct_level(
483                self.cursor,
484                &self.capacity,
485                FallbackLevel::iter().filter_map(|level| {
486                    level
487                        .sections(remaining_text)
488                        .next()
489                        .map(|(_, str)| (level, str))
490                }),
491                self.trim,
492            );
493
494            max_offset = match (fallback_max_offset, max_offset) {
495                (Some(fallback), Some(max)) => Some(fallback.min(max)),
496                (fallback, max) => fallback.or(max),
497            };
498
499            let fallback_level = semantic_level.unwrap_or(FallbackLevel::Char);
500
501            Either::Right(
502                fallback_level
503                    .sections(remaining_text)
504                    .map(|(offset, text)| (self.cursor + offset, text)),
505            )
506        };
507
508        let mut sections = sections
509            .take_while(move |(offset, _)| max_offset.is_none_or(|max| *offset <= max))
510            .filter(|(_, str)| !str.is_empty());
511
512        let mut low = 0;
515        let mut prev_equals: Option<usize> = None;
516        let max = self.capacity.max;
517        let mut target_offset = self.chunk_stats.max_chunk_size.unwrap_or(max);
518
519        loop {
520            let prev_num = self.next_sections.len();
521            for (offset, str) in sections.by_ref() {
522                self.next_sections.push((offset, str));
523                if offset + str.len() > (self.cursor.saturating_add(target_offset)) {
524                    break;
525                }
526            }
527            let new_num = self.next_sections.len();
528            if new_num - prev_num == 0 {
530                break;
531            }
532
533            if let Some(&(offset, str)) = self.next_sections.last() {
535                let text_end = offset + str.len();
536                if (text_end - self.cursor) < target_offset {
537                    break;
538                }
539                let chunk_size = self.chunk_sizer.chunk_size(
540                    offset,
541                    self.text.get(self.cursor..text_end).expect("Invalid range"),
542                    self.trim,
543                );
544                let fits = self.capacity.fits(chunk_size);
545
546                if fits.is_le() {
547                    let final_offset = offset + str.len() - self.cursor;
548                    let size = chunk_size.max(1);
549                    let diff = (max - size).max(1);
550                    let avg_size = final_offset.div_ceil(size);
551
552                    target_offset = final_offset
553                        .saturating_add(diff.saturating_mul(avg_size))
554                        .saturating_add(final_offset.div_ceil(10));
555                }
556
557                match fits {
558                    Ordering::Less => {
559                        low = new_num.saturating_sub(1);
561                    }
562                    Ordering::Equal => {
563                        if let Some(prev) = prev_equals {
566                            if prev < chunk_size {
567                                break;
568                            }
569                        }
570                        prev_equals = Some(chunk_size);
571                    }
572                    Ordering::Greater => {
573                        break;
574                    }
575                }
576            }
577        }
578
579        low
580    }
581}
582
583impl<'sizer, 'text: 'sizer, Sizer, Level> Iterator for TextChunks<'text, 'sizer, Sizer, Level>
584where
585    Sizer: ChunkSizer,
586    Level: SemanticLevel,
587{
588    type Item = (usize, &'text str);
589
590    fn next(&mut self) -> Option<Self::Item> {
591        loop {
592            if self.cursor >= self.text.len() {
594                return None;
595            }
596
597            match self.next_chunk()? {
598                (_, "") => {}
601                c => {
602                    let item_end = c.0 + c.1.len();
603                    if item_end <= self.prev_item_end {
605                        continue;
606                    }
607                    self.prev_item_end = item_end;
608                    return Some(c);
609                }
610            }
611        }
612    }
613}
614
615#[derive(Debug, Clone, Copy, PartialEq, Eq)]
617pub struct ChunkCharIndex<'text> {
618    pub chunk: &'text str,
620    pub byte_offset: usize,
622    pub char_offset: usize,
624}
625
626#[derive(Debug)]
628struct TextChunksWithCharIndices<'text, 'sizer, Sizer, Level>
629where
630    Sizer: ChunkSizer,
631    Level: SemanticLevel,
632{
633    text: &'text str,
635    text_chunks: TextChunks<'text, 'sizer, Sizer, Level>,
637    byte_offset: usize,
639    char_offset: usize,
641}
642
643impl<'sizer, 'text: 'sizer, Sizer, Level> TextChunksWithCharIndices<'text, 'sizer, Sizer, Level>
644where
645    Sizer: ChunkSizer,
646    Level: SemanticLevel,
647{
648    fn new(
651        chunk_config: &'sizer ChunkConfig<Sizer>,
652        text: &'text str,
653        offsets: Vec<(Level, Range<usize>)>,
654        trim: Trim,
655    ) -> Self {
656        Self {
657            text,
658            text_chunks: TextChunks::new(chunk_config, text, offsets, trim),
659            byte_offset: 0,
660            char_offset: 0,
661        }
662    }
663}
664
665impl<'sizer, 'text: 'sizer, Sizer, Level> Iterator
666    for TextChunksWithCharIndices<'text, 'sizer, Sizer, Level>
667where
668    Sizer: ChunkSizer,
669    Level: SemanticLevel,
670{
671    type Item = ChunkCharIndex<'text>;
672
673    fn next(&mut self) -> Option<Self::Item> {
674        let (byte_offset, chunk) = self.text_chunks.next()?;
675        let preceding_text = self
676            .text
677            .get(self.byte_offset..byte_offset)
678            .expect("Invalid byte sequence");
679        self.byte_offset = byte_offset;
680        self.char_offset += preceding_text.chars().count();
681        Some(ChunkCharIndex {
682            chunk,
683            byte_offset,
684            char_offset: self.char_offset,
685        })
686    }
687}
688
689#[derive(Debug, Default)]
691struct ChunkStats {
692    max_chunk_size: Option<usize>,
694}
695
696impl ChunkStats {
697    fn new() -> Self {
698        Self::default()
699    }
700
701    fn update_max_chunk_size(&mut self, size: usize) {
703        self.max_chunk_size = self.max_chunk_size.map(|s| s.max(size)).or(Some(size));
704    }
705}
706
707#[cfg(test)]
708mod tests {
709    use super::*;
710
711    #[test]
712    fn chunk_stats_empty() {
713        let stats = ChunkStats::new();
714        assert_eq!(stats.max_chunk_size, None);
715    }
716
717    #[test]
718    fn chunk_stats_one() {
719        let mut stats = ChunkStats::new();
720        stats.update_max_chunk_size(10);
721        assert_eq!(stats.max_chunk_size, Some(10));
722    }
723
724    #[test]
725    fn chunk_stats_multiple() {
726        let mut stats = ChunkStats::new();
727        stats.update_max_chunk_size(10);
728        stats.update_max_chunk_size(20);
729        stats.update_max_chunk_size(30);
730        assert_eq!(stats.max_chunk_size, Some(30));
731    }
732
733    impl SemanticLevel for usize {}
734
735    #[test]
736    fn semantic_ranges_are_sorted() {
737        let ranges = SemanticSplitRanges::new(vec![(0, 0..1), (1, 0..2), (0, 1..2), (2, 0..4)]);
738
739        assert_eq!(
740            ranges.ranges,
741            vec![(2, 0..4), (1, 0..2), (0, 0..1), (0, 1..2)]
742        );
743    }
744
745    #[test]
746    fn semantic_ranges_skip_previous_ranges() {
747        let mut ranges = SemanticSplitRanges::new(vec![(0, 0..1), (1, 0..2), (0, 1..2), (2, 0..4)]);
748
749        ranges.update_cursor(1);
750
751        assert_eq!(
752            ranges.ranges_after_offset(0).collect::<Vec<_>>(),
753            vec![(0, 1..2)]
754        );
755    }
756}