1use std::collections::HashMap;
8use scribe_core::Result;
9use tree_sitter::{Parser, Language, Node, Tree};
10
11#[derive(Debug, Clone)]
13pub struct SimpleImport {
14 pub module: String,
16 pub line_number: usize,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
22pub enum ImportLanguage {
23 Python,
24 JavaScript,
25 TypeScript,
26 Go,
27 Rust,
28}
29
30impl ImportLanguage {
31 pub fn tree_sitter_language(&self) -> Language {
33 match self {
34 ImportLanguage::Python => tree_sitter_python::language(),
35 ImportLanguage::JavaScript => tree_sitter_javascript::language(),
36 ImportLanguage::TypeScript => tree_sitter_typescript::language_typescript(),
37 ImportLanguage::Go => tree_sitter_go::language(),
38 ImportLanguage::Rust => tree_sitter_rust::language(),
39 }
40 }
41
42 pub fn from_extension(ext: &str) -> Option<Self> {
44 match ext.to_lowercase().as_str() {
45 "py" | "pyi" | "pyw" => Some(ImportLanguage::Python),
46 "js" | "mjs" | "cjs" => Some(ImportLanguage::JavaScript),
47 "ts" | "mts" | "cts" => Some(ImportLanguage::TypeScript),
48 "go" => Some(ImportLanguage::Go),
49 "rs" => Some(ImportLanguage::Rust),
50 _ => None,
51 }
52 }
53}
54
55pub struct SimpleAstParser {
57 parsers: HashMap<ImportLanguage, Parser>,
58}
59
60impl std::fmt::Debug for SimpleAstParser {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 f.debug_struct("SimpleAstParser")
63 .field("parsers", &format!("[{} parsers]", self.parsers.len()))
64 .finish()
65 }
66}
67
68impl SimpleAstParser {
69 pub fn new() -> Result<Self> {
71 let mut parsers = HashMap::new();
72
73 for language in [
74 ImportLanguage::Python,
75 ImportLanguage::JavaScript,
76 ImportLanguage::TypeScript,
77 ImportLanguage::Go,
78 ImportLanguage::Rust,
79 ] {
80 let mut parser = Parser::new();
81 parser.set_language(language.tree_sitter_language())
82 .map_err(|e| scribe_core::ScribeError::parse(format!("Failed to set tree-sitter language: {}", e)))?;
83 parsers.insert(language, parser);
84 }
85
86 Ok(Self { parsers })
87 }
88
89 pub fn extract_imports(&self, content: &str, language: ImportLanguage) -> Result<Vec<SimpleImport>> {
91 let mut parser = Parser::new();
93 parser.set_language(language.tree_sitter_language()).map_err(|e|
94 scribe_core::ScribeError::parse(format!("Failed to set language: {}", e)))?;
95
96 let tree = parser.parse(content, None)
97 .ok_or_else(|| scribe_core::ScribeError::parse("Failed to parse content"))?;
98
99 let mut imports = Vec::new();
100 let root_node = tree.root_node();
101
102 match language {
104 ImportLanguage::Python => {
105 self.extract_python_imports(&root_node, content, &mut imports)?;
106 }
107 ImportLanguage::JavaScript | ImportLanguage::TypeScript => {
108 self.extract_js_ts_imports(&root_node, content, &mut imports)?;
109 }
110 ImportLanguage::Go => {
111 self.extract_go_imports(&root_node, content, &mut imports)?;
112 }
113 ImportLanguage::Rust => {
114 self.extract_rust_imports(&root_node, content, &mut imports)?;
115 }
116 }
117
118 Ok(imports)
119 }
120
121 fn extract_python_imports(&self, node: &Node, content: &str, imports: &mut Vec<SimpleImport>) -> Result<()> {
123 if node.kind() == "import_statement" {
124 for i in 0..node.child_count() {
126 if let Some(child) = node.child(i) {
127 if child.kind() == "dotted_name" || child.kind() == "identifier" {
128 let module = self.node_text(child, content);
129 let line_number = child.start_position().row + 1;
130
131 imports.push(SimpleImport {
132 module,
133 line_number,
134 });
135 }
136 }
137 }
138 } else if node.kind() == "import_from_statement" {
139 if let Some(module_node) = node.child_by_field_name("module_name") {
140 let module = self.node_text(module_node, content);
141 let line_number = node.start_position().row + 1;
142 imports.push(SimpleImport {
143 module,
144 line_number,
145 });
146 }
147 }
148
149 for i in 0..node.child_count() {
151 if let Some(child) = node.child(i) {
152 self.extract_python_imports(&child, content, imports)?;
153 }
154 }
155
156 Ok(())
157 }
158
159 fn extract_js_ts_imports(&self, node: &Node, content: &str, imports: &mut Vec<SimpleImport>) -> Result<()> {
161 if node.kind() == "import_statement" {
162 for i in 0..node.child_count() {
164 if let Some(child) = node.child(i) {
165 if child.kind() == "string" {
166 let mut module = self.node_text(child, content);
167 module = module.trim_matches('"').trim_matches('\'').to_string();
169 let line_number = node.start_position().row + 1;
170 imports.push(SimpleImport {
171 module,
172 line_number,
173 });
174 break;
175 }
176 }
177 }
178 }
179
180 for i in 0..node.child_count() {
182 if let Some(child) = node.child(i) {
183 self.extract_js_ts_imports(&child, content, imports)?;
184 }
185 }
186
187 Ok(())
188 }
189
190 fn extract_go_imports(&self, node: &Node, content: &str, imports: &mut Vec<SimpleImport>) -> Result<()> {
192 if node.kind() == "import_spec" {
193 for i in 0..node.child_count() {
194 if let Some(child) = node.child(i) {
195 if child.kind() == "interpreted_string_literal" {
196 let module = self.node_text(child, content);
197 let module = module.trim_matches('"').to_string();
198 let line_number = child.start_position().row + 1;
199
200 imports.push(SimpleImport {
201 module,
202 line_number,
203 });
204 }
205 }
206 }
207 }
208
209 for i in 0..node.child_count() {
211 if let Some(child) = node.child(i) {
212 self.extract_go_imports(&child, content, imports)?;
213 }
214 }
215
216 Ok(())
217 }
218
219 fn extract_rust_imports(&self, node: &Node, content: &str, imports: &mut Vec<SimpleImport>) -> Result<()> {
221 if node.kind() == "use_declaration" {
222 if let Some(use_tree) = node.child_by_field_name("argument") {
223 let module = self.node_text(use_tree, content);
224 let line_number = node.start_position().row + 1;
225
226 imports.push(SimpleImport {
227 module,
228 line_number,
229 });
230 }
231 }
232
233 for i in 0..node.child_count() {
235 if let Some(child) = node.child(i) {
236 self.extract_rust_imports(&child, content, imports)?;
237 }
238 }
239
240 Ok(())
241 }
242
243 fn node_text(&self, node: Node, content: &str) -> String {
245 content[node.start_byte()..node.end_byte()].to_string()
246 }
247}
248
249impl Default for SimpleAstParser {
250 fn default() -> Self {
251 Self::new().expect("Failed to create SimpleAstParser")
252 }
253}