text_splitter/splitter/
code.rs

1use std::{cmp::Ordering, ops::Range};
2
3use thiserror::Error;
4use tree_sitter::{Language, LanguageError, Parser, TreeCursor, MIN_COMPATIBLE_LANGUAGE_VERSION};
5
6use crate::{
7    splitter::{SemanticLevel, Splitter},
8    trim::Trim,
9    ChunkConfig, ChunkSizer,
10};
11
12use super::ChunkCharIndex;
13
14/// Indicates there was an error with creating a `CodeSplitter`.
15/// The `Display` implementation will provide a human-readable error message to
16/// help debug the issue that caused the error.
17#[derive(Error, Debug)]
18#[error(transparent)]
19#[allow(clippy::module_name_repetitions)]
20pub struct CodeSplitterError(#[from] CodeSplitterErrorRepr);
21
22/// Private error and free to change across minor version of the crate.
23#[derive(Error, Debug)]
24enum CodeSplitterErrorRepr {
25    #[error(
26        "Language version {0:?} is too old. Expected at least version {min_version}",
27        min_version=MIN_COMPATIBLE_LANGUAGE_VERSION,
28    )]
29    LanguageError(LanguageError),
30}
31
32/// Source code splitter. Recursively splits chunks into the largest
33/// semantic units that fit within the chunk size. Also will attempt to merge
34/// neighboring chunks if they can fit within the given chunk size.
35#[derive(Debug)]
36#[allow(clippy::module_name_repetitions)]
37pub struct CodeSplitter<Sizer>
38where
39    Sizer: ChunkSizer,
40{
41    /// Method of determining chunk sizes.
42    chunk_config: ChunkConfig<Sizer>,
43    /// Language to use for parsing the code.
44    language: Language,
45}
46
47impl<Sizer> CodeSplitter<Sizer>
48where
49    Sizer: ChunkSizer,
50{
51    /// Creates a new [`CodeSplitter`].
52    ///
53    /// ```
54    /// use text_splitter::CodeSplitter;
55    ///
56    /// // By default, the chunk sizer is based on characters.
57    /// let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 512).expect("Invalid language");
58    /// ```
59    ///
60    /// # Errors
61    ///
62    /// Will return an error if the language version is too old to be compatible
63    /// with the current version of the tree-sitter crate.
64    pub fn new(
65        language: impl Into<Language>,
66        chunk_config: impl Into<ChunkConfig<Sizer>>,
67    ) -> Result<Self, CodeSplitterError> {
68        // Verify that this is a valid language so we can rely on that later.
69        let mut parser = Parser::new();
70        let language = language.into();
71        parser
72            .set_language(&language)
73            .map_err(CodeSplitterErrorRepr::LanguageError)?;
74        Ok(Self {
75            chunk_config: chunk_config.into(),
76            language,
77        })
78    }
79
80    /// Generate a list of chunks from a given text. Each chunk will be up to the `chunk_capacity`.
81    ///
82    /// ## Method
83    ///
84    /// To preserve as much semantic meaning within a chunk as possible, each chunk is composed of the largest semantic units that can fit in the next given chunk. For each splitter type, there is a defined set of semantic levels. Here is an example of the steps used:
85    ///
86    /// 1. Split the text by a increasing semantic levels.
87    /// 2. Check the first item for each level and select the highest level whose first item still fits within the chunk size.
88    /// 3. Merge as many of these neighboring sections of this level or above into a chunk to maximize chunk length.
89    ///    Boundaries of higher semantic levels are always included when merging, so that the chunk doesn't inadvertantly cross semantic boundaries.
90    ///
91    /// The boundaries used to split the text if using the `chunks` method, in ascending order:
92    ///
93    /// 1. Characters
94    /// 2. [Unicode Grapheme Cluster Boundaries](https://www.unicode.org/reports/tr29/#Grapheme_Cluster_Boundaries)
95    /// 3. [Unicode Word Boundaries](https://www.unicode.org/reports/tr29/#Word_Boundaries)
96    /// 4. [Unicode Sentence Boundaries](https://www.unicode.org/reports/tr29/#Sentence_Boundaries)
97    /// 5. Ascending depth of the syntax tree. So function would have a higher level than a statement inside of the function, and so on.
98    ///
99    /// Splitting doesn't occur below the character level, otherwise you could get partial bytes of a char, which may not be a valid unicode str.
100    ///
101    /// ```
102    /// use text_splitter::CodeSplitter;
103    ///
104    /// let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 10).expect("Invalid language");
105    /// let text = "Some text\n\nfrom a\ndocument";
106    /// let chunks = splitter.chunks(text).collect::<Vec<_>>();
107    ///
108    /// assert_eq!(vec!["Some text", "from a", "document"], chunks);
109    /// ```
110    pub fn chunks<'splitter, 'text: 'splitter>(
111        &'splitter self,
112        text: &'text str,
113    ) -> impl Iterator<Item = &'text str> + 'splitter {
114        Splitter::<_>::chunks(self, text)
115    }
116
117    /// Returns an iterator over chunks of the text and their byte offsets.
118    /// Each chunk will be up to the `chunk_capacity`.
119    ///
120    /// See [`CodeSplitter::chunks`] for more information.
121    ///
122    /// ```
123    /// use text_splitter::{ChunkCharIndex, CodeSplitter};
124    ///
125    /// let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 10).expect("Invalid language");
126    /// let text = "Some text\n\nfrom a\ndocument";
127    /// let chunks = splitter.chunk_indices(text).collect::<Vec<_>>();
128    ///
129    /// assert_eq!(vec![(0, "Some text"), (11, "from a"), (18, "document")], chunks);
130    /// ```
131    pub fn chunk_indices<'splitter, 'text: 'splitter>(
132        &'splitter self,
133        text: &'text str,
134    ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
135        Splitter::<_>::chunk_indices(self, text)
136    }
137
138    /// Returns an iterator over chunks of the text with their byte and character offsets.
139    /// Each chunk will be up to the `chunk_capacity`.
140    ///
141    /// See [`CodeSplitter::chunks`] for more information.
142    ///
143    /// This will be more expensive than just byte offsets, and for most usage in Rust, just
144    /// having byte offsets is sufficient. But when interfacing with other languages or systems
145    /// that require character offsets, this will track the character offsets for you,
146    /// accounting for any trimming that may have occurred.
147    ///
148    /// ```
149    /// use text_splitter::{ChunkCharIndex, CodeSplitter};
150    ///
151    /// let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 10).expect("Invalid language");
152    /// let text = "Some text\n\nfrom a\ndocument";
153    /// let chunks = splitter.chunk_char_indices(text).collect::<Vec<_>>();
154    ///
155    /// assert_eq!(vec![ChunkCharIndex {chunk: "Some text", byte_offset: 0, char_offset: 0}, ChunkCharIndex {chunk: "from a", byte_offset: 11, char_offset: 11}, ChunkCharIndex {chunk: "document", byte_offset: 18, char_offset: 18}], chunks);
156    /// ```
157    pub fn chunk_char_indices<'splitter, 'text: 'splitter>(
158        &'splitter self,
159        text: &'text str,
160    ) -> impl Iterator<Item = ChunkCharIndex<'text>> + 'splitter {
161        Splitter::<_>::chunk_char_indices(self, text)
162    }
163}
164
165impl<Sizer> Splitter<Sizer> for CodeSplitter<Sizer>
166where
167    Sizer: ChunkSizer,
168{
169    type Level = Depth;
170
171    const TRIM: Trim = Trim::PreserveIndentation;
172
173    fn chunk_config(&self) -> &ChunkConfig<Sizer> {
174        &self.chunk_config
175    }
176
177    fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
178        let mut parser = Parser::new();
179        parser
180            .set_language(&self.language)
181            // We verify at initialization that the language is valid, so this should be safe.
182            .expect("Error loading language");
183        // The only reason the tree would be None is:
184        // - No language was set (we do that)
185        // - There was a timeout or cancellation option set (we don't)
186        // - So it should be safe to unwrap here
187        let tree = parser.parse(text, None).expect("Error parsing source code");
188
189        CursorOffsets::new(tree.walk()).collect()
190    }
191}
192
193/// New type around a usize to capture the depth of a given code node.
194/// Custom type so that we can implement custom ordering, since we want to
195/// sort items of lower depth as higher priority.
196#[derive(Clone, Copy, Debug, Eq, PartialEq)]
197pub struct Depth(usize);
198
199impl PartialOrd for Depth {
200    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
201        Some(self.cmp(other))
202    }
203}
204
205impl Ord for Depth {
206    fn cmp(&self, other: &Self) -> Ordering {
207        other.0.cmp(&self.0)
208    }
209}
210
211/// New type around a tree-sitter cursor to allow for implementing an iterator.
212/// Each call to `next()` will return the next node in the tree in a depth-first
213/// order.
214struct CursorOffsets<'cursor> {
215    cursor: TreeCursor<'cursor>,
216}
217
218impl<'cursor> CursorOffsets<'cursor> {
219    fn new(cursor: TreeCursor<'cursor>) -> Self {
220        Self { cursor }
221    }
222}
223
224impl Iterator for CursorOffsets<'_> {
225    type Item = (Depth, Range<usize>);
226
227    fn next(&mut self) -> Option<Self::Item> {
228        // There are children (can call this initially because we don't want the root node)
229        if self.cursor.goto_first_child() {
230            return Some((
231                Depth(self.cursor.depth() as usize),
232                self.cursor.node().byte_range(),
233            ));
234        }
235
236        loop {
237            // There are sibling elements to grab
238            if self.cursor.goto_next_sibling() {
239                return Some((
240                    Depth(self.cursor.depth() as usize),
241                    self.cursor.node().byte_range(),
242                ));
243            // Start going back up the tree and check for next sibling on next iteration.
244            } else if self.cursor.goto_parent() {
245                continue;
246            }
247
248            // We have no more siblings or parents, so we're done.
249            return None;
250        }
251    }
252}
253
254impl SemanticLevel for Depth {}
255
256#[cfg(test)]
257mod tests {
258    use tree_sitter::{Node, Tree};
259
260    use super::*;
261
262    #[test]
263    fn rust_splitter() {
264        let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 16).unwrap();
265        let text = "fn main() {\n    let x = 5;\n}";
266        let chunks = splitter.chunks(text).collect::<Vec<_>>();
267
268        assert_eq!(chunks, vec!["fn main()", "{\n    let x = 5;", "}"]);
269    }
270
271    #[test]
272    fn rust_splitter_indices() {
273        let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 16).unwrap();
274        let text = "fn main() {\n    let x = 5;\n}";
275        let chunks = splitter.chunk_indices(text).collect::<Vec<_>>();
276
277        assert_eq!(
278            chunks,
279            vec![(0, "fn main()"), (10, "{\n    let x = 5;"), (27, "}")]
280        );
281    }
282
283    #[test]
284    fn rust_splitter_char_indices() {
285        let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 16).unwrap();
286        let text = "fn main() {\n    let x = 5;\n}";
287        let chunks = splitter.chunk_char_indices(text).collect::<Vec<_>>();
288
289        assert_eq!(
290            chunks,
291            vec![
292                ChunkCharIndex {
293                    chunk: "fn main()",
294                    byte_offset: 0,
295                    char_offset: 0
296                },
297                ChunkCharIndex {
298                    chunk: "{\n    let x = 5;",
299                    byte_offset: 10,
300                    char_offset: 10
301                },
302                ChunkCharIndex {
303                    chunk: "}",
304                    byte_offset: 27,
305                    char_offset: 27
306                }
307            ]
308        );
309    }
310
311    #[test]
312    fn depth_partialord() {
313        assert_eq!(Depth(0).partial_cmp(&Depth(1)), Some(Ordering::Greater));
314        assert_eq!(Depth(1).partial_cmp(&Depth(2)), Some(Ordering::Greater));
315        assert_eq!(Depth(1).partial_cmp(&Depth(1)), Some(Ordering::Equal));
316        assert_eq!(Depth(2).partial_cmp(&Depth(1)), Some(Ordering::Less));
317    }
318
319    #[test]
320    fn depth_ord() {
321        assert_eq!(Depth(0).cmp(&Depth(1)), Ordering::Greater);
322        assert_eq!(Depth(1).cmp(&Depth(2)), Ordering::Greater);
323        assert_eq!(Depth(1).cmp(&Depth(1)), Ordering::Equal);
324        assert_eq!(Depth(2).cmp(&Depth(1)), Ordering::Less);
325    }
326
327    #[test]
328    fn depth_sorting() {
329        let mut depths = vec![Depth(0), Depth(1), Depth(2)];
330        depths.sort();
331        assert_eq!(depths, [Depth(2), Depth(1), Depth(0)]);
332    }
333
334    /// Checks that the optimized version of the code produces the same results as the naive version.
335    #[test]
336    fn optimized_code_offsets() {
337        let mut parser = Parser::new();
338        parser
339            .set_language(&tree_sitter_rust::LANGUAGE.into())
340            .expect("Error loading Rust grammar");
341        let source_code = "fn test() {
342    let x = 1;
343}";
344        let tree = parser
345            .parse(source_code, None)
346            .expect("Error parsing source code");
347
348        let offsets = CursorOffsets::new(tree.walk()).collect::<Vec<_>>();
349
350        assert_eq!(offsets, naive_offsets(&tree));
351    }
352
353    #[test]
354    fn multiple_top_siblings() {
355        let mut parser = Parser::new();
356        parser
357            .set_language(&tree_sitter_rust::LANGUAGE.into())
358            .expect("Error loading Rust grammar");
359        let source_code = "
360fn fn1() {}
361fn fn2() {}
362fn fn3() {}
363fn fn4() {}";
364        let tree = parser
365            .parse(source_code, None)
366            .expect("Error parsing source code");
367
368        let offsets = CursorOffsets::new(tree.walk()).collect::<Vec<_>>();
369
370        assert_eq!(offsets, naive_offsets(&tree));
371    }
372
373    fn naive_offsets(tree: &Tree) -> Vec<(Depth, Range<usize>)> {
374        let root_node = tree.root_node();
375        let mut offsets = vec![];
376        recursive_naive_offsets(&mut offsets, root_node, 0);
377        offsets
378    }
379
380    // Basic version to compare an optimized version against. According to the tree-sitter
381    // documentation, this is not efficient for large trees. But because it is the easiest
382    // to reason about it is a good check for correctness.
383    fn recursive_naive_offsets(
384        collection: &mut Vec<(Depth, Range<usize>)>,
385        node: Node<'_>,
386        depth: usize,
387    ) {
388        // We can skip the root node
389        if depth > 0 {
390            collection.push((Depth(depth), node.byte_range()));
391        }
392
393        for child in node.children(&mut node.walk()) {
394            recursive_naive_offsets(collection, child, depth + 1);
395        }
396    }
397}