1use crate::binder;
2use crate::binder::Binder;
3use crate::resolve;
4use crate::symbols::Name;
5use rowan::{TextRange, TextSize};
6use squawk_syntax::ast::{self, AstNode};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum InlayHintKind {
11 Type,
12 Parameter,
13}
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct InlayHint {
17 pub position: TextSize,
18 pub label: String,
19 pub kind: InlayHintKind,
20 pub target: Option<TextRange>,
21}
22
23pub fn inlay_hints(file: &ast::SourceFile) -> Vec<InlayHint> {
24 let mut hints = vec![];
25 let binder = binder::bind(file);
26
27 for node in file.syntax().descendants() {
28 if let Some(call_expr) = ast::CallExpr::cast(node.clone()) {
29 inlay_hint_call_expr(&mut hints, file, &binder, call_expr);
30 } else if let Some(insert) = ast::Insert::cast(node) {
31 inlay_hint_insert(&mut hints, file, &binder, insert);
32 }
33 }
34
35 hints
36}
37
38fn inlay_hint_call_expr(
39 hints: &mut Vec<InlayHint>,
40 file: &ast::SourceFile,
41 binder: &Binder,
42 call_expr: ast::CallExpr,
43) -> Option<()> {
44 let arg_list = call_expr.arg_list()?;
45 let expr = call_expr.expr()?;
46
47 let name_ref = if let Some(name_ref) = ast::NameRef::cast(expr.syntax().clone()) {
48 name_ref
49 } else {
50 ast::FieldExpr::cast(expr.syntax().clone())?.field()?
51 };
52
53 let function_ptr = resolve::resolve_name_ref(binder, &name_ref)?;
54
55 let root = file.syntax();
56 let function_name_node = function_ptr.to_node(root);
57
58 if let Some(create_function) = function_name_node
59 .ancestors()
60 .find_map(ast::CreateFunction::cast)
61 && let Some(param_list) = create_function.param_list()
62 {
63 for (param, arg) in param_list.params().zip(arg_list.args()) {
64 if let Some(param_name) = param.name() {
65 let arg_start = arg.syntax().text_range().start();
66 let target = Some(param_name.syntax().text_range());
67 hints.push(InlayHint {
68 position: arg_start,
69 label: format!("{}: ", param_name.syntax().text()),
70 kind: InlayHintKind::Parameter,
71 target,
72 });
73 }
74 }
75 };
76
77 Some(())
78}
79
80fn inlay_hint_insert(
81 hints: &mut Vec<InlayHint>,
82 file: &ast::SourceFile,
83 binder: &Binder,
84 insert: ast::Insert,
85) -> Option<()> {
86 let values = insert.values()?;
87 let row_list = values.row_list()?;
88 let create_table = resolve::resolve_insert_create_table(file, binder, &insert);
89
90 let columns: Vec<(Name, Option<TextRange>)> = if let Some(column_list) = insert.column_list() {
91 column_list
93 .columns()
94 .filter_map(|col| {
95 let col_name = resolve::extract_column_name(&col)?;
96 let target = create_table
97 .as_ref()
98 .and_then(|x| resolve::find_column_in_create_table(x, &col_name))
99 .map(|x| x.text_range());
100 Some((col_name, target))
101 })
102 .collect()
103 } else {
104 create_table?
106 .table_arg_list()?
107 .args()
108 .filter_map(|arg| {
109 if let ast::TableArg::Column(column) = arg
110 && let Some(name) = column.name()
111 {
112 let col_name = Name::from_node(&name);
113 let target = Some(name.syntax().text_range());
114 Some((col_name, target))
115 } else {
116 None
117 }
118 })
119 .collect()
120 };
121
122 for row in row_list.rows() {
123 for ((column_name, target), expr) in columns.iter().zip(row.exprs()) {
124 let expr_start = expr.syntax().text_range().start();
125 hints.push(InlayHint {
126 position: expr_start,
127 label: format!("{}: ", column_name),
128 kind: InlayHintKind::Parameter,
129 target: *target,
130 });
131 }
132 }
133
134 Some(())
135}
136
137#[cfg(test)]
138mod test {
139 use crate::inlay_hints::inlay_hints;
140 use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
141 use insta::assert_snapshot;
142 use squawk_syntax::ast;
143
144 #[track_caller]
145 fn check_inlay_hints(sql: &str) -> String {
146 let parse = ast::SourceFile::parse(sql);
147 assert_eq!(parse.errors(), vec![]);
148 let file: ast::SourceFile = parse.tree();
149
150 let hints = inlay_hints(&file);
151
152 if hints.is_empty() {
153 return String::new();
154 }
155
156 let mut modified_sql = sql.to_string();
157 let mut insertions: Vec<(usize, String)> = hints
158 .iter()
159 .map(|hint| {
160 let offset: usize = hint.position.into();
161 (offset, hint.label.clone())
162 })
163 .collect();
164
165 insertions.sort_by(|a, b| b.0.cmp(&a.0));
166
167 for (offset, label) in &insertions {
168 modified_sql.insert_str(*offset, label);
169 }
170
171 let mut annotations = vec![];
172 let mut cumulative_offset = 0;
173
174 insertions.reverse();
175 for (original_offset, label) in insertions {
176 let new_offset = original_offset + cumulative_offset;
177 annotations.push((new_offset, label.len()));
178 cumulative_offset += label.len();
179 }
180
181 let mut snippet = Snippet::source(&modified_sql).fold(true);
182
183 for (offset, len) in annotations {
184 snippet = snippet.annotation(AnnotationKind::Context.span(offset..offset + len));
185 }
186
187 let group = Level::INFO.primary_title("inlay hints").element(snippet);
188
189 let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
190 renderer
191 .render(&[group])
192 .to_string()
193 .replace("info: inlay hints", "inlay hints:")
194 }
195
196 #[test]
197 fn single_param() {
198 assert_snapshot!(check_inlay_hints("
199create function foo(a int) returns int as 'select $$1' language sql;
200select foo(1);
201"), @r"
202 inlay hints:
203 ╭▸
204 3 │ select foo(a: 1);
205 ╰╴ ───
206 ");
207 }
208
209 #[test]
210 fn multiple_params() {
211 assert_snapshot!(check_inlay_hints("
212create function add(a int, b int) returns int as 'select $$1 + $$2' language sql;
213select add(1, 2);
214"), @r"
215 inlay hints:
216 ╭▸
217 3 │ select add(a: 1, b: 2);
218 ╰╴ ─── ───
219 ");
220 }
221
222 #[test]
223 fn no_params() {
224 assert_snapshot!(check_inlay_hints("
225create function foo() returns int as 'select 1' language sql;
226select foo();
227"), @"");
228 }
229
230 #[test]
231 fn with_schema() {
232 assert_snapshot!(check_inlay_hints("
233create function public.foo(x int) returns int as 'select $$1' language sql;
234select public.foo(42);
235"), @r"
236 inlay hints:
237 ╭▸
238 3 │ select public.foo(x: 42);
239 ╰╴ ───
240 ");
241 }
242
243 #[test]
244 fn with_search_path() {
245 assert_snapshot!(check_inlay_hints(r#"
246set search_path to myschema;
247create function foo(val int) returns int as 'select $$1' language sql;
248select foo(100);
249"#), @r"
250 inlay hints:
251 ╭▸
252 4 │ select foo(val: 100);
253 ╰╴ ─────
254 ");
255 }
256
257 #[test]
258 fn multiple_calls() {
259 assert_snapshot!(check_inlay_hints("
260create function inc(n int) returns int as 'select $$1 + 1' language sql;
261select inc(1), inc(2);
262"), @r"
263 inlay hints:
264 ╭▸
265 3 │ select inc(n: 1), inc(n: 2);
266 ╰╴ ─── ───
267 ");
268 }
269
270 #[test]
271 fn more_args_than_params() {
272 assert_snapshot!(check_inlay_hints("
273create function foo(a int) returns int as 'select $$1' language sql;
274select foo(1, 2);
275"), @r"
276 inlay hints:
277 ╭▸
278 3 │ select foo(a: 1, 2);
279 ╰╴ ───
280 ");
281 }
282
283 #[test]
284 fn insert_with_column_list() {
285 assert_snapshot!(check_inlay_hints("
286create table t (column_a int, column_b int, column_c text);
287insert into t (column_a, column_c) values (1, 'foo');
288"), @r"
289 inlay hints:
290 ╭▸
291 3 │ insert into t (column_a, column_c) values (column_a: 1, column_c: 'foo');
292 ╰╴ ────────── ──────────
293 ");
294 }
295
296 #[test]
297 fn insert_without_column_list() {
298 assert_snapshot!(check_inlay_hints("
299create table t (column_a int, column_b int, column_c text);
300insert into t values (1, 2, 'foo');
301"), @r"
302 inlay hints:
303 ╭▸
304 3 │ insert into t values (column_a: 1, column_b: 2, column_c: 'foo');
305 ╰╴ ────────── ────────── ──────────
306 ");
307 }
308
309 #[test]
310 fn insert_multiple_rows() {
311 assert_snapshot!(check_inlay_hints("
312create table t (x int, y int);
313insert into t values (1, 2), (3, 4);
314"), @r"
315 inlay hints:
316 ╭▸
317 3 │ insert into t values (x: 1, y: 2), (x: 3, y: 4);
318 ╰╴ ─── ─── ─── ───
319 ");
320 }
321
322 #[test]
323 fn insert_no_create_table() {
324 assert_snapshot!(check_inlay_hints("
325insert into t (a, b) values (1, 2);
326"), @r"
327 inlay hints:
328 ╭▸
329 2 │ insert into t (a, b) values (a: 1, b: 2);
330 ╰╴ ─── ───
331 ");
332 }
333
334 #[test]
335 fn insert_more_values_than_columns() {
336 assert_snapshot!(check_inlay_hints("
337create table t (a int, b int);
338insert into t values (1, 2, 3);
339"), @r"
340 inlay hints:
341 ╭▸
342 3 │ insert into t values (a: 1, b: 2, 3);
343 ╰╴ ─── ───
344 ");
345 }
346}