Skip to main content

squawk_ide/
find_references.rs

1use crate::db::{File, parse};
2use crate::goto_definition;
3use crate::location::Location;
4use rowan::TextSize;
5use salsa::Database as Db;
6use squawk_syntax::ast::{self, AstNode};
7
8#[salsa::tracked]
9pub fn find_references(db: &dyn Db, file: File, offset: TextSize) -> Vec<Location> {
10    let targets = goto_definition::goto_definition(db, file, offset);
11    let Some(first) = targets.first() else {
12        return vec![];
13    };
14
15    let mut refs = targets.to_vec();
16
17    for node in parse(db, file)
18        .tree()
19        .syntax()
20        .descendants()
21        .filter(|x| ast::NameRef::can_cast(x.kind()))
22    {
23        let range = node.text_range();
24        let matches = goto_definition::goto_definition(db, file, range.start())
25            .into_iter()
26            .any(|location| targets.contains(&location));
27        if matches {
28            refs.push(Location {
29                file,
30                range,
31                kind: first.kind,
32            });
33        }
34    }
35    refs.sort_by_key(|loc| (loc.file != file, loc.range.start()));
36    refs
37}
38
39#[cfg(test)]
40mod test {
41    use crate::builtins::builtins_file;
42    use crate::db::{Database, File};
43    use crate::find_references::find_references;
44    use crate::test_utils::Fixture;
45    use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
46    use insta::assert_snapshot;
47    use rowan::TextRange;
48    use rustc_hash::FxHashMap;
49
50    #[must_use]
51    #[track_caller]
52    fn find_refs(sql: &str) -> String {
53        let fixture = Fixture::new(sql);
54        let marker = fixture.marker();
55        let offset = marker.offset_before();
56        let query_span = marker.range();
57        let sql = fixture.sql();
58        let db = Database::default();
59        let current_file = File::new(&db, sql.into());
60        assert_eq!(crate::db::parse(&db, current_file).errors(), vec![]);
61
62        let references = find_references(&db, current_file, offset);
63
64        let mut file_paths = FxHashMap::default();
65        file_paths.insert(current_file, "current.sql");
66        file_paths.insert(builtins_file(&db), "builtins.sql");
67
68        let mut refs_by_file: FxHashMap<File, Vec<(usize, TextRange)>> = FxHashMap::default();
69        for (i, location) in references.iter().enumerate() {
70            refs_by_file
71                .entry(location.file)
72                .or_default()
73                .push((i + 1, location.range));
74        }
75
76        let multi_file = refs_by_file.len() > 1 || !refs_by_file.contains_key(&current_file);
77
78        let mut snippet = Snippet::source(sql).fold(true);
79        if multi_file {
80            snippet = snippet.path(*file_paths.get(&current_file).unwrap());
81        }
82        snippet = snippet.annotation(AnnotationKind::Context.span(query_span).label("0. query"));
83        if let Some(current_refs) = refs_by_file.remove(&current_file) {
84            snippet = annotate_refs(snippet, current_refs);
85        }
86
87        let mut groups = vec![Level::INFO.primary_title("references").element(snippet)];
88
89        for (ref_file, refs) in refs_by_file {
90            let path = file_paths.get(&ref_file).unwrap();
91            let other_snippet = Snippet::source(ref_file.content(&db).as_ref())
92                .path(*path)
93                .fold(true);
94            let other_snippet = annotate_refs(other_snippet, refs);
95            groups.push(
96                Level::INFO
97                    .primary_title("references")
98                    .element(other_snippet),
99            );
100        }
101
102        let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
103        renderer
104            .render(&groups)
105            .to_string()
106            .replace("info: references", "")
107    }
108
109    fn annotate_refs<'a>(
110        mut snippet: Snippet<'a, annotate_snippets::Annotation<'a>>,
111        refs: Vec<(usize, TextRange)>,
112    ) -> Snippet<'a, annotate_snippets::Annotation<'a>> {
113        for (label_index, range) in refs {
114            snippet = snippet.annotation(
115                AnnotationKind::Context
116                    .span(range.into())
117                    .label(format!("{}. reference", label_index)),
118            );
119        }
120        snippet
121    }
122
123    #[test]
124    fn simple_table_reference() {
125        assert_snapshot!(find_refs("
126create table t();
127drop table t$0;
128"), @r"
129          ╭▸ 
130        2 │ create table t();
131          │              ─ 1. reference
132        3 │ drop table t;
133          │            ┬
134          │            │
135          │            0. query
136          ╰╴           2. reference
137        ");
138    }
139
140    #[test]
141    fn multiple_references() {
142        assert_snapshot!(find_refs("
143create table users();
144drop table users$0;
145table users;
146"), @r"
147          ╭▸ 
148        2 │ create table users();
149          │              ───── 1. reference
150        3 │ drop table users;
151          │            ┬───┬
152          │            │   │
153          │            │   0. query
154          │            2. reference
155        4 │ table users;
156          ╰╴      ───── 3. reference
157        ");
158    }
159
160    #[test]
161    fn join_using_column() {
162        assert_snapshot!(find_refs("
163create table t(id int);
164create table u(id int);
165select * from t join u using (id$0);
166"), @r"
167          ╭▸ 
168        2 │ create table t(id int);
169          │                ── 1. reference
170        3 │ create table u(id int);
171          │                ── 2. reference
172        4 │ select * from t join u using (id);
173          │                               ┬┬
174          │                               ││
175          │                               │0. query
176          ╰╴                              3. reference
177        ");
178    }
179
180    #[test]
181    fn find_from_definition() {
182        assert_snapshot!(find_refs("
183create table t$0();
184drop table t;
185"), @r"
186          ╭▸ 
187        2 │ create table t();
188          │              ┬
189          │              │
190          │              0. query
191          │              1. reference
192        3 │ drop table t;
193          ╰╴           ─ 2. reference
194        ");
195    }
196
197    #[test]
198    fn with_schema_qualified() {
199        assert_snapshot!(find_refs("
200create table public.users();
201drop table public.users$0;
202table users;
203"), @r"
204          ╭▸ 
205        2 │ create table public.users();
206          │                     ───── 1. reference
207        3 │ drop table public.users;
208          │                   ┬───┬
209          │                   │   │
210          │                   │   0. query
211          │                   2. reference
212        4 │ table users;
213          ╰╴      ───── 3. reference
214        ");
215    }
216
217    #[test]
218    fn temp_table_do_not_shadows_public() {
219        assert_snapshot!(find_refs("
220create table t();
221create temp table t$0();
222drop table t;
223"), @r"
224          ╭▸ 
225        3 │ create temp table t();
226          │                   ┬
227          │                   │
228          │                   0. query
229          ╰╴                  1. reference
230        ");
231    }
232
233    #[test]
234    fn different_schema_no_match() {
235        assert_snapshot!(find_refs("
236create table foo.t();
237create table bar.t$0();
238"), @r"
239          ╭▸ 
240        3 │ create table bar.t();
241          │                  ┬
242          │                  │
243          │                  0. query
244          ╰╴                 1. reference
245        ");
246    }
247
248    #[test]
249    fn with_search_path() {
250        assert_snapshot!(find_refs("
251set search_path to myschema;
252create table myschema.users$0();
253drop table users;
254"), @r"
255          ╭▸ 
256        3 │ create table myschema.users();
257          │                       ┬───┬
258          │                       │   │
259          │                       │   0. query
260          │                       1. reference
261        4 │ drop table users;
262          ╰╴           ───── 2. reference
263        ");
264    }
265
266    #[test]
267    fn temp_table_with_pg_temp_schema() {
268        assert_snapshot!(find_refs("
269create temp table t();
270drop table pg_temp.t$0;
271"), @r"
272          ╭▸ 
273        2 │ create temp table t();
274          │                   ─ 1. reference
275        3 │ drop table pg_temp.t;
276          │                    ┬
277          │                    │
278          │                    0. query
279          ╰╴                   2. reference
280        ");
281    }
282
283    #[test]
284    fn case_insensitive() {
285        assert_snapshot!(find_refs("
286create table Users();
287drop table USERS$0;
288table users;
289"), @r"
290          ╭▸ 
291        2 │ create table Users();
292          │              ───── 1. reference
293        3 │ drop table USERS;
294          │            ┬───┬
295          │            │   │
296          │            │   0. query
297          │            2. reference
298        4 │ table users;
299          ╰╴      ───── 3. reference
300        ");
301    }
302    #[test]
303    fn case_insensitive_part_2() {
304        // we should see refs for `drop table` and `table`
305        assert_snapshot!(find_refs(r#"
306create table actors();
307create table "Actors"();
308drop table ACTORS$0;
309table actors;
310"#), @r#"
311          ╭▸ 
312        2 │ create table actors();
313          │              ────── 1. reference
314        3 │ create table "Actors"();
315        4 │ drop table ACTORS;
316          │            ┬────┬
317          │            │    │
318          │            │    0. query
319          │            2. reference
320        5 │ table actors;
321          ╰╴      ────── 3. reference
322        "#);
323    }
324
325    #[test]
326    fn case_insensitive_with_schema() {
327        assert_snapshot!(find_refs("
328create table Public.Users();
329drop table PUBLIC.USERS$0;
330table public.users;
331"), @r"
332          ╭▸ 
333        2 │ create table Public.Users();
334          │                     ───── 1. reference
335        3 │ drop table PUBLIC.USERS;
336          │                   ┬───┬
337          │                   │   │
338          │                   │   0. query
339          │                   2. reference
340        4 │ table public.users;
341          ╰╴             ───── 3. reference
342        ");
343    }
344
345    #[test]
346    fn no_partial_match() {
347        assert_snapshot!(find_refs("
348create table t$0();
349create table temp_t();
350"), @r"
351          ╭▸ 
352        2 │ create table t();
353          │              ┬
354          │              │
355          │              0. query
356          ╰╴             1. reference
357        ");
358    }
359
360    #[test]
361    fn identifier_boundaries() {
362        assert_snapshot!(find_refs("
363create table foo$0();
364drop table foo;
365drop table foo1;
366drop table barfoo;
367drop table foo_bar;
368"), @r"
369          ╭▸ 
370        2 │ create table foo();
371          │              ┬─┬
372          │              │ │
373          │              │ 0. query
374          │              1. reference
375        3 │ drop table foo;
376          ╰╴           ─── 2. reference
377        ");
378    }
379
380    #[test]
381    fn builtin_function_references() {
382        assert_snapshot!(find_refs("
383select now$0();
384select now();
385"), @"
386              ╭▸ current.sql:2:8
387388            2 │ select now();
389              │        ┬─┬
390              │        │ │
391              │        │ 0. query
392              │        1. reference
393            3 │ select now();
394              │        ─── 2. reference
395              ╰╴
396
397              ╭▸ builtins.sql:11089:28
398399        11089 │ create function pg_catalog.now() returns timestamp with time zone
400              ╰╴                           ─── 3. reference
401        ");
402    }
403}