1use rowan::{NodeOrToken, TextRange};
2use salsa::Database as Db;
3use squawk_syntax::{
4 SyntaxElement, SyntaxKind,
5 ast::{self, AstNode},
6};
7
8use crate::db::{File, parse};
9use crate::goto_definition::{LocationKind, goto_definition};
10
11fn highlight_param_mode(out: &mut SemanticTokenBuilder, mode: ast::ParamMode) {
12 match mode {
13 ast::ParamMode::ParamIn(param_in) => {
14 if let Some(token) = param_in.in_token() {
15 out.push_keyword(token.into());
16 }
17 }
18 ast::ParamMode::ParamInOut(param_in_out) => {
19 if let Some(token) = param_in_out.in_token() {
20 out.push_keyword(token.into());
21 }
22 if let Some(token) = param_in_out.inout_token() {
23 out.push_keyword(token.into());
24 }
25 if let Some(token) = param_in_out.out_token() {
26 out.push_keyword(token.into());
27 }
28 }
29 ast::ParamMode::ParamOut(param_out) => {
30 if let Some(token) = param_out.out_token() {
31 out.push_keyword(token.into());
32 }
33 }
34 ast::ParamMode::ParamVariadic(param_variadic) => {
35 if let Some(token) = param_variadic.variadic_token() {
36 out.push_keyword(token.into());
37 }
38 }
39 }
40}
41
42fn highlight_type(out: &mut SemanticTokenBuilder, ty: ast::Type) {
43 match ty {
44 ast::Type::ArrayType(_) => (),
45 ast::Type::BitType(bit_type) => {
46 if let Some(token) = bit_type.setof_token() {
47 out.push_type(token.into());
48 }
49 if let Some(token) = bit_type.bit_token() {
50 out.push_type(token.into());
51 }
52 if let Some(token) = bit_type.varying_token() {
53 out.push_type(token.into());
54 }
55 }
56 ast::Type::CharType(char_type) => {
57 if let Some(token) = char_type.setof_token() {
58 out.push_type(token.into());
59 }
60 if let Some(token) = char_type.national_token() {
61 out.push_type(token.into());
62 }
63
64 if let Some(token) = char_type
65 .varchar_token()
66 .or_else(|| char_type.nchar_token())
67 .or_else(|| char_type.character_token())
68 .or_else(|| char_type.char_token())
69 {
70 out.push_type(token.into());
71 }
72 if let Some(token) = char_type.varying_token() {
73 out.push_type(token.into());
74 }
75 }
76 ast::Type::DoubleType(double_type) => {
77 if let Some(token) = double_type.setof_token() {
78 out.push_type(token.into());
79 }
80 if let Some(token) = double_type.double_token() {
81 out.push_type(token.into());
82 }
83 if let Some(token) = double_type.precision_token() {
84 out.push_type(token.into());
85 }
86 }
87 ast::Type::ExprType(_) => (),
88 ast::Type::IntervalType(interval_type) => {
89 if let Some(token) = interval_type.setof_token() {
90 out.push_type(token.into());
91 }
92 if let Some(token) = interval_type.interval_token() {
93 out.push_type(token.into());
94 }
95 }
96 ast::Type::PathType(path_type) => {
97 if let Some(token) = path_type.setof_token() {
98 out.push_type(token.into());
99 }
100 }
101 ast::Type::PercentType(_) => (),
102 ast::Type::TimeType(time_type) => {
103 if let Some(token) = time_type.setof_token() {
104 out.push_type(token.into());
105 }
106 if let Some(token) = time_type
107 .timestamp_token()
108 .or_else(|| time_type.time_token())
109 {
110 out.push_type(token.into());
111 }
112
113 if let Some(timezone) = time_type.timezone() {
114 match timezone {
115 ast::Timezone::WithTimezone(with_timezone) => {
116 if let Some(token) = with_timezone.with_token() {
117 out.push_type(token.into());
118 }
119 if let Some(token) = with_timezone.time_token() {
120 out.push_type(token.into());
121 }
122 if let Some(token) = with_timezone.zone_token() {
123 out.push_type(token.into());
124 }
125 }
126 ast::Timezone::WithoutTimezone(without_timezone) => {
127 if let Some(token) = without_timezone.without_token() {
128 out.push_type(token.into());
129 }
130 if let Some(token) = without_timezone.time_token() {
131 out.push_type(token.into());
132 }
133 if let Some(token) = without_timezone.zone_token() {
134 out.push_type(token.into());
135 }
136 }
137 }
138 }
139 }
140 }
141}
142
143#[derive(Debug, Clone, PartialEq, Eq)]
145pub struct SemanticToken {
146 pub range: TextRange,
147 pub token_type: SemanticTokenType,
148 pub modifiers: Option<SemanticTokenModifier>,
149}
150
151#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
152#[repr(u8)]
153pub enum SemanticTokenModifier {
154 Definition = 0,
155 Readonly,
156 Documentation,
157}
158
159#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
161pub enum SemanticTokenType {
162 Keyword,
163 String,
164 Bool,
165 Number,
166 Function,
167 Operator,
168 Punctuation,
169 Name,
170 NameRef,
171 Comment,
172 Column,
173 Type,
174 Parameter,
175 PositionalParam,
176 Table,
177 Schema,
178}
179
180impl TryFrom<LocationKind> for SemanticTokenType {
181 type Error = LocationKind;
182
183 fn try_from(kind: LocationKind) -> Result<Self, Self::Error> {
184 match kind {
185 LocationKind::Aggregate | LocationKind::Function | LocationKind::Procedure => {
186 Ok(SemanticTokenType::Function)
187 }
188 LocationKind::Column => Ok(SemanticTokenType::Column),
189 LocationKind::NamedArgParameter => Ok(SemanticTokenType::Parameter),
190 LocationKind::Schema => Ok(SemanticTokenType::Schema),
191 LocationKind::Sequence | LocationKind::Table | LocationKind::View => {
192 Ok(SemanticTokenType::Table)
193 }
194 LocationKind::Type => Ok(SemanticTokenType::Type),
195 LocationKind::CaseExpr
196 | LocationKind::Channel
197 | LocationKind::CommitBegin
198 | LocationKind::CommitEnd
199 | LocationKind::Cursor
200 | LocationKind::Database
201 | LocationKind::EventTrigger
202 | LocationKind::Extension
203 | LocationKind::Index
204 | LocationKind::Policy
205 | LocationKind::PreparedStatement
206 | LocationKind::PropertyGraph
207 | LocationKind::Role
208 | LocationKind::Server
209 | LocationKind::Tablespace
210 | LocationKind::Trigger
211 | LocationKind::Window => Err(kind),
212 }
213 }
214}
215
216fn token_type_for_node<T: AstNode>(db: &dyn Db, file: File, node: &T) -> Option<SemanticTokenType> {
217 let offset = node.syntax().text_range().start();
218 let location = goto_definition(db, file, offset).into_iter().next()?;
219
220 SemanticTokenType::try_from(location.kind).ok()
221}
222
223#[derive(Default)]
224struct SemanticTokenBuilder {
225 tokens: Vec<SemanticToken>,
226}
227
228impl SemanticTokenBuilder {
229 fn build(mut self) -> Vec<SemanticToken> {
230 self.tokens
231 .sort_by_key(|token| (token.range.start(), token.range.end()));
232 self.tokens
233 }
234
235 fn push_keyword(&mut self, syntax_element: SyntaxElement) {
236 self.push_token(syntax_element, SemanticTokenType::Keyword);
237 }
238
239 fn push_type(&mut self, syntax_element: SyntaxElement) {
240 self.push_token(syntax_element, SemanticTokenType::Type);
241 }
242
243 fn push_token(&mut self, syntax_element: SyntaxElement, token_type: SemanticTokenType) {
244 self.tokens.push(SemanticToken {
245 range: syntax_element.text_range(),
246 token_type,
247 modifiers: None,
248 });
249 }
250}
251
252#[salsa::tracked]
253pub fn semantic_tokens(
254 db: &dyn Db,
255 file: File,
256 range_to_highlight: Option<TextRange>,
257) -> Vec<SemanticToken> {
258 let parse = parse(db, file);
259 let tree = parse.tree();
260 let root = tree.syntax();
261
262 let (root, range_to_highlight) = {
264 let source_file = root;
265 match range_to_highlight {
266 Some(range) => {
267 let node = match source_file.covering_element(range) {
268 NodeOrToken::Node(it) => it,
269 NodeOrToken::Token(it) => it.parent().unwrap_or_else(|| source_file.clone()),
270 };
271 (node, range)
272 }
273 None => (source_file.clone(), source_file.text_range()),
274 }
275 };
276
277 let mut out = SemanticTokenBuilder::default();
278
279 let preorder = root.preorder_with_tokens();
281 for event in preorder {
282 use rowan::WalkEvent::{Enter, Leave};
283
284 let range = match &event {
285 Enter(it) | Leave(it) => it.text_range(),
286 };
287
288 if range_to_highlight.intersect(range).is_none() {
290 continue;
291 }
292
293 match event {
294 Enter(NodeOrToken::Node(node)) => {
295 if let Some(name) = ast::Name::cast(node.clone())
296 && let Some(token_type) = token_type_for_node(db, file, &name)
297 {
298 out.push_token(name.syntax().clone().into(), token_type);
299 }
300
301 if let Some(name_ref) = ast::NameRef::cast(node.clone())
302 && let Some(token_type) = token_type_for_node(db, file, &name_ref)
303 {
304 out.push_token(name_ref.syntax().clone().into(), token_type);
305 }
306
307 if let Some(ty) = ast::Type::cast(node.clone()) {
308 highlight_type(&mut out, ty);
309 }
310
311 if let Some(mode) = ast::ParamMode::cast(node.clone()) {
312 highlight_param_mode(&mut out, mode);
313 }
314
315 if let Some(like_clause) = ast::LikeClause::cast(node.clone())
319 && let Some(token) = like_clause.like_token()
320 {
321 out.push_keyword(token.into());
322 }
323 if let Some(not_null_constraint) = ast::NotNullConstraint::cast(node.clone())
324 && let Some(token) = not_null_constraint.not_token()
325 {
326 out.push_keyword(token.into());
327 }
328 if let Some(partition_for_values_in) = ast::PartitionForValuesIn::cast(node.clone())
329 && let Some(token) = partition_for_values_in.in_token()
330 {
331 out.push_keyword(token.into());
332 }
333 }
334 Enter(NodeOrToken::Token(token)) => {
335 if token.kind() == SyntaxKind::WHITESPACE {
336 continue;
337 }
338 if token.kind() == SyntaxKind::POSITIONAL_PARAM {
339 out.push_token(token.into(), SemanticTokenType::PositionalParam);
340 }
341 }
342 Leave(_) => {}
343 }
344 }
345
346 out.build()
347}
348
349#[cfg(test)]
350mod test {
351 use crate::db::{Database, File};
352 use insta::assert_snapshot;
353 use std::fmt::Write;
354
355 fn semantic_tokens(sql: &str) -> String {
356 let db = Database::default();
357 let file = File::new(&db, sql.to_string().into());
358 let tokens = super::semantic_tokens(&db, file, None);
359
360 let mut result = String::new();
361 for token in tokens {
362 let start: usize = token.range.start().into();
363 let end: usize = token.range.end().into();
364 let token_text = &sql[start..end];
365 let modifiers_text = "";
367 writeln!(
368 result,
369 "{:?} @ {}..{}: {:?}{}",
370 token_text, start, end, token.token_type, modifiers_text
371 )
372 .unwrap();
373 }
374 result
375 }
376
377 #[test]
378 fn create_function_misc_params() {
379 assert_snapshot!(semantic_tokens(
380 "
381create function add(
382 in a int = 1,
383 inout b text default 'x',
384 in out c varchar(10)[],
385 variadic d int[]
386) returns int
387as 'select $1 + $2'
388language sql;
389",
390 ), @r#"
391 "add" @ 17..20: Function
392 "in" @ 24..26: Keyword
393 "a" @ 27..28: Parameter
394 "int" @ 29..32: Type
395 "inout" @ 40..45: Keyword
396 "b" @ 46..47: Parameter
397 "text" @ 48..52: Type
398 "in" @ 68..70: Keyword
399 "out" @ 71..74: Keyword
400 "c" @ 75..76: Parameter
401 "varchar" @ 77..84: Type
402 "variadic" @ 94..102: Keyword
403 "d" @ 103..104: Parameter
404 "int" @ 105..108: Type
405 "int" @ 121..124: Type
406 "#);
407 }
408
409 #[test]
410 fn create_function_param_mode_type() {
411 assert_snapshot!(semantic_tokens(
412 "
413create function f(int8 in int8)
414returns void
415as '' language sql;
416",
417 ), @r#"
418 "f" @ 17..18: Function
419 "int8" @ 19..23: Parameter
420 "in" @ 24..26: Keyword
421 "int8" @ 27..31: Type
422 "void" @ 41..45: Type
423 "#);
424 }
425
426 #[test]
427 fn create_function_percent_type() {
428 assert_snapshot!(semantic_tokens(
429 "
430create function f(a t.c%type)
431returns t.b%type
432as '' language plpgsql;
433",
434 ), @r#"
435 "f" @ 17..18: Function
436 "a" @ 19..20: Parameter
437 "#);
438 }
439
440 #[test]
441 fn select_keywords() {
442 assert_snapshot!(semantic_tokens("
443select 1 and, 2 select;
444"), @r#"
445 "and" @ 10..13: Column
446 "select" @ 17..23: Column
447 "#)
448 }
449
450 #[test]
451 fn positional_param() {
452 assert_snapshot!(semantic_tokens("
453select $1, $2;
454"), @r#"
455 "$1" @ 8..10: PositionalParam
456 "$2" @ 12..14: PositionalParam
457 "#)
458 }
459
460 #[test]
461 fn insert_column_list() {
462 assert_snapshot!(semantic_tokens(
463 "
464create table products (product_no bigint, name text, price text);
465insert into products (product_no, name, price) values
466 (1, 'Cheese', 9.99),
467 (2, 'Bread', 1.99),
468 (3, 'Milk', 2.99);
469",
470 ), @r#"
471 "products" @ 14..22: Table
472 "product_no" @ 24..34: Column
473 "bigint" @ 35..41: Type
474 "name" @ 43..47: Column
475 "text" @ 48..52: Type
476 "price" @ 54..59: Column
477 "text" @ 60..64: Type
478 "products" @ 79..87: Table
479 "product_no" @ 89..99: Column
480 "name" @ 101..105: Column
481 "price" @ 107..112: Column
482 "#)
483 }
484
485 #[test]
486 fn from_alias_column_types() {
487 assert_snapshot!(semantic_tokens(
488 "
489select *
490from f as t(a int, b jsonb, c text, x int, ca char(5)[], ia int[][], r text);
491",
492 ), @r#"
493 "t" @ 20..21: Table
494 "a" @ 22..23: Column
495 "int" @ 24..27: Type
496 "b" @ 29..30: Column
497 "jsonb" @ 31..36: Type
498 "c" @ 38..39: Column
499 "text" @ 40..44: Type
500 "x" @ 46..47: Column
501 "int" @ 48..51: Type
502 "ca" @ 53..55: Column
503 "char" @ 56..60: Type
504 "ia" @ 67..69: Column
505 "int" @ 70..73: Type
506 "r" @ 79..80: Column
507 "text" @ 81..85: Type
508 "#);
509 }
510
511 #[test]
512 fn json_table_columns() {
513 assert_snapshot!(semantic_tokens(
514 "
515select *
516from my_films,
517json_table(
518 js,
519 '$.favorites[*]' columns (
520 id for ordinality,
521 kind text path '$.kind'
522 )
523) as jt;
524",
525 ), @r#"
526 "id" @ 76..78: Column
527 "kind" @ 99..103: Column
528 "text" @ 104..108: Type
529 "jt" @ 132..134: Table
530 "#);
531 }
532
533 #[test]
534 fn xml_table_columns() {
535 assert_snapshot!(semantic_tokens(
536 "
537select *
538from xmltable(
539 '/root/item'
540 passing xmlparse(document '<root><item id=\"1\"/></root>')
541 columns
542 row_num for ordinality,
543 item_id integer path '@id'
544);
545",
546 ), @r#"
547 "row_num" @ 113..120: Column
548 "item_id" @ 141..148: Column
549 "integer" @ 149..156: Type
550 "#);
551 }
552
553 #[test]
554 fn cast_types() {
555 assert_snapshot!(semantic_tokens(
556 "
557select '1'::jsonb, '2'::json, cast(1 as integer), cast(1 as int4[][]), cast(1 as varchar(10));
558",
559 ), @r#"
560 "jsonb" @ 13..18: Type
561 "json" @ 25..29: Type
562 "integer" @ 41..48: Type
563 "int4" @ 61..65: Type
564 "varchar" @ 82..89: Type
565 "#);
566 }
567
568 #[test]
569 fn cast_double() {
570 assert_snapshot!(semantic_tokens(
571 "
572select '1'::double precision;
573",
574 ), @r#"
575 "double" @ 13..19: Type
576 "precision" @ 20..29: Type
577 "#);
578 }
579
580 #[test]
581 fn cast_time_and_timestamp_time_zone() {
582 assert_snapshot!(semantic_tokens(
583 "
584select cast(1 as timestamp with time zone), cast(1 as timestamp without time zone), cast(1 as time with time zone), cast(1 as time without time zone);
585",
586 ), @r#"
587 "timestamp" @ 18..27: Type
588 "with" @ 28..32: Type
589 "time" @ 33..37: Type
590 "zone" @ 38..42: Type
591 "timestamp" @ 55..64: Type
592 "without" @ 65..72: Type
593 "time" @ 73..77: Type
594 "zone" @ 78..82: Type
595 "time" @ 95..99: Type
596 "with" @ 100..104: Type
597 "time" @ 105..109: Type
598 "zone" @ 110..114: Type
599 "time" @ 127..131: Type
600 "without" @ 132..139: Type
601 "time" @ 140..144: Type
602 "zone" @ 145..149: Type
603 "#);
604 }
605
606 #[test]
607 fn cast_national_character_varying_type() {
608 assert_snapshot!(semantic_tokens(
609 "
610select 'foo'::national character varying;
611",
612 ), @r#"
613 "national" @ 15..23: Type
614 "character" @ 24..33: Type
615 "varying" @ 34..41: Type
616 "#);
617 }
618
619 #[test]
620 fn create_function_returns_setof_type() {
621 assert_snapshot!(semantic_tokens(
622 "
623create function f() returns setof int
624as 'select 1'
625language sql;
626",
627 ), @r#"
628 "f" @ 17..18: Function
629 "setof" @ 29..34: Type
630 "int" @ 35..38: Type
631 "#);
632 }
633
634 #[test]
635 fn create_table_temporal_primary_key_column_types() {
636 assert_snapshot!(semantic_tokens(
637 "
638-- temporal_primary_key
639CREATE TABLE addresses (
640 id int8 generated BY DEFAULT AS IDENTITY,
641 valid_range tstzrange NOT NULL DEFAULT tstzrange(now(), 'infinity', '[)'),
642 recipient text NOT NULL,
643 PRIMARY KEY (id, valid_range WITHOUT OVERLAPS)
644);
645",
646 ), @r#"
647 "addresses" @ 38..47: Table
648 "id" @ 54..56: Column
649 "int8" @ 57..61: Type
650 "valid_range" @ 100..111: Column
651 "tstzrange" @ 112..121: Type
652 "NOT" @ 122..125: Keyword
653 "tstzrange" @ 139..148: Function
654 "now" @ 149..152: Function
655 "recipient" @ 179..188: Column
656 "text" @ 189..193: Type
657 "NOT" @ 194..197: Keyword
658 "id" @ 221..223: Column
659 "valid_range" @ 225..236: Column
660 "#);
661 }
662
663 #[test]
664 fn like_clause_keyword() {
665 assert_snapshot!(semantic_tokens(
666 "
667create table products(a text);
668create table test (
669 like products
670);
671",
672 ), @r#"
673 "products" @ 14..22: Table
674 "a" @ 23..24: Column
675 "text" @ 25..29: Type
676 "test" @ 45..49: Table
677 "like" @ 54..58: Keyword
678 "products" @ 59..67: Table
679 "#)
680 }
681
682 #[test]
683 fn partition_for_values_in_keywords() {
684 assert_snapshot!(semantic_tokens(
685 "
686create table t(a int);
687create table t_1 partition of t for values in (1);
688",
689 ), @r#"
690 "t" @ 14..15: Table
691 "a" @ 16..17: Column
692 "int" @ 18..21: Type
693 "t_1" @ 37..40: Table
694 "t" @ 54..55: Table
695 "in" @ 67..69: Keyword
696 "#)
697 }
698
699 #[test]
700 fn positional_param_and_cast_type() {
701 assert_snapshot!(semantic_tokens(
702 "
703select $2::jsonb;
704",
705 ), @r#"
706 "$2" @ 8..10: PositionalParam
707 "jsonb" @ 12..17: Type
708 "#);
709 }
710
711 #[test]
712 fn select_target_column() {
713 assert_snapshot!(semantic_tokens(
714 "
715create table t(a int, b text);
716select a, b from t;
717",
718 ), @r#"
719 "t" @ 14..15: Table
720 "a" @ 16..17: Column
721 "int" @ 18..21: Type
722 "b" @ 23..24: Column
723 "text" @ 25..29: Type
724 "a" @ 39..40: Column
725 "b" @ 42..43: Column
726 "t" @ 49..50: Table
727 "#);
728 }
729
730 #[test]
731 fn select_target_qualified_column() {
732 assert_snapshot!(semantic_tokens(
733 "
734create table t(a int);
735select t.a from t;
736",
737 ), @r#"
738 "t" @ 14..15: Table
739 "a" @ 16..17: Column
740 "int" @ 18..21: Type
741 "t" @ 31..32: Table
742 "a" @ 33..34: Column
743 "t" @ 40..41: Table
744 "#);
745 }
746
747 #[test]
748 fn select_target_function_call() {
749 assert_snapshot!(semantic_tokens(
750 "
751create function f() returns int as 'select 1' language sql;
752select f();
753",
754 ), @r#"
755 "f" @ 17..18: Function
756 "int" @ 29..32: Type
757 "f" @ 68..69: Function
758 "#);
759 }
760
761 #[test]
762 fn select_function_arg_and_qualified_column() {
763 assert_snapshot!(semantic_tokens(
764 "
765create table t(a int);
766create function b(t) returns int as 'select 1' language sql;
767select b(t), t.b from t;
768",
769 ), @r#"
770 "t" @ 14..15: Table
771 "a" @ 16..17: Column
772 "int" @ 18..21: Type
773 "b" @ 40..41: Function
774 "t" @ 42..43: Type
775 "int" @ 53..56: Type
776 "b" @ 92..93: Function
777 "t" @ 94..95: Table
778 "t" @ 98..99: Table
779 "b" @ 100..101: Function
780 "t" @ 107..108: Table
781 "#);
782 }
783
784 #[test]
785 fn policy_field_style_function_call() {
786 assert_snapshot!(semantic_tokens(
787 "
788create table t(c int);
789create function x(t) returns int as 'select 1' language sql;
790create policy p on t
791 with check (t.x > 0 and t.c > 0);
792",
793 ), @r#"
794 "t" @ 14..15: Table
795 "c" @ 16..17: Column
796 "int" @ 18..21: Type
797 "x" @ 40..41: Function
798 "t" @ 42..43: Type
799 "int" @ 53..56: Type
800 "t" @ 104..105: Table
801 "t" @ 120..121: Table
802 "x" @ 122..123: Function
803 "t" @ 132..133: Table
804 "c" @ 134..135: Column
805 "#);
806 }
807
808 #[test]
809 fn with_cte_name() {
810 assert_snapshot!(semantic_tokens(
811 "
812with t as (
813 select 1
814)
815select * from t;
816",
817 ), @r#"
818 "t" @ 6..7: Table
819 "t" @ 40..41: Table
820 "#);
821 }
822
823 #[test]
824 fn select_target_schema_qualified() {
825 assert_snapshot!(semantic_tokens(
826 "
827create schema s;
828create table s.t(a int);
829select s.t.a from s.t;
830",
831 ), @r#"
832 "s" @ 15..16: Schema
833 "s" @ 31..32: Schema
834 "t" @ 33..34: Table
835 "a" @ 35..36: Column
836 "int" @ 37..40: Type
837 "s" @ 50..51: Schema
838 "t" @ 52..53: Table
839 "a" @ 54..55: Column
840 "s" @ 61..62: Schema
841 "t" @ 63..64: Table
842 "#);
843 }
844}