ts_bridge/protocol/text_document/
code_action.rs

1//! =============================================================================
2//! textDocument/codeAction
3//! =============================================================================
4//!
5//! Bridges LSP code actions to tsserver’s `getCodeFixes` command.  Every
6//! diagnostic supplied by the client is forwarded so tsserver can suggest quick
7//! fixes (missing imports, unreachable code, etc.).  Results are converted into
8//! `CodeAction` entries with ready-to-apply workspace edits.  When tsserver also
9//! reports a `fixId`, we surface a companion "fix all" action that is resolved
10//! lazily via `codeAction/resolve`.
11
12use std::collections::HashMap;
13
14use anyhow::{Context, Result};
15use lsp_types::{
16    CodeAction, CodeActionContext, CodeActionKind, CodeActionOrCommand, CodeActionParams,
17    CodeActionResponse, Diagnostic, NumberOrString, TextEdit, Uri, WorkspaceEdit,
18};
19use serde::{Deserialize, Serialize};
20use serde_json::{Value, json};
21
22use crate::protocol::{AdapterResult, RequestSpec};
23use crate::rpc::{Priority, Route};
24use crate::utils::{tsserver_file_to_uri, tsserver_range_from_value_lsp, uri_to_file_path};
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27#[serde(tag = "type")]
28pub enum CodeActionData {
29    #[serde(rename = "fixAll")]
30    FixAll(FixAllData),
31    #[serde(rename = "organizeImports")]
32    OrganizeImports(OrganizeImportsData),
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct FixAllData {
37    pub file: String,
38    pub fix_id: String,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct OrganizeImportsData {
43    pub file: String,
44}
45
46#[derive(Debug, Deserialize)]
47struct AdapterContext {
48    file: String,
49    context: CodeActionContext,
50    #[serde(default, rename = "includeOrganize")]
51    include_organize: bool,
52}
53
54pub fn handle(params: CodeActionParams) -> RequestSpec {
55    let CodeActionParams {
56        text_document,
57        range,
58        context,
59        work_done_progress_params: _,
60        partial_result_params: _,
61    } = params;
62    let uri = text_document.uri;
63    let file = uri_to_file_path(uri.as_str()).unwrap_or_else(|| uri.to_string());
64    let context_only = context.only.clone();
65    let wants_organize = context_only
66        .as_ref()
67        .map(|list| {
68            list.iter()
69                .any(|kind| matches_kind(kind, CodeActionKind::SOURCE_ORGANIZE_IMPORTS.as_str()))
70        })
71        .unwrap_or(false);
72    let wants_quickfix = context_only
73        .as_ref()
74        .map(|list| {
75            list.iter()
76                .any(|kind| matches_kind(kind, CodeActionKind::QUICKFIX.as_str()))
77        })
78        .unwrap_or(true);
79
80    let has_filter = context_only
81        .as_ref()
82        .map(|list| !list.is_empty())
83        .unwrap_or(false);
84
85    if wants_organize && !wants_quickfix {
86        return organize_imports_request(file);
87    }
88
89    // When the client didn't filter (`only` empty/missing), include organize imports alongside
90    // quick fixes so the default picker shows it.
91    let include_organize = wants_organize || !has_filter;
92
93    let error_codes = collect_error_codes(&context);
94
95    let request = json!({
96        "command": "getCodeFixes",
97        "arguments": {
98            "file": file,
99            "startLine": range.start.line + 1,
100            "startOffset": range.start.character + 1,
101            "endLine": range.end.line + 1,
102            "endOffset": range.end.character + 1,
103            "errorCodes": error_codes,
104        }
105    });
106
107    let adapter_context = json!({
108        "file": file,
109        "context": context,
110        "includeOrganize": include_organize,
111    });
112
113    RequestSpec {
114        route: Route::Syntax,
115        payload: request,
116        priority: Priority::Normal,
117        on_response: Some(adapt_code_actions),
118        response_context: Some(adapter_context),
119    }
120}
121
122fn organize_imports_request(file: String) -> RequestSpec {
123    let request = organize_imports_payload(&file);
124
125    RequestSpec {
126        route: Route::Syntax,
127        payload: request,
128        priority: Priority::Low,
129        on_response: Some(adapt_organize_imports),
130        response_context: None,
131    }
132}
133
134fn adapt_code_actions(payload: &Value, context: Option<&Value>) -> Result<AdapterResult> {
135    let adapter_ctx: AdapterContext =
136        serde_json::from_value(context.cloned().context("code action context missing")?)?;
137    let fixes = payload
138        .get("body")
139        .and_then(|value| value.as_array())
140        .cloned()
141        .unwrap_or_default();
142
143    let mut actions: Vec<CodeActionOrCommand> = Vec::new();
144    for fix in fixes {
145        if let Some(action) = build_quick_fix(&fix, &adapter_ctx) {
146            actions.push(CodeActionOrCommand::CodeAction(action));
147        }
148        if let Some(action) = build_fix_all_action(&fix, &adapter_ctx) {
149            actions.push(CodeActionOrCommand::CodeAction(action));
150        }
151    }
152
153    if adapter_ctx.include_organize {
154        if let Some(action) = organize_imports_placeholder(&adapter_ctx.file) {
155            actions.push(CodeActionOrCommand::CodeAction(action));
156        }
157    }
158
159    Ok(AdapterResult::ready(serde_json::to_value(
160        CodeActionResponse::from(actions),
161    )?))
162}
163
164fn build_quick_fix(fix: &Value, ctx: &AdapterContext) -> Option<CodeAction> {
165    let title = fix.get("description")?.as_str()?.to_string();
166    let changes = fix.get("changes")?.as_array()?;
167    let edit = workspace_edit_from_tsserver_changes(changes)?;
168    let diagnostics = diagnostics_for_action(&ctx.context);
169
170    let mut action = CodeAction {
171        title,
172        kind: Some(CodeActionKind::QUICKFIX),
173        diagnostics,
174        edit: Some(edit),
175        ..CodeAction::default()
176    };
177
178    if let Some(preferred) = fix.get("isPreferred").and_then(|v| v.as_bool()) {
179        action.is_preferred = Some(preferred);
180    }
181
182    Some(action)
183}
184
185fn build_fix_all_action(fix: &Value, ctx: &AdapterContext) -> Option<CodeAction> {
186    let fix_id = fix.get("fixId")?.as_str()?;
187    let description = fix.get("fixAllDescription")?.as_str()?;
188    let diagnostics = diagnostics_for_action(&ctx.context);
189
190    // "Fix all" belongs to the source action family so clients can filter by
191    // `source.fixAll` in picker UIs. Classifying it as a quick fix hides the
192    // entry whenever a client explicitly requests source actions only.
193    let mut action = CodeAction {
194        title: description.to_string(),
195        kind: Some(CodeActionKind::SOURCE_FIX_ALL),
196        diagnostics,
197        ..CodeAction::default()
198    };
199
200    let data = CodeActionData::FixAll(FixAllData {
201        file: ctx.file.clone(),
202        fix_id: fix_id.to_string(),
203    });
204    action.data = Some(serde_json::to_value(data).ok()?);
205
206    Some(action)
207}
208
209fn organize_imports_placeholder(file: &str) -> Option<CodeAction> {
210    let data = CodeActionData::OrganizeImports(OrganizeImportsData {
211        file: file.to_string(),
212    });
213    Some(CodeAction {
214        title: "Organize Imports".to_string(),
215        kind: Some(CodeActionKind::SOURCE_ORGANIZE_IMPORTS),
216        data: Some(serde_json::to_value(data).ok()?),
217        ..CodeAction::default()
218    })
219}
220
221fn diagnostics_for_action(context: &CodeActionContext) -> Option<Vec<Diagnostic>> {
222    if context.diagnostics.is_empty() {
223        None
224    } else {
225        Some(context.diagnostics.clone())
226    }
227}
228
229fn collect_error_codes(context: &CodeActionContext) -> Vec<i32> {
230    let mut codes = Vec::new();
231    for diagnostic in &context.diagnostics {
232        if let Some(NumberOrString::Number(value)) = diagnostic.code.clone() {
233            codes.push(value as i32);
234        }
235    }
236    codes
237}
238
239pub(crate) fn workspace_edit_from_tsserver_changes(changes: &[Value]) -> Option<WorkspaceEdit> {
240    let mut map: HashMap<Uri, Vec<TextEdit>> = HashMap::new();
241    for change in changes {
242        let file_name = change
243            .get("fileName")
244            .or_else(|| change.get("file"))?
245            .as_str()?;
246        let uri = tsserver_file_to_uri(file_name)?;
247        let text_changes = change.get("textChanges")?.as_array()?;
248        let entry = map.entry(uri).or_default();
249        for text_change in text_changes {
250            let range = tsserver_range_from_value_lsp(text_change)?;
251            let new_text = text_change.get("newText")?.as_str()?.to_string();
252            entry.push(TextEdit { range, new_text });
253        }
254    }
255
256    if map.is_empty() {
257        None
258    } else {
259        Some(WorkspaceEdit {
260            changes: Some(map),
261            document_changes: None,
262            change_annotations: None,
263        })
264    }
265}
266
267fn adapt_organize_imports(payload: &Value, _context: Option<&Value>) -> Result<AdapterResult> {
268    let changes = payload
269        .get("body")
270        .and_then(|value| value.as_array())
271        .cloned()
272        .unwrap_or_default();
273
274    let mut actions: Vec<CodeActionOrCommand> = Vec::new();
275    if let Some(edit) = workspace_edit_from_tsserver_changes(&changes) {
276        let action = CodeAction {
277            title: "Organize Imports".to_string(),
278            kind: Some(CodeActionKind::SOURCE_ORGANIZE_IMPORTS),
279            edit: Some(edit),
280            ..CodeAction::default()
281        };
282        actions.push(CodeActionOrCommand::CodeAction(action));
283    }
284
285    Ok(AdapterResult::ready(serde_json::to_value(
286        CodeActionResponse::from(actions),
287    )?))
288}
289
290fn matches_kind(kind: &CodeActionKind, needle: &str) -> bool {
291    let value = kind.as_str();
292    value == needle || value.starts_with(&(needle.to_string() + "."))
293}
294
295pub(crate) fn organize_imports_payload(file: &str) -> Value {
296    json!({
297        "command": "organizeImports",
298        "arguments": {
299            "scope": {
300                "type": "file",
301                "args": {
302                    "file": file,
303                }
304            }
305        }
306    })
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use lsp_types::{
313        CodeActionContext, CodeActionKind, Diagnostic, Position, Range, TextDocumentIdentifier, Uri,
314    };
315    use serde_json::json;
316    use std::str::FromStr;
317
318    const FILE_URI: &str = "file:///workspace/app.ts";
319    const FILE_PATH: &str = "/workspace/app.ts";
320
321    fn sample_diagnostic(code: i32) -> Diagnostic {
322        Diagnostic {
323            range: Range {
324                start: Position {
325                    line: 0,
326                    character: 0,
327                },
328                end: Position {
329                    line: 0,
330                    character: 1,
331                },
332            },
333            code: Some(NumberOrString::Number(code)),
334            ..Diagnostic::default()
335        }
336    }
337
338    fn sample_context() -> CodeActionContext {
339        CodeActionContext {
340            diagnostics: vec![sample_diagnostic(6133)],
341            only: None,
342            trigger_kind: None,
343        }
344    }
345
346    fn adapter_context(include_organize: bool) -> Value {
347        json!({
348            "file": FILE_PATH,
349            "context": sample_context(),
350            "includeOrganize": include_organize,
351        })
352    }
353
354    #[test]
355    fn build_fix_all_action_sets_source_kind() {
356        let ctx = AdapterContext {
357            file: FILE_PATH.to_string(),
358            context: sample_context(),
359            include_organize: false,
360        };
361        let fix = json!({
362            "fixId": "fixAllMissingImports",
363            "fixAllDescription": "Fix all missing imports",
364        });
365
366        let action = build_fix_all_action(&fix, &ctx).expect("fix all action");
367        assert_eq!(action.kind, Some(CodeActionKind::SOURCE_FIX_ALL));
368        let data: CodeActionData =
369            serde_json::from_value(action.data.expect("data")).expect("code action data");
370        match data {
371            CodeActionData::FixAll(fix_all) => {
372                assert_eq!(fix_all.file, FILE_PATH);
373                assert_eq!(fix_all.fix_id, "fixAllMissingImports");
374            }
375            _ => panic!("expected fix all data"),
376        }
377    }
378
379    #[test]
380    fn adapt_code_actions_emits_quick_fix_fix_all_and_organize() {
381        let payload = json!({
382            "body": [{
383                "description": "Add missing import",
384                "changes": [{
385                    "fileName": FILE_PATH,
386                    "textChanges": [{
387                        "start": { "line": 1, "offset": 1 },
388                        "end": { "line": 1, "offset": 1 },
389                        "newText": "import { foo } from 'foo';\n"
390                    }]
391                }],
392                "fixId": "fixAllMissingImports",
393                "fixAllDescription": "Fix all missing imports",
394                "isPreferred": true,
395            }]
396        });
397        let ctx_value = adapter_context(true);
398
399        let adapted = adapt_code_actions(&payload, Some(&ctx_value)).expect("adapt");
400        let value = match adapted {
401            AdapterResult::Ready(value) => value,
402            AdapterResult::Continue(_) => panic!("expected ready code action response"),
403        };
404        let actions: Vec<_> = match serde_json::from_value::<CodeActionResponse>(value) {
405            Ok(actions) => actions,
406            Err(err) => panic!("failed to deserialize code action response: {err}"),
407        };
408        assert_eq!(actions.len(), 3, "quick fix, fix all, organize placeholder");
409
410        match &actions[0] {
411            CodeActionOrCommand::CodeAction(action) => {
412                assert_eq!(action.kind, Some(CodeActionKind::QUICKFIX));
413                assert!(action.edit.is_some(), "quick fix should have edit");
414                assert_eq!(action.is_preferred, Some(true));
415            }
416            _ => panic!("expected code action"),
417        }
418
419        match &actions[1] {
420            CodeActionOrCommand::CodeAction(action) => {
421                assert_eq!(action.kind, Some(CodeActionKind::SOURCE_FIX_ALL));
422                let data: CodeActionData =
423                    serde_json::from_value(action.data.clone().unwrap()).expect("fix all data");
424                match data {
425                    CodeActionData::FixAll(fix_all) => {
426                        assert_eq!(fix_all.file, FILE_PATH);
427                    }
428                    _ => panic!("expected fix all data"),
429                }
430            }
431            _ => panic!("expected fix all code action"),
432        }
433
434        match &actions[2] {
435            CodeActionOrCommand::CodeAction(action) => {
436                assert_eq!(action.kind, Some(CodeActionKind::SOURCE_ORGANIZE_IMPORTS));
437                assert!(action.data.is_some());
438            }
439            _ => panic!("expected organize imports action"),
440        }
441    }
442
443    #[test]
444    fn handle_collects_error_codes_from_context() {
445        let params = CodeActionParams {
446            text_document: TextDocumentIdentifier {
447                uri: Uri::from_str(FILE_URI).expect("uri"),
448            },
449            range: lsp_types::Range {
450                start: lsp_types::Position {
451                    line: 0,
452                    character: 0,
453                },
454                end: lsp_types::Position {
455                    line: 0,
456                    character: 1,
457                },
458            },
459            context: CodeActionContext {
460                diagnostics: vec![sample_diagnostic(1234)],
461                only: None,
462                trigger_kind: None,
463            },
464            work_done_progress_params: Default::default(),
465            partial_result_params: Default::default(),
466        };
467
468        let spec = handle(params);
469        let args = spec
470            .payload
471            .get("arguments")
472            .and_then(|value| value.as_object())
473            .expect("arguments");
474        let error_codes = args
475            .get("errorCodes")
476            .and_then(|v| v.as_array())
477            .expect("error codes array");
478        assert_eq!(error_codes, &[json!(1234)]);
479    }
480}