Skip to main content

typr_cli/
lsp.rs

1//! LSP server for TypR.
2//!
3//! Currently exposes:
4//!   - **Hover** provider: shows inferred types with Markdown syntax highlighting
5//!   - **Completion** provider: context-aware autocompletion for variables, functions, and type aliases
6//!     - Trigger characters: `.`, `$`, `>` (for `|>`), `:` (for type annotations)
7//!   - **Diagnostics** (push model): real-time error checking via `textDocument/publishDiagnostics`
8//!     - Diagnostics are published on `didOpen` and `didChange` events
9//!   - **Go to Definition** provider: jump to symbol definitions (variables, functions, type aliases)
10//!   - **Workspace Symbol** provider: search for symbols across all open documents
11//!
12//! Launch with `typr lsp`.  The server communicates over stdin/stdout using
13//! the standard LSP JSON-RPC protocol.
14
15use crate::lsp_parser;
16use nom_locate::LocatedSpan;
17use std::collections::HashMap;
18use std::sync::Arc;
19use tokio::sync::RwLock;
20use tower_lsp::jsonrpc::Result;
21use tower_lsp::lsp_types::*;
22use tower_lsp::{Client, LanguageServer, LspService, Server};
23use typr_core::components::context::Context;
24use typr_core::components::language::var::Var;
25use typr_core::components::language::Lang;
26use typr_core::processes::parsing::parse;
27
28type Span<'a> = LocatedSpan<&'a str, String>;
29
30/// Shared state: one copy of each open document's full text.
31struct Backend {
32    client: Client,
33    documents: Arc<RwLock<HashMap<Url, String>>>,
34}
35
36#[tower_lsp::async_trait]
37impl LanguageServer for Backend {
38    // ── initialisation ──────────────────────────────────────────────────────
39    async fn initialize(&self, _: InitializeParams) -> Result<InitializeResult> {
40        Ok(InitializeResult {
41            capabilities: ServerCapabilities {
42                text_document_sync: Some(TextDocumentSyncCapability::Kind(
43                    TextDocumentSyncKind::FULL,
44                )),
45                hover_provider: Some(HoverProviderCapability::Simple(true)),
46                completion_provider: Some(CompletionOptions {
47                    trigger_characters: Some(vec![".".into(), "$".into(), ">".into(), ":".into()]),
48                    resolve_provider: None,
49                    ..Default::default()
50                }),
51                workspace_symbol_provider: Some(OneOf::Left(true)),
52                definition_provider: Some(OneOf::Left(true)),
53                ..Default::default()
54            },
55            ..Default::default()
56        })
57    }
58
59    async fn initialized(&self, _: InitializedParams) {
60        self.client
61            .log_message(MessageType::INFO, "TypR LSP server initialized.")
62            .await;
63    }
64
65    async fn shutdown(&self) -> Result<()> {
66        Ok(())
67    }
68
69    // ── document sync ───────────────────────────────────────────────────────
70    async fn did_open(&self, params: DidOpenTextDocumentParams) {
71        let uri = params.text_document.uri.clone();
72        let content = params.text_document.text.clone();
73
74        let mut docs = self.documents.write().await;
75        docs.insert(uri.clone(), content.clone());
76        drop(docs); // Release the lock before computing diagnostics
77
78        // Compute and publish diagnostics
79        let diagnostics = self.compute_diagnostics(&content, &uri).await;
80        self.client
81            .publish_diagnostics(uri, diagnostics, None)
82            .await;
83    }
84
85    async fn did_change(&self, params: DidChangeTextDocumentParams) {
86        let uri = params.text_document.uri.clone();
87        let mut docs = self.documents.write().await;
88
89        // Full-sync mode: each change event contains the complete text.
90        if let Some(change) = params.content_changes.into_iter().next() {
91            let content = change.text.clone();
92            docs.insert(uri.clone(), content.clone());
93            drop(docs); // Release the lock before computing diagnostics
94
95            // Compute and publish diagnostics
96            let diagnostics = self.compute_diagnostics(&content, &uri).await;
97            self.client
98                .publish_diagnostics(uri, diagnostics, None)
99                .await;
100        }
101    }
102
103    // ── hover ───────────────────────────────────────────────────────────────
104    async fn hover(&self, params: HoverParams) -> Result<Option<Hover>> {
105        let uri = params.text_document_position_params.text_document.uri;
106        let position = params.text_document_position_params.position;
107
108        let docs = self.documents.read().await;
109        let content = match docs.get(&uri) {
110            Some(c) => c,
111            None => return Ok(None),
112        };
113
114        // Offload parsing + typing to a blocking thread so we don't stall
115        // the LSP event loop.
116        let content_owned = content.clone();
117        let info = tokio::task::spawn_blocking(move || {
118            lsp_parser::find_type_at(&content_owned, position.line, position.character)
119        })
120        .await
121        .ok() // if the blocking task panicked, treat as None
122        .flatten();
123
124        match info {
125            Some(hover_info) => Ok(Some(Hover {
126                contents: HoverContents::Markup(MarkupContent {
127                    kind: MarkupKind::Markdown,
128                    value: hover_info.type_display,
129                }),
130                range: Some(hover_info.range),
131            })),
132            None => Ok(None),
133        }
134    }
135
136    // ── completion ──────────────────────────────────────────────────────────
137    async fn completion(&self, params: CompletionParams) -> Result<Option<CompletionResponse>> {
138        let uri = params.text_document_position.text_document.uri;
139        let position = params.text_document_position.position;
140
141        let docs = self.documents.read().await;
142        let content = match docs.get(&uri) {
143            Some(c) => c,
144            None => return Ok(None),
145        };
146
147        // Offload parsing + typing to a blocking thread (same strategy as hover).
148        let content_owned = content.clone();
149        let items = tokio::task::spawn_blocking(move || {
150            lsp_parser::get_completions_at(&content_owned, position.line, position.character)
151        })
152        .await
153        .ok()
154        .unwrap_or_default();
155
156        if items.is_empty() {
157            Ok(None)
158        } else {
159            Ok(Some(CompletionResponse::Array(items)))
160        }
161    }
162
163    // ── workspace/symbol ─────────────────────────────────────────────────────
164    async fn symbol(
165        &self,
166        params: WorkspaceSymbolParams,
167    ) -> Result<Option<Vec<SymbolInformation>>> {
168        let query = params.query.to_lowercase();
169        let docs = self.documents.read().await;
170
171        let mut all_symbols = Vec::new();
172
173        for (uri, content) in docs.iter() {
174            let content_owned = content.clone();
175            let uri_owned = uri.clone();
176
177            // Offload parsing to a blocking thread
178            let symbols = tokio::task::spawn_blocking(move || {
179                get_workspace_symbols(&content_owned, &uri_owned)
180            })
181            .await
182            .ok()
183            .unwrap_or_default();
184
185            all_symbols.extend(symbols);
186        }
187
188        // Filter symbols by query (case-insensitive substring match)
189        if !query.is_empty() {
190            all_symbols.retain(|sym| sym.name.to_lowercase().contains(&query));
191        }
192
193        if all_symbols.is_empty() {
194            Ok(None)
195        } else {
196            Ok(Some(all_symbols))
197        }
198    }
199
200    // ── textDocument/definition ───────────────────────────────────────────────
201    async fn goto_definition(
202        &self,
203        params: GotoDefinitionParams,
204    ) -> Result<Option<GotoDefinitionResponse>> {
205        let uri = params.text_document_position_params.text_document.uri;
206        let position = params.text_document_position_params.position;
207
208        let docs = self.documents.read().await;
209        let content = match docs.get(&uri) {
210            Some(c) => c,
211            None => return Ok(None),
212        };
213
214        // Extract the file path from the URI for cross-file definition lookup.
215        let file_path = uri
216            .to_file_path()
217            .ok()
218            .map(|p| p.to_string_lossy().to_string())
219            .unwrap_or_default();
220
221        // Offload parsing + typing to a blocking thread so we don't stall
222        // the LSP event loop.
223        let content_owned = content.clone();
224        let uri_owned = uri.clone();
225        let file_path_owned = file_path.clone();
226        let info = tokio::task::spawn_blocking(move || {
227            lsp_parser::find_definition_at(
228                &content_owned,
229                position.line,
230                position.character,
231                &file_path_owned,
232            )
233        })
234        .await
235        .ok()
236        .flatten();
237
238        match info {
239            Some(def_info) => {
240                // Use the definition's file path if available, otherwise use current file
241                let target_uri = match &def_info.file_path {
242                    Some(path) => Url::from_file_path(path).unwrap_or(uri_owned),
243                    None => uri_owned,
244                };
245                Ok(Some(GotoDefinitionResponse::Scalar(Location {
246                    uri: target_uri,
247                    range: def_info.range,
248                })))
249            }
250            None => Ok(None),
251        }
252    }
253}
254
255// ══════════════════════════════════════════════════════════════════════════
256// ── DIAGNOSTICS ───────────────────────────────────────────────────────────
257// ══════════════════════════════════════════════════════════════════════════
258
259impl Backend {
260    /// Compute diagnostics for a document by parsing and type-checking.
261    async fn compute_diagnostics(&self, content: &str, uri: &Url) -> Vec<Diagnostic> {
262        let content_owned = content.to_string();
263        let file_name = uri.to_string();
264
265        tokio::task::spawn_blocking(move || {
266            check_code_and_extract_errors(&content_owned, &file_name)
267        })
268        .await
269        .unwrap_or_else(|_| {
270            // If the blocking task panicked, return a generic diagnostic
271            vec![Diagnostic {
272                range: Range::new(Position::new(0, 0), Position::new(0, 0)),
273                severity: Some(DiagnosticSeverity::ERROR),
274                message: "Internal error while checking code".to_string(),
275                source: Some("typr".to_string()),
276                ..Default::default()
277            }]
278        })
279    }
280}
281
282/// Check the code and extract errors from parsing and type-checking.
283fn check_code_and_extract_errors(content: &str, file_name: &str) -> Vec<Diagnostic> {
284    let mut diagnostics = Vec::new();
285
286    // Convert URI to file path if needed
287    let path = if file_name.starts_with("file://") {
288        file_name.strip_prefix("file://").unwrap_or(file_name)
289    } else {
290        file_name
291    };
292
293    // 1. Attempt parsing
294    let span: Span = LocatedSpan::new_extra(content, path.to_string());
295    let parse_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| parse(span)));
296
297    let ast = match parse_result {
298        Ok(result) => {
299            // Collect syntax errors from the parsed AST
300            for syntax_error in &result.errors {
301                let msg = syntax_error.simple_message();
302                let range = if let Some(help_data) = syntax_error.get_help_data() {
303                    let offset = help_data.get_offset();
304                    let pos = offset_to_position(offset, content);
305                    let end_col = find_token_end(content, offset, pos);
306                    Range::new(pos, Position::new(pos.line, end_col))
307                } else {
308                    Range::new(Position::new(0, 0), Position::new(0, 1))
309                };
310                diagnostics.push(Diagnostic {
311                    range,
312                    severity: Some(DiagnosticSeverity::WARNING),
313                    message: msg,
314                    source: Some("typr".to_string()),
315                    ..Default::default()
316                });
317            }
318            result.ast
319        }
320        Err(panic_info) => {
321            // Extract diagnostic from the panic
322            if let Some(diagnostic) = extract_diagnostic_from_panic(&panic_info, content) {
323                diagnostics.push(diagnostic);
324            } else {
325                diagnostics.push(Diagnostic {
326                    range: Range::new(Position::new(0, 0), Position::new(0, 1)),
327                    severity: Some(DiagnosticSeverity::ERROR),
328                    message: "Syntax error in code".to_string(),
329                    source: Some("typr".to_string()),
330                    ..Default::default()
331                });
332            }
333            return diagnostics;
334        }
335    };
336
337    // 2. Attempt type checking with error collection
338    let context = Context::default();
339    let typing_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
340        typr_core::typing_with_errors(&context, &ast)
341    }));
342
343    match typing_result {
344        Ok(result) => {
345            use typr_core::TypRError;
346
347            for error in result.get_errors() {
348                let msg = error.simple_message();
349                let severity = match error {
350                    TypRError::Type(_) => DiagnosticSeverity::ERROR,
351                    TypRError::Syntax(_) => DiagnosticSeverity::WARNING,
352                };
353                let range = if let Some(help_data) = error.get_help_data() {
354                    let offset = help_data.get_offset();
355                    let pos = offset_to_position(offset, content);
356                    let end_col = find_token_end(content, offset, pos);
357                    Range::new(pos, Position::new(pos.line, end_col))
358                } else {
359                    Range::new(Position::new(0, 0), Position::new(0, 1))
360                };
361                diagnostics.push(Diagnostic {
362                    range,
363                    severity: Some(severity),
364                    message: msg,
365                    source: Some("typr".to_string()),
366                    ..Default::default()
367                });
368            }
369        }
370        Err(panic_info) => {
371            if let Some(diagnostic) = extract_diagnostic_from_panic(&panic_info, content) {
372                diagnostics.push(diagnostic);
373            } else {
374                diagnostics.push(Diagnostic {
375                    range: Range::new(Position::new(0, 0), Position::new(0, 1)),
376                    severity: Some(DiagnosticSeverity::ERROR),
377                    message: "Type error in code".to_string(),
378                    source: Some("typr".to_string()),
379                    ..Default::default()
380                });
381            }
382        }
383    }
384
385    diagnostics
386}
387
388/// Extract an LSP diagnostic from a panic payload.
389fn extract_diagnostic_from_panic(
390    panic_info: &Box<dyn std::any::Any + Send>,
391    content: &str,
392) -> Option<Diagnostic> {
393    let message = if let Some(s) = panic_info.downcast_ref::<String>() {
394        s.as_str()
395    } else if let Some(s) = panic_info.downcast_ref::<&str>() {
396        *s
397    } else {
398        return None;
399    };
400
401    let range = extract_position_from_error(message, content)
402        .unwrap_or_else(|| Range::new(Position::new(0, 0), Position::new(0, 1)));
403
404    Some(Diagnostic {
405        range,
406        severity: Some(DiagnosticSeverity::ERROR),
407        message: clean_error_message(message),
408        source: Some("typr".to_string()),
409        ..Default::default()
410    })
411}
412
413/// Clean an error message for display in the LSP.
414fn clean_error_message(msg: &str) -> String {
415    let without_ansi = strip_ansi_codes(msg);
416
417    for line in without_ansi.lines() {
418        let trimmed = line.trim();
419        if let Some(pos) = trimmed.find("× ") {
420            let message = &trimmed[pos + 2..];
421            return message.trim().trim_end_matches('.').to_string();
422        }
423    }
424
425    without_ansi
426        .lines()
427        .find(|line| !line.trim().is_empty())
428        .unwrap_or(&without_ansi)
429        .trim()
430        .to_string()
431}
432
433/// Strip ANSI escape codes from a string.
434fn strip_ansi_codes(s: &str) -> String {
435    let mut result = String::with_capacity(s.len());
436    let mut chars = s.chars().peekable();
437
438    while let Some(ch) = chars.next() {
439        if ch == '\x1b' {
440            if chars.peek() == Some(&'[') {
441                chars.next();
442                while let Some(&c) = chars.peek() {
443                    chars.next();
444                    if c.is_ascii_alphabetic() {
445                        break;
446                    }
447                }
448            }
449        } else {
450            result.push(ch);
451        }
452    }
453
454    result
455}
456
457/// Attempt to extract position information from an error message.
458fn extract_position_from_error(message: &str, content: &str) -> Option<Range> {
459    let without_ansi = strip_ansi_codes(message);
460
461    for line in without_ansi.lines() {
462        if let Some(bracket_start) = line.find('[') {
463            if let Some(bracket_end) = line[bracket_start..].find(']') {
464                let location = &line[bracket_start + 1..bracket_start + bracket_end];
465
466                if let Some(last_colon) = location.rfind(':') {
467                    if let Some(second_last_colon) = location[..last_colon].rfind(':') {
468                        let line_str = &location[second_last_colon + 1..last_colon];
469                        let col_str = &location[last_colon + 1..];
470
471                        if let (Ok(line_num), Ok(col_num)) =
472                            (line_str.parse::<u32>(), col_str.parse::<u32>())
473                        {
474                            let line = line_num.saturating_sub(1);
475                            let col = col_num.saturating_sub(1);
476                            let length = extract_error_length(&without_ansi, content, line);
477
478                            return Some(Range::new(
479                                Position::new(line, col),
480                                Position::new(line, col + length),
481                            ));
482                        }
483                    }
484                }
485            }
486        }
487    }
488
489    None
490}
491
492/// Try to extract the length of the error token from the miette diagram.
493fn extract_error_length(message: &str, content: &str, line: u32) -> u32 {
494    let mut found_line_number = false;
495    let mut marker_col = None;
496
497    for msg_line in message.lines() {
498        let trimmed = msg_line.trim_start();
499
500        if let Some(pipe_pos) = trimmed.find("│") {
501            if let Ok(num) = trimmed[..pipe_pos].trim().parse::<u32>() {
502                if num == line + 1 {
503                    found_line_number = true;
504                    continue;
505                }
506            }
507        }
508
509        if found_line_number && trimmed.contains("▲") {
510            if let Some(marker_pos) = trimmed.find("▲") {
511                marker_col = Some(marker_pos as u32);
512                break;
513            }
514        }
515    }
516
517    if let Some(_col) = marker_col {
518        if let Some(content_line) = content.lines().nth(line as usize) {
519            return content_line
520                .trim()
521                .split_whitespace()
522                .next()
523                .map(|s| s.len() as u32)
524                .unwrap_or(1);
525        }
526    }
527
528    1
529}
530
531/// Convert a character offset to a Position (line, column).
532fn offset_to_position(offset: usize, content: &str) -> Position {
533    let mut line = 0u32;
534    let mut col = 0u32;
535
536    for (i, ch) in content.chars().enumerate() {
537        if i >= offset {
538            break;
539        }
540        if ch == '\n' {
541            line += 1;
542            col = 0;
543        } else {
544            col += 1;
545        }
546    }
547
548    Position::new(line, col)
549}
550
551/// Find the end column of a token starting at the given offset.
552fn find_token_end(content: &str, offset: usize, start_pos: Position) -> u32 {
553    let bytes = content.as_bytes();
554    let mut end_offset = offset;
555
556    while end_offset < bytes.len() {
557        let ch = bytes[end_offset] as char;
558        if ch.is_whitespace() || ch == ';' || ch == ',' || ch == ')' || ch == ']' || ch == '}' {
559            break;
560        }
561        end_offset += 1;
562    }
563
564    let token_len = (end_offset - offset) as u32;
565    if token_len == 0 {
566        start_pos.character + 1
567    } else {
568        start_pos.character + token_len
569    }
570}
571
572// ══════════════════════════════════════════════════════════════════════════
573// ── WORKSPACE SYMBOLS ─────────────────────────────────────────────────────
574// ══════════════════════════════════════════════════════════════════════════
575
576/// Get all symbols from a document for workspace/symbol support.
577#[allow(deprecated)]
578fn get_workspace_symbols(content: &str, file_uri: &Url) -> Vec<SymbolInformation> {
579    let mut symbols = Vec::new();
580
581    let span: Span = LocatedSpan::new_extra(content, file_uri.path().to_string());
582    let parse_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| parse(span)));
583
584    let ast = match parse_result {
585        Ok(result) => result.ast,
586        Err(_) => return symbols,
587    };
588
589    collect_symbols_from_ast(&ast, content, file_uri, None, &mut symbols);
590
591    symbols
592}
593
594/// Recursively collect symbols from an AST node.
595#[allow(deprecated)]
596fn collect_symbols_from_ast(
597    lang: &Lang,
598    content: &str,
599    file_uri: &Url,
600    container_name: Option<String>,
601    symbols: &mut Vec<SymbolInformation>,
602) {
603    match lang {
604        Lang::Lines(statements, _) => {
605            for stmt in statements {
606                collect_symbols_from_ast(stmt, content, file_uri, container_name.clone(), symbols);
607            }
608        }
609
610        Lang::Scope(statements, _) => {
611            for stmt in statements {
612                collect_symbols_from_ast(stmt, content, file_uri, container_name.clone(), symbols);
613            }
614        }
615
616        Lang::Let(var_lang, typ, body, _) => {
617            if let Ok(var) = Var::try_from(var_lang) {
618                let name = var.get_name();
619                let help_data = var.get_help_data();
620                let offset = help_data.get_offset();
621                let pos = offset_to_position(offset, content);
622                let end_col = find_token_end(content, offset, pos);
623
624                let kind = if typ.is_function() || body.is_function() {
625                    SymbolKind::FUNCTION
626                } else {
627                    SymbolKind::VARIABLE
628                };
629
630                symbols.push(SymbolInformation {
631                    name: name.clone(),
632                    kind,
633                    location: Location {
634                        uri: file_uri.clone(),
635                        range: Range::new(pos, Position::new(pos.line, end_col)),
636                    },
637                    deprecated: None,
638                    container_name: container_name.clone(),
639                    tags: None,
640                });
641
642                collect_symbols_from_ast(body, content, file_uri, Some(name), symbols);
643            }
644        }
645
646        Lang::Alias(var_lang, _params, _typ, _) => {
647            if let Ok(var) = Var::try_from(var_lang) {
648                let name = var.get_name();
649                let help_data = var.get_help_data();
650                let offset = help_data.get_offset();
651                let pos = offset_to_position(offset, content);
652                let end_col = find_token_end(content, offset, pos);
653
654                symbols.push(SymbolInformation {
655                    name,
656                    kind: SymbolKind::TYPE_PARAMETER,
657                    location: Location {
658                        uri: file_uri.clone(),
659                        range: Range::new(pos, Position::new(pos.line, end_col)),
660                    },
661                    deprecated: None,
662                    container_name: container_name.clone(),
663                    tags: None,
664                });
665            }
666        }
667
668        Lang::Module(name, members, _, _, help_data) => {
669            let offset = help_data.get_offset();
670            let pos = offset_to_position(offset, content);
671            let end_col = pos.character + name.len() as u32;
672
673            symbols.push(SymbolInformation {
674                name: name.clone(),
675                kind: SymbolKind::MODULE,
676                location: Location {
677                    uri: file_uri.clone(),
678                    range: Range::new(pos, Position::new(pos.line, end_col)),
679                },
680                deprecated: None,
681                container_name: container_name.clone(),
682                tags: None,
683            });
684
685            for member in members {
686                collect_symbols_from_ast(member, content, file_uri, Some(name.clone()), symbols);
687            }
688        }
689
690        Lang::Signature(var, _typ, _) => {
691            let name = var.get_name();
692            let help_data = var.get_help_data();
693            let offset = help_data.get_offset();
694            let pos = offset_to_position(offset, content);
695            let end_col = find_token_end(content, offset, pos);
696
697            symbols.push(SymbolInformation {
698                name,
699                kind: SymbolKind::FUNCTION,
700                location: Location {
701                    uri: file_uri.clone(),
702                    range: Range::new(pos, Position::new(pos.line, end_col)),
703                },
704                deprecated: None,
705                container_name: container_name.clone(),
706                tags: None,
707            });
708        }
709
710        Lang::Function(_, _, body, _) => {
711            collect_symbols_from_ast(body, content, file_uri, container_name, symbols);
712        }
713
714        _ => {}
715    }
716}
717
718/// Start the LSP server.  Blocks until the client disconnects.
719pub async fn run_lsp() {
720    let stdin = tokio::io::stdin();
721    let stdout = tokio::io::stdout();
722
723    let (service, socket) = LspService::new(|client| Backend {
724        client,
725        documents: Arc::new(RwLock::new(HashMap::new())),
726    });
727
728    Server::new(stdin, stdout, socket).serve(service).await;
729}