1use rowan::{NodeOrToken, TextRange};
2use salsa::Database as Db;
3use squawk_syntax::{
4 SyntaxElement, SyntaxKind,
5 ast::{self, AstNode},
6};
7
8use crate::db::{File, parse};
9
10fn highlight_param_mode(out: &mut SemanticTokenBuilder, mode: ast::ParamMode) {
11 match mode {
12 ast::ParamMode::ParamIn(param_in) => {
13 if let Some(token) = param_in.in_token() {
14 out.push_keyword(token.into());
15 }
16 }
17 ast::ParamMode::ParamInOut(param_in_out) => {
18 if let Some(token) = param_in_out.in_token() {
19 out.push_keyword(token.into());
20 }
21 if let Some(token) = param_in_out.inout_token() {
22 out.push_keyword(token.into());
23 }
24 if let Some(token) = param_in_out.out_token() {
25 out.push_keyword(token.into());
26 }
27 }
28 ast::ParamMode::ParamOut(param_out) => {
29 if let Some(token) = param_out.out_token() {
30 out.push_keyword(token.into());
31 }
32 }
33 ast::ParamMode::ParamVariadic(param_variadic) => {
34 if let Some(token) = param_variadic.variadic_token() {
35 out.push_keyword(token.into());
36 }
37 }
38 }
39}
40
41fn highlight_type(out: &mut SemanticTokenBuilder, ty: ast::Type) {
42 match ty {
43 ast::Type::ArrayType(array_type) => {
44 if let Some(ty) = array_type.ty() {
45 highlight_type(out, ty);
46 }
47 }
48 ast::Type::BitType(bit_type) => {
49 if let Some(token) = bit_type.bit_token() {
50 out.push_type(token.into());
51 }
52 }
53 ast::Type::CharType(char_type) => {
54 if let Some(token) = char_type
55 .varchar_token()
56 .or_else(|| char_type.nchar_token())
57 .or_else(|| char_type.character_token())
58 .or_else(|| char_type.char_token())
59 {
60 out.push_type(token.into());
61 };
62 }
63 ast::Type::DoubleType(double_type) => {
64 if let Some(token) = double_type.double_token() {
65 out.push_type(token.into());
66 }
67 }
68 ast::Type::ExprType(_) => (),
69 ast::Type::IntervalType(interval_type) => {
70 if let Some(token) = interval_type.interval_token() {
71 out.push_type(token.into());
72 }
73 }
74 ast::Type::PathType(path_type) => {
75 if let Some(name_ref) = path_type
76 .path()
77 .and_then(|path| path.segment())
78 .and_then(|ps| ps.name_ref())
79 {
80 out.push_type(name_ref.syntax().clone().into());
81 }
82 }
83 ast::Type::PercentType(_) => (),
84 ast::Type::TimeType(time_type) => {
85 if let Some(token) = time_type
86 .timestamp_token()
87 .or_else(|| time_type.time_token())
88 {
89 out.push_type(token.into());
90 }
91 }
92 }
93}
94
95#[derive(Debug, Clone, PartialEq, Eq)]
97pub struct SemanticToken {
98 pub range: TextRange,
99 pub token_type: SemanticTokenType,
100 pub modifiers: Option<SemanticTokenModifier>,
101}
102
103#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
104#[repr(u8)]
105pub enum SemanticTokenModifier {
106 Definition = 0,
107 Readonly,
108 Documentation,
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
113pub enum SemanticTokenType {
114 Keyword,
115 String,
116 Bool,
117 Number,
118 Function,
119 Operator,
120 Punctuation,
121 Name,
122 NameRef,
123 Comment,
124 Type,
125 Parameter,
126 PositionalParam,
127}
128
129#[derive(Default)]
130struct SemanticTokenBuilder {
131 tokens: Vec<SemanticToken>,
132}
133
134impl SemanticTokenBuilder {
135 fn build(mut self) -> Vec<SemanticToken> {
136 self.tokens
137 .sort_by_key(|token| (token.range.start(), token.range.end()));
138 self.tokens
139 }
140
141 fn push_keyword(&mut self, syntax_element: SyntaxElement) {
142 self.push_token(syntax_element, SemanticTokenType::Keyword);
143 }
144
145 fn push_type(&mut self, syntax_element: SyntaxElement) {
146 self.push_token(syntax_element, SemanticTokenType::Type);
147 }
148
149 fn push_token(&mut self, syntax_element: SyntaxElement, token_type: SemanticTokenType) {
150 self.tokens.push(SemanticToken {
151 range: syntax_element.text_range(),
152 token_type,
153 modifiers: None,
154 });
155 }
156}
157
158#[salsa::tracked]
159pub fn semantic_tokens(
160 db: &dyn Db,
161 file: File,
162 range_to_highlight: Option<TextRange>,
163) -> Vec<SemanticToken> {
164 let parse = parse(db, file);
165 let tree = parse.tree();
166 let root = tree.syntax();
167
168 let (root, range_to_highlight) = {
170 let source_file = root;
171 match range_to_highlight {
172 Some(range) => {
173 let node = match source_file.covering_element(range) {
174 NodeOrToken::Node(it) => it,
175 NodeOrToken::Token(it) => it.parent().unwrap_or_else(|| source_file.clone()),
176 };
177 (node, range)
178 }
179 None => (source_file.clone(), source_file.text_range()),
180 }
181 };
182
183 let mut out = SemanticTokenBuilder::default();
184
185 let preorder = root.preorder_with_tokens();
187 for event in preorder {
188 use rowan::WalkEvent::{Enter, Leave};
189
190 let range = match &event {
191 Enter(it) | Leave(it) => it.text_range(),
192 };
193
194 if range_to_highlight.intersect(range).is_none() {
196 continue;
197 }
198
199 match event {
200 Enter(NodeOrToken::Node(node)) => {
201 if let Some(target) = ast::Target::cast(node.clone())
202 && let Some(as_name) = target.as_name()
203 && let Some(name) = as_name.name()
204 {
205 out.push_token(name.syntax().clone().into(), SemanticTokenType::Name);
206 };
207
208 if let Some(alias) = ast::Alias::cast(node.clone())
209 && let Some(column_list) = alias.column_list()
210 {
211 for column in column_list.columns() {
212 if let Some(ty) = column.ty() {
213 highlight_type(&mut out, ty);
214 }
215 }
216 }
217
218 if let Some(cast_expr) = ast::CastExpr::cast(node.clone())
219 && let Some(ty) = cast_expr.ty()
220 {
221 highlight_type(&mut out, ty);
222 }
223
224 if let Some(create_function) = ast::CreateFunction::cast(node) {
225 if let Some(param_list) = create_function.param_list() {
226 for param in param_list.params() {
227 if let Some(mode) = param.mode() {
228 highlight_param_mode(&mut out, mode);
229 }
230 if let Some(name) = param.name() {
231 out.push_token(
232 name.syntax().clone().into(),
233 SemanticTokenType::Parameter,
234 );
235 }
236 if let Some(ty) = param.ty() {
237 highlight_type(&mut out, ty);
238 }
239 }
240 }
241
242 if let Some(ret_type) = create_function.ret_type() {
243 if let Some(ty) = ret_type.ty() {
244 highlight_type(&mut out, ty);
245 }
246 if let Some(table_arg_list) = ret_type.table_arg_list() {
247 for arg in table_arg_list.args() {
248 if let ast::TableArg::Column(column) = arg
249 && let Some(ty) = column.ty()
250 {
251 highlight_type(&mut out, ty);
252 }
253 }
254 }
255 }
256 }
257 }
258 Enter(NodeOrToken::Token(token)) => {
259 if token.kind() == SyntaxKind::WHITESPACE {
260 continue;
261 }
262 if token.kind() == SyntaxKind::POSITIONAL_PARAM {
263 out.push_token(token.into(), SemanticTokenType::PositionalParam);
264 }
265 }
266 Leave(_) => {}
267 }
268 }
269
270 out.build()
271}
272
273#[cfg(test)]
274mod test {
275 use crate::db::{Database, File};
276 use insta::assert_snapshot;
277 use std::fmt::Write;
278
279 fn semantic_tokens(sql: &str) -> String {
280 let db = Database::default();
281 let file = File::new(&db, sql.to_string().into());
282 let tokens = super::semantic_tokens(&db, file, None);
283
284 let mut result = String::new();
285 for token in tokens {
286 let start: usize = token.range.start().into();
287 let end: usize = token.range.end().into();
288 let token_text = &sql[start..end];
289 let modifiers_text = "";
291 writeln!(
292 result,
293 "{:?} @ {}..{}: {:?}{}",
294 token_text, start, end, token.token_type, modifiers_text
295 )
296 .unwrap();
297 }
298 result
299 }
300
301 #[test]
302 fn create_function_misc_params() {
303 assert_snapshot!(semantic_tokens(
304 "
305create function add(
306 in a int = 1,
307 inout b text default 'x',
308 in out c varchar(10)[],
309 variadic d int[]
310) returns int
311as 'select $1 + $2'
312language sql;
313",
314 ), @r#"
315 "in" @ 24..26: Keyword
316 "a" @ 27..28: Parameter
317 "int" @ 29..32: Type
318 "inout" @ 40..45: Keyword
319 "b" @ 46..47: Parameter
320 "text" @ 48..52: Type
321 "in" @ 68..70: Keyword
322 "out" @ 71..74: Keyword
323 "c" @ 75..76: Parameter
324 "varchar" @ 77..84: Type
325 "variadic" @ 94..102: Keyword
326 "d" @ 103..104: Parameter
327 "int" @ 105..108: Type
328 "int" @ 121..124: Type
329 "#);
330 }
331
332 #[test]
333 fn create_function_param_mode_type() {
334 assert_snapshot!(semantic_tokens(
335 "
336create function f(int8 in int8)
337returns void
338as '' language sql;
339",
340 ), @r#"
341 "int8" @ 19..23: Parameter
342 "in" @ 24..26: Keyword
343 "int8" @ 27..31: Type
344 "void" @ 41..45: Type
345 "#);
346 }
347
348 #[test]
349 fn create_function_percent_type() {
350 assert_snapshot!(semantic_tokens(
351 "
352create function f(a t.c%type)
353returns t.b%type
354as '' language plpgsql;
355",
356 ), @r#""a" @ 19..20: Parameter"#);
357 }
358
359 #[test]
360 fn select_keywords() {
361 assert_snapshot!(semantic_tokens("
362select 1 and, 2 select;
363"), @r#"
364 "and" @ 10..13: Name
365 "select" @ 17..23: Name
366 "#)
367 }
368
369 #[test]
370 fn positional_param() {
371 assert_snapshot!(semantic_tokens("
372select $1, $2;
373"), @r#"
374 "$1" @ 8..10: PositionalParam
375 "$2" @ 12..14: PositionalParam
376 "#)
377 }
378
379 #[test]
380 fn from_alias_column_types() {
381 assert_snapshot!(semantic_tokens(
382 "
383select *
384from f as t(a int, b jsonb, c text, x int, ca char(5)[], ia int[][], r jbpop);
385",
386 ), @r#"
387 "int" @ 24..27: Type
388 "jsonb" @ 31..36: Type
389 "text" @ 40..44: Type
390 "int" @ 48..51: Type
391 "char" @ 56..60: Type
392 "int" @ 70..73: Type
393 "jbpop" @ 81..86: Type
394 "#);
395 }
396
397 #[test]
398 fn cast_types() {
399 assert_snapshot!(semantic_tokens(
400 "
401select '1'::jsonb, '2'::json, cast(1 as integer), cast(1 as int4[][]), cast(1 as varchar(10));
402",
403 ), @r#"
404 "jsonb" @ 13..18: Type
405 "json" @ 25..29: Type
406 "integer" @ 41..48: Type
407 "int4" @ 61..65: Type
408 "varchar" @ 82..89: Type
409 "#);
410 }
411
412 #[test]
413 fn positional_param_and_cast_type() {
414 assert_snapshot!(semantic_tokens(
415 "
416select $2::jsonb;
417",
418 ), @r#"
419 "$2" @ 8..10: PositionalParam
420 "jsonb" @ 12..17: Type
421 "#);
422 }
423}