text_splitter/splitter/code.rs
1use std::{cmp::Ordering, ops::Range};
2
3use thiserror::Error;
4use tree_sitter::{Language, LanguageError, Parser, TreeCursor, MIN_COMPATIBLE_LANGUAGE_VERSION};
5
6use crate::{
7 splitter::{SemanticLevel, Splitter},
8 trim::Trim,
9 ChunkConfig, ChunkSizer,
10};
11
12use super::ChunkCharIndex;
13
14/// Indicates there was an error with creating a `CodeSplitter`.
15/// The `Display` implementation will provide a human-readable error message to
16/// help debug the issue that caused the error.
17#[derive(Error, Debug)]
18#[error(transparent)]
19pub struct CodeSplitterError(#[from] CodeSplitterErrorRepr);
20
21/// Private error and free to change across minor version of the crate.
22#[derive(Error, Debug)]
23enum CodeSplitterErrorRepr {
24 #[error(
25 "Language version {0:?} is too old. Expected at least version {min_version}",
26 min_version=MIN_COMPATIBLE_LANGUAGE_VERSION,
27 )]
28 LanguageError(LanguageError),
29}
30
31/// Source code splitter. Recursively splits chunks into the largest
32/// semantic units that fit within the chunk size. Also will attempt to merge
33/// neighboring chunks if they can fit within the given chunk size.
34#[derive(Debug)]
35pub struct CodeSplitter<Sizer>
36where
37 Sizer: ChunkSizer,
38{
39 /// Method of determining chunk sizes.
40 chunk_config: ChunkConfig<Sizer>,
41 /// Language to use for parsing the code.
42 language: Language,
43}
44
45impl<Sizer> CodeSplitter<Sizer>
46where
47 Sizer: ChunkSizer,
48{
49 /// Creates a new [`CodeSplitter`].
50 ///
51 /// ```
52 /// use text_splitter::CodeSplitter;
53 ///
54 /// // By default, the chunk sizer is based on characters.
55 /// let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 512).expect("Invalid language");
56 /// ```
57 ///
58 /// # Errors
59 ///
60 /// Will return an error if the language version is too old to be compatible
61 /// with the current version of the tree-sitter crate.
62 pub fn new(
63 language: impl Into<Language>,
64 chunk_config: impl Into<ChunkConfig<Sizer>>,
65 ) -> Result<Self, CodeSplitterError> {
66 // Verify that this is a valid language so we can rely on that later.
67 let mut parser = Parser::new();
68 let language = language.into();
69 parser
70 .set_language(&language)
71 .map_err(CodeSplitterErrorRepr::LanguageError)?;
72 Ok(Self {
73 chunk_config: chunk_config.into(),
74 language,
75 })
76 }
77
78 /// Generate a list of chunks from a given text. Each chunk will be up to the `chunk_capacity`.
79 ///
80 /// ## Method
81 ///
82 /// To preserve as much semantic meaning within a chunk as possible, each chunk is composed of the largest semantic units that can fit in the next given chunk. For each splitter type, there is a defined set of semantic levels. Here is an example of the steps used:
83 ///
84 /// 1. Split the text by a increasing semantic levels.
85 /// 2. Check the first item for each level and select the highest level whose first item still fits within the chunk size.
86 /// 3. Merge as many of these neighboring sections of this level or above into a chunk to maximize chunk length.
87 /// Boundaries of higher semantic levels are always included when merging, so that the chunk doesn't inadvertantly cross semantic boundaries.
88 ///
89 /// The boundaries used to split the text if using the `chunks` method, in ascending order:
90 ///
91 /// 1. Characters
92 /// 2. [Unicode Grapheme Cluster Boundaries](https://www.unicode.org/reports/tr29/#Grapheme_Cluster_Boundaries)
93 /// 3. [Unicode Word Boundaries](https://www.unicode.org/reports/tr29/#Word_Boundaries)
94 /// 4. [Unicode Sentence Boundaries](https://www.unicode.org/reports/tr29/#Sentence_Boundaries)
95 /// 5. Ascending depth of the syntax tree. So function would have a higher level than a statement inside of the function, and so on.
96 ///
97 /// Splitting doesn't occur below the character level, otherwise you could get partial bytes of a char, which may not be a valid unicode str.
98 ///
99 /// ```
100 /// use text_splitter::CodeSplitter;
101 ///
102 /// let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 10).expect("Invalid language");
103 /// let text = "Some text\n\nfrom a\ndocument";
104 /// let chunks = splitter.chunks(text).collect::<Vec<_>>();
105 ///
106 /// assert_eq!(vec!["Some text", "from a", "document"], chunks);
107 /// ```
108 pub fn chunks<'splitter, 'text: 'splitter>(
109 &'splitter self,
110 text: &'text str,
111 ) -> impl Iterator<Item = &'text str> + 'splitter {
112 Splitter::<_>::chunks(self, text)
113 }
114
115 /// Returns an iterator over chunks of the text and their byte offsets.
116 /// Each chunk will be up to the `chunk_capacity`.
117 ///
118 /// See [`CodeSplitter::chunks`] for more information.
119 ///
120 /// ```
121 /// use text_splitter::{ChunkCharIndex, CodeSplitter};
122 ///
123 /// let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 10).expect("Invalid language");
124 /// let text = "Some text\n\nfrom a\ndocument";
125 /// let chunks = splitter.chunk_indices(text).collect::<Vec<_>>();
126 ///
127 /// assert_eq!(vec![(0, "Some text"), (11, "from a"), (18, "document")], chunks);
128 /// ```
129 pub fn chunk_indices<'splitter, 'text: 'splitter>(
130 &'splitter self,
131 text: &'text str,
132 ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
133 Splitter::<_>::chunk_indices(self, text)
134 }
135
136 /// Returns an iterator over chunks of the text with their byte and character offsets.
137 /// Each chunk will be up to the `chunk_capacity`.
138 ///
139 /// See [`CodeSplitter::chunks`] for more information.
140 ///
141 /// This will be more expensive than just byte offsets, and for most usage in Rust, just
142 /// having byte offsets is sufficient. But when interfacing with other languages or systems
143 /// that require character offsets, this will track the character offsets for you,
144 /// accounting for any trimming that may have occurred.
145 ///
146 /// ```
147 /// use text_splitter::{ChunkCharIndex, CodeSplitter};
148 ///
149 /// let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 10).expect("Invalid language");
150 /// let text = "Some text\n\nfrom a\ndocument";
151 /// let chunks = splitter.chunk_char_indices(text).collect::<Vec<_>>();
152 ///
153 /// assert_eq!(vec![ChunkCharIndex {chunk: "Some text", byte_offset: 0, char_offset: 0}, ChunkCharIndex {chunk: "from a", byte_offset: 11, char_offset: 11}, ChunkCharIndex {chunk: "document", byte_offset: 18, char_offset: 18}], chunks);
154 /// ```
155 pub fn chunk_char_indices<'splitter, 'text: 'splitter>(
156 &'splitter self,
157 text: &'text str,
158 ) -> impl Iterator<Item = ChunkCharIndex<'text>> + 'splitter {
159 Splitter::<_>::chunk_char_indices(self, text)
160 }
161}
162
163impl<Sizer> Splitter<Sizer> for CodeSplitter<Sizer>
164where
165 Sizer: ChunkSizer,
166{
167 type Level = Depth;
168
169 const TRIM: Trim = Trim::PreserveIndentation;
170
171 fn chunk_config(&self) -> &ChunkConfig<Sizer> {
172 &self.chunk_config
173 }
174
175 fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
176 let mut parser = Parser::new();
177 parser
178 .set_language(&self.language)
179 // We verify at initialization that the language is valid, so this should be safe.
180 .expect("Error loading language");
181 // The only reason the tree would be None is:
182 // - No language was set (we do that)
183 // - There was a timeout or cancellation option set (we don't)
184 // - So it should be safe to unwrap here
185 let tree = parser.parse(text, None).expect("Error parsing source code");
186
187 CursorOffsets::new(tree.walk()).collect()
188 }
189}
190
191/// New type around a usize to capture the depth of a given code node.
192/// Custom type so that we can implement custom ordering, since we want to
193/// sort items of lower depth as higher priority.
194#[derive(Clone, Copy, Debug, Eq, PartialEq)]
195pub struct Depth(usize);
196
197impl PartialOrd for Depth {
198 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
199 Some(self.cmp(other))
200 }
201}
202
203impl Ord for Depth {
204 fn cmp(&self, other: &Self) -> Ordering {
205 other.0.cmp(&self.0)
206 }
207}
208
209/// New type around a tree-sitter cursor to allow for implementing an iterator.
210/// Each call to `next()` will return the next node in the tree in a depth-first
211/// order.
212struct CursorOffsets<'cursor> {
213 cursor: TreeCursor<'cursor>,
214}
215
216impl<'cursor> CursorOffsets<'cursor> {
217 fn new(cursor: TreeCursor<'cursor>) -> Self {
218 Self { cursor }
219 }
220}
221
222impl Iterator for CursorOffsets<'_> {
223 type Item = (Depth, Range<usize>);
224
225 fn next(&mut self) -> Option<Self::Item> {
226 // There are children (can call this initially because we don't want the root node)
227 if self.cursor.goto_first_child() {
228 return Some((
229 Depth(self.cursor.depth() as usize),
230 self.cursor.node().byte_range(),
231 ));
232 }
233
234 loop {
235 // There are sibling elements to grab
236 if self.cursor.goto_next_sibling() {
237 return Some((
238 Depth(self.cursor.depth() as usize),
239 self.cursor.node().byte_range(),
240 ));
241 // Start going back up the tree and check for next sibling on next iteration.
242 } else if self.cursor.goto_parent() {
243 continue;
244 }
245
246 // We have no more siblings or parents, so we're done.
247 return None;
248 }
249 }
250}
251
252impl SemanticLevel for Depth {}
253
254#[cfg(test)]
255mod tests {
256 use tree_sitter::{Node, Tree};
257
258 use super::*;
259
260 #[test]
261 fn rust_splitter() {
262 let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 16).unwrap();
263 let text = "fn main() {\n let x = 5;\n}";
264 let chunks = splitter.chunks(text).collect::<Vec<_>>();
265
266 assert_eq!(chunks, vec!["fn main()", "{\n let x = 5;", "}"]);
267 }
268
269 #[test]
270 fn rust_splitter_indices() {
271 let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 16).unwrap();
272 let text = "fn main() {\n let x = 5;\n}";
273 let chunks = splitter.chunk_indices(text).collect::<Vec<_>>();
274
275 assert_eq!(
276 chunks,
277 vec![(0, "fn main()"), (10, "{\n let x = 5;"), (27, "}")]
278 );
279 }
280
281 #[test]
282 fn rust_splitter_char_indices() {
283 let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 16).unwrap();
284 let text = "fn main() {\n let x = 5;\n}";
285 let chunks = splitter.chunk_char_indices(text).collect::<Vec<_>>();
286
287 assert_eq!(
288 chunks,
289 vec![
290 ChunkCharIndex {
291 chunk: "fn main()",
292 byte_offset: 0,
293 char_offset: 0
294 },
295 ChunkCharIndex {
296 chunk: "{\n let x = 5;",
297 byte_offset: 10,
298 char_offset: 10
299 },
300 ChunkCharIndex {
301 chunk: "}",
302 byte_offset: 27,
303 char_offset: 27
304 }
305 ]
306 );
307 }
308
309 #[test]
310 fn depth_partialord() {
311 assert_eq!(Depth(0).partial_cmp(&Depth(1)), Some(Ordering::Greater));
312 assert_eq!(Depth(1).partial_cmp(&Depth(2)), Some(Ordering::Greater));
313 assert_eq!(Depth(1).partial_cmp(&Depth(1)), Some(Ordering::Equal));
314 assert_eq!(Depth(2).partial_cmp(&Depth(1)), Some(Ordering::Less));
315 }
316
317 #[test]
318 fn depth_ord() {
319 assert_eq!(Depth(0).cmp(&Depth(1)), Ordering::Greater);
320 assert_eq!(Depth(1).cmp(&Depth(2)), Ordering::Greater);
321 assert_eq!(Depth(1).cmp(&Depth(1)), Ordering::Equal);
322 assert_eq!(Depth(2).cmp(&Depth(1)), Ordering::Less);
323 }
324
325 #[test]
326 fn depth_sorting() {
327 let mut depths = vec![Depth(0), Depth(1), Depth(2)];
328 depths.sort();
329 assert_eq!(depths, [Depth(2), Depth(1), Depth(0)]);
330 }
331
332 /// Checks that the optimized version of the code produces the same results as the naive version.
333 #[test]
334 fn optimized_code_offsets() {
335 let mut parser = Parser::new();
336 parser
337 .set_language(&tree_sitter_rust::LANGUAGE.into())
338 .expect("Error loading Rust grammar");
339 let source_code = "fn test() {
340 let x = 1;
341}";
342 let tree = parser
343 .parse(source_code, None)
344 .expect("Error parsing source code");
345
346 let offsets = CursorOffsets::new(tree.walk()).collect::<Vec<_>>();
347
348 assert_eq!(offsets, naive_offsets(&tree));
349 }
350
351 #[test]
352 fn multiple_top_siblings() {
353 let mut parser = Parser::new();
354 parser
355 .set_language(&tree_sitter_rust::LANGUAGE.into())
356 .expect("Error loading Rust grammar");
357 let source_code = "
358fn fn1() {}
359fn fn2() {}
360fn fn3() {}
361fn fn4() {}";
362 let tree = parser
363 .parse(source_code, None)
364 .expect("Error parsing source code");
365
366 let offsets = CursorOffsets::new(tree.walk()).collect::<Vec<_>>();
367
368 assert_eq!(offsets, naive_offsets(&tree));
369 }
370
371 fn naive_offsets(tree: &Tree) -> Vec<(Depth, Range<usize>)> {
372 let root_node = tree.root_node();
373 let mut offsets = vec![];
374 recursive_naive_offsets(&mut offsets, root_node, 0);
375 offsets
376 }
377
378 // Basic version to compare an optimized version against. According to the tree-sitter
379 // documentation, this is not efficient for large trees. But because it is the easiest
380 // to reason about it is a good check for correctness.
381 fn recursive_naive_offsets(
382 collection: &mut Vec<(Depth, Range<usize>)>,
383 node: Node<'_>,
384 depth: usize,
385 ) {
386 // We can skip the root node
387 if depth > 0 {
388 collection.push((Depth(depth), node.byte_range()));
389 }
390
391 for child in node.children(&mut node.walk()) {
392 recursive_naive_offsets(collection, child, depth + 1);
393 }
394 }
395}