Skip to main content

shape_lsp/
code_actions.rs

1//! Code actions provider for Shape
2//!
3//! Provides quick fixes, refactoring actions, and source actions.
4
5use crate::doc_actions::generate_doc_comment_action;
6use crate::module_cache::ModuleCache;
7use crate::util::{get_word_at_position, span_to_range};
8use shape_ast::ast::{ImportItems, Item};
9use shape_ast::parser::parse_program;
10use std::collections::{HashMap, HashSet};
11use tower_lsp_server::ls_types::{
12    CodeAction, CodeActionKind, CodeActionOrCommand, Diagnostic, Position, Range, TextEdit, Uri,
13    WorkspaceEdit,
14};
15
16/// Get code actions for a document at a given range
17pub fn get_code_actions(
18    text: &str,
19    uri: &Uri,
20    range: Range,
21    diagnostics: &[Diagnostic],
22    module_cache: Option<&ModuleCache>,
23    requested_kinds: Option<&[CodeActionKind]>,
24) -> Vec<CodeActionOrCommand> {
25    let mut actions = Vec::new();
26
27    if is_kind_requested(requested_kinds, CodeActionKind::QUICKFIX.as_str()) {
28        // Add quick fixes for diagnostics at/near the requested range.
29        for diagnostic in diagnostics {
30            if ranges_overlap(range, diagnostic.range) {
31                if let Some(fix_actions) = get_quick_fixes(text, uri, diagnostic, module_cache) {
32                    actions.extend(fix_actions);
33                }
34            }
35        }
36
37        // Also offer symbol-based auto-import when the cursor is on an unresolved
38        // type-like identifier, even if diagnostics are stale/misaligned.
39        if let Some(cache) = module_cache {
40            actions.extend(get_symbol_auto_import_actions(text, uri, range, cache));
41        }
42    }
43
44    // Add refactoring actions based on selection
45    if is_group_requested(requested_kinds, CodeActionKind::REFACTOR.as_str()) {
46        if let Some(doc_action) = generate_doc_comment_action(text, uri, range) {
47            actions.push(doc_action);
48        }
49        if let Some(refactor_actions) = get_refactor_actions(text, uri, range) {
50            actions.extend(refactor_actions);
51        }
52    }
53
54    // Add source actions (organize imports, etc.)
55    if is_group_requested(requested_kinds, CodeActionKind::SOURCE.as_str()) {
56        if let Some(source_actions) =
57            get_source_actions(text, uri, range, diagnostics, module_cache, requested_kinds)
58        {
59            actions.extend(source_actions);
60        }
61    }
62
63    dedupe_actions(actions)
64}
65
66/// Get quick fixes for a diagnostic
67fn get_quick_fixes(
68    text: &str,
69    uri: &Uri,
70    diagnostic: &Diagnostic,
71    module_cache: Option<&ModuleCache>,
72) -> Option<Vec<CodeActionOrCommand>> {
73    let mut fixes = Vec::new();
74    let message = &diagnostic.message;
75
76    // Fix for undefined variable - suggest declaration
77    if message.contains("undefined") || message.contains("not defined") {
78        if let Some(var_name) = extract_undefined_name(message) {
79            fixes.push(create_quick_fix(
80                format!("Declare variable '{}'", var_name),
81                uri.clone(),
82                vec![TextEdit {
83                    range: Range {
84                        start: Position {
85                            line: diagnostic.range.start.line,
86                            character: 0,
87                        },
88                        end: Position {
89                            line: diagnostic.range.start.line,
90                            character: 0,
91                        },
92                    },
93                    new_text: format!("let {} = undefined;\n", var_name),
94                }],
95                diagnostic.clone(),
96            ));
97        }
98    }
99
100    // Fix for missing semicolon
101    if message.contains("expected ';'") || message.contains("missing semicolon") {
102        fixes.push(create_quick_fix(
103            "Add missing semicolon".to_string(),
104            uri.clone(),
105            vec![TextEdit {
106                range: Range {
107                    start: diagnostic.range.end,
108                    end: diagnostic.range.end,
109                },
110                new_text: ";".to_string(),
111            }],
112            diagnostic.clone(),
113        ));
114    }
115
116    // Fix for missing closing brace
117    if message.contains("expected '}'") || message.contains("unclosed") {
118        fixes.push(create_quick_fix(
119            "Add missing closing brace".to_string(),
120            uri.clone(),
121            vec![TextEdit {
122                range: Range {
123                    start: diagnostic.range.end,
124                    end: diagnostic.range.end,
125                },
126                new_text: "\n}".to_string(),
127            }],
128            diagnostic.clone(),
129        ));
130    }
131
132    // Fix for var to let conversion suggestion
133    if message.contains("prefer 'let'") || message.contains("use 'let' instead of 'var'") {
134        let line = get_line(text, diagnostic.range.start.line as usize);
135        if let Some(line_text) = line {
136            if let Some(var_pos) = line_text.find("var ") {
137                fixes.push(create_quick_fix(
138                    "Change 'var' to 'let'".to_string(),
139                    uri.clone(),
140                    vec![TextEdit {
141                        range: Range {
142                            start: Position {
143                                line: diagnostic.range.start.line,
144                                character: var_pos as u32,
145                            },
146                            end: Position {
147                                line: diagnostic.range.start.line,
148                                character: (var_pos + 3) as u32,
149                            },
150                        },
151                        new_text: "let".to_string(),
152                    }],
153                    diagnostic.clone(),
154                ));
155            }
156        }
157    }
158
159    // Auto-import for unknown enum/type
160    if message.contains("Unknown enum type") || message.contains("Unknown variant") {
161        if let Some(cache) = module_cache {
162            if let Some(name) = extract_quoted_name(message) {
163                let symbols = if let Some(current_file) = uri.to_file_path() {
164                    cache.find_exported_symbol_with_context(&name, current_file.as_ref(), None)
165                } else {
166                    cache.find_exported_symbol(&name)
167                };
168                for (import_path, _export) in symbols {
169                    fixes.push(create_quick_fix(
170                        format!("Import '{}' from {}", name, import_path),
171                        uri.clone(),
172                        vec![TextEdit {
173                            range: Range {
174                                start: Position {
175                                    line: 0,
176                                    character: 0,
177                                },
178                                end: Position {
179                                    line: 0,
180                                    character: 0,
181                                },
182                            },
183                            new_text: format!("from {} use {{ {} }}\n", import_path, name),
184                        }],
185                        diagnostic.clone(),
186                    ));
187                }
188            }
189        }
190    }
191
192    if message.contains("match expression requires at least one arm") {
193        if let Some((insert_pos, indent)) = find_match_arm_insert_position(text, diagnostic.range) {
194            let arm_indent = format!("{indent}  ");
195            fixes.push(create_quick_fix(
196                "Add wildcard match arm".to_string(),
197                uri.clone(),
198                vec![TextEdit {
199                    range: Range {
200                        start: insert_pos,
201                        end: insert_pos,
202                    },
203                    new_text: format!("{arm_indent}_ => {{\n{arm_indent}}},\n"),
204                }],
205                diagnostic.clone(),
206            ));
207        }
208    }
209
210    if let Some((enum_name, missing_variants)) = parse_non_exhaustive_match(message) {
211        if let Some((insert_pos, indent)) = find_match_arm_insert_position(text, diagnostic.range) {
212            let arm_indent = format!("{indent}  ");
213            let mut new_text = String::new();
214            for variant in missing_variants {
215                new_text.push_str(&format!(
216                    "{arm_indent}{enum_name}::{variant} => {{\n{arm_indent}}},\n"
217                ));
218            }
219            fixes.push(create_quick_fix(
220                format!("Add missing match arms for {}", enum_name),
221                uri.clone(),
222                vec![TextEdit {
223                    range: Range {
224                        start: insert_pos,
225                        end: insert_pos,
226                    },
227                    new_text,
228                }],
229                diagnostic.clone(),
230            ));
231        }
232    }
233
234    // Fix for missing required trait method — suggest adding the method stub
235    if message.contains("Missing required method") {
236        if let Some(method_name) = extract_quoted_name(message) {
237            // Find the closing brace of the impl block on the diagnostic line
238            let impl_end_line = diagnostic.range.end.line;
239            // Insert just before the closing brace
240            fixes.push(create_quick_fix(
241                format!("Implement method '{}'", method_name),
242                uri.clone(),
243                vec![TextEdit {
244                    range: Range {
245                        start: Position {
246                            line: impl_end_line,
247                            character: 0,
248                        },
249                        end: Position {
250                            line: impl_end_line,
251                            character: 0,
252                        },
253                    },
254                    new_text: format!(
255                        "    method {}() {{\n        // TODO: implement\n    }}\n",
256                        method_name
257                    ),
258                }],
259                diagnostic.clone(),
260            ));
261        }
262    }
263
264    // Fix for unused variable - add underscore prefix
265    if message.contains("unused") {
266        if let Some(var_name) = extract_unused_name(message) {
267            if !var_name.starts_with('_') {
268                let line = get_line(text, diagnostic.range.start.line as usize);
269                if let Some(line_text) = line {
270                    if let Some(name_pos) = line_text.find(&var_name) {
271                        fixes.push(create_quick_fix(
272                            format!("Prefix with underscore: _{}", var_name),
273                            uri.clone(),
274                            vec![TextEdit {
275                                range: Range {
276                                    start: Position {
277                                        line: diagnostic.range.start.line,
278                                        character: name_pos as u32,
279                                    },
280                                    end: Position {
281                                        line: diagnostic.range.start.line,
282                                        character: (name_pos + var_name.len()) as u32,
283                                    },
284                                },
285                                new_text: format!("_{}", var_name),
286                            }],
287                            diagnostic.clone(),
288                        ));
289                    }
290                }
291            }
292        }
293    }
294
295    if fixes.is_empty() { None } else { Some(fixes) }
296}
297
298/// Get refactoring actions for a selection
299fn get_refactor_actions(text: &str, uri: &Uri, range: Range) -> Option<Vec<CodeActionOrCommand>> {
300    let mut actions = Vec::new();
301
302    // Get the selected text
303    let selected = get_text_in_range(text, range);
304    if selected.is_empty() {
305        return None;
306    }
307
308    // Extract to variable
309    if is_expression(&selected) {
310        actions.push(CodeActionOrCommand::CodeAction(CodeAction {
311            title: "Extract to variable".to_string(),
312            kind: Some(CodeActionKind::REFACTOR_EXTRACT),
313            diagnostics: None,
314            edit: Some(WorkspaceEdit {
315                changes: Some({
316                    let mut changes = HashMap::new();
317                    changes.insert(
318                        uri.clone(),
319                        vec![
320                            TextEdit {
321                                range: Range {
322                                    start: Position {
323                                        line: range.start.line,
324                                        character: 0,
325                                    },
326                                    end: Position {
327                                        line: range.start.line,
328                                        character: 0,
329                                    },
330                                },
331                                new_text: format!("let extracted = {};\n", selected),
332                            },
333                            TextEdit {
334                                range,
335                                new_text: "extracted".to_string(),
336                            },
337                        ],
338                    );
339                    changes
340                }),
341                document_changes: None,
342                change_annotations: None,
343            }),
344            command: None,
345            is_preferred: None,
346            disabled: None,
347            data: None,
348        }));
349    }
350
351    // Extract to function (for multi-line selections or complex expressions)
352    if selected.contains('\n') || selected.len() > 50 {
353        actions.push(CodeActionOrCommand::CodeAction(CodeAction {
354            title: "Extract to function".to_string(),
355            kind: Some(CodeActionKind::REFACTOR_EXTRACT),
356            diagnostics: None,
357            edit: Some(WorkspaceEdit {
358                changes: Some({
359                    let mut changes = HashMap::new();
360                    changes.insert(
361                        uri.clone(),
362                        vec![
363                            TextEdit {
364                                range: Range {
365                                    start: Position {
366                                        line: 0,
367                                        character: 0,
368                                    },
369                                    end: Position {
370                                        line: 0,
371                                        character: 0,
372                                    },
373                                },
374                                new_text: format!(
375                                    "fn extractedFunction() {{\n    {}\n}}\n\n",
376                                    selected.replace('\n', "\n    ")
377                                ),
378                            },
379                            TextEdit {
380                                range,
381                                new_text: "extractedFunction()".to_string(),
382                            },
383                        ],
384                    );
385                    changes
386                }),
387                document_changes: None,
388                change_annotations: None,
389            }),
390            command: None,
391            is_preferred: None,
392            disabled: None,
393            data: None,
394        }));
395    }
396
397    // Convert string concatenation to template string
398    if selected.contains(" + ") && selected.contains('"') {
399        // This is a simplification - real implementation would need proper parsing
400        actions.push(CodeActionOrCommand::CodeAction(CodeAction {
401            title: "Convert to template string".to_string(),
402            kind: Some(CodeActionKind::REFACTOR_REWRITE),
403            diagnostics: None,
404            edit: None, // Would need proper implementation
405            command: None,
406            is_preferred: None,
407            disabled: Some(tower_lsp_server::ls_types::CodeActionDisabled {
408                reason: "Complex conversion - manual edit recommended".to_string(),
409            }),
410            data: None,
411        }));
412    }
413
414    if actions.is_empty() {
415        None
416    } else {
417        Some(actions)
418    }
419}
420
421/// Get source actions for the document
422fn get_source_actions(
423    text: &str,
424    uri: &Uri,
425    range: Range,
426    diagnostics: &[Diagnostic],
427    module_cache: Option<&ModuleCache>,
428    requested_kinds: Option<&[CodeActionKind]>,
429) -> Option<Vec<CodeActionOrCommand>> {
430    let mut actions = Vec::new();
431
432    let import_ranges = import_statement_ranges(text);
433    let on_import_stmt = import_ranges.iter().any(|r| ranges_overlap(*r, range));
434    let organize_requested = is_kind_explicitly_requested(
435        requested_kinds,
436        CodeActionKind::SOURCE_ORGANIZE_IMPORTS.as_str(),
437    );
438
439    // Show organize-imports only when explicitly requested or when cursor is
440    // currently inside import declarations.
441    if !import_ranges.is_empty() && (organize_requested || on_import_stmt) {
442        actions.push(CodeActionOrCommand::CodeAction(CodeAction {
443            title: "Organize imports".to_string(),
444            kind: Some(CodeActionKind::SOURCE_ORGANIZE_IMPORTS),
445            diagnostics: None,
446            edit: None, // Would need proper implementation
447            command: None,
448            is_preferred: None,
449            disabled: None,
450            data: None,
451        }));
452    }
453
454    // Add "Fix all" only when requested explicitly or when there are fixable
455    // diagnostics on the current range.
456    let fix_all_requested =
457        is_kind_explicitly_requested(requested_kinds, CodeActionKind::SOURCE_FIX_ALL.as_str());
458    let has_fixable_here = diagnostics
459        .iter()
460        .filter(|d| ranges_overlap(d.range, range))
461        .any(|d| get_quick_fixes(text, uri, d, module_cache).is_some());
462    if fix_all_requested || has_fixable_here {
463        actions.push(CodeActionOrCommand::CodeAction(CodeAction {
464            title: "Fix all auto-fixable problems".to_string(),
465            kind: Some(CodeActionKind::SOURCE_FIX_ALL),
466            diagnostics: None,
467            edit: None, // Would need to collect all quick fixes
468            command: None,
469            is_preferred: None,
470            disabled: None,
471            data: None,
472        }));
473    }
474
475    if actions.is_empty() {
476        None
477    } else {
478        Some(actions)
479    }
480}
481
482/// Create a quick fix code action
483fn create_quick_fix(
484    title: String,
485    uri: Uri,
486    edits: Vec<TextEdit>,
487    diagnostic: Diagnostic,
488) -> CodeActionOrCommand {
489    let mut changes = HashMap::new();
490    changes.insert(uri, edits);
491
492    CodeActionOrCommand::CodeAction(CodeAction {
493        title,
494        kind: Some(CodeActionKind::QUICKFIX),
495        diagnostics: Some(vec![diagnostic]),
496        edit: Some(WorkspaceEdit {
497            changes: Some(changes),
498            document_changes: None,
499            change_annotations: None,
500        }),
501        command: None,
502        is_preferred: Some(true),
503        disabled: None,
504        data: None,
505    })
506}
507
508/// Extract a single-quoted name from an error message
509fn extract_quoted_name(message: &str) -> Option<String> {
510    let start = message.find('\'')?;
511    let end = message[start + 1..].find('\'')?;
512    Some(message[start + 1..start + 1 + end].to_string())
513}
514
515/// Extract the undefined variable name from an error message
516fn extract_undefined_name(message: &str) -> Option<String> {
517    // Pattern: "undefined variable 'name'" or "'name' is not defined"
518    if let Some(start) = message.find('\'') {
519        if let Some(end) = message[start + 1..].find('\'') {
520            return Some(message[start + 1..start + 1 + end].to_string());
521        }
522    }
523    None
524}
525
526/// Extract the unused variable name from an error message
527fn extract_unused_name(message: &str) -> Option<String> {
528    // Pattern: "unused variable 'name'" or "'name' is unused"
529    if let Some(start) = message.find('\'') {
530        if let Some(end) = message[start + 1..].find('\'') {
531            return Some(message[start + 1..start + 1 + end].to_string());
532        }
533    }
534    None
535}
536
537fn parse_non_exhaustive_match(message: &str) -> Option<(String, Vec<String>)> {
538    const PREFIX: &str = "Non-exhaustive match on '";
539    const MARKER: &str = "': missing variants ";
540    let after_prefix = message.strip_prefix(PREFIX)?;
541    let marker_pos = after_prefix.find(MARKER)?;
542    let enum_name = after_prefix[..marker_pos].trim().to_string();
543    if enum_name.is_empty() {
544        return None;
545    }
546    let variants_part = &after_prefix[marker_pos + MARKER.len()..];
547    let variants = variants_part
548        .split(',')
549        .map(|v| v.trim())
550        .filter(|v| !v.is_empty())
551        .map(|v| v.to_string())
552        .collect::<Vec<_>>();
553    if variants.is_empty() {
554        None
555    } else {
556        Some((enum_name, variants))
557    }
558}
559
560fn find_match_arm_insert_position(text: &str, range: Range) -> Option<(Position, String)> {
561    let lines: Vec<&str> = text.lines().collect();
562    if lines.is_empty() {
563        return None;
564    }
565    let start_line = range.start.line as usize;
566    let mut line_index = start_line.min(lines.len().saturating_sub(1));
567    while line_index < lines.len() {
568        let line = lines[line_index];
569        let trimmed = line.trim_start();
570        if trimmed.starts_with('}') {
571            let indent_len = line.len().saturating_sub(trimmed.len());
572            let indent = " ".repeat(indent_len);
573            return Some((
574                Position {
575                    line: line_index as u32,
576                    character: 0,
577                },
578                indent,
579            ));
580        }
581        line_index += 1;
582    }
583    None
584}
585
586/// Check if two ranges overlap
587fn ranges_overlap(a: Range, b: Range) -> bool {
588    !(a.end.line < b.start.line
589        || (a.end.line == b.start.line && a.end.character < b.start.character)
590        || b.end.line < a.start.line
591        || (b.end.line == a.start.line && b.end.character < a.start.character))
592}
593
594/// Get a line from text by line number
595fn get_line(text: &str, line: usize) -> Option<&str> {
596    text.lines().nth(line)
597}
598
599/// Get text within a range
600fn get_text_in_range(text: &str, range: Range) -> String {
601    let lines: Vec<&str> = text.lines().collect();
602
603    if range.start.line == range.end.line {
604        // Single line selection
605        if let Some(line) = lines.get(range.start.line as usize) {
606            let start = range.start.character as usize;
607            let end = range.end.character as usize;
608            if start < line.len() && end <= line.len() {
609                return line[start..end].to_string();
610            }
611        }
612    } else {
613        // Multi-line selection
614        let mut result = String::new();
615
616        for (i, line) in lines.iter().enumerate() {
617            let line_num = i as u32;
618
619            if line_num < range.start.line {
620                continue;
621            }
622            if line_num > range.end.line {
623                break;
624            }
625
626            if line_num == range.start.line {
627                let start = range.start.character as usize;
628                if start < line.len() {
629                    result.push_str(&line[start..]);
630                }
631            } else if line_num == range.end.line {
632                let end = range.end.character as usize;
633                if end <= line.len() {
634                    result.push_str(&line[..end]);
635                }
636            } else {
637                result.push_str(line);
638            }
639
640            if line_num != range.end.line {
641                result.push('\n');
642            }
643        }
644
645        return result;
646    }
647
648    String::new()
649}
650
651/// Check if a string looks like an expression
652fn is_expression(text: &str) -> bool {
653    let trimmed = text.trim();
654
655    // Empty or whitespace-only is not an expression
656    if trimmed.is_empty() {
657        return false;
658    }
659
660    // Statements are not expressions (simplified check)
661    if trimmed.starts_with("let ")
662        || trimmed.starts_with("var ")
663        || trimmed.starts_with("const ")
664        || trimmed.starts_with("fn ")
665        || trimmed.starts_with("function ")
666        || trimmed.starts_with("if ")
667        || trimmed.starts_with("for ")
668        || trimmed.starts_with("while ")
669        || trimmed.starts_with("return ")
670    {
671        return false;
672    }
673
674    // Try to parse as expression
675    let test_code = format!("let _test = {};", trimmed);
676    parse_program(&test_code).is_ok()
677}
678
679/// True when `requested_kinds` allows the exact `target` kind.
680/// If a broad parent kind is requested (e.g. `source`), sub-kinds are allowed.
681fn is_kind_requested(requested_kinds: Option<&[CodeActionKind]>, target: &str) -> bool {
682    match requested_kinds {
683        None => true,
684        Some(kinds) if kinds.is_empty() => true,
685        Some(kinds) => kinds.iter().any(|k| {
686            let requested = k.as_str();
687            requested == target || target.starts_with(&format!("{requested}."))
688        }),
689    }
690}
691
692/// True when `requested_kinds` allows a kind group (`quickfix`, `source`, `refactor`).
693fn is_group_requested(requested_kinds: Option<&[CodeActionKind]>, group: &str) -> bool {
694    match requested_kinds {
695        None => true,
696        Some(kinds) if kinds.is_empty() => true,
697        Some(kinds) => kinds.iter().any(|k| {
698            let requested = k.as_str();
699            requested == group || requested.starts_with(&format!("{group}."))
700        }),
701    }
702}
703
704/// True only when a request explicitly provided `only` kinds including `target`
705/// (or a parent kind such as `source` for `source.organizeImports`).
706fn is_kind_explicitly_requested(requested_kinds: Option<&[CodeActionKind]>, target: &str) -> bool {
707    let Some(kinds) = requested_kinds else {
708        return false;
709    };
710    if kinds.is_empty() {
711        return false;
712    }
713    kinds.iter().any(|k| {
714        let requested = k.as_str();
715        requested == target || target.starts_with(&format!("{requested}."))
716    })
717}
718
719/// Deduplicate actions by `(kind, title)`.
720fn dedupe_actions(actions: Vec<CodeActionOrCommand>) -> Vec<CodeActionOrCommand> {
721    let mut seen = HashSet::new();
722    let mut deduped = Vec::new();
723
724    for action in actions {
725        let key = match &action {
726            CodeActionOrCommand::CodeAction(ca) => format!(
727                "{}::{}",
728                ca.kind.as_ref().map(|k| k.as_str()).unwrap_or(""),
729                ca.title
730            ),
731            CodeActionOrCommand::Command(cmd) => format!("command::{}", cmd.title),
732        };
733
734        if seen.insert(key) {
735            deduped.push(action);
736        }
737    }
738
739    deduped
740}
741
742/// Collect LSP ranges for parsed import statements.
743fn import_statement_ranges(text: &str) -> Vec<Range> {
744    if let Ok(program) = parse_program(text) {
745        let mut ranges = Vec::new();
746        for item in &program.items {
747            if let Item::Import(_, span) = item {
748                ranges.push(span_to_range(text, span));
749            }
750        }
751        return ranges;
752    }
753
754    // Parse fallback: use line-based detection for unfinished import lines.
755    text.lines()
756        .enumerate()
757        .filter_map(|(line, raw)| {
758            let trimmed = raw.trim_start();
759            if trimmed.starts_with("from ") || trimmed.starts_with("use ") {
760                Some(Range {
761                    start: Position {
762                        line: line as u32,
763                        character: 0,
764                    },
765                    end: Position {
766                        line: line as u32,
767                        character: raw.len() as u32,
768                    },
769                })
770            } else {
771                None
772            }
773        })
774        .collect()
775}
776
777/// Return local names currently imported into scope.
778fn collect_imported_local_names(text: &str) -> HashSet<String> {
779    let Ok(program) = parse_program(text) else {
780        return HashSet::new();
781    };
782
783    let mut imported = HashSet::new();
784    for item in &program.items {
785        let Item::Import(import_stmt, _) = item else {
786            continue;
787        };
788        match &import_stmt.items {
789            ImportItems::Named(specs) => {
790                for spec in specs {
791                    imported.insert(spec.alias.clone().unwrap_or_else(|| spec.name.clone()));
792                }
793            }
794            ImportItems::Namespace { name, alias } => {
795                imported.insert(alias.clone().unwrap_or_else(|| name.clone()));
796            }
797        }
798    }
799
800    imported
801}
802
803fn import_insert_position(text: &str) -> Position {
804    let import_ranges = import_statement_ranges(text);
805    if let Some(last_line) = import_ranges.iter().map(|r| r.end.line).max() {
806        Position {
807            line: last_line + 1,
808            character: 0,
809        }
810    } else {
811        Position {
812            line: 0,
813            character: 0,
814        }
815    }
816}
817
818fn get_symbol_auto_import_actions(
819    text: &str,
820    uri: &Uri,
821    range: Range,
822    cache: &ModuleCache,
823) -> Vec<CodeActionOrCommand> {
824    let Some(symbol) = symbol_at_or_in_range(text, range) else {
825        return Vec::new();
826    };
827    if !is_import_candidate_symbol(&symbol) {
828        return Vec::new();
829    }
830
831    let imported_names = collect_imported_local_names(text);
832    if imported_names.contains(&symbol) {
833        return Vec::new();
834    }
835
836    let matches = if let Some(current_file) = uri.to_file_path() {
837        cache.find_exported_symbol_with_context(&symbol, current_file.as_ref(), None)
838    } else {
839        cache.find_exported_symbol(&symbol)
840    };
841    if matches.is_empty() {
842        return Vec::new();
843    }
844
845    let mut out = Vec::new();
846    let insert_at = import_insert_position(text);
847    for (import_path, _export) in matches {
848        out.push(CodeActionOrCommand::CodeAction(CodeAction {
849            title: format!("Import '{}' from {}", symbol, import_path),
850            kind: Some(CodeActionKind::QUICKFIX),
851            diagnostics: None,
852            edit: Some(WorkspaceEdit {
853                changes: Some({
854                    let mut changes = HashMap::new();
855                    changes.insert(
856                        uri.clone(),
857                        vec![TextEdit {
858                            range: Range {
859                                start: insert_at,
860                                end: insert_at,
861                            },
862                            new_text: format!("from {} use {{ {} }}\n", import_path, symbol),
863                        }],
864                    );
865                    changes
866                }),
867                document_changes: None,
868                change_annotations: None,
869            }),
870            command: None,
871            is_preferred: Some(true),
872            disabled: None,
873            data: None,
874        }));
875    }
876
877    out
878}
879
880fn symbol_at_or_in_range(text: &str, range: Range) -> Option<String> {
881    let selected = get_text_in_range(text, range);
882    let selected = selected.trim();
883    if !selected.is_empty() && is_identifier(selected) {
884        return Some(selected.to_string());
885    }
886    get_word_at_position(text, range.start)
887}
888
889fn is_identifier(name: &str) -> bool {
890    let mut chars = name.chars();
891    let Some(first) = chars.next() else {
892        return false;
893    };
894    if !(first.is_ascii_alphabetic() || first == '_') {
895        return false;
896    }
897    chars.all(|c| c.is_ascii_alphanumeric() || c == '_')
898}
899
900fn is_import_candidate_symbol(name: &str) -> bool {
901    is_identifier(name) && name.chars().next().is_some_and(|c| c.is_ascii_uppercase())
902}
903
904#[cfg(test)]
905mod tests {
906    use super::*;
907
908    #[test]
909    fn test_extract_undefined_name() {
910        assert_eq!(
911            extract_undefined_name("undefined variable 'foo'"),
912            Some("foo".to_string())
913        );
914        assert_eq!(
915            extract_undefined_name("'bar' is not defined"),
916            Some("bar".to_string())
917        );
918        assert_eq!(extract_undefined_name("some other message"), None);
919    }
920
921    #[test]
922    fn test_ranges_overlap() {
923        let r1 = Range {
924            start: Position {
925                line: 1,
926                character: 0,
927            },
928            end: Position {
929                line: 1,
930                character: 10,
931            },
932        };
933        let r2 = Range {
934            start: Position {
935                line: 1,
936                character: 5,
937            },
938            end: Position {
939                line: 1,
940                character: 15,
941            },
942        };
943        let r3 = Range {
944            start: Position {
945                line: 2,
946                character: 0,
947            },
948            end: Position {
949                line: 2,
950                character: 10,
951            },
952        };
953
954        assert!(ranges_overlap(r1, r2));
955        assert!(!ranges_overlap(r1, r3));
956    }
957
958    #[test]
959    fn test_get_text_in_range() {
960        let text = "let x = 42;\nlet y = 10;";
961
962        let range = Range {
963            start: Position {
964                line: 0,
965                character: 4,
966            },
967            end: Position {
968                line: 0,
969                character: 5,
970            },
971        };
972        assert_eq!(get_text_in_range(text, range), "x");
973
974        let range = Range {
975            start: Position {
976                line: 0,
977                character: 8,
978            },
979            end: Position {
980                line: 0,
981                character: 10,
982            },
983        };
984        assert_eq!(get_text_in_range(text, range), "42");
985    }
986
987    #[test]
988    fn test_is_expression() {
989        assert!(is_expression("42"));
990        assert!(is_expression("x + y"));
991        assert!(is_expression("foo()"));
992        assert!(!is_expression("let x = 42"));
993        assert!(!is_expression("function foo() {}"));
994    }
995
996    #[test]
997    fn test_extract_quoted_name_from_compiler_errors() {
998        // Matches actual compiler error message format from checking.rs
999        assert_eq!(
1000            extract_quoted_name(
1001                "Unknown enum type 'Snapshot'. Make sure it is imported or defined."
1002            ),
1003            Some("Snapshot".to_string())
1004        );
1005        assert_eq!(
1006            extract_quoted_name("Unknown variant 'BadVariant' for enum 'Color'"),
1007            Some("BadVariant".to_string())
1008        );
1009        assert_eq!(extract_quoted_name("no quotes here"), None);
1010    }
1011
1012    #[test]
1013    fn test_missing_trait_method_quick_fix() {
1014        let text = "trait Q {\n    filter(p): any;\n    select(c): any\n}\nimpl Q for T {\n    method filter(p) { self }\n}\n";
1015        let uri = Uri::from_file_path("/tmp/test.shape").unwrap();
1016        let diagnostic = Diagnostic {
1017            range: Range {
1018                start: Position {
1019                    line: 4,
1020                    character: 0,
1021                },
1022                end: Position {
1023                    line: 6,
1024                    character: 1,
1025                },
1026            },
1027            severity: Some(tower_lsp_server::ls_types::DiagnosticSeverity::ERROR),
1028            code: Some(tower_lsp_server::ls_types::NumberOrString::String(
1029                "E0401".to_string(),
1030            )),
1031            message: "Missing required method 'select' in impl Q for T.".to_string(),
1032            ..Default::default()
1033        };
1034        let actions = get_code_actions(text, &uri, diagnostic.range, &[diagnostic], None, None);
1035        assert!(
1036            actions.iter().any(|a| {
1037                if let CodeActionOrCommand::CodeAction(action) = a {
1038                    action.title.contains("Implement method 'select'")
1039                } else {
1040                    false
1041                }
1042            }),
1043            "Should have quick fix to implement missing method. Got: {:?}",
1044            actions
1045                .iter()
1046                .map(|a| match a {
1047                    CodeActionOrCommand::CodeAction(action) => action.title.clone(),
1048                    CodeActionOrCommand::Command(cmd) => cmd.title.clone(),
1049                })
1050                .collect::<Vec<_>>()
1051        );
1052    }
1053
1054    #[test]
1055    fn test_auto_import_generates_valid_syntax() {
1056        // Verify the generated import text matches Shape grammar
1057        let name = "Snapshot";
1058        let import_path = "std::core::snapshot";
1059        let import_text = format!("from {} use {{ {} }}\n", import_path, name);
1060        assert_eq!(import_text, "from std::core::snapshot use { Snapshot }\n");
1061        // Verify it parses as valid Shape
1062        let full_code = format!("{}let x = 1\n", import_text);
1063        assert!(
1064            shape_ast::parser::parse_program(&full_code).is_ok(),
1065            "Generated import should be valid Shape syntax: {}",
1066            full_code
1067        );
1068    }
1069
1070    #[test]
1071    fn test_source_actions_organize_imports_only_on_import_lines() {
1072        let text = "from std::core::math use { abs }\nlet x = abs(1)\n";
1073        let uri = Uri::from_file_path("/tmp/test.shape").unwrap();
1074
1075        let import_range = Range {
1076            start: Position {
1077                line: 0,
1078                character: 5,
1079            },
1080            end: Position {
1081                line: 0,
1082                character: 5,
1083            },
1084        };
1085        let non_import_range = Range {
1086            start: Position {
1087                line: 1,
1088                character: 4,
1089            },
1090            end: Position {
1091                line: 1,
1092                character: 4,
1093            },
1094        };
1095
1096        let on_import = get_code_actions(text, &uri, import_range, &[], None, None);
1097        let away_from_import = get_code_actions(text, &uri, non_import_range, &[], None, None);
1098
1099        let has_organize = |actions: &[CodeActionOrCommand]| {
1100            actions.iter().any(|a| {
1101                matches!(
1102                    a,
1103                    CodeActionOrCommand::CodeAction(CodeAction {
1104                        kind: Some(kind),
1105                        ..
1106                    }) if kind == &CodeActionKind::SOURCE_ORGANIZE_IMPORTS
1107                )
1108            })
1109        };
1110
1111        assert!(has_organize(&on_import));
1112        assert!(!has_organize(&away_from_import));
1113    }
1114
1115    #[test]
1116    fn test_symbol_auto_import_action_from_cursor() {
1117        let text = "match snapshot() {\n  Snapshot::Resumed => { }\n}\n";
1118        let uri = Uri::from_file_path("/tmp/test.shape").unwrap();
1119        let range = Range {
1120            start: Position {
1121                line: 1,
1122                character: 3,
1123            },
1124            end: Position {
1125                line: 1,
1126                character: 11,
1127            },
1128        };
1129
1130        let cache = ModuleCache::new();
1131        let actions = get_code_actions(text, &uri, range, &[], Some(&cache), None);
1132        assert!(
1133            actions.iter().any(|a| {
1134                matches!(
1135                    a,
1136                    CodeActionOrCommand::CodeAction(CodeAction { title, .. })
1137                        if title.contains("Import 'Snapshot' from std::core::snapshot")
1138                )
1139            }),
1140            "Expected auto-import action for Snapshot. Got: {:?}",
1141            actions
1142                .iter()
1143                .map(|a| match a {
1144                    CodeActionOrCommand::CodeAction(action) => action.title.clone(),
1145                    CodeActionOrCommand::Command(cmd) => cmd.title.clone(),
1146                })
1147                .collect::<Vec<_>>()
1148        );
1149    }
1150
1151    #[test]
1152    fn test_empty_match_quick_fix_adds_wildcard_arm() {
1153        let text = "fn afunc(c) {\n  match c {\n\n  }\n}\n";
1154        let uri = Uri::from_file_path("/tmp/test.shape").unwrap();
1155        let diagnostic = Diagnostic {
1156            range: Range {
1157                start: Position {
1158                    line: 1,
1159                    character: 2,
1160                },
1161                end: Position {
1162                    line: 3,
1163                    character: 3,
1164                },
1165            },
1166            severity: None,
1167            code: None,
1168            code_description: None,
1169            source: Some("shape".to_string()),
1170            message: "match expression requires at least one arm".to_string(),
1171            related_information: None,
1172            tags: None,
1173            data: None,
1174        };
1175        let range = Range {
1176            start: Position {
1177                line: 2,
1178                character: 2,
1179            },
1180            end: Position {
1181                line: 2,
1182                character: 2,
1183            },
1184        };
1185
1186        let actions = get_code_actions(text, &uri, range, &[diagnostic], None, None);
1187        assert!(
1188            actions.iter().any(|a| matches!(
1189                a,
1190                CodeActionOrCommand::CodeAction(CodeAction { title, .. })
1191                    if title == "Add wildcard match arm"
1192            )),
1193            "Expected wildcard match-arm quick fix. Got: {:?}",
1194            actions
1195                .iter()
1196                .map(|a| match a {
1197                    CodeActionOrCommand::CodeAction(action) => action.title.clone(),
1198                    CodeActionOrCommand::Command(cmd) => cmd.title.clone(),
1199                })
1200                .collect::<Vec<_>>()
1201        );
1202    }
1203
1204    #[test]
1205    fn test_non_exhaustive_match_quick_fix_adds_missing_arms() {
1206        let text = "match snapshot() {\n  Snapshot::Resumed => { }\n}\n";
1207        let uri = Uri::from_file_path("/tmp/test.shape").unwrap();
1208        let diagnostic = Diagnostic {
1209            range: Range {
1210                start: Position {
1211                    line: 0,
1212                    character: 0,
1213                },
1214                end: Position {
1215                    line: 2,
1216                    character: 1,
1217                },
1218            },
1219            severity: None,
1220            code: None,
1221            code_description: None,
1222            source: Some("shape".to_string()),
1223            message: "Non-exhaustive match on 'Snapshot': missing variants Hash".to_string(),
1224            related_information: None,
1225            tags: None,
1226            data: None,
1227        };
1228        let range = Range {
1229            start: Position {
1230                line: 1,
1231                character: 5,
1232            },
1233            end: Position {
1234                line: 1,
1235                character: 5,
1236            },
1237        };
1238
1239        let actions = get_code_actions(text, &uri, range, &[diagnostic], None, None);
1240        assert!(
1241            actions.iter().any(|a| matches!(
1242                a,
1243                CodeActionOrCommand::CodeAction(CodeAction { title, .. })
1244                    if title == "Add missing match arms for Snapshot"
1245            )),
1246            "Expected missing-arms quick fix. Got: {:?}",
1247            actions
1248                .iter()
1249                .map(|a| match a {
1250                    CodeActionOrCommand::CodeAction(action) => action.title.clone(),
1251                    CodeActionOrCommand::Command(cmd) => cmd.title.clone(),
1252                })
1253                .collect::<Vec<_>>()
1254        );
1255    }
1256}