swiftide_indexing/transformers/
chunk_markdown.rs1use std::sync::Arc;
3
4use async_trait::async_trait;
5use derive_builder::Builder;
6use swiftide_core::{ChunkerTransformer, indexing::IndexingStream, indexing::Node};
7use text_splitter::{Characters, ChunkConfig, MarkdownSplitter};
8
9const DEFAULT_MAX_CHAR_SIZE: usize = 2056;
10
11#[derive(Clone, Builder)]
12#[builder(setter(strip_option))]
13pub struct ChunkMarkdown {
22 #[builder(default)]
25 concurrency: Option<usize>,
26
27 #[builder(default = "DEFAULT_MAX_CHAR_SIZE")]
31 max_characters: usize,
32
33 #[builder(default = "0..DEFAULT_MAX_CHAR_SIZE")]
42 range: std::ops::Range<usize>,
43
44 #[builder(setter(into), default = "self.default_client()")]
48 chunker: Arc<MarkdownSplitter<Characters>>,
49}
50
51impl std::fmt::Debug for ChunkMarkdown {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("ChunkMarkdown")
54 .field("concurrency", &self.concurrency)
55 .field("max_characters", &self.max_characters)
56 .field("range", &self.range)
57 .finish()
58 }
59}
60
61impl Default for ChunkMarkdown {
62 fn default() -> Self {
63 Self::from_max_characters(DEFAULT_MAX_CHAR_SIZE)
64 }
65}
66
67impl ChunkMarkdown {
68 pub fn builder() -> ChunkMarkdownBuilder {
69 ChunkMarkdownBuilder::default()
70 }
71
72 #[allow(clippy::missing_panics_doc)]
74 pub fn from_max_characters(max_characters: usize) -> Self {
75 Self::builder()
76 .max_characters(max_characters)
77 .build()
78 .expect("Cannot fail")
79 }
80
81 #[allow(clippy::missing_panics_doc)]
85 pub fn from_chunk_range(range: std::ops::Range<usize>) -> Self {
86 Self::builder().range(range).build().expect("Cannot fail")
87 }
88
89 #[must_use]
91 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
92 self.concurrency = Some(concurrency);
93 self
94 }
95
96 fn min_size(&self) -> usize {
97 self.range.start
98 }
99}
100
101impl ChunkMarkdownBuilder {
102 fn default_client(&self) -> Arc<MarkdownSplitter<Characters>> {
103 let chunk_config: ChunkConfig<Characters> = self
104 .range
105 .clone()
106 .map(ChunkConfig::<Characters>::from)
107 .or_else(|| self.max_characters.map(Into::into))
108 .unwrap_or(DEFAULT_MAX_CHAR_SIZE.into());
109
110 Arc::new(MarkdownSplitter::new(chunk_config))
111 }
112}
113
114#[async_trait]
115impl ChunkerTransformer for ChunkMarkdown {
116 #[tracing::instrument(skip_all)]
117 async fn transform_node(&self, node: Node) -> IndexingStream {
118 let chunks = self
119 .chunker
120 .chunks(&node.chunk)
121 .filter_map(|chunk| {
122 let trim = chunk.trim();
123 if trim.is_empty() || trim.len() < self.min_size() {
124 None
125 } else {
126 Some(chunk.to_string())
127 }
128 })
129 .collect::<Vec<String>>();
130
131 IndexingStream::iter(
132 chunks
133 .into_iter()
134 .map(move |chunk| Node::build_from_other(&node).chunk(chunk).build()),
135 )
136 }
137
138 fn concurrency(&self) -> Option<usize> {
139 self.concurrency
140 }
141}
142
143#[cfg(test)]
144mod test {
145 use super::*;
146 use futures_util::stream::TryStreamExt;
147
148 const MARKDOWN: &str = r"
149 # Hello, world!
150
151 This is a test markdown document.
152
153 ## Section 1
154
155 This is a paragraph.
156
157 ## Section 2
158
159 This is another paragraph.
160 ";
161
162 #[tokio::test]
163 async fn test_transforming_with_max_characters_and_trimming() {
164 let chunker = ChunkMarkdown::from_max_characters(40);
165
166 let node = Node::new(MARKDOWN.to_string());
167
168 let nodes: Vec<Node> = chunker
169 .transform_node(node)
170 .await
171 .try_collect()
172 .await
173 .unwrap();
174
175 dbg!(&nodes.iter().map(|n| n.chunk.clone()).collect::<Vec<_>>());
176 for line in MARKDOWN.lines().filter(|line| !line.trim().is_empty()) {
177 nodes
178 .iter()
179 .find(|node| node.chunk == line.trim())
180 .unwrap_or_else(|| panic!("Line not found: {line}"));
181 }
182
183 assert_eq!(nodes.len(), 6);
184 }
185
186 #[tokio::test]
187 async fn test_always_within_range() {
188 let ranges = vec![(10..15), (20..25), (30..35), (40..45), (50..55)];
189 for range in ranges {
190 let chunker = ChunkMarkdown::from_chunk_range(range.clone());
191 let node = Node::new(MARKDOWN.to_string());
192 let nodes: Vec<Node> = chunker
193 .transform_node(node)
194 .await
195 .try_collect()
196 .await
197 .unwrap();
198 assert!(
200 nodes.iter().all(|node| {
201 let len = node.chunk.len();
202 range.contains(&len)
203 }),
204 "{:?}, {:?}",
205 range,
206 nodes.iter().filter(|node| {
207 let len = node.chunk.len();
208 !range.contains(&len)
209 })
210 );
211 }
212 }
213
214 #[test]
215 fn test_builder() {
216 ChunkMarkdown::builder()
217 .chunker(MarkdownSplitter::new(40))
218 .concurrency(10)
219 .range(10..20)
220 .build()
221 .unwrap();
222 }
223}