splitter_tree_sitter/
lib.rs

1use thiserror::Error;
2use tree_sitter::{Tree, TreeCursor};
3
4#[derive(Error, Debug)]
5pub enum NewError {
6    #[error("chunk_size must be greater than chunk_overlap")]
7    SizeOverlapError,
8}
9
10#[derive(Error, Debug)]
11pub enum SplitError {
12    #[error("converting utf8 to str")]
13    Utf8Error(#[from] core::str::Utf8Error),
14}
15
16pub struct TreeSitterCodeSplitter {
17    chunk_size: usize,
18    chunk_overlap: usize,
19}
20
21pub struct ByteRange {
22    pub start_byte: usize,
23    pub end_byte: usize,
24}
25
26impl ByteRange {
27    fn new(start_byte: usize, end_byte: usize) -> Self {
28        Self {
29            start_byte,
30            end_byte,
31        }
32    }
33}
34
35pub struct Chunk<'a> {
36    pub text: &'a str,
37    pub range: ByteRange,
38}
39
40impl<'a> Chunk<'a> {
41    fn new(text: &'a str, range: ByteRange) -> Self {
42        Self { text, range }
43    }
44}
45
46impl TreeSitterCodeSplitter {
47    pub fn new(chunk_size: usize, chunk_overlap: usize) -> Result<Self, NewError> {
48        if chunk_overlap > chunk_size {
49            Err(NewError::SizeOverlapError)
50        } else {
51            Ok(Self {
52                chunk_size,
53                chunk_overlap,
54            })
55        }
56    }
57
58    pub fn split<'c>(&self, tree: &Tree, utf8: &'c [u8]) -> Result<Vec<Chunk<'c>>, SplitError> {
59        let cursor = tree.walk();
60        Ok(self
61            .split_recursive(cursor, utf8)?
62            .into_iter()
63            .rev()
64            // Let's combine some of our smaller chunks together
65            // We also want to do this in reverse as it (seems) to make more sense to combine code slices from bottom to top
66            .try_fold(vec![], |mut acc, current| {
67                if acc.is_empty() {
68                    acc.push(current);
69                    Ok::<_, SplitError>(acc)
70                } else {
71                    if acc.last().as_ref().unwrap().text.len() + current.text.len()
72                        < self.chunk_size
73                    {
74                        let last = acc.pop().unwrap();
75                        let text = std::str::from_utf8(
76                            &utf8[current.range.start_byte..last.range.end_byte],
77                        )?;
78                        acc.push(Chunk::new(
79                            text,
80                            ByteRange::new(current.range.start_byte, last.range.end_byte),
81                        ));
82                    } else {
83                        acc.push(current);
84                    }
85                    Ok(acc)
86                }
87            })?
88            .into_iter()
89            .rev()
90            .collect())
91    }
92
93    fn split_recursive<'c>(
94        &self,
95        mut cursor: TreeCursor<'_>,
96        utf8: &'c [u8],
97    ) -> Result<Vec<Chunk<'c>>, SplitError> {
98        let node = cursor.node();
99        let text = node.utf8_text(utf8)?;
100
101        // There are three cases:
102        // 1. Is the current range of code smaller than the chunk_size? If so, return it
103        // 2. If not, does the current node have children? If so, recursively walk down
104        // 3. If not, we must split our current node
105        let mut out = if text.chars().count() <= self.chunk_size {
106            vec![Chunk::new(
107                text,
108                ByteRange::new(node.range().start_byte, node.range().end_byte),
109            )]
110        } else {
111            let mut cursor_copy = cursor.clone();
112            if cursor_copy.goto_first_child() {
113                self.split_recursive(cursor_copy, utf8)?
114            } else {
115                let mut current_range =
116                    ByteRange::new(node.range().start_byte, node.range().end_byte);
117                let mut chunks = vec![];
118                let mut current_chunk = text;
119                loop {
120                    if current_chunk.len() < self.chunk_size {
121                        chunks.push(Chunk::new(current_chunk, current_range));
122                        break;
123                    } else {
124                        let new_chunk = &current_chunk[0..self.chunk_size.min(current_chunk.len())];
125                        let new_range = ByteRange::new(
126                            current_range.start_byte,
127                            current_range.start_byte + new_chunk.as_bytes().len(),
128                        );
129                        chunks.push(Chunk::new(new_chunk, new_range));
130                        let new_current_chunk =
131                            &current_chunk[self.chunk_size - self.chunk_overlap..];
132                        let byte_diff =
133                            current_chunk.as_bytes().len() - new_current_chunk.as_bytes().len();
134                        current_range = ByteRange::new(
135                            current_range.start_byte + byte_diff,
136                            current_range.end_byte,
137                        );
138                        current_chunk = new_current_chunk
139                    }
140                }
141                chunks
142            }
143        };
144        if cursor.goto_next_sibling() {
145            out.append(&mut self.split_recursive(cursor, utf8)?);
146        }
147        Ok(out)
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use tree_sitter::Parser;
155
156    #[test]
157    fn test_split_rust() {
158        let splitter = TreeSitterCodeSplitter::new(128, 0).unwrap();
159
160        let mut parser = Parser::new();
161        parser
162            .set_language(&tree_sitter_rust::language())
163            .expect("Error loading Rust grammar");
164
165        let source_code = r#"
166#[derive(Debug)]
167struct Rectangle {
168    width: u32,
169    height: u32,
170}
171
172impl Rectangle {
173    fn area(&self) -> u32 {
174        self.width * self.height
175    }
176}
177
178fn main() {
179    let rect1 = Rectangle {
180        width: 30,
181        height: 50,
182    };
183
184    println!(
185        "The area of the rectangle is {} square pixels.",
186        rect1.area()
187    );
188}
189"#;
190        let tree = parser.parse(source_code, None).unwrap();
191        let chunks = splitter.split(&tree, source_code.as_bytes()).unwrap();
192        assert_eq!(
193            chunks[0].text,
194            r#"#[derive(Debug)]
195struct Rectangle {
196    width: u32,
197    height: u32,
198}"#
199        );
200        assert_eq!(
201            chunks[1].text,
202            r#"impl Rectangle {
203    fn area(&self) -> u32 {
204        self.width * self.height
205    }
206}"#
207        );
208        assert_eq!(
209            chunks[2].text,
210            r#"fn main() {
211    let rect1 = Rectangle {
212        width: 30,
213        height: 50,
214    };"#
215        );
216        assert_eq!(
217            chunks[3].text,
218            r#"println!(
219        "The area of the rectangle is {} square pixels.",
220        rect1.area()
221    );
222}"#
223        );
224    }
225
226    #[test]
227    fn test_split_zig() {
228        let splitter = TreeSitterCodeSplitter::new(128, 10).unwrap();
229
230        let mut parser = Parser::new();
231        parser
232            .set_language(&tree_sitter_rust::language())
233            .expect("Error loading Rust grammar");
234
235        let source_code = r#"
236const std = @import("std");
237const parseInt = std.fmt.parseInt;
238
239std.debug.print("Here is a long string 1 ... Here is a long string 2 ... Here is a long string 3 ... Here is a long string 4 ... Here is a long string 5 ... Here is a long string 6 ... Here is a long string 7 ... Here is a long string 8 ... Here is a long string 9 ...", .{});
240
241test "parse integers" {
242    const input = "123 67 89,99";
243    const ally = std.testing.allocator;
244
245    var list = std.ArrayList(u32).init(ally);
246    // Ensure the list is freed at scope exit.
247    // Try commenting out this line!
248    defer list.deinit();
249
250    var it = std.mem.tokenizeAny(u8, input, " ,");
251    while (it.next()) |num| {
252        const n = try parseInt(u32, num, 10);
253        try list.append(n);
254    }
255
256    const expected = [_]u32{ 123, 67, 89, 99 };
257
258    for (expected, list.items) |exp, actual| {
259        try std.testing.expectEqual(exp, actual);
260    }
261}
262"#;
263        let tree = parser.parse(source_code, None).unwrap();
264        let chunks = splitter.split(&tree, source_code.as_bytes()).unwrap();
265
266        assert_eq!(
267            chunks[0].text,
268            r#"const std = @import("std");
269const parseInt = std.fmt.parseInt;
270
271std.debug.print(""#
272        );
273        assert_eq!(
274            chunks[1].text,
275            r#"Here is a long string 1 ... Here is a long string 2 ... Here is a long string 3 ... Here is a long string 4 ... Here is a long s"#
276        );
277        assert_eq!(
278            chunks[2].text,
279            r#"s a long string 5 ... Here is a long string 6 ... Here is a long string 7 ... Here is a long string 8 ... Here is a long string "#
280        );
281        assert_eq!(chunks[3].text, r#"ng string 9 ...", .{});"#);
282        assert_eq!(
283            chunks[4].text,
284            r#"test "parse integers" {
285    const input = "123 67 89,99";
286    const ally = std.testing.allocator;
287
288    var list = std.ArrayList"#
289        );
290        assert_eq!(
291            chunks[5].text,
292            r#"(u32).init(ally);
293    // Ensure the list is freed at scope exit.
294    // Try commenting out this line!"#
295        );
296        assert_eq!(
297            chunks[6].text,
298            r#"defer list.deinit();
299
300    var it = std.mem.tokenizeAny(u8, input, " ,");
301    while (it.next()) |num"#
302        );
303        assert_eq!(
304            chunks[7].text,
305            r#"| {
306        const n = try parseInt(u32, num, 10);
307        try list.append(n);
308    }
309
310    const expected = [_]u32{ 123, 67, 89,"#
311        );
312        assert_eq!(
313            chunks[8].text,
314            r#"99 };
315
316    for (expected, list.items) |exp, actual| {
317        try std.testing.expectEqual(exp, actual);
318    }
319}"#
320        );
321    }
322}