swiftide_indexing/transformers/
chunk_markdown.rs1use std::sync::Arc;
3
4use async_trait::async_trait;
5use derive_builder::Builder;
6use swiftide_core::{ChunkerTransformer, indexing::IndexingStream, indexing::TextNode};
7use text_splitter::{Characters, ChunkConfig, MarkdownSplitter};
8
9const DEFAULT_MAX_CHAR_SIZE: usize = 2056;
10
11#[derive(Clone, Builder)]
12#[builder(setter(strip_option))]
13pub struct ChunkMarkdown {
22 #[builder(default)]
25 concurrency: Option<usize>,
26
27 #[builder(default = "DEFAULT_MAX_CHAR_SIZE")]
31 max_characters: usize,
32
33 #[builder(default = "0..DEFAULT_MAX_CHAR_SIZE")]
42 range: std::ops::Range<usize>,
43
44 #[builder(setter(into), default = "self.default_client()")]
48 chunker: Arc<MarkdownSplitter<Characters>>,
49}
50
51impl std::fmt::Debug for ChunkMarkdown {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("ChunkMarkdown")
54 .field("concurrency", &self.concurrency)
55 .field("max_characters", &self.max_characters)
56 .field("range", &self.range)
57 .finish()
58 }
59}
60
61impl Default for ChunkMarkdown {
62 fn default() -> Self {
63 Self::from_max_characters(DEFAULT_MAX_CHAR_SIZE)
64 }
65}
66
67impl ChunkMarkdown {
68 pub fn builder() -> ChunkMarkdownBuilder {
69 ChunkMarkdownBuilder::default()
70 }
71
72 #[allow(clippy::missing_panics_doc)]
74 pub fn from_max_characters(max_characters: usize) -> Self {
75 Self::builder()
76 .max_characters(max_characters)
77 .build()
78 .expect("Cannot fail")
79 }
80
81 #[allow(clippy::missing_panics_doc)]
85 pub fn from_chunk_range(range: std::ops::Range<usize>) -> Self {
86 Self::builder().range(range).build().expect("Cannot fail")
87 }
88
89 #[must_use]
91 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
92 self.concurrency = Some(concurrency);
93 self
94 }
95
96 fn min_size(&self) -> usize {
97 self.range.start
98 }
99}
100
101impl ChunkMarkdownBuilder {
102 fn default_client(&self) -> Arc<MarkdownSplitter<Characters>> {
103 let chunk_config: ChunkConfig<Characters> = self
104 .range
105 .clone()
106 .map(ChunkConfig::<Characters>::from)
107 .or_else(|| self.max_characters.map(Into::into))
108 .unwrap_or(DEFAULT_MAX_CHAR_SIZE.into());
109
110 Arc::new(MarkdownSplitter::new(chunk_config))
111 }
112}
113
114#[async_trait]
115impl ChunkerTransformer for ChunkMarkdown {
116 type Input = String;
117 type Output = String;
118
119 #[tracing::instrument(skip_all)]
120 async fn transform_node(&self, node: TextNode) -> IndexingStream<String> {
121 let chunks = self
122 .chunker
123 .chunks(&node.chunk)
124 .filter_map(|chunk| {
125 let trim = chunk.trim();
126 if trim.is_empty() || trim.len() < self.min_size() {
127 None
128 } else {
129 Some(chunk.to_string())
130 }
131 })
132 .collect::<Vec<String>>();
133
134 IndexingStream::iter(
135 chunks
136 .into_iter()
137 .map(move |chunk| TextNode::build_from_other(&node).chunk(chunk).build()),
138 )
139 }
140
141 fn concurrency(&self) -> Option<usize> {
142 self.concurrency
143 }
144}
145
146#[cfg(test)]
147mod test {
148 use super::*;
149 use futures_util::stream::TryStreamExt;
150
151 const MARKDOWN: &str = r"
152 # Hello, world!
153
154 This is a test markdown document.
155
156 ## Section 1
157
158 This is a paragraph.
159
160 ## Section 2
161
162 This is another paragraph.
163 ";
164
165 #[tokio::test]
166 async fn test_transforming_with_max_characters_and_trimming() {
167 let chunker = ChunkMarkdown::from_max_characters(40);
168
169 let node = TextNode::new(MARKDOWN.to_string());
170
171 let nodes: Vec<TextNode> = chunker
172 .transform_node(node)
173 .await
174 .try_collect()
175 .await
176 .unwrap();
177
178 dbg!(&nodes.iter().map(|n| n.chunk.clone()).collect::<Vec<_>>());
179 for line in MARKDOWN.lines().filter(|line| !line.trim().is_empty()) {
180 nodes
181 .iter()
182 .find(|node| node.chunk == line.trim())
183 .unwrap_or_else(|| panic!("Line not found: {line}"));
184 }
185
186 assert_eq!(nodes.len(), 6);
187 }
188
189 #[tokio::test]
190 async fn test_always_within_range() {
191 let ranges = vec![(10..15), (20..25), (30..35), (40..45), (50..55)];
192 for range in ranges {
193 let chunker = ChunkMarkdown::from_chunk_range(range.clone());
194 let node = TextNode::new(MARKDOWN.to_string());
195 let nodes: Vec<TextNode> = chunker
196 .transform_node(node)
197 .await
198 .try_collect()
199 .await
200 .unwrap();
201 assert!(
203 nodes.iter().all(|node| {
204 let len = node.chunk.len();
205 range.contains(&len)
206 }),
207 "{:?}, {:?}",
208 range,
209 nodes.iter().filter(|node| {
210 let len = node.chunk.len();
211 !range.contains(&len)
212 })
213 );
214 }
215 }
216
217 #[test]
218 fn test_builder() {
219 ChunkMarkdown::builder()
220 .chunker(MarkdownSplitter::new(40))
221 .concurrency(10)
222 .range(10..20)
223 .build()
224 .unwrap();
225 }
226}