1use shape_ast::ast::Item;
6use shape_ast::parser::parse_program;
7use tower_lsp_server::ls_types::{CodeLens, Command, Position, Range, Uri};
8
9pub fn get_code_lenses(text: &str, uri: &Uri) -> Vec<CodeLens> {
11 let mut lenses = Vec::new();
12
13 let program = match parse_program(text) {
15 Ok(p) => p,
16 Err(_) => {
17 let partial = shape_ast::parse_program_resilient(text);
18 if partial.items.is_empty() {
19 return lenses;
20 }
21 partial.into_program()
22 }
23 };
24
25 for item in &program.items {
26 collect_lenses_for_item(item, text, uri, &mut lenses);
27 }
28
29 lenses
30}
31
32pub fn resolve_code_lens(lens: CodeLens) -> CodeLens {
34 lens
36}
37
38fn collect_lenses_for_item(item: &Item, text: &str, uri: &Uri, lenses: &mut Vec<CodeLens>) {
40 match item {
41 Item::Function(func, _) => {
42 if let Some((line, keyword_end_col)) = find_function_line(text, &func.name) {
44 let ref_count = count_references(text, &func.name);
46 lenses.push(CodeLens {
47 range: Range {
48 start: Position { line, character: 0 },
49 end: Position { line, character: 0 },
50 },
51 command: Some(Command {
52 title: format!(
53 "{} reference{}",
54 ref_count,
55 if ref_count == 1 { "" } else { "s" }
56 ),
57 command: "shape.findReferences".to_string(),
58 arguments: Some(vec![
59 serde_json::json!(uri.to_string()),
60 serde_json::json!(line),
61 serde_json::json!(keyword_end_col),
62 ]),
63 }),
64 data: None,
65 });
66
67 for annotation in &func.annotations {
69 lenses.push(CodeLens {
70 range: Range {
71 start: Position { line, character: 0 },
72 end: Position { line, character: 0 },
73 },
74 command: Some(Command {
75 title: format!("@{}", annotation.name),
76 command: "shape.showAnnotation".to_string(),
77 arguments: Some(vec![
78 serde_json::json!(uri.to_string()),
79 serde_json::json!(annotation.name),
80 serde_json::json!(func.name),
81 ]),
82 }),
83 data: None,
84 });
85 }
86 }
87 }
88 Item::Trait(trait_def, _) => {
89 if let Some(line) = find_trait_line(text, &trait_def.name) {
91 let impl_count = count_trait_implementations(text, &trait_def.name);
92 lenses.push(CodeLens {
93 range: Range {
94 start: Position { line, character: 0 },
95 end: Position { line, character: 0 },
96 },
97 command: Some(Command {
98 title: format!(
99 "{} implementation{}",
100 impl_count,
101 if impl_count == 1 { "" } else { "s" }
102 ),
103 command: "shape.findImplementations".to_string(),
104 arguments: Some(vec![
105 serde_json::json!(uri.to_string()),
106 serde_json::json!(trait_def.name),
107 ]),
108 }),
109 data: None,
110 });
111 }
112
113 for member in &trait_def.members {
115 let (method_name, is_default) = match member {
116 shape_ast::ast::TraitMember::Required(
117 shape_ast::ast::InterfaceMember::Method { name, .. },
118 ) => (name.as_str(), false),
119 shape_ast::ast::TraitMember::Default(method_def) => {
120 (method_def.name.as_str(), true)
121 }
122 _ => continue,
123 };
124
125 if let Some(method_line) = find_method_in_trait(text, &trait_def.name, method_name)
126 {
127 if is_default {
128 lenses.push(CodeLens {
129 range: Range {
130 start: Position {
131 line: method_line,
132 character: 0,
133 },
134 end: Position {
135 line: method_line,
136 character: 0,
137 },
138 },
139 command: Some(Command {
140 title: "(default)".to_string(),
141 command: "shape.showTraitMethod".to_string(),
142 arguments: Some(vec![
143 serde_json::json!(uri.to_string()),
144 serde_json::json!(trait_def.name),
145 serde_json::json!(method_name),
146 ]),
147 }),
148 data: None,
149 });
150 }
151 }
152 }
153 }
154 Item::Test(test, _) => {
155 if let Some(line) = find_test_line(text, &test.name) {
156 lenses.push(CodeLens {
158 range: Range {
159 start: Position { line, character: 0 },
160 end: Position { line, character: 0 },
161 },
162 command: Some(Command {
163 title: "▶ Run All Tests".to_string(),
164 command: "shape.runTests".to_string(),
165 arguments: Some(vec![
166 serde_json::json!(uri.to_string()),
167 serde_json::json!(test.name),
168 ]),
169 }),
170 data: None,
171 });
172
173 lenses.push(CodeLens {
175 range: Range {
176 start: Position { line, character: 0 },
177 end: Position { line, character: 0 },
178 },
179 command: Some(Command {
180 title: "🐛 Debug Tests".to_string(),
181 command: "shape.debugTests".to_string(),
182 arguments: Some(vec![
183 serde_json::json!(uri.to_string()),
184 serde_json::json!(test.name),
185 ]),
186 }),
187 data: None,
188 });
189 }
190 }
191 _ => {}
192 }
193}
194
195fn find_function_line(text: &str, name: &str) -> Option<(u32, u32)> {
197 let fn_pattern = format!("fn {}", name);
198 let function_pattern = format!("function {}", name);
199
200 for (line_num, line) in text.lines().enumerate() {
201 if let Some(col) = line.find(&fn_pattern) {
202 return Some((line_num as u32, (col + "fn ".len()) as u32));
203 }
204 if let Some(col) = line.find(&function_pattern) {
205 return Some((line_num as u32, (col + "function ".len()) as u32));
206 }
207 }
208 None
209}
210
211fn find_test_line(text: &str, name: &str) -> Option<u32> {
213 let pattern = format!("test \"{}\"", name);
214 for (line_num, line) in text.lines().enumerate() {
215 if line.contains(&pattern) {
216 return Some(line_num as u32);
217 }
218 }
219 let pattern = format!("test {}", name);
221 for (line_num, line) in text.lines().enumerate() {
222 if line.contains(&pattern) {
223 return Some(line_num as u32);
224 }
225 }
226 None
227}
228
229#[allow(dead_code)]
231fn find_pattern_line(text: &str, name: &str) -> Option<u32> {
232 let pattern = format!("pattern {}", name);
233 for (line_num, line) in text.lines().enumerate() {
234 if line.contains(&pattern) {
235 return Some(line_num as u32);
236 }
237 }
238 None
239}
240
241fn find_trait_line(text: &str, name: &str) -> Option<u32> {
243 let pattern = format!("trait {}", name);
244 for (line_num, line) in text.lines().enumerate() {
245 if line.trim().starts_with(&pattern) {
246 return Some(line_num as u32);
247 }
248 }
249 None
250}
251
252fn count_trait_implementations(text: &str, trait_name: &str) -> usize {
254 let pattern = format!("impl {} for", trait_name);
255 text.lines()
256 .filter(|line| line.trim().starts_with(&pattern) || line.trim().contains(&pattern))
257 .count()
258}
259
260fn find_method_in_trait(text: &str, trait_name: &str, method_name: &str) -> Option<u32> {
262 let trait_pattern = format!("trait {}", trait_name);
263 let mut in_trait = false;
264 let mut brace_count: i32 = 0;
265
266 for (line_num, line) in text.lines().enumerate() {
267 if line.trim().starts_with(&trait_pattern) {
268 in_trait = true;
269 }
270
271 if in_trait {
272 brace_count += line.matches('{').count() as i32;
273 brace_count -= line.matches('}').count() as i32;
274
275 let trimmed = line.trim();
277 if (trimmed.contains(&format!("{}(", method_name))
278 || trimmed.starts_with(&format!("method {}(", method_name)))
279 && !trimmed.starts_with("trait ")
280 {
281 return Some(line_num as u32);
282 }
283
284 if brace_count == 0 && line.contains('}') {
285 in_trait = false;
286 }
287 }
288 }
289 None
290}
291
292fn count_references(text: &str, name: &str) -> usize {
294 let mut count = 0;
295 let name_len = name.len();
296
297 for (i, _) in text.match_indices(name) {
298 let before_ok = i == 0 || !text[..i].chars().last().unwrap().is_alphanumeric();
300 let after_ok = i + name_len >= text.len()
301 || !text[i + name_len..]
302 .chars()
303 .next()
304 .unwrap()
305 .is_alphanumeric();
306
307 if before_ok && after_ok {
308 count += 1;
309 }
310 }
311
312 if count > 0 { count - 1 } else { 0 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 #[test]
321 fn test_count_references() {
322 let text = "let foo = 1;\nlet bar = foo + foo;";
323
324 assert_eq!(count_references(text, "foo"), 2);
327
328 assert_eq!(count_references(text, "bar"), 0);
330
331 assert_eq!(count_references(text, "baz"), 0);
333 }
334
335 #[test]
336 fn test_find_function_line() {
337 let text = "// comment\nfunction myFunc() {\n return 1;\n}";
338 assert_eq!(find_function_line(text, "myFunc"), Some((1, 9)));
339 let text = "// comment\nfn myFunc() {\n return 1;\n}";
340 assert_eq!(find_function_line(text, "myFunc"), Some((1, 3)));
341 assert_eq!(find_function_line(text, "nonexistent"), None);
342 }
343
344 #[test]
345 fn test_find_trait_line() {
346 let text = "// comment\ntrait Queryable {\n filter(pred): any\n}\n";
347 assert_eq!(find_trait_line(text, "Queryable"), Some(1));
348 assert_eq!(find_trait_line(text, "NonExistent"), None);
349 }
350
351 #[test]
352 fn test_count_trait_implementations() {
353 let text = "trait Queryable {\n filter(pred): any\n}\nimpl Queryable for Table {\n method filter(pred) { self }\n}\nimpl Queryable for DataFrame {\n method filter(pred) { self }\n}\n";
354 assert_eq!(count_trait_implementations(text, "Queryable"), 2);
355 assert_eq!(count_trait_implementations(text, "NonExistent"), 0);
356 }
357
358 #[test]
359 fn test_trait_code_lens() {
360 let text = "trait Queryable {\n filter(pred): any\n}\nimpl Queryable for Table {\n method filter(pred) { self }\n}\n";
361 let uri = Uri::from_file_path("/tmp/test.shape").unwrap();
362 let lenses = get_code_lenses(text, &uri);
363 assert!(
365 lenses.iter().any(|l| l
366 .command
367 .as_ref()
368 .map_or(false, |c| c.title.contains("implementation"))),
369 "Should have implementation count lens for trait. Got: {:?}",
370 lenses
371 .iter()
372 .map(|l| l.command.as_ref().map(|c| c.title.clone()))
373 .collect::<Vec<_>>()
374 );
375 }
376
377 #[test]
378 fn test_find_pattern_line() {
379 let text = "// comment\npattern hammer {\n close > open\n}";
380 assert_eq!(find_pattern_line(text, "hammer"), Some(1));
381 assert_eq!(find_pattern_line(text, "doji"), None);
382 }
383}