Skip to main content

sparrow/engine/
treesitter.rs

1// ─── Tree-sitter based RepoMap symbol extraction (Phase 3 Item 11) ─────────────
2
3use std::path::Path;
4
5/// Extract symbols from a source file using tree-sitter.
6/// Falls back to regex if tree-sitter is not available for the language.
7pub struct TreeSitterParser;
8
9impl TreeSitterParser {
10    /// Extract symbols from file content, returning (name, kind, line).
11    pub fn extract(content: &str, language: &str) -> Vec<(String, String, u32)> {
12        match language {
13            "rust" => Self::extract_rust(content),
14            "python" => Self::extract_python(content),
15            "typescript" | "javascript" => Self::extract_ts(content),
16            "go" => Self::extract_go(content),
17            _ => Self::extract_generic(content),
18        }
19    }
20
21    fn extract_rust(content: &str) -> Vec<(String, String, u32)> {
22        let mut symbols = Vec::new();
23        // Strip line comments and block-comment bodies before scanning so we don't
24        // pick up declarations inside doc-comments or commented-out code.
25        let stripped = strip_rust_comments(content);
26        // Track visibility prefixes like `pub`, `pub(crate)`, `pub(super)`.
27        let strip_vis = |s: &str| -> String {
28            let mut t = s.trim_start();
29            // Strip `pub(...)` first.
30            if t.starts_with("pub(") {
31                if let Some(close) = t.find(')') {
32                    t = t[close + 1..].trim_start();
33                }
34            } else if let Some(rest) = t.strip_prefix("pub ") {
35                t = rest.trim_start();
36            }
37            // Strip async/unsafe/extern modifiers so the next-token check works.
38            for kw in ["async ", "unsafe ", "extern \"C\" ", "extern ", "default "] {
39                if let Some(rest) = t.strip_prefix(kw) {
40                    t = rest.trim_start();
41                }
42            }
43            t.to_string()
44        };
45        for (i, line) in stripped.lines().enumerate() {
46            let trimmed = strip_vis(line.trim());
47            let line_num = (i + 1) as u32;
48
49            if let Some(rest) = trimmed.strip_prefix("fn ") {
50                let name = extract_ident(rest);
51                if !name.is_empty() {
52                    symbols.push((name, "fn".into(), line_num));
53                }
54            } else if let Some(rest) = trimmed.strip_prefix("struct ") {
55                let name = extract_ident(rest);
56                if !name.is_empty() {
57                    symbols.push((name, "struct".into(), line_num));
58                }
59            } else if let Some(rest) = trimmed.strip_prefix("enum ") {
60                let name = extract_ident(rest);
61                if !name.is_empty() {
62                    symbols.push((name, "enum".into(), line_num));
63                }
64            } else if let Some(rest) = trimmed.strip_prefix("trait ") {
65                let name = extract_ident(rest);
66                if !name.is_empty() {
67                    symbols.push((name, "trait".into(), line_num));
68                }
69            } else if let Some(rest) = trimmed.strip_prefix("mod ") {
70                let name = extract_ident(rest);
71                if !name.is_empty() {
72                    symbols.push((name, "mod".into(), line_num));
73                }
74            } else if trimmed.starts_with("impl") {
75                let name = extract_impl_name(&trimmed);
76                symbols.push((name, "impl".into(), line_num));
77            } else if let Some(rest) = trimmed
78                .strip_prefix("const ")
79                .or_else(|| trimmed.strip_prefix("static "))
80            {
81                let kind = if trimmed.starts_with("static ") {
82                    "static"
83                } else {
84                    "const"
85                };
86                let name = rest
87                    .split(|c: char| c == ':' || c == '=' || c == ';')
88                    .next()
89                    .unwrap_or("")
90                    .trim()
91                    .to_string();
92                // Rust idiomatic constants are SCREAMING_SNAKE — accept digits too
93                // (e.g. CONST_1, V2_TOKEN). Allow leading underscore for visibility.
94                let is_screaming = !name.is_empty()
95                    && name
96                        .chars()
97                        .all(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || c == '_');
98                if is_screaming {
99                    symbols.push((name, kind.into(), line_num));
100                }
101            } else if let Some(rest) = trimmed.strip_prefix("type ") {
102                let name = extract_ident(rest);
103                if !name.is_empty() {
104                    symbols.push((name, "type".into(), line_num));
105                }
106            }
107        }
108        symbols
109    }
110
111    fn extract_python(content: &str) -> Vec<(String, String, u32)> {
112        let mut symbols = Vec::new();
113        for (i, line) in content.lines().enumerate() {
114            let trimmed = line.trim();
115            let line_num = (i + 1) as u32;
116            if trimmed.starts_with("def ") {
117                let name = trimmed
118                    .trim_start_matches("def ")
119                    .split('(')
120                    .next()
121                    .unwrap_or("")
122                    .trim()
123                    .to_string();
124                symbols.push((name, "fn".into(), line_num));
125            } else if trimmed.starts_with("class ") {
126                let name = trimmed
127                    .trim_start_matches("class ")
128                    .split(|c: char| c == '(' || c == ':')
129                    .next()
130                    .unwrap_or("")
131                    .trim()
132                    .to_string();
133                symbols.push((name, "class".into(), line_num));
134            } else if trimmed.starts_with("async def ") {
135                let name = trimmed
136                    .trim_start_matches("async def ")
137                    .split('(')
138                    .next()
139                    .unwrap_or("")
140                    .trim()
141                    .to_string();
142                symbols.push((name, "async fn".into(), line_num));
143            }
144        }
145        symbols
146    }
147
148    fn extract_ts(content: &str) -> Vec<(String, String, u32)> {
149        let mut symbols = Vec::new();
150        for (i, line) in content.lines().enumerate() {
151            let trimmed = line.trim();
152            let line_num = (i + 1) as u32;
153            let kind = if trimmed.starts_with("export function ")
154                || trimmed.starts_with("function ")
155            {
156                "fn"
157            } else if trimmed.starts_with("export class ") || trimmed.starts_with("class ") {
158                "class"
159            } else if trimmed.starts_with("export interface ") || trimmed.starts_with("interface ")
160            {
161                "interface"
162            } else if trimmed.starts_with("export const ") || trimmed.starts_with("const ") {
163                "const"
164            } else {
165                continue;
166            };
167            let name = trimmed.split_whitespace().nth(1).unwrap_or("").to_string();
168            symbols.push((name, kind.into(), line_num));
169        }
170        symbols
171    }
172
173    fn extract_go(content: &str) -> Vec<(String, String, u32)> {
174        let mut symbols = Vec::new();
175        for (i, line) in content.lines().enumerate() {
176            let trimmed = line.trim();
177            let line_num = (i + 1) as u32;
178            if trimmed.starts_with("func ") {
179                let name = trimmed
180                    .trim_start_matches("func ")
181                    .split('(')
182                    .next()
183                    .unwrap_or("")
184                    .trim()
185                    .to_string();
186                symbols.push((name, "fn".into(), line_num));
187            } else if trimmed.starts_with("type ") && trimmed.contains("struct") {
188                let name = trimmed
189                    .trim_start_matches("type ")
190                    .split_whitespace()
191                    .next()
192                    .unwrap_or("")
193                    .to_string();
194                symbols.push((name, "struct".into(), line_num));
195            }
196        }
197        symbols
198    }
199
200    fn extract_generic(content: &str) -> Vec<(String, String, u32)> {
201        let mut symbols = Vec::new();
202        for (i, line) in content.lines().enumerate() {
203            let trimmed = line.trim();
204            if trimmed.starts_with("def ")
205                || trimmed.starts_with("fn ")
206                || trimmed.starts_with("func ")
207                || trimmed.starts_with("class ")
208                || trimmed.starts_with("struct ")
209            {
210                let line_num = (i + 1) as u32;
211                let name = trimmed.split_whitespace().nth(1).unwrap_or("").to_string();
212                symbols.push((name, "symbol".into(), line_num));
213            }
214        }
215        symbols
216    }
217}
218
219/// Extract the leading Rust identifier from a slice that starts with one.
220/// Skips leading whitespace, then takes the longest run of [A-Za-z0-9_].
221fn extract_ident(s: &str) -> String {
222    let s = s.trim_start();
223    s.chars()
224        .take_while(|c| c.is_ascii_alphanumeric() || *c == '_')
225        .collect()
226}
227
228/// Extract the type being implemented from an `impl` line.
229/// Handles `impl Foo`, `impl<T> Foo<T>`, `impl Trait for Foo<T>`,
230/// `impl<T: Bound> Trait<T> for Foo<T> where ... {`.
231fn extract_impl_name(line: &str) -> String {
232    // Drop leading `impl` keyword.
233    let after_impl = line.trim_start().trim_start_matches("impl").trim_start();
234
235    // Skip optional generic parameter list `<...>` immediately after `impl`,
236    // balancing angle brackets so we don't get fooled by `<T: Trait<U>>`.
237    let after_generics = skip_balanced_angles(after_impl);
238
239    // If the body contains ` for `, the right-hand side is the implementing type.
240    let target_slice = if let Some(idx) = find_keyword(after_generics, " for ") {
241        &after_generics[idx + 5..]
242    } else {
243        after_generics
244    };
245    let target = target_slice.trim_start();
246
247    // Take the type name up to `<`, `{`, `where`, whitespace, or end.
248    let end = target
249        .find(|c: char| c == '<' || c == '{' || c.is_whitespace())
250        .unwrap_or(target.len());
251    let name = target[..end].trim();
252    if name.is_empty() {
253        "impl".into()
254    } else {
255        format!("impl {}", name)
256    }
257}
258
259/// If `s` starts with `<`, advance past the matching `>` (balanced). Otherwise
260/// returns `s` unchanged.
261fn skip_balanced_angles(s: &str) -> &str {
262    let bytes = s.as_bytes();
263    if !s.starts_with('<') {
264        return s;
265    }
266    let mut depth = 0i32;
267    for (i, b) in bytes.iter().enumerate() {
268        match *b {
269            b'<' => depth += 1,
270            b'>' => {
271                depth -= 1;
272                if depth == 0 {
273                    return s[i + 1..].trim_start();
274                }
275            }
276            _ => {}
277        }
278    }
279    s
280}
281
282/// Find a keyword surrounded by word boundaries (cheap: requires the surrounding
283/// spaces baked in, e.g. " for "). Returns the index of the leading space.
284fn find_keyword(haystack: &str, needle: &str) -> Option<usize> {
285    haystack.find(needle)
286}
287
288/// Strip Rust line comments (`// ...`) and block comments (`/* ... */`,
289/// possibly nested). Keeps line breaks so line numbers stay aligned.
290fn strip_rust_comments(s: &str) -> String {
291    let bytes = s.as_bytes();
292    let mut out = String::with_capacity(s.len());
293    let mut i = 0;
294    let mut depth = 0u32;
295    let mut in_string = false;
296    let mut in_char = false;
297    while i < bytes.len() {
298        let b = bytes[i];
299        let next = bytes.get(i + 1).copied();
300        if depth > 0 {
301            if b == b'/' && next == Some(b'*') {
302                depth += 1;
303                i += 2;
304                continue;
305            }
306            if b == b'*' && next == Some(b'/') {
307                depth -= 1;
308                i += 2;
309                continue;
310            }
311            if b == b'\n' {
312                out.push('\n');
313            }
314            i += 1;
315            continue;
316        }
317        if in_string {
318            if let (b'\\', Some(next)) = (b, next) {
319                out.push(b as char);
320                out.push(next as char);
321                i += 2;
322                continue;
323            }
324            if b == b'"' {
325                in_string = false;
326            }
327            out.push(b as char);
328            i += 1;
329            continue;
330        }
331        if in_char {
332            if let (b'\\', Some(next)) = (b, next) {
333                out.push(b as char);
334                out.push(next as char);
335                i += 2;
336                continue;
337            }
338            if b == b'\'' {
339                in_char = false;
340            }
341            out.push(b as char);
342            i += 1;
343            continue;
344        }
345        if b == b'"' {
346            in_string = true;
347            out.push('"');
348            i += 1;
349            continue;
350        }
351        // We deliberately do NOT enter `in_char` mode: distinguishing lifetimes
352        // from char literals reliably needs a real lexer, and getting it wrong
353        // can drop code. Lifetimes can contain `//`-like sequences only inside
354        // strings, which we already handle.
355        if b == b'/' && next == Some(b'/') {
356            // Skip to end of line.
357            while i < bytes.len() && bytes[i] != b'\n' {
358                i += 1;
359            }
360            continue;
361        }
362        if b == b'/' && next == Some(b'*') {
363            depth = 1;
364            i += 2;
365            continue;
366        }
367        out.push(b as char);
368        i += 1;
369    }
370    out
371}
372
373pub fn detect_language(path: &Path) -> Option<&'static str> {
374    match path.extension().and_then(|e| e.to_str()) {
375        Some("rs") => Some("rust"),
376        Some("py") => Some("python"),
377        Some("ts") | Some("tsx") => Some("typescript"),
378        Some("js") | Some("jsx") => Some("javascript"),
379        Some("go") => Some("go"),
380        _ => None,
381    }
382}