1use std::{iter::once, ops::Range};
8
9use either::Either;
10use itertools::Itertools;
11use pulldown_cmark::{Event, Options, Parser, Tag};
12
13use crate::{
14 splitter::{SemanticLevel, Splitter},
15 trim::Trim,
16 ChunkConfig, ChunkSizer,
17};
18
19use super::ChunkCharIndex;
20
21#[derive(Debug)]
26pub struct MarkdownSplitter<Sizer>
27where
28 Sizer: ChunkSizer,
29{
30 chunk_config: ChunkConfig<Sizer>,
32}
33
34impl<Sizer> MarkdownSplitter<Sizer>
35where
36 Sizer: ChunkSizer,
37{
38 #[must_use]
47 pub fn new(chunk_config: impl Into<ChunkConfig<Sizer>>) -> Self {
48 Self {
49 chunk_config: chunk_config.into(),
50 }
51 }
52
53 pub fn chunks<'splitter, 'text: 'splitter>(
84 &'splitter self,
85 text: &'text str,
86 ) -> impl Iterator<Item = &'text str> + 'splitter {
87 Splitter::<_>::chunks(self, text)
88 }
89
90 pub fn chunk_indices<'splitter, 'text: 'splitter>(
105 &'splitter self,
106 text: &'text str,
107 ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
108 Splitter::<_>::chunk_indices(self, text)
109 }
110
111 pub fn chunk_char_indices<'splitter, 'text: 'splitter>(
131 &'splitter self,
132 text: &'text str,
133 ) -> impl Iterator<Item = ChunkCharIndex<'text>> + 'splitter {
134 Splitter::<_>::chunk_char_indices(self, text)
135 }
136}
137
138impl<Sizer> Splitter<Sizer> for MarkdownSplitter<Sizer>
139where
140 Sizer: ChunkSizer,
141{
142 type Level = Element;
143
144 const TRIM: Trim = Trim::PreserveIndentation;
145
146 fn chunk_config(&self) -> &ChunkConfig<Sizer> {
147 &self.chunk_config
148 }
149
150 fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
151 Parser::new_ext(text, Options::all())
152 .into_offset_iter()
153 .filter_map(|(event, range)| match event {
154 Event::Start(
155 Tag::Emphasis
156 | Tag::Strong
157 | Tag::Strikethrough
158 | Tag::Link { .. }
159 | Tag::Image { .. }
160 | Tag::Subscript
161 | Tag::Superscript
162 | Tag::TableCell,
163 )
164 | Event::Text(_)
165 | Event::HardBreak
166 | Event::Code(_)
167 | Event::InlineHtml(_)
168 | Event::InlineMath(_)
169 | Event::FootnoteReference(_)
170 | Event::TaskListMarker(_) => Some((Element::Inline, range)),
171 Event::SoftBreak => Some((Element::SoftBreak, range)),
172 Event::Html(_)
173 | Event::DisplayMath(_)
174 | Event::Start(
175 Tag::Paragraph
176 | Tag::CodeBlock(_)
177 | Tag::FootnoteDefinition(_)
178 | Tag::MetadataBlock(_)
179 | Tag::TableHead
180 | Tag::BlockQuote(_)
181 | Tag::TableRow
182 | Tag::Item
183 | Tag::HtmlBlock
184 | Tag::List(_)
185 | Tag::Table(_)
186 | Tag::DefinitionList
187 | Tag::DefinitionListTitle
188 | Tag::DefinitionListDefinition,
189 ) => Some((Element::Block, range)),
190 Event::Rule => Some((Element::Rule, range)),
191 Event::Start(Tag::Heading { level, .. }) => {
192 Some((Element::Heading(level.into()), range))
193 }
194 Event::End(_) => None,
196 })
197 .collect()
198 }
199}
200
201#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
204pub enum HeadingLevel {
205 H6,
206 H5,
207 H4,
208 H3,
209 H2,
210 H1,
211}
212
213impl From<pulldown_cmark::HeadingLevel> for HeadingLevel {
214 fn from(value: pulldown_cmark::HeadingLevel) -> Self {
215 match value {
216 pulldown_cmark::HeadingLevel::H1 => HeadingLevel::H1,
217 pulldown_cmark::HeadingLevel::H2 => HeadingLevel::H2,
218 pulldown_cmark::HeadingLevel::H3 => HeadingLevel::H3,
219 pulldown_cmark::HeadingLevel::H4 => HeadingLevel::H4,
220 pulldown_cmark::HeadingLevel::H5 => HeadingLevel::H5,
221 pulldown_cmark::HeadingLevel::H6 => HeadingLevel::H6,
222 }
223 }
224}
225
226#[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
228enum SemanticSplitPosition {
229 Own,
231 Next,
233}
234
235#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
239pub enum Element {
240 SoftBreak,
242 Inline,
245 Block,
247 Rule,
249 Heading(HeadingLevel),
251}
252
253impl Element {
254 fn split_position(self) -> SemanticSplitPosition {
255 match self {
256 Self::SoftBreak | Self::Block | Self::Rule | Self::Inline => SemanticSplitPosition::Own,
257 Self::Heading(_) => SemanticSplitPosition::Next,
259 }
260 }
261
262 fn treat_whitespace_as_previous(self) -> bool {
263 match self {
264 Self::SoftBreak | Self::Inline | Self::Rule | Self::Heading(_) => false,
265 Self::Block => true,
266 }
267 }
268}
269
270impl SemanticLevel for Element {
271 fn sections(
272 text: &str,
273 level_ranges: impl Iterator<Item = (Self, Range<usize>)>,
274 ) -> impl Iterator<Item = (usize, &str)> {
275 let mut cursor = 0;
276 let mut final_match = false;
277 level_ranges
278 .batching(move |it| {
279 loop {
280 match it.next() {
281 None if final_match => return None,
283 None => {
285 final_match = true;
286 return text.get(cursor..).map(|t| Either::Left(once((cursor, t))));
287 }
288 Some((level, range)) => {
290 let offset = cursor;
291 match level.split_position() {
292 SemanticSplitPosition::Own => {
293 if range.start < cursor {
294 continue;
295 }
296 let prev_section = text
297 .get(cursor..range.start)
298 .expect("invalid character sequence");
299 if level.treat_whitespace_as_previous()
300 && prev_section.chars().all(char::is_whitespace)
301 {
302 let section = text
303 .get(cursor..range.end)
304 .expect("invalid character sequence");
305 cursor = range.end;
306 return Some(Either::Left(once((offset, section))));
307 }
308 let separator = text
309 .get(range.start..range.end)
310 .expect("invalid character sequence");
311 cursor = range.end;
312 return Some(Either::Right(
313 [(offset, prev_section), (range.start, separator)]
314 .into_iter(),
315 ));
316 }
317 SemanticSplitPosition::Next => {
318 if range.start < cursor {
319 continue;
320 }
321 let prev_section = text
322 .get(cursor..range.start)
323 .expect("invalid character sequence");
324 cursor = range.start;
326 return Some(Either::Left(once((offset, prev_section))));
327 }
328 }
329 }
330 }
331 }
332 })
333 .flatten()
334 .filter(|(_, s)| !s.is_empty())
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use std::cmp::min;
341
342 use fake::{Fake, Faker};
343
344 use crate::splitter::SemanticSplitRanges;
345
346 use super::*;
347
348 #[test]
349 fn returns_one_chunk_if_text_is_shorter_than_max_chunk_size() {
350 let text = Faker.fake::<String>();
351 let chunks = MarkdownSplitter::new(ChunkConfig::new(text.chars().count()).with_trim(false))
352 .chunks(&text)
353 .collect::<Vec<_>>();
354
355 assert_eq!(vec![&text], chunks);
356 }
357
358 #[test]
359 fn returns_two_chunks_if_text_is_longer_than_max_chunk_size() {
360 let text1 = Faker.fake::<String>();
361 let text2 = Faker.fake::<String>();
362 let text = format!("{text1}{text2}");
363 let max_chunk_size = text.chars().count() / 2 + 1;
365
366 let chunks = MarkdownSplitter::new(ChunkConfig::new(max_chunk_size).with_trim(false))
367 .chunks(&text)
368 .collect::<Vec<_>>();
369
370 assert!(chunks.iter().all(|c| c.chars().count() <= max_chunk_size));
371
372 let len = min(text1.len(), chunks[0].len());
374 assert_eq!(text1[..len], chunks[0][..len]);
375 let len = min(text2.len(), chunks[1].len());
377 assert_eq!(
378 text2[(text2.len() - len)..],
379 chunks[1][chunks[1].len() - len..]
380 );
381
382 assert_eq!(chunks.join(""), text);
383 }
384
385 #[test]
386 fn empty_string() {
387 let text = "";
388 let chunks = MarkdownSplitter::new(ChunkConfig::new(100).with_trim(false))
389 .chunks(text)
390 .collect::<Vec<_>>();
391
392 assert!(chunks.is_empty());
393 }
394
395 #[test]
396 fn can_handle_unicode_characters() {
397 let text = "éé"; let chunks = MarkdownSplitter::new(ChunkConfig::new(1).with_trim(false))
399 .chunks(text)
400 .collect::<Vec<_>>();
401
402 assert_eq!(vec!["é", "é"], chunks);
403 }
404
405 #[test]
406 fn chunk_by_graphemes() {
407 let text = "a̐éö̲\r\n";
408 let chunks = MarkdownSplitter::new(ChunkConfig::new(3).with_trim(false))
409 .chunks(text)
410 .collect::<Vec<_>>();
411
412 assert_eq!(vec!["a̐é", "ö̲", "\r\n"], chunks);
414 }
415
416 #[test]
417 fn trim_char_indices() {
418 let text = " a b ";
419 let chunks = MarkdownSplitter::new(1)
420 .chunk_indices(text)
421 .collect::<Vec<_>>();
422
423 assert_eq!(vec![(1, "a"), (3, "b")], chunks);
424 }
425
426 #[test]
427 fn chunk_char_indices() {
428 let text = " a b ";
429 let chunks = MarkdownSplitter::new(1)
430 .chunk_char_indices(text)
431 .collect::<Vec<_>>();
432
433 assert_eq!(
434 vec![
435 ChunkCharIndex {
436 chunk: "a",
437 byte_offset: 1,
438 char_offset: 1
439 },
440 ChunkCharIndex {
441 chunk: "b",
442 byte_offset: 3,
443 char_offset: 3,
444 },
445 ],
446 chunks
447 );
448 }
449
450 #[test]
451 fn graphemes_fallback_to_chars() {
452 let text = "a̐éö̲\r\n";
453 let chunks = MarkdownSplitter::new(ChunkConfig::new(1).with_trim(false))
454 .chunks(text)
455 .collect::<Vec<_>>();
456
457 assert_eq!(
458 vec!["a", "\u{310}", "é", "ö", "\u{332}", "\r", "\n"],
459 chunks
460 );
461 }
462
463 #[test]
464 fn trim_grapheme_indices() {
465 let text = "\r\na̐éö̲\r\n";
466 let chunks = MarkdownSplitter::new(3)
467 .chunk_indices(text)
468 .collect::<Vec<_>>();
469
470 assert_eq!(vec![(2, "a̐é"), (7, "ö̲")], chunks);
471 }
472
473 #[test]
474 fn grapheme_char_indices() {
475 let text = "\r\na̐éö̲\r\n";
476 let chunks = MarkdownSplitter::new(3)
477 .chunk_char_indices(text)
478 .collect::<Vec<_>>();
479
480 assert_eq!(
481 vec![
482 ChunkCharIndex {
483 chunk: "a̐é",
484 byte_offset: 2,
485 char_offset: 2
486 },
487 ChunkCharIndex {
488 chunk: "ö̲",
489 byte_offset: 7,
490 char_offset: 5
491 }
492 ],
493 chunks
494 );
495 }
496
497 #[test]
498 fn chunk_by_words() {
499 let text = "The quick brown fox can jump 32.3 feet, right?";
500 let chunks = MarkdownSplitter::new(ChunkConfig::new(10).with_trim(false))
501 .chunks(text)
502 .collect::<Vec<_>>();
503
504 assert_eq!(
505 vec![
506 "The quick ",
507 "brown fox ",
508 "can jump ",
509 "32.3 feet,",
510 " right?"
511 ],
512 chunks
513 );
514 }
515
516 #[test]
517 fn words_fallback_to_graphemes() {
518 let text = "Thé quick\r\n";
519 let chunks = MarkdownSplitter::new(ChunkConfig::new(2).with_trim(false))
520 .chunks(text)
521 .collect::<Vec<_>>();
522
523 assert_eq!(vec!["Th", "é ", "qu", "ic", "k", "\r\n"], chunks);
524 }
525
526 #[test]
527 fn trim_word_indices() {
528 let text = "Some text from a document";
529 let chunks = MarkdownSplitter::new(10)
530 .chunk_indices(text)
531 .collect::<Vec<_>>();
532
533 assert_eq!(
534 vec![(0, "Some text"), (10, "from a"), (17, "document")],
535 chunks
536 );
537 }
538
539 #[test]
540 fn chunk_by_sentences() {
541 let text = "Mr. Fox jumped. The dog was too lazy.";
542 let chunks = MarkdownSplitter::new(ChunkConfig::new(21).with_trim(false))
543 .chunks(text)
544 .collect::<Vec<_>>();
545
546 assert_eq!(vec!["Mr. Fox jumped. ", "The dog was too lazy."], chunks);
547 }
548
549 #[test]
550 fn sentences_falls_back_to_words() {
551 let text = "Mr. Fox jumped. The dog was too lazy.";
552 let chunks = MarkdownSplitter::new(ChunkConfig::new(16).with_trim(false))
553 .chunks(text)
554 .collect::<Vec<_>>();
555
556 assert_eq!(
557 vec!["Mr. Fox jumped. ", "The dog was too ", "lazy."],
558 chunks
559 );
560 }
561
562 #[test]
563 fn trim_sentence_indices() {
564 let text = "Some text. From a document.";
565 let chunks = MarkdownSplitter::new(10)
566 .chunk_indices(text)
567 .collect::<Vec<_>>();
568
569 assert_eq!(
570 vec![(0, "Some text."), (11, "From a"), (18, "document.")],
571 chunks
572 );
573 }
574
575 #[test]
576 fn test_no_markdown_separators() {
577 let splitter = MarkdownSplitter::new(10);
578 let markdown =
579 SemanticSplitRanges::new(splitter.parse("Some text without any markdown separators"));
580
581 assert_eq!(
582 vec![(Element::Block, 0..41), (Element::Inline, 0..41)],
583 markdown.ranges_after_offset(0).collect::<Vec<_>>()
584 );
585 }
586
587 #[test]
588 fn test_checklist() {
589 let splitter = MarkdownSplitter::new(10);
590 let markdown =
591 SemanticSplitRanges::new(splitter.parse("- [ ] incomplete task\n- [x] completed task"));
592
593 assert_eq!(
594 vec![
595 (Element::Block, 0..42),
596 (Element::Block, 0..22),
597 (Element::Inline, 2..5),
598 (Element::Inline, 6..21),
599 (Element::Block, 22..42),
600 (Element::Inline, 24..27),
601 (Element::Inline, 28..42),
602 ],
603 markdown.ranges_after_offset(0).collect::<Vec<_>>()
604 );
605 }
606
607 #[test]
608 fn test_footnote_reference() {
609 let splitter = MarkdownSplitter::new(10);
610 let markdown = SemanticSplitRanges::new(splitter.parse("Footnote[^1]"));
611
612 assert_eq!(
613 vec![
614 (Element::Block, 0..12),
615 (Element::Inline, 0..8),
616 (Element::Inline, 8..12),
617 ],
618 markdown.ranges_after_offset(0).collect::<Vec<_>>()
619 );
620 }
621
622 #[test]
623 fn test_inline_code() {
624 let splitter = MarkdownSplitter::new(10);
625 let markdown = SemanticSplitRanges::new(splitter.parse("`bash`"));
626
627 assert_eq!(
628 vec![(Element::Block, 0..6), (Element::Inline, 0..6)],
629 markdown.ranges_after_offset(0).collect::<Vec<_>>()
630 );
631 }
632
633 #[test]
634 fn test_emphasis() {
635 let splitter = MarkdownSplitter::new(10);
636 let markdown = SemanticSplitRanges::new(splitter.parse("*emphasis*"));
637
638 assert_eq!(
639 vec![
640 (Element::Block, 0..10),
641 (Element::Inline, 0..10),
642 (Element::Inline, 1..9),
643 ],
644 markdown.ranges_after_offset(0).collect::<Vec<_>>()
645 );
646 }
647
648 #[test]
649 fn test_strong() {
650 let splitter = MarkdownSplitter::new(10);
651 let markdown = SemanticSplitRanges::new(splitter.parse("**emphasis**"));
652
653 assert_eq!(
654 vec![
655 (Element::Block, 0..12),
656 (Element::Inline, 0..12),
657 (Element::Inline, 2..10),
658 ],
659 markdown.ranges_after_offset(0).collect::<Vec<_>>()
660 );
661 }
662
663 #[test]
664 fn test_strikethrough() {
665 let splitter = MarkdownSplitter::new(10);
666 let markdown = SemanticSplitRanges::new(splitter.parse("~~emphasis~~"));
667
668 assert_eq!(
669 vec![
670 (Element::Block, 0..12),
671 (Element::Inline, 0..12),
672 (Element::Inline, 2..10),
673 ],
674 markdown.ranges_after_offset(0).collect::<Vec<_>>()
675 );
676 }
677
678 #[test]
679 fn test_link() {
680 let splitter = MarkdownSplitter::new(10);
681 let markdown = SemanticSplitRanges::new(splitter.parse("[link](url)"));
682
683 assert_eq!(
684 vec![
685 (Element::Block, 0..11),
686 (Element::Inline, 0..11),
687 (Element::Inline, 1..5),
688 ],
689 markdown.ranges_after_offset(0).collect::<Vec<_>>()
690 );
691 }
692
693 #[test]
694 fn test_image() {
695 let splitter = MarkdownSplitter::new(10);
696 let markdown = SemanticSplitRanges::new(splitter.parse(""));
697
698 assert_eq!(
699 vec![
700 (Element::Block, 0..12),
701 (Element::Inline, 0..12),
702 (Element::Inline, 2..6),
703 ],
704 markdown.ranges_after_offset(0).collect::<Vec<_>>()
705 );
706 }
707
708 #[test]
709 fn test_inline_html() {
710 let splitter = MarkdownSplitter::new(10);
711 let markdown = SemanticSplitRanges::new(splitter.parse("<span>Some text</span>"));
712
713 assert_eq!(
714 vec![
715 (Element::Block, 0..22),
716 (Element::Inline, 0..6),
717 (Element::Inline, 6..15),
718 (Element::Inline, 15..22),
719 ],
720 markdown.ranges_after_offset(0).collect::<Vec<_>>()
721 );
722 }
723
724 #[test]
725 fn test_html() {
726 let splitter = MarkdownSplitter::new(10);
727 let markdown = SemanticSplitRanges::new(splitter.parse("<div>Some text</div>"));
728
729 assert_eq!(
730 vec![(Element::Block, 0..20), (Element::Block, 0..20)],
731 markdown.ranges_after_offset(0).collect::<Vec<_>>()
732 );
733 }
734
735 #[test]
736 fn test_table() {
737 let splitter = MarkdownSplitter::new(10);
738 let markdown = SemanticSplitRanges::new(
739 splitter.parse("| Header 1 | Header 2 |\n| --- | --- |\n| Cell 1 | Cell 2 |"),
740 );
741 assert_eq!(
742 vec![
743 (Element::Block, 0..57),
744 (Element::Block, 0..24),
745 (Element::Inline, 1..11),
746 (Element::Inline, 2..10),
747 (Element::Inline, 12..22),
748 (Element::Inline, 13..21),
749 (Element::Block, 38..57),
750 (Element::Inline, 39..47),
751 (Element::Inline, 40..46),
752 (Element::Inline, 48..56),
753 (Element::Inline, 49..55)
754 ],
755 markdown.ranges_after_offset(0).collect::<Vec<_>>()
756 );
757 }
758
759 #[test]
760 fn test_softbreak() {
761 let splitter = MarkdownSplitter::new(10);
762 let markdown = SemanticSplitRanges::new(splitter.parse("Some text\nwith a softbreak"));
763
764 assert_eq!(
765 vec![
766 (Element::Block, 0..26),
767 (Element::Inline, 0..9),
768 (Element::SoftBreak, 9..10),
769 (Element::Inline, 10..26)
770 ],
771 markdown.ranges_after_offset(0).collect::<Vec<_>>()
772 );
773 }
774
775 #[test]
776 fn test_hardbreak() {
777 let splitter = MarkdownSplitter::new(10);
778 let markdown = SemanticSplitRanges::new(splitter.parse("Some text\\\nwith a hardbreak"));
779
780 assert_eq!(
781 vec![
782 (Element::Block, 0..27),
783 (Element::Inline, 0..9),
784 (Element::Inline, 9..11),
785 (Element::Inline, 11..27)
786 ],
787 markdown.ranges_after_offset(0).collect::<Vec<_>>()
788 );
789 }
790
791 #[test]
792 fn test_footnote_def() {
793 let splitter = MarkdownSplitter::new(10);
794 let markdown = SemanticSplitRanges::new(splitter.parse("[^first]: Footnote"));
795
796 assert_eq!(
797 vec![
798 (Element::Block, 0..18),
799 (Element::Block, 10..18),
800 (Element::Inline, 10..18)
801 ],
802 markdown.ranges_after_offset(0).collect::<Vec<_>>()
803 );
804 }
805
806 #[test]
807 fn test_code_block() {
808 let splitter = MarkdownSplitter::new(10);
809 let markdown = SemanticSplitRanges::new(splitter.parse("```\ncode\n```"));
810
811 assert_eq!(
812 vec![(Element::Block, 0..12), (Element::Inline, 4..9)],
813 markdown.ranges_after_offset(0).collect::<Vec<_>>()
814 );
815 }
816
817 #[test]
818 fn test_block_quote() {
819 let splitter = MarkdownSplitter::new(10);
820 let markdown = SemanticSplitRanges::new(splitter.parse("> quote"));
821
822 assert_eq!(
823 vec![
824 (Element::Block, 0..7),
825 (Element::Block, 2..7),
826 (Element::Inline, 2..7)
827 ],
828 markdown.ranges_after_offset(0).collect::<Vec<_>>()
829 );
830 }
831
832 #[test]
833 fn test_with_rule() {
834 let splitter = MarkdownSplitter::new(10);
835 let markdown = SemanticSplitRanges::new(splitter.parse("Some text\n\n---\n\nwith a rule"));
836
837 assert_eq!(
838 vec![
839 (Element::Block, 0..10),
840 (Element::Inline, 0..9),
841 (Element::Rule, 11..15),
842 (Element::Block, 16..27),
843 (Element::Inline, 16..27)
844 ],
845 markdown.ranges_after_offset(0).collect::<Vec<_>>()
846 );
847 }
848
849 #[test]
850 fn test_heading() {
851 for (index, (heading, level)) in [
852 ("#", HeadingLevel::H1),
853 ("##", HeadingLevel::H2),
854 ("###", HeadingLevel::H3),
855 ("####", HeadingLevel::H4),
856 ("#####", HeadingLevel::H5),
857 ("######", HeadingLevel::H6),
858 ]
859 .into_iter()
860 .enumerate()
861 {
862 let splitter = MarkdownSplitter::new(10);
863 let markdown = SemanticSplitRanges::new(splitter.parse(&format!("{heading} Heading")));
864
865 assert_eq!(
866 vec![
867 (Element::Heading(level), 0..9 + index),
868 (Element::Inline, 2 + index..9 + index)
869 ],
870 markdown.ranges_after_offset(0).collect::<Vec<_>>()
871 );
872 }
873 }
874
875 #[test]
876 fn test_ranges_after_offset_block() {
877 let splitter = MarkdownSplitter::new(10);
878 let markdown =
879 SemanticSplitRanges::new(splitter.parse("- [ ] incomplete task\n- [x] completed task"));
880
881 assert_eq!(
882 vec![(Element::Block, 0..22), (Element::Block, 22..42),],
883 markdown
884 .level_ranges_after_offset(0, Element::Block)
885 .collect::<Vec<_>>()
886 );
887 }
888}