1use thiserror::Error;
2use tree_sitter::{Tree, TreeCursor};
3
4#[derive(Error, Debug)]
5pub enum NewError {
6 #[error("chunk_size must be greater than chunk_overlap")]
7 SizeOverlapError,
8}
9
10#[derive(Error, Debug)]
11pub enum SplitError {
12 #[error("converting utf8 to str")]
13 Utf8Error(#[from] core::str::Utf8Error),
14}
15
16pub struct TreeSitterCodeSplitter {
17 chunk_size: usize,
18 chunk_overlap: usize,
19}
20
21pub struct ByteRange {
22 pub start_byte: usize,
23 pub end_byte: usize,
24}
25
26impl ByteRange {
27 fn new(start_byte: usize, end_byte: usize) -> Self {
28 Self {
29 start_byte,
30 end_byte,
31 }
32 }
33}
34
35pub struct Chunk<'a> {
36 pub text: &'a str,
37 pub range: ByteRange,
38}
39
40impl<'a> Chunk<'a> {
41 fn new(text: &'a str, range: ByteRange) -> Self {
42 Self { text, range }
43 }
44}
45
46impl TreeSitterCodeSplitter {
47 pub fn new(chunk_size: usize, chunk_overlap: usize) -> Result<Self, NewError> {
48 if chunk_overlap > chunk_size {
49 Err(NewError::SizeOverlapError)
50 } else {
51 Ok(Self {
52 chunk_size,
53 chunk_overlap,
54 })
55 }
56 }
57
58 pub fn split<'c>(&self, tree: &Tree, utf8: &'c [u8]) -> Result<Vec<Chunk<'c>>, SplitError> {
59 let cursor = tree.walk();
60 Ok(self
61 .split_recursive(cursor, utf8)?
62 .into_iter()
63 .rev()
64 .try_fold(vec![], |mut acc, current| {
67 if acc.is_empty() {
68 acc.push(current);
69 Ok::<_, SplitError>(acc)
70 } else {
71 if acc.last().as_ref().unwrap().text.len() + current.text.len()
72 < self.chunk_size
73 {
74 let last = acc.pop().unwrap();
75 let text = std::str::from_utf8(
76 &utf8[current.range.start_byte..last.range.end_byte],
77 )?;
78 acc.push(Chunk::new(
79 text,
80 ByteRange::new(current.range.start_byte, last.range.end_byte),
81 ));
82 } else {
83 acc.push(current);
84 }
85 Ok(acc)
86 }
87 })?
88 .into_iter()
89 .rev()
90 .collect())
91 }
92
93 fn split_recursive<'c>(
94 &self,
95 mut cursor: TreeCursor<'_>,
96 utf8: &'c [u8],
97 ) -> Result<Vec<Chunk<'c>>, SplitError> {
98 let node = cursor.node();
99 let text = node.utf8_text(utf8)?;
100
101 let mut out = if text.chars().count() <= self.chunk_size {
106 vec![Chunk::new(
107 text,
108 ByteRange::new(node.range().start_byte, node.range().end_byte),
109 )]
110 } else {
111 let mut cursor_copy = cursor.clone();
112 if cursor_copy.goto_first_child() {
113 self.split_recursive(cursor_copy, utf8)?
114 } else {
115 let mut current_range =
116 ByteRange::new(node.range().start_byte, node.range().end_byte);
117 let mut chunks = vec![];
118 let mut current_chunk = text;
119 loop {
120 if current_chunk.len() < self.chunk_size {
121 chunks.push(Chunk::new(current_chunk, current_range));
122 break;
123 } else {
124 let new_chunk = ¤t_chunk[0..self.chunk_size.min(current_chunk.len())];
125 let new_range = ByteRange::new(
126 current_range.start_byte,
127 current_range.start_byte + new_chunk.as_bytes().len(),
128 );
129 chunks.push(Chunk::new(new_chunk, new_range));
130 let new_current_chunk =
131 ¤t_chunk[self.chunk_size - self.chunk_overlap..];
132 let byte_diff =
133 current_chunk.as_bytes().len() - new_current_chunk.as_bytes().len();
134 current_range = ByteRange::new(
135 current_range.start_byte + byte_diff,
136 current_range.end_byte,
137 );
138 current_chunk = new_current_chunk
139 }
140 }
141 chunks
142 }
143 };
144 if cursor.goto_next_sibling() {
145 out.append(&mut self.split_recursive(cursor, utf8)?);
146 }
147 Ok(out)
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use tree_sitter::Parser;
155
156 #[test]
157 fn test_split_rust() {
158 let splitter = TreeSitterCodeSplitter::new(128, 0).unwrap();
159
160 let mut parser = Parser::new();
161 parser
162 .set_language(&tree_sitter_rust::language())
163 .expect("Error loading Rust grammar");
164
165 let source_code = r#"
166#[derive(Debug)]
167struct Rectangle {
168 width: u32,
169 height: u32,
170}
171
172impl Rectangle {
173 fn area(&self) -> u32 {
174 self.width * self.height
175 }
176}
177
178fn main() {
179 let rect1 = Rectangle {
180 width: 30,
181 height: 50,
182 };
183
184 println!(
185 "The area of the rectangle is {} square pixels.",
186 rect1.area()
187 );
188}
189"#;
190 let tree = parser.parse(source_code, None).unwrap();
191 let chunks = splitter.split(&tree, source_code.as_bytes()).unwrap();
192 assert_eq!(
193 chunks[0].text,
194 r#"#[derive(Debug)]
195struct Rectangle {
196 width: u32,
197 height: u32,
198}"#
199 );
200 assert_eq!(
201 chunks[1].text,
202 r#"impl Rectangle {
203 fn area(&self) -> u32 {
204 self.width * self.height
205 }
206}"#
207 );
208 assert_eq!(
209 chunks[2].text,
210 r#"fn main() {
211 let rect1 = Rectangle {
212 width: 30,
213 height: 50,
214 };"#
215 );
216 assert_eq!(
217 chunks[3].text,
218 r#"println!(
219 "The area of the rectangle is {} square pixels.",
220 rect1.area()
221 );
222}"#
223 );
224 }
225
226 #[test]
227 fn test_split_zig() {
228 let splitter = TreeSitterCodeSplitter::new(128, 10).unwrap();
229
230 let mut parser = Parser::new();
231 parser
232 .set_language(&tree_sitter_rust::language())
233 .expect("Error loading Rust grammar");
234
235 let source_code = r#"
236const std = @import("std");
237const parseInt = std.fmt.parseInt;
238
239std.debug.print("Here is a long string 1 ... Here is a long string 2 ... Here is a long string 3 ... Here is a long string 4 ... Here is a long string 5 ... Here is a long string 6 ... Here is a long string 7 ... Here is a long string 8 ... Here is a long string 9 ...", .{});
240
241test "parse integers" {
242 const input = "123 67 89,99";
243 const ally = std.testing.allocator;
244
245 var list = std.ArrayList(u32).init(ally);
246 // Ensure the list is freed at scope exit.
247 // Try commenting out this line!
248 defer list.deinit();
249
250 var it = std.mem.tokenizeAny(u8, input, " ,");
251 while (it.next()) |num| {
252 const n = try parseInt(u32, num, 10);
253 try list.append(n);
254 }
255
256 const expected = [_]u32{ 123, 67, 89, 99 };
257
258 for (expected, list.items) |exp, actual| {
259 try std.testing.expectEqual(exp, actual);
260 }
261}
262"#;
263 let tree = parser.parse(source_code, None).unwrap();
264 let chunks = splitter.split(&tree, source_code.as_bytes()).unwrap();
265
266 assert_eq!(
267 chunks[0].text,
268 r#"const std = @import("std");
269const parseInt = std.fmt.parseInt;
270
271std.debug.print(""#
272 );
273 assert_eq!(
274 chunks[1].text,
275 r#"Here is a long string 1 ... Here is a long string 2 ... Here is a long string 3 ... Here is a long string 4 ... Here is a long s"#
276 );
277 assert_eq!(
278 chunks[2].text,
279 r#"s a long string 5 ... Here is a long string 6 ... Here is a long string 7 ... Here is a long string 8 ... Here is a long string "#
280 );
281 assert_eq!(chunks[3].text, r#"ng string 9 ...", .{});"#);
282 assert_eq!(
283 chunks[4].text,
284 r#"test "parse integers" {
285 const input = "123 67 89,99";
286 const ally = std.testing.allocator;
287
288 var list = std.ArrayList"#
289 );
290 assert_eq!(
291 chunks[5].text,
292 r#"(u32).init(ally);
293 // Ensure the list is freed at scope exit.
294 // Try commenting out this line!"#
295 );
296 assert_eq!(
297 chunks[6].text,
298 r#"defer list.deinit();
299
300 var it = std.mem.tokenizeAny(u8, input, " ,");
301 while (it.next()) |num"#
302 );
303 assert_eq!(
304 chunks[7].text,
305 r#"| {
306 const n = try parseInt(u32, num, 10);
307 try list.append(n);
308 }
309
310 const expected = [_]u32{ 123, 67, 89,"#
311 );
312 assert_eq!(
313 chunks[8].text,
314 r#"99 };
315
316 for (expected, list.items) |exp, actual| {
317 try std.testing.expectEqual(exp, actual);
318 }
319}"#
320 );
321 }
322}