squawk_ide/
inlay_hints.rs

1use crate::binder;
2use crate::binder::Binder;
3use crate::resolve;
4use crate::symbols::Name;
5use rowan::{TextRange, TextSize};
6use squawk_syntax::{
7    SyntaxNode,
8    ast::{self, AstNode},
9};
10
11/// `VSCode` has some theming options based on these types.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum InlayHintKind {
14    Type,
15    Parameter,
16}
17
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct InlayHint {
20    pub position: TextSize,
21    pub label: String,
22    pub kind: InlayHintKind,
23    pub target: Option<TextRange>,
24}
25
26pub fn inlay_hints(file: &ast::SourceFile) -> Vec<InlayHint> {
27    let mut hints = vec![];
28    let binder = binder::bind(file);
29    let root = file.syntax();
30
31    for node in root.descendants() {
32        if let Some(call_expr) = ast::CallExpr::cast(node.clone()) {
33            inlay_hint_call_expr(&mut hints, root, &binder, call_expr);
34        } else if let Some(insert) = ast::Insert::cast(node) {
35            inlay_hint_insert(&mut hints, root, &binder, insert);
36        }
37    }
38
39    hints
40}
41
42fn inlay_hint_call_expr(
43    hints: &mut Vec<InlayHint>,
44    root: &SyntaxNode,
45    binder: &Binder,
46    call_expr: ast::CallExpr,
47) -> Option<()> {
48    let arg_list = call_expr.arg_list()?;
49    let expr = call_expr.expr()?;
50
51    let name_ref = if let Some(name_ref) = ast::NameRef::cast(expr.syntax().clone()) {
52        name_ref
53    } else {
54        ast::FieldExpr::cast(expr.syntax().clone())?.field()?
55    };
56
57    let function_ptr = resolve::resolve_name_ref(binder, root, &name_ref)?
58        .into_iter()
59        .next()?;
60
61    let function_name_node = function_ptr.to_node(root);
62
63    if let Some(create_function) = function_name_node
64        .ancestors()
65        .find_map(ast::CreateFunction::cast)
66        && let Some(param_list) = create_function.param_list()
67    {
68        for (param, arg) in param_list.params().zip(arg_list.args()) {
69            if let Some(param_name) = param.name() {
70                let arg_start = arg.syntax().text_range().start();
71                let target = Some(param_name.syntax().text_range());
72                hints.push(InlayHint {
73                    position: arg_start,
74                    label: format!("{}: ", param_name.syntax().text()),
75                    kind: InlayHintKind::Parameter,
76                    target,
77                });
78            }
79        }
80    };
81
82    Some(())
83}
84
85fn inlay_hint_insert(
86    hints: &mut Vec<InlayHint>,
87    root: &SyntaxNode,
88    binder: &Binder,
89    insert: ast::Insert,
90) -> Option<()> {
91    let values = insert.values()?;
92    let row_list = values.row_list()?;
93    let create_table = resolve::resolve_insert_create_table(root, binder, &insert);
94
95    let columns: Vec<(Name, Option<TextRange>)> = if let Some(column_list) = insert.column_list() {
96        // `insert into t(a, b, c) values (1, 2, 3)`
97        column_list
98            .columns()
99            .filter_map(|col| {
100                let col_name = resolve::extract_column_name(&col)?;
101                let target = create_table
102                    .as_ref()
103                    .and_then(|x| resolve::find_column_in_create_table(x, &col_name))
104                    .map(|x| x.text_range());
105                Some((col_name, target))
106            })
107            .collect()
108    } else {
109        // `insert into t values (1, 2, 3)`
110        create_table?
111            .table_arg_list()?
112            .args()
113            .filter_map(|arg| {
114                if let ast::TableArg::Column(column) = arg
115                    && let Some(name) = column.name()
116                {
117                    let col_name = Name::from_node(&name);
118                    let target = Some(name.syntax().text_range());
119                    Some((col_name, target))
120                } else {
121                    None
122                }
123            })
124            .collect()
125    };
126
127    for row in row_list.rows() {
128        for ((column_name, target), expr) in columns.iter().zip(row.exprs()) {
129            let expr_start = expr.syntax().text_range().start();
130            hints.push(InlayHint {
131                position: expr_start,
132                label: format!("{}: ", column_name),
133                kind: InlayHintKind::Parameter,
134                target: *target,
135            });
136        }
137    }
138
139    Some(())
140}
141
142#[cfg(test)]
143mod test {
144    use crate::inlay_hints::inlay_hints;
145    use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
146    use insta::assert_snapshot;
147    use squawk_syntax::ast;
148
149    #[track_caller]
150    fn check_inlay_hints(sql: &str) -> String {
151        let parse = ast::SourceFile::parse(sql);
152        assert_eq!(parse.errors(), vec![]);
153        let file: ast::SourceFile = parse.tree();
154
155        let hints = inlay_hints(&file);
156
157        if hints.is_empty() {
158            return String::new();
159        }
160
161        let mut modified_sql = sql.to_string();
162        let mut insertions: Vec<(usize, String)> = hints
163            .iter()
164            .map(|hint| {
165                let offset: usize = hint.position.into();
166                (offset, hint.label.clone())
167            })
168            .collect();
169
170        insertions.sort_by(|a, b| b.0.cmp(&a.0));
171
172        for (offset, label) in &insertions {
173            modified_sql.insert_str(*offset, label);
174        }
175
176        let mut annotations = vec![];
177        let mut cumulative_offset = 0;
178
179        insertions.reverse();
180        for (original_offset, label) in insertions {
181            let new_offset = original_offset + cumulative_offset;
182            annotations.push((new_offset, label.len()));
183            cumulative_offset += label.len();
184        }
185
186        let mut snippet = Snippet::source(&modified_sql).fold(true);
187
188        for (offset, len) in annotations {
189            snippet = snippet.annotation(AnnotationKind::Context.span(offset..offset + len));
190        }
191
192        let group = Level::INFO.primary_title("inlay hints").element(snippet);
193
194        let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
195        renderer
196            .render(&[group])
197            .to_string()
198            .replace("info: inlay hints", "inlay hints:")
199    }
200
201    #[test]
202    fn single_param() {
203        assert_snapshot!(check_inlay_hints("
204create function foo(a int) returns int as 'select $$1' language sql;
205select foo(1);
206"), @r"
207        inlay hints:
208          ╭▸ 
209        3 │ select foo(a: 1);
210          ╰╴           ───
211        ");
212    }
213
214    #[test]
215    fn multiple_params() {
216        assert_snapshot!(check_inlay_hints("
217create function add(a int, b int) returns int as 'select $$1 + $$2' language sql;
218select add(1, 2);
219"), @r"
220        inlay hints:
221          ╭▸ 
222        3 │ select add(a: 1, b: 2);
223          ╰╴           ───   ───
224        ");
225    }
226
227    #[test]
228    fn no_params() {
229        assert_snapshot!(check_inlay_hints("
230create function foo() returns int as 'select 1' language sql;
231select foo();
232"), @"");
233    }
234
235    #[test]
236    fn with_schema() {
237        assert_snapshot!(check_inlay_hints("
238create function public.foo(x int) returns int as 'select $$1' language sql;
239select public.foo(42);
240"), @r"
241        inlay hints:
242          ╭▸ 
243        3 │ select public.foo(x: 42);
244          ╰╴                  ───
245        ");
246    }
247
248    #[test]
249    fn with_search_path() {
250        assert_snapshot!(check_inlay_hints(r#"
251set search_path to myschema;
252create function foo(val int) returns int as 'select $$1' language sql;
253select foo(100);
254"#), @r"
255        inlay hints:
256          ╭▸ 
257        4 │ select foo(val: 100);
258          ╰╴           ─────
259        ");
260    }
261
262    #[test]
263    fn multiple_calls() {
264        assert_snapshot!(check_inlay_hints("
265create function inc(n int) returns int as 'select $$1 + 1' language sql;
266select inc(1), inc(2);
267"), @r"
268        inlay hints:
269          ╭▸ 
270        3 │ select inc(n: 1), inc(n: 2);
271          ╰╴           ───        ───
272        ");
273    }
274
275    #[test]
276    fn more_args_than_params() {
277        assert_snapshot!(check_inlay_hints("
278create function foo(a int) returns int as 'select $$1' language sql;
279select foo(1, 2);
280"), @r"
281        inlay hints:
282          ╭▸ 
283        3 │ select foo(a: 1, 2);
284          ╰╴           ───
285        ");
286    }
287
288    #[test]
289    fn insert_with_column_list() {
290        assert_snapshot!(check_inlay_hints("
291create table t (column_a int, column_b int, column_c text);
292insert into t (column_a, column_c) values (1, 'foo');
293"), @r"
294        inlay hints:
295          ╭▸ 
296        3 │ insert into t (column_a, column_c) values (column_a: 1, column_c: 'foo');
297          ╰╴                                           ──────────   ──────────
298        ");
299    }
300
301    #[test]
302    fn insert_without_column_list() {
303        assert_snapshot!(check_inlay_hints("
304create table t (column_a int, column_b int, column_c text);
305insert into t values (1, 2, 'foo');
306"), @r"
307        inlay hints:
308          ╭▸ 
309        3 │ insert into t values (column_a: 1, column_b: 2, column_c: 'foo');
310          ╰╴                      ──────────   ──────────   ──────────
311        ");
312    }
313
314    #[test]
315    fn insert_multiple_rows() {
316        assert_snapshot!(check_inlay_hints("
317create table t (x int, y int);
318insert into t values (1, 2), (3, 4);
319"), @r"
320        inlay hints:
321          ╭▸ 
322        3 │ insert into t values (x: 1, y: 2), (x: 3, y: 4);
323          ╰╴                      ───   ───     ───   ───
324        ");
325    }
326
327    #[test]
328    fn insert_no_create_table() {
329        assert_snapshot!(check_inlay_hints("
330insert into t (a, b) values (1, 2);
331"), @r"
332        inlay hints:
333          ╭▸ 
334        2 │ insert into t (a, b) values (a: 1, b: 2);
335          ╰╴                             ───   ───
336        ");
337    }
338
339    #[test]
340    fn insert_more_values_than_columns() {
341        assert_snapshot!(check_inlay_hints("
342create table t (a int, b int);
343insert into t values (1, 2, 3);
344"), @r"
345        inlay hints:
346          ╭▸ 
347        3 │ insert into t values (a: 1, b: 2, 3);
348          ╰╴                      ───   ───
349        ");
350    }
351}