squawk_ide/
goto_definition.rs

1use rowan::{TextRange, TextSize};
2use squawk_syntax::{
3    SyntaxKind, SyntaxToken,
4    ast::{self, AstNode},
5};
6
7pub fn goto_definition(file: ast::SourceFile, offset: TextSize) -> Option<TextRange> {
8    let token = token_from_offset(&file, offset)?;
9    let parent = token.parent()?;
10
11    // goto def on case exprs
12    if (token.kind() == SyntaxKind::WHEN_KW && parent.kind() == SyntaxKind::WHEN_CLAUSE)
13        || (token.kind() == SyntaxKind::ELSE_KW && parent.kind() == SyntaxKind::ELSE_CLAUSE)
14        || (token.kind() == SyntaxKind::END_KW && parent.kind() == SyntaxKind::CASE_EXPR)
15    {
16        for parent in token.parent_ancestors() {
17            if let Some(case_expr) = ast::CaseExpr::cast(parent) {
18                if let Some(case_token) = case_expr.case_token() {
19                    return Some(case_token.text_range());
20                }
21            }
22        }
23    }
24
25    return None;
26}
27
28fn token_from_offset(file: &ast::SourceFile, offset: TextSize) -> Option<SyntaxToken> {
29    let mut token = file.syntax().token_at_offset(offset).right_biased()?;
30    // want to be lenient in case someone clicks the trailing `;` of a line
31    // instead of an identifier
32    if token.kind() == SyntaxKind::SEMICOLON {
33        token = token.prev_token()?;
34    }
35    return Some(token);
36}
37
38#[cfg(test)]
39mod test {
40    use crate::goto_definition::goto_definition;
41    use crate::test_utils::fixture;
42    use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
43    use insta::assert_snapshot;
44    use log::info;
45    use squawk_syntax::ast;
46
47    #[track_caller]
48    fn goto(sql: &str) -> String {
49        goto_(sql).expect("should always find a definition")
50    }
51
52    #[track_caller]
53    fn goto_(sql: &str) -> Option<String> {
54        info!("starting");
55        let (mut offset, sql) = fixture(sql);
56        // For go to def we want the previous character since we usually put the
57        // marker after the item we're trying to go to def on.
58        offset = offset.checked_sub(1.into()).unwrap_or_default();
59        let parse = ast::SourceFile::parse(&sql);
60        assert_eq!(parse.errors(), vec![]);
61        let file: ast::SourceFile = parse.tree();
62        if let Some(result) = goto_definition(file, offset) {
63            let offset: usize = offset.into();
64            let group = Level::INFO.primary_title("definition").element(
65                Snippet::source(&sql)
66                    .fold(true)
67                    .annotation(
68                        AnnotationKind::Context
69                            .span(result.into())
70                            .label("2. destination"),
71                    )
72                    .annotation(
73                        AnnotationKind::Context
74                            .span(offset..offset + 1)
75                            .label("1. source"),
76                    ),
77            );
78            let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
79            return Some(
80                renderer
81                    .render(&[group])
82                    .to_string()
83                    // hacky cleanup to make the text shorter
84                    .replace("info: definition", ""),
85            );
86        }
87        None
88    }
89
90    fn goto_not_found(sql: &str) {
91        assert!(goto_(sql).is_none(), "Should not find a definition");
92    }
93
94    #[test]
95    fn goto_case_when() {
96        assert_snapshot!(goto("
97select case when$0 x > 1 then 1 else 2 end;
98"), @r"
99          ╭▸ 
100        2 │ select case when x > 1 then 1 else 2 end;
101          │        ┬───    ─ 1. source
102          │        │
103          ╰╴       2. destination
104        ");
105    }
106
107    #[test]
108    fn goto_case_else() {
109        assert_snapshot!(goto("
110select case when x > 1 then 1 else$0 2 end;
111"), @r"
112          ╭▸ 
113        2 │ select case when x > 1 then 1 else 2 end;
114          ╰╴       ──── 2. destination       ─ 1. source
115        ");
116    }
117
118    #[test]
119    fn goto_case_end() {
120        assert_snapshot!(goto("
121select case when x > 1 then 1 else 2 end$0;
122"), @r"
123          ╭▸ 
124        2 │ select case when x > 1 then 1 else 2 end;
125          ╰╴       ──── 2. destination             ─ 1. source
126        ");
127    }
128
129    #[test]
130    fn goto_case_end_trailing_semi() {
131        assert_snapshot!(goto("
132select case when x > 1 then 1 else 2 end;$0
133"), @r"
134          ╭▸ 
135        2 │ select case when x > 1 then 1 else 2 end;
136          ╰╴       ──── 2. destination              ─ 1. source
137        ");
138    }
139
140    #[test]
141    fn goto_case_then_not_found() {
142        goto_not_found(
143            "
144select case when x > 1 then$0 1 else 2 end;
145",
146        )
147    }
148}