Skip to main content

ryo_source/ops/
def_refs.rs

1//! Definition and Reference analysis.
2//!
3//! Query operation to find where symbols are defined and referenced.
4
5use std::collections::HashMap;
6use syn::visit::Visit;
7
8use crate::ast::RustAST;
9
10/// A location in source code (simplified - just tracks the name).
11#[derive(Debug, Clone, PartialEq, Eq, Default)]
12pub struct Location {
13    /// Symbol name.
14    pub name: String,
15}
16
17impl Location {
18    /// Construct a `Location` from a symbol name.
19    pub fn new(name: &str) -> Self {
20        Self {
21            name: name.to_string(),
22        }
23    }
24}
25
26/// Information about a symbol (variable, function, type, etc.).
27#[derive(Debug, Clone)]
28pub struct Symbol {
29    /// Symbol name.
30    pub name: String,
31    /// Kind of symbol.
32    pub kind: SymbolKind,
33    /// Definition location.
34    pub definition: Location,
35    /// All reference locations.
36    pub references: Vec<Location>,
37}
38
39/// Kind of symbol.
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum SymbolKind {
42    /// Local variable (let binding).
43    LocalVar,
44    /// Function parameter.
45    Parameter,
46    /// Function definition.
47    Function,
48    /// Struct definition.
49    Struct,
50    /// Enum definition.
51    Enum,
52    /// Const or static.
53    Const,
54    /// Type alias.
55    TypeAlias,
56    /// Impl block (for a type).
57    Impl,
58}
59
60/// Definition and Reference finder.
61pub struct DefRefs;
62
63impl DefRefs {
64    /// Find all symbols and their references in the AST.
65    pub fn analyze(ast: &RustAST) -> SymbolTable {
66        let mut collector = SymbolCollector::new();
67        collector.visit_file(ast.file());
68        collector.table
69    }
70
71    /// Find definition of a symbol at a given location.
72    pub fn find_definition(ast: &RustAST, name: &str) -> Option<Symbol> {
73        let table = Self::analyze(ast);
74        table.symbols.get(name).cloned()
75    }
76
77    /// Find all references to a symbol.
78    pub fn find_references(ast: &RustAST, name: &str) -> Vec<Location> {
79        let table = Self::analyze(ast);
80        table
81            .symbols
82            .get(name)
83            .map(|s| s.references.clone())
84            .unwrap_or_default()
85    }
86}
87
88/// Table of all symbols in a file.
89#[derive(Debug, Default)]
90pub struct SymbolTable {
91    /// All symbols indexed by name.
92    pub symbols: HashMap<String, Symbol>,
93}
94
95impl SymbolTable {
96    /// Get all symbols of a specific kind.
97    pub fn by_kind(&self, kind: SymbolKind) -> Vec<&Symbol> {
98        self.symbols.values().filter(|s| s.kind == kind).collect()
99    }
100
101    /// Get all function definitions.
102    pub fn functions(&self) -> Vec<&Symbol> {
103        self.by_kind(SymbolKind::Function)
104    }
105
106    /// Get all local variables.
107    pub fn local_vars(&self) -> Vec<&Symbol> {
108        self.by_kind(SymbolKind::LocalVar)
109    }
110}
111
112/// Visitor that collects symbol definitions and references.
113struct SymbolCollector {
114    table: SymbolTable,
115    /// Stack of scopes (for tracking local variables).
116    scopes: Vec<HashMap<String, Location>>,
117}
118
119impl SymbolCollector {
120    fn new() -> Self {
121        Self {
122            table: SymbolTable::default(),
123            scopes: vec![HashMap::new()], // Global scope
124        }
125    }
126
127    fn enter_scope(&mut self) {
128        self.scopes.push(HashMap::new());
129    }
130
131    fn exit_scope(&mut self) {
132        self.scopes.pop();
133    }
134
135    fn define_symbol(&mut self, name: &str, kind: SymbolKind) {
136        let loc = Location::new(name);
137
138        // Add to current scope for locals
139        if matches!(kind, SymbolKind::LocalVar | SymbolKind::Parameter) {
140            if let Some(scope) = self.scopes.last_mut() {
141                scope.insert(name.to_string(), loc.clone());
142            }
143        }
144
145        // Add to symbol table
146        self.table.symbols.insert(
147            name.to_string(),
148            Symbol {
149                name: name.to_string(),
150                kind,
151                definition: loc,
152                references: vec![],
153            },
154        );
155    }
156
157    fn add_reference(&mut self, name: &str) {
158        let loc = Location::new(name);
159        if let Some(symbol) = self.table.symbols.get_mut(name) {
160            symbol.references.push(loc);
161        }
162    }
163
164    fn is_defined(&self, name: &str) -> bool {
165        self.scopes.iter().rev().any(|s| s.contains_key(name))
166            || self.table.symbols.contains_key(name)
167    }
168
169    /// Visit a pattern and define any bound variables.
170    fn define_from_pat(&mut self, pat: &syn::Pat, kind: SymbolKind) {
171        match pat {
172            syn::Pat::Ident(pat_ident) => {
173                self.define_symbol(&pat_ident.ident.to_string(), kind);
174            }
175            syn::Pat::Tuple(pat_tuple) => {
176                for elem in &pat_tuple.elems {
177                    self.define_from_pat(elem, kind);
178                }
179            }
180            syn::Pat::TupleStruct(pat_tuple_struct) => {
181                for elem in &pat_tuple_struct.elems {
182                    self.define_from_pat(elem, kind);
183                }
184            }
185            syn::Pat::Struct(pat_struct) => {
186                for field in &pat_struct.fields {
187                    self.define_from_pat(&field.pat, kind);
188                }
189            }
190            syn::Pat::Reference(pat_ref) => {
191                self.define_from_pat(&pat_ref.pat, kind);
192            }
193            syn::Pat::Type(pat_type) => {
194                self.define_from_pat(&pat_type.pat, kind);
195            }
196            syn::Pat::Or(pat_or) => {
197                for case in &pat_or.cases {
198                    self.define_from_pat(case, kind);
199                }
200            }
201            syn::Pat::Slice(pat_slice) => {
202                for elem in &pat_slice.elems {
203                    self.define_from_pat(elem, kind);
204                }
205            }
206            _ => {}
207        }
208    }
209}
210
211impl<'ast> Visit<'ast> for SymbolCollector {
212    fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) {
213        // Define the function
214        self.define_symbol(&node.sig.ident.to_string(), SymbolKind::Function);
215
216        // Enter function scope
217        self.enter_scope();
218
219        // Define parameters
220        for param in &node.sig.inputs {
221            if let syn::FnArg::Typed(pat_type) = param {
222                self.define_from_pat(&pat_type.pat, SymbolKind::Parameter);
223            }
224        }
225
226        // Visit function body
227        syn::visit::visit_block(self, &node.block);
228
229        self.exit_scope();
230    }
231
232    fn visit_local(&mut self, node: &'ast syn::Local) {
233        // Visit the init expression first (before defining the variable)
234        if let Some(init) = &node.init {
235            self.visit_expr(&init.expr);
236        }
237
238        // Define local variable(s) from pattern
239        self.define_from_pat(&node.pat, SymbolKind::LocalVar);
240    }
241
242    fn visit_expr_path(&mut self, node: &'ast syn::ExprPath) {
243        // This might be a reference to a variable
244        if node.path.segments.len() == 1 {
245            let name = node.path.segments[0].ident.to_string();
246            if self.is_defined(&name) {
247                self.add_reference(&name);
248            }
249        }
250        syn::visit::visit_expr_path(self, node);
251    }
252
253    fn visit_item_struct(&mut self, node: &'ast syn::ItemStruct) {
254        self.define_symbol(&node.ident.to_string(), SymbolKind::Struct);
255        syn::visit::visit_item_struct(self, node);
256    }
257
258    fn visit_item_enum(&mut self, node: &'ast syn::ItemEnum) {
259        self.define_symbol(&node.ident.to_string(), SymbolKind::Enum);
260        syn::visit::visit_item_enum(self, node);
261    }
262
263    fn visit_item_const(&mut self, node: &'ast syn::ItemConst) {
264        self.define_symbol(&node.ident.to_string(), SymbolKind::Const);
265        syn::visit::visit_item_const(self, node);
266    }
267
268    fn visit_item_static(&mut self, node: &'ast syn::ItemStatic) {
269        self.define_symbol(&node.ident.to_string(), SymbolKind::Const);
270        syn::visit::visit_item_static(self, node);
271    }
272
273    fn visit_item_type(&mut self, node: &'ast syn::ItemType) {
274        self.define_symbol(&node.ident.to_string(), SymbolKind::TypeAlias);
275        syn::visit::visit_item_type(self, node);
276    }
277
278    fn visit_block(&mut self, node: &'ast syn::Block) {
279        self.enter_scope();
280        syn::visit::visit_block(self, node);
281        self.exit_scope();
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn test_find_function_def() {
291        let ast = RustAST::parse(
292            r#"
293            fn hello() {}
294            fn world() {}
295            "#,
296        )
297        .unwrap();
298
299        let table = DefRefs::analyze(&ast);
300        assert!(table.symbols.contains_key("hello"));
301        assert!(table.symbols.contains_key("world"));
302        assert_eq!(table.functions().len(), 2);
303    }
304
305    #[test]
306    fn test_find_local_var() {
307        let ast = RustAST::parse(
308            r#"
309            fn main() {
310                let x = 1;
311                let y = 2;
312            }
313            "#,
314        )
315        .unwrap();
316
317        let table = DefRefs::analyze(&ast);
318        assert!(table.symbols.contains_key("x"));
319        assert!(table.symbols.contains_key("y"));
320    }
321
322    #[test]
323    fn test_find_references() {
324        let ast = RustAST::parse(
325            r#"
326            fn main() {
327                let x = 1;
328                let y = x + 1;
329                let z = x + y;
330            }
331            "#,
332        )
333        .unwrap();
334
335        let refs = DefRefs::find_references(&ast, "x");
336        assert_eq!(refs.len(), 2); // x is used twice
337    }
338
339    #[test]
340    fn test_struct_definition() {
341        let ast = RustAST::parse(
342            r#"
343            struct Point {
344                x: i32,
345                y: i32,
346            }
347            "#,
348        )
349        .unwrap();
350
351        let table = DefRefs::analyze(&ast);
352        assert!(table.symbols.contains_key("Point"));
353        assert_eq!(table.symbols["Point"].kind, SymbolKind::Struct);
354    }
355
356    #[test]
357    fn test_symbol_table_by_kind() {
358        let ast = RustAST::parse(
359            r#"
360            struct Foo {}
361            enum Bar {}
362            fn baz() {
363                let x = 1;
364            }
365            "#,
366        )
367        .unwrap();
368
369        let table = DefRefs::analyze(&ast);
370        assert_eq!(table.by_kind(SymbolKind::Struct).len(), 1);
371        assert_eq!(table.by_kind(SymbolKind::Enum).len(), 1);
372        assert_eq!(table.by_kind(SymbolKind::Function).len(), 1);
373    }
374}