squawk_server/
lib.rs

1use anyhow::{Context, Result};
2use line_index::LineIndex;
3use log::info;
4use lsp_server::{Connection, Message, Notification, Response};
5use lsp_types::{
6    CodeAction, CodeActionKind, CodeActionOptions, CodeActionOrCommand, CodeActionParams,
7    CodeActionProviderCapability, CodeActionResponse, Command, Diagnostic,
8    DidChangeTextDocumentParams, DidCloseTextDocumentParams, DidOpenTextDocumentParams,
9    GotoDefinitionParams, GotoDefinitionResponse, InitializeParams, Location, OneOf,
10    PublishDiagnosticsParams, SelectionRangeParams, SelectionRangeProviderCapability,
11    ServerCapabilities, TextDocumentSyncCapability, TextDocumentSyncKind, Url,
12    WorkDoneProgressOptions, WorkspaceEdit,
13    notification::{
14        DidChangeTextDocument, DidCloseTextDocument, DidOpenTextDocument, Notification as _,
15        PublishDiagnostics,
16    },
17    request::{CodeActionRequest, GotoDefinition, Request, SelectionRangeRequest},
18};
19use rowan::TextRange;
20use squawk_ide::code_actions::code_actions;
21use squawk_ide::goto_definition::goto_definition;
22use squawk_syntax::{Parse, SourceFile};
23use std::collections::HashMap;
24
25use diagnostic::DIAGNOSTIC_NAME;
26
27use crate::diagnostic::AssociatedDiagnosticData;
28mod diagnostic;
29mod ignore;
30mod lint;
31mod lsp_utils;
32
33struct DocumentState {
34    content: String,
35    version: i32,
36}
37
38pub fn run() -> Result<()> {
39    info!("Starting Squawk LSP server");
40
41    let (connection, io_threads) = Connection::stdio();
42
43    let server_capabilities = serde_json::to_value(&ServerCapabilities {
44        text_document_sync: Some(TextDocumentSyncCapability::Kind(
45            TextDocumentSyncKind::INCREMENTAL,
46        )),
47        code_action_provider: Some(CodeActionProviderCapability::Options(CodeActionOptions {
48            code_action_kinds: Some(vec![
49                CodeActionKind::QUICKFIX,
50                CodeActionKind::REFACTOR_REWRITE,
51            ]),
52            work_done_progress_options: WorkDoneProgressOptions {
53                work_done_progress: None,
54            },
55            resolve_provider: None,
56        })),
57        selection_range_provider: Some(SelectionRangeProviderCapability::Simple(true)),
58        definition_provider: Some(OneOf::Left(true)),
59        ..Default::default()
60    })
61    .unwrap();
62
63    info!("LSP server initializing connection...");
64    let initialization_params = connection.initialize(server_capabilities)?;
65    info!("LSP server initialized, entering main loop");
66
67    main_loop(connection, initialization_params)?;
68
69    info!("LSP server shutting down");
70
71    io_threads.join()?;
72    Ok(())
73}
74
75fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
76    info!("Server main loop");
77
78    let init_params: InitializeParams = serde_json::from_value(params).unwrap_or_default();
79    info!("Client process ID: {:?}", init_params.process_id);
80    let client_name = init_params.client_info.map(|x| x.name);
81    info!("Client name: {client_name:?}");
82
83    let mut documents: HashMap<Url, DocumentState> = HashMap::new();
84
85    for msg in &connection.receiver {
86        match msg {
87            Message::Request(req) => {
88                info!("Received request: method={}, id={:?}", req.method, req.id);
89
90                if connection.handle_shutdown(&req)? {
91                    info!("Received shutdown request, exiting");
92                    return Ok(());
93                }
94
95                match req.method.as_ref() {
96                    GotoDefinition::METHOD => {
97                        handle_goto_definition(&connection, req, &documents)?;
98                    }
99                    CodeActionRequest::METHOD => {
100                        handle_code_action(&connection, req, &documents)?;
101                    }
102                    SelectionRangeRequest::METHOD => {
103                        handle_selection_range(&connection, req, &documents)?;
104                    }
105                    "squawk/syntaxTree" => {
106                        handle_syntax_tree(&connection, req, &documents)?;
107                    }
108                    "squawk/tokens" => {
109                        handle_tokens(&connection, req, &documents)?;
110                    }
111                    _ => {
112                        info!("Ignoring unhandled request: {}", req.method);
113                    }
114                }
115            }
116            Message::Response(resp) => {
117                info!("Received response: id={:?}", resp.id);
118            }
119            Message::Notification(notif) => {
120                info!("Received notification: method={}", notif.method);
121                match notif.method.as_ref() {
122                    DidOpenTextDocument::METHOD => {
123                        handle_did_open(&connection, notif, &mut documents)?;
124                    }
125                    DidChangeTextDocument::METHOD => {
126                        handle_did_change(&connection, notif, &mut documents)?;
127                    }
128                    DidCloseTextDocument::METHOD => {
129                        handle_did_close(&connection, notif, &mut documents)?;
130                    }
131                    _ => {
132                        info!("Ignoring unhandled notification: {}", notif.method);
133                    }
134                }
135            }
136        }
137    }
138    Ok(())
139}
140
141fn handle_goto_definition(
142    connection: &Connection,
143    req: lsp_server::Request,
144    documents: &HashMap<Url, DocumentState>,
145) -> Result<()> {
146    let params: GotoDefinitionParams = serde_json::from_value(req.params)?;
147    let uri = params.text_document_position_params.text_document.uri;
148    let position = params.text_document_position_params.position;
149
150    let content = documents.get(&uri).map_or("", |doc| &doc.content);
151    let parse: Parse<SourceFile> = SourceFile::parse(content);
152    let file = parse.tree();
153    let line_index = LineIndex::new(content);
154    let offset = lsp_utils::offset(&line_index, position).unwrap();
155
156    let range = goto_definition(file, offset);
157
158    let result = match range {
159        Some(target_range) => {
160            debug_assert!(
161                !target_range.contains(offset),
162                "Our target destination range must not include the source range otherwise go to def won't work in vscode."
163            );
164            GotoDefinitionResponse::Scalar(Location {
165                uri: uri.clone(),
166                range: lsp_utils::range(&line_index, target_range),
167            })
168        }
169        None => GotoDefinitionResponse::Array(vec![]),
170    };
171
172    let resp = Response {
173        id: req.id,
174        result: Some(serde_json::to_value(&result).unwrap()),
175        error: None,
176    };
177
178    connection.sender.send(Message::Response(resp))?;
179    Ok(())
180}
181
182fn handle_selection_range(
183    connection: &Connection,
184    req: lsp_server::Request,
185    documents: &HashMap<Url, DocumentState>,
186) -> Result<()> {
187    let params: SelectionRangeParams = serde_json::from_value(req.params)?;
188    let uri = params.text_document.uri;
189
190    let content = documents.get(&uri).map_or("", |doc| &doc.content);
191    let parse: Parse<SourceFile> = SourceFile::parse(content);
192    let root = parse.syntax_node();
193    let line_index = LineIndex::new(content);
194
195    let mut selection_ranges = vec![];
196
197    for position in params.positions {
198        let Some(offset) = lsp_utils::offset(&line_index, position) else {
199            continue;
200        };
201
202        let mut ranges = Vec::new();
203        {
204            let mut range = TextRange::new(offset, offset);
205            loop {
206                ranges.push(range);
207                let next = squawk_ide::expand_selection::extend_selection(&root, range);
208                if next == range {
209                    break;
210                } else {
211                    range = next
212                }
213            }
214        }
215
216        let mut range = lsp_types::SelectionRange {
217            range: lsp_utils::range(&line_index, *ranges.last().unwrap()),
218            parent: None,
219        };
220        for &r in ranges.iter().rev().skip(1) {
221            range = lsp_types::SelectionRange {
222                range: lsp_utils::range(&line_index, r),
223                parent: Some(Box::new(range)),
224            }
225        }
226        selection_ranges.push(range);
227    }
228
229    let resp = Response {
230        id: req.id,
231        result: Some(serde_json::to_value(&selection_ranges).unwrap()),
232        error: None,
233    };
234
235    connection.sender.send(Message::Response(resp))?;
236    Ok(())
237}
238
239fn handle_code_action(
240    connection: &Connection,
241    req: lsp_server::Request,
242    documents: &HashMap<Url, DocumentState>,
243) -> Result<()> {
244    let params: CodeActionParams = serde_json::from_value(req.params)?;
245    let uri = params.text_document.uri;
246
247    let mut actions: CodeActionResponse = Vec::new();
248
249    let content = documents.get(&uri).map_or("", |doc| &doc.content);
250    let parse: Parse<SourceFile> = SourceFile::parse(content);
251    let file = parse.tree();
252    let line_index = LineIndex::new(content);
253    let offset = lsp_utils::offset(&line_index, params.range.start).unwrap();
254
255    let ide_actions = code_actions(file, offset).unwrap_or_default();
256
257    for action in ide_actions {
258        let lsp_action = lsp_utils::code_action(&line_index, uri.clone(), action);
259        actions.push(CodeActionOrCommand::CodeAction(lsp_action));
260    }
261
262    for mut diagnostic in params
263        .context
264        .diagnostics
265        .into_iter()
266        .filter(|diagnostic| diagnostic.source.as_deref() == Some(DIAGNOSTIC_NAME))
267    {
268        let Some(rule_name) = diagnostic.code.as_ref().map(|x| match x {
269            lsp_types::NumberOrString::String(s) => s.clone(),
270            lsp_types::NumberOrString::Number(n) => n.to_string(),
271        }) else {
272            continue;
273        };
274        let Some(data) = diagnostic.data.take() else {
275            continue;
276        };
277
278        let associated_data: AssociatedDiagnosticData =
279            serde_json::from_value(data).context("deserializing diagnostic data")?;
280
281        if let Some(ignore_line_edit) = associated_data.ignore_line_edit {
282            let disable_line_action = CodeAction {
283                title: format!("Disable {rule_name} for this line"),
284                kind: Some(CodeActionKind::QUICKFIX),
285                diagnostics: Some(vec![diagnostic.clone()]),
286                edit: Some(WorkspaceEdit {
287                    changes: Some({
288                        let mut changes = HashMap::new();
289                        changes.insert(uri.clone(), vec![ignore_line_edit]);
290                        changes
291                    }),
292                    ..Default::default()
293                }),
294                command: None,
295                is_preferred: Some(false),
296                disabled: None,
297                data: None,
298            };
299            actions.push(CodeActionOrCommand::CodeAction(disable_line_action));
300        }
301        if let Some(ignore_file_edit) = associated_data.ignore_file_edit {
302            let disable_file_action = CodeAction {
303                title: format!("Disable {rule_name} for the entire file"),
304                kind: Some(CodeActionKind::QUICKFIX),
305                diagnostics: Some(vec![diagnostic.clone()]),
306                edit: Some(WorkspaceEdit {
307                    changes: Some({
308                        let mut changes = HashMap::new();
309                        changes.insert(uri.clone(), vec![ignore_file_edit]);
310                        changes
311                    }),
312                    ..Default::default()
313                }),
314                command: None,
315                is_preferred: Some(false),
316                disabled: None,
317                data: None,
318            };
319            actions.push(CodeActionOrCommand::CodeAction(disable_file_action));
320        }
321
322        let title = format!("Show documentation for {rule_name}");
323        let documentation_action = CodeAction {
324            title: title.clone(),
325            kind: Some(CodeActionKind::QUICKFIX),
326            diagnostics: Some(vec![diagnostic.clone()]),
327            edit: None,
328            command: Some(Command {
329                title,
330                command: "vscode.open".to_string(),
331                arguments: Some(vec![serde_json::to_value(format!(
332                    "https://squawkhq.com/docs/{rule_name}"
333                ))?]),
334            }),
335            is_preferred: Some(false),
336            disabled: None,
337            data: None,
338        };
339        actions.push(CodeActionOrCommand::CodeAction(documentation_action));
340
341        if !associated_data.title.is_empty() && !associated_data.edits.is_empty() {
342            let fix_action = CodeAction {
343                title: associated_data.title,
344                kind: Some(CodeActionKind::QUICKFIX),
345                diagnostics: Some(vec![diagnostic.clone()]),
346                edit: Some(WorkspaceEdit {
347                    changes: Some({
348                        let mut changes = HashMap::new();
349                        changes.insert(uri.clone(), associated_data.edits);
350                        changes
351                    }),
352                    ..Default::default()
353                }),
354                command: None,
355                is_preferred: Some(true),
356                disabled: None,
357                data: None,
358            };
359            actions.push(CodeActionOrCommand::CodeAction(fix_action));
360        }
361    }
362
363    let result: CodeActionResponse = actions;
364    let resp = Response {
365        id: req.id,
366        result: Some(serde_json::to_value(&result).unwrap()),
367        error: None,
368    };
369
370    connection.sender.send(Message::Response(resp))?;
371    Ok(())
372}
373
374fn publish_diagnostics(
375    connection: &Connection,
376    uri: Url,
377    version: i32,
378    diagnostics: Vec<Diagnostic>,
379) -> Result<()> {
380    let publish_params = PublishDiagnosticsParams {
381        uri,
382        diagnostics,
383        version: Some(version),
384    };
385
386    let notification = Notification {
387        method: PublishDiagnostics::METHOD.to_owned(),
388        params: serde_json::to_value(publish_params)?,
389    };
390
391    connection
392        .sender
393        .send(Message::Notification(notification))?;
394    Ok(())
395}
396
397fn handle_did_open(
398    connection: &Connection,
399    notif: lsp_server::Notification,
400    documents: &mut HashMap<Url, DocumentState>,
401) -> Result<()> {
402    let params: DidOpenTextDocumentParams = serde_json::from_value(notif.params)?;
403    let uri = params.text_document.uri;
404    let content = params.text_document.text;
405    let version = params.text_document.version;
406
407    documents.insert(uri.clone(), DocumentState { content, version });
408
409    let content = documents.get(&uri).map_or("", |doc| &doc.content);
410
411    // TODO: we need a better setup for "run func when input changed"
412    let diagnostics = lint::lint(content);
413    publish_diagnostics(connection, uri, version, diagnostics)?;
414
415    Ok(())
416}
417
418fn handle_did_change(
419    connection: &Connection,
420    notif: lsp_server::Notification,
421    documents: &mut HashMap<Url, DocumentState>,
422) -> Result<()> {
423    let params: DidChangeTextDocumentParams = serde_json::from_value(notif.params)?;
424    let uri = params.text_document.uri;
425    let version = params.text_document.version;
426
427    let Some(doc_state) = documents.get_mut(&uri) else {
428        return Ok(());
429    };
430
431    doc_state.content =
432        lsp_utils::apply_incremental_changes(&doc_state.content, params.content_changes);
433    doc_state.version = version;
434
435    let diagnostics = lint::lint(&doc_state.content);
436    publish_diagnostics(connection, uri, version, diagnostics)?;
437
438    Ok(())
439}
440
441fn handle_did_close(
442    connection: &Connection,
443    notif: lsp_server::Notification,
444    documents: &mut HashMap<Url, DocumentState>,
445) -> Result<()> {
446    let params: DidCloseTextDocumentParams = serde_json::from_value(notif.params)?;
447    let uri = params.text_document.uri;
448
449    documents.remove(&uri);
450
451    let publish_params = PublishDiagnosticsParams {
452        uri,
453        diagnostics: vec![],
454        version: None,
455    };
456
457    let notification = Notification {
458        method: PublishDiagnostics::METHOD.to_owned(),
459        params: serde_json::to_value(publish_params)?,
460    };
461
462    connection
463        .sender
464        .send(Message::Notification(notification))?;
465
466    Ok(())
467}
468
469#[derive(serde::Deserialize)]
470struct SyntaxTreeParams {
471    #[serde(rename = "textDocument")]
472    text_document: lsp_types::TextDocumentIdentifier,
473}
474
475fn handle_syntax_tree(
476    connection: &Connection,
477    req: lsp_server::Request,
478    documents: &HashMap<Url, DocumentState>,
479) -> Result<()> {
480    let params: SyntaxTreeParams = serde_json::from_value(req.params)?;
481    let uri = params.text_document.uri;
482
483    info!("Generating syntax tree for: {uri}");
484
485    let content = documents.get(&uri).map_or("", |doc| &doc.content);
486
487    let parse: Parse<SourceFile> = SourceFile::parse(content);
488    let syntax_tree = format!("{:#?}", parse.syntax_node());
489
490    let resp = Response {
491        id: req.id,
492        result: Some(serde_json::to_value(&syntax_tree).unwrap()),
493        error: None,
494    };
495
496    connection.sender.send(Message::Response(resp))?;
497    Ok(())
498}
499
500#[derive(serde::Deserialize)]
501struct TokensParams {
502    #[serde(rename = "textDocument")]
503    text_document: lsp_types::TextDocumentIdentifier,
504}
505
506fn handle_tokens(
507    connection: &Connection,
508    req: lsp_server::Request,
509    documents: &HashMap<Url, DocumentState>,
510) -> Result<()> {
511    let params: TokensParams = serde_json::from_value(req.params)?;
512    let uri = params.text_document.uri;
513
514    info!("Generating tokens for: {uri}");
515
516    let content = documents.get(&uri).map_or("", |doc| &doc.content);
517
518    let tokens = squawk_lexer::tokenize(content);
519
520    let mut output = Vec::new();
521    let mut char_pos = 0;
522    for token in tokens {
523        let token_start = char_pos;
524        let token_end = token_start + token.len as usize;
525        let token_text = &content[token_start..token_end];
526        output.push(format!(
527            "{:?}@{}..{} {:?}",
528            token.kind, token_start, token_end, token_text
529        ));
530        char_pos = token_end;
531    }
532
533    let tokens_output = output.join("\n");
534
535    let resp = Response {
536        id: req.id,
537        result: Some(serde_json::to_value(&tokens_output).unwrap()),
538        error: None,
539    };
540
541    connection.sender.send(Message::Response(resp))?;
542    Ok(())
543}