swiftide_indexing/transformers/
chunk_markdown.rs

1//! Chunk markdown content into smaller pieces
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use derive_builder::Builder;
6use swiftide_core::{ChunkerTransformer, indexing::IndexingStream, indexing::TextNode};
7use text_splitter::{Characters, ChunkConfig, MarkdownSplitter};
8
9const DEFAULT_MAX_CHAR_SIZE: usize = 2056;
10
11#[derive(Clone, Builder)]
12#[builder(setter(strip_option))]
13/// A transformer that chunks markdown content into smaller pieces.
14///
15/// The transformer will split the markdown content into smaller pieces based on the specified
16/// `max_characters` or `range` of characters.
17///
18/// For further customization, you can use the builder to create a custom splitter.
19///
20/// Technically that might work with every splitter `text_splitter` provides.
21pub struct ChunkMarkdown {
22    /// Defaults to `None`. If you use a splitter that is resource heavy, this parameter can be
23    /// tuned.
24    #[builder(default)]
25    concurrency: Option<usize>,
26
27    /// Optional maximum number of characters per chunk.
28    ///
29    /// Defaults to [`DEFAULT_MAX_CHAR_SIZE`].
30    #[builder(default = "DEFAULT_MAX_CHAR_SIZE")]
31    max_characters: usize,
32
33    /// A range of minimum and maximum characters per chunk.
34    ///
35    /// Chunks smaller than the range min will be ignored. `max_characters` will be ignored if this
36    /// is set.
37    ///
38    /// If you provide a custom chunker with a range, you might want to set the range as well.
39    ///
40    /// Defaults to 0..[`max_characters`]
41    #[builder(default = "0..DEFAULT_MAX_CHAR_SIZE")]
42    range: std::ops::Range<usize>,
43
44    /// The markdown splitter from [`text_splitter`]
45    ///
46    /// Defaults to a new [`MarkdownSplitter`] with the specified `max_characters`.
47    #[builder(setter(into), default = "self.default_client()")]
48    chunker: Arc<MarkdownSplitter<Characters>>,
49}
50
51impl std::fmt::Debug for ChunkMarkdown {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("ChunkMarkdown")
54            .field("concurrency", &self.concurrency)
55            .field("max_characters", &self.max_characters)
56            .field("range", &self.range)
57            .finish()
58    }
59}
60
61impl Default for ChunkMarkdown {
62    fn default() -> Self {
63        Self::from_max_characters(DEFAULT_MAX_CHAR_SIZE)
64    }
65}
66
67impl ChunkMarkdown {
68    pub fn builder() -> ChunkMarkdownBuilder {
69        ChunkMarkdownBuilder::default()
70    }
71
72    /// Create a new transformer with a maximum number of characters per chunk.
73    #[allow(clippy::missing_panics_doc)]
74    pub fn from_max_characters(max_characters: usize) -> Self {
75        Self::builder()
76            .max_characters(max_characters)
77            .build()
78            .expect("Cannot fail")
79    }
80
81    /// Create a new transformer with a range of characters per chunk.
82    ///
83    /// Chunks smaller than the range will be ignored.
84    #[allow(clippy::missing_panics_doc)]
85    pub fn from_chunk_range(range: std::ops::Range<usize>) -> Self {
86        Self::builder().range(range).build().expect("Cannot fail")
87    }
88
89    /// Set the number of concurrent chunks to process.
90    #[must_use]
91    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
92        self.concurrency = Some(concurrency);
93        self
94    }
95
96    fn min_size(&self) -> usize {
97        self.range.start
98    }
99}
100
101impl ChunkMarkdownBuilder {
102    fn default_client(&self) -> Arc<MarkdownSplitter<Characters>> {
103        let chunk_config: ChunkConfig<Characters> = self
104            .range
105            .clone()
106            .map(ChunkConfig::<Characters>::from)
107            .or_else(|| self.max_characters.map(Into::into))
108            .unwrap_or(DEFAULT_MAX_CHAR_SIZE.into());
109
110        Arc::new(MarkdownSplitter::new(chunk_config))
111    }
112}
113
114#[async_trait]
115impl ChunkerTransformer for ChunkMarkdown {
116    type Input = String;
117    type Output = String;
118
119    #[tracing::instrument(skip_all)]
120    async fn transform_node(&self, node: TextNode) -> IndexingStream<String> {
121        let chunks = self
122            .chunker
123            .chunks(&node.chunk)
124            .filter_map(|chunk| {
125                let trim = chunk.trim();
126                if trim.is_empty() || trim.len() < self.min_size() {
127                    None
128                } else {
129                    Some(chunk.to_string())
130                }
131            })
132            .collect::<Vec<String>>();
133
134        IndexingStream::iter(
135            chunks
136                .into_iter()
137                .map(move |chunk| TextNode::build_from_other(&node).chunk(chunk).build()),
138        )
139    }
140
141    fn concurrency(&self) -> Option<usize> {
142        self.concurrency
143    }
144}
145
146#[cfg(test)]
147mod test {
148    use super::*;
149    use futures_util::stream::TryStreamExt;
150
151    const MARKDOWN: &str = r"
152        # Hello, world!
153
154        This is a test markdown document.
155
156        ## Section 1
157
158        This is a paragraph.
159
160        ## Section 2
161
162        This is another paragraph.
163        ";
164
165    #[tokio::test]
166    async fn test_transforming_with_max_characters_and_trimming() {
167        let chunker = ChunkMarkdown::from_max_characters(40);
168
169        let node = TextNode::new(MARKDOWN.to_string());
170
171        let nodes: Vec<TextNode> = chunker
172            .transform_node(node)
173            .await
174            .try_collect()
175            .await
176            .unwrap();
177
178        dbg!(&nodes.iter().map(|n| n.chunk.clone()).collect::<Vec<_>>());
179        for line in MARKDOWN.lines().filter(|line| !line.trim().is_empty()) {
180            nodes
181                .iter()
182                .find(|node| node.chunk == line.trim())
183                .unwrap_or_else(|| panic!("Line not found: {line}"));
184        }
185
186        assert_eq!(nodes.len(), 6);
187    }
188
189    #[tokio::test]
190    async fn test_always_within_range() {
191        let ranges = vec![(10..15), (20..25), (30..35), (40..45), (50..55)];
192        for range in ranges {
193            let chunker = ChunkMarkdown::from_chunk_range(range.clone());
194            let node = TextNode::new(MARKDOWN.to_string());
195            let nodes: Vec<TextNode> = chunker
196                .transform_node(node)
197                .await
198                .try_collect()
199                .await
200                .unwrap();
201            // Assert all nodes chunk length within the range
202            assert!(
203                nodes.iter().all(|node| {
204                    let len = node.chunk.len();
205                    range.contains(&len)
206                }),
207                "{:?}, {:?}",
208                range,
209                nodes.iter().filter(|node| {
210                    let len = node.chunk.len();
211                    !range.contains(&len)
212                })
213            );
214        }
215    }
216
217    #[test]
218    fn test_builder() {
219        ChunkMarkdown::builder()
220            .chunker(MarkdownSplitter::new(40))
221            .concurrency(10)
222            .range(10..20)
223            .build()
224            .unwrap();
225    }
226}