Skip to main content

testgap_core/
function_extractor.rs

1use crate::language_registry;
2use crate::test_mapper::SourceFile;
3use crate::types::{ExtractedFunction, Language};
4use crate::Result;
5use crate::TestGapError;
6use streaming_iterator::StreamingIterator;
7
8/// Extract all functions from a source file using tree-sitter.
9pub fn extract_functions(file: &SourceFile) -> Result<Vec<ExtractedFunction>> {
10    let source = std::fs::read_to_string(&file.path).map_err(TestGapError::Io)?;
11    let lang = file.language;
12    let ts_language = language_registry::get_language(lang);
13
14    let mut parser = tree_sitter::Parser::new();
15    parser
16        .set_language(&ts_language)
17        .map_err(|e| TestGapError::Parse {
18            file: file.path.display().to_string(),
19            message: e.to_string(),
20        })?;
21
22    let tree = parser
23        .parse(&source, None)
24        .ok_or_else(|| TestGapError::Parse {
25            file: file.path.display().to_string(),
26            message: "Failed to parse file".into(),
27        })?;
28
29    let query_src = language_registry::function_query(lang);
30    let query =
31        tree_sitter::Query::new(&ts_language, query_src).map_err(|e| TestGapError::Parse {
32            file: file.path.display().to_string(),
33            message: format!("Query error: {e}"),
34        })?;
35
36    let mut cursor = tree_sitter::QueryCursor::new();
37    let mut matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
38
39    let name_idx = query
40        .capture_index_for_name("name")
41        .expect("query must have @name capture");
42    let func_idx = query
43        .capture_index_for_name("function")
44        .expect("query must have @function capture");
45
46    let mut functions = Vec::new();
47
48    while let Some(m) = matches.next() {
49        let mut name_node = None;
50        let mut func_node = None;
51
52        for cap in m.captures {
53            if cap.index == name_idx {
54                name_node = Some(cap.node);
55            } else if cap.index == func_idx {
56                func_node = Some(cap.node);
57            }
58        }
59
60        let (Some(name_n), Some(func_n)) = (name_node, func_node) else {
61            continue;
62        };
63
64        let name: String = name_n
65            .utf8_text(source.as_bytes())
66            .unwrap_or("")
67            .to_string();
68
69        if name.is_empty() {
70            continue;
71        }
72
73        let body: String = func_n
74            .utf8_text(source.as_bytes())
75            .unwrap_or("")
76            .to_string();
77
78        let line_start = func_n.start_position().row + 1;
79        let line_end = func_n.end_position().row + 1;
80
81        let signature = extract_signature(&source, func_n, lang);
82        let is_public = check_visibility(&source, func_n, lang);
83        let is_test = check_is_test(&name, &source, func_n, lang, file.is_test);
84        let complexity = estimate_complexity(&body);
85
86        functions.push(ExtractedFunction {
87            name,
88            file_path: file.path.clone(),
89            line_start,
90            line_end,
91            signature,
92            body,
93            language: lang,
94            is_public,
95            is_test,
96            complexity,
97        });
98    }
99
100    // Sort then dedup by (file_path, line_start, name)
101    functions.sort_by(|a, b| {
102        a.file_path
103            .cmp(&b.file_path)
104            .then_with(|| a.line_start.cmp(&b.line_start))
105            .then_with(|| a.name.cmp(&b.name))
106    });
107    functions.dedup_by(|a, b| {
108        a.file_path == b.file_path && a.line_start == b.line_start && a.name == b.name
109    });
110
111    Ok(functions)
112}
113
114fn extract_signature(source: &str, node: tree_sitter::Node, lang: Language) -> String {
115    let text = node.utf8_text(source.as_bytes()).unwrap_or("");
116    match lang {
117        Language::Rust => {
118            // Take everything up to the opening brace
119            if let Some(brace_pos) = text.find('{') {
120                text[..brace_pos].trim().to_string()
121            } else {
122                text.lines().next().unwrap_or("").to_string()
123            }
124        }
125        Language::Go => {
126            if let Some(brace_pos) = text.find('{') {
127                text[..brace_pos].trim().to_string()
128            } else {
129                text.lines().next().unwrap_or("").to_string()
130            }
131        }
132        Language::Python => {
133            // Take the def line up to the colon (first line only)
134            let first_line = text.lines().next().unwrap_or("");
135            if let Some(colon_pos) = first_line.rfind(':') {
136                first_line[..colon_pos].trim().to_string()
137            } else {
138                first_line.to_string()
139            }
140        }
141        Language::JavaScript | Language::TypeScript => {
142            // Take the first line or up to opening brace
143            if let Some(brace_pos) = text.find('{') {
144                text[..brace_pos].trim().to_string()
145            } else if let Some(arrow_pos) = text.find("=>") {
146                text[..arrow_pos + 2].trim().to_string()
147            } else {
148                text.lines().next().unwrap_or("").to_string()
149            }
150        }
151    }
152}
153
154fn check_visibility(source: &str, node: tree_sitter::Node, lang: Language) -> bool {
155    match lang {
156        Language::Rust => {
157            let text = node.utf8_text(source.as_bytes()).unwrap_or("");
158            text.starts_with("pub ")
159                || text.starts_with("pub(crate)")
160                || text.starts_with("pub(super)")
161        }
162        Language::Python => {
163            let text = node.utf8_text(source.as_bytes()).unwrap_or("");
164            // Python: functions not starting with _ are public
165            if let Some(line) = text.lines().next() {
166                if let Some(name_start) = line.find("def ") {
167                    let after_def = &line[name_start + 4..];
168                    return !after_def.starts_with('_');
169                }
170            }
171            true
172        }
173        Language::Go => {
174            // Go: exported functions start with uppercase
175            let text = node.utf8_text(source.as_bytes()).unwrap_or("");
176            if let Some(func_pos) = text.find("func ") {
177                let after_func = text[func_pos + 5..].trim_start();
178                // Skip receiver for methods: (r *Type) Name
179                let name_part = if after_func.starts_with('(') {
180                    if let Some(paren_end) = after_func.find(") ") {
181                        after_func[paren_end + 2..].trim_start()
182                    } else {
183                        after_func
184                    }
185                } else {
186                    after_func
187                };
188                name_part.chars().next().is_some_and(|c| c.is_uppercase())
189            } else {
190                false
191            }
192        }
193        Language::JavaScript | Language::TypeScript => {
194            // Check if parent is an export statement
195            if let Some(parent) = node.parent() {
196                let kind = parent.kind();
197                kind == "export_statement" || kind == "export_default_declaration"
198            } else {
199                // Top-level functions without export — treat as module-public
200                true
201            }
202        }
203    }
204}
205
206fn check_is_test(
207    name: &str,
208    source: &str,
209    node: tree_sitter::Node,
210    lang: Language,
211    is_test_file: bool,
212) -> bool {
213    if is_test_file {
214        return true;
215    }
216
217    match lang {
218        Language::Rust => {
219            // Check for #[test] or #[cfg(test)] attribute (walk all preceding siblings)
220            let mut sibling = node.prev_sibling();
221            while let Some(prev) = sibling {
222                if prev.kind() == "attribute_item" {
223                    let attr_text = prev.utf8_text(source.as_bytes()).unwrap_or("");
224                    if attr_text.contains("test") {
225                        return true;
226                    }
227                } else {
228                    break;
229                }
230                sibling = prev.prev_sibling();
231            }
232            name.starts_with("test_")
233        }
234        Language::Python => name.starts_with("test_"),
235        Language::Go => name.starts_with("Test") || name.starts_with("Benchmark"),
236        Language::JavaScript | Language::TypeScript => {
237            name == "it" || name == "test" || name == "describe"
238        }
239    }
240}
241
242fn estimate_complexity(body: &str) -> u32 {
243    // Simple cyclomatic complexity estimate: count branching keywords
244    let keywords = [
245        "if ", "else ", "else{", "match ", "for ", "while ", "loop ", "case ", "catch ", "except ",
246        "elif ", "?", "&&", "||", "switch ",
247    ];
248    let mut complexity: u32 = 1; // base complexity
249    for kw in &keywords {
250        complexity += body.matches(kw).count() as u32;
251    }
252    complexity
253}
254
255#[cfg(test)]
256mod tests {
257    use crate::function_extractor::extract_functions;
258    use crate::test_mapper::SourceFile;
259    use crate::types::Language;
260    use std::io::Write;
261    use tempfile;
262
263    #[test]
264    fn parse_rust_snippet() {
265        let mut file = tempfile::Builder::new().suffix(".rs").tempfile().unwrap();
266        writeln!(
267            file,
268            r#"pub fn add(a: i32, b: i32) -> i32 {{
269    a + b
270}}
271
272pub fn complex_calc(x: i32) -> i32 {{
273    if x > 0 {{
274        for i in 0..x {{
275            if i % 2 == 0 {{
276                return i;
277            }}
278        }}
279    }}
280    x
281}}"#
282        )
283        .unwrap();
284        file.flush().unwrap();
285
286        let source = SourceFile {
287            path: file.path().to_path_buf(),
288            language: Language::Rust,
289            is_test: false,
290        };
291
292        let funcs = extract_functions(&source).unwrap();
293        assert!(
294            funcs.len() >= 2,
295            "expected at least 2 functions, got {}",
296            funcs.len()
297        );
298
299        let add_fn = funcs
300            .iter()
301            .find(|f| f.name == "add")
302            .expect("should find 'add'");
303        assert!(add_fn.is_public, "add should be public");
304        assert!(add_fn.line_start >= 1);
305        assert!(add_fn.line_end >= add_fn.line_start);
306        assert!(
307            add_fn.signature.contains("fn add"),
308            "signature should contain 'fn add', got: {}",
309            add_fn.signature
310        );
311
312        let complex_fn = funcs
313            .iter()
314            .find(|f| f.name == "complex_calc")
315            .expect("should find 'complex_calc'");
316        assert!(complex_fn.is_public, "complex_calc should be public");
317    }
318
319    #[test]
320    fn parse_typescript_snippet() {
321        let mut file = tempfile::Builder::new().suffix(".ts").tempfile().unwrap();
322        writeln!(
323            file,
324            r#"export function greet(name: string): string {{
325    return "hello " + name;
326}}"#
327        )
328        .unwrap();
329        file.flush().unwrap();
330
331        let source = SourceFile {
332            path: file.path().to_path_buf(),
333            language: Language::TypeScript,
334            is_test: false,
335        };
336
337        let funcs = extract_functions(&source).unwrap();
338        assert!(!funcs.is_empty(), "expected at least 1 function");
339
340        let greet_fn = funcs
341            .iter()
342            .find(|f| f.name == "greet")
343            .expect("should find 'greet'");
344        assert!(greet_fn.is_public, "exported function should be public");
345        assert_eq!(greet_fn.name, "greet");
346    }
347
348    #[test]
349    fn parse_python_snippet() {
350        let mut file = tempfile::Builder::new().suffix(".py").tempfile().unwrap();
351        write!(
352            file,
353            "def calculate(x, y):\n    if x > 0:\n        return x + y\n    return y\n"
354        )
355        .unwrap();
356        file.flush().unwrap();
357
358        let source = SourceFile {
359            path: file.path().to_path_buf(),
360            language: Language::Python,
361            is_test: false,
362        };
363
364        let funcs = extract_functions(&source).unwrap();
365        assert!(!funcs.is_empty(), "expected at least 1 function");
366
367        let calc_fn = funcs
368            .iter()
369            .find(|f| f.name == "calculate")
370            .expect("should find 'calculate'");
371        assert!(
372            calc_fn.is_public,
373            "calculate should be public (no leading underscore)"
374        );
375        assert_eq!(calc_fn.name, "calculate");
376    }
377
378    #[test]
379    fn complexity_estimation_via_extract() {
380        let mut file = tempfile::Builder::new().suffix(".rs").tempfile().unwrap();
381        writeln!(
382            file,
383            r#"pub fn branchy(x: i32, y: i32) -> i32 {{
384    if x > 0 {{
385        if y > 0 {{
386            for i in 0..x {{
387                match i {{
388                    0 => return 0,
389                    _ => {{
390                        if i > 5 && y < 10 {{
391                            return i;
392                        }}
393                    }}
394                }}
395            }}
396        }}
397    }} else {{
398        while x > 0 {{
399            return y;
400        }}
401    }}
402    x + y
403}}"#
404        )
405        .unwrap();
406        file.flush().unwrap();
407
408        let source = SourceFile {
409            path: file.path().to_path_buf(),
410            language: Language::Rust,
411            is_test: false,
412        };
413
414        let funcs = extract_functions(&source).unwrap();
415        let branchy = funcs
416            .iter()
417            .find(|f| f.name == "branchy")
418            .expect("should find 'branchy'");
419        assert!(
420            branchy.complexity > 1,
421            "complex function should have complexity > 1, got {}",
422            branchy.complexity
423        );
424    }
425}