squawk_ide/
find_references.rs1use crate::db::parse;
2use crate::file::InFile;
3use crate::goto_definition;
4use crate::location::Location;
5use rowan::TextSize;
6use salsa::Database as Db;
7use squawk_syntax::{
8 SyntaxNode,
9 ast::{self, AstNode},
10};
11
12fn is_reference_node(node: &SyntaxNode) -> bool {
13 if ast::NameRef::can_cast(node.kind()) {
14 return true;
15 }
16
17 if let Some(ty) = ast::Type::cast(node.clone()) {
18 return match ty {
19 ast::Type::BitType(_)
20 | ast::Type::CharType(_)
21 | ast::Type::DoubleType(_)
22 | ast::Type::IntervalType(_)
23 | ast::Type::TimeType(_) => true,
24 ast::Type::ArrayType(_)
25 | ast::Type::ExprType(_)
26 | ast::Type::PathType(_)
27 | ast::Type::PercentType(_) => false,
28 };
29 }
30
31 false
32}
33
34pub fn find_references(db: &dyn Db, position: InFile<TextSize>) -> Vec<Location> {
35 let file = position.file_id;
36 let targets = goto_definition::goto_definition(db, position);
37 let Some(first) = targets.first() else {
38 return vec![];
39 };
40
41 let mut refs = targets.to_vec();
42
43 for node in parse(db, file)
44 .tree()
45 .syntax()
46 .descendants()
47 .filter(is_reference_node)
48 {
49 let range = node.text_range();
50 let matches = goto_definition::goto_definition(db, InFile::new(file, range.start()))
51 .into_iter()
52 .any(|location| targets.contains(&location));
53 if matches {
54 refs.push(Location {
55 file,
56 range,
57 kind: first.kind,
58 });
59 }
60 }
61 refs.sort_by_key(|loc| (loc.file != file, loc.range.start()));
62 refs
63}
64
65#[cfg(test)]
66mod test {
67 use crate::builtins::builtins_file;
68 use crate::db::File;
69
70 use crate::find_references::find_references;
71 use crate::test_utils::Fixture;
72 use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
73 use insta::assert_snapshot;
74 use rowan::TextRange;
75 use rustc_hash::FxHashMap;
76
77 #[must_use]
78 #[track_caller]
79 fn find_refs(sql: &str) -> String {
80 let fixture = Fixture::new(sql);
81 let marker = fixture.marker();
82 let offset = marker.offset_before();
83 let query_span = marker.range();
84 let db = fixture.db();
85 let current_file = offset.file_id;
86
87 let references = find_references(db, offset);
88
89 let mut file_paths = FxHashMap::default();
90 file_paths.insert(current_file, "current.sql");
91 file_paths.insert(builtins_file(db), "builtins.sql");
92
93 let mut refs_by_file: FxHashMap<File, Vec<(usize, TextRange)>> = FxHashMap::default();
94 for (i, location) in references.iter().enumerate() {
95 refs_by_file
96 .entry(location.file)
97 .or_default()
98 .push((i + 1, location.range));
99 }
100
101 let multi_file = refs_by_file.len() > 1 || !refs_by_file.contains_key(¤t_file);
102
103 let mut snippet = Snippet::source(current_file.content(db).as_ref()).fold(true);
104 if multi_file {
105 snippet = snippet.path(*file_paths.get(¤t_file).unwrap());
106 }
107 snippet = snippet.annotation(AnnotationKind::Context.span(query_span).label("0. query"));
108 if let Some(current_refs) = refs_by_file.remove(¤t_file) {
109 snippet = annotate_refs(snippet, current_refs);
110 }
111
112 let mut groups = vec![Level::INFO.primary_title("references").element(snippet)];
113
114 for (ref_file, refs) in refs_by_file {
115 let path = file_paths.get(&ref_file).unwrap();
116 let other_snippet = Snippet::source(ref_file.content(db).as_ref())
117 .path(*path)
118 .fold(true);
119 let other_snippet = annotate_refs(other_snippet, refs);
120 groups.push(
121 Level::INFO
122 .primary_title("references")
123 .element(other_snippet),
124 );
125 }
126
127 let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
128 renderer
129 .render(&groups)
130 .to_string()
131 .replace("info: references", "")
132 }
133
134 fn annotate_refs<'a>(
135 mut snippet: Snippet<'a, annotate_snippets::Annotation<'a>>,
136 refs: Vec<(usize, TextRange)>,
137 ) -> Snippet<'a, annotate_snippets::Annotation<'a>> {
138 for (label_index, range) in refs {
139 snippet = snippet.annotation(
140 AnnotationKind::Context
141 .span(range.into())
142 .label(format!("{label_index}. reference")),
143 );
144 }
145 snippet
146 }
147
148 #[test]
149 fn simple_table_reference() {
150 assert_snapshot!(find_refs("
151create table t();
152drop table t$0;
153"), @r"
154 ╭▸
155 2 │ create table t();
156 │ ─ 1. reference
157 3 │ drop table t;
158 │ ┬
159 │ │
160 │ 0. query
161 ╰╴ 2. reference
162 ");
163 }
164
165 #[test]
166 fn multiple_references() {
167 assert_snapshot!(find_refs("
168create table users();
169drop table users$0;
170table users;
171"), @r"
172 ╭▸
173 2 │ create table users();
174 │ ───── 1. reference
175 3 │ drop table users;
176 │ ┬───┬
177 │ │ │
178 │ │ 0. query
179 │ 2. reference
180 4 │ table users;
181 ╰╴ ───── 3. reference
182 ");
183 }
184
185 #[test]
186 fn join_using_column() {
187 assert_snapshot!(find_refs("
188create table t(id int);
189create table u(id int);
190select * from t join u using (id$0);
191"), @r"
192 ╭▸
193 2 │ create table t(id int);
194 │ ── 1. reference
195 3 │ create table u(id int);
196 │ ── 2. reference
197 4 │ select * from t join u using (id);
198 │ ┬┬
199 │ ││
200 │ │0. query
201 ╰╴ 3. reference
202 ");
203 }
204
205 #[test]
206 fn find_from_definition() {
207 assert_snapshot!(find_refs("
208create table t$0();
209drop table t;
210"), @r"
211 ╭▸
212 2 │ create table t();
213 │ ┬
214 │ │
215 │ 0. query
216 │ 1. reference
217 3 │ drop table t;
218 ╰╴ ─ 2. reference
219 ");
220 }
221
222 #[test]
223 fn with_schema_qualified() {
224 assert_snapshot!(find_refs("
225create table public.users();
226drop table public.users$0;
227table users;
228"), @r"
229 ╭▸
230 2 │ create table public.users();
231 │ ───── 1. reference
232 3 │ drop table public.users;
233 │ ┬───┬
234 │ │ │
235 │ │ 0. query
236 │ 2. reference
237 4 │ table users;
238 ╰╴ ───── 3. reference
239 ");
240 }
241
242 #[test]
243 fn temp_table_do_not_shadows_public() {
244 assert_snapshot!(find_refs("
245create table t();
246create temp table t$0();
247drop table t;
248"), @r"
249 ╭▸
250 3 │ create temp table t();
251 │ ┬
252 │ │
253 │ 0. query
254 ╰╴ 1. reference
255 ");
256 }
257
258 #[test]
259 fn different_schema_no_match() {
260 assert_snapshot!(find_refs("
261create table foo.t();
262create table bar.t$0();
263"), @r"
264 ╭▸
265 3 │ create table bar.t();
266 │ ┬
267 │ │
268 │ 0. query
269 ╰╴ 1. reference
270 ");
271 }
272
273 #[test]
274 fn with_search_path() {
275 assert_snapshot!(find_refs("
276set search_path to myschema;
277create table myschema.users$0();
278drop table users;
279"), @r"
280 ╭▸
281 3 │ create table myschema.users();
282 │ ┬───┬
283 │ │ │
284 │ │ 0. query
285 │ 1. reference
286 4 │ drop table users;
287 ╰╴ ───── 2. reference
288 ");
289 }
290
291 #[test]
292 fn temp_table_with_pg_temp_schema() {
293 assert_snapshot!(find_refs("
294create temp table t();
295drop table pg_temp.t$0;
296"), @r"
297 ╭▸
298 2 │ create temp table t();
299 │ ─ 1. reference
300 3 │ drop table pg_temp.t;
301 │ ┬
302 │ │
303 │ 0. query
304 ╰╴ 2. reference
305 ");
306 }
307
308 #[test]
309 fn case_insensitive() {
310 assert_snapshot!(find_refs("
311create table Users();
312drop table USERS$0;
313table users;
314"), @r"
315 ╭▸
316 2 │ create table Users();
317 │ ───── 1. reference
318 3 │ drop table USERS;
319 │ ┬───┬
320 │ │ │
321 │ │ 0. query
322 │ 2. reference
323 4 │ table users;
324 ╰╴ ───── 3. reference
325 ");
326 }
327 #[test]
328 fn case_insensitive_part_2() {
329 assert_snapshot!(find_refs(r#"
331create table actors();
332create table "Actors"();
333drop table ACTORS$0;
334table actors;
335"#), @r#"
336 ╭▸
337 2 │ create table actors();
338 │ ────── 1. reference
339 3 │ create table "Actors"();
340 4 │ drop table ACTORS;
341 │ ┬────┬
342 │ │ │
343 │ │ 0. query
344 │ 2. reference
345 5 │ table actors;
346 ╰╴ ────── 3. reference
347 "#);
348 }
349
350 #[test]
351 fn case_insensitive_with_schema() {
352 assert_snapshot!(find_refs("
353create table Public.Users();
354drop table PUBLIC.USERS$0;
355table public.users;
356"), @r"
357 ╭▸
358 2 │ create table Public.Users();
359 │ ───── 1. reference
360 3 │ drop table PUBLIC.USERS;
361 │ ┬───┬
362 │ │ │
363 │ │ 0. query
364 │ 2. reference
365 4 │ table public.users;
366 ╰╴ ───── 3. reference
367 ");
368 }
369
370 #[test]
371 fn no_partial_match() {
372 assert_snapshot!(find_refs("
373create table t$0();
374create table temp_t();
375"), @r"
376 ╭▸
377 2 │ create table t();
378 │ ┬
379 │ │
380 │ 0. query
381 ╰╴ 1. reference
382 ");
383 }
384
385 #[test]
386 fn identifier_boundaries() {
387 assert_snapshot!(find_refs("
388create table foo$0();
389drop table foo;
390drop table foo1;
391drop table barfoo;
392drop table foo_bar;
393"), @r"
394 ╭▸
395 2 │ create table foo();
396 │ ┬─┬
397 │ │ │
398 │ │ 0. query
399 │ 1. reference
400 3 │ drop table foo;
401 ╰╴ ─── 2. reference
402 ");
403 }
404
405 #[test]
406 fn builtin_function_references() {
407 assert_snapshot!(find_refs("
408-- include-builtins
409select now$0();
410select now();
411"), @"
412 ╭▸ current.sql:3:8
413 │
414 3 │ select now();
415 │ ┬─┬
416 │ │ │
417 │ │ 0. query
418 │ 1. reference
419 4 │ select now();
420 │ ─── 2. reference
421 ╰╴
422
423 ╭▸ builtins.sql:11089:28
424 │
425 11089 │ create function pg_catalog.now() returns timestamp with time zone
426 ╰╴ ─── 3. reference
427 ");
428 }
429
430 #[test]
431 fn bit() {
432 assert_snapshot!(find_refs("
433create type pg_catalog.bit$0;
434
435create function pg_catalog.bit(bigint, integer) returns bit
436 language internal;
437
438create function pg_catalog.bit(bit, integer, boolean) returns bit
439 language internal;
440
441create function pg_catalog.bit(integer, integer) returns bit
442 language internal;
443"), @"
444 ╭▸
445 2 │ create type pg_catalog.bit;
446 │ ┬─┬
447 │ │ │
448 │ │ 0. query
449 │ 1. reference
450 3 │
451 4 │ create function pg_catalog.bit(bigint, integer) returns bit
452 │ ─── 2. reference
453 ‡
454 7 │ create function pg_catalog.bit(bit, integer, boolean) returns bit
455 │ ─── 3. reference ─── 4. reference
456 ‡
457 10 │ create function pg_catalog.bit(integer, integer) returns bit
458 ╰╴ ─── 5. reference
459 ");
460 }
461
462 #[test]
463 fn char() {
464 assert_snapshot!(find_refs("
465create type pg_catalog.bpchar$0;
466
467select '1'::char;
468select '1'::bpchar;
469"), @"
470 ╭▸
471 2 │ create type pg_catalog.bpchar;
472 │ ┬────┬
473 │ │ │
474 │ │ 0. query
475 │ 1. reference
476 3 │
477 4 │ select '1'::char;
478 │ ──── 2. reference
479 5 │ select '1'::bpchar;
480 ╰╴ ────── 3. reference
481 ");
482 }
483}