1use std::{iter::once, ops::Range};
8
9use either::Either;
10use itertools::Itertools;
11use pulldown_cmark::{Event, Options, Parser, Tag};
12
13use crate::{
14 splitter::{SemanticLevel, Splitter},
15 trim::Trim,
16 ChunkConfig, ChunkSizer,
17};
18
19#[derive(Debug)]
24#[allow(clippy::module_name_repetitions)]
25pub struct MarkdownSplitter<Sizer>
26where
27 Sizer: ChunkSizer,
28{
29 chunk_config: ChunkConfig<Sizer>,
31}
32
33impl<Sizer> MarkdownSplitter<Sizer>
34where
35 Sizer: ChunkSizer,
36{
37 #[must_use]
46 pub fn new(chunk_config: impl Into<ChunkConfig<Sizer>>) -> Self {
47 Self {
48 chunk_config: chunk_config.into(),
49 }
50 }
51
52 pub fn chunks<'splitter, 'text: 'splitter>(
83 &'splitter self,
84 text: &'text str,
85 ) -> impl Iterator<Item = &'text str> + 'splitter {
86 Splitter::<_>::chunks(self, text)
87 }
88
89 pub fn chunk_indices<'splitter, 'text: 'splitter>(
103 &'splitter self,
104 text: &'text str,
105 ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
106 Splitter::<_>::chunk_indices(self, text)
107 }
108}
109
110impl<Sizer> Splitter<Sizer> for MarkdownSplitter<Sizer>
111where
112 Sizer: ChunkSizer,
113{
114 type Level = Element;
115
116 const TRIM: Trim = Trim::PreserveIndentation;
117
118 fn chunk_config(&self) -> &ChunkConfig<Sizer> {
119 &self.chunk_config
120 }
121
122 fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
123 Parser::new_ext(text, Options::all())
124 .into_offset_iter()
125 .filter_map(|(event, range)| match event {
126 Event::Start(
127 Tag::Emphasis
128 | Tag::Strong
129 | Tag::Strikethrough
130 | Tag::Link { .. }
131 | Tag::Image { .. }
132 | Tag::Subscript
133 | Tag::Superscript
134 | Tag::TableCell,
135 )
136 | Event::Text(_)
137 | Event::HardBreak
138 | Event::Code(_)
139 | Event::InlineHtml(_)
140 | Event::InlineMath(_)
141 | Event::FootnoteReference(_)
142 | Event::TaskListMarker(_) => Some((Element::Inline, range)),
143 Event::SoftBreak => Some((Element::SoftBreak, range)),
144 Event::Html(_)
145 | Event::DisplayMath(_)
146 | Event::Start(
147 Tag::Paragraph
148 | Tag::CodeBlock(_)
149 | Tag::FootnoteDefinition(_)
150 | Tag::MetadataBlock(_)
151 | Tag::TableHead
152 | Tag::BlockQuote(_)
153 | Tag::TableRow
154 | Tag::Item
155 | Tag::HtmlBlock
156 | Tag::List(_)
157 | Tag::Table(_)
158 | Tag::DefinitionList
159 | Tag::DefinitionListTitle
160 | Tag::DefinitionListDefinition,
161 ) => Some((Element::Block, range)),
162 Event::Rule => Some((Element::Rule, range)),
163 Event::Start(Tag::Heading { level, .. }) => {
164 Some((Element::Heading(level.into()), range))
165 }
166 Event::End(_) => None,
168 })
169 .collect()
170 }
171}
172
173#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
176pub enum HeadingLevel {
177 H6,
178 H5,
179 H4,
180 H3,
181 H2,
182 H1,
183}
184
185impl From<pulldown_cmark::HeadingLevel> for HeadingLevel {
186 fn from(value: pulldown_cmark::HeadingLevel) -> Self {
187 match value {
188 pulldown_cmark::HeadingLevel::H1 => HeadingLevel::H1,
189 pulldown_cmark::HeadingLevel::H2 => HeadingLevel::H2,
190 pulldown_cmark::HeadingLevel::H3 => HeadingLevel::H3,
191 pulldown_cmark::HeadingLevel::H4 => HeadingLevel::H4,
192 pulldown_cmark::HeadingLevel::H5 => HeadingLevel::H5,
193 pulldown_cmark::HeadingLevel::H6 => HeadingLevel::H6,
194 }
195 }
196}
197
198#[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
200enum SemanticSplitPosition {
201 Own,
203 Next,
205}
206
207#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
211pub enum Element {
212 SoftBreak,
214 Inline,
217 Block,
219 Rule,
221 Heading(HeadingLevel),
223}
224
225impl Element {
226 fn split_position(self) -> SemanticSplitPosition {
227 match self {
228 Self::SoftBreak | Self::Block | Self::Rule | Self::Inline => SemanticSplitPosition::Own,
229 Self::Heading(_) => SemanticSplitPosition::Next,
231 }
232 }
233
234 fn treat_whitespace_as_previous(self) -> bool {
235 match self {
236 Self::SoftBreak | Self::Inline | Self::Rule | Self::Heading(_) => false,
237 Self::Block => true,
238 }
239 }
240}
241
242impl SemanticLevel for Element {
243 fn sections(
244 text: &str,
245 level_ranges: impl Iterator<Item = (Self, Range<usize>)>,
246 ) -> impl Iterator<Item = (usize, &str)> {
247 let mut cursor = 0;
248 let mut final_match = false;
249 level_ranges
250 .batching(move |it| {
251 loop {
252 match it.next() {
253 None if final_match => return None,
255 None => {
257 final_match = true;
258 return text.get(cursor..).map(|t| Either::Left(once((cursor, t))));
259 }
260 Some((level, range)) => {
262 let offset = cursor;
263 match level.split_position() {
264 SemanticSplitPosition::Own => {
265 if range.start < cursor {
266 continue;
267 }
268 let prev_section = text
269 .get(cursor..range.start)
270 .expect("invalid character sequence");
271 if level.treat_whitespace_as_previous()
272 && prev_section.chars().all(char::is_whitespace)
273 {
274 let section = text
275 .get(cursor..range.end)
276 .expect("invalid character sequence");
277 cursor = range.end;
278 return Some(Either::Left(once((offset, section))));
279 }
280 let separator = text
281 .get(range.start..range.end)
282 .expect("invalid character sequence");
283 cursor = range.end;
284 return Some(Either::Right(
285 [(offset, prev_section), (range.start, separator)]
286 .into_iter(),
287 ));
288 }
289 SemanticSplitPosition::Next => {
290 if range.start < cursor {
291 continue;
292 }
293 let prev_section = text
294 .get(cursor..range.start)
295 .expect("invalid character sequence");
296 cursor = range.start;
298 return Some(Either::Left(once((offset, prev_section))));
299 }
300 }
301 }
302 }
303 }
304 })
305 .flatten()
306 .filter(|(_, s)| !s.is_empty())
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use std::cmp::min;
313
314 use fake::{Fake, Faker};
315
316 use crate::splitter::SemanticSplitRanges;
317
318 use super::*;
319
320 #[test]
321 fn returns_one_chunk_if_text_is_shorter_than_max_chunk_size() {
322 let text = Faker.fake::<String>();
323 let chunks = MarkdownSplitter::new(ChunkConfig::new(text.chars().count()).with_trim(false))
324 .chunks(&text)
325 .collect::<Vec<_>>();
326
327 assert_eq!(vec![&text], chunks);
328 }
329
330 #[test]
331 fn returns_two_chunks_if_text_is_longer_than_max_chunk_size() {
332 let text1 = Faker.fake::<String>();
333 let text2 = Faker.fake::<String>();
334 let text = format!("{text1}{text2}");
335 let max_chunk_size = text.chars().count() / 2 + 1;
337
338 let chunks = MarkdownSplitter::new(ChunkConfig::new(max_chunk_size).with_trim(false))
339 .chunks(&text)
340 .collect::<Vec<_>>();
341
342 assert!(chunks.iter().all(|c| c.chars().count() <= max_chunk_size));
343
344 let len = min(text1.len(), chunks[0].len());
346 assert_eq!(text1[..len], chunks[0][..len]);
347 let len = min(text2.len(), chunks[1].len());
349 assert_eq!(
350 text2[(text2.len() - len)..],
351 chunks[1][chunks[1].len() - len..]
352 );
353
354 assert_eq!(chunks.join(""), text);
355 }
356
357 #[test]
358 fn empty_string() {
359 let text = "";
360 let chunks = MarkdownSplitter::new(ChunkConfig::new(100).with_trim(false))
361 .chunks(text)
362 .collect::<Vec<_>>();
363
364 assert!(chunks.is_empty());
365 }
366
367 #[test]
368 fn can_handle_unicode_characters() {
369 let text = "éé"; let chunks = MarkdownSplitter::new(ChunkConfig::new(1).with_trim(false))
371 .chunks(text)
372 .collect::<Vec<_>>();
373
374 assert_eq!(vec!["é", "é"], chunks);
375 }
376
377 #[test]
378 fn chunk_by_graphemes() {
379 let text = "a̐éö̲\r\n";
380 let chunks = MarkdownSplitter::new(ChunkConfig::new(3).with_trim(false))
381 .chunks(text)
382 .collect::<Vec<_>>();
383
384 assert_eq!(vec!["a̐é", "ö̲", "\r\n"], chunks);
386 }
387
388 #[test]
389 fn trim_char_indices() {
390 let text = " a b ";
391 let chunks = MarkdownSplitter::new(1)
392 .chunk_indices(text)
393 .collect::<Vec<_>>();
394
395 assert_eq!(vec![(1, "a"), (3, "b")], chunks);
396 }
397
398 #[test]
399 fn graphemes_fallback_to_chars() {
400 let text = "a̐éö̲\r\n";
401 let chunks = MarkdownSplitter::new(ChunkConfig::new(1).with_trim(false))
402 .chunks(text)
403 .collect::<Vec<_>>();
404
405 assert_eq!(
406 vec!["a", "\u{310}", "é", "ö", "\u{332}", "\r", "\n"],
407 chunks
408 );
409 }
410
411 #[test]
412 fn trim_grapheme_indices() {
413 let text = "\r\na̐éö̲\r\n";
414 let chunks = MarkdownSplitter::new(3)
415 .chunk_indices(text)
416 .collect::<Vec<_>>();
417
418 assert_eq!(vec![(2, "a̐é"), (7, "ö̲")], chunks);
419 }
420
421 #[test]
422 fn chunk_by_words() {
423 let text = "The quick brown fox can jump 32.3 feet, right?";
424 let chunks = MarkdownSplitter::new(ChunkConfig::new(10).with_trim(false))
425 .chunks(text)
426 .collect::<Vec<_>>();
427
428 assert_eq!(
429 vec![
430 "The quick ",
431 "brown fox ",
432 "can jump ",
433 "32.3 feet,",
434 " right?"
435 ],
436 chunks
437 );
438 }
439
440 #[test]
441 fn words_fallback_to_graphemes() {
442 let text = "Thé quick\r\n";
443 let chunks = MarkdownSplitter::new(ChunkConfig::new(2).with_trim(false))
444 .chunks(text)
445 .collect::<Vec<_>>();
446
447 assert_eq!(vec!["Th", "é ", "qu", "ic", "k", "\r\n"], chunks);
448 }
449
450 #[test]
451 fn trim_word_indices() {
452 let text = "Some text from a document";
453 let chunks = MarkdownSplitter::new(10)
454 .chunk_indices(text)
455 .collect::<Vec<_>>();
456
457 assert_eq!(
458 vec![(0, "Some text"), (10, "from a"), (17, "document")],
459 chunks
460 );
461 }
462
463 #[test]
464 fn chunk_by_sentences() {
465 let text = "Mr. Fox jumped. The dog was too lazy.";
466 let chunks = MarkdownSplitter::new(ChunkConfig::new(21).with_trim(false))
467 .chunks(text)
468 .collect::<Vec<_>>();
469
470 assert_eq!(vec!["Mr. Fox jumped. ", "The dog was too lazy."], chunks);
471 }
472
473 #[test]
474 fn sentences_falls_back_to_words() {
475 let text = "Mr. Fox jumped. The dog was too lazy.";
476 let chunks = MarkdownSplitter::new(ChunkConfig::new(16).with_trim(false))
477 .chunks(text)
478 .collect::<Vec<_>>();
479
480 assert_eq!(
481 vec!["Mr. Fox jumped. ", "The dog was too ", "lazy."],
482 chunks
483 );
484 }
485
486 #[test]
487 fn trim_sentence_indices() {
488 let text = "Some text. From a document.";
489 let chunks = MarkdownSplitter::new(10)
490 .chunk_indices(text)
491 .collect::<Vec<_>>();
492
493 assert_eq!(
494 vec![(0, "Some text."), (11, "From a"), (18, "document.")],
495 chunks
496 );
497 }
498
499 #[test]
500 fn test_no_markdown_separators() {
501 let splitter = MarkdownSplitter::new(10);
502 let markdown =
503 SemanticSplitRanges::new(splitter.parse("Some text without any markdown separators"));
504
505 assert_eq!(
506 vec![(Element::Block, 0..41), (Element::Inline, 0..41)],
507 markdown.ranges_after_offset(0).collect::<Vec<_>>()
508 );
509 }
510
511 #[test]
512 fn test_checklist() {
513 let splitter = MarkdownSplitter::new(10);
514 let markdown =
515 SemanticSplitRanges::new(splitter.parse("- [ ] incomplete task\n- [x] completed task"));
516
517 assert_eq!(
518 vec![
519 (Element::Block, 0..42),
520 (Element::Block, 0..22),
521 (Element::Inline, 2..5),
522 (Element::Inline, 6..21),
523 (Element::Block, 22..42),
524 (Element::Inline, 24..27),
525 (Element::Inline, 28..42),
526 ],
527 markdown.ranges_after_offset(0).collect::<Vec<_>>()
528 );
529 }
530
531 #[test]
532 fn test_footnote_reference() {
533 let splitter = MarkdownSplitter::new(10);
534 let markdown = SemanticSplitRanges::new(splitter.parse("Footnote[^1]"));
535
536 assert_eq!(
537 vec![
538 (Element::Block, 0..12),
539 (Element::Inline, 0..8),
540 (Element::Inline, 8..12),
541 ],
542 markdown.ranges_after_offset(0).collect::<Vec<_>>()
543 );
544 }
545
546 #[test]
547 fn test_inline_code() {
548 let splitter = MarkdownSplitter::new(10);
549 let markdown = SemanticSplitRanges::new(splitter.parse("`bash`"));
550
551 assert_eq!(
552 vec![(Element::Block, 0..6), (Element::Inline, 0..6)],
553 markdown.ranges_after_offset(0).collect::<Vec<_>>()
554 );
555 }
556
557 #[test]
558 fn test_emphasis() {
559 let splitter = MarkdownSplitter::new(10);
560 let markdown = SemanticSplitRanges::new(splitter.parse("*emphasis*"));
561
562 assert_eq!(
563 vec![
564 (Element::Block, 0..10),
565 (Element::Inline, 0..10),
566 (Element::Inline, 1..9),
567 ],
568 markdown.ranges_after_offset(0).collect::<Vec<_>>()
569 );
570 }
571
572 #[test]
573 fn test_strong() {
574 let splitter = MarkdownSplitter::new(10);
575 let markdown = SemanticSplitRanges::new(splitter.parse("**emphasis**"));
576
577 assert_eq!(
578 vec![
579 (Element::Block, 0..12),
580 (Element::Inline, 0..12),
581 (Element::Inline, 2..10),
582 ],
583 markdown.ranges_after_offset(0).collect::<Vec<_>>()
584 );
585 }
586
587 #[test]
588 fn test_strikethrough() {
589 let splitter = MarkdownSplitter::new(10);
590 let markdown = SemanticSplitRanges::new(splitter.parse("~~emphasis~~"));
591
592 assert_eq!(
593 vec![
594 (Element::Block, 0..12),
595 (Element::Inline, 0..12),
596 (Element::Inline, 2..10),
597 ],
598 markdown.ranges_after_offset(0).collect::<Vec<_>>()
599 );
600 }
601
602 #[test]
603 fn test_link() {
604 let splitter = MarkdownSplitter::new(10);
605 let markdown = SemanticSplitRanges::new(splitter.parse("[link](url)"));
606
607 assert_eq!(
608 vec![
609 (Element::Block, 0..11),
610 (Element::Inline, 0..11),
611 (Element::Inline, 1..5),
612 ],
613 markdown.ranges_after_offset(0).collect::<Vec<_>>()
614 );
615 }
616
617 #[test]
618 fn test_image() {
619 let splitter = MarkdownSplitter::new(10);
620 let markdown = SemanticSplitRanges::new(splitter.parse(""));
621
622 assert_eq!(
623 vec![
624 (Element::Block, 0..12),
625 (Element::Inline, 0..12),
626 (Element::Inline, 2..6),
627 ],
628 markdown.ranges_after_offset(0).collect::<Vec<_>>()
629 );
630 }
631
632 #[test]
633 fn test_inline_html() {
634 let splitter = MarkdownSplitter::new(10);
635 let markdown = SemanticSplitRanges::new(splitter.parse("<span>Some text</span>"));
636
637 assert_eq!(
638 vec![
639 (Element::Block, 0..22),
640 (Element::Inline, 0..6),
641 (Element::Inline, 6..15),
642 (Element::Inline, 15..22),
643 ],
644 markdown.ranges_after_offset(0).collect::<Vec<_>>()
645 );
646 }
647
648 #[test]
649 fn test_html() {
650 let splitter = MarkdownSplitter::new(10);
651 let markdown = SemanticSplitRanges::new(splitter.parse("<div>Some text</div>"));
652
653 assert_eq!(
654 vec![(Element::Block, 0..20), (Element::Block, 0..20)],
655 markdown.ranges_after_offset(0).collect::<Vec<_>>()
656 );
657 }
658
659 #[test]
660 fn test_table() {
661 let splitter = MarkdownSplitter::new(10);
662 let markdown = SemanticSplitRanges::new(
663 splitter.parse("| Header 1 | Header 2 |\n| --- | --- |\n| Cell 1 | Cell 2 |"),
664 );
665 assert_eq!(
666 vec![
667 (Element::Block, 0..57),
668 (Element::Block, 0..24),
669 (Element::Inline, 1..11),
670 (Element::Inline, 2..10),
671 (Element::Inline, 12..22),
672 (Element::Inline, 13..21),
673 (Element::Block, 38..57),
674 (Element::Inline, 39..47),
675 (Element::Inline, 40..46),
676 (Element::Inline, 48..56),
677 (Element::Inline, 49..55)
678 ],
679 markdown.ranges_after_offset(0).collect::<Vec<_>>()
680 );
681 }
682
683 #[test]
684 fn test_softbreak() {
685 let splitter = MarkdownSplitter::new(10);
686 let markdown = SemanticSplitRanges::new(splitter.parse("Some text\nwith a softbreak"));
687
688 assert_eq!(
689 vec![
690 (Element::Block, 0..26),
691 (Element::Inline, 0..9),
692 (Element::SoftBreak, 9..10),
693 (Element::Inline, 10..26)
694 ],
695 markdown.ranges_after_offset(0).collect::<Vec<_>>()
696 );
697 }
698
699 #[test]
700 fn test_hardbreak() {
701 let splitter = MarkdownSplitter::new(10);
702 let markdown = SemanticSplitRanges::new(splitter.parse("Some text\\\nwith a hardbreak"));
703
704 assert_eq!(
705 vec![
706 (Element::Block, 0..27),
707 (Element::Inline, 0..9),
708 (Element::Inline, 9..11),
709 (Element::Inline, 11..27)
710 ],
711 markdown.ranges_after_offset(0).collect::<Vec<_>>()
712 );
713 }
714
715 #[test]
716 fn test_footnote_def() {
717 let splitter = MarkdownSplitter::new(10);
718 let markdown = SemanticSplitRanges::new(splitter.parse("[^first]: Footnote"));
719
720 assert_eq!(
721 vec![
722 (Element::Block, 0..18),
723 (Element::Block, 10..18),
724 (Element::Inline, 10..18)
725 ],
726 markdown.ranges_after_offset(0).collect::<Vec<_>>()
727 );
728 }
729
730 #[test]
731 fn test_code_block() {
732 let splitter = MarkdownSplitter::new(10);
733 let markdown = SemanticSplitRanges::new(splitter.parse("```\ncode\n```"));
734
735 assert_eq!(
736 vec![(Element::Block, 0..12), (Element::Inline, 4..9)],
737 markdown.ranges_after_offset(0).collect::<Vec<_>>()
738 );
739 }
740
741 #[test]
742 fn test_block_quote() {
743 let splitter = MarkdownSplitter::new(10);
744 let markdown = SemanticSplitRanges::new(splitter.parse("> quote"));
745
746 assert_eq!(
747 vec![
748 (Element::Block, 0..7),
749 (Element::Block, 2..7),
750 (Element::Inline, 2..7)
751 ],
752 markdown.ranges_after_offset(0).collect::<Vec<_>>()
753 );
754 }
755
756 #[test]
757 fn test_with_rule() {
758 let splitter = MarkdownSplitter::new(10);
759 let markdown = SemanticSplitRanges::new(splitter.parse("Some text\n\n---\n\nwith a rule"));
760
761 assert_eq!(
762 vec![
763 (Element::Block, 0..10),
764 (Element::Inline, 0..9),
765 (Element::Rule, 11..15),
766 (Element::Block, 16..27),
767 (Element::Inline, 16..27)
768 ],
769 markdown.ranges_after_offset(0).collect::<Vec<_>>()
770 );
771 }
772
773 #[test]
774 fn test_heading() {
775 for (index, (heading, level)) in [
776 ("#", HeadingLevel::H1),
777 ("##", HeadingLevel::H2),
778 ("###", HeadingLevel::H3),
779 ("####", HeadingLevel::H4),
780 ("#####", HeadingLevel::H5),
781 ("######", HeadingLevel::H6),
782 ]
783 .into_iter()
784 .enumerate()
785 {
786 let splitter = MarkdownSplitter::new(10);
787 let markdown = SemanticSplitRanges::new(splitter.parse(&format!("{heading} Heading")));
788
789 assert_eq!(
790 vec![
791 (Element::Heading(level), 0..9 + index),
792 (Element::Inline, 2 + index..9 + index)
793 ],
794 markdown.ranges_after_offset(0).collect::<Vec<_>>()
795 );
796 }
797 }
798
799 #[test]
800 fn test_ranges_after_offset_block() {
801 let splitter = MarkdownSplitter::new(10);
802 let markdown =
803 SemanticSplitRanges::new(splitter.parse("- [ ] incomplete task\n- [x] completed task"));
804
805 assert_eq!(
806 vec![(Element::Block, 0..22), (Element::Block, 22..42),],
807 markdown
808 .level_ranges_after_offset(0, Element::Block)
809 .collect::<Vec<_>>()
810 );
811 }
812}