1use std::{iter::once, ops::Range};
8
9use either::Either;
10use itertools::Itertools;
11use pulldown_cmark::{Event, Options, Parser, Tag};
12
13use crate::{
14 splitter::{SemanticLevel, Splitter},
15 trim::Trim,
16 ChunkConfig, ChunkSizer,
17};
18
19use super::ChunkCharIndex;
20
21#[derive(Debug)]
26#[allow(clippy::module_name_repetitions)]
27pub struct MarkdownSplitter<Sizer>
28where
29 Sizer: ChunkSizer,
30{
31 chunk_config: ChunkConfig<Sizer>,
33}
34
35impl<Sizer> MarkdownSplitter<Sizer>
36where
37 Sizer: ChunkSizer,
38{
39 #[must_use]
48 pub fn new(chunk_config: impl Into<ChunkConfig<Sizer>>) -> Self {
49 Self {
50 chunk_config: chunk_config.into(),
51 }
52 }
53
54 pub fn chunks<'splitter, 'text: 'splitter>(
85 &'splitter self,
86 text: &'text str,
87 ) -> impl Iterator<Item = &'text str> + 'splitter {
88 Splitter::<_>::chunks(self, text)
89 }
90
91 pub fn chunk_indices<'splitter, 'text: 'splitter>(
106 &'splitter self,
107 text: &'text str,
108 ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
109 Splitter::<_>::chunk_indices(self, text)
110 }
111
112 pub fn chunk_char_indices<'splitter, 'text: 'splitter>(
132 &'splitter self,
133 text: &'text str,
134 ) -> impl Iterator<Item = ChunkCharIndex<'text>> + 'splitter {
135 Splitter::<_>::chunk_char_indices(self, text)
136 }
137}
138
139impl<Sizer> Splitter<Sizer> for MarkdownSplitter<Sizer>
140where
141 Sizer: ChunkSizer,
142{
143 type Level = Element;
144
145 const TRIM: Trim = Trim::PreserveIndentation;
146
147 fn chunk_config(&self) -> &ChunkConfig<Sizer> {
148 &self.chunk_config
149 }
150
151 fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
152 Parser::new_ext(text, Options::all())
153 .into_offset_iter()
154 .filter_map(|(event, range)| match event {
155 Event::Start(
156 Tag::Emphasis
157 | Tag::Strong
158 | Tag::Strikethrough
159 | Tag::Link { .. }
160 | Tag::Image { .. }
161 | Tag::Subscript
162 | Tag::Superscript
163 | Tag::TableCell,
164 )
165 | Event::Text(_)
166 | Event::HardBreak
167 | Event::Code(_)
168 | Event::InlineHtml(_)
169 | Event::InlineMath(_)
170 | Event::FootnoteReference(_)
171 | Event::TaskListMarker(_) => Some((Element::Inline, range)),
172 Event::SoftBreak => Some((Element::SoftBreak, range)),
173 Event::Html(_)
174 | Event::DisplayMath(_)
175 | Event::Start(
176 Tag::Paragraph
177 | Tag::CodeBlock(_)
178 | Tag::FootnoteDefinition(_)
179 | Tag::MetadataBlock(_)
180 | Tag::TableHead
181 | Tag::BlockQuote(_)
182 | Tag::TableRow
183 | Tag::Item
184 | Tag::HtmlBlock
185 | Tag::List(_)
186 | Tag::Table(_)
187 | Tag::DefinitionList
188 | Tag::DefinitionListTitle
189 | Tag::DefinitionListDefinition,
190 ) => Some((Element::Block, range)),
191 Event::Rule => Some((Element::Rule, range)),
192 Event::Start(Tag::Heading { level, .. }) => {
193 Some((Element::Heading(level.into()), range))
194 }
195 Event::End(_) => None,
197 })
198 .collect()
199 }
200}
201
202#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
205pub enum HeadingLevel {
206 H6,
207 H5,
208 H4,
209 H3,
210 H2,
211 H1,
212}
213
214impl From<pulldown_cmark::HeadingLevel> for HeadingLevel {
215 fn from(value: pulldown_cmark::HeadingLevel) -> Self {
216 match value {
217 pulldown_cmark::HeadingLevel::H1 => HeadingLevel::H1,
218 pulldown_cmark::HeadingLevel::H2 => HeadingLevel::H2,
219 pulldown_cmark::HeadingLevel::H3 => HeadingLevel::H3,
220 pulldown_cmark::HeadingLevel::H4 => HeadingLevel::H4,
221 pulldown_cmark::HeadingLevel::H5 => HeadingLevel::H5,
222 pulldown_cmark::HeadingLevel::H6 => HeadingLevel::H6,
223 }
224 }
225}
226
227#[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
229enum SemanticSplitPosition {
230 Own,
232 Next,
234}
235
236#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
240pub enum Element {
241 SoftBreak,
243 Inline,
246 Block,
248 Rule,
250 Heading(HeadingLevel),
252}
253
254impl Element {
255 fn split_position(self) -> SemanticSplitPosition {
256 match self {
257 Self::SoftBreak | Self::Block | Self::Rule | Self::Inline => SemanticSplitPosition::Own,
258 Self::Heading(_) => SemanticSplitPosition::Next,
260 }
261 }
262
263 fn treat_whitespace_as_previous(self) -> bool {
264 match self {
265 Self::SoftBreak | Self::Inline | Self::Rule | Self::Heading(_) => false,
266 Self::Block => true,
267 }
268 }
269}
270
271impl SemanticLevel for Element {
272 fn sections(
273 text: &str,
274 level_ranges: impl Iterator<Item = (Self, Range<usize>)>,
275 ) -> impl Iterator<Item = (usize, &str)> {
276 let mut cursor = 0;
277 let mut final_match = false;
278 level_ranges
279 .batching(move |it| {
280 loop {
281 match it.next() {
282 None if final_match => return None,
284 None => {
286 final_match = true;
287 return text.get(cursor..).map(|t| Either::Left(once((cursor, t))));
288 }
289 Some((level, range)) => {
291 let offset = cursor;
292 match level.split_position() {
293 SemanticSplitPosition::Own => {
294 if range.start < cursor {
295 continue;
296 }
297 let prev_section = text
298 .get(cursor..range.start)
299 .expect("invalid character sequence");
300 if level.treat_whitespace_as_previous()
301 && prev_section.chars().all(char::is_whitespace)
302 {
303 let section = text
304 .get(cursor..range.end)
305 .expect("invalid character sequence");
306 cursor = range.end;
307 return Some(Either::Left(once((offset, section))));
308 }
309 let separator = text
310 .get(range.start..range.end)
311 .expect("invalid character sequence");
312 cursor = range.end;
313 return Some(Either::Right(
314 [(offset, prev_section), (range.start, separator)]
315 .into_iter(),
316 ));
317 }
318 SemanticSplitPosition::Next => {
319 if range.start < cursor {
320 continue;
321 }
322 let prev_section = text
323 .get(cursor..range.start)
324 .expect("invalid character sequence");
325 cursor = range.start;
327 return Some(Either::Left(once((offset, prev_section))));
328 }
329 }
330 }
331 }
332 }
333 })
334 .flatten()
335 .filter(|(_, s)| !s.is_empty())
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use std::cmp::min;
342
343 use fake::{Fake, Faker};
344
345 use crate::splitter::SemanticSplitRanges;
346
347 use super::*;
348
349 #[test]
350 fn returns_one_chunk_if_text_is_shorter_than_max_chunk_size() {
351 let text = Faker.fake::<String>();
352 let chunks = MarkdownSplitter::new(ChunkConfig::new(text.chars().count()).with_trim(false))
353 .chunks(&text)
354 .collect::<Vec<_>>();
355
356 assert_eq!(vec![&text], chunks);
357 }
358
359 #[test]
360 fn returns_two_chunks_if_text_is_longer_than_max_chunk_size() {
361 let text1 = Faker.fake::<String>();
362 let text2 = Faker.fake::<String>();
363 let text = format!("{text1}{text2}");
364 let max_chunk_size = text.chars().count() / 2 + 1;
366
367 let chunks = MarkdownSplitter::new(ChunkConfig::new(max_chunk_size).with_trim(false))
368 .chunks(&text)
369 .collect::<Vec<_>>();
370
371 assert!(chunks.iter().all(|c| c.chars().count() <= max_chunk_size));
372
373 let len = min(text1.len(), chunks[0].len());
375 assert_eq!(text1[..len], chunks[0][..len]);
376 let len = min(text2.len(), chunks[1].len());
378 assert_eq!(
379 text2[(text2.len() - len)..],
380 chunks[1][chunks[1].len() - len..]
381 );
382
383 assert_eq!(chunks.join(""), text);
384 }
385
386 #[test]
387 fn empty_string() {
388 let text = "";
389 let chunks = MarkdownSplitter::new(ChunkConfig::new(100).with_trim(false))
390 .chunks(text)
391 .collect::<Vec<_>>();
392
393 assert!(chunks.is_empty());
394 }
395
396 #[test]
397 fn can_handle_unicode_characters() {
398 let text = "éé"; let chunks = MarkdownSplitter::new(ChunkConfig::new(1).with_trim(false))
400 .chunks(text)
401 .collect::<Vec<_>>();
402
403 assert_eq!(vec!["é", "é"], chunks);
404 }
405
406 #[test]
407 fn chunk_by_graphemes() {
408 let text = "a̐éö̲\r\n";
409 let chunks = MarkdownSplitter::new(ChunkConfig::new(3).with_trim(false))
410 .chunks(text)
411 .collect::<Vec<_>>();
412
413 assert_eq!(vec!["a̐é", "ö̲", "\r\n"], chunks);
415 }
416
417 #[test]
418 fn trim_char_indices() {
419 let text = " a b ";
420 let chunks = MarkdownSplitter::new(1)
421 .chunk_indices(text)
422 .collect::<Vec<_>>();
423
424 assert_eq!(vec![(1, "a"), (3, "b")], chunks);
425 }
426
427 #[test]
428 fn chunk_char_indices() {
429 let text = " a b ";
430 let chunks = MarkdownSplitter::new(1)
431 .chunk_char_indices(text)
432 .collect::<Vec<_>>();
433
434 assert_eq!(
435 vec![
436 ChunkCharIndex {
437 chunk: "a",
438 byte_offset: 1,
439 char_offset: 1
440 },
441 ChunkCharIndex {
442 chunk: "b",
443 byte_offset: 3,
444 char_offset: 3,
445 },
446 ],
447 chunks
448 );
449 }
450
451 #[test]
452 fn graphemes_fallback_to_chars() {
453 let text = "a̐éö̲\r\n";
454 let chunks = MarkdownSplitter::new(ChunkConfig::new(1).with_trim(false))
455 .chunks(text)
456 .collect::<Vec<_>>();
457
458 assert_eq!(
459 vec!["a", "\u{310}", "é", "ö", "\u{332}", "\r", "\n"],
460 chunks
461 );
462 }
463
464 #[test]
465 fn trim_grapheme_indices() {
466 let text = "\r\na̐éö̲\r\n";
467 let chunks = MarkdownSplitter::new(3)
468 .chunk_indices(text)
469 .collect::<Vec<_>>();
470
471 assert_eq!(vec![(2, "a̐é"), (7, "ö̲")], chunks);
472 }
473
474 #[test]
475 fn grapheme_char_indices() {
476 let text = "\r\na̐éö̲\r\n";
477 let chunks = MarkdownSplitter::new(3)
478 .chunk_char_indices(text)
479 .collect::<Vec<_>>();
480
481 assert_eq!(
482 vec![
483 ChunkCharIndex {
484 chunk: "a̐é",
485 byte_offset: 2,
486 char_offset: 2
487 },
488 ChunkCharIndex {
489 chunk: "ö̲",
490 byte_offset: 7,
491 char_offset: 5
492 }
493 ],
494 chunks
495 );
496 }
497
498 #[test]
499 fn chunk_by_words() {
500 let text = "The quick brown fox can jump 32.3 feet, right?";
501 let chunks = MarkdownSplitter::new(ChunkConfig::new(10).with_trim(false))
502 .chunks(text)
503 .collect::<Vec<_>>();
504
505 assert_eq!(
506 vec![
507 "The quick ",
508 "brown fox ",
509 "can jump ",
510 "32.3 feet,",
511 " right?"
512 ],
513 chunks
514 );
515 }
516
517 #[test]
518 fn words_fallback_to_graphemes() {
519 let text = "Thé quick\r\n";
520 let chunks = MarkdownSplitter::new(ChunkConfig::new(2).with_trim(false))
521 .chunks(text)
522 .collect::<Vec<_>>();
523
524 assert_eq!(vec!["Th", "é ", "qu", "ic", "k", "\r\n"], chunks);
525 }
526
527 #[test]
528 fn trim_word_indices() {
529 let text = "Some text from a document";
530 let chunks = MarkdownSplitter::new(10)
531 .chunk_indices(text)
532 .collect::<Vec<_>>();
533
534 assert_eq!(
535 vec![(0, "Some text"), (10, "from a"), (17, "document")],
536 chunks
537 );
538 }
539
540 #[test]
541 fn chunk_by_sentences() {
542 let text = "Mr. Fox jumped. The dog was too lazy.";
543 let chunks = MarkdownSplitter::new(ChunkConfig::new(21).with_trim(false))
544 .chunks(text)
545 .collect::<Vec<_>>();
546
547 assert_eq!(vec!["Mr. Fox jumped. ", "The dog was too lazy."], chunks);
548 }
549
550 #[test]
551 fn sentences_falls_back_to_words() {
552 let text = "Mr. Fox jumped. The dog was too lazy.";
553 let chunks = MarkdownSplitter::new(ChunkConfig::new(16).with_trim(false))
554 .chunks(text)
555 .collect::<Vec<_>>();
556
557 assert_eq!(
558 vec!["Mr. Fox jumped. ", "The dog was too ", "lazy."],
559 chunks
560 );
561 }
562
563 #[test]
564 fn trim_sentence_indices() {
565 let text = "Some text. From a document.";
566 let chunks = MarkdownSplitter::new(10)
567 .chunk_indices(text)
568 .collect::<Vec<_>>();
569
570 assert_eq!(
571 vec![(0, "Some text."), (11, "From a"), (18, "document.")],
572 chunks
573 );
574 }
575
576 #[test]
577 fn test_no_markdown_separators() {
578 let splitter = MarkdownSplitter::new(10);
579 let markdown =
580 SemanticSplitRanges::new(splitter.parse("Some text without any markdown separators"));
581
582 assert_eq!(
583 vec![(Element::Block, 0..41), (Element::Inline, 0..41)],
584 markdown.ranges_after_offset(0).collect::<Vec<_>>()
585 );
586 }
587
588 #[test]
589 fn test_checklist() {
590 let splitter = MarkdownSplitter::new(10);
591 let markdown =
592 SemanticSplitRanges::new(splitter.parse("- [ ] incomplete task\n- [x] completed task"));
593
594 assert_eq!(
595 vec![
596 (Element::Block, 0..42),
597 (Element::Block, 0..22),
598 (Element::Inline, 2..5),
599 (Element::Inline, 6..21),
600 (Element::Block, 22..42),
601 (Element::Inline, 24..27),
602 (Element::Inline, 28..42),
603 ],
604 markdown.ranges_after_offset(0).collect::<Vec<_>>()
605 );
606 }
607
608 #[test]
609 fn test_footnote_reference() {
610 let splitter = MarkdownSplitter::new(10);
611 let markdown = SemanticSplitRanges::new(splitter.parse("Footnote[^1]"));
612
613 assert_eq!(
614 vec![
615 (Element::Block, 0..12),
616 (Element::Inline, 0..8),
617 (Element::Inline, 8..12),
618 ],
619 markdown.ranges_after_offset(0).collect::<Vec<_>>()
620 );
621 }
622
623 #[test]
624 fn test_inline_code() {
625 let splitter = MarkdownSplitter::new(10);
626 let markdown = SemanticSplitRanges::new(splitter.parse("`bash`"));
627
628 assert_eq!(
629 vec![(Element::Block, 0..6), (Element::Inline, 0..6)],
630 markdown.ranges_after_offset(0).collect::<Vec<_>>()
631 );
632 }
633
634 #[test]
635 fn test_emphasis() {
636 let splitter = MarkdownSplitter::new(10);
637 let markdown = SemanticSplitRanges::new(splitter.parse("*emphasis*"));
638
639 assert_eq!(
640 vec![
641 (Element::Block, 0..10),
642 (Element::Inline, 0..10),
643 (Element::Inline, 1..9),
644 ],
645 markdown.ranges_after_offset(0).collect::<Vec<_>>()
646 );
647 }
648
649 #[test]
650 fn test_strong() {
651 let splitter = MarkdownSplitter::new(10);
652 let markdown = SemanticSplitRanges::new(splitter.parse("**emphasis**"));
653
654 assert_eq!(
655 vec![
656 (Element::Block, 0..12),
657 (Element::Inline, 0..12),
658 (Element::Inline, 2..10),
659 ],
660 markdown.ranges_after_offset(0).collect::<Vec<_>>()
661 );
662 }
663
664 #[test]
665 fn test_strikethrough() {
666 let splitter = MarkdownSplitter::new(10);
667 let markdown = SemanticSplitRanges::new(splitter.parse("~~emphasis~~"));
668
669 assert_eq!(
670 vec![
671 (Element::Block, 0..12),
672 (Element::Inline, 0..12),
673 (Element::Inline, 2..10),
674 ],
675 markdown.ranges_after_offset(0).collect::<Vec<_>>()
676 );
677 }
678
679 #[test]
680 fn test_link() {
681 let splitter = MarkdownSplitter::new(10);
682 let markdown = SemanticSplitRanges::new(splitter.parse("[link](url)"));
683
684 assert_eq!(
685 vec![
686 (Element::Block, 0..11),
687 (Element::Inline, 0..11),
688 (Element::Inline, 1..5),
689 ],
690 markdown.ranges_after_offset(0).collect::<Vec<_>>()
691 );
692 }
693
694 #[test]
695 fn test_image() {
696 let splitter = MarkdownSplitter::new(10);
697 let markdown = SemanticSplitRanges::new(splitter.parse(""));
698
699 assert_eq!(
700 vec![
701 (Element::Block, 0..12),
702 (Element::Inline, 0..12),
703 (Element::Inline, 2..6),
704 ],
705 markdown.ranges_after_offset(0).collect::<Vec<_>>()
706 );
707 }
708
709 #[test]
710 fn test_inline_html() {
711 let splitter = MarkdownSplitter::new(10);
712 let markdown = SemanticSplitRanges::new(splitter.parse("<span>Some text</span>"));
713
714 assert_eq!(
715 vec![
716 (Element::Block, 0..22),
717 (Element::Inline, 0..6),
718 (Element::Inline, 6..15),
719 (Element::Inline, 15..22),
720 ],
721 markdown.ranges_after_offset(0).collect::<Vec<_>>()
722 );
723 }
724
725 #[test]
726 fn test_html() {
727 let splitter = MarkdownSplitter::new(10);
728 let markdown = SemanticSplitRanges::new(splitter.parse("<div>Some text</div>"));
729
730 assert_eq!(
731 vec![(Element::Block, 0..20), (Element::Block, 0..20)],
732 markdown.ranges_after_offset(0).collect::<Vec<_>>()
733 );
734 }
735
736 #[test]
737 fn test_table() {
738 let splitter = MarkdownSplitter::new(10);
739 let markdown = SemanticSplitRanges::new(
740 splitter.parse("| Header 1 | Header 2 |\n| --- | --- |\n| Cell 1 | Cell 2 |"),
741 );
742 assert_eq!(
743 vec![
744 (Element::Block, 0..57),
745 (Element::Block, 0..24),
746 (Element::Inline, 1..11),
747 (Element::Inline, 2..10),
748 (Element::Inline, 12..22),
749 (Element::Inline, 13..21),
750 (Element::Block, 38..57),
751 (Element::Inline, 39..47),
752 (Element::Inline, 40..46),
753 (Element::Inline, 48..56),
754 (Element::Inline, 49..55)
755 ],
756 markdown.ranges_after_offset(0).collect::<Vec<_>>()
757 );
758 }
759
760 #[test]
761 fn test_softbreak() {
762 let splitter = MarkdownSplitter::new(10);
763 let markdown = SemanticSplitRanges::new(splitter.parse("Some text\nwith a softbreak"));
764
765 assert_eq!(
766 vec![
767 (Element::Block, 0..26),
768 (Element::Inline, 0..9),
769 (Element::SoftBreak, 9..10),
770 (Element::Inline, 10..26)
771 ],
772 markdown.ranges_after_offset(0).collect::<Vec<_>>()
773 );
774 }
775
776 #[test]
777 fn test_hardbreak() {
778 let splitter = MarkdownSplitter::new(10);
779 let markdown = SemanticSplitRanges::new(splitter.parse("Some text\\\nwith a hardbreak"));
780
781 assert_eq!(
782 vec![
783 (Element::Block, 0..27),
784 (Element::Inline, 0..9),
785 (Element::Inline, 9..11),
786 (Element::Inline, 11..27)
787 ],
788 markdown.ranges_after_offset(0).collect::<Vec<_>>()
789 );
790 }
791
792 #[test]
793 fn test_footnote_def() {
794 let splitter = MarkdownSplitter::new(10);
795 let markdown = SemanticSplitRanges::new(splitter.parse("[^first]: Footnote"));
796
797 assert_eq!(
798 vec![
799 (Element::Block, 0..18),
800 (Element::Block, 10..18),
801 (Element::Inline, 10..18)
802 ],
803 markdown.ranges_after_offset(0).collect::<Vec<_>>()
804 );
805 }
806
807 #[test]
808 fn test_code_block() {
809 let splitter = MarkdownSplitter::new(10);
810 let markdown = SemanticSplitRanges::new(splitter.parse("```\ncode\n```"));
811
812 assert_eq!(
813 vec![(Element::Block, 0..12), (Element::Inline, 4..9)],
814 markdown.ranges_after_offset(0).collect::<Vec<_>>()
815 );
816 }
817
818 #[test]
819 fn test_block_quote() {
820 let splitter = MarkdownSplitter::new(10);
821 let markdown = SemanticSplitRanges::new(splitter.parse("> quote"));
822
823 assert_eq!(
824 vec![
825 (Element::Block, 0..7),
826 (Element::Block, 2..7),
827 (Element::Inline, 2..7)
828 ],
829 markdown.ranges_after_offset(0).collect::<Vec<_>>()
830 );
831 }
832
833 #[test]
834 fn test_with_rule() {
835 let splitter = MarkdownSplitter::new(10);
836 let markdown = SemanticSplitRanges::new(splitter.parse("Some text\n\n---\n\nwith a rule"));
837
838 assert_eq!(
839 vec![
840 (Element::Block, 0..10),
841 (Element::Inline, 0..9),
842 (Element::Rule, 11..15),
843 (Element::Block, 16..27),
844 (Element::Inline, 16..27)
845 ],
846 markdown.ranges_after_offset(0).collect::<Vec<_>>()
847 );
848 }
849
850 #[test]
851 fn test_heading() {
852 for (index, (heading, level)) in [
853 ("#", HeadingLevel::H1),
854 ("##", HeadingLevel::H2),
855 ("###", HeadingLevel::H3),
856 ("####", HeadingLevel::H4),
857 ("#####", HeadingLevel::H5),
858 ("######", HeadingLevel::H6),
859 ]
860 .into_iter()
861 .enumerate()
862 {
863 let splitter = MarkdownSplitter::new(10);
864 let markdown = SemanticSplitRanges::new(splitter.parse(&format!("{heading} Heading")));
865
866 assert_eq!(
867 vec![
868 (Element::Heading(level), 0..9 + index),
869 (Element::Inline, 2 + index..9 + index)
870 ],
871 markdown.ranges_after_offset(0).collect::<Vec<_>>()
872 );
873 }
874 }
875
876 #[test]
877 fn test_ranges_after_offset_block() {
878 let splitter = MarkdownSplitter::new(10);
879 let markdown =
880 SemanticSplitRanges::new(splitter.parse("- [ ] incomplete task\n- [x] completed task"));
881
882 assert_eq!(
883 vec![(Element::Block, 0..22), (Element::Block, 22..42),],
884 markdown
885 .level_ranges_after_offset(0, Element::Block)
886 .collect::<Vec<_>>()
887 );
888 }
889}