solidity_language_server/
highlight.rs1use tower_lsp::lsp_types::{DocumentHighlight, DocumentHighlightKind, Position, Range};
2use tree_sitter::{Node, Parser};
3
4pub fn document_highlights(source: &str, position: Position) -> Vec<DocumentHighlight> {
10 let tree = match parse(source) {
11 Some(t) => t,
12 None => return vec![],
13 };
14
15 let root = tree.root_node();
16
17 let target = match find_identifier_at(root, source, position) {
19 Some(node) => node,
20 None => return vec![],
21 };
22
23 let name = &source[target.byte_range()];
24
25 let mut highlights = Vec::new();
27 collect_matching_identifiers(root, source, name, &mut highlights);
28 highlights
29}
30
31fn find_identifier_at<'a>(root: Node<'a>, _source: &str, position: Position) -> Option<Node<'a>> {
36 let point = tree_sitter::Point {
37 row: position.line as usize,
38 column: position.character as usize,
39 };
40
41 let node = root.descendant_for_point_range(point, point)?;
42
43 if node.kind() == "identifier" {
45 return Some(node);
46 }
47
48 let mut current = node;
52 for _ in 0..3 {
53 if current.kind() == "identifier" {
54 return Some(current);
55 }
56 current = current.parent()?;
57 }
58
59 let parent = node.parent()?;
64 let mut cursor = parent.walk();
65 parent
66 .children(&mut cursor)
67 .find(|child| child.kind() == "identifier" && contains_point(*child, point))
68}
69
70fn contains_point(node: Node, point: tree_sitter::Point) -> bool {
72 node.start_position() <= point && point <= node.end_position()
73}
74
75fn collect_matching_identifiers(
78 node: Node,
79 source: &str,
80 name: &str,
81 out: &mut Vec<DocumentHighlight>,
82) {
83 if node.kind() == "identifier" && &source[node.byte_range()] == name {
84 let kind = classify_highlight(node, source);
85 out.push(DocumentHighlight {
86 range: range(node),
87 kind: Some(kind),
88 });
89 return; }
91
92 let mut cursor = node.walk();
94 for child in node.children(&mut cursor) {
95 collect_matching_identifiers(child, source, name, out);
96 }
97}
98
99fn classify_highlight(node: Node, _source: &str) -> DocumentHighlightKind {
116 let parent = match node.parent() {
117 Some(p) => p,
118 None => return DocumentHighlightKind::READ,
119 };
120
121 match parent.kind() {
124 "function_definition"
126 | "constructor_definition"
127 | "modifier_definition"
128 | "contract_declaration"
129 | "interface_declaration"
130 | "library_declaration"
131 | "struct_declaration"
132 | "enum_declaration"
133 | "event_definition"
134 | "error_declaration"
135 | "user_defined_type_definition"
136 | "state_variable_declaration"
137 | "struct_member" => {
138 if is_first_identifier(parent, node) {
139 return DocumentHighlightKind::WRITE;
140 }
141 return DocumentHighlightKind::READ;
142 }
143
144 "variable_declaration" => {
146 if is_first_identifier(parent, node) {
147 return DocumentHighlightKind::WRITE;
148 }
149 return DocumentHighlightKind::READ;
150 }
151
152 "parameter" | "event_parameter" | "error_parameter" => {
154 if is_first_identifier(parent, node) {
155 return DocumentHighlightKind::WRITE;
156 }
157 return DocumentHighlightKind::READ;
158 }
159
160 _ => {}
161 }
162
163 if parent.kind() == "expression"
166 && let Some(grandparent) = parent.parent()
167 {
168 return classify_expression_context(grandparent, parent);
169 }
170
171 DocumentHighlightKind::READ
172}
173
174fn classify_expression_context(grandparent: Node, expr_node: Node) -> DocumentHighlightKind {
178 match grandparent.kind() {
179 "assignment_expression" => {
181 if is_lhs_of_assignment(grandparent, expr_node) {
182 DocumentHighlightKind::WRITE
183 } else {
184 DocumentHighlightKind::READ
185 }
186 }
187
188 "augmented_assignment_expression" => {
190 if is_lhs_of_assignment(grandparent, expr_node) {
191 DocumentHighlightKind::WRITE
192 } else {
193 DocumentHighlightKind::READ
194 }
195 }
196
197 "update_expression" => DocumentHighlightKind::WRITE,
199
200 "delete_expression" | "delete_statement" => DocumentHighlightKind::WRITE,
202
203 "tuple_expression" => {
206 if let Some(great_grandparent) = grandparent.parent()
207 && let Some(ggp) = great_grandparent.parent()
208 && (ggp.kind() == "assignment_expression"
209 || ggp.kind() == "augmented_assignment_expression")
210 && is_lhs_of_assignment(ggp, great_grandparent)
211 {
212 return DocumentHighlightKind::WRITE;
213 }
214 DocumentHighlightKind::READ
215 }
216
217 _ => DocumentHighlightKind::READ,
218 }
219}
220
221fn is_first_identifier(parent: Node, node: Node) -> bool {
223 let mut cursor = parent.walk();
224 for child in parent.children(&mut cursor) {
225 if child.kind() == "identifier" {
226 return child.id() == node.id();
227 }
228 }
229 false
230}
231
232fn is_lhs_of_assignment(assignment: Node, node: Node) -> bool {
238 let mut cursor = assignment.walk();
239 for child in assignment.children(&mut cursor) {
240 if child.is_named() {
241 return child.id() == node.id()
244 || (child.start_byte() <= node.start_byte()
245 && node.end_byte() <= child.end_byte());
246 }
247 }
248 false
249}
250
251fn parse(source: &str) -> Option<tree_sitter::Tree> {
254 let mut parser = Parser::new();
255 parser
256 .set_language(&tree_sitter_solidity::LANGUAGE.into())
257 .expect("failed to load Solidity grammar");
258 parser.parse(source, None)
259}
260
261fn range(node: Node) -> Range {
262 let s = node.start_position();
263 let e = node.end_position();
264 Range {
265 start: Position::new(s.row as u32, s.column as u32),
266 end: Position::new(e.row as u32, e.column as u32),
267 }
268}
269
270#[cfg(test)]
273mod tests {
274 use super::*;
275
276 fn highlights_at(source: &str, line: u32, col: u32) -> Vec<(u32, u32, DocumentHighlightKind)> {
278 let result = document_highlights(source, Position::new(line, col));
279 result
280 .into_iter()
281 .map(|h| (h.range.start.line, h.range.start.character, h.kind.unwrap()))
282 .collect()
283 }
284
285 #[test]
286 fn test_empty_source() {
287 assert!(document_highlights("", Position::new(0, 0)).is_empty());
288 }
289
290 #[test]
291 fn test_no_identifier_at_position() {
292 let source = "pragma solidity ^0.8.0;";
293 let result = document_highlights(source, Position::new(0, 0));
294 let _ = result;
297 }
298
299 #[test]
300 fn test_state_variable_read_write() {
301 let source = r#"contract Foo {
302 uint256 public count;
303 function inc() public {
304 count += 1;
305 }
306 function get() public view returns (uint256) {
307 return count;
308 }
309}"#;
310 let highlights = highlights_at(source, 1, 23);
312 assert!(
313 highlights.len() == 3,
314 "expected 3 highlights for 'count', got {}: {:?}",
315 highlights.len(),
316 highlights
317 );
318
319 let decl = highlights.iter().find(|h| h.0 == 1);
321 assert_eq!(
322 decl.map(|h| h.2),
323 Some(DocumentHighlightKind::WRITE),
324 "declaration should be Write"
325 );
326
327 let assign = highlights.iter().find(|h| h.0 == 3);
329 assert_eq!(
330 assign.map(|h| h.2),
331 Some(DocumentHighlightKind::WRITE),
332 "`count += 1` should be Write"
333 );
334
335 let read = highlights.iter().find(|h| h.0 == 6);
337 assert_eq!(
338 read.map(|h| h.2),
339 Some(DocumentHighlightKind::READ),
340 "`return count` should be Read"
341 );
342 }
343
344 #[test]
345 fn test_function_name_highlights() {
346 let source = r#"contract Foo {
347 function bar() public {}
348 function baz() public {
349 bar();
350 }
351}"#;
352 let highlights = highlights_at(source, 1, 13);
354 assert_eq!(highlights.len(), 2, "expected 2 highlights for 'bar'");
355
356 assert_eq!(highlights[0].2, DocumentHighlightKind::WRITE);
358 assert_eq!(highlights[1].2, DocumentHighlightKind::READ);
360 }
361
362 #[test]
363 fn test_parameter_highlights() {
364 let source = r#"contract Foo {
365 function add(uint256 a, uint256 b) public pure returns (uint256) {
366 return a + b;
367 }
368}"#;
369 let highlights = highlights_at(source, 1, 25);
371 assert_eq!(highlights.len(), 2, "expected 2 highlights for 'a'");
372 assert_eq!(highlights[0].2, DocumentHighlightKind::WRITE);
374 assert_eq!(highlights[1].2, DocumentHighlightKind::READ);
376 }
377
378 #[test]
379 fn test_local_variable_highlights() {
380 let source = r#"contract Foo {
381 function bar() public {
382 uint256 x = 1;
383 uint256 y = x + 1;
384 x = y;
385 }
386}"#;
387 let highlights = highlights_at(source, 2, 16);
389 assert_eq!(
390 highlights.len(),
391 3,
392 "expected 3 highlights for 'x': {:?}",
393 highlights
394 );
395 assert_eq!(highlights[0].2, DocumentHighlightKind::WRITE);
397 assert_eq!(highlights[1].2, DocumentHighlightKind::READ);
399 assert_eq!(highlights[2].2, DocumentHighlightKind::WRITE);
401 }
402
403 #[test]
404 fn test_contract_name_highlights() {
405 let source = r#"contract Foo {
406 Foo public self;
407}"#;
408 let highlights = highlights_at(source, 0, 9);
409 assert!(
410 highlights.len() >= 1,
411 "expected at least 1 highlight for contract name 'Foo'"
412 );
413 assert_eq!(highlights[0].2, DocumentHighlightKind::WRITE);
415 }
416
417 #[test]
418 fn test_struct_name_and_members() {
419 let source = r#"contract Foo {
420 struct Info {
421 string name;
422 uint256 value;
423 }
424 Info public info;
425}"#;
426 let highlights = highlights_at(source, 1, 11);
428 assert!(
429 highlights.len() >= 2,
430 "expected at least 2 highlights for 'Info'"
431 );
432 assert_eq!(highlights[0].2, DocumentHighlightKind::WRITE);
434 }
435
436 #[test]
437 fn test_event_name_highlights() {
438 let source = r#"contract Foo {
439 event Transfer(address from, address to, uint256 value);
440 function send() public {
441 emit Transfer(msg.sender, address(0), 100);
442 }
443}"#;
444 let highlights = highlights_at(source, 1, 10);
446 assert_eq!(highlights.len(), 2, "expected 2 highlights for 'Transfer'");
447 assert_eq!(highlights[0].2, DocumentHighlightKind::WRITE);
448 assert_eq!(highlights[1].2, DocumentHighlightKind::READ);
449 }
450
451 #[test]
452 fn test_no_cross_name_pollution() {
453 let source = r#"contract Foo {
454 uint256 public x;
455 uint256 public y;
456 function bar() public {
457 x = y;
458 }
459}"#;
460 let highlights = highlights_at(source, 1, 23);
462 for h in &highlights {
463 let text = &source[..];
464 let line: &str = text.lines().nth(h.0 as usize).unwrap();
465 assert!(
466 line.contains("x"),
467 "highlight on line {} should contain 'x': '{}'",
468 h.0,
469 line
470 );
471 }
472 }
473
474 #[test]
475 fn test_enum_name_highlights() {
476 let source = r#"contract Foo {
477 enum Status { Active, Paused }
478 Status public status;
479}"#;
480 let highlights = highlights_at(source, 1, 9);
481 assert!(
482 highlights.len() >= 2,
483 "expected at least 2 highlights for 'Status'"
484 );
485 assert_eq!(highlights[0].2, DocumentHighlightKind::WRITE);
486 }
487
488 #[test]
489 fn test_modifier_name_highlights() {
490 let source = r#"contract Foo {
491 address public owner;
492 modifier onlyOwner() {
493 require(msg.sender == owner);
494 _;
495 }
496 function bar() public onlyOwner {}
497}"#;
498 let highlights = highlights_at(source, 2, 13);
499 assert_eq!(highlights.len(), 2, "expected 2 highlights for 'onlyOwner'");
500 assert_eq!(highlights[0].2, DocumentHighlightKind::WRITE);
501 assert_eq!(highlights[1].2, DocumentHighlightKind::READ);
502 }
503
504 #[test]
505 fn test_shop_sol() {
506 let source = std::fs::read_to_string("example/Shop.sol").unwrap();
507 let highlights = document_highlights(&source, Position::new(68, 22));
509 assert!(
510 highlights.len() >= 2,
511 "Shop.sol 'PRICE' should have at least 2 highlights, got {}",
512 highlights.len()
513 );
514
515 let decl = highlights.iter().find(|h| h.range.start.line == 68);
517 assert_eq!(
518 decl.map(|h| h.kind),
519 Some(Some(DocumentHighlightKind::WRITE))
520 );
521 }
522
523 #[test]
524 fn test_increment_is_write() {
525 let source = r#"contract Foo {
526 uint256 public x;
527 function inc() public {
528 x++;
529 }
530}"#;
531 let highlights = highlights_at(source, 1, 19);
533 assert!(
534 highlights.len() >= 2,
535 "expected at least 2 highlights for 'x', got {}: {:?}",
536 highlights.len(),
537 highlights
538 );
539 let inc = highlights.iter().find(|h| h.0 == 3);
540 assert_eq!(
541 inc.map(|h| h.2),
542 Some(DocumentHighlightKind::WRITE),
543 "`x++` should be Write, all highlights: {:?}",
544 highlights
545 );
546 }
547
548 #[test]
549 fn test_cursor_on_usage_finds_all() {
550 let source = r#"contract Foo {
551 uint256 public count;
552 function inc() public {
553 count += 1;
554 }
555}"#;
556 let highlights_from_usage = highlights_at(source, 3, 8);
558 let highlights_from_decl = highlights_at(source, 1, 23);
559 assert_eq!(
560 highlights_from_usage.len(),
561 highlights_from_decl.len(),
562 "clicking on usage vs declaration should find the same set"
563 );
564 }
565}