Skip to main content

shape_lsp/
code_lens.rs

1//! Code lens provider for Shape
2//!
3//! Provides actionable code lenses for functions, patterns, and tests.
4
5use shape_ast::ast::Item;
6use shape_ast::parser::parse_program;
7use tower_lsp_server::ls_types::{CodeLens, Command, Position, Range, Uri};
8
9/// Get code lenses for a document
10pub fn get_code_lenses(text: &str, uri: &Uri) -> Vec<CodeLens> {
11    let mut lenses = Vec::new();
12
13    // Parse the document, falling back to resilient parser
14    let program = match parse_program(text) {
15        Ok(p) => p,
16        Err(_) => {
17            let partial = shape_ast::parse_program_resilient(text);
18            if partial.items.is_empty() {
19                return lenses;
20            }
21            partial.into_program()
22        }
23    };
24
25    for item in &program.items {
26        collect_lenses_for_item(item, text, uri, &mut lenses);
27    }
28
29    lenses
30}
31
32/// Resolve a code lens (add the command)
33pub fn resolve_code_lens(lens: CodeLens) -> CodeLens {
34    // Code lenses are already resolved in get_code_lenses
35    lens
36}
37
38/// Collect code lenses for an item
39fn collect_lenses_for_item(item: &Item, text: &str, uri: &Uri, lenses: &mut Vec<CodeLens>) {
40    match item {
41        Item::Function(func, _) => {
42            // Find the line where the function is defined
43            if let Some((line, keyword_end_col)) = find_function_line(text, &func.name) {
44                // Reference count lens
45                let ref_count = count_references(text, &func.name);
46                lenses.push(CodeLens {
47                    range: Range {
48                        start: Position { line, character: 0 },
49                        end: Position { line, character: 0 },
50                    },
51                    command: Some(Command {
52                        title: format!(
53                            "{} reference{}",
54                            ref_count,
55                            if ref_count == 1 { "" } else { "s" }
56                        ),
57                        command: "shape.findReferences".to_string(),
58                        arguments: Some(vec![
59                            serde_json::json!(uri.to_string()),
60                            serde_json::json!(line),
61                            serde_json::json!(keyword_end_col),
62                        ]),
63                    }),
64                    data: None,
65                });
66
67                // Add code lenses for annotations on self function
68                for annotation in &func.annotations {
69                    lenses.push(CodeLens {
70                        range: Range {
71                            start: Position { line, character: 0 },
72                            end: Position { line, character: 0 },
73                        },
74                        command: Some(Command {
75                            title: format!("@{}", annotation.name),
76                            command: "shape.showAnnotation".to_string(),
77                            arguments: Some(vec![
78                                serde_json::json!(uri.to_string()),
79                                serde_json::json!(annotation.name),
80                                serde_json::json!(func.name),
81                            ]),
82                        }),
83                        data: None,
84                    });
85                }
86            }
87        }
88        Item::Trait(trait_def, _) => {
89            // Add "N implementations" lens on the trait definition
90            if let Some(line) = find_trait_line(text, &trait_def.name) {
91                let impl_count = count_trait_implementations(text, &trait_def.name);
92                lenses.push(CodeLens {
93                    range: Range {
94                        start: Position { line, character: 0 },
95                        end: Position { line, character: 0 },
96                    },
97                    command: Some(Command {
98                        title: format!(
99                            "{} implementation{}",
100                            impl_count,
101                            if impl_count == 1 { "" } else { "s" }
102                        ),
103                        command: "shape.findImplementations".to_string(),
104                        arguments: Some(vec![
105                            serde_json::json!(uri.to_string()),
106                            serde_json::json!(trait_def.name),
107                        ]),
108                    }),
109                    data: None,
110                });
111            }
112
113            // Add per-method lenses showing if the method has a default implementation
114            for member in &trait_def.members {
115                let (method_name, is_default) = match member {
116                    shape_ast::ast::TraitMember::Required(
117                        shape_ast::ast::InterfaceMember::Method { name, .. },
118                    ) => (name.as_str(), false),
119                    shape_ast::ast::TraitMember::Default(method_def) => {
120                        (method_def.name.as_str(), true)
121                    }
122                    _ => continue,
123                };
124
125                if let Some(method_line) = find_method_in_trait(text, &trait_def.name, method_name)
126                {
127                    if is_default {
128                        lenses.push(CodeLens {
129                            range: Range {
130                                start: Position {
131                                    line: method_line,
132                                    character: 0,
133                                },
134                                end: Position {
135                                    line: method_line,
136                                    character: 0,
137                                },
138                            },
139                            command: Some(Command {
140                                title: "(default)".to_string(),
141                                command: "shape.showTraitMethod".to_string(),
142                                arguments: Some(vec![
143                                    serde_json::json!(uri.to_string()),
144                                    serde_json::json!(trait_def.name),
145                                    serde_json::json!(method_name),
146                                ]),
147                            }),
148                            data: None,
149                        });
150                    }
151                }
152            }
153        }
154        Item::Test(test, _) => {
155            if let Some(line) = find_test_line(text, &test.name) {
156                // Run all tests lens
157                lenses.push(CodeLens {
158                    range: Range {
159                        start: Position { line, character: 0 },
160                        end: Position { line, character: 0 },
161                    },
162                    command: Some(Command {
163                        title: "▶ Run All Tests".to_string(),
164                        command: "shape.runTests".to_string(),
165                        arguments: Some(vec![
166                            serde_json::json!(uri.to_string()),
167                            serde_json::json!(test.name),
168                        ]),
169                    }),
170                    data: None,
171                });
172
173                // Debug tests lens
174                lenses.push(CodeLens {
175                    range: Range {
176                        start: Position { line, character: 0 },
177                        end: Position { line, character: 0 },
178                    },
179                    command: Some(Command {
180                        title: "🐛 Debug Tests".to_string(),
181                        command: "shape.debugTests".to_string(),
182                        arguments: Some(vec![
183                            serde_json::json!(uri.to_string()),
184                            serde_json::json!(test.name),
185                        ]),
186                    }),
187                    data: None,
188                });
189            }
190        }
191        _ => {}
192    }
193}
194
195/// Find the line number where a function is defined
196fn find_function_line(text: &str, name: &str) -> Option<(u32, u32)> {
197    let fn_pattern = format!("fn {}", name);
198    let function_pattern = format!("function {}", name);
199
200    for (line_num, line) in text.lines().enumerate() {
201        if let Some(col) = line.find(&fn_pattern) {
202            return Some((line_num as u32, (col + "fn ".len()) as u32));
203        }
204        if let Some(col) = line.find(&function_pattern) {
205            return Some((line_num as u32, (col + "function ".len()) as u32));
206        }
207    }
208    None
209}
210
211/// Find the line number where a test is defined
212fn find_test_line(text: &str, name: &str) -> Option<u32> {
213    let pattern = format!("test \"{}\"", name);
214    for (line_num, line) in text.lines().enumerate() {
215        if line.contains(&pattern) {
216            return Some(line_num as u32);
217        }
218    }
219    // Also try without quotes
220    let pattern = format!("test {}", name);
221    for (line_num, line) in text.lines().enumerate() {
222        if line.contains(&pattern) {
223            return Some(line_num as u32);
224        }
225    }
226    None
227}
228
229/// Find the line number where a pattern is defined
230#[allow(dead_code)]
231fn find_pattern_line(text: &str, name: &str) -> Option<u32> {
232    let pattern = format!("pattern {}", name);
233    for (line_num, line) in text.lines().enumerate() {
234        if line.contains(&pattern) {
235            return Some(line_num as u32);
236        }
237    }
238    None
239}
240
241/// Find the line number where a trait is defined
242fn find_trait_line(text: &str, name: &str) -> Option<u32> {
243    let pattern = format!("trait {}", name);
244    for (line_num, line) in text.lines().enumerate() {
245        if line.trim().starts_with(&pattern) {
246            return Some(line_num as u32);
247        }
248    }
249    None
250}
251
252/// Count the number of `impl TraitName for ...` blocks in the text
253fn count_trait_implementations(text: &str, trait_name: &str) -> usize {
254    let pattern = format!("impl {} for", trait_name);
255    text.lines()
256        .filter(|line| line.trim().starts_with(&pattern) || line.trim().contains(&pattern))
257        .count()
258}
259
260/// Find the line of a method within a trait definition
261fn find_method_in_trait(text: &str, trait_name: &str, method_name: &str) -> Option<u32> {
262    let trait_pattern = format!("trait {}", trait_name);
263    let mut in_trait = false;
264    let mut brace_count: i32 = 0;
265
266    for (line_num, line) in text.lines().enumerate() {
267        if line.trim().starts_with(&trait_pattern) {
268            in_trait = true;
269        }
270
271        if in_trait {
272            brace_count += line.matches('{').count() as i32;
273            brace_count -= line.matches('}').count() as i32;
274
275            // Check if self line contains the method name
276            let trimmed = line.trim();
277            if (trimmed.contains(&format!("{}(", method_name))
278                || trimmed.starts_with(&format!("method {}(", method_name)))
279                && !trimmed.starts_with("trait ")
280            {
281                return Some(line_num as u32);
282            }
283
284            if brace_count == 0 && line.contains('}') {
285                in_trait = false;
286            }
287        }
288    }
289    None
290}
291
292/// Count references to a symbol in the text
293fn count_references(text: &str, name: &str) -> usize {
294    let mut count = 0;
295    let name_len = name.len();
296
297    for (i, _) in text.match_indices(name) {
298        // Check word boundaries
299        let before_ok = i == 0 || !text[..i].chars().last().unwrap().is_alphanumeric();
300        let after_ok = i + name_len >= text.len()
301            || !text[i + name_len..]
302                .chars()
303                .next()
304                .unwrap()
305                .is_alphanumeric();
306
307        if before_ok && after_ok {
308            count += 1;
309        }
310    }
311
312    // Subtract 1 for the definition itself
313    if count > 0 { count - 1 } else { 0 }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319
320    #[test]
321    fn test_count_references() {
322        let text = "let foo = 1;\nlet bar = foo + foo;";
323
324        // foo appears 3 times (1 definition + 2 uses)
325        // count_references subtracts 1 for definition
326        assert_eq!(count_references(text, "foo"), 2);
327
328        // bar appears once (just definition)
329        assert_eq!(count_references(text, "bar"), 0);
330
331        // nonexistent
332        assert_eq!(count_references(text, "baz"), 0);
333    }
334
335    #[test]
336    fn test_find_function_line() {
337        let text = "// comment\nfunction myFunc() {\n    return 1;\n}";
338        assert_eq!(find_function_line(text, "myFunc"), Some((1, 9)));
339        let text = "// comment\nfn myFunc() {\n    return 1;\n}";
340        assert_eq!(find_function_line(text, "myFunc"), Some((1, 3)));
341        assert_eq!(find_function_line(text, "nonexistent"), None);
342    }
343
344    #[test]
345    fn test_find_trait_line() {
346        let text = "// comment\ntrait Queryable {\n    filter(pred): any\n}\n";
347        assert_eq!(find_trait_line(text, "Queryable"), Some(1));
348        assert_eq!(find_trait_line(text, "NonExistent"), None);
349    }
350
351    #[test]
352    fn test_count_trait_implementations() {
353        let text = "trait Queryable {\n    filter(pred): any\n}\nimpl Queryable for Table {\n    method filter(pred) { self }\n}\nimpl Queryable for DataFrame {\n    method filter(pred) { self }\n}\n";
354        assert_eq!(count_trait_implementations(text, "Queryable"), 2);
355        assert_eq!(count_trait_implementations(text, "NonExistent"), 0);
356    }
357
358    #[test]
359    fn test_trait_code_lens() {
360        let text = "trait Queryable {\n    filter(pred): any\n}\nimpl Queryable for Table {\n    method filter(pred) { self }\n}\n";
361        let uri = Uri::from_file_path("/tmp/test.shape").unwrap();
362        let lenses = get_code_lenses(text, &uri);
363        // Should have at least one code lens for the trait
364        assert!(
365            lenses.iter().any(|l| l
366                .command
367                .as_ref()
368                .map_or(false, |c| c.title.contains("implementation"))),
369            "Should have implementation count lens for trait. Got: {:?}",
370            lenses
371                .iter()
372                .map(|l| l.command.as_ref().map(|c| c.title.clone()))
373                .collect::<Vec<_>>()
374        );
375    }
376
377    #[test]
378    fn test_find_pattern_line() {
379        let text = "// comment\npattern hammer {\n    close > open\n}";
380        assert_eq!(find_pattern_line(text, "hammer"), Some(1));
381        assert_eq!(find_pattern_line(text, "doji"), None);
382    }
383}