Skip to main content

squawk_ide/
find_references.rs

1use crate::db::parse;
2use crate::file::InFile;
3use crate::goto_definition;
4use crate::location::Location;
5use rowan::TextSize;
6use salsa::Database as Db;
7use squawk_syntax::{
8    SyntaxNode,
9    ast::{self, AstNode},
10};
11
12fn is_reference_node(node: &SyntaxNode) -> bool {
13    if ast::NameRef::can_cast(node.kind()) {
14        return true;
15    }
16
17    if let Some(ty) = ast::Type::cast(node.clone()) {
18        return match ty {
19            ast::Type::BitType(_)
20            | ast::Type::CharType(_)
21            | ast::Type::DoubleType(_)
22            | ast::Type::IntervalType(_)
23            | ast::Type::TimeType(_) => true,
24            ast::Type::ArrayType(_)
25            | ast::Type::ExprType(_)
26            | ast::Type::PathType(_)
27            | ast::Type::PercentType(_) => false,
28        };
29    }
30
31    false
32}
33
34pub fn find_references(db: &dyn Db, position: InFile<TextSize>) -> Vec<Location> {
35    let file = position.file_id;
36    let targets = goto_definition::goto_definition(db, position);
37    let Some(first) = targets.first() else {
38        return vec![];
39    };
40
41    let mut refs = targets.to_vec();
42
43    for node in parse(db, file)
44        .tree()
45        .syntax()
46        .descendants()
47        .filter(is_reference_node)
48    {
49        let range = node.text_range();
50        let matches = goto_definition::goto_definition(db, InFile::new(file, range.start()))
51            .into_iter()
52            .any(|location| targets.contains(&location));
53        if matches {
54            refs.push(Location {
55                file,
56                range,
57                kind: first.kind,
58            });
59        }
60    }
61    refs.sort_by_key(|loc| (loc.file != file, loc.range.start()));
62    refs
63}
64
65#[cfg(test)]
66mod test {
67    use crate::builtins::builtins_file;
68    use crate::db::File;
69
70    use crate::find_references::find_references;
71    use crate::test_utils::Fixture;
72    use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
73    use insta::assert_snapshot;
74    use rowan::TextRange;
75    use rustc_hash::FxHashMap;
76
77    #[must_use]
78    #[track_caller]
79    fn find_refs(sql: &str) -> String {
80        let fixture = Fixture::new(sql);
81        let marker = fixture.marker();
82        let offset = marker.offset_before();
83        let query_span = marker.range();
84        let db = fixture.db();
85        let current_file = offset.file_id;
86
87        let references = find_references(db, offset);
88
89        let mut file_paths = FxHashMap::default();
90        file_paths.insert(current_file, "current.sql");
91        file_paths.insert(builtins_file(db), "builtins.sql");
92
93        let mut refs_by_file: FxHashMap<File, Vec<(usize, TextRange)>> = FxHashMap::default();
94        for (i, location) in references.iter().enumerate() {
95            refs_by_file
96                .entry(location.file)
97                .or_default()
98                .push((i + 1, location.range));
99        }
100
101        let multi_file = refs_by_file.len() > 1 || !refs_by_file.contains_key(&current_file);
102
103        let mut snippet = Snippet::source(current_file.content(db).as_ref()).fold(true);
104        if multi_file {
105            snippet = snippet.path(*file_paths.get(&current_file).unwrap());
106        }
107        snippet = snippet.annotation(AnnotationKind::Context.span(query_span).label("0. query"));
108        if let Some(current_refs) = refs_by_file.remove(&current_file) {
109            snippet = annotate_refs(snippet, current_refs);
110        }
111
112        let mut groups = vec![Level::INFO.primary_title("references").element(snippet)];
113
114        for (ref_file, refs) in refs_by_file {
115            let path = file_paths.get(&ref_file).unwrap();
116            let other_snippet = Snippet::source(ref_file.content(db).as_ref())
117                .path(*path)
118                .fold(true);
119            let other_snippet = annotate_refs(other_snippet, refs);
120            groups.push(
121                Level::INFO
122                    .primary_title("references")
123                    .element(other_snippet),
124            );
125        }
126
127        let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
128        renderer
129            .render(&groups)
130            .to_string()
131            .replace("info: references", "")
132    }
133
134    fn annotate_refs<'a>(
135        mut snippet: Snippet<'a, annotate_snippets::Annotation<'a>>,
136        refs: Vec<(usize, TextRange)>,
137    ) -> Snippet<'a, annotate_snippets::Annotation<'a>> {
138        for (label_index, range) in refs {
139            snippet = snippet.annotation(
140                AnnotationKind::Context
141                    .span(range.into())
142                    .label(format!("{label_index}. reference")),
143            );
144        }
145        snippet
146    }
147
148    #[test]
149    fn simple_table_reference() {
150        assert_snapshot!(find_refs("
151create table t();
152drop table t$0;
153"), @r"
154          ╭▸ 
155        2 │ create table t();
156          │              ─ 1. reference
157        3 │ drop table t;
158          │            ┬
159          │            │
160          │            0. query
161          ╰╴           2. reference
162        ");
163    }
164
165    #[test]
166    fn multiple_references() {
167        assert_snapshot!(find_refs("
168create table users();
169drop table users$0;
170table users;
171"), @r"
172          ╭▸ 
173        2 │ create table users();
174          │              ───── 1. reference
175        3 │ drop table users;
176          │            ┬───┬
177          │            │   │
178          │            │   0. query
179          │            2. reference
180        4 │ table users;
181          ╰╴      ───── 3. reference
182        ");
183    }
184
185    #[test]
186    fn join_using_column() {
187        assert_snapshot!(find_refs("
188create table t(id int);
189create table u(id int);
190select * from t join u using (id$0);
191"), @r"
192          ╭▸ 
193        2 │ create table t(id int);
194          │                ── 1. reference
195        3 │ create table u(id int);
196          │                ── 2. reference
197        4 │ select * from t join u using (id);
198          │                               ┬┬
199          │                               ││
200          │                               │0. query
201          ╰╴                              3. reference
202        ");
203    }
204
205    #[test]
206    fn find_from_definition() {
207        assert_snapshot!(find_refs("
208create table t$0();
209drop table t;
210"), @r"
211          ╭▸ 
212        2 │ create table t();
213          │              ┬
214          │              │
215          │              0. query
216          │              1. reference
217        3 │ drop table t;
218          ╰╴           ─ 2. reference
219        ");
220    }
221
222    #[test]
223    fn with_schema_qualified() {
224        assert_snapshot!(find_refs("
225create table public.users();
226drop table public.users$0;
227table users;
228"), @r"
229          ╭▸ 
230        2 │ create table public.users();
231          │                     ───── 1. reference
232        3 │ drop table public.users;
233          │                   ┬───┬
234          │                   │   │
235          │                   │   0. query
236          │                   2. reference
237        4 │ table users;
238          ╰╴      ───── 3. reference
239        ");
240    }
241
242    #[test]
243    fn temp_table_do_not_shadows_public() {
244        assert_snapshot!(find_refs("
245create table t();
246create temp table t$0();
247drop table t;
248"), @r"
249          ╭▸ 
250        3 │ create temp table t();
251          │                   ┬
252          │                   │
253          │                   0. query
254          ╰╴                  1. reference
255        ");
256    }
257
258    #[test]
259    fn different_schema_no_match() {
260        assert_snapshot!(find_refs("
261create table foo.t();
262create table bar.t$0();
263"), @r"
264          ╭▸ 
265        3 │ create table bar.t();
266          │                  ┬
267          │                  │
268          │                  0. query
269          ╰╴                 1. reference
270        ");
271    }
272
273    #[test]
274    fn with_search_path() {
275        assert_snapshot!(find_refs("
276set search_path to myschema;
277create table myschema.users$0();
278drop table users;
279"), @r"
280          ╭▸ 
281        3 │ create table myschema.users();
282          │                       ┬───┬
283          │                       │   │
284          │                       │   0. query
285          │                       1. reference
286        4 │ drop table users;
287          ╰╴           ───── 2. reference
288        ");
289    }
290
291    #[test]
292    fn temp_table_with_pg_temp_schema() {
293        assert_snapshot!(find_refs("
294create temp table t();
295drop table pg_temp.t$0;
296"), @r"
297          ╭▸ 
298        2 │ create temp table t();
299          │                   ─ 1. reference
300        3 │ drop table pg_temp.t;
301          │                    ┬
302          │                    │
303          │                    0. query
304          ╰╴                   2. reference
305        ");
306    }
307
308    #[test]
309    fn case_insensitive() {
310        assert_snapshot!(find_refs("
311create table Users();
312drop table USERS$0;
313table users;
314"), @r"
315          ╭▸ 
316        2 │ create table Users();
317          │              ───── 1. reference
318        3 │ drop table USERS;
319          │            ┬───┬
320          │            │   │
321          │            │   0. query
322          │            2. reference
323        4 │ table users;
324          ╰╴      ───── 3. reference
325        ");
326    }
327    #[test]
328    fn case_insensitive_part_2() {
329        // we should see refs for `drop table` and `table`
330        assert_snapshot!(find_refs(r#"
331create table actors();
332create table "Actors"();
333drop table ACTORS$0;
334table actors;
335"#), @r#"
336          ╭▸ 
337        2 │ create table actors();
338          │              ────── 1. reference
339        3 │ create table "Actors"();
340        4 │ drop table ACTORS;
341          │            ┬────┬
342          │            │    │
343          │            │    0. query
344          │            2. reference
345        5 │ table actors;
346          ╰╴      ────── 3. reference
347        "#);
348    }
349
350    #[test]
351    fn case_insensitive_with_schema() {
352        assert_snapshot!(find_refs("
353create table Public.Users();
354drop table PUBLIC.USERS$0;
355table public.users;
356"), @r"
357          ╭▸ 
358        2 │ create table Public.Users();
359          │                     ───── 1. reference
360        3 │ drop table PUBLIC.USERS;
361          │                   ┬───┬
362          │                   │   │
363          │                   │   0. query
364          │                   2. reference
365        4 │ table public.users;
366          ╰╴             ───── 3. reference
367        ");
368    }
369
370    #[test]
371    fn no_partial_match() {
372        assert_snapshot!(find_refs("
373create table t$0();
374create table temp_t();
375"), @r"
376          ╭▸ 
377        2 │ create table t();
378          │              ┬
379          │              │
380          │              0. query
381          ╰╴             1. reference
382        ");
383    }
384
385    #[test]
386    fn identifier_boundaries() {
387        assert_snapshot!(find_refs("
388create table foo$0();
389drop table foo;
390drop table foo1;
391drop table barfoo;
392drop table foo_bar;
393"), @r"
394          ╭▸ 
395        2 │ create table foo();
396          │              ┬─┬
397          │              │ │
398          │              │ 0. query
399          │              1. reference
400        3 │ drop table foo;
401          ╰╴           ─── 2. reference
402        ");
403    }
404
405    #[test]
406    fn builtin_function_references() {
407        assert_snapshot!(find_refs("
408-- include-builtins
409select now$0();
410select now();
411"), @"
412              ╭▸ current.sql:3:8
413414            3 │ select now();
415              │        ┬─┬
416              │        │ │
417              │        │ 0. query
418              │        1. reference
419            4 │ select now();
420              │        ─── 2. reference
421              ╰╴
422
423              ╭▸ builtins.sql:11089:28
424425        11089 │ create function pg_catalog.now() returns timestamp with time zone
426              ╰╴                           ─── 3. reference
427        ");
428    }
429
430    #[test]
431    fn bit() {
432        assert_snapshot!(find_refs("
433create type pg_catalog.bit$0;
434
435create function pg_catalog.bit(bigint, integer) returns bit
436  language internal;
437
438create function pg_catalog.bit(bit, integer, boolean) returns bit
439  language internal;
440
441create function pg_catalog.bit(integer, integer) returns bit
442  language internal;
443"), @"
444           ╭▸ 
445         2 │ create type pg_catalog.bit;
446           │                        ┬─┬
447           │                        │ │
448           │                        │ 0. query
449           │                        1. reference
450         3 │
451         4 │ create function pg_catalog.bit(bigint, integer) returns bit
452           │                                                         ─── 2. reference
453454         7 │ create function pg_catalog.bit(bit, integer, boolean) returns bit
455           │                                ─── 3. reference               ─── 4. reference
456457        10 │ create function pg_catalog.bit(integer, integer) returns bit
458           ╰╴                                                         ─── 5. reference
459        ");
460    }
461
462    #[test]
463    fn char() {
464        assert_snapshot!(find_refs("
465create type pg_catalog.bpchar$0;
466
467select '1'::char;
468select '1'::bpchar;
469"), @"
470          ╭▸ 
471        2 │ create type pg_catalog.bpchar;
472          │                        ┬────┬
473          │                        │    │
474          │                        │    0. query
475          │                        1. reference
476        3 │
477        4 │ select '1'::char;
478          │             ──── 2. reference
479        5 │ select '1'::bpchar;
480          ╰╴            ────── 3. reference
481        ");
482    }
483}