Skip to main content

solidity_language_server/
inlay_hints.rs

1use crate::goto::CachedBuild;
2use serde_json::Value;
3use std::collections::HashMap;
4use tower_lsp::lsp_types::*;
5use tree_sitter::{Node, Parser};
6
7/// Parameter info resolved from the AST for a callable.
8struct ParamInfo {
9    /// Parameter names from the declaration.
10    names: Vec<String>,
11    /// Number of leading params to skip (1 for using-for library calls).
12    skip: usize,
13}
14
15/// Call-site info extracted from the AST, keyed by source byte offset.
16struct CallSite {
17    /// The resolved parameter info for this specific call.
18    info: ParamInfo,
19    /// Function/event name (for matching with tree-sitter).
20    name: String,
21}
22
23/// Both lookup strategies: exact byte-offset match and (name, arg_count) fallback.
24struct HintLookup {
25    /// Primary: byte_offset → CallSite (exact match when AST offsets are fresh).
26    by_offset: HashMap<usize, CallSite>,
27    /// Fallback: (name, arg_count) → ParamInfo (works even with stale offsets).
28    by_name: HashMap<(String, usize), ParamInfo>,
29}
30
31/// Generate inlay hints for a given range of source.
32///
33/// Uses tree-sitter on the **live buffer** for argument positions (so hints
34/// follow edits in real time) and the Forge AST for semantic info (parameter
35/// names via `referencedDeclaration`).
36pub fn inlay_hints(
37    build: &CachedBuild,
38    uri: &Url,
39    range: Range,
40    live_source: &[u8],
41) -> Vec<InlayHint> {
42    let sources = match build.ast.get("sources") {
43        Some(s) => s,
44        None => return vec![],
45    };
46
47    let path_str = match uri.to_file_path() {
48        Ok(p) => p.to_str().unwrap_or("").to_string(),
49        Err(_) => return vec![],
50    };
51
52    let abs = match build
53        .path_to_abs
54        .iter()
55        .find(|(k, _)| path_str.ends_with(k.as_str()))
56    {
57        Some((_, v)) => v.clone(),
58        None => return vec![],
59    };
60
61    let file_ast = match find_file_ast(sources, &abs) {
62        Some(a) => a,
63        None => return vec![],
64    };
65
66    // Phase 1: Build lookup from AST
67    let lookup = build_hint_lookup(file_ast, sources);
68
69    // Phase 2: Walk tree-sitter on the live buffer for real-time positions
70    let source_str = String::from_utf8_lossy(live_source);
71    let tree = match ts_parse(&source_str) {
72        Some(t) => t,
73        None => return vec![],
74    };
75
76    let mut hints = Vec::new();
77    collect_ts_hints(tree.root_node(), &source_str, &range, &lookup, &mut hints);
78    hints
79}
80
81/// Parse Solidity source with tree-sitter.
82fn ts_parse(source: &str) -> Option<tree_sitter::Tree> {
83    let mut parser = Parser::new();
84    parser
85        .set_language(&tree_sitter_solidity::LANGUAGE.into())
86        .expect("failed to load Solidity grammar");
87    parser.parse(source, None)
88}
89
90/// Build both lookup strategies from the AST.
91fn build_hint_lookup(file_ast: &Value, sources: &Value) -> HintLookup {
92    let mut lookup = HintLookup {
93        by_offset: HashMap::new(),
94        by_name: HashMap::new(),
95    };
96    collect_ast_calls(file_ast, sources, &mut lookup);
97    lookup
98}
99
100/// Parse the `src` field ("offset:length:fileId") and return the byte offset.
101fn parse_src_offset(node: &Value) -> Option<usize> {
102    let src = node.get("src").and_then(|v| v.as_str())?;
103    src.split(':').next()?.parse().ok()
104}
105
106/// Recursively walk AST nodes collecting call site info.
107fn collect_ast_calls(node: &Value, sources: &Value, lookup: &mut HintLookup) {
108    let node_type = node.get("nodeType").and_then(|v| v.as_str()).unwrap_or("");
109
110    match node_type {
111        "FunctionCall" => {
112            if let Some((name, info)) = extract_call_info(node, sources) {
113                let arg_count = node
114                    .get("arguments")
115                    .and_then(|v| v.as_array())
116                    .map(|a| a.len())
117                    .unwrap_or(0);
118                if let Some(offset) = parse_src_offset(node) {
119                    lookup.by_offset.insert(
120                        offset,
121                        CallSite {
122                            info: ParamInfo {
123                                names: info.names.clone(),
124                                skip: info.skip,
125                            },
126                            name: name.clone(),
127                        },
128                    );
129                }
130                lookup.by_name.entry((name, arg_count)).or_insert(info);
131            }
132        }
133        "EmitStatement" => {
134            if let Some(event_call) = node.get("eventCall") {
135                if let Some((name, info)) = extract_call_info(event_call, sources) {
136                    let arg_count = event_call
137                        .get("arguments")
138                        .and_then(|v| v.as_array())
139                        .map(|a| a.len())
140                        .unwrap_or(0);
141                    if let Some(offset) = parse_src_offset(node) {
142                        lookup.by_offset.insert(
143                            offset,
144                            CallSite {
145                                info: ParamInfo {
146                                    names: info.names.clone(),
147                                    skip: info.skip,
148                                },
149                                name: name.clone(),
150                            },
151                        );
152                    }
153                    lookup.by_name.entry((name, arg_count)).or_insert(info);
154                }
155            }
156        }
157        _ => {}
158    }
159
160    // Recurse into children
161    for key in crate::goto::CHILD_KEYS {
162        if let Some(child) = node.get(*key) {
163            if child.is_array() {
164                if let Some(arr) = child.as_array() {
165                    for item in arr {
166                        collect_ast_calls(item, sources, lookup);
167                    }
168                }
169            } else if child.is_object() {
170                collect_ast_calls(child, sources, lookup);
171            }
172        }
173    }
174}
175
176/// Extract function/event name and parameter info from an AST FunctionCall node.
177fn extract_call_info(node: &Value, sources: &Value) -> Option<(String, ParamInfo)> {
178    let args = node.get("arguments")?.as_array()?;
179    if args.is_empty() {
180        return None;
181    }
182
183    // Skip struct constructors with named args
184    let kind = node.get("kind").and_then(|v| v.as_str()).unwrap_or("");
185    if kind == "structConstructorCall" {
186        if node
187            .get("names")
188            .and_then(|v| v.as_array())
189            .is_some_and(|n| !n.is_empty())
190        {
191            return None;
192        }
193    }
194
195    let expr = node.get("expression")?;
196    let decl_id = expr.get("referencedDeclaration").and_then(|v| v.as_u64())?;
197
198    let decl_node = find_declaration(sources, decl_id)?;
199    let names = get_parameter_names(&decl_node)?;
200
201    // Extract the function name from the expression
202    let func_name = extract_function_name(expr)?;
203
204    // Using-for library calls pass the receiver as the implicit first param,
205    // so the AST has one fewer arg than the declaration has params.
206    // Direct library calls (Transaction.addTax) and struct constructors
207    // pass all params explicitly — arg count matches param count.
208    let arg_count = node
209        .get("arguments")
210        .and_then(|v| v.as_array())
211        .map(|a| a.len())
212        .unwrap_or(0);
213    let skip = if is_member_access(expr) && arg_count < names.len() {
214        1
215    } else {
216        0
217    };
218
219    Some((func_name, ParamInfo { names, skip }))
220}
221
222/// Extract the function/event name from an AST expression node.
223fn extract_function_name(expr: &Value) -> Option<String> {
224    let node_type = expr.get("nodeType").and_then(|v| v.as_str())?;
225    match node_type {
226        "Identifier" => expr.get("name").and_then(|v| v.as_str()).map(String::from),
227        "MemberAccess" => expr
228            .get("memberName")
229            .and_then(|v| v.as_str())
230            .map(String::from),
231        _ => None,
232    }
233}
234
235/// Check if expression is a MemberAccess (potential using-for call).
236fn is_member_access(expr: &Value) -> bool {
237    expr.get("nodeType")
238        .and_then(|v| v.as_str())
239        .is_some_and(|t| t == "MemberAccess")
240}
241
242// ── Tree-sitter walk ──────────────────────────────────────────────────────
243
244/// Look up param info: try exact byte-offset match first, fall back to (name, arg_count).
245fn lookup_info<'a>(
246    lookup: &'a HintLookup,
247    offset: usize,
248    name: &str,
249    arg_count: usize,
250) -> Option<&'a ParamInfo> {
251    // Exact match by byte offset (works when AST is fresh)
252    if let Some(site) = lookup.by_offset.get(&offset) {
253        if site.name == name {
254            return Some(&site.info);
255        }
256    }
257    // Fallback by (name, arg_count) (works with stale offsets after edits)
258    lookup.by_name.get(&(name.to_string(), arg_count))
259}
260
261/// Recursively walk tree-sitter nodes, emitting hints for calls in the visible range.
262fn collect_ts_hints(
263    node: Node,
264    source: &str,
265    range: &Range,
266    lookup: &HintLookup,
267    hints: &mut Vec<InlayHint>,
268) {
269    // Quick range check — skip nodes entirely outside the visible range
270    let node_start = node.start_position();
271    let node_end = node.end_position();
272    if (node_end.row as u32) < range.start.line || (node_start.row as u32) > range.end.line {
273        return;
274    }
275
276    match node.kind() {
277        "call_expression" => {
278            emit_call_hints(node, source, lookup, hints);
279        }
280        "emit_statement" => {
281            emit_emit_hints(node, source, lookup, hints);
282        }
283        _ => {}
284    }
285
286    // Recurse into children
287    let mut cursor = node.walk();
288    for child in node.children(&mut cursor) {
289        collect_ts_hints(child, source, range, lookup, hints);
290    }
291}
292
293/// Emit parameter hints for a `call_expression` node.
294fn emit_call_hints(node: Node, source: &str, lookup: &HintLookup, hints: &mut Vec<InlayHint>) {
295    let func_name = match ts_call_function_name(node, source) {
296        Some(n) => n,
297        None => return,
298    };
299
300    let args = ts_call_arguments(node);
301    if args.is_empty() {
302        return;
303    }
304
305    let info = match lookup_info(lookup, node.start_byte(), func_name, args.len()) {
306        Some(i) => i,
307        None => return,
308    };
309
310    emit_param_hints(&args, info, hints);
311}
312
313/// Emit parameter hints for an `emit_statement` node.
314fn emit_emit_hints(node: Node, source: &str, lookup: &HintLookup, hints: &mut Vec<InlayHint>) {
315    let event_name = match ts_emit_event_name(node, source) {
316        Some(n) => n,
317        None => return,
318    };
319
320    let args = ts_call_arguments(node);
321    if args.is_empty() {
322        return;
323    }
324
325    let info = match lookup_info(lookup, node.start_byte(), event_name, args.len()) {
326        Some(i) => i,
327        None => return,
328    };
329
330    emit_param_hints(&args, info, hints);
331}
332
333/// Emit InlayHint items for each argument, using tree-sitter positions.
334fn emit_param_hints(args: &[Node], info: &ParamInfo, hints: &mut Vec<InlayHint>) {
335    for (i, arg) in args.iter().enumerate() {
336        let pi = i + info.skip;
337        if pi >= info.names.len() || info.names[pi].is_empty() {
338            continue;
339        }
340
341        let start = arg.start_position();
342        let position = Position::new(start.row as u32, start.column as u32);
343
344        hints.push(InlayHint {
345            position,
346            kind: Some(InlayHintKind::PARAMETER),
347            label: InlayHintLabel::String(format!("{}:", info.names[pi])),
348            text_edits: None,
349            tooltip: None,
350            padding_left: None,
351            padding_right: Some(true),
352            data: None,
353        });
354    }
355}
356
357// ── Tree-sitter helpers ───────────────────────────────────────────────────
358
359/// Get the function name from a `call_expression` node.
360///
361/// For `transfer(...)` → "transfer"
362/// For `PRICE.addTax(...)` → "addTax"
363fn ts_call_function_name<'a>(node: Node<'a>, source: &'a str) -> Option<&'a str> {
364    let func_expr = node.child_by_field_name("function")?;
365    // The expression wrapper has one named child
366    let inner = first_named_child(func_expr)?;
367    match inner.kind() {
368        "identifier" => Some(&source[inner.byte_range()]),
369        "member_expression" => {
370            let prop = inner.child_by_field_name("property")?;
371            Some(&source[prop.byte_range()])
372        }
373        _ => None,
374    }
375}
376
377/// Get the event name from an `emit_statement` node.
378fn ts_emit_event_name<'a>(node: Node<'a>, source: &'a str) -> Option<&'a str> {
379    let name_expr = node.child_by_field_name("name")?;
380    let inner = first_named_child(name_expr)?;
381    match inner.kind() {
382        "identifier" => Some(&source[inner.byte_range()]),
383        "member_expression" => {
384            let prop = inner.child_by_field_name("property")?;
385            Some(&source[prop.byte_range()])
386        }
387        _ => None,
388    }
389}
390
391/// Collect `call_argument` children from a node (works for both
392/// `call_expression` and `emit_statement` since `_call_arguments` is hidden).
393fn ts_call_arguments(node: Node) -> Vec<Node> {
394    let mut args = Vec::new();
395    let mut cursor = node.walk();
396    for child in node.children(&mut cursor) {
397        if child.kind() == "call_argument" {
398            args.push(child);
399        }
400    }
401    args
402}
403
404/// Get the first named child of a node.
405fn first_named_child(node: Node) -> Option<Node> {
406    let mut cursor = node.walk();
407    node.children(&mut cursor).find(|c| c.is_named())
408}
409
410// ── AST helpers (unchanged) ──────────────────────────────────────────────
411
412/// Find a declaration node by ID in the AST sources.
413fn find_declaration(sources: &Value, decl_id: u64) -> Option<Value> {
414    let sources_obj = sources.as_object()?;
415    for (_, file_data) in sources_obj {
416        let entries = file_data.as_array()?;
417        for entry in entries {
418            let ast = entry.get("source_file")?.get("ast")?;
419            if let Some(found) = find_node_by_id(ast, decl_id) {
420                return Some(found.clone());
421            }
422        }
423    }
424    None
425}
426
427/// Recursively find a node by its `id` field.
428fn find_node_by_id(node: &Value, id: u64) -> Option<&Value> {
429    if node.get("id").and_then(|v| v.as_u64()) == Some(id) {
430        return Some(node);
431    }
432    for key in crate::goto::CHILD_KEYS {
433        if let Some(child) = node.get(*key) {
434            if child.is_array() {
435                if let Some(arr) = child.as_array() {
436                    for item in arr {
437                        if let Some(found) = find_node_by_id(item, id) {
438                            return Some(found);
439                        }
440                    }
441                }
442            } else if child.is_object() {
443                if let Some(found) = find_node_by_id(child, id) {
444                    return Some(found);
445                }
446            }
447        }
448    }
449    if let Some(nodes) = node.get("nodes").and_then(|v| v.as_array()) {
450        for child in nodes {
451            if let Some(found) = find_node_by_id(child, id) {
452                return Some(found);
453            }
454        }
455    }
456    None
457}
458
459/// Extract parameter names from a function/event/error/struct declaration.
460fn get_parameter_names(decl: &Value) -> Option<Vec<String>> {
461    // Functions, events, errors: parameters.parameters[]
462    // Structs: members[]
463    let items = decl
464        .get("parameters")
465        .and_then(|p| p.get("parameters"))
466        .and_then(|v| v.as_array())
467        .or_else(|| decl.get("members").and_then(|v| v.as_array()))?;
468    Some(
469        items
470            .iter()
471            .map(|p| {
472                p.get("name")
473                    .and_then(|v| v.as_str())
474                    .unwrap_or("")
475                    .to_string()
476            })
477            .collect(),
478    )
479}
480
481/// Find the AST for a specific file by its absolutePath.
482fn find_file_ast<'a>(sources: &'a Value, abs_path: &str) -> Option<&'a Value> {
483    let sources_obj = sources.as_object()?;
484    for (_, file_data) in sources_obj {
485        let entries = file_data.as_array()?;
486        for entry in entries {
487            let ast = entry.get("source_file")?.get("ast")?;
488            if ast.get("absolutePath").and_then(|v| v.as_str()) == Some(abs_path) {
489                return Some(ast);
490            }
491        }
492    }
493    None
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499
500    #[test]
501    fn test_get_parameter_names() {
502        let decl: Value = serde_json::json!({
503            "parameters": {
504                "parameters": [
505                    {"name": "to", "nodeType": "VariableDeclaration"},
506                    {"name": "amount", "nodeType": "VariableDeclaration"},
507                ]
508            }
509        });
510        let names = get_parameter_names(&decl).unwrap();
511        assert_eq!(names, vec!["to", "amount"]);
512    }
513
514    #[test]
515    fn test_ts_call_function_name() {
516        let source = r#"
517contract Foo {
518    function bar(uint x) public {}
519    function test() public {
520        bar(42);
521    }
522}
523"#;
524        let tree = ts_parse(source).unwrap();
525        let mut found = Vec::new();
526        find_calls(tree.root_node(), source, &mut found);
527        assert_eq!(found.len(), 1);
528        assert_eq!(found[0], "bar");
529    }
530
531    #[test]
532    fn test_ts_member_call_name() {
533        let source = r#"
534contract Foo {
535    function test() public {
536        PRICE.addTax(TAX, TAX_BASE);
537    }
538}
539"#;
540        let tree = ts_parse(source).unwrap();
541        let mut found = Vec::new();
542        find_calls(tree.root_node(), source, &mut found);
543        assert_eq!(found.len(), 1);
544        assert_eq!(found[0], "addTax");
545    }
546
547    #[test]
548    fn test_ts_emit_event_name() {
549        let source = r#"
550contract Foo {
551    event Purchase(address buyer, uint256 price);
552    function test() public {
553        emit Purchase(msg.sender, 100);
554    }
555}
556"#;
557        let tree = ts_parse(source).unwrap();
558        let mut found = Vec::new();
559        find_emits(tree.root_node(), source, &mut found);
560        assert_eq!(found.len(), 1);
561        assert_eq!(found[0], "Purchase");
562    }
563
564    #[test]
565    fn test_ts_call_arguments_count() {
566        let source = r#"
567contract Foo {
568    function bar(uint x, uint y) public {}
569    function test() public {
570        bar(1, 2);
571    }
572}
573"#;
574        let tree = ts_parse(source).unwrap();
575        let mut arg_counts = Vec::new();
576        find_call_arg_counts(tree.root_node(), &mut arg_counts);
577        assert_eq!(arg_counts, vec![2]);
578    }
579
580    #[test]
581    fn test_ts_argument_positions_follow_live_buffer() {
582        // Simulate an edited buffer with extra whitespace
583        let source = r#"
584contract Foo {
585    function bar(uint x, uint y) public {}
586    function test() public {
587        bar(
588            1,
589            2
590        );
591    }
592}
593"#;
594        let tree = ts_parse(source).unwrap();
595        let mut positions = Vec::new();
596        find_arg_positions(tree.root_node(), &mut positions);
597        // First arg "1" is on line 5 (0-indexed), second "2" on line 6
598        assert_eq!(positions.len(), 2);
599        assert_eq!(positions[0].0, 5); // row of "1"
600        assert_eq!(positions[1].0, 6); // row of "2"
601    }
602
603    // Test helpers
604
605    fn find_calls<'a>(node: Node<'a>, source: &'a str, out: &mut Vec<&'a str>) {
606        if node.kind() == "call_expression" {
607            if let Some(name) = ts_call_function_name(node, source) {
608                out.push(name);
609            }
610        }
611        let mut cursor = node.walk();
612        for child in node.children(&mut cursor) {
613            find_calls(child, source, out);
614        }
615    }
616
617    fn find_emits<'a>(node: Node<'a>, source: &'a str, out: &mut Vec<&'a str>) {
618        if node.kind() == "emit_statement" {
619            if let Some(name) = ts_emit_event_name(node, source) {
620                out.push(name);
621            }
622        }
623        let mut cursor = node.walk();
624        for child in node.children(&mut cursor) {
625            find_emits(child, source, out);
626        }
627    }
628
629    fn find_call_arg_counts(node: Node, out: &mut Vec<usize>) {
630        if node.kind() == "call_expression" {
631            out.push(ts_call_arguments(node).len());
632        }
633        let mut cursor = node.walk();
634        for child in node.children(&mut cursor) {
635            find_call_arg_counts(child, out);
636        }
637    }
638
639    fn find_arg_positions(node: Node, out: &mut Vec<(usize, usize)>) {
640        if node.kind() == "call_expression" {
641            for arg in ts_call_arguments(node) {
642                let p = arg.start_position();
643                out.push((p.row, p.column));
644            }
645        }
646        let mut cursor = node.walk();
647        for child in node.children(&mut cursor) {
648            find_arg_positions(child, out);
649        }
650    }
651}