Skip to main content

solidity_language_server/
symbols.rs

1#![allow(deprecated)]
2
3use tower_lsp::lsp_types::{
4    DocumentSymbol, Location, Position, Range, SymbolInformation, SymbolKind, Url,
5};
6use tree_sitter::{Node, Parser};
7
8// ── Document symbols (hierarchical, single file) ───────────────────────────
9
10/// Extract hierarchical document symbols from Solidity source using tree-sitter.
11pub fn extract_document_symbols(source: &str) -> Vec<DocumentSymbol> {
12    let tree = match parse(source) {
13        Some(t) => t,
14        None => return vec![],
15    };
16    collect_top_level(tree.root_node(), source)
17}
18
19fn collect_top_level(node: Node, source: &str) -> Vec<DocumentSymbol> {
20    named_children(node)
21        .filter_map(|child| match child.kind() {
22            "pragma_directive" => Some(text_symbol(child, source, SymbolKind::STRING)),
23            "import_directive" => Some(import_symbol(child, source)),
24            "contract_declaration" => contract_symbol(child, source, SymbolKind::CLASS),
25            "interface_declaration" => contract_symbol(child, source, SymbolKind::INTERFACE),
26            "library_declaration" => contract_symbol(child, source, SymbolKind::NAMESPACE),
27            "struct_declaration" => struct_symbol(child, source),
28            "enum_declaration" => enum_symbol(child, source),
29            "function_definition" => function_symbol(child, source),
30            "event_definition" | "error_declaration" => id_symbol(child, source, SymbolKind::EVENT),
31            "state_variable_declaration" => id_symbol(child, source, SymbolKind::FIELD),
32            "user_defined_type_definition" => id_symbol(child, source, SymbolKind::TYPE_PARAMETER),
33            _ => None,
34        })
35        .collect()
36}
37
38fn collect_contract_members(body: Node, source: &str) -> Vec<DocumentSymbol> {
39    named_children(body)
40        .filter_map(|child| match child.kind() {
41            "function_definition" => function_symbol(child, source),
42            "constructor_definition" => Some(leaf("constructor", SymbolKind::CONSTRUCTOR, child)),
43            "fallback_receive_definition" => Some(leaf(
44                &fallback_or_receive(child, source),
45                SymbolKind::FUNCTION,
46                child,
47            )),
48            "state_variable_declaration" => id_symbol(child, source, SymbolKind::FIELD),
49            "event_definition" | "error_declaration" => id_symbol(child, source, SymbolKind::EVENT),
50            "modifier_definition" => id_symbol(child, source, SymbolKind::METHOD),
51            "struct_declaration" => struct_symbol(child, source),
52            "enum_declaration" => enum_symbol(child, source),
53            "using_directive" => Some(text_symbol(child, source, SymbolKind::PROPERTY)),
54            "user_defined_type_definition" => id_symbol(child, source, SymbolKind::TYPE_PARAMETER),
55            _ => None,
56        })
57        .collect()
58}
59
60// ── Symbol builders ────────────────────────────────────────────────────────
61
62fn contract_symbol(node: Node, source: &str, kind: SymbolKind) -> Option<DocumentSymbol> {
63    let name = child_id_text(node, source)?;
64    let children = find_child(node, "contract_body")
65        .map(|body| collect_contract_members(body, source))
66        .filter(|c| !c.is_empty());
67
68    Some(DocumentSymbol {
69        name: name.into(),
70        detail: None,
71        kind,
72        range: range(node),
73        selection_range: child_id_range(node)?,
74        children,
75        tags: None,
76        deprecated: None,
77    })
78}
79
80fn function_symbol(node: Node, source: &str) -> Option<DocumentSymbol> {
81    let name = child_id_text(node, source)?;
82    Some(DocumentSymbol {
83        name: name.into(),
84        detail: Some(function_detail(node, source)),
85        kind: SymbolKind::FUNCTION,
86        range: range(node),
87        selection_range: child_id_range(node)?,
88        children: None,
89        tags: None,
90        deprecated: None,
91    })
92}
93
94fn struct_symbol(node: Node, source: &str) -> Option<DocumentSymbol> {
95    let name = child_id_text(node, source)?;
96    let children = find_child(node, "struct_body")
97        .map(|body| {
98            named_children(body)
99                .filter(|c| c.kind() == "struct_member")
100                .filter_map(|c| id_symbol(c, source, SymbolKind::FIELD))
101                .collect::<Vec<_>>()
102        })
103        .filter(|c| !c.is_empty());
104
105    Some(DocumentSymbol {
106        name: name.into(),
107        detail: None,
108        kind: SymbolKind::STRUCT,
109        range: range(node),
110        selection_range: child_id_range(node)?,
111        children,
112        tags: None,
113        deprecated: None,
114    })
115}
116
117fn enum_symbol(node: Node, source: &str) -> Option<DocumentSymbol> {
118    let name = child_id_text(node, source)?;
119    let children = find_child(node, "enum_body")
120        .map(|body| {
121            named_children(body)
122                .filter(|c| c.kind() == "enum_value")
123                .map(|c| leaf(&source[c.byte_range()], SymbolKind::ENUM_MEMBER, c))
124                .collect::<Vec<_>>()
125        })
126        .filter(|c| !c.is_empty());
127
128    Some(DocumentSymbol {
129        name: name.into(),
130        detail: None,
131        kind: SymbolKind::ENUM,
132        range: range(node),
133        selection_range: child_id_range(node)?,
134        children,
135        tags: None,
136        deprecated: None,
137    })
138}
139
140/// Symbol whose name comes from its first `identifier` child.
141fn id_symbol(node: Node, source: &str, kind: SymbolKind) -> Option<DocumentSymbol> {
142    let name = child_id_text(node, source)?;
143    Some(DocumentSymbol {
144        name: name.into(),
145        detail: None,
146        kind,
147        range: range(node),
148        selection_range: child_id_range(node).unwrap_or(range(node)),
149        children: None,
150        tags: None,
151        deprecated: None,
152    })
153}
154
155/// Symbol whose name is the full node text (pragmas, using directives).
156fn text_symbol(node: Node, source: &str, kind: SymbolKind) -> DocumentSymbol {
157    let text = source[node.byte_range()].trim_end_matches(';').trim();
158    leaf(text, kind, node)
159}
160
161fn import_symbol(node: Node, source: &str) -> DocumentSymbol {
162    let name = find_child(node, "string")
163        .map(|s| format!("import {}", &source[s.byte_range()]))
164        .unwrap_or_else(|| {
165            source[node.byte_range()]
166                .trim_end_matches(';')
167                .trim()
168                .into()
169        });
170    leaf(&name, SymbolKind::MODULE, node)
171}
172
173/// Leaf symbol with no children — range equals selection_range.
174fn leaf(name: &str, kind: SymbolKind, node: Node) -> DocumentSymbol {
175    DocumentSymbol {
176        name: name.into(),
177        detail: None,
178        kind,
179        range: range(node),
180        selection_range: range(node),
181        children: None,
182        tags: None,
183        deprecated: None,
184    }
185}
186
187fn function_detail(node: Node, source: &str) -> String {
188    let params: Vec<&str> = named_children(node)
189        .filter(|c| c.kind() == "parameter")
190        .map(|c| source[c.byte_range()].trim())
191        .collect();
192
193    let returns: Vec<&str> = find_child(node, "return_type_definition")
194        .map(|ret| {
195            named_children(ret)
196                .filter(|c| c.kind() == "parameter")
197                .map(|c| source[c.byte_range()].trim())
198                .collect()
199        })
200        .unwrap_or_default();
201
202    let mut sig = format!("({})", params.join(", "));
203    if !returns.is_empty() {
204        sig.push_str(&format!(" returns ({})", returns.join(", ")));
205    }
206    sig
207}
208
209// ── Workspace symbols (flat, multi-file) ───────────────────────────────────
210
211/// Extract flat workspace symbols from multiple files.
212pub fn extract_workspace_symbols(files: &[(Url, String)]) -> Vec<SymbolInformation> {
213    let mut parser = Parser::new();
214    parser
215        .set_language(&tree_sitter_solidity::LANGUAGE.into())
216        .expect("failed to load Solidity grammar");
217
218    let mut symbols = Vec::new();
219    for (uri, source) in files {
220        if let Some(tree) = parser.parse(source, None) {
221            collect_workspace_symbols(tree.root_node(), source, uri, None, &mut symbols);
222        }
223    }
224    symbols
225}
226
227fn collect_workspace_symbols(
228    node: Node,
229    source: &str,
230    uri: &Url,
231    container: Option<&str>,
232    out: &mut Vec<SymbolInformation>,
233) {
234    for child in named_children(node) {
235        match child.kind() {
236            // Containers: recurse into body
237            "contract_declaration" | "interface_declaration" | "library_declaration" => {
238                let kind = match child.kind() {
239                    "interface_declaration" => SymbolKind::INTERFACE,
240                    "library_declaration" => SymbolKind::NAMESPACE,
241                    _ => SymbolKind::CLASS,
242                };
243                if let Some(name) = child_id_text(child, source) {
244                    push_info(out, name, kind, child, uri, container);
245                    if let Some(body) = find_child(child, "contract_body") {
246                        collect_workspace_symbols(body, source, uri, Some(name), out);
247                    }
248                }
249            }
250            "struct_declaration" => {
251                if let Some(name) = child_id_text(child, source) {
252                    push_info(out, name, SymbolKind::STRUCT, child, uri, container);
253                    if let Some(body) = find_child(child, "struct_body") {
254                        collect_workspace_symbols(body, source, uri, Some(name), out);
255                    }
256                }
257            }
258            "enum_declaration" => {
259                if let Some(name) = child_id_text(child, source) {
260                    push_info(out, name, SymbolKind::ENUM, child, uri, container);
261                    if let Some(body) = find_child(child, "enum_body") {
262                        collect_workspace_symbols(body, source, uri, Some(name), out);
263                    }
264                }
265            }
266            // Leaves
267            "function_definition" => {
268                push_id(out, child, source, SymbolKind::FUNCTION, uri, container)
269            }
270            "constructor_definition" => push_info(
271                out,
272                "constructor",
273                SymbolKind::CONSTRUCTOR,
274                child,
275                uri,
276                container,
277            ),
278            "state_variable_declaration" | "struct_member" => {
279                push_id(out, child, source, SymbolKind::FIELD, uri, container)
280            }
281            "event_definition" | "error_declaration" => {
282                push_id(out, child, source, SymbolKind::EVENT, uri, container)
283            }
284            "modifier_definition" => {
285                push_id(out, child, source, SymbolKind::METHOD, uri, container)
286            }
287            "enum_value" => push_info(
288                out,
289                &source[child.byte_range()],
290                SymbolKind::ENUM_MEMBER,
291                child,
292                uri,
293                container,
294            ),
295            "user_defined_type_definition" => push_id(
296                out,
297                child,
298                source,
299                SymbolKind::TYPE_PARAMETER,
300                uri,
301                container,
302            ),
303            _ => {}
304        }
305    }
306}
307
308fn push_id(
309    out: &mut Vec<SymbolInformation>,
310    node: Node,
311    source: &str,
312    kind: SymbolKind,
313    uri: &Url,
314    container: Option<&str>,
315) {
316    if let Some(name) = child_id_text(node, source) {
317        push_info(out, name, kind, node, uri, container);
318    }
319}
320
321fn push_info(
322    out: &mut Vec<SymbolInformation>,
323    name: &str,
324    kind: SymbolKind,
325    node: Node,
326    uri: &Url,
327    container: Option<&str>,
328) {
329    out.push(SymbolInformation {
330        name: name.into(),
331        kind,
332        tags: None,
333        deprecated: None,
334        location: Location {
335            uri: uri.clone(),
336            range: range(node),
337        },
338        container_name: container.map(Into::into),
339    });
340}
341
342// ── Helpers ────────────────────────────────────────────────────────────────
343
344fn parse(source: &str) -> Option<tree_sitter::Tree> {
345    let mut parser = Parser::new();
346    parser
347        .set_language(&tree_sitter_solidity::LANGUAGE.into())
348        .expect("failed to load Solidity grammar");
349    parser.parse(source, None)
350}
351
352fn range(node: Node) -> Range {
353    let s = node.start_position();
354    let e = node.end_position();
355    Range {
356        start: Position::new(s.row as u32, s.column as u32),
357        end: Position::new(e.row as u32, e.column as u32),
358    }
359}
360
361fn named_children(node: Node) -> impl Iterator<Item = Node> {
362    let mut cursor = node.walk();
363    let children: Vec<Node> = node
364        .children(&mut cursor)
365        .filter(|c| c.is_named())
366        .collect();
367    children.into_iter()
368}
369
370fn child_id_text<'a>(node: Node<'a>, source: &'a str) -> Option<&'a str> {
371    let mut cursor = node.walk();
372    node.children(&mut cursor)
373        .find(|c| c.kind() == "identifier" && c.is_named())
374        .map(|c| &source[c.byte_range()])
375}
376
377fn child_id_range(node: Node) -> Option<Range> {
378    let mut cursor = node.walk();
379    node.children(&mut cursor)
380        .find(|c| c.kind() == "identifier" && c.is_named())
381        .map(|c| range(c))
382}
383
384fn find_child<'a>(node: Node<'a>, kind: &str) -> Option<Node<'a>> {
385    let mut cursor = node.walk();
386    node.children(&mut cursor).find(|c| c.kind() == kind)
387}
388
389fn fallback_or_receive(node: Node, source: &str) -> String {
390    let mut cursor = node.walk();
391    node.children(&mut cursor)
392        .find(|c| !c.is_named() && matches!(&source[c.byte_range()], "fallback" | "receive"))
393        .map(|c| source[c.byte_range()].into())
394        .unwrap_or_else(|| "fallback".into())
395}
396
397// ── Tests ──────────────────────────────────────────────────────────────────
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_empty_source() {
405        assert!(extract_document_symbols("").is_empty());
406    }
407
408    #[test]
409    fn test_simple_contract() {
410        let source = r#"
411pragma solidity ^0.8.0;
412
413contract Counter {
414    uint256 public count;
415    function increment() public { count += 1; }
416    function getCount() public view returns (uint256) { return count; }
417}
418"#;
419        let symbols = extract_document_symbols(source);
420        assert!(symbols.len() >= 2);
421
422        let contract = symbols
423            .iter()
424            .find(|s| s.kind == SymbolKind::CLASS)
425            .unwrap();
426        assert_eq!(contract.name, "Counter");
427
428        let children = contract.children.as_ref().unwrap();
429        assert!(
430            children
431                .iter()
432                .any(|c| c.name == "count" && c.kind == SymbolKind::FIELD)
433        );
434        assert!(
435            children
436                .iter()
437                .any(|c| c.name == "increment" && c.kind == SymbolKind::FUNCTION)
438        );
439        assert!(
440            children
441                .iter()
442                .any(|c| c.name == "getCount" && c.kind == SymbolKind::FUNCTION)
443        );
444    }
445
446    #[test]
447    fn test_struct_with_members() {
448        let source = "contract Foo { struct Info { string name; uint256 value; } }";
449        let symbols = extract_document_symbols(source);
450        let members = symbols[0]
451            .children
452            .as_ref()
453            .unwrap()
454            .iter()
455            .find(|c| c.kind == SymbolKind::STRUCT)
456            .unwrap()
457            .children
458            .as_ref()
459            .unwrap();
460        assert_eq!(members.len(), 2);
461        assert!(members.iter().any(|m| m.name == "name"));
462        assert!(members.iter().any(|m| m.name == "value"));
463    }
464
465    #[test]
466    fn test_enum_with_values() {
467        let source = "contract Foo { enum Status { Active, Paused, Stopped } }";
468        let symbols = extract_document_symbols(source);
469        let members = symbols[0]
470            .children
471            .as_ref()
472            .unwrap()
473            .iter()
474            .find(|c| c.kind == SymbolKind::ENUM)
475            .unwrap()
476            .children
477            .as_ref()
478            .unwrap();
479        assert_eq!(members.len(), 3);
480        assert!(members.iter().any(|m| m.name == "Active"));
481        assert!(members.iter().any(|m| m.name == "Paused"));
482        assert!(members.iter().any(|m| m.name == "Stopped"));
483    }
484
485    #[test]
486    fn test_all_member_types() {
487        let source = r#"
488contract Token {
489    event Transfer(address from, address to, uint256 value);
490    error Unauthorized();
491    uint256 public totalSupply;
492    modifier onlyOwner() { _; }
493    constructor() {}
494    function transfer(address to, uint256 amount) external returns (bool) { return true; }
495    fallback() external payable {}
496    receive() external payable {}
497    type Price is uint256;
498}
499"#;
500        let children = extract_document_symbols(source)
501            .into_iter()
502            .find(|s| s.kind == SymbolKind::CLASS)
503            .unwrap()
504            .children
505            .unwrap();
506
507        assert!(
508            children
509                .iter()
510                .any(|c| c.name == "Transfer" && c.kind == SymbolKind::EVENT)
511        );
512        assert!(
513            children
514                .iter()
515                .any(|c| c.name == "Unauthorized" && c.kind == SymbolKind::EVENT)
516        );
517        assert!(
518            children
519                .iter()
520                .any(|c| c.name == "totalSupply" && c.kind == SymbolKind::FIELD)
521        );
522        assert!(
523            children
524                .iter()
525                .any(|c| c.name == "onlyOwner" && c.kind == SymbolKind::METHOD)
526        );
527        assert!(
528            children
529                .iter()
530                .any(|c| c.name == "constructor" && c.kind == SymbolKind::CONSTRUCTOR)
531        );
532        assert!(
533            children
534                .iter()
535                .any(|c| c.name == "transfer" && c.kind == SymbolKind::FUNCTION)
536        );
537        assert!(
538            children
539                .iter()
540                .any(|c| c.name == "fallback" && c.kind == SymbolKind::FUNCTION)
541        );
542        assert!(
543            children
544                .iter()
545                .any(|c| c.name == "receive" && c.kind == SymbolKind::FUNCTION)
546        );
547        assert!(
548            children
549                .iter()
550                .any(|c| c.name == "Price" && c.kind == SymbolKind::TYPE_PARAMETER)
551        );
552    }
553
554    #[test]
555    fn test_interface_and_library() {
556        let source = r#"
557interface IToken { function transfer(address to, uint256 amount) external returns (bool); }
558library SafeMath { function add(uint256 a, uint256 b) internal pure returns (uint256) { return a + b; } }
559"#;
560        let symbols = extract_document_symbols(source);
561        assert!(
562            symbols
563                .iter()
564                .any(|s| s.name == "IToken" && s.kind == SymbolKind::INTERFACE)
565        );
566        assert!(
567            symbols
568                .iter()
569                .any(|s| s.name == "SafeMath" && s.kind == SymbolKind::NAMESPACE)
570        );
571    }
572
573    #[test]
574    fn test_workspace_symbols() {
575        let uri = Url::parse("file:///test.sol").unwrap();
576        let source = "contract Foo { uint256 public bar; function baz() public {} }";
577        let symbols = extract_workspace_symbols(&[(uri, source.into())]);
578        assert!(
579            symbols
580                .iter()
581                .any(|s| s.name == "Foo" && s.kind == SymbolKind::CLASS)
582        );
583        assert!(
584            symbols
585                .iter()
586                .any(|s| s.name == "bar" && s.container_name.as_deref() == Some("Foo"))
587        );
588        assert!(
589            symbols
590                .iter()
591                .any(|s| s.name == "baz" && s.container_name.as_deref() == Some("Foo"))
592        );
593    }
594
595    #[test]
596    fn test_counter_sol() {
597        let source = std::fs::read_to_string("example/Counter.sol").unwrap();
598        let symbols = extract_document_symbols(&source);
599        let children = symbols
600            .iter()
601            .find(|s| s.kind == SymbolKind::CLASS)
602            .unwrap()
603            .children
604            .as_ref()
605            .unwrap();
606        assert!(children.iter().any(|c| c.name == "increment"));
607        assert!(children.iter().any(|c| c.name == "decrement"));
608        assert!(children.iter().any(|c| c.name == "reset"));
609        assert!(children.iter().any(|c| c.name == "getCount"));
610    }
611
612    #[test]
613    fn test_function_detail() {
614        let source = "contract Foo { function bar(uint256 x, address y) public pure returns (bool) { return true; } }";
615        let func = extract_document_symbols(source)[0]
616            .children
617            .as_ref()
618            .unwrap()
619            .iter()
620            .find(|c| c.name == "bar")
621            .unwrap()
622            .clone();
623        let detail = func.detail.unwrap();
624        assert!(detail.contains("uint256 x"));
625        assert!(detail.contains("address y"));
626        assert!(detail.contains("returns"));
627    }
628}