Skip to main content

voirs_cli/lsp/
code_actions.rs

1// ! LSP code actions provider
2//!
3//! Provides quick fixes and refactorings for SSML documents.
4
5use super::{Position, Range};
6use serde_json::Value;
7
8/// Code action kind
9#[derive(Debug, Clone, Copy)]
10#[allow(dead_code)]
11pub enum CodeActionKind {
12    /// Quick fix for errors
13    QuickFix,
14    /// Refactor code
15    Refactor,
16    /// Refactor extract
17    RefactorExtract,
18    /// Refactor inline
19    RefactorInline,
20    /// Refactor rewrite
21    RefactorRewrite,
22    /// Source action
23    Source,
24    /// Source organize imports
25    SourceOrganizeImports,
26}
27
28impl CodeActionKind {
29    /// Get the LSP string representation
30    pub fn as_str(&self) -> &'static str {
31        match self {
32            CodeActionKind::QuickFix => "quickfix",
33            CodeActionKind::Refactor => "refactor",
34            CodeActionKind::RefactorExtract => "refactor.extract",
35            CodeActionKind::RefactorInline => "refactor.inline",
36            CodeActionKind::RefactorRewrite => "refactor.rewrite",
37            CodeActionKind::Source => "source",
38            CodeActionKind::SourceOrganizeImports => "source.organizeImports",
39        }
40    }
41}
42
43/// Get code actions for a document range
44pub fn get_code_actions(document_text: &str, range: Range) -> Vec<Value> {
45    let mut actions = Vec::new();
46
47    // Extract text in range
48    if let Some(selection) = extract_range_text(document_text, range) {
49        // Wrap in prosody tag
50        if !selection.trim().starts_with('<') {
51            actions.push(create_code_action(
52                "Wrap in prosody tag",
53                CodeActionKind::RefactorRewrite,
54                &format!("<prosody rate=\"1.0\">{}</prosody>", selection),
55                range,
56            ));
57        }
58
59        // Wrap in emphasis tag
60        if !selection.trim().starts_with('<') {
61            actions.push(create_code_action(
62                "Add emphasis",
63                CodeActionKind::RefactorRewrite,
64                &format!("<emphasis level=\"moderate\">{}</emphasis>", selection),
65                range,
66            ));
67        }
68
69        // Wrap in voice tag
70        if !selection.trim().starts_with('<') {
71            actions.push(create_code_action(
72                "Change voice",
73                CodeActionKind::RefactorRewrite,
74                &format!("<voice name=\"kokoro-en\">{}</voice>", selection),
75                range,
76            ));
77        }
78    }
79
80    // Add break before/after
81    actions.push(create_insert_action(
82        "Insert pause before",
83        range.start,
84        "<break time=\"500ms\"/>",
85    ));
86
87    actions.push(create_insert_action(
88        "Insert pause after",
89        range.end,
90        "<break time=\"500ms\"/>",
91    ));
92
93    // Format SSML
94    actions.push(create_source_action(
95        "Format SSML",
96        CodeActionKind::Source,
97        "format",
98    ));
99
100    // Validate SSML
101    actions.push(create_source_action(
102        "Validate SSML",
103        CodeActionKind::Source,
104        "validate",
105    ));
106
107    actions
108}
109
110/// Create a code action
111fn create_code_action(title: &str, kind: CodeActionKind, new_text: &str, range: Range) -> Value {
112    serde_json::json!({
113        "title": title,
114        "kind": kind.as_str(),
115        "edit": {
116            "changes": {
117                "document": [{
118                    "range": {
119                        "start": {
120                            "line": range.start.line,
121                            "character": range.start.character
122                        },
123                        "end": {
124                            "line": range.end.line,
125                            "character": range.end.character
126                        }
127                    },
128                    "newText": new_text
129                }]
130            }
131        }
132    })
133}
134
135/// Create an insert action
136fn create_insert_action(title: &str, position: Position, text: &str) -> Value {
137    serde_json::json!({
138        "title": title,
139        "kind": CodeActionKind::RefactorRewrite.as_str(),
140        "edit": {
141            "changes": {
142                "document": [{
143                    "range": {
144                        "start": {
145                            "line": position.line,
146                            "character": position.character
147                        },
148                        "end": {
149                            "line": position.line,
150                            "character": position.character
151                        }
152                    },
153                    "newText": text
154                }]
155            }
156        }
157    })
158}
159
160/// Create a source action
161fn create_source_action(title: &str, kind: CodeActionKind, command: &str) -> Value {
162    serde_json::json!({
163        "title": title,
164        "kind": kind.as_str(),
165        "command": {
166            "title": title,
167            "command": format!("voirs.{}", command),
168            "arguments": []
169        }
170    })
171}
172
173/// Extract text from a range
174fn extract_range_text(text: &str, range: Range) -> Option<String> {
175    let lines: Vec<&str> = text.lines().collect();
176
177    if range.start.line == range.end.line {
178        // Single line selection
179        if let Some(line) = lines.get(range.start.line as usize) {
180            let start = range.start.character as usize;
181            let end = range.end.character as usize;
182            if start < line.len() && end <= line.len() {
183                return Some(line[start..end].to_string());
184            }
185        }
186    } else {
187        // Multi-line selection
188        let mut result = String::new();
189        for (i, line) in lines.iter().enumerate() {
190            let line_num = i as u32;
191            if line_num < range.start.line || line_num > range.end.line {
192                continue;
193            }
194
195            if line_num == range.start.line {
196                // First line
197                let start = range.start.character as usize;
198                if start < line.len() {
199                    result.push_str(&line[start..]);
200                    result.push('\n');
201                }
202            } else if line_num == range.end.line {
203                // Last line
204                let end = range.end.character as usize;
205                if end <= line.len() {
206                    result.push_str(&line[..end]);
207                }
208            } else {
209                // Middle lines
210                result.push_str(line);
211                result.push('\n');
212            }
213        }
214        if !result.is_empty() {
215            return Some(result);
216        }
217    }
218
219    None
220}
221
222/// Get quick fixes for common SSML errors
223pub fn get_quick_fixes(error_message: &str, range: Range) -> Vec<Value> {
224    let mut fixes = Vec::new();
225
226    if error_message.contains("unclosed tag") {
227        if let Some(tag_name) = extract_tag_name(error_message) {
228            fixes.push(create_code_action(
229                &format!("Close <{}> tag", tag_name),
230                CodeActionKind::QuickFix,
231                &format!("</{}>", tag_name),
232                range,
233            ));
234        }
235    }
236
237    if error_message.contains("invalid attribute") {
238        fixes.push(create_code_action(
239            "Remove invalid attribute",
240            CodeActionKind::QuickFix,
241            "",
242            range,
243        ));
244    }
245
246    if error_message.contains("invalid voice") {
247        fixes.push(create_code_action(
248            "Replace with 'kokoro-en'",
249            CodeActionKind::QuickFix,
250            "kokoro-en",
251            range,
252        ));
253    }
254
255    fixes
256}
257
258/// Extract tag name from error message
259fn extract_tag_name(message: &str) -> Option<String> {
260    // Simple extraction: look for text between '<' and '>'
261    if let Some(start) = message.find('<') {
262        if let Some(end) = message[start..].find('>') {
263            let tag = &message[start + 1..start + end];
264            // Remove any attributes or whitespace
265            let tag_name = tag.split_whitespace().next()?;
266            return Some(tag_name.to_string());
267        }
268    }
269    None
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_get_code_actions() {
278        let text = "Hello world";
279        let range = Range::single_line(0, 0, 11);
280
281        let actions = get_code_actions(text, range);
282        assert!(!actions.is_empty());
283
284        // Check for expected actions
285        let titles: Vec<&str> = actions
286            .iter()
287            .map(|a| a["title"].as_str().unwrap_or(""))
288            .collect();
289
290        assert!(titles.contains(&"Wrap in prosody tag"));
291        assert!(titles.contains(&"Add emphasis"));
292        assert!(titles.contains(&"Insert pause before"));
293    }
294
295    #[test]
296    fn test_extract_range_text_single_line() {
297        let text = "Hello world";
298        let range = Range::single_line(0, 0, 5);
299
300        let extracted = extract_range_text(text, range);
301        assert_eq!(extracted, Some("Hello".to_string()));
302    }
303
304    #[test]
305    fn test_extract_range_text_multi_line() {
306        let text = "Line 1\nLine 2\nLine 3";
307        let range = Range::new(Position::new(0, 5), Position::new(1, 4));
308
309        let extracted = extract_range_text(text, range);
310        assert!(extracted.is_some());
311        assert!(extracted.unwrap().contains("1\nLine"));
312    }
313
314    #[test]
315    fn test_get_quick_fixes_unclosed_tag() {
316        let error = "Error: unclosed tag <speak>";
317        let range = Range::single_line(0, 0, 7);
318
319        let fixes = get_quick_fixes(error, range);
320        assert!(!fixes.is_empty());
321        assert_eq!(fixes[0]["title"].as_str().unwrap(), "Close <speak> tag");
322    }
323
324    #[test]
325    fn test_extract_tag_name() {
326        let message = "unclosed tag <prosody>";
327        let tag = extract_tag_name(message);
328        assert_eq!(tag, Some("prosody".to_string()));
329    }
330
331    #[test]
332    fn test_code_action_kinds() {
333        assert_eq!(CodeActionKind::QuickFix.as_str(), "quickfix");
334        assert_eq!(CodeActionKind::Refactor.as_str(), "refactor");
335        assert_eq!(CodeActionKind::Source.as_str(), "source");
336    }
337
338    #[test]
339    fn test_create_insert_action() {
340        let action = create_insert_action("Test insert", Position::new(5, 10), "test text");
341
342        assert_eq!(action["title"].as_str().unwrap(), "Test insert");
343        assert_eq!(
344            action["edit"]["changes"]["document"][0]["newText"]
345                .as_str()
346                .unwrap(),
347            "test text"
348        );
349    }
350}