1use rowan::TextRange;
2use squawk_syntax::ast::{self, AstNode};
3
4use crate::binder::{self, extract_string_literal};
5use crate::resolve::{resolve_function_info, resolve_table_info, resolve_type_info};
6
7#[derive(Debug)]
8pub enum DocumentSymbolKind {
9 Table,
10 Function,
11 Type,
12 Column,
13 Variant,
14}
15
16#[derive(Debug)]
17pub struct DocumentSymbol {
18 pub name: String,
19 pub detail: Option<String>,
20 pub kind: DocumentSymbolKind,
21 pub full_range: TextRange,
24 pub focus_range: TextRange,
26 pub children: Vec<DocumentSymbol>,
27}
28
29pub fn document_symbols(file: &ast::SourceFile) -> Vec<DocumentSymbol> {
30 let binder = binder::bind(file);
31 let mut symbols = vec![];
32
33 for stmt in file.stmts() {
34 match stmt {
35 ast::Stmt::CreateTable(create_table) => {
36 if let Some(symbol) = create_table_symbol(&binder, create_table) {
37 symbols.push(symbol);
38 }
39 }
40 ast::Stmt::CreateFunction(create_function) => {
41 if let Some(symbol) = create_function_symbol(&binder, create_function) {
42 symbols.push(symbol);
43 }
44 }
45 ast::Stmt::CreateType(create_type) => {
46 if let Some(symbol) = create_type_symbol(&binder, create_type) {
47 symbols.push(symbol);
48 }
49 }
50 _ => {}
51 }
52 }
53
54 symbols
55}
56
57fn create_table_symbol(
58 binder: &binder::Binder,
59 create_table: ast::CreateTable,
60) -> Option<DocumentSymbol> {
61 let path = create_table.path()?;
62 let segment = path.segment()?;
63 let name_node = segment.name()?;
64
65 let (schema, table_name) = resolve_table_info(binder, &path)?;
66 let name = format!("{}.{}", schema.0, table_name);
67
68 let full_range = create_table.syntax().text_range();
69 let focus_range = name_node.syntax().text_range();
70
71 let mut children = vec![];
72 if let Some(table_arg_list) = create_table.table_arg_list() {
73 for arg in table_arg_list.args() {
74 if let ast::TableArg::Column(column) = arg
75 && let Some(column_symbol) = create_column_symbol(column)
76 {
77 children.push(column_symbol);
78 }
79 }
80 }
81
82 Some(DocumentSymbol {
83 name,
84 detail: None,
85 kind: DocumentSymbolKind::Table,
86 full_range,
87 focus_range,
88 children,
89 })
90}
91
92fn create_function_symbol(
93 binder: &binder::Binder,
94 create_function: ast::CreateFunction,
95) -> Option<DocumentSymbol> {
96 let path = create_function.path()?;
97 let segment = path.segment()?;
98 let name_node = segment.name()?;
99
100 let (schema, function_name) = resolve_function_info(binder, &path)?;
101 let name = format!("{}.{}", schema.0, function_name);
102
103 let full_range = create_function.syntax().text_range();
104 let focus_range = name_node.syntax().text_range();
105
106 Some(DocumentSymbol {
107 name,
108 detail: None,
109 kind: DocumentSymbolKind::Function,
110 full_range,
111 focus_range,
112 children: vec![],
113 })
114}
115
116fn create_type_symbol(
117 binder: &binder::Binder,
118 create_type: ast::CreateType,
119) -> Option<DocumentSymbol> {
120 let path = create_type.path()?;
121 let segment = path.segment()?;
122 let name_node = segment.name()?;
123
124 let (schema, type_name) = resolve_type_info(binder, &path)?;
125 let name = format!("{}.{}", schema.0, type_name);
126
127 let full_range = create_type.syntax().text_range();
128 let focus_range = name_node.syntax().text_range();
129
130 let mut children = vec![];
131 if let Some(variant_list) = create_type.variant_list() {
132 for variant in variant_list.variants() {
133 if let Some(variant_symbol) = create_variant_symbol(variant) {
134 children.push(variant_symbol);
135 }
136 }
137 } else if let Some(column_list) = create_type.column_list() {
138 for column in column_list.columns() {
139 if let Some(column_symbol) = create_column_symbol(column) {
140 children.push(column_symbol);
141 }
142 }
143 }
144
145 Some(DocumentSymbol {
146 name,
147 detail: None,
148 kind: DocumentSymbolKind::Type,
149 full_range,
150 focus_range,
151 children,
152 })
153}
154
155fn create_column_symbol(column: ast::Column) -> Option<DocumentSymbol> {
156 let name_node = column.name()?;
157 let name = name_node.syntax().text().to_string();
158
159 let detail = column.ty().map(|t| t.syntax().text().to_string());
160
161 let full_range = column.syntax().text_range();
162 let focus_range = name_node.syntax().text_range();
163
164 Some(DocumentSymbol {
165 name,
166 detail,
167 kind: DocumentSymbolKind::Column,
168 full_range,
169 focus_range,
170 children: vec![],
171 })
172}
173
174fn create_variant_symbol(variant: ast::Variant) -> Option<DocumentSymbol> {
175 let literal = variant.literal()?;
176 let name = extract_string_literal(&literal)?;
177
178 let full_range = variant.syntax().text_range();
179 let focus_range = literal.syntax().text_range();
180
181 Some(DocumentSymbol {
182 name,
183 detail: None,
184 kind: DocumentSymbolKind::Variant,
185 full_range,
186 focus_range,
187 children: vec![],
188 })
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use annotate_snippets::{
195 AnnotationKind, Group, Level, Renderer, Snippet, renderer::DecorStyle,
196 };
197 use insta::assert_snapshot;
198
199 fn symbols_not_found(sql: &str) {
200 let parse = ast::SourceFile::parse(sql);
201 let file = parse.tree();
202 let symbols = document_symbols(&file);
203 if !symbols.is_empty() {
204 panic!("Symbols found. If this is expected, use `symbols` instead.")
205 }
206 }
207
208 fn symbols(sql: &str) -> String {
209 let parse = ast::SourceFile::parse(sql);
210 let file = parse.tree();
211 let symbols = document_symbols(&file);
212 if symbols.is_empty() {
213 panic!("No symbols found. If this is expected, use `symbols_not_found` instead.")
214 }
215
216 let mut output = vec![];
217 for symbol in symbols {
218 let group = symbol_to_group(&symbol, sql);
219 output.push(group);
220 }
221 Renderer::plain()
222 .decor_style(DecorStyle::Unicode)
223 .render(&output)
224 .to_string()
225 }
226
227 fn symbol_to_group<'a>(symbol: &DocumentSymbol, sql: &'a str) -> Group<'a> {
228 let kind = match symbol.kind {
229 DocumentSymbolKind::Table => "table",
230 DocumentSymbolKind::Function => "function",
231 DocumentSymbolKind::Type => "type",
232 DocumentSymbolKind::Column => "column",
233 DocumentSymbolKind::Variant => "variant",
234 };
235
236 let title = if let Some(detail) = &symbol.detail {
237 format!("{}: {} {}", kind, symbol.name, detail)
238 } else {
239 format!("{}: {}", kind, symbol.name)
240 };
241
242 let snippet = Snippet::source(sql)
243 .fold(true)
244 .annotation(
245 AnnotationKind::Primary
246 .span(symbol.focus_range.into())
247 .label("focus range"),
248 )
249 .annotation(
250 AnnotationKind::Context
251 .span(symbol.full_range.into())
252 .label("full range"),
253 );
254
255 let mut group = Level::INFO.primary_title(title.clone()).element(snippet);
256
257 if !symbol.children.is_empty() {
258 let child_labels: Vec<String> = symbol
259 .children
260 .iter()
261 .map(|child| {
262 let kind = match child.kind {
263 DocumentSymbolKind::Column => "column",
264 DocumentSymbolKind::Variant => "variant",
265 _ => unreachable!("only columns and variants can be children"),
266 };
267 if let Some(detail) = &child.detail {
268 format!("{}: {} {}", kind, child.name, detail)
269 } else {
270 format!("{}: {}", kind, child.name)
271 }
272 })
273 .collect();
274
275 let mut children_snippet = Snippet::source(sql).fold(true);
276
277 for (i, child) in symbol.children.iter().enumerate() {
278 children_snippet = children_snippet
279 .annotation(
280 AnnotationKind::Context
281 .span(child.full_range.into())
282 .label(format!("full range for `{}`", child_labels[i].clone())),
283 )
284 .annotation(
285 AnnotationKind::Primary
286 .span(child.focus_range.into())
287 .label("focus range"),
288 );
289 }
290
291 group = group.element(children_snippet);
292 }
293
294 group
295 }
296
297 #[test]
298 fn create_table() {
299 assert_snapshot!(symbols("
300create table users (
301 id int,
302 email citext
303);"), @r"
304 info: table: public.users
305 ╭▸
306 2 │ create table users (
307 │ │ ━━━━━ focus range
308 │ ┌─┘
309 │ │
310 3 │ │ id int,
311 4 │ │ email citext
312 5 │ │ );
313 │ └─┘ full range
314 │
315 ⸬
316 3 │ id int,
317 │ ┯━────
318 │ │
319 │ full range for `column: id int`
320 │ focus range
321 4 │ email citext
322 │ ┯━━━━───────
323 │ │
324 │ full range for `column: email citext`
325 ╰╴ focus range
326 ");
327 }
328
329 #[test]
330 fn create_function() {
331 assert_snapshot!(
332 symbols("create function hello() returns void as $$ select 1; $$ language sql;"),
333 @r"
334 info: function: public.hello
335 ╭▸
336 1 │ create function hello() returns void as $$ select 1; $$ language sql;
337 │ ┬───────────────┯━━━━───────────────────────────────────────────────
338 │ │ │
339 │ │ focus range
340 ╰╴full range
341 "
342 );
343 }
344
345 #[test]
346 fn multiple_symbols() {
347 assert_snapshot!(symbols("
348create table users (id int);
349create table posts (id int);
350create function get_user(user_id int) returns void as $$ select 1; $$ language sql;
351"), @r"
352 info: table: public.users
353 ╭▸
354 2 │ create table users (id int);
355 │ ┬────────────┯━━━━─────────
356 │ │ │
357 │ │ focus range
358 │ full range
359 │
360 ⸬
361 2 │ create table users (id int);
362 │ ┯━────
363 │ │
364 │ full range for `column: id int`
365 │ focus range
366 ╰╴
367 info: table: public.posts
368 ╭▸
369 3 │ create table posts (id int);
370 │ ┬────────────┯━━━━─────────
371 │ │ │
372 │ │ focus range
373 │ full range
374 │
375 ⸬
376 3 │ create table posts (id int);
377 │ ┯━────
378 │ │
379 │ full range for `column: id int`
380 ╰╴ focus range
381 info: function: public.get_user
382 ╭▸
383 4 │ create function get_user(user_id int) returns void as $$ select 1; $$ language sql;
384 │ ┬───────────────┯━━━━━━━──────────────────────────────────────────────────────────
385 │ │ │
386 │ │ focus range
387 ╰╴full range
388 ");
389 }
390
391 #[test]
392 fn qualified_names() {
393 assert_snapshot!(symbols("
394create table public.users (id int);
395create function my_schema.hello() returns void as $$ select 1; $$ language sql;
396"), @r"
397 info: table: public.users
398 ╭▸
399 2 │ create table public.users (id int);
400 │ ┬───────────────────┯━━━━─────────
401 │ │ │
402 │ │ focus range
403 │ full range
404 │
405 ⸬
406 2 │ create table public.users (id int);
407 │ ┯━────
408 │ │
409 │ full range for `column: id int`
410 │ focus range
411 ╰╴
412 info: function: my_schema.hello
413 ╭▸
414 3 │ create function my_schema.hello() returns void as $$ select 1; $$ language sql;
415 │ ┬─────────────────────────┯━━━━───────────────────────────────────────────────
416 │ │ │
417 │ │ focus range
418 ╰╴full range
419 ");
420 }
421
422 #[test]
423 fn create_type() {
424 assert_snapshot!(
425 symbols("create type status as enum ('active', 'inactive');"),
426 @r"
427 info: type: public.status
428 ╭▸
429 1 │ create type status as enum ('active', 'inactive');
430 │ ┬───────────┯━━━━━───────────────────────────────
431 │ │ │
432 │ │ focus range
433 │ full range
434 │
435 ⸬
436 1 │ create type status as enum ('active', 'inactive');
437 │ ┯━━━━━━━ ┯━━━━━━━━━
438 │ │ │
439 │ │ full range for `variant: inactive`
440 │ │ focus range
441 │ full range for `variant: active`
442 ╰╴ focus range
443 "
444 );
445 }
446
447 #[test]
448 fn create_type_composite() {
449 assert_snapshot!(
450 symbols("create type person as (name text, age int);"),
451 @r"
452 info: type: public.person
453 ╭▸
454 1 │ create type person as (name text, age int);
455 │ ┬───────────┯━━━━━────────────────────────
456 │ │ │
457 │ │ focus range
458 │ full range
459 │
460 ⸬
461 1 │ create type person as (name text, age int);
462 │ ┯━━━───── ┯━━────
463 │ │ │
464 │ │ full range for `column: age int`
465 │ │ focus range
466 │ full range for `column: name text`
467 ╰╴ focus range
468 "
469 );
470 }
471
472 #[test]
473 fn create_type_composite_multiple_columns() {
474 assert_snapshot!(
475 symbols("create type address as (street text, city text, zip varchar(10));"),
476 @r"
477 info: type: public.address
478 ╭▸
479 1 │ create type address as (street text, city text, zip varchar(10));
480 │ ┬───────────┯━━━━━━─────────────────────────────────────────────
481 │ │ │
482 │ │ focus range
483 │ full range
484 │
485 ⸬
486 1 │ create type address as (street text, city text, zip varchar(10));
487 │ ┯━━━━━───── ┯━━━───── ┯━━────────────
488 │ │ │ │
489 │ │ │ full range for `column: zip varchar(10)`
490 │ │ │ focus range
491 │ │ full range for `column: city text`
492 │ │ focus range
493 │ full range for `column: street text`
494 ╰╴ focus range
495 "
496 );
497 }
498
499 #[test]
500 fn create_type_with_schema() {
501 assert_snapshot!(
502 symbols("create type myschema.status as enum ('active', 'inactive');"),
503 @r"
504 info: type: myschema.status
505 ╭▸
506 1 │ create type myschema.status as enum ('active', 'inactive');
507 │ ┬────────────────────┯━━━━━───────────────────────────────
508 │ │ │
509 │ │ focus range
510 │ full range
511 │
512 ⸬
513 1 │ create type myschema.status as enum ('active', 'inactive');
514 │ ┯━━━━━━━ ┯━━━━━━━━━
515 │ │ │
516 │ │ full range for `variant: inactive`
517 │ │ focus range
518 │ full range for `variant: active`
519 ╰╴ focus range
520 "
521 );
522 }
523
524 #[test]
525 fn create_type_enum_multiple_variants() {
526 assert_snapshot!(
527 symbols("create type priority as enum ('low', 'medium', 'high', 'urgent');"),
528 @r"
529 info: type: public.priority
530 ╭▸
531 1 │ create type priority as enum ('low', 'medium', 'high', 'urgent');
532 │ ┬───────────┯━━━━━━━────────────────────────────────────────────
533 │ │ │
534 │ │ focus range
535 │ full range
536 │
537 ⸬
538 1 │ create type priority as enum ('low', 'medium', 'high', 'urgent');
539 │ ┯━━━━ ┯━━━━━━━ ┯━━━━━ ┯━━━━━━━
540 │ │ │ │ │
541 │ │ │ │ full range for `variant: urgent`
542 │ │ │ │ focus range
543 │ │ │ full range for `variant: high`
544 │ │ │ focus range
545 │ │ full range for `variant: medium`
546 │ │ focus range
547 │ full range for `variant: low`
548 ╰╴ focus range
549 "
550 );
551 }
552
553 #[test]
554 fn empty_file() {
555 symbols_not_found("")
556 }
557
558 #[test]
559 fn non_create_statements() {
560 symbols_not_found("select * from users;")
561 }
562}