Skip to main content

solidity_language_server/
highlight.rs

1use tower_lsp::lsp_types::{DocumentHighlight, DocumentHighlightKind, Position, Range};
2use tree_sitter::{Node, Parser};
3
4/// Return all highlights for the identifier under the cursor.
5///
6/// Walks the tree-sitter parse tree to find every occurrence of the same
7/// identifier text and classifies each as Read or Write based on its
8/// syntactic context.
9pub fn document_highlights(source: &str, position: Position) -> Vec<DocumentHighlight> {
10    let tree = match parse(source) {
11        Some(t) => t,
12        None => return vec![],
13    };
14
15    let root = tree.root_node();
16
17    // Find the identifier node at the cursor position.
18    let target = match find_identifier_at(root, source, position) {
19        Some(node) => node,
20        None => return vec![],
21    };
22
23    let name = &source[target.byte_range()];
24
25    // Collect every identifier in the file with the same text.
26    let mut highlights = Vec::new();
27    collect_matching_identifiers(root, source, name, &mut highlights);
28    highlights
29}
30
31/// Find the identifier node at the given cursor position.
32///
33/// Descends to the deepest named node at the position, then walks up to
34/// find the nearest `identifier` node.
35fn find_identifier_at<'a>(root: Node<'a>, _source: &str, position: Position) -> Option<Node<'a>> {
36    let point = tree_sitter::Point {
37        row: position.line as usize,
38        column: position.character as usize,
39    };
40
41    let node = root.descendant_for_point_range(point, point)?;
42
43    // If we landed directly on an identifier, use it.
44    if node.kind() == "identifier" {
45        return Some(node);
46    }
47
48    // Check if the node text at this position is a keyword-like identifier
49    // that tree-sitter doesn't classify as "identifier" (e.g., type names
50    // in some contexts). Walk up a couple of levels.
51    let mut current = node;
52    for _ in 0..3 {
53        if current.kind() == "identifier" {
54            return Some(current);
55        }
56        current = current.parent()?;
57    }
58
59    // If the deepest node is a short anonymous node (like a keyword token),
60    // check if it overlaps with an identifier sibling at the same position.
61    // This handles cases where tree-sitter places the cursor on a non-named
62    // token adjacent to an identifier.
63    let parent = node.parent()?;
64    let mut cursor = parent.walk();
65    parent
66        .children(&mut cursor)
67        .find(|child| child.kind() == "identifier" && contains_point(*child, point))
68}
69
70/// Check if a node's range contains the given point.
71fn contains_point(node: Node, point: tree_sitter::Point) -> bool {
72    node.start_position() <= point && point <= node.end_position()
73}
74
75/// Recursively collect all identifier nodes matching `name`, classifying
76/// each as Read or Write.
77fn collect_matching_identifiers(
78    node: Node,
79    source: &str,
80    name: &str,
81    out: &mut Vec<DocumentHighlight>,
82) {
83    if node.kind() == "identifier" && &source[node.byte_range()] == name {
84        let kind = classify_highlight(node, source);
85        out.push(DocumentHighlight {
86            range: range(node),
87            kind: Some(kind),
88        });
89        return; // identifiers have no children
90    }
91
92    // Recurse into children
93    let mut cursor = node.walk();
94    for child in node.children(&mut cursor) {
95        collect_matching_identifiers(child, source, name, out);
96    }
97}
98
99/// Determine whether an identifier occurrence is a Write or Read.
100///
101/// Write contexts:
102/// - Declaration name (function, contract, struct, enum, event, error, modifier,
103///   state variable, local variable, parameter, constructor)
104/// - Left-hand side of an assignment expression (`=`, `+=`, `-=`, etc.)
105/// - Increment/decrement expressions (`++`, `--`)
106///
107/// Everything else is Read.
108///
109/// Note: in tree-sitter-solidity, identifiers inside statements are wrapped
110/// in `expression` nodes:
111///   `count += 1` → augmented_assignment_expression > expression > identifier
112///   `count++`    → update_expression > expression > identifier
113///   `x = 5`      → assignment_expression > expression > identifier
114/// So we check both the parent and grandparent to classify correctly.
115fn classify_highlight(node: Node, _source: &str) -> DocumentHighlightKind {
116    let parent = match node.parent() {
117        Some(p) => p,
118        None => return DocumentHighlightKind::READ,
119    };
120
121    // First, check if the immediate parent is a declaration context
122    // (identifiers in declarations are direct children, not wrapped in expression).
123    match parent.kind() {
124        // ── Declaration sites (the name being declared) ────────────────
125        "function_definition"
126        | "constructor_definition"
127        | "modifier_definition"
128        | "contract_declaration"
129        | "interface_declaration"
130        | "library_declaration"
131        | "struct_declaration"
132        | "enum_declaration"
133        | "event_definition"
134        | "error_declaration"
135        | "user_defined_type_definition"
136        | "state_variable_declaration"
137        | "struct_member" => {
138            if is_first_identifier(parent, node) {
139                return DocumentHighlightKind::WRITE;
140            }
141            return DocumentHighlightKind::READ;
142        }
143
144        // Local variable declaration: `uint256 x` inside variable_declaration
145        "variable_declaration" => {
146            if is_first_identifier(parent, node) {
147                return DocumentHighlightKind::WRITE;
148            }
149            return DocumentHighlightKind::READ;
150        }
151
152        // Parameters: `function foo(uint256 x)` — the name is a Write
153        "parameter" | "event_parameter" | "error_parameter" => {
154            if is_first_identifier(parent, node) {
155                return DocumentHighlightKind::WRITE;
156            }
157            return DocumentHighlightKind::READ;
158        }
159
160        _ => {}
161    }
162
163    // For expression-wrapped identifiers, check the grandparent.
164    // Tree structure: grandparent > expression(parent) > identifier(node)
165    if parent.kind() == "expression"
166        && let Some(grandparent) = parent.parent()
167    {
168        return classify_expression_context(grandparent, parent);
169    }
170
171    DocumentHighlightKind::READ
172}
173
174/// Classify an identifier that is wrapped in an `expression` node.
175/// `grandparent` is the node above `expression`, `expr_node` is the
176/// `expression` wrapping the identifier.
177fn classify_expression_context(grandparent: Node, expr_node: Node) -> DocumentHighlightKind {
178    match grandparent.kind() {
179        // `x = 5` → assignment_expression > expression(lhs) > identifier
180        "assignment_expression" => {
181            if is_lhs_of_assignment(grandparent, expr_node) {
182                DocumentHighlightKind::WRITE
183            } else {
184                DocumentHighlightKind::READ
185            }
186        }
187
188        // `count += 1` → augmented_assignment_expression > expression(lhs) > identifier
189        "augmented_assignment_expression" => {
190            if is_lhs_of_assignment(grandparent, expr_node) {
191                DocumentHighlightKind::WRITE
192            } else {
193                DocumentHighlightKind::READ
194            }
195        }
196
197        // `count++`, `++count` → update_expression > expression > identifier
198        "update_expression" => DocumentHighlightKind::WRITE,
199
200        // `delete x` → delete_expression > expression > identifier
201        "delete_expression" | "delete_statement" => DocumentHighlightKind::WRITE,
202
203        // Tuple destructuring: (a, b) = func()
204        // The tuple_expression wraps expressions that are LHS of assignment
205        "tuple_expression" => {
206            if let Some(great_grandparent) = grandparent.parent()
207                && let Some(ggp) = great_grandparent.parent()
208                && (ggp.kind() == "assignment_expression"
209                    || ggp.kind() == "augmented_assignment_expression")
210                && is_lhs_of_assignment(ggp, great_grandparent)
211            {
212                return DocumentHighlightKind::WRITE;
213            }
214            DocumentHighlightKind::READ
215        }
216
217        _ => DocumentHighlightKind::READ,
218    }
219}
220
221/// Check if `node` is the first `identifier` child of `parent`.
222fn is_first_identifier(parent: Node, node: Node) -> bool {
223    let mut cursor = parent.walk();
224    for child in parent.children(&mut cursor) {
225        if child.kind() == "identifier" {
226            return child.id() == node.id();
227        }
228    }
229    false
230}
231
232/// Check if `node` is on the left-hand side of an assignment expression.
233///
234/// In tree-sitter-solidity, assignment_expression has the structure:
235///   assignment_expression -> lhs operator rhs
236/// The LHS is the first named child.
237fn is_lhs_of_assignment(assignment: Node, node: Node) -> bool {
238    let mut cursor = assignment.walk();
239    for child in assignment.children(&mut cursor) {
240        if child.is_named() {
241            // The first named child is the LHS.
242            // Check if `node` is the LHS itself or is contained within it.
243            return child.id() == node.id()
244                || (child.start_byte() <= node.start_byte()
245                    && node.end_byte() <= child.end_byte());
246        }
247    }
248    false
249}
250
251// ── Helpers ────────────────────────────────────────────────────────────────
252
253fn parse(source: &str) -> Option<tree_sitter::Tree> {
254    let mut parser = Parser::new();
255    parser
256        .set_language(&tree_sitter_solidity::LANGUAGE.into())
257        .expect("failed to load Solidity grammar");
258    parser.parse(source, None)
259}
260
261fn range(node: Node) -> Range {
262    let s = node.start_position();
263    let e = node.end_position();
264    Range {
265        start: Position::new(s.row as u32, s.column as u32),
266        end: Position::new(e.row as u32, e.column as u32),
267    }
268}
269
270// ── Tests ──────────────────────────────────────────────────────────────────
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    /// Helper: return highlights as (line, col, kind) tuples for easy assertion.
277    fn highlights_at(source: &str, line: u32, col: u32) -> Vec<(u32, u32, DocumentHighlightKind)> {
278        let result = document_highlights(source, Position::new(line, col));
279        result
280            .into_iter()
281            .map(|h| (h.range.start.line, h.range.start.character, h.kind.unwrap()))
282            .collect()
283    }
284
285    #[test]
286    fn test_empty_source() {
287        assert!(document_highlights("", Position::new(0, 0)).is_empty());
288    }
289
290    #[test]
291    fn test_no_identifier_at_position() {
292        let source = "pragma solidity ^0.8.0;";
293        let result = document_highlights(source, Position::new(0, 0));
294        // "pragma" is a keyword, not an identifier — may or may not match
295        // depending on tree-sitter grammar. Either empty or non-empty is fine.
296        let _ = result;
297    }
298
299    #[test]
300    fn test_state_variable_read_write() {
301        let source = r#"contract Foo {
302    uint256 public count;
303    function inc() public {
304        count += 1;
305    }
306    function get() public view returns (uint256) {
307        return count;
308    }
309}"#;
310        // Click on "count" at the declaration (line 1, col 23)
311        let highlights = highlights_at(source, 1, 23);
312        assert!(
313            highlights.len() == 3,
314            "expected 3 highlights for 'count', got {}: {:?}",
315            highlights.len(),
316            highlights
317        );
318
319        // Declaration should be Write
320        let decl = highlights.iter().find(|h| h.0 == 1);
321        assert_eq!(
322            decl.map(|h| h.2),
323            Some(DocumentHighlightKind::WRITE),
324            "declaration should be Write"
325        );
326
327        // `count += 1` should be Write
328        let assign = highlights.iter().find(|h| h.0 == 3);
329        assert_eq!(
330            assign.map(|h| h.2),
331            Some(DocumentHighlightKind::WRITE),
332            "`count += 1` should be Write"
333        );
334
335        // `return count` should be Read
336        let read = highlights.iter().find(|h| h.0 == 6);
337        assert_eq!(
338            read.map(|h| h.2),
339            Some(DocumentHighlightKind::READ),
340            "`return count` should be Read"
341        );
342    }
343
344    #[test]
345    fn test_function_name_highlights() {
346        let source = r#"contract Foo {
347    function bar() public {}
348    function baz() public {
349        bar();
350    }
351}"#;
352        // Click on "bar" at its definition (line 1)
353        let highlights = highlights_at(source, 1, 13);
354        assert_eq!(highlights.len(), 2, "expected 2 highlights for 'bar'");
355
356        // Definition is Write
357        assert_eq!(highlights[0].2, DocumentHighlightKind::WRITE);
358        // Call is Read
359        assert_eq!(highlights[1].2, DocumentHighlightKind::READ);
360    }
361
362    #[test]
363    fn test_parameter_highlights() {
364        let source = r#"contract Foo {
365    function add(uint256 a, uint256 b) public pure returns (uint256) {
366        return a + b;
367    }
368}"#;
369        // Click on "a" at parameter declaration (line 1, col 25)
370        let highlights = highlights_at(source, 1, 25);
371        assert_eq!(highlights.len(), 2, "expected 2 highlights for 'a'");
372        // Parameter declaration is Write
373        assert_eq!(highlights[0].2, DocumentHighlightKind::WRITE);
374        // Usage in `return a + b` is Read
375        assert_eq!(highlights[1].2, DocumentHighlightKind::READ);
376    }
377
378    #[test]
379    fn test_local_variable_highlights() {
380        let source = r#"contract Foo {
381    function bar() public {
382        uint256 x = 1;
383        uint256 y = x + 1;
384        x = y;
385    }
386}"#;
387        // Click on "x" at declaration (line 2)
388        let highlights = highlights_at(source, 2, 16);
389        assert_eq!(
390            highlights.len(),
391            3,
392            "expected 3 highlights for 'x': {:?}",
393            highlights
394        );
395        // Declaration: Write
396        assert_eq!(highlights[0].2, DocumentHighlightKind::WRITE);
397        // `x + 1`: Read
398        assert_eq!(highlights[1].2, DocumentHighlightKind::READ);
399        // `x = y`: Write (LHS of assignment)
400        assert_eq!(highlights[2].2, DocumentHighlightKind::WRITE);
401    }
402
403    #[test]
404    fn test_contract_name_highlights() {
405        let source = r#"contract Foo {
406    Foo public self;
407}"#;
408        let highlights = highlights_at(source, 0, 9);
409        assert!(
410            highlights.len() >= 1,
411            "expected at least 1 highlight for contract name 'Foo'"
412        );
413        // Contract declaration name is Write
414        assert_eq!(highlights[0].2, DocumentHighlightKind::WRITE);
415    }
416
417    #[test]
418    fn test_struct_name_and_members() {
419        let source = r#"contract Foo {
420    struct Info {
421        string name;
422        uint256 value;
423    }
424    Info public info;
425}"#;
426        // Click on "Info" at struct declaration (line 1)
427        let highlights = highlights_at(source, 1, 11);
428        assert!(
429            highlights.len() >= 2,
430            "expected at least 2 highlights for 'Info'"
431        );
432        // Struct declaration is Write
433        assert_eq!(highlights[0].2, DocumentHighlightKind::WRITE);
434    }
435
436    #[test]
437    fn test_event_name_highlights() {
438        let source = r#"contract Foo {
439    event Transfer(address from, address to, uint256 value);
440    function send() public {
441        emit Transfer(msg.sender, address(0), 100);
442    }
443}"#;
444        // Click on "Transfer" at event definition (line 1)
445        let highlights = highlights_at(source, 1, 10);
446        assert_eq!(highlights.len(), 2, "expected 2 highlights for 'Transfer'");
447        assert_eq!(highlights[0].2, DocumentHighlightKind::WRITE);
448        assert_eq!(highlights[1].2, DocumentHighlightKind::READ);
449    }
450
451    #[test]
452    fn test_no_cross_name_pollution() {
453        let source = r#"contract Foo {
454    uint256 public x;
455    uint256 public y;
456    function bar() public {
457        x = y;
458    }
459}"#;
460        // Click on "x" — should NOT highlight "y"
461        let highlights = highlights_at(source, 1, 23);
462        for h in &highlights {
463            let text = &source[..];
464            let line: &str = text.lines().nth(h.0 as usize).unwrap();
465            assert!(
466                line.contains("x"),
467                "highlight on line {} should contain 'x': '{}'",
468                h.0,
469                line
470            );
471        }
472    }
473
474    #[test]
475    fn test_enum_name_highlights() {
476        let source = r#"contract Foo {
477    enum Status { Active, Paused }
478    Status public status;
479}"#;
480        let highlights = highlights_at(source, 1, 9);
481        assert!(
482            highlights.len() >= 2,
483            "expected at least 2 highlights for 'Status'"
484        );
485        assert_eq!(highlights[0].2, DocumentHighlightKind::WRITE);
486    }
487
488    #[test]
489    fn test_modifier_name_highlights() {
490        let source = r#"contract Foo {
491    address public owner;
492    modifier onlyOwner() {
493        require(msg.sender == owner);
494        _;
495    }
496    function bar() public onlyOwner {}
497}"#;
498        let highlights = highlights_at(source, 2, 13);
499        assert_eq!(highlights.len(), 2, "expected 2 highlights for 'onlyOwner'");
500        assert_eq!(highlights[0].2, DocumentHighlightKind::WRITE);
501        assert_eq!(highlights[1].2, DocumentHighlightKind::READ);
502    }
503
504    #[test]
505    fn test_shop_sol() {
506        let source = std::fs::read_to_string("example/Shop.sol").unwrap();
507        // "PRICE" is declared at line 68 (0-indexed) and used in buy()
508        let highlights = document_highlights(&source, Position::new(68, 22));
509        assert!(
510            highlights.len() >= 2,
511            "Shop.sol 'PRICE' should have at least 2 highlights, got {}",
512            highlights.len()
513        );
514
515        // The declaration (line 68) should be Write
516        let decl = highlights.iter().find(|h| h.range.start.line == 68);
517        assert_eq!(
518            decl.map(|h| h.kind),
519            Some(Some(DocumentHighlightKind::WRITE))
520        );
521    }
522
523    #[test]
524    fn test_increment_is_write() {
525        let source = r#"contract Foo {
526    uint256 public x;
527    function inc() public {
528        x++;
529    }
530}"#;
531        // `x` is at column 19: "    uint256 public x;"
532        let highlights = highlights_at(source, 1, 19);
533        assert!(
534            highlights.len() >= 2,
535            "expected at least 2 highlights for 'x', got {}: {:?}",
536            highlights.len(),
537            highlights
538        );
539        let inc = highlights.iter().find(|h| h.0 == 3);
540        assert_eq!(
541            inc.map(|h| h.2),
542            Some(DocumentHighlightKind::WRITE),
543            "`x++` should be Write, all highlights: {:?}",
544            highlights
545        );
546    }
547
548    #[test]
549    fn test_cursor_on_usage_finds_all() {
550        let source = r#"contract Foo {
551    uint256 public count;
552    function inc() public {
553        count += 1;
554    }
555}"#;
556        // Click on "count" at the usage site (line 3), not the declaration
557        let highlights_from_usage = highlights_at(source, 3, 8);
558        let highlights_from_decl = highlights_at(source, 1, 23);
559        assert_eq!(
560            highlights_from_usage.len(),
561            highlights_from_decl.len(),
562            "clicking on usage vs declaration should find the same set"
563        );
564    }
565}