squawk_ide/
find_references.rs

1use crate::binder::{self, Binder};
2use crate::offsets::token_from_offset;
3use crate::resolve;
4use rowan::{TextRange, TextSize};
5use squawk_syntax::{
6    SyntaxNodePtr,
7    ast::{self, AstNode},
8    match_ast,
9};
10
11pub fn find_references(file: &ast::SourceFile, offset: TextSize) -> Vec<TextRange> {
12    let binder = binder::bind(file);
13    let Some(target) = find_target(file, offset, &binder) else {
14        return vec![];
15    };
16
17    let mut refs = vec![];
18
19    for node in file.syntax().descendants() {
20        match_ast! {
21            match node {
22                ast::NameRef(name_ref) => {
23                    if let Some(found) = resolve::resolve_name_ref(&binder, &name_ref)
24                      && found == target
25                    {
26                        refs.push(name_ref.syntax().text_range());
27                    }
28                },
29                ast::Name(name) => {
30                    let found = SyntaxNodePtr::new(name.syntax());
31                    if found == target {
32                        refs.push(name.syntax().text_range());
33                    }
34                },
35                _ => (),
36            }
37        }
38    }
39
40    refs.sort_by_key(|range| range.start());
41    refs
42}
43
44fn find_target(file: &ast::SourceFile, offset: TextSize, binder: &Binder) -> Option<SyntaxNodePtr> {
45    let token = token_from_offset(file, offset)?;
46    let parent = token.parent()?;
47
48    if let Some(name) = ast::Name::cast(parent.clone()) {
49        return Some(SyntaxNodePtr::new(name.syntax()));
50    }
51
52    if let Some(name_ref) = ast::NameRef::cast(parent.clone())
53        && let Some(ptr) = resolve::resolve_name_ref(binder, &name_ref)
54    {
55        return Some(ptr);
56    }
57
58    None
59}
60
61#[cfg(test)]
62mod test {
63    use crate::find_references::find_references;
64    use crate::test_utils::fixture;
65    use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
66    use insta::assert_snapshot;
67    use squawk_syntax::ast;
68
69    #[track_caller]
70    fn find_refs(sql: &str) -> String {
71        let (mut offset, sql) = fixture(sql);
72        offset = offset.checked_sub(1.into()).unwrap_or_default();
73        let parse = ast::SourceFile::parse(&sql);
74        assert_eq!(parse.errors(), vec![]);
75        let file: ast::SourceFile = parse.tree();
76
77        let references = find_references(&file, offset);
78
79        let offset_usize: usize = offset.into();
80
81        let labels: Vec<String> = (1..=references.len())
82            .map(|i| format!("{}. reference", i))
83            .collect();
84
85        let mut snippet = Snippet::source(&sql).fold(true).annotation(
86            AnnotationKind::Context
87                .span(offset_usize..offset_usize + 1)
88                .label("0. query"),
89        );
90
91        for (i, range) in references.iter().enumerate() {
92            snippet = snippet.annotation(
93                AnnotationKind::Context
94                    .span((*range).into())
95                    .label(&labels[i]),
96            );
97        }
98
99        let group = Level::INFO.primary_title("references").element(snippet);
100        let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
101        renderer
102            .render(&[group])
103            .to_string()
104            .replace("info: references", "")
105    }
106
107    #[test]
108    fn simple_table_reference() {
109        assert_snapshot!(find_refs("
110create table t();
111drop table t$0;
112"), @r"
113          ╭▸ 
114        2 │ create table t();
115          │              ─ 1. reference
116        3 │ drop table t;
117          │            ┬
118          │            │
119          │            0. query
120          ╰╴           2. reference
121        ");
122    }
123
124    #[test]
125    fn multiple_references() {
126        assert_snapshot!(find_refs("
127create table users();
128drop table users$0;
129table users;
130"), @r"
131          ╭▸ 
132        2 │ create table users();
133          │              ───── 1. reference
134        3 │ drop table users;
135          │            ┬───┬
136          │            │   │
137          │            │   0. query
138          │            2. reference
139        4 │ table users;
140          ╰╴      ───── 3. reference
141        ");
142    }
143
144    #[test]
145    fn find_from_definition() {
146        assert_snapshot!(find_refs("
147create table t$0();
148drop table t;
149"), @r"
150          ╭▸ 
151        2 │ create table t();
152          │              ┬
153          │              │
154          │              0. query
155          │              1. reference
156        3 │ drop table t;
157          ╰╴           ─ 2. reference
158        ");
159    }
160
161    #[test]
162    fn with_schema_qualified() {
163        assert_snapshot!(find_refs("
164create table public.users();
165drop table public.users$0;
166table users;
167"), @r"
168          ╭▸ 
169        2 │ create table public.users();
170          │                     ───── 1. reference
171        3 │ drop table public.users;
172          │                   ┬───┬
173          │                   │   │
174          │                   │   0. query
175          │                   2. reference
176        4 │ table users;
177          ╰╴      ───── 3. reference
178        ");
179    }
180
181    #[test]
182    fn temp_table_do_not_shadows_public() {
183        assert_snapshot!(find_refs("
184create table t();
185create temp table t$0();
186drop table t;
187"), @r"
188          ╭▸ 
189        3 │ create temp table t();
190          │                   ┬
191          │                   │
192          │                   0. query
193          ╰╴                  1. reference
194        ");
195    }
196
197    #[test]
198    fn different_schema_no_match() {
199        assert_snapshot!(find_refs("
200create table foo.t();
201create table bar.t$0();
202"), @r"
203          ╭▸ 
204        3 │ create table bar.t();
205          │                  ┬
206          │                  │
207          │                  0. query
208          ╰╴                 1. reference
209        ");
210    }
211
212    #[test]
213    fn with_search_path() {
214        assert_snapshot!(find_refs("
215set search_path to myschema;
216create table myschema.users$0();
217drop table users;
218"), @r"
219          ╭▸ 
220        3 │ create table myschema.users();
221          │                       ┬───┬
222          │                       │   │
223          │                       │   0. query
224          │                       1. reference
225        4 │ drop table users;
226          ╰╴           ───── 2. reference
227        ");
228    }
229
230    #[test]
231    fn temp_table_with_pg_temp_schema() {
232        assert_snapshot!(find_refs("
233create temp table t();
234drop table pg_temp.t$0;
235"), @r"
236          ╭▸ 
237        2 │ create temp table t();
238          │                   ─ 1. reference
239        3 │ drop table pg_temp.t;
240          │                    ┬
241          │                    │
242          │                    0. query
243          ╰╴                   2. reference
244        ");
245    }
246
247    #[test]
248    fn case_insensitive() {
249        assert_snapshot!(find_refs("
250create table Users();
251drop table USERS$0;
252table users;
253"), @r"
254          ╭▸ 
255        2 │ create table Users();
256          │              ───── 1. reference
257        3 │ drop table USERS;
258          │            ┬───┬
259          │            │   │
260          │            │   0. query
261          │            2. reference
262        4 │ table users;
263          ╰╴      ───── 3. reference
264        ");
265    }
266    #[test]
267    fn case_insensitive_part_2() {
268        // we should see refs for `drop table` and `table`
269        assert_snapshot!(find_refs(r#"
270create table actors();
271create table "Actors"();
272drop table ACTORS$0;
273table actors;
274"#), @r#"
275          ╭▸ 
276        2 │ create table actors();
277          │              ────── 1. reference
278        3 │ create table "Actors"();
279        4 │ drop table ACTORS;
280          │            ┬────┬
281          │            │    │
282          │            │    0. query
283          │            2. reference
284        5 │ table actors;
285          ╰╴      ────── 3. reference
286        "#);
287    }
288
289    #[test]
290    fn case_insensitive_with_schema() {
291        assert_snapshot!(find_refs("
292create table Public.Users();
293drop table PUBLIC.USERS$0;
294table public.users;
295"), @r"
296          ╭▸ 
297        2 │ create table Public.Users();
298          │                     ───── 1. reference
299        3 │ drop table PUBLIC.USERS;
300          │                   ┬───┬
301          │                   │   │
302          │                   │   0. query
303          │                   2. reference
304        4 │ table public.users;
305          ╰╴             ───── 3. reference
306        ");
307    }
308
309    #[test]
310    fn no_partial_match() {
311        assert_snapshot!(find_refs("
312create table t$0();
313create table temp_t();
314"), @r"
315          ╭▸ 
316        2 │ create table t();
317          │              ┬
318          │              │
319          │              0. query
320          ╰╴             1. reference
321        ");
322    }
323
324    #[test]
325    fn identifier_boundaries() {
326        assert_snapshot!(find_refs("
327create table foo$0();
328drop table foo;
329drop table foo1;
330drop table barfoo;
331drop table foo_bar;
332"), @r"
333          ╭▸ 
334        2 │ create table foo();
335          │              ┬─┬
336          │              │ │
337          │              │ 0. query
338          │              1. reference
339        3 │ drop table foo;
340          ╰╴           ─── 2. reference
341        ");
342    }
343}