Skip to main content

squawk_ide/
find_references.rs

1use crate::db::{File, parse};
2use crate::goto_definition;
3use crate::location::Location;
4use rowan::TextSize;
5use salsa::Database as Db;
6use squawk_syntax::ast::{self, AstNode};
7
8#[salsa::tracked]
9pub fn find_references(db: &dyn Db, file: File, offset: TextSize) -> Vec<Location> {
10    let targets = goto_definition::goto_definition(db, file, offset);
11    let Some(first) = targets.first() else {
12        return vec![];
13    };
14
15    let mut refs = targets.to_vec();
16
17    for node in parse(db, file)
18        .tree()
19        .syntax()
20        .descendants()
21        .filter(|x| ast::NameRef::can_cast(x.kind()))
22    {
23        let range = node.text_range();
24        let matches = goto_definition::goto_definition(db, file, range.start())
25            .into_iter()
26            .any(|location| targets.contains(&location));
27        if matches {
28            refs.push(Location {
29                file,
30                range,
31                kind: first.kind,
32            });
33        }
34    }
35    refs.sort_by_key(|loc| (loc.file != file, loc.range.start()));
36    refs
37}
38
39#[cfg(test)]
40mod test {
41    use crate::builtins::builtins_file;
42    use crate::db::{Database, File};
43    use crate::find_references::find_references;
44    use crate::test_utils::fixture;
45    use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
46    use insta::assert_snapshot;
47    use rowan::TextRange;
48    use rustc_hash::FxHashMap;
49
50    #[track_caller]
51    fn find_refs(sql: &str) -> String {
52        let (mut offset, sql) = fixture(sql);
53        offset = offset.checked_sub(1.into()).unwrap_or_default();
54        let db = Database::default();
55        let current_file = File::new(&db, sql.clone().into());
56        assert_eq!(crate::db::parse(&db, current_file).errors(), vec![]);
57
58        let references = find_references(&db, current_file, offset);
59        let offset_usize: usize = offset.into();
60
61        let mut file_paths = FxHashMap::default();
62        file_paths.insert(current_file, "current.sql");
63        file_paths.insert(builtins_file(&db), "builtins.sql");
64
65        let mut refs_by_file: FxHashMap<File, Vec<(usize, TextRange)>> = FxHashMap::default();
66        for (i, location) in references.iter().enumerate() {
67            refs_by_file
68                .entry(location.file)
69                .or_default()
70                .push((i + 1, location.range));
71        }
72
73        let multi_file = refs_by_file.len() > 1 || !refs_by_file.contains_key(&current_file);
74
75        let mut snippet = Snippet::source(&sql).fold(true);
76        if multi_file {
77            snippet = snippet.path(*file_paths.get(&current_file).unwrap());
78        }
79        snippet = snippet.annotation(
80            AnnotationKind::Context
81                .span(offset_usize..offset_usize + 1)
82                .label("0. query"),
83        );
84        if let Some(current_refs) = refs_by_file.remove(&current_file) {
85            snippet = annotate_refs(snippet, current_refs);
86        }
87
88        let mut groups = vec![Level::INFO.primary_title("references").element(snippet)];
89
90        for (ref_file, refs) in refs_by_file {
91            let path = file_paths.get(&ref_file).unwrap();
92            let other_snippet = Snippet::source(ref_file.content(&db).as_ref())
93                .path(*path)
94                .fold(true);
95            let other_snippet = annotate_refs(other_snippet, refs);
96            groups.push(
97                Level::INFO
98                    .primary_title("references")
99                    .element(other_snippet),
100            );
101        }
102
103        let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
104        renderer
105            .render(&groups)
106            .to_string()
107            .replace("info: references", "")
108    }
109
110    fn annotate_refs<'a>(
111        mut snippet: Snippet<'a, annotate_snippets::Annotation<'a>>,
112        refs: Vec<(usize, TextRange)>,
113    ) -> Snippet<'a, annotate_snippets::Annotation<'a>> {
114        for (label_index, range) in refs {
115            snippet = snippet.annotation(
116                AnnotationKind::Context
117                    .span(range.into())
118                    .label(format!("{}. reference", label_index)),
119            );
120        }
121        snippet
122    }
123
124    #[test]
125    fn simple_table_reference() {
126        assert_snapshot!(find_refs("
127create table t();
128drop table t$0;
129"), @r"
130          ╭▸ 
131        2 │ create table t();
132          │              ─ 1. reference
133        3 │ drop table t;
134          │            ┬
135          │            │
136          │            0. query
137          ╰╴           2. reference
138        ");
139    }
140
141    #[test]
142    fn multiple_references() {
143        assert_snapshot!(find_refs("
144create table users();
145drop table users$0;
146table users;
147"), @r"
148          ╭▸ 
149        2 │ create table users();
150          │              ───── 1. reference
151        3 │ drop table users;
152          │            ┬───┬
153          │            │   │
154          │            │   0. query
155          │            2. reference
156        4 │ table users;
157          ╰╴      ───── 3. reference
158        ");
159    }
160
161    #[test]
162    fn join_using_column() {
163        assert_snapshot!(find_refs("
164create table t(id int);
165create table u(id int);
166select * from t join u using (id$0);
167"), @r"
168          ╭▸ 
169        2 │ create table t(id int);
170          │                ── 1. reference
171        3 │ create table u(id int);
172          │                ── 2. reference
173        4 │ select * from t join u using (id);
174          │                               ┬┬
175          │                               ││
176          │                               │0. query
177          ╰╴                              3. reference
178        ");
179    }
180
181    #[test]
182    fn find_from_definition() {
183        assert_snapshot!(find_refs("
184create table t$0();
185drop table t;
186"), @r"
187          ╭▸ 
188        2 │ create table t();
189          │              ┬
190          │              │
191          │              0. query
192          │              1. reference
193        3 │ drop table t;
194          ╰╴           ─ 2. reference
195        ");
196    }
197
198    #[test]
199    fn with_schema_qualified() {
200        assert_snapshot!(find_refs("
201create table public.users();
202drop table public.users$0;
203table users;
204"), @r"
205          ╭▸ 
206        2 │ create table public.users();
207          │                     ───── 1. reference
208        3 │ drop table public.users;
209          │                   ┬───┬
210          │                   │   │
211          │                   │   0. query
212          │                   2. reference
213        4 │ table users;
214          ╰╴      ───── 3. reference
215        ");
216    }
217
218    #[test]
219    fn temp_table_do_not_shadows_public() {
220        assert_snapshot!(find_refs("
221create table t();
222create temp table t$0();
223drop table t;
224"), @r"
225          ╭▸ 
226        3 │ create temp table t();
227          │                   ┬
228          │                   │
229          │                   0. query
230          ╰╴                  1. reference
231        ");
232    }
233
234    #[test]
235    fn different_schema_no_match() {
236        assert_snapshot!(find_refs("
237create table foo.t();
238create table bar.t$0();
239"), @r"
240          ╭▸ 
241        3 │ create table bar.t();
242          │                  ┬
243          │                  │
244          │                  0. query
245          ╰╴                 1. reference
246        ");
247    }
248
249    #[test]
250    fn with_search_path() {
251        assert_snapshot!(find_refs("
252set search_path to myschema;
253create table myschema.users$0();
254drop table users;
255"), @r"
256          ╭▸ 
257        3 │ create table myschema.users();
258          │                       ┬───┬
259          │                       │   │
260          │                       │   0. query
261          │                       1. reference
262        4 │ drop table users;
263          ╰╴           ───── 2. reference
264        ");
265    }
266
267    #[test]
268    fn temp_table_with_pg_temp_schema() {
269        assert_snapshot!(find_refs("
270create temp table t();
271drop table pg_temp.t$0;
272"), @r"
273          ╭▸ 
274        2 │ create temp table t();
275          │                   ─ 1. reference
276        3 │ drop table pg_temp.t;
277          │                    ┬
278          │                    │
279          │                    0. query
280          ╰╴                   2. reference
281        ");
282    }
283
284    #[test]
285    fn case_insensitive() {
286        assert_snapshot!(find_refs("
287create table Users();
288drop table USERS$0;
289table users;
290"), @r"
291          ╭▸ 
292        2 │ create table Users();
293          │              ───── 1. reference
294        3 │ drop table USERS;
295          │            ┬───┬
296          │            │   │
297          │            │   0. query
298          │            2. reference
299        4 │ table users;
300          ╰╴      ───── 3. reference
301        ");
302    }
303    #[test]
304    fn case_insensitive_part_2() {
305        // we should see refs for `drop table` and `table`
306        assert_snapshot!(find_refs(r#"
307create table actors();
308create table "Actors"();
309drop table ACTORS$0;
310table actors;
311"#), @r#"
312          ╭▸ 
313        2 │ create table actors();
314          │              ────── 1. reference
315        3 │ create table "Actors"();
316        4 │ drop table ACTORS;
317          │            ┬────┬
318          │            │    │
319          │            │    0. query
320          │            2. reference
321        5 │ table actors;
322          ╰╴      ────── 3. reference
323        "#);
324    }
325
326    #[test]
327    fn case_insensitive_with_schema() {
328        assert_snapshot!(find_refs("
329create table Public.Users();
330drop table PUBLIC.USERS$0;
331table public.users;
332"), @r"
333          ╭▸ 
334        2 │ create table Public.Users();
335          │                     ───── 1. reference
336        3 │ drop table PUBLIC.USERS;
337          │                   ┬───┬
338          │                   │   │
339          │                   │   0. query
340          │                   2. reference
341        4 │ table public.users;
342          ╰╴             ───── 3. reference
343        ");
344    }
345
346    #[test]
347    fn no_partial_match() {
348        assert_snapshot!(find_refs("
349create table t$0();
350create table temp_t();
351"), @r"
352          ╭▸ 
353        2 │ create table t();
354          │              ┬
355          │              │
356          │              0. query
357          ╰╴             1. reference
358        ");
359    }
360
361    #[test]
362    fn identifier_boundaries() {
363        assert_snapshot!(find_refs("
364create table foo$0();
365drop table foo;
366drop table foo1;
367drop table barfoo;
368drop table foo_bar;
369"), @r"
370          ╭▸ 
371        2 │ create table foo();
372          │              ┬─┬
373          │              │ │
374          │              │ 0. query
375          │              1. reference
376        3 │ drop table foo;
377          ╰╴           ─── 2. reference
378        ");
379    }
380
381    #[test]
382    fn builtin_function_references() {
383        assert_snapshot!(find_refs("
384select now$0();
385select now();
386"), @"
387              ╭▸ current.sql:2:8
388389            2 │ select now();
390              │        ┬─┬
391              │        │ │
392              │        │ 0. query
393              │        1. reference
394            3 │ select now();
395              │        ─── 2. reference
396              ╰╴
397
398              ╭▸ builtins.sql:11089:28
399400        11089 │ create function pg_catalog.now() returns timestamp with time zone
401              ╰╴                           ─── 3. reference
402        ");
403    }
404}