Skip to main content

squawk_ide/
semantic_tokens.rs

1use rowan::{NodeOrToken, TextRange};
2use salsa::Database as Db;
3use squawk_syntax::{
4    SyntaxElement, SyntaxKind,
5    ast::{self, AstNode},
6};
7
8use crate::db::{File, parse};
9
10fn highlight_param_mode(out: &mut SemanticTokenBuilder, mode: ast::ParamMode) {
11    match mode {
12        ast::ParamMode::ParamIn(param_in) => {
13            if let Some(token) = param_in.in_token() {
14                out.push_keyword(token.into());
15            }
16        }
17        ast::ParamMode::ParamInOut(param_in_out) => {
18            if let Some(token) = param_in_out.in_token() {
19                out.push_keyword(token.into());
20            }
21            if let Some(token) = param_in_out.inout_token() {
22                out.push_keyword(token.into());
23            }
24            if let Some(token) = param_in_out.out_token() {
25                out.push_keyword(token.into());
26            }
27        }
28        ast::ParamMode::ParamOut(param_out) => {
29            if let Some(token) = param_out.out_token() {
30                out.push_keyword(token.into());
31            }
32        }
33        ast::ParamMode::ParamVariadic(param_variadic) => {
34            if let Some(token) = param_variadic.variadic_token() {
35                out.push_keyword(token.into());
36            }
37        }
38    }
39}
40
41fn highlight_type(out: &mut SemanticTokenBuilder, ty: ast::Type) {
42    match ty {
43        ast::Type::ArrayType(array_type) => {
44            if let Some(ty) = array_type.ty() {
45                highlight_type(out, ty);
46            }
47        }
48        ast::Type::BitType(bit_type) => {
49            if let Some(token) = bit_type.bit_token() {
50                out.push_type(token.into());
51            }
52        }
53        ast::Type::CharType(char_type) => {
54            if let Some(token) = char_type
55                .varchar_token()
56                .or_else(|| char_type.nchar_token())
57                .or_else(|| char_type.character_token())
58                .or_else(|| char_type.char_token())
59            {
60                out.push_type(token.into());
61            };
62        }
63        ast::Type::DoubleType(double_type) => {
64            if let Some(token) = double_type.double_token() {
65                out.push_type(token.into());
66            }
67        }
68        ast::Type::ExprType(_) => (),
69        ast::Type::IntervalType(interval_type) => {
70            if let Some(token) = interval_type.interval_token() {
71                out.push_type(token.into());
72            }
73        }
74        ast::Type::PathType(path_type) => {
75            if let Some(name_ref) = path_type
76                .path()
77                .and_then(|path| path.segment())
78                .and_then(|ps| ps.name_ref())
79            {
80                out.push_type(name_ref.syntax().clone().into());
81            }
82        }
83        ast::Type::PercentType(_) => (),
84        ast::Type::TimeType(time_type) => {
85            if let Some(token) = time_type
86                .timestamp_token()
87                .or_else(|| time_type.time_token())
88            {
89                out.push_type(token.into());
90            }
91        }
92    }
93}
94
95/// A semantic token with its position and classification.
96#[derive(Debug, Clone, PartialEq, Eq)]
97pub struct SemanticToken {
98    pub range: TextRange,
99    pub token_type: SemanticTokenType,
100    pub modifiers: Option<SemanticTokenModifier>,
101}
102
103#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
104#[repr(u8)]
105pub enum SemanticTokenModifier {
106    Definition = 0,
107    Readonly,
108    Documentation,
109}
110
111/// Semantic token types supported by the language server.
112#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
113pub enum SemanticTokenType {
114    Keyword,
115    String,
116    Bool,
117    Number,
118    Function,
119    Operator,
120    Punctuation,
121    Name,
122    NameRef,
123    Comment,
124    Type,
125    Parameter,
126    PositionalParam,
127}
128
129#[derive(Default)]
130struct SemanticTokenBuilder {
131    tokens: Vec<SemanticToken>,
132}
133
134impl SemanticTokenBuilder {
135    fn build(mut self) -> Vec<SemanticToken> {
136        self.tokens
137            .sort_by_key(|token| (token.range.start(), token.range.end()));
138        self.tokens
139    }
140
141    fn push_keyword(&mut self, syntax_element: SyntaxElement) {
142        self.push_token(syntax_element, SemanticTokenType::Keyword);
143    }
144
145    fn push_type(&mut self, syntax_element: SyntaxElement) {
146        self.push_token(syntax_element, SemanticTokenType::Type);
147    }
148
149    fn push_token(&mut self, syntax_element: SyntaxElement, token_type: SemanticTokenType) {
150        self.tokens.push(SemanticToken {
151            range: syntax_element.text_range(),
152            token_type,
153            modifiers: None,
154        });
155    }
156}
157
158#[salsa::tracked]
159pub fn semantic_tokens(
160    db: &dyn Db,
161    file: File,
162    range_to_highlight: Option<TextRange>,
163) -> Vec<SemanticToken> {
164    let parse = parse(db, file);
165    let tree = parse.tree();
166    let root = tree.syntax();
167
168    // Determine the root based on the given range.
169    let (root, range_to_highlight) = {
170        let source_file = root;
171        match range_to_highlight {
172            Some(range) => {
173                let node = match source_file.covering_element(range) {
174                    NodeOrToken::Node(it) => it,
175                    NodeOrToken::Token(it) => it.parent().unwrap_or_else(|| source_file.clone()),
176                };
177                (node, range)
178            }
179            None => (source_file.clone(), source_file.text_range()),
180        }
181    };
182
183    let mut out = SemanticTokenBuilder::default();
184
185    // Taken from: https://github.com/rust-lang/rust-analyzer/blob/2efc80078029894eec0699f62ec8d5c1a56af763/crates/ide/src/syntax_highlighting.rs#L267C21-L267C21
186    let preorder = root.preorder_with_tokens();
187    for event in preorder {
188        use rowan::WalkEvent::{Enter, Leave};
189
190        let range = match &event {
191            Enter(it) | Leave(it) => it.text_range(),
192        };
193
194        // Element outside of the viewport, no need to highlight
195        if range_to_highlight.intersect(range).is_none() {
196            continue;
197        }
198
199        match event {
200            Enter(NodeOrToken::Node(node)) => {
201                if let Some(target) = ast::Target::cast(node.clone())
202                    && let Some(as_name) = target.as_name()
203                    && let Some(name) = as_name.name()
204                {
205                    out.push_token(name.syntax().clone().into(), SemanticTokenType::Name);
206                };
207
208                if let Some(alias) = ast::Alias::cast(node.clone())
209                    && let Some(column_list) = alias.column_list()
210                {
211                    for column in column_list.columns() {
212                        if let Some(ty) = column.ty() {
213                            highlight_type(&mut out, ty);
214                        }
215                    }
216                }
217
218                if let Some(cast_expr) = ast::CastExpr::cast(node.clone())
219                    && let Some(ty) = cast_expr.ty()
220                {
221                    highlight_type(&mut out, ty);
222                }
223
224                if let Some(create_function) = ast::CreateFunction::cast(node) {
225                    if let Some(param_list) = create_function.param_list() {
226                        for param in param_list.params() {
227                            if let Some(mode) = param.mode() {
228                                highlight_param_mode(&mut out, mode);
229                            }
230                            if let Some(name) = param.name() {
231                                out.push_token(
232                                    name.syntax().clone().into(),
233                                    SemanticTokenType::Parameter,
234                                );
235                            }
236                            if let Some(ty) = param.ty() {
237                                highlight_type(&mut out, ty);
238                            }
239                        }
240                    }
241
242                    if let Some(ret_type) = create_function.ret_type() {
243                        if let Some(ty) = ret_type.ty() {
244                            highlight_type(&mut out, ty);
245                        }
246                        if let Some(table_arg_list) = ret_type.table_arg_list() {
247                            for arg in table_arg_list.args() {
248                                if let ast::TableArg::Column(column) = arg
249                                    && let Some(ty) = column.ty()
250                                {
251                                    highlight_type(&mut out, ty);
252                                }
253                            }
254                        }
255                    }
256                }
257            }
258            Enter(NodeOrToken::Token(token)) => {
259                if token.kind() == SyntaxKind::WHITESPACE {
260                    continue;
261                }
262                if token.kind() == SyntaxKind::POSITIONAL_PARAM {
263                    out.push_token(token.into(), SemanticTokenType::PositionalParam);
264                }
265            }
266            Leave(_) => {}
267        }
268    }
269
270    out.build()
271}
272
273#[cfg(test)]
274mod test {
275    use crate::db::{Database, File};
276    use insta::assert_snapshot;
277    use std::fmt::Write;
278
279    fn semantic_tokens(sql: &str) -> String {
280        let db = Database::default();
281        let file = File::new(&db, sql.to_string().into());
282        let tokens = super::semantic_tokens(&db, file, None);
283
284        let mut result = String::new();
285        for token in tokens {
286            let start: usize = token.range.start().into();
287            let end: usize = token.range.end().into();
288            let token_text = &sql[start..end];
289            // TODO: once we get modfifiers, we'll need to update this
290            let modifiers_text = "";
291            writeln!(
292                result,
293                "{:?} @ {}..{}: {:?}{}",
294                token_text, start, end, token.token_type, modifiers_text
295            )
296            .unwrap();
297        }
298        result
299    }
300
301    #[test]
302    fn create_function_misc_params() {
303        assert_snapshot!(semantic_tokens(
304            "
305create function add(
306  in a int = 1,
307  inout b text default 'x',
308  in out c varchar(10)[],
309  variadic d int[]
310) returns int
311as 'select $1 + $2'
312language sql;
313",
314        ), @r#"
315        "in" @ 24..26: Keyword
316        "a" @ 27..28: Parameter
317        "int" @ 29..32: Type
318        "inout" @ 40..45: Keyword
319        "b" @ 46..47: Parameter
320        "text" @ 48..52: Type
321        "in" @ 68..70: Keyword
322        "out" @ 71..74: Keyword
323        "c" @ 75..76: Parameter
324        "varchar" @ 77..84: Type
325        "variadic" @ 94..102: Keyword
326        "d" @ 103..104: Parameter
327        "int" @ 105..108: Type
328        "int" @ 121..124: Type
329        "#);
330    }
331
332    #[test]
333    fn create_function_param_mode_type() {
334        assert_snapshot!(semantic_tokens(
335            "
336create function f(int8 in int8)
337returns void
338as '' language sql;
339",
340        ), @r#"
341        "int8" @ 19..23: Parameter
342        "in" @ 24..26: Keyword
343        "int8" @ 27..31: Type
344        "void" @ 41..45: Type
345        "#);
346    }
347
348    #[test]
349    fn create_function_percent_type() {
350        assert_snapshot!(semantic_tokens(
351            "
352create function f(a t.c%type) 
353returns t.b%type 
354as '' language plpgsql;
355",
356        ), @r#""a" @ 19..20: Parameter"#);
357    }
358
359    #[test]
360    fn select_keywords() {
361        assert_snapshot!(semantic_tokens("
362select 1 and, 2 select;
363"), @r#"
364        "and" @ 10..13: Name
365        "select" @ 17..23: Name
366        "#)
367    }
368
369    #[test]
370    fn positional_param() {
371        assert_snapshot!(semantic_tokens("
372select $1, $2;
373"), @r#"
374        "$1" @ 8..10: PositionalParam
375        "$2" @ 12..14: PositionalParam
376        "#)
377    }
378
379    #[test]
380    fn from_alias_column_types() {
381        assert_snapshot!(semantic_tokens(
382            "
383select *
384from f as t(a int, b jsonb, c text, x int, ca char(5)[], ia int[][], r jbpop);
385",
386        ), @r#"
387        "int" @ 24..27: Type
388        "jsonb" @ 31..36: Type
389        "text" @ 40..44: Type
390        "int" @ 48..51: Type
391        "char" @ 56..60: Type
392        "int" @ 70..73: Type
393        "jbpop" @ 81..86: Type
394        "#);
395    }
396
397    #[test]
398    fn cast_types() {
399        assert_snapshot!(semantic_tokens(
400            "
401select '1'::jsonb, '2'::json, cast(1 as integer), cast(1 as int4[][]), cast(1 as varchar(10));
402",
403        ), @r#"
404        "jsonb" @ 13..18: Type
405        "json" @ 25..29: Type
406        "integer" @ 41..48: Type
407        "int4" @ 61..65: Type
408        "varchar" @ 82..89: Type
409        "#);
410    }
411
412    #[test]
413    fn positional_param_and_cast_type() {
414        assert_snapshot!(semantic_tokens(
415            "
416select $2::jsonb;
417",
418        ), @r#"
419        "$2" @ 8..10: PositionalParam
420        "jsonb" @ 12..17: Type
421        "#);
422    }
423}