squawk_ide/
find_references.rs

1use crate::binder::{self, Binder};
2use crate::offsets::token_from_offset;
3use crate::resolve;
4use rowan::{TextRange, TextSize};
5use smallvec::{SmallVec, smallvec};
6use squawk_syntax::SyntaxNode;
7use squawk_syntax::{
8    SyntaxNodePtr,
9    ast::{self, AstNode},
10    match_ast,
11};
12
13pub fn find_references(file: &ast::SourceFile, offset: TextSize) -> Vec<TextRange> {
14    let binder = binder::bind(file);
15    let root = file.syntax();
16    let Some(targets) = find_targets(file, root, offset, &binder) else {
17        return vec![];
18    };
19
20    let mut refs = vec![];
21    for node in file.syntax().descendants() {
22        match_ast! {
23            match node {
24                ast::NameRef(name_ref) => {
25                    if let Some(found_refs) = resolve::resolve_name_ref_ptrs(&binder, root, &name_ref)
26                        && found_refs.iter().any(|ptr| targets.contains(ptr))
27                    {
28                        refs.push(name_ref.syntax().text_range());
29                    }
30                },
31                ast::Name(name) => {
32                    let found = SyntaxNodePtr::new(name.syntax());
33                    if targets.contains(&found) {
34                        refs.push(name.syntax().text_range());
35                    }
36                },
37                _ => (),
38            }
39        }
40    }
41
42    refs.sort_by_key(|range| range.start());
43    refs
44}
45
46fn find_targets(
47    file: &ast::SourceFile,
48    root: &SyntaxNode,
49    offset: TextSize,
50    binder: &Binder,
51) -> Option<SmallVec<[SyntaxNodePtr; 1]>> {
52    let token = token_from_offset(file, offset)?;
53    let parent = token.parent()?;
54
55    if let Some(name) = ast::Name::cast(parent.clone()) {
56        return Some(smallvec![SyntaxNodePtr::new(name.syntax())]);
57    }
58
59    if let Some(name_ref) = ast::NameRef::cast(parent.clone()) {
60        return resolve::resolve_name_ref_ptrs(binder, root, &name_ref);
61    }
62
63    None
64}
65
66#[cfg(test)]
67mod test {
68    use crate::find_references::find_references;
69    use crate::test_utils::fixture;
70    use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
71    use insta::assert_snapshot;
72    use squawk_syntax::ast;
73
74    #[track_caller]
75    fn find_refs(sql: &str) -> String {
76        let (mut offset, sql) = fixture(sql);
77        offset = offset.checked_sub(1.into()).unwrap_or_default();
78        let parse = ast::SourceFile::parse(&sql);
79        assert_eq!(parse.errors(), vec![]);
80        let file: ast::SourceFile = parse.tree();
81
82        let references = find_references(&file, offset);
83
84        let offset_usize: usize = offset.into();
85
86        let labels: Vec<String> = (1..=references.len())
87            .map(|i| format!("{}. reference", i))
88            .collect();
89
90        let mut snippet = Snippet::source(&sql).fold(true).annotation(
91            AnnotationKind::Context
92                .span(offset_usize..offset_usize + 1)
93                .label("0. query"),
94        );
95
96        for (i, range) in references.iter().enumerate() {
97            snippet = snippet.annotation(
98                AnnotationKind::Context
99                    .span((*range).into())
100                    .label(&labels[i]),
101            );
102        }
103
104        let group = Level::INFO.primary_title("references").element(snippet);
105        let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
106        renderer
107            .render(&[group])
108            .to_string()
109            .replace("info: references", "")
110    }
111
112    #[test]
113    fn simple_table_reference() {
114        assert_snapshot!(find_refs("
115create table t();
116drop table t$0;
117"), @r"
118          ╭▸ 
119        2 │ create table t();
120          │              ─ 1. reference
121        3 │ drop table t;
122          │            ┬
123          │            │
124          │            0. query
125          ╰╴           2. reference
126        ");
127    }
128
129    #[test]
130    fn multiple_references() {
131        assert_snapshot!(find_refs("
132create table users();
133drop table users$0;
134table users;
135"), @r"
136          ╭▸ 
137        2 │ create table users();
138          │              ───── 1. reference
139        3 │ drop table users;
140          │            ┬───┬
141          │            │   │
142          │            │   0. query
143          │            2. reference
144        4 │ table users;
145          ╰╴      ───── 3. reference
146        ");
147    }
148
149    #[test]
150    fn join_using_column() {
151        assert_snapshot!(find_refs("
152create table t(id int);
153create table u(id int);
154select * from t join u using (id$0);
155"), @r"
156          ╭▸ 
157        2 │ create table t(id int);
158          │                ── 1. reference
159        3 │ create table u(id int);
160          │                ── 2. reference
161        4 │ select * from t join u using (id);
162          │                               ┬┬
163          │                               ││
164          │                               │0. query
165          ╰╴                              3. reference
166        ");
167    }
168
169    #[test]
170    fn find_from_definition() {
171        assert_snapshot!(find_refs("
172create table t$0();
173drop table t;
174"), @r"
175          ╭▸ 
176        2 │ create table t();
177          │              ┬
178          │              │
179          │              0. query
180          │              1. reference
181        3 │ drop table t;
182          ╰╴           ─ 2. reference
183        ");
184    }
185
186    #[test]
187    fn with_schema_qualified() {
188        assert_snapshot!(find_refs("
189create table public.users();
190drop table public.users$0;
191table users;
192"), @r"
193          ╭▸ 
194        2 │ create table public.users();
195          │                     ───── 1. reference
196        3 │ drop table public.users;
197          │                   ┬───┬
198          │                   │   │
199          │                   │   0. query
200          │                   2. reference
201        4 │ table users;
202          ╰╴      ───── 3. reference
203        ");
204    }
205
206    #[test]
207    fn temp_table_do_not_shadows_public() {
208        assert_snapshot!(find_refs("
209create table t();
210create temp table t$0();
211drop table t;
212"), @r"
213          ╭▸ 
214        3 │ create temp table t();
215          │                   ┬
216          │                   │
217          │                   0. query
218          ╰╴                  1. reference
219        ");
220    }
221
222    #[test]
223    fn different_schema_no_match() {
224        assert_snapshot!(find_refs("
225create table foo.t();
226create table bar.t$0();
227"), @r"
228          ╭▸ 
229        3 │ create table bar.t();
230          │                  ┬
231          │                  │
232          │                  0. query
233          ╰╴                 1. reference
234        ");
235    }
236
237    #[test]
238    fn with_search_path() {
239        assert_snapshot!(find_refs("
240set search_path to myschema;
241create table myschema.users$0();
242drop table users;
243"), @r"
244          ╭▸ 
245        3 │ create table myschema.users();
246          │                       ┬───┬
247          │                       │   │
248          │                       │   0. query
249          │                       1. reference
250        4 │ drop table users;
251          ╰╴           ───── 2. reference
252        ");
253    }
254
255    #[test]
256    fn temp_table_with_pg_temp_schema() {
257        assert_snapshot!(find_refs("
258create temp table t();
259drop table pg_temp.t$0;
260"), @r"
261          ╭▸ 
262        2 │ create temp table t();
263          │                   ─ 1. reference
264        3 │ drop table pg_temp.t;
265          │                    ┬
266          │                    │
267          │                    0. query
268          ╰╴                   2. reference
269        ");
270    }
271
272    #[test]
273    fn case_insensitive() {
274        assert_snapshot!(find_refs("
275create table Users();
276drop table USERS$0;
277table users;
278"), @r"
279          ╭▸ 
280        2 │ create table Users();
281          │              ───── 1. reference
282        3 │ drop table USERS;
283          │            ┬───┬
284          │            │   │
285          │            │   0. query
286          │            2. reference
287        4 │ table users;
288          ╰╴      ───── 3. reference
289        ");
290    }
291    #[test]
292    fn case_insensitive_part_2() {
293        // we should see refs for `drop table` and `table`
294        assert_snapshot!(find_refs(r#"
295create table actors();
296create table "Actors"();
297drop table ACTORS$0;
298table actors;
299"#), @r#"
300          ╭▸ 
301        2 │ create table actors();
302          │              ────── 1. reference
303        3 │ create table "Actors"();
304        4 │ drop table ACTORS;
305          │            ┬────┬
306          │            │    │
307          │            │    0. query
308          │            2. reference
309        5 │ table actors;
310          ╰╴      ────── 3. reference
311        "#);
312    }
313
314    #[test]
315    fn case_insensitive_with_schema() {
316        assert_snapshot!(find_refs("
317create table Public.Users();
318drop table PUBLIC.USERS$0;
319table public.users;
320"), @r"
321          ╭▸ 
322        2 │ create table Public.Users();
323          │                     ───── 1. reference
324        3 │ drop table PUBLIC.USERS;
325          │                   ┬───┬
326          │                   │   │
327          │                   │   0. query
328          │                   2. reference
329        4 │ table public.users;
330          ╰╴             ───── 3. reference
331        ");
332    }
333
334    #[test]
335    fn no_partial_match() {
336        assert_snapshot!(find_refs("
337create table t$0();
338create table temp_t();
339"), @r"
340          ╭▸ 
341        2 │ create table t();
342          │              ┬
343          │              │
344          │              0. query
345          ╰╴             1. reference
346        ");
347    }
348
349    #[test]
350    fn identifier_boundaries() {
351        assert_snapshot!(find_refs("
352create table foo$0();
353drop table foo;
354drop table foo1;
355drop table barfoo;
356drop table foo_bar;
357"), @r"
358          ╭▸ 
359        2 │ create table foo();
360          │              ┬─┬
361          │              │ │
362          │              │ 0. query
363          │              1. reference
364        3 │ drop table foo;
365          ╰╴           ─── 2. reference
366        ");
367    }
368}