squawk_ide/
inlay_hints.rs

1use crate::binder;
2use crate::binder::Binder;
3use crate::resolve;
4use crate::symbols::Name;
5use rowan::{TextRange, TextSize};
6use squawk_syntax::ast::{self, AstNode};
7
8/// `VSCode` has some theming options based on these types.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum InlayHintKind {
11    Type,
12    Parameter,
13}
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct InlayHint {
17    pub position: TextSize,
18    pub label: String,
19    pub kind: InlayHintKind,
20    pub target: Option<TextRange>,
21}
22
23pub fn inlay_hints(file: &ast::SourceFile) -> Vec<InlayHint> {
24    let mut hints = vec![];
25    let binder = binder::bind(file);
26
27    for node in file.syntax().descendants() {
28        if let Some(call_expr) = ast::CallExpr::cast(node.clone()) {
29            inlay_hint_call_expr(&mut hints, file, &binder, call_expr);
30        } else if let Some(insert) = ast::Insert::cast(node) {
31            inlay_hint_insert(&mut hints, file, &binder, insert);
32        }
33    }
34
35    hints
36}
37
38fn inlay_hint_call_expr(
39    hints: &mut Vec<InlayHint>,
40    file: &ast::SourceFile,
41    binder: &Binder,
42    call_expr: ast::CallExpr,
43) -> Option<()> {
44    let arg_list = call_expr.arg_list()?;
45    let expr = call_expr.expr()?;
46
47    let name_ref = if let Some(name_ref) = ast::NameRef::cast(expr.syntax().clone()) {
48        name_ref
49    } else {
50        ast::FieldExpr::cast(expr.syntax().clone())?.field()?
51    };
52
53    let function_ptr = resolve::resolve_name_ref(binder, &name_ref)?;
54
55    let root = file.syntax();
56    let function_name_node = function_ptr.to_node(root);
57
58    if let Some(create_function) = function_name_node
59        .ancestors()
60        .find_map(ast::CreateFunction::cast)
61        && let Some(param_list) = create_function.param_list()
62    {
63        for (param, arg) in param_list.params().zip(arg_list.args()) {
64            if let Some(param_name) = param.name() {
65                let arg_start = arg.syntax().text_range().start();
66                let target = Some(param_name.syntax().text_range());
67                hints.push(InlayHint {
68                    position: arg_start,
69                    label: format!("{}: ", param_name.syntax().text()),
70                    kind: InlayHintKind::Parameter,
71                    target,
72                });
73            }
74        }
75    };
76
77    Some(())
78}
79
80fn inlay_hint_insert(
81    hints: &mut Vec<InlayHint>,
82    file: &ast::SourceFile,
83    binder: &Binder,
84    insert: ast::Insert,
85) -> Option<()> {
86    let values = insert.values()?;
87    let row_list = values.row_list()?;
88    let create_table = resolve::resolve_insert_create_table(file, binder, &insert);
89
90    let columns: Vec<(Name, Option<TextRange>)> = if let Some(column_list) = insert.column_list() {
91        // `insert into t(a, b, c) values (1, 2, 3)`
92        column_list
93            .columns()
94            .filter_map(|col| {
95                let col_name = resolve::extract_column_name(&col)?;
96                let target = create_table
97                    .as_ref()
98                    .and_then(|x| resolve::find_column_in_create_table(x, &col_name))
99                    .map(|x| x.text_range());
100                Some((col_name, target))
101            })
102            .collect()
103    } else {
104        // `insert into t values (1, 2, 3)`
105        create_table?
106            .table_arg_list()?
107            .args()
108            .filter_map(|arg| {
109                if let ast::TableArg::Column(column) = arg
110                    && let Some(name) = column.name()
111                {
112                    let col_name = Name::from_node(&name);
113                    let target = Some(name.syntax().text_range());
114                    Some((col_name, target))
115                } else {
116                    None
117                }
118            })
119            .collect()
120    };
121
122    for row in row_list.rows() {
123        for ((column_name, target), expr) in columns.iter().zip(row.exprs()) {
124            let expr_start = expr.syntax().text_range().start();
125            hints.push(InlayHint {
126                position: expr_start,
127                label: format!("{}: ", column_name),
128                kind: InlayHintKind::Parameter,
129                target: *target,
130            });
131        }
132    }
133
134    Some(())
135}
136
137#[cfg(test)]
138mod test {
139    use crate::inlay_hints::inlay_hints;
140    use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
141    use insta::assert_snapshot;
142    use squawk_syntax::ast;
143
144    #[track_caller]
145    fn check_inlay_hints(sql: &str) -> String {
146        let parse = ast::SourceFile::parse(sql);
147        assert_eq!(parse.errors(), vec![]);
148        let file: ast::SourceFile = parse.tree();
149
150        let hints = inlay_hints(&file);
151
152        if hints.is_empty() {
153            return String::new();
154        }
155
156        let mut modified_sql = sql.to_string();
157        let mut insertions: Vec<(usize, String)> = hints
158            .iter()
159            .map(|hint| {
160                let offset: usize = hint.position.into();
161                (offset, hint.label.clone())
162            })
163            .collect();
164
165        insertions.sort_by(|a, b| b.0.cmp(&a.0));
166
167        for (offset, label) in &insertions {
168            modified_sql.insert_str(*offset, label);
169        }
170
171        let mut annotations = vec![];
172        let mut cumulative_offset = 0;
173
174        insertions.reverse();
175        for (original_offset, label) in insertions {
176            let new_offset = original_offset + cumulative_offset;
177            annotations.push((new_offset, label.len()));
178            cumulative_offset += label.len();
179        }
180
181        let mut snippet = Snippet::source(&modified_sql).fold(true);
182
183        for (offset, len) in annotations {
184            snippet = snippet.annotation(AnnotationKind::Context.span(offset..offset + len));
185        }
186
187        let group = Level::INFO.primary_title("inlay hints").element(snippet);
188
189        let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
190        renderer
191            .render(&[group])
192            .to_string()
193            .replace("info: inlay hints", "inlay hints:")
194    }
195
196    #[test]
197    fn single_param() {
198        assert_snapshot!(check_inlay_hints("
199create function foo(a int) returns int as 'select $$1' language sql;
200select foo(1);
201"), @r"
202        inlay hints:
203          ╭▸ 
204        3 │ select foo(a: 1);
205          ╰╴           ───
206        ");
207    }
208
209    #[test]
210    fn multiple_params() {
211        assert_snapshot!(check_inlay_hints("
212create function add(a int, b int) returns int as 'select $$1 + $$2' language sql;
213select add(1, 2);
214"), @r"
215        inlay hints:
216          ╭▸ 
217        3 │ select add(a: 1, b: 2);
218          ╰╴           ───   ───
219        ");
220    }
221
222    #[test]
223    fn no_params() {
224        assert_snapshot!(check_inlay_hints("
225create function foo() returns int as 'select 1' language sql;
226select foo();
227"), @"");
228    }
229
230    #[test]
231    fn with_schema() {
232        assert_snapshot!(check_inlay_hints("
233create function public.foo(x int) returns int as 'select $$1' language sql;
234select public.foo(42);
235"), @r"
236        inlay hints:
237          ╭▸ 
238        3 │ select public.foo(x: 42);
239          ╰╴                  ───
240        ");
241    }
242
243    #[test]
244    fn with_search_path() {
245        assert_snapshot!(check_inlay_hints(r#"
246set search_path to myschema;
247create function foo(val int) returns int as 'select $$1' language sql;
248select foo(100);
249"#), @r"
250        inlay hints:
251          ╭▸ 
252        4 │ select foo(val: 100);
253          ╰╴           ─────
254        ");
255    }
256
257    #[test]
258    fn multiple_calls() {
259        assert_snapshot!(check_inlay_hints("
260create function inc(n int) returns int as 'select $$1 + 1' language sql;
261select inc(1), inc(2);
262"), @r"
263        inlay hints:
264          ╭▸ 
265        3 │ select inc(n: 1), inc(n: 2);
266          ╰╴           ───        ───
267        ");
268    }
269
270    #[test]
271    fn more_args_than_params() {
272        assert_snapshot!(check_inlay_hints("
273create function foo(a int) returns int as 'select $$1' language sql;
274select foo(1, 2);
275"), @r"
276        inlay hints:
277          ╭▸ 
278        3 │ select foo(a: 1, 2);
279          ╰╴           ───
280        ");
281    }
282
283    #[test]
284    fn insert_with_column_list() {
285        assert_snapshot!(check_inlay_hints("
286create table t (column_a int, column_b int, column_c text);
287insert into t (column_a, column_c) values (1, 'foo');
288"), @r"
289        inlay hints:
290          ╭▸ 
291        3 │ insert into t (column_a, column_c) values (column_a: 1, column_c: 'foo');
292          ╰╴                                           ──────────   ──────────
293        ");
294    }
295
296    #[test]
297    fn insert_without_column_list() {
298        assert_snapshot!(check_inlay_hints("
299create table t (column_a int, column_b int, column_c text);
300insert into t values (1, 2, 'foo');
301"), @r"
302        inlay hints:
303          ╭▸ 
304        3 │ insert into t values (column_a: 1, column_b: 2, column_c: 'foo');
305          ╰╴                      ──────────   ──────────   ──────────
306        ");
307    }
308
309    #[test]
310    fn insert_multiple_rows() {
311        assert_snapshot!(check_inlay_hints("
312create table t (x int, y int);
313insert into t values (1, 2), (3, 4);
314"), @r"
315        inlay hints:
316          ╭▸ 
317        3 │ insert into t values (x: 1, y: 2), (x: 3, y: 4);
318          ╰╴                      ───   ───     ───   ───
319        ");
320    }
321
322    #[test]
323    fn insert_no_create_table() {
324        assert_snapshot!(check_inlay_hints("
325insert into t (a, b) values (1, 2);
326"), @r"
327        inlay hints:
328          ╭▸ 
329        2 │ insert into t (a, b) values (a: 1, b: 2);
330          ╰╴                             ───   ───
331        ");
332    }
333
334    #[test]
335    fn insert_more_values_than_columns() {
336        assert_snapshot!(check_inlay_hints("
337create table t (a int, b int);
338insert into t values (1, 2, 3);
339"), @r"
340        inlay hints:
341          ╭▸ 
342        3 │ insert into t values (a: 1, b: 2, 3);
343          ╰╴                      ───   ───
344        ");
345    }
346}