1use crate::language_registry;
2use crate::test_mapper::SourceFile;
3use crate::types::{ExtractedFunction, Language};
4use crate::Result;
5use crate::TestGapError;
6use streaming_iterator::StreamingIterator;
7
8pub fn extract_functions(file: &SourceFile) -> Result<Vec<ExtractedFunction>> {
10 let source = std::fs::read_to_string(&file.path).map_err(TestGapError::Io)?;
11 let lang = file.language;
12 let ts_language = language_registry::get_language(lang);
13
14 let mut parser = tree_sitter::Parser::new();
15 parser
16 .set_language(&ts_language)
17 .map_err(|e| TestGapError::Parse {
18 file: file.path.display().to_string(),
19 message: e.to_string(),
20 })?;
21
22 let tree = parser
23 .parse(&source, None)
24 .ok_or_else(|| TestGapError::Parse {
25 file: file.path.display().to_string(),
26 message: "Failed to parse file".into(),
27 })?;
28
29 let query_src = language_registry::function_query(lang);
30 let query =
31 tree_sitter::Query::new(&ts_language, query_src).map_err(|e| TestGapError::Parse {
32 file: file.path.display().to_string(),
33 message: format!("Query error: {e}"),
34 })?;
35
36 let mut cursor = tree_sitter::QueryCursor::new();
37 let mut matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
38
39 let name_idx = query
40 .capture_index_for_name("name")
41 .expect("query must have @name capture");
42 let func_idx = query
43 .capture_index_for_name("function")
44 .expect("query must have @function capture");
45
46 let mut functions = Vec::new();
47
48 while let Some(m) = matches.next() {
49 let mut name_node = None;
50 let mut func_node = None;
51
52 for cap in m.captures {
53 if cap.index == name_idx {
54 name_node = Some(cap.node);
55 } else if cap.index == func_idx {
56 func_node = Some(cap.node);
57 }
58 }
59
60 let (Some(name_n), Some(func_n)) = (name_node, func_node) else {
61 continue;
62 };
63
64 let name: String = name_n
65 .utf8_text(source.as_bytes())
66 .unwrap_or("")
67 .to_string();
68
69 if name.is_empty() {
70 continue;
71 }
72
73 let body: String = func_n
74 .utf8_text(source.as_bytes())
75 .unwrap_or("")
76 .to_string();
77
78 let line_start = func_n.start_position().row + 1;
79 let line_end = func_n.end_position().row + 1;
80
81 let signature = extract_signature(&source, func_n, lang);
82 let is_public = check_visibility(&source, func_n, lang);
83 let is_test = check_is_test(&name, &source, func_n, lang, file.is_test);
84 let complexity = estimate_complexity(&body);
85
86 functions.push(ExtractedFunction {
87 name,
88 file_path: file.path.clone(),
89 line_start,
90 line_end,
91 signature,
92 body,
93 language: lang,
94 is_public,
95 is_test,
96 complexity,
97 });
98 }
99
100 functions.sort_by(|a, b| {
102 a.file_path
103 .cmp(&b.file_path)
104 .then_with(|| a.line_start.cmp(&b.line_start))
105 .then_with(|| a.name.cmp(&b.name))
106 });
107 functions.dedup_by(|a, b| {
108 a.file_path == b.file_path && a.line_start == b.line_start && a.name == b.name
109 });
110
111 Ok(functions)
112}
113
114fn extract_signature(source: &str, node: tree_sitter::Node, lang: Language) -> String {
115 let text = node.utf8_text(source.as_bytes()).unwrap_or("");
116 match lang {
117 Language::Rust => {
118 if let Some(brace_pos) = text.find('{') {
120 text[..brace_pos].trim().to_string()
121 } else {
122 text.lines().next().unwrap_or("").to_string()
123 }
124 }
125 Language::Go => {
126 if let Some(brace_pos) = text.find('{') {
127 text[..brace_pos].trim().to_string()
128 } else {
129 text.lines().next().unwrap_or("").to_string()
130 }
131 }
132 Language::Python => {
133 let first_line = text.lines().next().unwrap_or("");
135 if let Some(colon_pos) = first_line.rfind(':') {
136 first_line[..colon_pos].trim().to_string()
137 } else {
138 first_line.to_string()
139 }
140 }
141 Language::JavaScript | Language::TypeScript => {
142 if let Some(brace_pos) = text.find('{') {
144 text[..brace_pos].trim().to_string()
145 } else if let Some(arrow_pos) = text.find("=>") {
146 text[..arrow_pos + 2].trim().to_string()
147 } else {
148 text.lines().next().unwrap_or("").to_string()
149 }
150 }
151 }
152}
153
154fn check_visibility(source: &str, node: tree_sitter::Node, lang: Language) -> bool {
155 match lang {
156 Language::Rust => {
157 let text = node.utf8_text(source.as_bytes()).unwrap_or("");
158 text.starts_with("pub ")
159 || text.starts_with("pub(crate)")
160 || text.starts_with("pub(super)")
161 }
162 Language::Python => {
163 let text = node.utf8_text(source.as_bytes()).unwrap_or("");
164 if let Some(line) = text.lines().next() {
166 if let Some(name_start) = line.find("def ") {
167 let after_def = &line[name_start + 4..];
168 return !after_def.starts_with('_');
169 }
170 }
171 true
172 }
173 Language::Go => {
174 let text = node.utf8_text(source.as_bytes()).unwrap_or("");
176 if let Some(func_pos) = text.find("func ") {
177 let after_func = text[func_pos + 5..].trim_start();
178 let name_part = if after_func.starts_with('(') {
180 if let Some(paren_end) = after_func.find(") ") {
181 after_func[paren_end + 2..].trim_start()
182 } else {
183 after_func
184 }
185 } else {
186 after_func
187 };
188 name_part.chars().next().is_some_and(|c| c.is_uppercase())
189 } else {
190 false
191 }
192 }
193 Language::JavaScript | Language::TypeScript => {
194 if let Some(parent) = node.parent() {
196 let kind = parent.kind();
197 kind == "export_statement" || kind == "export_default_declaration"
198 } else {
199 true
201 }
202 }
203 }
204}
205
206fn check_is_test(
207 name: &str,
208 source: &str,
209 node: tree_sitter::Node,
210 lang: Language,
211 is_test_file: bool,
212) -> bool {
213 if is_test_file {
214 return true;
215 }
216
217 match lang {
218 Language::Rust => {
219 let mut sibling = node.prev_sibling();
221 while let Some(prev) = sibling {
222 if prev.kind() == "attribute_item" {
223 let attr_text = prev.utf8_text(source.as_bytes()).unwrap_or("");
224 if attr_text.contains("test") {
225 return true;
226 }
227 } else {
228 break;
229 }
230 sibling = prev.prev_sibling();
231 }
232 name.starts_with("test_")
233 }
234 Language::Python => name.starts_with("test_"),
235 Language::Go => name.starts_with("Test") || name.starts_with("Benchmark"),
236 Language::JavaScript | Language::TypeScript => {
237 name == "it" || name == "test" || name == "describe"
238 }
239 }
240}
241
242fn estimate_complexity(body: &str) -> u32 {
243 let keywords = [
245 "if ", "else ", "else{", "match ", "for ", "while ", "loop ", "case ", "catch ", "except ",
246 "elif ", "?", "&&", "||", "switch ",
247 ];
248 let mut complexity: u32 = 1; for kw in &keywords {
250 complexity += body.matches(kw).count() as u32;
251 }
252 complexity
253}
254
255#[cfg(test)]
256mod tests {
257 use crate::function_extractor::extract_functions;
258 use crate::test_mapper::SourceFile;
259 use crate::types::Language;
260 use std::io::Write;
261 use tempfile;
262
263 #[test]
264 fn parse_rust_snippet() {
265 let mut file = tempfile::Builder::new().suffix(".rs").tempfile().unwrap();
266 writeln!(
267 file,
268 r#"pub fn add(a: i32, b: i32) -> i32 {{
269 a + b
270}}
271
272pub fn complex_calc(x: i32) -> i32 {{
273 if x > 0 {{
274 for i in 0..x {{
275 if i % 2 == 0 {{
276 return i;
277 }}
278 }}
279 }}
280 x
281}}"#
282 )
283 .unwrap();
284 file.flush().unwrap();
285
286 let source = SourceFile {
287 path: file.path().to_path_buf(),
288 language: Language::Rust,
289 is_test: false,
290 };
291
292 let funcs = extract_functions(&source).unwrap();
293 assert!(
294 funcs.len() >= 2,
295 "expected at least 2 functions, got {}",
296 funcs.len()
297 );
298
299 let add_fn = funcs
300 .iter()
301 .find(|f| f.name == "add")
302 .expect("should find 'add'");
303 assert!(add_fn.is_public, "add should be public");
304 assert!(add_fn.line_start >= 1);
305 assert!(add_fn.line_end >= add_fn.line_start);
306 assert!(
307 add_fn.signature.contains("fn add"),
308 "signature should contain 'fn add', got: {}",
309 add_fn.signature
310 );
311
312 let complex_fn = funcs
313 .iter()
314 .find(|f| f.name == "complex_calc")
315 .expect("should find 'complex_calc'");
316 assert!(complex_fn.is_public, "complex_calc should be public");
317 }
318
319 #[test]
320 fn parse_typescript_snippet() {
321 let mut file = tempfile::Builder::new().suffix(".ts").tempfile().unwrap();
322 writeln!(
323 file,
324 r#"export function greet(name: string): string {{
325 return "hello " + name;
326}}"#
327 )
328 .unwrap();
329 file.flush().unwrap();
330
331 let source = SourceFile {
332 path: file.path().to_path_buf(),
333 language: Language::TypeScript,
334 is_test: false,
335 };
336
337 let funcs = extract_functions(&source).unwrap();
338 assert!(!funcs.is_empty(), "expected at least 1 function");
339
340 let greet_fn = funcs
341 .iter()
342 .find(|f| f.name == "greet")
343 .expect("should find 'greet'");
344 assert!(greet_fn.is_public, "exported function should be public");
345 assert_eq!(greet_fn.name, "greet");
346 }
347
348 #[test]
349 fn parse_python_snippet() {
350 let mut file = tempfile::Builder::new().suffix(".py").tempfile().unwrap();
351 write!(
352 file,
353 "def calculate(x, y):\n if x > 0:\n return x + y\n return y\n"
354 )
355 .unwrap();
356 file.flush().unwrap();
357
358 let source = SourceFile {
359 path: file.path().to_path_buf(),
360 language: Language::Python,
361 is_test: false,
362 };
363
364 let funcs = extract_functions(&source).unwrap();
365 assert!(!funcs.is_empty(), "expected at least 1 function");
366
367 let calc_fn = funcs
368 .iter()
369 .find(|f| f.name == "calculate")
370 .expect("should find 'calculate'");
371 assert!(
372 calc_fn.is_public,
373 "calculate should be public (no leading underscore)"
374 );
375 assert_eq!(calc_fn.name, "calculate");
376 }
377
378 #[test]
379 fn complexity_estimation_via_extract() {
380 let mut file = tempfile::Builder::new().suffix(".rs").tempfile().unwrap();
381 writeln!(
382 file,
383 r#"pub fn branchy(x: i32, y: i32) -> i32 {{
384 if x > 0 {{
385 if y > 0 {{
386 for i in 0..x {{
387 match i {{
388 0 => return 0,
389 _ => {{
390 if i > 5 && y < 10 {{
391 return i;
392 }}
393 }}
394 }}
395 }}
396 }}
397 }} else {{
398 while x > 0 {{
399 return y;
400 }}
401 }}
402 x + y
403}}"#
404 )
405 .unwrap();
406 file.flush().unwrap();
407
408 let source = SourceFile {
409 path: file.path().to_path_buf(),
410 language: Language::Rust,
411 is_test: false,
412 };
413
414 let funcs = extract_functions(&source).unwrap();
415 let branchy = funcs
416 .iter()
417 .find(|f| f.name == "branchy")
418 .expect("should find 'branchy'");
419 assert!(
420 branchy.complexity > 1,
421 "complex function should have complexity > 1, got {}",
422 branchy.complexity
423 );
424 }
425}