1use crate::binder::Binder;
2use crate::builtins::{builtins_binder, parse_builtins};
3use crate::db::{File, bind, parse};
4use crate::goto_definition::{FileId, Location};
5use crate::offsets::token_from_offset;
6use crate::resolve;
7use rowan::TextSize;
8use salsa::Database as Db;
9use smallvec::{SmallVec, smallvec};
10use squawk_syntax::{
11 SyntaxNodePtr,
12 ast::{self, AstNode},
13 match_ast,
14};
15
16#[salsa::tracked]
17pub fn find_references(db: &dyn Db, file: File, offset: TextSize) -> Vec<Location> {
18 let parse = parse(db, file);
19 let source_file = parse.tree();
20
21 let current_binder = bind(db, file);
22
23 let builtins_tree = parse_builtins(db).tree();
24 let builtins_binder = builtins_binder(db);
25
26 let Some((target_file, target_defs)) = find_target_defs(
27 &source_file,
28 offset,
29 ¤t_binder,
30 &builtins_tree,
31 &builtins_binder,
32 ) else {
33 return vec![];
34 };
35
36 let (binder, root) = match target_file {
37 FileId::Current => (¤t_binder, source_file.syntax()),
38 FileId::Builtins => (&builtins_binder, builtins_tree.syntax()),
39 };
40
41 let mut refs: Vec<Location> = vec![];
42
43 if target_file == FileId::Builtins {
44 for ptr in &target_defs {
45 refs.push(Location {
46 file: FileId::Builtins,
47 range: ptr.to_node(builtins_tree.syntax()).text_range(),
48 });
49 }
50 }
51
52 for node in source_file.syntax().descendants() {
53 match_ast! {
54 match node {
55 ast::NameRef(name_ref) => {
56 if let Some(found_defs) = resolve::resolve_name_ref_ptrs(binder, root, &name_ref)
58 && found_defs.iter().any(|def| target_defs.contains(def))
59 {
60 refs.push(Location {
61 file: FileId::Current,
62 range: name_ref.syntax().text_range(),
63 });
64 }
65 },
66 ast::Name(name) => {
67 let found = SyntaxNodePtr::new(name.syntax());
69 if target_defs.contains(&found) {
70 refs.push(Location {
71 file: FileId::Current,
72 range: name.syntax().text_range(),
73 });
74 }
75 },
76 _ => (),
77 }
78 }
79 }
80
81 refs.sort_by_key(|loc| (loc.file, loc.range.start()));
82 refs
83}
84
85fn find_target_defs(
86 file: &ast::SourceFile,
87 offset: TextSize,
88 current_binder: &Binder,
89 builtins_tree: &ast::SourceFile,
90 builtins_binder: &Binder,
91) -> Option<(FileId, SmallVec<[SyntaxNodePtr; 1]>)> {
92 let token = token_from_offset(file, offset)?;
93 let parent = token.parent()?;
94
95 if let Some(name) = ast::Name::cast(parent.clone()) {
96 return Some((
97 FileId::Current,
98 smallvec![SyntaxNodePtr::new(name.syntax())],
99 ));
100 }
101
102 if let Some(name_ref) = ast::NameRef::cast(parent.clone()) {
103 for file_id in [FileId::Current, FileId::Builtins] {
104 let (binder, root) = match file_id {
105 FileId::Current => (current_binder, file.syntax()),
106 FileId::Builtins => (builtins_binder, builtins_tree.syntax()),
107 };
108 if let Some(ptrs) = resolve::resolve_name_ref_ptrs(binder, root, &name_ref) {
109 return Some((file_id, ptrs));
110 }
111 }
112 }
113
114 None
115}
116
117#[cfg(test)]
118mod test {
119 use crate::builtins::BUILTINS_SQL;
120 use crate::db::{Database, File};
121 use crate::find_references::find_references;
122 use crate::goto_definition::FileId;
123 use crate::test_utils::fixture;
124 use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
125 use insta::assert_snapshot;
126 use rowan::TextRange;
127
128 #[track_caller]
129 fn find_refs(sql: &str) -> String {
130 let (mut offset, sql) = fixture(sql);
131 offset = offset.checked_sub(1.into()).unwrap_or_default();
132 let db = Database::default();
133 let file = File::new(&db, sql.clone().into());
134 assert_eq!(crate::db::parse(&db, file).errors(), vec![]);
135
136 let references = find_references(&db, file, offset);
137
138 let offset_usize: usize = offset.into();
139
140 let mut current_refs = vec![];
141 let mut builtin_refs = vec![];
142 for (i, location) in references.iter().enumerate() {
143 let label_index = i + 1;
144 match location.file {
145 FileId::Current => current_refs.push((label_index, location.range)),
146 FileId::Builtins => builtin_refs.push((label_index, location.range)),
147 }
148 }
149
150 let has_builtins = !builtin_refs.is_empty();
151
152 let mut snippet = Snippet::source(&sql).fold(true);
153 if has_builtins {
154 snippet = snippet.path("current.sql");
155 }
156 snippet = snippet.annotation(
157 AnnotationKind::Context
158 .span(offset_usize..offset_usize + 1)
159 .label("0. query"),
160 );
161 snippet = annotate_refs(snippet, current_refs);
162
163 let mut groups = vec![Level::INFO.primary_title("references").element(snippet)];
164
165 if has_builtins {
166 let builtins_snippet = Snippet::source(BUILTINS_SQL).path("builtin.sql").fold(true);
167 let builtins_snippet = annotate_refs(builtins_snippet, builtin_refs);
168 groups.push(
169 Level::INFO
170 .primary_title("references")
171 .element(builtins_snippet),
172 );
173 }
174
175 let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
176 renderer
177 .render(&groups)
178 .to_string()
179 .replace("info: references", "")
180 }
181
182 fn annotate_refs<'a>(
183 mut snippet: Snippet<'a, annotate_snippets::Annotation<'a>>,
184 refs: Vec<(usize, TextRange)>,
185 ) -> Snippet<'a, annotate_snippets::Annotation<'a>> {
186 for (label_index, range) in refs {
187 snippet = snippet.annotation(
188 AnnotationKind::Context
189 .span(range.into())
190 .label(format!("{}. reference", label_index)),
191 );
192 }
193 snippet
194 }
195
196 #[test]
197 fn simple_table_reference() {
198 assert_snapshot!(find_refs("
199create table t();
200drop table t$0;
201"), @r"
202 ╭▸
203 2 │ create table t();
204 │ ─ 1. reference
205 3 │ drop table t;
206 │ ┬
207 │ │
208 │ 0. query
209 ╰╴ 2. reference
210 ");
211 }
212
213 #[test]
214 fn multiple_references() {
215 assert_snapshot!(find_refs("
216create table users();
217drop table users$0;
218table users;
219"), @r"
220 ╭▸
221 2 │ create table users();
222 │ ───── 1. reference
223 3 │ drop table users;
224 │ ┬───┬
225 │ │ │
226 │ │ 0. query
227 │ 2. reference
228 4 │ table users;
229 ╰╴ ───── 3. reference
230 ");
231 }
232
233 #[test]
234 fn join_using_column() {
235 assert_snapshot!(find_refs("
236create table t(id int);
237create table u(id int);
238select * from t join u using (id$0);
239"), @r"
240 ╭▸
241 2 │ create table t(id int);
242 │ ── 1. reference
243 3 │ create table u(id int);
244 │ ── 2. reference
245 4 │ select * from t join u using (id);
246 │ ┬┬
247 │ ││
248 │ │0. query
249 ╰╴ 3. reference
250 ");
251 }
252
253 #[test]
254 fn find_from_definition() {
255 assert_snapshot!(find_refs("
256create table t$0();
257drop table t;
258"), @r"
259 ╭▸
260 2 │ create table t();
261 │ ┬
262 │ │
263 │ 0. query
264 │ 1. reference
265 3 │ drop table t;
266 ╰╴ ─ 2. reference
267 ");
268 }
269
270 #[test]
271 fn with_schema_qualified() {
272 assert_snapshot!(find_refs("
273create table public.users();
274drop table public.users$0;
275table users;
276"), @r"
277 ╭▸
278 2 │ create table public.users();
279 │ ───── 1. reference
280 3 │ drop table public.users;
281 │ ┬───┬
282 │ │ │
283 │ │ 0. query
284 │ 2. reference
285 4 │ table users;
286 ╰╴ ───── 3. reference
287 ");
288 }
289
290 #[test]
291 fn temp_table_do_not_shadows_public() {
292 assert_snapshot!(find_refs("
293create table t();
294create temp table t$0();
295drop table t;
296"), @r"
297 ╭▸
298 3 │ create temp table t();
299 │ ┬
300 │ │
301 │ 0. query
302 ╰╴ 1. reference
303 ");
304 }
305
306 #[test]
307 fn different_schema_no_match() {
308 assert_snapshot!(find_refs("
309create table foo.t();
310create table bar.t$0();
311"), @r"
312 ╭▸
313 3 │ create table bar.t();
314 │ ┬
315 │ │
316 │ 0. query
317 ╰╴ 1. reference
318 ");
319 }
320
321 #[test]
322 fn with_search_path() {
323 assert_snapshot!(find_refs("
324set search_path to myschema;
325create table myschema.users$0();
326drop table users;
327"), @r"
328 ╭▸
329 3 │ create table myschema.users();
330 │ ┬───┬
331 │ │ │
332 │ │ 0. query
333 │ 1. reference
334 4 │ drop table users;
335 ╰╴ ───── 2. reference
336 ");
337 }
338
339 #[test]
340 fn temp_table_with_pg_temp_schema() {
341 assert_snapshot!(find_refs("
342create temp table t();
343drop table pg_temp.t$0;
344"), @r"
345 ╭▸
346 2 │ create temp table t();
347 │ ─ 1. reference
348 3 │ drop table pg_temp.t;
349 │ ┬
350 │ │
351 │ 0. query
352 ╰╴ 2. reference
353 ");
354 }
355
356 #[test]
357 fn case_insensitive() {
358 assert_snapshot!(find_refs("
359create table Users();
360drop table USERS$0;
361table users;
362"), @r"
363 ╭▸
364 2 │ create table Users();
365 │ ───── 1. reference
366 3 │ drop table USERS;
367 │ ┬───┬
368 │ │ │
369 │ │ 0. query
370 │ 2. reference
371 4 │ table users;
372 ╰╴ ───── 3. reference
373 ");
374 }
375 #[test]
376 fn case_insensitive_part_2() {
377 assert_snapshot!(find_refs(r#"
379create table actors();
380create table "Actors"();
381drop table ACTORS$0;
382table actors;
383"#), @r#"
384 ╭▸
385 2 │ create table actors();
386 │ ────── 1. reference
387 3 │ create table "Actors"();
388 4 │ drop table ACTORS;
389 │ ┬────┬
390 │ │ │
391 │ │ 0. query
392 │ 2. reference
393 5 │ table actors;
394 ╰╴ ────── 3. reference
395 "#);
396 }
397
398 #[test]
399 fn case_insensitive_with_schema() {
400 assert_snapshot!(find_refs("
401create table Public.Users();
402drop table PUBLIC.USERS$0;
403table public.users;
404"), @r"
405 ╭▸
406 2 │ create table Public.Users();
407 │ ───── 1. reference
408 3 │ drop table PUBLIC.USERS;
409 │ ┬───┬
410 │ │ │
411 │ │ 0. query
412 │ 2. reference
413 4 │ table public.users;
414 ╰╴ ───── 3. reference
415 ");
416 }
417
418 #[test]
419 fn no_partial_match() {
420 assert_snapshot!(find_refs("
421create table t$0();
422create table temp_t();
423"), @r"
424 ╭▸
425 2 │ create table t();
426 │ ┬
427 │ │
428 │ 0. query
429 ╰╴ 1. reference
430 ");
431 }
432
433 #[test]
434 fn identifier_boundaries() {
435 assert_snapshot!(find_refs("
436create table foo$0();
437drop table foo;
438drop table foo1;
439drop table barfoo;
440drop table foo_bar;
441"), @r"
442 ╭▸
443 2 │ create table foo();
444 │ ┬─┬
445 │ │ │
446 │ │ 0. query
447 │ 1. reference
448 3 │ drop table foo;
449 ╰╴ ─── 2. reference
450 ");
451 }
452
453 #[test]
454 fn builtin_function_references() {
455 assert_snapshot!(find_refs("
456select now$0();
457select now();
458"), @"
459 ╭▸ current.sql:2:8
460 │
461 2 │ select now();
462 │ ┬─┬
463 │ │ │
464 │ │ 0. query
465 │ 1. reference
466 3 │ select now();
467 │ ─── 2. reference
468 ╰╴
469
470 ╭▸ builtin.sql:11089:28
471 │
472 11089 │ create function pg_catalog.now() returns timestamp with time zone
473 ╰╴ ─── 3. reference
474 ");
475 }
476}