Skip to main content

seekr_code/parser/
treesitter.rs

1//! Tree-sitter integration.
2//!
3//! Uses individual tree-sitter language crates to parse source files into ASTs.
4//! Supports language detection by file extension and AST traversal.
5
6use std::path::Path;
7
8use tree_sitter::{Language, Parser};
9
10use crate::error::ParserError;
11
12/// Supported languages and their Tree-sitter grammars.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub enum SupportedLanguage {
15    Rust,
16    Python,
17    JavaScript,
18    TypeScript,
19    Tsx,
20    Go,
21    Java,
22    C,
23    Cpp,
24    Json,
25    Toml,
26    Yaml,
27    Html,
28    Css,
29    Ruby,
30    Bash,
31}
32
33impl SupportedLanguage {
34    /// Detect the language from a file extension.
35    pub fn from_extension(ext: &str) -> Option<Self> {
36        match ext.to_lowercase().as_str() {
37            "rs" => Some(Self::Rust),
38            "py" | "pyi" | "pyx" => Some(Self::Python),
39            "js" | "jsx" | "mjs" | "cjs" => Some(Self::JavaScript),
40            "ts" | "mts" | "cts" => Some(Self::TypeScript),
41            "tsx" => Some(Self::Tsx),
42            "go" => Some(Self::Go),
43            "java" => Some(Self::Java),
44            "c" | "h" => Some(Self::C),
45            "cc" | "cpp" | "cxx" | "hpp" | "hxx" => Some(Self::Cpp),
46            "json" => Some(Self::Json),
47            "toml" => Some(Self::Toml),
48            "yaml" | "yml" => Some(Self::Yaml),
49            "html" | "htm" => Some(Self::Html),
50            "css" | "scss" => Some(Self::Css),
51            "rb" => Some(Self::Ruby),
52            "sh" | "bash" | "zsh" => Some(Self::Bash),
53            _ => None,
54        }
55    }
56
57    /// Detect the language from a file path.
58    pub fn from_path(path: &Path) -> Option<Self> {
59        // Check special filenames first
60        if let Some(filename) = path.file_name().and_then(|f| f.to_str()) {
61            match filename.to_lowercase().as_str() {
62                "makefile" | "gnumakefile" => return Some(Self::Bash),
63                "dockerfile" => return Some(Self::Bash),
64                _ => {}
65            }
66        }
67
68        // Check extension
69        path.extension()
70            .and_then(|ext| ext.to_str())
71            .and_then(Self::from_extension)
72    }
73
74    /// Get the Tree-sitter language grammar.
75    pub fn grammar(&self) -> Language {
76        match self {
77            Self::Rust => tree_sitter_rust::LANGUAGE.into(),
78            Self::Python => tree_sitter_python::LANGUAGE.into(),
79            Self::JavaScript => tree_sitter_javascript::LANGUAGE.into(),
80            Self::TypeScript => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
81            Self::Tsx => tree_sitter_typescript::LANGUAGE_TSX.into(),
82            Self::Go => tree_sitter_go::LANGUAGE.into(),
83            Self::Java => tree_sitter_java::LANGUAGE.into(),
84            Self::C => tree_sitter_c::LANGUAGE.into(),
85            Self::Cpp => tree_sitter_cpp::LANGUAGE.into(),
86            Self::Json => tree_sitter_json::LANGUAGE.into(),
87            Self::Toml => tree_sitter_toml_ng::LANGUAGE.into(),
88            Self::Yaml => tree_sitter_yaml::LANGUAGE.into(),
89            Self::Html => tree_sitter_html::LANGUAGE.into(),
90            Self::Css => tree_sitter_css::LANGUAGE.into(),
91            Self::Ruby => tree_sitter_ruby::LANGUAGE.into(),
92            Self::Bash => tree_sitter_bash::LANGUAGE.into(),
93        }
94    }
95
96    /// Get the language name as a string.
97    pub fn name(&self) -> &'static str {
98        match self {
99            Self::Rust => "rust",
100            Self::Python => "python",
101            Self::JavaScript => "javascript",
102            Self::TypeScript => "typescript",
103            Self::Tsx => "tsx",
104            Self::Go => "go",
105            Self::Java => "java",
106            Self::C => "c",
107            Self::Cpp => "cpp",
108            Self::Json => "json",
109            Self::Toml => "toml",
110            Self::Yaml => "yaml",
111            Self::Html => "html",
112            Self::Css => "css",
113            Self::Ruby => "ruby",
114            Self::Bash => "bash",
115        }
116    }
117
118    /// Get AST node kinds that represent interesting code constructs
119    /// (functions, classes, methods, etc.) for this language.
120    pub fn chunk_node_kinds(&self) -> &[&str] {
121        match self {
122            Self::Rust => &[
123                "function_item",
124                "impl_item",
125                "struct_item",
126                "enum_item",
127                "trait_item",
128                "mod_item",
129                "const_item",
130                "static_item",
131                "type_item",
132                "macro_definition",
133            ],
134            Self::Python => &[
135                "function_definition",
136                "class_definition",
137                "decorated_definition",
138            ],
139            Self::JavaScript | Self::Tsx => &[
140                "function_declaration",
141                "class_declaration",
142                "method_definition",
143                "arrow_function",
144                "export_statement",
145            ],
146            Self::TypeScript => &[
147                "function_declaration",
148                "class_declaration",
149                "method_definition",
150                "arrow_function",
151                "interface_declaration",
152                "type_alias_declaration",
153                "export_statement",
154            ],
155            Self::Go => &[
156                "function_declaration",
157                "method_declaration",
158                "type_declaration",
159            ],
160            Self::Java => &[
161                "class_declaration",
162                "method_declaration",
163                "interface_declaration",
164                "enum_declaration",
165                "constructor_declaration",
166            ],
167            Self::C => &["function_definition", "struct_specifier", "enum_specifier"],
168            Self::Cpp => &[
169                "function_definition",
170                "class_specifier",
171                "struct_specifier",
172                "enum_specifier",
173                "namespace_definition",
174            ],
175            Self::Ruby => &["method", "class", "module", "singleton_method"],
176            // For config/data languages, we don't chunk
177            Self::Json | Self::Toml | Self::Yaml | Self::Html | Self::Css | Self::Bash => &[],
178        }
179    }
180}
181
182/// Create a parser configured for the given language.
183pub fn create_parser(lang: SupportedLanguage) -> Result<Parser, ParserError> {
184    let mut parser = Parser::new();
185    parser
186        .set_language(&lang.grammar())
187        .map_err(|e| ParserError::UnsupportedLanguage(format!("{}: {}", lang.name(), e)))?;
188    Ok(parser)
189}
190
191/// Parse source code with the given language and return the tree.
192pub fn parse_source(
193    source: &str,
194    lang: SupportedLanguage,
195) -> Result<tree_sitter::Tree, ParserError> {
196    let mut parser = create_parser(lang)?;
197    parser
198        .parse(source, None)
199        .ok_or_else(|| ParserError::ParseFailed {
200            path: std::path::PathBuf::from("<string>"),
201            reason: "Parser returned None".to_string(),
202        })
203}
204
205impl std::fmt::Display for SupportedLanguage {
206    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207        write!(f, "{}", self.name())
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn test_language_detection() {
217        assert_eq!(
218            SupportedLanguage::from_extension("rs"),
219            Some(SupportedLanguage::Rust)
220        );
221        assert_eq!(
222            SupportedLanguage::from_extension("py"),
223            Some(SupportedLanguage::Python)
224        );
225        assert_eq!(
226            SupportedLanguage::from_extension("tsx"),
227            Some(SupportedLanguage::Tsx)
228        );
229        assert_eq!(SupportedLanguage::from_extension("unknown"), None);
230    }
231
232    #[test]
233    fn test_parse_rust() {
234        let source = "fn hello() -> String { \"world\".to_string() }";
235        let tree = parse_source(source, SupportedLanguage::Rust).unwrap();
236        let root = tree.root_node();
237        assert_eq!(root.kind(), "source_file");
238        assert!(root.child_count() > 0);
239    }
240
241    #[test]
242    fn test_parse_python() {
243        let source = "def greet(name: str) -> str:\n    return f\"Hello, {name}\"";
244        let tree = parse_source(source, SupportedLanguage::Python).unwrap();
245        let root = tree.root_node();
246        assert_eq!(root.kind(), "module");
247    }
248
249    #[test]
250    fn test_parse_javascript() {
251        let source = "function add(a, b) { return a + b; }";
252        let tree = parse_source(source, SupportedLanguage::JavaScript).unwrap();
253        let root = tree.root_node();
254        assert_eq!(root.kind(), "program");
255    }
256
257    #[test]
258    fn test_all_grammars_load() {
259        let languages = [
260            SupportedLanguage::Rust,
261            SupportedLanguage::Python,
262            SupportedLanguage::JavaScript,
263            SupportedLanguage::TypeScript,
264            SupportedLanguage::Tsx,
265            SupportedLanguage::Go,
266            SupportedLanguage::Java,
267            SupportedLanguage::C,
268            SupportedLanguage::Cpp,
269            SupportedLanguage::Json,
270            SupportedLanguage::Toml,
271            SupportedLanguage::Yaml,
272            SupportedLanguage::Html,
273            SupportedLanguage::Css,
274            SupportedLanguage::Ruby,
275            SupportedLanguage::Bash,
276        ];
277
278        for lang in languages {
279            let parser = create_parser(lang);
280            assert!(
281                parser.is_ok(),
282                "Failed to create parser for {}",
283                lang.name()
284            );
285        }
286    }
287}