1use crate::binder::{self, Binder};
2use crate::builtins::BUILTINS_SQL;
3use crate::goto_definition::{FileId, Location};
4use crate::offsets::token_from_offset;
5use crate::resolve;
6use rowan::TextSize;
7use smallvec::{SmallVec, smallvec};
8use squawk_syntax::{
9 SyntaxNodePtr,
10 ast::{self, AstNode},
11 match_ast,
12};
13
14pub fn find_references(file: &ast::SourceFile, offset: TextSize) -> Vec<Location> {
15 let current_binder = binder::bind(file);
16
17 let builtins_tree = ast::SourceFile::parse(BUILTINS_SQL).tree();
18 let builtins_binder = binder::bind(&builtins_tree);
19
20 let Some((target_file, target_defs)) = find_target_defs(
21 file,
22 offset,
23 ¤t_binder,
24 &builtins_tree,
25 &builtins_binder,
26 ) else {
27 return vec![];
28 };
29
30 let (binder, root) = match target_file {
31 FileId::Current => (¤t_binder, file.syntax()),
32 FileId::Builtins => (&builtins_binder, builtins_tree.syntax()),
33 };
34
35 let mut refs: Vec<Location> = vec![];
36
37 if target_file == FileId::Builtins {
38 for ptr in &target_defs {
39 refs.push(Location {
40 file: FileId::Builtins,
41 range: ptr.to_node(builtins_tree.syntax()).text_range(),
42 });
43 }
44 }
45
46 for node in file.syntax().descendants() {
47 match_ast! {
48 match node {
49 ast::NameRef(name_ref) => {
50 if let Some(found_defs) = resolve::resolve_name_ref_ptrs(binder, root, &name_ref)
52 && found_defs.iter().any(|def| target_defs.contains(def))
53 {
54 refs.push(Location {
55 file: FileId::Current,
56 range: name_ref.syntax().text_range(),
57 });
58 }
59 },
60 ast::Name(name) => {
61 let found = SyntaxNodePtr::new(name.syntax());
63 if target_defs.contains(&found) {
64 refs.push(Location {
65 file: FileId::Current,
66 range: name.syntax().text_range(),
67 });
68 }
69 },
70 _ => (),
71 }
72 }
73 }
74
75 refs.sort_by_key(|loc| (loc.file, loc.range.start()));
76 refs
77}
78
79fn find_target_defs(
80 file: &ast::SourceFile,
81 offset: TextSize,
82 current_binder: &Binder,
83 builtins_tree: &ast::SourceFile,
84 builtins_binder: &Binder,
85) -> Option<(FileId, SmallVec<[SyntaxNodePtr; 1]>)> {
86 let token = token_from_offset(file, offset)?;
87 let parent = token.parent()?;
88
89 if let Some(name) = ast::Name::cast(parent.clone()) {
90 return Some((
91 FileId::Current,
92 smallvec![SyntaxNodePtr::new(name.syntax())],
93 ));
94 }
95
96 if let Some(name_ref) = ast::NameRef::cast(parent.clone()) {
97 for file_id in [FileId::Current, FileId::Builtins] {
98 let (binder, root) = match file_id {
99 FileId::Current => (current_binder, file.syntax()),
100 FileId::Builtins => (builtins_binder, builtins_tree.syntax()),
101 };
102 if let Some(ptrs) = resolve::resolve_name_ref_ptrs(binder, root, &name_ref) {
103 return Some((file_id, ptrs));
104 }
105 }
106 }
107
108 None
109}
110
111#[cfg(test)]
112mod test {
113 use crate::builtins::BUILTINS_SQL;
114 use crate::find_references::find_references;
115 use crate::goto_definition::FileId;
116 use crate::test_utils::fixture;
117 use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
118 use insta::assert_snapshot;
119 use rowan::TextRange;
120 use squawk_syntax::ast;
121
122 #[track_caller]
123 fn find_refs(sql: &str) -> String {
124 let (mut offset, sql) = fixture(sql);
125 offset = offset.checked_sub(1.into()).unwrap_or_default();
126 let parse = ast::SourceFile::parse(&sql);
127 assert_eq!(parse.errors(), vec![]);
128 let file: ast::SourceFile = parse.tree();
129
130 let references = find_references(&file, offset);
131
132 let offset_usize: usize = offset.into();
133
134 let mut current_refs = vec![];
135 let mut builtin_refs = vec![];
136 for (i, location) in references.iter().enumerate() {
137 let label_index = i + 1;
138 match location.file {
139 FileId::Current => current_refs.push((label_index, location.range)),
140 FileId::Builtins => builtin_refs.push((label_index, location.range)),
141 }
142 }
143
144 let has_builtins = !builtin_refs.is_empty();
145
146 let mut snippet = Snippet::source(&sql).fold(true);
147 if has_builtins {
148 snippet = snippet.path("current.sql");
149 }
150 snippet = snippet.annotation(
151 AnnotationKind::Context
152 .span(offset_usize..offset_usize + 1)
153 .label("0. query"),
154 );
155 snippet = annotate_refs(snippet, current_refs);
156
157 let mut groups = vec![Level::INFO.primary_title("references").element(snippet)];
158
159 if has_builtins {
160 let builtins_snippet = Snippet::source(BUILTINS_SQL).path("builtin.sql").fold(true);
161 let builtins_snippet = annotate_refs(builtins_snippet, builtin_refs);
162 groups.push(
163 Level::INFO
164 .primary_title("references")
165 .element(builtins_snippet),
166 );
167 }
168
169 let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
170 renderer
171 .render(&groups)
172 .to_string()
173 .replace("info: references", "")
174 }
175
176 fn annotate_refs<'a>(
177 mut snippet: Snippet<'a, annotate_snippets::Annotation<'a>>,
178 refs: Vec<(usize, TextRange)>,
179 ) -> Snippet<'a, annotate_snippets::Annotation<'a>> {
180 for (label_index, range) in refs {
181 snippet = snippet.annotation(
182 AnnotationKind::Context
183 .span(range.into())
184 .label(format!("{}. reference", label_index)),
185 );
186 }
187 snippet
188 }
189
190 #[test]
191 fn simple_table_reference() {
192 assert_snapshot!(find_refs("
193create table t();
194drop table t$0;
195"), @r"
196 ╭▸
197 2 │ create table t();
198 │ ─ 1. reference
199 3 │ drop table t;
200 │ ┬
201 │ │
202 │ 0. query
203 ╰╴ 2. reference
204 ");
205 }
206
207 #[test]
208 fn multiple_references() {
209 assert_snapshot!(find_refs("
210create table users();
211drop table users$0;
212table users;
213"), @r"
214 ╭▸
215 2 │ create table users();
216 │ ───── 1. reference
217 3 │ drop table users;
218 │ ┬───┬
219 │ │ │
220 │ │ 0. query
221 │ 2. reference
222 4 │ table users;
223 ╰╴ ───── 3. reference
224 ");
225 }
226
227 #[test]
228 fn join_using_column() {
229 assert_snapshot!(find_refs("
230create table t(id int);
231create table u(id int);
232select * from t join u using (id$0);
233"), @r"
234 ╭▸
235 2 │ create table t(id int);
236 │ ── 1. reference
237 3 │ create table u(id int);
238 │ ── 2. reference
239 4 │ select * from t join u using (id);
240 │ ┬┬
241 │ ││
242 │ │0. query
243 ╰╴ 3. reference
244 ");
245 }
246
247 #[test]
248 fn find_from_definition() {
249 assert_snapshot!(find_refs("
250create table t$0();
251drop table t;
252"), @r"
253 ╭▸
254 2 │ create table t();
255 │ ┬
256 │ │
257 │ 0. query
258 │ 1. reference
259 3 │ drop table t;
260 ╰╴ ─ 2. reference
261 ");
262 }
263
264 #[test]
265 fn with_schema_qualified() {
266 assert_snapshot!(find_refs("
267create table public.users();
268drop table public.users$0;
269table users;
270"), @r"
271 ╭▸
272 2 │ create table public.users();
273 │ ───── 1. reference
274 3 │ drop table public.users;
275 │ ┬───┬
276 │ │ │
277 │ │ 0. query
278 │ 2. reference
279 4 │ table users;
280 ╰╴ ───── 3. reference
281 ");
282 }
283
284 #[test]
285 fn temp_table_do_not_shadows_public() {
286 assert_snapshot!(find_refs("
287create table t();
288create temp table t$0();
289drop table t;
290"), @r"
291 ╭▸
292 3 │ create temp table t();
293 │ ┬
294 │ │
295 │ 0. query
296 ╰╴ 1. reference
297 ");
298 }
299
300 #[test]
301 fn different_schema_no_match() {
302 assert_snapshot!(find_refs("
303create table foo.t();
304create table bar.t$0();
305"), @r"
306 ╭▸
307 3 │ create table bar.t();
308 │ ┬
309 │ │
310 │ 0. query
311 ╰╴ 1. reference
312 ");
313 }
314
315 #[test]
316 fn with_search_path() {
317 assert_snapshot!(find_refs("
318set search_path to myschema;
319create table myschema.users$0();
320drop table users;
321"), @r"
322 ╭▸
323 3 │ create table myschema.users();
324 │ ┬───┬
325 │ │ │
326 │ │ 0. query
327 │ 1. reference
328 4 │ drop table users;
329 ╰╴ ───── 2. reference
330 ");
331 }
332
333 #[test]
334 fn temp_table_with_pg_temp_schema() {
335 assert_snapshot!(find_refs("
336create temp table t();
337drop table pg_temp.t$0;
338"), @r"
339 ╭▸
340 2 │ create temp table t();
341 │ ─ 1. reference
342 3 │ drop table pg_temp.t;
343 │ ┬
344 │ │
345 │ 0. query
346 ╰╴ 2. reference
347 ");
348 }
349
350 #[test]
351 fn case_insensitive() {
352 assert_snapshot!(find_refs("
353create table Users();
354drop table USERS$0;
355table users;
356"), @r"
357 ╭▸
358 2 │ create table Users();
359 │ ───── 1. reference
360 3 │ drop table USERS;
361 │ ┬───┬
362 │ │ │
363 │ │ 0. query
364 │ 2. reference
365 4 │ table users;
366 ╰╴ ───── 3. reference
367 ");
368 }
369 #[test]
370 fn case_insensitive_part_2() {
371 assert_snapshot!(find_refs(r#"
373create table actors();
374create table "Actors"();
375drop table ACTORS$0;
376table actors;
377"#), @r#"
378 ╭▸
379 2 │ create table actors();
380 │ ────── 1. reference
381 3 │ create table "Actors"();
382 4 │ drop table ACTORS;
383 │ ┬────┬
384 │ │ │
385 │ │ 0. query
386 │ 2. reference
387 5 │ table actors;
388 ╰╴ ────── 3. reference
389 "#);
390 }
391
392 #[test]
393 fn case_insensitive_with_schema() {
394 assert_snapshot!(find_refs("
395create table Public.Users();
396drop table PUBLIC.USERS$0;
397table public.users;
398"), @r"
399 ╭▸
400 2 │ create table Public.Users();
401 │ ───── 1. reference
402 3 │ drop table PUBLIC.USERS;
403 │ ┬───┬
404 │ │ │
405 │ │ 0. query
406 │ 2. reference
407 4 │ table public.users;
408 ╰╴ ───── 3. reference
409 ");
410 }
411
412 #[test]
413 fn no_partial_match() {
414 assert_snapshot!(find_refs("
415create table t$0();
416create table temp_t();
417"), @r"
418 ╭▸
419 2 │ create table t();
420 │ ┬
421 │ │
422 │ 0. query
423 ╰╴ 1. reference
424 ");
425 }
426
427 #[test]
428 fn identifier_boundaries() {
429 assert_snapshot!(find_refs("
430create table foo$0();
431drop table foo;
432drop table foo1;
433drop table barfoo;
434drop table foo_bar;
435"), @r"
436 ╭▸
437 2 │ create table foo();
438 │ ┬─┬
439 │ │ │
440 │ │ 0. query
441 │ 1. reference
442 3 │ drop table foo;
443 ╰╴ ─── 2. reference
444 ");
445 }
446
447 #[test]
448 fn builtin_function_references() {
449 assert_snapshot!(find_refs("
450select now$0();
451select now();
452"), @r"
453 ╭▸ current.sql:2:8
454 │
455 2 │ select now();
456 │ ┬─┬
457 │ │ │
458 │ │ 0. query
459 │ 1. reference
460 3 │ select now();
461 │ ─── 2. reference
462 ╰╴
463
464 ╭▸ builtin.sql:10798:28
465 │
466 10798 │ create function pg_catalog.now() returns timestamp with time zone
467 ╰╴ ─── 3. reference
468 ");
469 }
470}