Skip to main content

squawk_ide/
find_references.rs

1use crate::binder::Binder;
2use crate::builtins::{builtins_binder, parse_builtins};
3use crate::classify::classify_def_node;
4use crate::db::{File, bind, parse};
5use crate::goto_definition::{FileId, Location, LocationKind};
6use crate::offsets::token_from_offset;
7use crate::resolve;
8use rowan::TextSize;
9use salsa::Database as Db;
10use smallvec::{SmallVec, smallvec};
11use squawk_syntax::{
12    SyntaxNodePtr,
13    ast::{self, AstNode},
14    match_ast,
15};
16
17#[salsa::tracked]
18pub fn find_references(db: &dyn Db, file: File, offset: TextSize) -> Vec<Location> {
19    let parse = parse(db, file);
20    let source_file = parse.tree();
21
22    let current_binder = bind(db, file);
23
24    let builtins_tree = parse_builtins(db).tree();
25    let builtins_binder = builtins_binder(db);
26
27    let Some((target_file, target_defs, target_kind)) = find_target_defs(
28        &source_file,
29        offset,
30        &current_binder,
31        &builtins_tree,
32        &builtins_binder,
33    ) else {
34        return vec![];
35    };
36
37    let (binder, root) = match target_file {
38        FileId::Current => (&current_binder, source_file.syntax()),
39        FileId::Builtins => (&builtins_binder, builtins_tree.syntax()),
40    };
41
42    let mut refs: Vec<Location> = vec![];
43
44    if target_file == FileId::Builtins {
45        for ptr in &target_defs {
46            refs.push(Location {
47                file: FileId::Builtins,
48                range: ptr.to_node(builtins_tree.syntax()).text_range(),
49                kind: target_kind,
50            });
51        }
52    }
53
54    for node in source_file.syntax().descendants() {
55        match_ast! {
56            match node {
57                ast::NameRef(name_ref) => {
58                    // Check if the ref matches one of the defs
59                    if let Some(found_defs) = resolve::resolve_name_ref_ptrs(binder, root, &name_ref)
60                        && found_defs.iter().any(|def| target_defs.contains(def))
61                    {
62                        refs.push(Location {
63                            file: FileId::Current,
64                            range: name_ref.syntax().text_range(),
65                            kind: target_kind,
66                        });
67                    }
68                },
69                ast::Name(name) => {
70                    // Find refs also includes the defs so we have to check.
71                    let found = SyntaxNodePtr::new(name.syntax());
72                    if target_defs.contains(&found) {
73                        refs.push(Location {
74                            file: FileId::Current,
75                            range: name.syntax().text_range(),
76                            kind: target_kind,
77                        });
78                    }
79                },
80                _ => (),
81            }
82        }
83    }
84
85    refs.sort_by_key(|loc| (loc.file, loc.range.start()));
86    refs
87}
88
89fn find_target_defs(
90    file: &ast::SourceFile,
91    offset: TextSize,
92    current_binder: &Binder,
93    builtins_tree: &ast::SourceFile,
94    builtins_binder: &Binder,
95) -> Option<(FileId, SmallVec<[SyntaxNodePtr; 1]>, LocationKind)> {
96    let token = token_from_offset(file, offset)?;
97    let parent = token.parent()?;
98
99    if let Some(name) = ast::Name::cast(parent.clone())
100        && let Some(kind) = classify_def_node(name.syntax()).map(LocationKind::from)
101    {
102        return Some((
103            FileId::Current,
104            smallvec![SyntaxNodePtr::new(name.syntax())],
105            kind,
106        ));
107    }
108
109    if let Some(name_ref) = ast::NameRef::cast(parent.clone()) {
110        for file_id in [FileId::Current, FileId::Builtins] {
111            let (binder, root) = match file_id {
112                FileId::Current => (current_binder, file.syntax()),
113                FileId::Builtins => (builtins_binder, builtins_tree.syntax()),
114            };
115            if let Some((ptrs, kind)) = resolve::resolve_name_ref(binder, root, &name_ref) {
116                return Some((file_id, ptrs, kind));
117            }
118        }
119    }
120
121    None
122}
123
124#[cfg(test)]
125mod test {
126    use crate::builtins::BUILTINS_SQL;
127    use crate::db::{Database, File};
128    use crate::find_references::find_references;
129    use crate::goto_definition::FileId;
130    use crate::test_utils::fixture;
131    use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
132    use insta::assert_snapshot;
133    use rowan::TextRange;
134
135    #[track_caller]
136    fn find_refs(sql: &str) -> String {
137        let (mut offset, sql) = fixture(sql);
138        offset = offset.checked_sub(1.into()).unwrap_or_default();
139        let db = Database::default();
140        let file = File::new(&db, sql.clone().into());
141        assert_eq!(crate::db::parse(&db, file).errors(), vec![]);
142
143        let references = find_references(&db, file, offset);
144
145        let offset_usize: usize = offset.into();
146
147        let mut current_refs = vec![];
148        let mut builtin_refs = vec![];
149        for (i, location) in references.iter().enumerate() {
150            let label_index = i + 1;
151            match location.file {
152                FileId::Current => current_refs.push((label_index, location.range)),
153                FileId::Builtins => builtin_refs.push((label_index, location.range)),
154            }
155        }
156
157        let has_builtins = !builtin_refs.is_empty();
158
159        let mut snippet = Snippet::source(&sql).fold(true);
160        if has_builtins {
161            snippet = snippet.path("current.sql");
162        }
163        snippet = snippet.annotation(
164            AnnotationKind::Context
165                .span(offset_usize..offset_usize + 1)
166                .label("0. query"),
167        );
168        snippet = annotate_refs(snippet, current_refs);
169
170        let mut groups = vec![Level::INFO.primary_title("references").element(snippet)];
171
172        if has_builtins {
173            let builtins_snippet = Snippet::source(BUILTINS_SQL).path("builtin.sql").fold(true);
174            let builtins_snippet = annotate_refs(builtins_snippet, builtin_refs);
175            groups.push(
176                Level::INFO
177                    .primary_title("references")
178                    .element(builtins_snippet),
179            );
180        }
181
182        let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
183        renderer
184            .render(&groups)
185            .to_string()
186            .replace("info: references", "")
187    }
188
189    fn annotate_refs<'a>(
190        mut snippet: Snippet<'a, annotate_snippets::Annotation<'a>>,
191        refs: Vec<(usize, TextRange)>,
192    ) -> Snippet<'a, annotate_snippets::Annotation<'a>> {
193        for (label_index, range) in refs {
194            snippet = snippet.annotation(
195                AnnotationKind::Context
196                    .span(range.into())
197                    .label(format!("{}. reference", label_index)),
198            );
199        }
200        snippet
201    }
202
203    #[test]
204    fn simple_table_reference() {
205        assert_snapshot!(find_refs("
206create table t();
207drop table t$0;
208"), @r"
209          ╭▸ 
210        2 │ create table t();
211          │              ─ 1. reference
212        3 │ drop table t;
213          │            ┬
214          │            │
215          │            0. query
216          ╰╴           2. reference
217        ");
218    }
219
220    #[test]
221    fn multiple_references() {
222        assert_snapshot!(find_refs("
223create table users();
224drop table users$0;
225table users;
226"), @r"
227          ╭▸ 
228        2 │ create table users();
229          │              ───── 1. reference
230        3 │ drop table users;
231          │            ┬───┬
232          │            │   │
233          │            │   0. query
234          │            2. reference
235        4 │ table users;
236          ╰╴      ───── 3. reference
237        ");
238    }
239
240    #[test]
241    fn join_using_column() {
242        assert_snapshot!(find_refs("
243create table t(id int);
244create table u(id int);
245select * from t join u using (id$0);
246"), @r"
247          ╭▸ 
248        2 │ create table t(id int);
249          │                ── 1. reference
250        3 │ create table u(id int);
251          │                ── 2. reference
252        4 │ select * from t join u using (id);
253          │                               ┬┬
254          │                               ││
255          │                               │0. query
256          ╰╴                              3. reference
257        ");
258    }
259
260    #[test]
261    fn find_from_definition() {
262        assert_snapshot!(find_refs("
263create table t$0();
264drop table t;
265"), @r"
266          ╭▸ 
267        2 │ create table t();
268          │              ┬
269          │              │
270          │              0. query
271          │              1. reference
272        3 │ drop table t;
273          ╰╴           ─ 2. reference
274        ");
275    }
276
277    #[test]
278    fn with_schema_qualified() {
279        assert_snapshot!(find_refs("
280create table public.users();
281drop table public.users$0;
282table users;
283"), @r"
284          ╭▸ 
285        2 │ create table public.users();
286          │                     ───── 1. reference
287        3 │ drop table public.users;
288          │                   ┬───┬
289          │                   │   │
290          │                   │   0. query
291          │                   2. reference
292        4 │ table users;
293          ╰╴      ───── 3. reference
294        ");
295    }
296
297    #[test]
298    fn temp_table_do_not_shadows_public() {
299        assert_snapshot!(find_refs("
300create table t();
301create temp table t$0();
302drop table t;
303"), @r"
304          ╭▸ 
305        3 │ create temp table t();
306          │                   ┬
307          │                   │
308          │                   0. query
309          ╰╴                  1. reference
310        ");
311    }
312
313    #[test]
314    fn different_schema_no_match() {
315        assert_snapshot!(find_refs("
316create table foo.t();
317create table bar.t$0();
318"), @r"
319          ╭▸ 
320        3 │ create table bar.t();
321          │                  ┬
322          │                  │
323          │                  0. query
324          ╰╴                 1. reference
325        ");
326    }
327
328    #[test]
329    fn with_search_path() {
330        assert_snapshot!(find_refs("
331set search_path to myschema;
332create table myschema.users$0();
333drop table users;
334"), @r"
335          ╭▸ 
336        3 │ create table myschema.users();
337          │                       ┬───┬
338          │                       │   │
339          │                       │   0. query
340          │                       1. reference
341        4 │ drop table users;
342          ╰╴           ───── 2. reference
343        ");
344    }
345
346    #[test]
347    fn temp_table_with_pg_temp_schema() {
348        assert_snapshot!(find_refs("
349create temp table t();
350drop table pg_temp.t$0;
351"), @r"
352          ╭▸ 
353        2 │ create temp table t();
354          │                   ─ 1. reference
355        3 │ drop table pg_temp.t;
356          │                    ┬
357          │                    │
358          │                    0. query
359          ╰╴                   2. reference
360        ");
361    }
362
363    #[test]
364    fn case_insensitive() {
365        assert_snapshot!(find_refs("
366create table Users();
367drop table USERS$0;
368table users;
369"), @r"
370          ╭▸ 
371        2 │ create table Users();
372          │              ───── 1. reference
373        3 │ drop table USERS;
374          │            ┬───┬
375          │            │   │
376          │            │   0. query
377          │            2. reference
378        4 │ table users;
379          ╰╴      ───── 3. reference
380        ");
381    }
382    #[test]
383    fn case_insensitive_part_2() {
384        // we should see refs for `drop table` and `table`
385        assert_snapshot!(find_refs(r#"
386create table actors();
387create table "Actors"();
388drop table ACTORS$0;
389table actors;
390"#), @r#"
391          ╭▸ 
392        2 │ create table actors();
393          │              ────── 1. reference
394        3 │ create table "Actors"();
395        4 │ drop table ACTORS;
396          │            ┬────┬
397          │            │    │
398          │            │    0. query
399          │            2. reference
400        5 │ table actors;
401          ╰╴      ────── 3. reference
402        "#);
403    }
404
405    #[test]
406    fn case_insensitive_with_schema() {
407        assert_snapshot!(find_refs("
408create table Public.Users();
409drop table PUBLIC.USERS$0;
410table public.users;
411"), @r"
412          ╭▸ 
413        2 │ create table Public.Users();
414          │                     ───── 1. reference
415        3 │ drop table PUBLIC.USERS;
416          │                   ┬───┬
417          │                   │   │
418          │                   │   0. query
419          │                   2. reference
420        4 │ table public.users;
421          ╰╴             ───── 3. reference
422        ");
423    }
424
425    #[test]
426    fn no_partial_match() {
427        assert_snapshot!(find_refs("
428create table t$0();
429create table temp_t();
430"), @r"
431          ╭▸ 
432        2 │ create table t();
433          │              ┬
434          │              │
435          │              0. query
436          ╰╴             1. reference
437        ");
438    }
439
440    #[test]
441    fn identifier_boundaries() {
442        assert_snapshot!(find_refs("
443create table foo$0();
444drop table foo;
445drop table foo1;
446drop table barfoo;
447drop table foo_bar;
448"), @r"
449          ╭▸ 
450        2 │ create table foo();
451          │              ┬─┬
452          │              │ │
453          │              │ 0. query
454          │              1. reference
455        3 │ drop table foo;
456          ╰╴           ─── 2. reference
457        ");
458    }
459
460    #[test]
461    fn builtin_function_references() {
462        assert_snapshot!(find_refs("
463select now$0();
464select now();
465"), @"
466              ╭▸ current.sql:2:8
467468            2 │ select now();
469              │        ┬─┬
470              │        │ │
471              │        │ 0. query
472              │        1. reference
473            3 │ select now();
474              │        ─── 2. reference
475              ╰╴
476
477              ╭▸ builtin.sql:11089:28
478479        11089 │ create function pg_catalog.now() returns timestamp with time zone
480              ╰╴                           ─── 3. reference
481        ");
482    }
483}