1use std::collections::HashMap;
6use syn::visit::Visit;
7
8use crate::ast::RustAST;
9
10#[derive(Debug, Clone, PartialEq, Eq, Default)]
12pub struct Location {
13 pub name: String,
15}
16
17impl Location {
18 pub fn new(name: &str) -> Self {
20 Self {
21 name: name.to_string(),
22 }
23 }
24}
25
26#[derive(Debug, Clone)]
28pub struct Symbol {
29 pub name: String,
31 pub kind: SymbolKind,
33 pub definition: Location,
35 pub references: Vec<Location>,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum SymbolKind {
42 LocalVar,
44 Parameter,
46 Function,
48 Struct,
50 Enum,
52 Const,
54 TypeAlias,
56 Impl,
58}
59
60pub struct DefRefs;
62
63impl DefRefs {
64 pub fn analyze(ast: &RustAST) -> SymbolTable {
66 let mut collector = SymbolCollector::new();
67 collector.visit_file(ast.file());
68 collector.table
69 }
70
71 pub fn find_definition(ast: &RustAST, name: &str) -> Option<Symbol> {
73 let table = Self::analyze(ast);
74 table.symbols.get(name).cloned()
75 }
76
77 pub fn find_references(ast: &RustAST, name: &str) -> Vec<Location> {
79 let table = Self::analyze(ast);
80 table
81 .symbols
82 .get(name)
83 .map(|s| s.references.clone())
84 .unwrap_or_default()
85 }
86}
87
88#[derive(Debug, Default)]
90pub struct SymbolTable {
91 pub symbols: HashMap<String, Symbol>,
93}
94
95impl SymbolTable {
96 pub fn by_kind(&self, kind: SymbolKind) -> Vec<&Symbol> {
98 self.symbols.values().filter(|s| s.kind == kind).collect()
99 }
100
101 pub fn functions(&self) -> Vec<&Symbol> {
103 self.by_kind(SymbolKind::Function)
104 }
105
106 pub fn local_vars(&self) -> Vec<&Symbol> {
108 self.by_kind(SymbolKind::LocalVar)
109 }
110}
111
112struct SymbolCollector {
114 table: SymbolTable,
115 scopes: Vec<HashMap<String, Location>>,
117}
118
119impl SymbolCollector {
120 fn new() -> Self {
121 Self {
122 table: SymbolTable::default(),
123 scopes: vec![HashMap::new()], }
125 }
126
127 fn enter_scope(&mut self) {
128 self.scopes.push(HashMap::new());
129 }
130
131 fn exit_scope(&mut self) {
132 self.scopes.pop();
133 }
134
135 fn define_symbol(&mut self, name: &str, kind: SymbolKind) {
136 let loc = Location::new(name);
137
138 if matches!(kind, SymbolKind::LocalVar | SymbolKind::Parameter) {
140 if let Some(scope) = self.scopes.last_mut() {
141 scope.insert(name.to_string(), loc.clone());
142 }
143 }
144
145 self.table.symbols.insert(
147 name.to_string(),
148 Symbol {
149 name: name.to_string(),
150 kind,
151 definition: loc,
152 references: vec![],
153 },
154 );
155 }
156
157 fn add_reference(&mut self, name: &str) {
158 let loc = Location::new(name);
159 if let Some(symbol) = self.table.symbols.get_mut(name) {
160 symbol.references.push(loc);
161 }
162 }
163
164 fn is_defined(&self, name: &str) -> bool {
165 self.scopes.iter().rev().any(|s| s.contains_key(name))
166 || self.table.symbols.contains_key(name)
167 }
168
169 fn define_from_pat(&mut self, pat: &syn::Pat, kind: SymbolKind) {
171 match pat {
172 syn::Pat::Ident(pat_ident) => {
173 self.define_symbol(&pat_ident.ident.to_string(), kind);
174 }
175 syn::Pat::Tuple(pat_tuple) => {
176 for elem in &pat_tuple.elems {
177 self.define_from_pat(elem, kind);
178 }
179 }
180 syn::Pat::TupleStruct(pat_tuple_struct) => {
181 for elem in &pat_tuple_struct.elems {
182 self.define_from_pat(elem, kind);
183 }
184 }
185 syn::Pat::Struct(pat_struct) => {
186 for field in &pat_struct.fields {
187 self.define_from_pat(&field.pat, kind);
188 }
189 }
190 syn::Pat::Reference(pat_ref) => {
191 self.define_from_pat(&pat_ref.pat, kind);
192 }
193 syn::Pat::Type(pat_type) => {
194 self.define_from_pat(&pat_type.pat, kind);
195 }
196 syn::Pat::Or(pat_or) => {
197 for case in &pat_or.cases {
198 self.define_from_pat(case, kind);
199 }
200 }
201 syn::Pat::Slice(pat_slice) => {
202 for elem in &pat_slice.elems {
203 self.define_from_pat(elem, kind);
204 }
205 }
206 _ => {}
207 }
208 }
209}
210
211impl<'ast> Visit<'ast> for SymbolCollector {
212 fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) {
213 self.define_symbol(&node.sig.ident.to_string(), SymbolKind::Function);
215
216 self.enter_scope();
218
219 for param in &node.sig.inputs {
221 if let syn::FnArg::Typed(pat_type) = param {
222 self.define_from_pat(&pat_type.pat, SymbolKind::Parameter);
223 }
224 }
225
226 syn::visit::visit_block(self, &node.block);
228
229 self.exit_scope();
230 }
231
232 fn visit_local(&mut self, node: &'ast syn::Local) {
233 if let Some(init) = &node.init {
235 self.visit_expr(&init.expr);
236 }
237
238 self.define_from_pat(&node.pat, SymbolKind::LocalVar);
240 }
241
242 fn visit_expr_path(&mut self, node: &'ast syn::ExprPath) {
243 if node.path.segments.len() == 1 {
245 let name = node.path.segments[0].ident.to_string();
246 if self.is_defined(&name) {
247 self.add_reference(&name);
248 }
249 }
250 syn::visit::visit_expr_path(self, node);
251 }
252
253 fn visit_item_struct(&mut self, node: &'ast syn::ItemStruct) {
254 self.define_symbol(&node.ident.to_string(), SymbolKind::Struct);
255 syn::visit::visit_item_struct(self, node);
256 }
257
258 fn visit_item_enum(&mut self, node: &'ast syn::ItemEnum) {
259 self.define_symbol(&node.ident.to_string(), SymbolKind::Enum);
260 syn::visit::visit_item_enum(self, node);
261 }
262
263 fn visit_item_const(&mut self, node: &'ast syn::ItemConst) {
264 self.define_symbol(&node.ident.to_string(), SymbolKind::Const);
265 syn::visit::visit_item_const(self, node);
266 }
267
268 fn visit_item_static(&mut self, node: &'ast syn::ItemStatic) {
269 self.define_symbol(&node.ident.to_string(), SymbolKind::Const);
270 syn::visit::visit_item_static(self, node);
271 }
272
273 fn visit_item_type(&mut self, node: &'ast syn::ItemType) {
274 self.define_symbol(&node.ident.to_string(), SymbolKind::TypeAlias);
275 syn::visit::visit_item_type(self, node);
276 }
277
278 fn visit_block(&mut self, node: &'ast syn::Block) {
279 self.enter_scope();
280 syn::visit::visit_block(self, node);
281 self.exit_scope();
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_find_function_def() {
291 let ast = RustAST::parse(
292 r#"
293 fn hello() {}
294 fn world() {}
295 "#,
296 )
297 .unwrap();
298
299 let table = DefRefs::analyze(&ast);
300 assert!(table.symbols.contains_key("hello"));
301 assert!(table.symbols.contains_key("world"));
302 assert_eq!(table.functions().len(), 2);
303 }
304
305 #[test]
306 fn test_find_local_var() {
307 let ast = RustAST::parse(
308 r#"
309 fn main() {
310 let x = 1;
311 let y = 2;
312 }
313 "#,
314 )
315 .unwrap();
316
317 let table = DefRefs::analyze(&ast);
318 assert!(table.symbols.contains_key("x"));
319 assert!(table.symbols.contains_key("y"));
320 }
321
322 #[test]
323 fn test_find_references() {
324 let ast = RustAST::parse(
325 r#"
326 fn main() {
327 let x = 1;
328 let y = x + 1;
329 let z = x + y;
330 }
331 "#,
332 )
333 .unwrap();
334
335 let refs = DefRefs::find_references(&ast, "x");
336 assert_eq!(refs.len(), 2); }
338
339 #[test]
340 fn test_struct_definition() {
341 let ast = RustAST::parse(
342 r#"
343 struct Point {
344 x: i32,
345 y: i32,
346 }
347 "#,
348 )
349 .unwrap();
350
351 let table = DefRefs::analyze(&ast);
352 assert!(table.symbols.contains_key("Point"));
353 assert_eq!(table.symbols["Point"].kind, SymbolKind::Struct);
354 }
355
356 #[test]
357 fn test_symbol_table_by_kind() {
358 let ast = RustAST::parse(
359 r#"
360 struct Foo {}
361 enum Bar {}
362 fn baz() {
363 let x = 1;
364 }
365 "#,
366 )
367 .unwrap();
368
369 let table = DefRefs::analyze(&ast);
370 assert_eq!(table.by_kind(SymbolKind::Struct).len(), 1);
371 assert_eq!(table.by_kind(SymbolKind::Enum).len(), 1);
372 assert_eq!(table.by_kind(SymbolKind::Function).len(), 1);
373 }
374}