1use crate::binder;
2use crate::binder::Binder;
3use crate::resolve;
4use crate::symbols::Name;
5use rowan::{TextRange, TextSize};
6use squawk_syntax::{
7 SyntaxNode,
8 ast::{self, AstNode},
9};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum InlayHintKind {
14 Type,
15 Parameter,
16}
17
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct InlayHint {
20 pub position: TextSize,
21 pub label: String,
22 pub kind: InlayHintKind,
23 pub target: Option<TextRange>,
24}
25
26pub fn inlay_hints(file: &ast::SourceFile) -> Vec<InlayHint> {
27 let mut hints = vec![];
28 let binder = binder::bind(file);
29 let root = file.syntax();
30
31 for node in root.descendants() {
32 if let Some(call_expr) = ast::CallExpr::cast(node.clone()) {
33 inlay_hint_call_expr(&mut hints, root, &binder, call_expr);
34 } else if let Some(insert) = ast::Insert::cast(node) {
35 inlay_hint_insert(&mut hints, root, &binder, insert);
36 }
37 }
38
39 hints
40}
41
42fn inlay_hint_call_expr(
43 hints: &mut Vec<InlayHint>,
44 root: &SyntaxNode,
45 binder: &Binder,
46 call_expr: ast::CallExpr,
47) -> Option<()> {
48 let arg_list = call_expr.arg_list()?;
49 let expr = call_expr.expr()?;
50
51 let name_ref = if let Some(name_ref) = ast::NameRef::cast(expr.syntax().clone()) {
52 name_ref
53 } else {
54 ast::FieldExpr::cast(expr.syntax().clone())?.field()?
55 };
56
57 let function_ptr = resolve::resolve_name_ref(binder, root, &name_ref)?
58 .into_iter()
59 .next()?;
60
61 let function_name_node = function_ptr.to_node(root);
62
63 if let Some(create_function) = function_name_node
64 .ancestors()
65 .find_map(ast::CreateFunction::cast)
66 && let Some(param_list) = create_function.param_list()
67 {
68 for (param, arg) in param_list.params().zip(arg_list.args()) {
69 if let Some(param_name) = param.name() {
70 let arg_start = arg.syntax().text_range().start();
71 let target = Some(param_name.syntax().text_range());
72 hints.push(InlayHint {
73 position: arg_start,
74 label: format!("{}: ", param_name.syntax().text()),
75 kind: InlayHintKind::Parameter,
76 target,
77 });
78 }
79 }
80 };
81
82 Some(())
83}
84
85fn inlay_hint_insert(
86 hints: &mut Vec<InlayHint>,
87 root: &SyntaxNode,
88 binder: &Binder,
89 insert: ast::Insert,
90) -> Option<()> {
91 let values = insert.values()?;
92 let row_list = values.row_list()?;
93 let create_table = resolve::resolve_insert_create_table(root, binder, &insert);
94
95 let columns: Vec<(Name, Option<TextRange>)> = if let Some(column_list) = insert.column_list() {
96 column_list
98 .columns()
99 .filter_map(|col| {
100 let col_name = resolve::extract_column_name(&col)?;
101 let target = create_table
102 .as_ref()
103 .and_then(|x| resolve::find_column_in_create_table(x, &col_name))
104 .map(|x| x.text_range());
105 Some((col_name, target))
106 })
107 .collect()
108 } else {
109 create_table?
111 .table_arg_list()?
112 .args()
113 .filter_map(|arg| {
114 if let ast::TableArg::Column(column) = arg
115 && let Some(name) = column.name()
116 {
117 let col_name = Name::from_node(&name);
118 let target = Some(name.syntax().text_range());
119 Some((col_name, target))
120 } else {
121 None
122 }
123 })
124 .collect()
125 };
126
127 for row in row_list.rows() {
128 for ((column_name, target), expr) in columns.iter().zip(row.exprs()) {
129 let expr_start = expr.syntax().text_range().start();
130 hints.push(InlayHint {
131 position: expr_start,
132 label: format!("{}: ", column_name),
133 kind: InlayHintKind::Parameter,
134 target: *target,
135 });
136 }
137 }
138
139 Some(())
140}
141
142#[cfg(test)]
143mod test {
144 use crate::inlay_hints::inlay_hints;
145 use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
146 use insta::assert_snapshot;
147 use squawk_syntax::ast;
148
149 #[track_caller]
150 fn check_inlay_hints(sql: &str) -> String {
151 let parse = ast::SourceFile::parse(sql);
152 assert_eq!(parse.errors(), vec![]);
153 let file: ast::SourceFile = parse.tree();
154
155 let hints = inlay_hints(&file);
156
157 if hints.is_empty() {
158 return String::new();
159 }
160
161 let mut modified_sql = sql.to_string();
162 let mut insertions: Vec<(usize, String)> = hints
163 .iter()
164 .map(|hint| {
165 let offset: usize = hint.position.into();
166 (offset, hint.label.clone())
167 })
168 .collect();
169
170 insertions.sort_by(|a, b| b.0.cmp(&a.0));
171
172 for (offset, label) in &insertions {
173 modified_sql.insert_str(*offset, label);
174 }
175
176 let mut annotations = vec![];
177 let mut cumulative_offset = 0;
178
179 insertions.reverse();
180 for (original_offset, label) in insertions {
181 let new_offset = original_offset + cumulative_offset;
182 annotations.push((new_offset, label.len()));
183 cumulative_offset += label.len();
184 }
185
186 let mut snippet = Snippet::source(&modified_sql).fold(true);
187
188 for (offset, len) in annotations {
189 snippet = snippet.annotation(AnnotationKind::Context.span(offset..offset + len));
190 }
191
192 let group = Level::INFO.primary_title("inlay hints").element(snippet);
193
194 let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
195 renderer
196 .render(&[group])
197 .to_string()
198 .replace("info: inlay hints", "inlay hints:")
199 }
200
201 #[test]
202 fn single_param() {
203 assert_snapshot!(check_inlay_hints("
204create function foo(a int) returns int as 'select $$1' language sql;
205select foo(1);
206"), @r"
207 inlay hints:
208 ╭▸
209 3 │ select foo(a: 1);
210 ╰╴ ───
211 ");
212 }
213
214 #[test]
215 fn multiple_params() {
216 assert_snapshot!(check_inlay_hints("
217create function add(a int, b int) returns int as 'select $$1 + $$2' language sql;
218select add(1, 2);
219"), @r"
220 inlay hints:
221 ╭▸
222 3 │ select add(a: 1, b: 2);
223 ╰╴ ─── ───
224 ");
225 }
226
227 #[test]
228 fn no_params() {
229 assert_snapshot!(check_inlay_hints("
230create function foo() returns int as 'select 1' language sql;
231select foo();
232"), @"");
233 }
234
235 #[test]
236 fn with_schema() {
237 assert_snapshot!(check_inlay_hints("
238create function public.foo(x int) returns int as 'select $$1' language sql;
239select public.foo(42);
240"), @r"
241 inlay hints:
242 ╭▸
243 3 │ select public.foo(x: 42);
244 ╰╴ ───
245 ");
246 }
247
248 #[test]
249 fn with_search_path() {
250 assert_snapshot!(check_inlay_hints(r#"
251set search_path to myschema;
252create function foo(val int) returns int as 'select $$1' language sql;
253select foo(100);
254"#), @r"
255 inlay hints:
256 ╭▸
257 4 │ select foo(val: 100);
258 ╰╴ ─────
259 ");
260 }
261
262 #[test]
263 fn multiple_calls() {
264 assert_snapshot!(check_inlay_hints("
265create function inc(n int) returns int as 'select $$1 + 1' language sql;
266select inc(1), inc(2);
267"), @r"
268 inlay hints:
269 ╭▸
270 3 │ select inc(n: 1), inc(n: 2);
271 ╰╴ ─── ───
272 ");
273 }
274
275 #[test]
276 fn more_args_than_params() {
277 assert_snapshot!(check_inlay_hints("
278create function foo(a int) returns int as 'select $$1' language sql;
279select foo(1, 2);
280"), @r"
281 inlay hints:
282 ╭▸
283 3 │ select foo(a: 1, 2);
284 ╰╴ ───
285 ");
286 }
287
288 #[test]
289 fn insert_with_column_list() {
290 assert_snapshot!(check_inlay_hints("
291create table t (column_a int, column_b int, column_c text);
292insert into t (column_a, column_c) values (1, 'foo');
293"), @r"
294 inlay hints:
295 ╭▸
296 3 │ insert into t (column_a, column_c) values (column_a: 1, column_c: 'foo');
297 ╰╴ ────────── ──────────
298 ");
299 }
300
301 #[test]
302 fn insert_without_column_list() {
303 assert_snapshot!(check_inlay_hints("
304create table t (column_a int, column_b int, column_c text);
305insert into t values (1, 2, 'foo');
306"), @r"
307 inlay hints:
308 ╭▸
309 3 │ insert into t values (column_a: 1, column_b: 2, column_c: 'foo');
310 ╰╴ ────────── ────────── ──────────
311 ");
312 }
313
314 #[test]
315 fn insert_multiple_rows() {
316 assert_snapshot!(check_inlay_hints("
317create table t (x int, y int);
318insert into t values (1, 2), (3, 4);
319"), @r"
320 inlay hints:
321 ╭▸
322 3 │ insert into t values (x: 1, y: 2), (x: 3, y: 4);
323 ╰╴ ─── ─── ─── ───
324 ");
325 }
326
327 #[test]
328 fn insert_no_create_table() {
329 assert_snapshot!(check_inlay_hints("
330insert into t (a, b) values (1, 2);
331"), @r"
332 inlay hints:
333 ╭▸
334 2 │ insert into t (a, b) values (a: 1, b: 2);
335 ╰╴ ─── ───
336 ");
337 }
338
339 #[test]
340 fn insert_more_values_than_columns() {
341 assert_snapshot!(check_inlay_hints("
342create table t (a int, b int);
343insert into t values (1, 2, 3);
344"), @r"
345 inlay hints:
346 ╭▸
347 3 │ insert into t values (a: 1, b: 2, 3);
348 ╰╴ ─── ───
349 ");
350 }
351}