Skip to main content

squawk_ide/
inlay_hints.rs

1use crate::binder;
2use crate::binder::Binder;
3use crate::resolve;
4use crate::symbols::Name;
5use rowan::{TextRange, TextSize};
6use squawk_syntax::{
7    SyntaxNode,
8    ast::{self, AstNode},
9};
10
11/// `VSCode` has some theming options based on these types.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum InlayHintKind {
14    Type,
15    Parameter,
16}
17
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct InlayHint {
20    pub position: TextSize,
21    pub label: String,
22    pub kind: InlayHintKind,
23    pub target: Option<TextRange>,
24}
25
26pub fn inlay_hints(file: &ast::SourceFile) -> Vec<InlayHint> {
27    let mut hints = vec![];
28    let binder = binder::bind(file);
29    let root = file.syntax();
30
31    for node in root.descendants() {
32        if let Some(call_expr) = ast::CallExpr::cast(node.clone()) {
33            inlay_hint_call_expr(&mut hints, root, &binder, call_expr);
34        } else if let Some(insert) = ast::Insert::cast(node) {
35            inlay_hint_insert(&mut hints, root, &binder, insert);
36        }
37    }
38
39    hints
40}
41
42fn inlay_hint_call_expr(
43    hints: &mut Vec<InlayHint>,
44    root: &SyntaxNode,
45    binder: &Binder,
46    call_expr: ast::CallExpr,
47) -> Option<()> {
48    let arg_list = call_expr.arg_list()?;
49    let expr = call_expr.expr()?;
50
51    let name_ref = if let Some(name_ref) = ast::NameRef::cast(expr.syntax().clone()) {
52        name_ref
53    } else {
54        ast::FieldExpr::cast(expr.syntax().clone())?.field()?
55    };
56
57    let function_ptr = resolve::resolve_name_ref_ptrs(binder, root, &name_ref)?
58        .into_iter()
59        .next()?;
60
61    let function_name_node = function_ptr.to_node(root);
62
63    if let Some(create_function) = function_name_node
64        .ancestors()
65        .find_map(ast::CreateFunction::cast)
66        && let Some(param_list) = create_function.param_list()
67    {
68        for (param, arg) in param_list.params().zip(arg_list.args()) {
69            if let Some(param_name) = param.name() {
70                let arg_start = arg.syntax().text_range().start();
71                let target = Some(param_name.syntax().text_range());
72                hints.push(InlayHint {
73                    position: arg_start,
74                    label: format!("{}: ", param_name.syntax().text()),
75                    kind: InlayHintKind::Parameter,
76                    target,
77                });
78            }
79        }
80    };
81
82    Some(())
83}
84
85fn inlay_hint_insert(
86    hints: &mut Vec<InlayHint>,
87    root: &SyntaxNode,
88    binder: &Binder,
89    insert: ast::Insert,
90) -> Option<()> {
91    let create_table = resolve::resolve_insert_create_table(root, binder, &insert);
92
93    let columns = if let Some(column_list) = insert.column_list() {
94        // `insert into t(a, b, c) values (1, 2, 3)`
95        column_list
96            .columns()
97            .filter_map(|col| {
98                let col_name = resolve::extract_column_name(&col)?;
99                let target = create_table
100                    .as_ref()
101                    .and_then(|x| resolve::find_column_in_create_table(binder, root, x, &col_name))
102                    .map(|x| x.text_range());
103                Some((col_name, target))
104            })
105            .collect()
106    } else {
107        // `insert into t values (1, 2, 3)`
108        create_table?
109            .table_arg_list()?
110            .args()
111            .filter_map(|arg| {
112                if let ast::TableArg::Column(column) = arg
113                    && let Some(name) = column.name()
114                {
115                    let col_name = Name::from_node(&name);
116                    let target = Some(name.syntax().text_range());
117                    Some((col_name, target))
118                } else {
119                    None
120                }
121            })
122            .collect()
123    };
124
125    let Some(values) = insert.values() else {
126        return inlay_hint_insert_select(hints, columns, insert.stmt()?);
127    };
128    let row_list = values.row_list()?;
129
130    for row in row_list.rows() {
131        for ((column_name, target), expr) in columns.iter().zip(row.exprs()) {
132            let expr_start = expr.syntax().text_range().start();
133            hints.push(InlayHint {
134                position: expr_start,
135                label: format!("{}: ", column_name),
136                kind: InlayHintKind::Parameter,
137                target: *target,
138            });
139        }
140    }
141
142    Some(())
143}
144
145fn inlay_hint_insert_select(
146    hints: &mut Vec<InlayHint>,
147    columns: Vec<(Name, Option<TextRange>)>,
148    stmt: ast::Stmt,
149) -> Option<()> {
150    let target_list = match stmt {
151        ast::Stmt::Select(select) => select.select_clause()?.target_list(),
152        ast::Stmt::SelectInto(select_into) => select_into.select_clause()?.target_list(),
153        ast::Stmt::ParenSelect(paren_select) => {
154            target_list_from_select_variant(paren_select.select()?)
155        }
156        _ => None,
157    }?;
158
159    for ((column_name, target), target_expr) in columns.iter().zip(target_list.targets()) {
160        let expr = target_expr.expr()?;
161        let expr_start = expr.syntax().text_range().start();
162        hints.push(InlayHint {
163            position: expr_start,
164            label: format!("{}: ", column_name),
165            kind: InlayHintKind::Parameter,
166            target: *target,
167        });
168    }
169
170    Some(())
171}
172
173fn target_list_from_select_variant(select: ast::SelectVariant) -> Option<ast::TargetList> {
174    let mut current = select;
175    for _ in 0..100 {
176        match current {
177            ast::SelectVariant::Select(select) => {
178                return select.select_clause()?.target_list();
179            }
180            ast::SelectVariant::SelectInto(select_into) => {
181                return select_into.select_clause()?.target_list();
182            }
183            ast::SelectVariant::ParenSelect(paren_select) => {
184                current = paren_select.select()?;
185            }
186            _ => return None,
187        }
188    }
189    None
190}
191
192#[cfg(test)]
193mod test {
194    use crate::inlay_hints::inlay_hints;
195    use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
196    use insta::assert_snapshot;
197    use squawk_syntax::ast;
198
199    #[track_caller]
200    fn check_inlay_hints(sql: &str) -> String {
201        let parse = ast::SourceFile::parse(sql);
202        assert_eq!(parse.errors(), vec![]);
203        let file: ast::SourceFile = parse.tree();
204
205        let hints = inlay_hints(&file);
206
207        if hints.is_empty() {
208            return String::new();
209        }
210
211        let mut modified_sql = sql.to_string();
212        let mut insertions: Vec<(usize, String)> = hints
213            .iter()
214            .map(|hint| {
215                let offset: usize = hint.position.into();
216                (offset, hint.label.clone())
217            })
218            .collect();
219
220        insertions.sort_by(|a, b| b.0.cmp(&a.0));
221
222        for (offset, label) in &insertions {
223            modified_sql.insert_str(*offset, label);
224        }
225
226        let mut annotations = vec![];
227        let mut cumulative_offset = 0;
228
229        insertions.reverse();
230        for (original_offset, label) in insertions {
231            let new_offset = original_offset + cumulative_offset;
232            annotations.push((new_offset, label.len()));
233            cumulative_offset += label.len();
234        }
235
236        let mut snippet = Snippet::source(&modified_sql).fold(true);
237
238        for (offset, len) in annotations {
239            snippet = snippet.annotation(AnnotationKind::Context.span(offset..offset + len));
240        }
241
242        let group = Level::INFO.primary_title("inlay hints").element(snippet);
243
244        let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
245        renderer
246            .render(&[group])
247            .to_string()
248            .replace("info: inlay hints", "inlay hints:")
249    }
250
251    #[test]
252    fn single_param() {
253        assert_snapshot!(check_inlay_hints("
254create function foo(a int) returns int as 'select $$1' language sql;
255select foo(1);
256"), @r"
257        inlay hints:
258          ╭▸ 
259        3 │ select foo(a: 1);
260          ╰╴           ───
261        ");
262    }
263
264    #[test]
265    fn multiple_params() {
266        assert_snapshot!(check_inlay_hints("
267create function add(a int, b int) returns int as 'select $$1 + $$2' language sql;
268select add(1, 2);
269"), @r"
270        inlay hints:
271          ╭▸ 
272        3 │ select add(a: 1, b: 2);
273          ╰╴           ───   ───
274        ");
275    }
276
277    #[test]
278    fn no_params() {
279        assert_snapshot!(check_inlay_hints("
280create function foo() returns int as 'select 1' language sql;
281select foo();
282"), @"");
283    }
284
285    #[test]
286    fn with_schema() {
287        assert_snapshot!(check_inlay_hints("
288create function public.foo(x int) returns int as 'select $$1' language sql;
289select public.foo(42);
290"), @r"
291        inlay hints:
292          ╭▸ 
293        3 │ select public.foo(x: 42);
294          ╰╴                  ───
295        ");
296    }
297
298    #[test]
299    fn with_search_path() {
300        assert_snapshot!(check_inlay_hints(r#"
301set search_path to myschema;
302create function foo(val int) returns int as 'select $$1' language sql;
303select foo(100);
304"#), @r"
305        inlay hints:
306          ╭▸ 
307        4 │ select foo(val: 100);
308          ╰╴           ─────
309        ");
310    }
311
312    #[test]
313    fn multiple_calls() {
314        assert_snapshot!(check_inlay_hints("
315create function inc(n int) returns int as 'select $$1 + 1' language sql;
316select inc(1), inc(2);
317"), @r"
318        inlay hints:
319          ╭▸ 
320        3 │ select inc(n: 1), inc(n: 2);
321          ╰╴           ───        ───
322        ");
323    }
324
325    #[test]
326    fn more_args_than_params() {
327        assert_snapshot!(check_inlay_hints("
328create function foo(a int) returns int as 'select $$1' language sql;
329select foo(1, 2);
330"), @r"
331        inlay hints:
332          ╭▸ 
333        3 │ select foo(a: 1, 2);
334          ╰╴           ───
335        ");
336    }
337
338    #[test]
339    fn insert_with_column_list() {
340        assert_snapshot!(check_inlay_hints("
341create table t (column_a int, column_b int, column_c text);
342insert into t (column_a, column_c) values (1, 'foo');
343"), @r"
344        inlay hints:
345          ╭▸ 
346        3 │ insert into t (column_a, column_c) values (column_a: 1, column_c: 'foo');
347          ╰╴                                           ──────────   ──────────
348        ");
349    }
350
351    #[test]
352    fn insert_without_column_list() {
353        assert_snapshot!(check_inlay_hints("
354create table t (column_a int, column_b int, column_c text);
355insert into t values (1, 2, 'foo');
356"), @r"
357        inlay hints:
358          ╭▸ 
359        3 │ insert into t values (column_a: 1, column_b: 2, column_c: 'foo');
360          ╰╴                      ──────────   ──────────   ──────────
361        ");
362    }
363
364    #[test]
365    fn insert_multiple_rows() {
366        assert_snapshot!(check_inlay_hints("
367create table t (x int, y int);
368insert into t values (1, 2), (3, 4);
369"), @r"
370        inlay hints:
371          ╭▸ 
372        3 │ insert into t values (x: 1, y: 2), (x: 3, y: 4);
373          ╰╴                      ───   ───     ───   ───
374        ");
375    }
376
377    #[test]
378    fn insert_no_create_table() {
379        assert_snapshot!(check_inlay_hints("
380insert into t (a, b) values (1, 2);
381"), @r"
382        inlay hints:
383          ╭▸ 
384        2 │ insert into t (a, b) values (a: 1, b: 2);
385          ╰╴                             ───   ───
386        ");
387    }
388
389    #[test]
390    fn insert_more_values_than_columns() {
391        assert_snapshot!(check_inlay_hints("
392create table t (a int, b int);
393insert into t values (1, 2, 3);
394"), @r"
395        inlay hints:
396          ╭▸ 
397        3 │ insert into t values (a: 1, b: 2, 3);
398          ╰╴                      ───   ───
399        ");
400    }
401
402    #[test]
403    fn insert_select() {
404        assert_snapshot!(check_inlay_hints("
405create table t (a int, b int);
406insert into t select 1, 2;
407"), @r"
408        inlay hints:
409          ╭▸ 
410        3 │ insert into t select a: 1, b: 2;
411          ╰╴                     ───   ───
412        ");
413    }
414}