Skip to main content

squawk_ide/
find_references.rs

1use crate::db::{File, parse};
2use crate::goto_definition;
3use crate::location::Location;
4use rowan::TextSize;
5use salsa::Database as Db;
6use squawk_syntax::ast::{self, AstNode};
7
8#[salsa::tracked]
9pub fn find_references(db: &dyn Db, file: File, offset: TextSize) -> Vec<Location> {
10    let targets = goto_definition::goto_definition(db, file, offset);
11    let Some(first) = targets.first() else {
12        return vec![];
13    };
14
15    let mut refs = targets.to_vec();
16
17    for node in parse(db, file)
18        .tree()
19        .syntax()
20        .descendants()
21        .filter(|x| ast::NameRef::can_cast(x.kind()))
22    {
23        let range = node.text_range();
24        let matches = goto_definition::goto_definition(db, file, range.start())
25            .into_iter()
26            .any(|location| targets.contains(&location));
27        if matches {
28            refs.push(Location {
29                file,
30                range,
31                kind: first.kind,
32            });
33        }
34    }
35    refs.sort_by_key(|loc| (loc.file != file, loc.range.start()));
36    refs
37}
38
39#[cfg(test)]
40mod test {
41    use crate::builtins::builtins_file;
42    use crate::db::{Database, File};
43    use crate::find_references::find_references;
44    use crate::test_utils::Fixture;
45    use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
46    use insta::assert_snapshot;
47    use rowan::TextRange;
48    use rustc_hash::FxHashMap;
49
50    #[track_caller]
51    fn find_refs(sql: &str) -> String {
52        let fixture = Fixture::new(sql);
53        let marker = fixture.marker();
54        let offset = marker.offset_before();
55        let query_span = marker.range();
56        let sql = fixture.sql();
57        let db = Database::default();
58        let current_file = File::new(&db, sql.into());
59        assert_eq!(crate::db::parse(&db, current_file).errors(), vec![]);
60
61        let references = find_references(&db, current_file, offset);
62
63        let mut file_paths = FxHashMap::default();
64        file_paths.insert(current_file, "current.sql");
65        file_paths.insert(builtins_file(&db), "builtins.sql");
66
67        let mut refs_by_file: FxHashMap<File, Vec<(usize, TextRange)>> = FxHashMap::default();
68        for (i, location) in references.iter().enumerate() {
69            refs_by_file
70                .entry(location.file)
71                .or_default()
72                .push((i + 1, location.range));
73        }
74
75        let multi_file = refs_by_file.len() > 1 || !refs_by_file.contains_key(&current_file);
76
77        let mut snippet = Snippet::source(sql).fold(true);
78        if multi_file {
79            snippet = snippet.path(*file_paths.get(&current_file).unwrap());
80        }
81        snippet = snippet.annotation(AnnotationKind::Context.span(query_span).label("0. query"));
82        if let Some(current_refs) = refs_by_file.remove(&current_file) {
83            snippet = annotate_refs(snippet, current_refs);
84        }
85
86        let mut groups = vec![Level::INFO.primary_title("references").element(snippet)];
87
88        for (ref_file, refs) in refs_by_file {
89            let path = file_paths.get(&ref_file).unwrap();
90            let other_snippet = Snippet::source(ref_file.content(&db).as_ref())
91                .path(*path)
92                .fold(true);
93            let other_snippet = annotate_refs(other_snippet, refs);
94            groups.push(
95                Level::INFO
96                    .primary_title("references")
97                    .element(other_snippet),
98            );
99        }
100
101        let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
102        renderer
103            .render(&groups)
104            .to_string()
105            .replace("info: references", "")
106    }
107
108    fn annotate_refs<'a>(
109        mut snippet: Snippet<'a, annotate_snippets::Annotation<'a>>,
110        refs: Vec<(usize, TextRange)>,
111    ) -> Snippet<'a, annotate_snippets::Annotation<'a>> {
112        for (label_index, range) in refs {
113            snippet = snippet.annotation(
114                AnnotationKind::Context
115                    .span(range.into())
116                    .label(format!("{}. reference", label_index)),
117            );
118        }
119        snippet
120    }
121
122    #[test]
123    fn simple_table_reference() {
124        assert_snapshot!(find_refs("
125create table t();
126drop table t$0;
127"), @r"
128          ╭▸ 
129        2 │ create table t();
130          │              ─ 1. reference
131        3 │ drop table t;
132          │            ┬
133          │            │
134          │            0. query
135          ╰╴           2. reference
136        ");
137    }
138
139    #[test]
140    fn multiple_references() {
141        assert_snapshot!(find_refs("
142create table users();
143drop table users$0;
144table users;
145"), @r"
146          ╭▸ 
147        2 │ create table users();
148          │              ───── 1. reference
149        3 │ drop table users;
150          │            ┬───┬
151          │            │   │
152          │            │   0. query
153          │            2. reference
154        4 │ table users;
155          ╰╴      ───── 3. reference
156        ");
157    }
158
159    #[test]
160    fn join_using_column() {
161        assert_snapshot!(find_refs("
162create table t(id int);
163create table u(id int);
164select * from t join u using (id$0);
165"), @r"
166          ╭▸ 
167        2 │ create table t(id int);
168          │                ── 1. reference
169        3 │ create table u(id int);
170          │                ── 2. reference
171        4 │ select * from t join u using (id);
172          │                               ┬┬
173          │                               ││
174          │                               │0. query
175          ╰╴                              3. reference
176        ");
177    }
178
179    #[test]
180    fn find_from_definition() {
181        assert_snapshot!(find_refs("
182create table t$0();
183drop table t;
184"), @r"
185          ╭▸ 
186        2 │ create table t();
187          │              ┬
188          │              │
189          │              0. query
190          │              1. reference
191        3 │ drop table t;
192          ╰╴           ─ 2. reference
193        ");
194    }
195
196    #[test]
197    fn with_schema_qualified() {
198        assert_snapshot!(find_refs("
199create table public.users();
200drop table public.users$0;
201table users;
202"), @r"
203          ╭▸ 
204        2 │ create table public.users();
205          │                     ───── 1. reference
206        3 │ drop table public.users;
207          │                   ┬───┬
208          │                   │   │
209          │                   │   0. query
210          │                   2. reference
211        4 │ table users;
212          ╰╴      ───── 3. reference
213        ");
214    }
215
216    #[test]
217    fn temp_table_do_not_shadows_public() {
218        assert_snapshot!(find_refs("
219create table t();
220create temp table t$0();
221drop table t;
222"), @r"
223          ╭▸ 
224        3 │ create temp table t();
225          │                   ┬
226          │                   │
227          │                   0. query
228          ╰╴                  1. reference
229        ");
230    }
231
232    #[test]
233    fn different_schema_no_match() {
234        assert_snapshot!(find_refs("
235create table foo.t();
236create table bar.t$0();
237"), @r"
238          ╭▸ 
239        3 │ create table bar.t();
240          │                  ┬
241          │                  │
242          │                  0. query
243          ╰╴                 1. reference
244        ");
245    }
246
247    #[test]
248    fn with_search_path() {
249        assert_snapshot!(find_refs("
250set search_path to myschema;
251create table myschema.users$0();
252drop table users;
253"), @r"
254          ╭▸ 
255        3 │ create table myschema.users();
256          │                       ┬───┬
257          │                       │   │
258          │                       │   0. query
259          │                       1. reference
260        4 │ drop table users;
261          ╰╴           ───── 2. reference
262        ");
263    }
264
265    #[test]
266    fn temp_table_with_pg_temp_schema() {
267        assert_snapshot!(find_refs("
268create temp table t();
269drop table pg_temp.t$0;
270"), @r"
271          ╭▸ 
272        2 │ create temp table t();
273          │                   ─ 1. reference
274        3 │ drop table pg_temp.t;
275          │                    ┬
276          │                    │
277          │                    0. query
278          ╰╴                   2. reference
279        ");
280    }
281
282    #[test]
283    fn case_insensitive() {
284        assert_snapshot!(find_refs("
285create table Users();
286drop table USERS$0;
287table users;
288"), @r"
289          ╭▸ 
290        2 │ create table Users();
291          │              ───── 1. reference
292        3 │ drop table USERS;
293          │            ┬───┬
294          │            │   │
295          │            │   0. query
296          │            2. reference
297        4 │ table users;
298          ╰╴      ───── 3. reference
299        ");
300    }
301    #[test]
302    fn case_insensitive_part_2() {
303        // we should see refs for `drop table` and `table`
304        assert_snapshot!(find_refs(r#"
305create table actors();
306create table "Actors"();
307drop table ACTORS$0;
308table actors;
309"#), @r#"
310          ╭▸ 
311        2 │ create table actors();
312          │              ────── 1. reference
313        3 │ create table "Actors"();
314        4 │ drop table ACTORS;
315          │            ┬────┬
316          │            │    │
317          │            │    0. query
318          │            2. reference
319        5 │ table actors;
320          ╰╴      ────── 3. reference
321        "#);
322    }
323
324    #[test]
325    fn case_insensitive_with_schema() {
326        assert_snapshot!(find_refs("
327create table Public.Users();
328drop table PUBLIC.USERS$0;
329table public.users;
330"), @r"
331          ╭▸ 
332        2 │ create table Public.Users();
333          │                     ───── 1. reference
334        3 │ drop table PUBLIC.USERS;
335          │                   ┬───┬
336          │                   │   │
337          │                   │   0. query
338          │                   2. reference
339        4 │ table public.users;
340          ╰╴             ───── 3. reference
341        ");
342    }
343
344    #[test]
345    fn no_partial_match() {
346        assert_snapshot!(find_refs("
347create table t$0();
348create table temp_t();
349"), @r"
350          ╭▸ 
351        2 │ create table t();
352          │              ┬
353          │              │
354          │              0. query
355          ╰╴             1. reference
356        ");
357    }
358
359    #[test]
360    fn identifier_boundaries() {
361        assert_snapshot!(find_refs("
362create table foo$0();
363drop table foo;
364drop table foo1;
365drop table barfoo;
366drop table foo_bar;
367"), @r"
368          ╭▸ 
369        2 │ create table foo();
370          │              ┬─┬
371          │              │ │
372          │              │ 0. query
373          │              1. reference
374        3 │ drop table foo;
375          ╰╴           ─── 2. reference
376        ");
377    }
378
379    #[test]
380    fn builtin_function_references() {
381        assert_snapshot!(find_refs("
382select now$0();
383select now();
384"), @"
385              ╭▸ current.sql:2:8
386387            2 │ select now();
388              │        ┬─┬
389              │        │ │
390              │        │ 0. query
391              │        1. reference
392            3 │ select now();
393              │        ─── 2. reference
394              ╰╴
395
396              ╭▸ builtins.sql:11089:28
397398        11089 │ create function pg_catalog.now() returns timestamp with time zone
399              ╰╴                           ─── 3. reference
400        ");
401    }
402}