swiftide_indexing/transformers/
chunk_markdown.rs

1//! Chunk markdown content into smaller pieces
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use derive_builder::Builder;
6use swiftide_core::{ChunkerTransformer, indexing::IndexingStream, indexing::Node};
7use text_splitter::{Characters, ChunkConfig, MarkdownSplitter};
8
9const DEFAULT_MAX_CHAR_SIZE: usize = 2056;
10
11#[derive(Clone, Builder)]
12#[builder(setter(strip_option))]
13/// A transformer that chunks markdown content into smaller pieces.
14///
15/// The transformer will split the markdown content into smaller pieces based on the specified
16/// `max_characters` or `range` of characters.
17///
18/// For further customization, you can use the builder to create a custom splitter.
19///
20/// Technically that might work with every splitter `text_splitter` provides.
21pub struct ChunkMarkdown {
22    /// Defaults to `None`. If you use a splitter that is resource heavy, this parameter can be
23    /// tuned.
24    #[builder(default)]
25    concurrency: Option<usize>,
26
27    /// Optional maximum number of characters per chunk.
28    ///
29    /// Defaults to [`DEFAULT_MAX_CHAR_SIZE`].
30    #[builder(default = "DEFAULT_MAX_CHAR_SIZE")]
31    max_characters: usize,
32
33    /// A range of minimum and maximum characters per chunk.
34    ///
35    /// Chunks smaller than the range min will be ignored. `max_characters` will be ignored if this
36    /// is set.
37    ///
38    /// If you provide a custom chunker with a range, you might want to set the range as well.
39    ///
40    /// Defaults to 0..[`max_characters`]
41    #[builder(default = "0..DEFAULT_MAX_CHAR_SIZE")]
42    range: std::ops::Range<usize>,
43
44    /// The markdown splitter from [`text_splitter`]
45    ///
46    /// Defaults to a new [`MarkdownSplitter`] with the specified `max_characters`.
47    #[builder(setter(into), default = "self.default_client()")]
48    chunker: Arc<MarkdownSplitter<Characters>>,
49}
50
51impl std::fmt::Debug for ChunkMarkdown {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("ChunkMarkdown")
54            .field("concurrency", &self.concurrency)
55            .field("max_characters", &self.max_characters)
56            .field("range", &self.range)
57            .finish()
58    }
59}
60
61impl Default for ChunkMarkdown {
62    fn default() -> Self {
63        Self::from_max_characters(DEFAULT_MAX_CHAR_SIZE)
64    }
65}
66
67impl ChunkMarkdown {
68    pub fn builder() -> ChunkMarkdownBuilder {
69        ChunkMarkdownBuilder::default()
70    }
71
72    /// Create a new transformer with a maximum number of characters per chunk.
73    #[allow(clippy::missing_panics_doc)]
74    pub fn from_max_characters(max_characters: usize) -> Self {
75        Self::builder()
76            .max_characters(max_characters)
77            .build()
78            .expect("Cannot fail")
79    }
80
81    /// Create a new transformer with a range of characters per chunk.
82    ///
83    /// Chunks smaller than the range will be ignored.
84    #[allow(clippy::missing_panics_doc)]
85    pub fn from_chunk_range(range: std::ops::Range<usize>) -> Self {
86        Self::builder().range(range).build().expect("Cannot fail")
87    }
88
89    /// Set the number of concurrent chunks to process.
90    #[must_use]
91    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
92        self.concurrency = Some(concurrency);
93        self
94    }
95
96    fn min_size(&self) -> usize {
97        self.range.start
98    }
99}
100
101impl ChunkMarkdownBuilder {
102    fn default_client(&self) -> Arc<MarkdownSplitter<Characters>> {
103        let chunk_config: ChunkConfig<Characters> = self
104            .range
105            .clone()
106            .map(ChunkConfig::<Characters>::from)
107            .or_else(|| self.max_characters.map(Into::into))
108            .unwrap_or(DEFAULT_MAX_CHAR_SIZE.into());
109
110        Arc::new(MarkdownSplitter::new(chunk_config))
111    }
112}
113
114#[async_trait]
115impl ChunkerTransformer for ChunkMarkdown {
116    #[tracing::instrument(skip_all)]
117    async fn transform_node(&self, node: Node) -> IndexingStream {
118        let chunks = self
119            .chunker
120            .chunks(&node.chunk)
121            .filter_map(|chunk| {
122                let trim = chunk.trim();
123                if trim.is_empty() || trim.len() < self.min_size() {
124                    None
125                } else {
126                    Some(chunk.to_string())
127                }
128            })
129            .collect::<Vec<String>>();
130
131        IndexingStream::iter(
132            chunks
133                .into_iter()
134                .map(move |chunk| Node::build_from_other(&node).chunk(chunk).build()),
135        )
136    }
137
138    fn concurrency(&self) -> Option<usize> {
139        self.concurrency
140    }
141}
142
143#[cfg(test)]
144mod test {
145    use super::*;
146    use futures_util::stream::TryStreamExt;
147
148    const MARKDOWN: &str = r"
149        # Hello, world!
150
151        This is a test markdown document.
152
153        ## Section 1
154
155        This is a paragraph.
156
157        ## Section 2
158
159        This is another paragraph.
160        ";
161
162    #[tokio::test]
163    async fn test_transforming_with_max_characters_and_trimming() {
164        let chunker = ChunkMarkdown::from_max_characters(40);
165
166        let node = Node::new(MARKDOWN.to_string());
167
168        let nodes: Vec<Node> = chunker
169            .transform_node(node)
170            .await
171            .try_collect()
172            .await
173            .unwrap();
174
175        dbg!(&nodes.iter().map(|n| n.chunk.clone()).collect::<Vec<_>>());
176        for line in MARKDOWN.lines().filter(|line| !line.trim().is_empty()) {
177            nodes
178                .iter()
179                .find(|node| node.chunk == line.trim())
180                .unwrap_or_else(|| panic!("Line not found: {line}"));
181        }
182
183        assert_eq!(nodes.len(), 6);
184    }
185
186    #[tokio::test]
187    async fn test_always_within_range() {
188        let ranges = vec![(10..15), (20..25), (30..35), (40..45), (50..55)];
189        for range in ranges {
190            let chunker = ChunkMarkdown::from_chunk_range(range.clone());
191            let node = Node::new(MARKDOWN.to_string());
192            let nodes: Vec<Node> = chunker
193                .transform_node(node)
194                .await
195                .try_collect()
196                .await
197                .unwrap();
198            // Assert all nodes chunk length within the range
199            assert!(
200                nodes.iter().all(|node| {
201                    let len = node.chunk.len();
202                    range.contains(&len)
203                }),
204                "{:?}, {:?}",
205                range,
206                nodes.iter().filter(|node| {
207                    let len = node.chunk.len();
208                    !range.contains(&len)
209                })
210            );
211        }
212    }
213
214    #[test]
215    fn test_builder() {
216        ChunkMarkdown::builder()
217            .chunker(MarkdownSplitter::new(40))
218            .concurrency(10)
219            .range(10..20)
220            .build()
221            .unwrap();
222    }
223}