sql_lsp/
server.rs

1use crate::dialect::Dialect;
2use crate::dialects::DialectRegistry;
3use crate::schema::{Schema, SchemaId, SchemaManager};
4use dashmap::DashMap;
5use std::sync::Arc;
6use tower_lsp::jsonrpc::Result;
7use tower_lsp::lsp_types::*;
8use tower_lsp::{Client, LanguageServer};
9
10/// 文档管理器,用于存储和管理打开的文档内容
11#[derive(Clone)]
12struct DocumentManager {
13    documents: Arc<DashMap<String, String>>,
14}
15
16impl DocumentManager {
17    fn new() -> Self {
18        Self {
19            documents: Arc::new(DashMap::new()),
20        }
21    }
22
23    fn update(&self, uri: String, text: String) {
24        self.documents.insert(uri, text);
25    }
26
27    fn get(&self, uri: &str) -> Option<String> {
28        self.documents.get(uri).map(|v| v.clone())
29    }
30
31    fn remove(&self, uri: &str) {
32        self.documents.remove(uri);
33    }
34}
35
36/// SQL LSP 服务器
37pub struct SqlLspServer {
38    client: Client,
39    /// 方言注册表
40    dialect_registry: Arc<DialectRegistry>,
41    /// Schema 管理器
42    schema_manager: Arc<SchemaManager>,
43    /// 文件到方言的映射
44    file_dialects: Arc<DashMap<String, String>>,
45    /// 文件到 Schema ID 的映射
46    file_schemas: Arc<DashMap<String, SchemaId>>,
47    /// 文档管理器
48    document_manager: DocumentManager,
49}
50
51impl SqlLspServer {
52    pub fn new(client: Client) -> Self {
53        tracing::info!("Creating new SQL LSP server instance");
54        Self {
55            client,
56            dialect_registry: Arc::new(DialectRegistry::new()),
57            schema_manager: Arc::new(SchemaManager::new()),
58            file_dialects: Arc::new(DashMap::new()),
59            file_schemas: Arc::new(DashMap::new()),
60            document_manager: DocumentManager::new(),
61        }
62    }
63
64    /// 获取文件的方言
65    fn get_dialect_for_file(&self, uri: &str) -> Option<Arc<dyn Dialect>> {
66        self.file_dialects
67            .get(uri)
68            .and_then(|dialect_name| self.dialect_registry.get_by_name(dialect_name.value()))
69    }
70
71    /// 获取文件的 Schema
72    /// 如果文件没有显式关联的 Schema,则根据 SQL 内容自动推断最佳匹配的 Schema
73    fn get_schema_for_file(&self, uri: &str) -> Option<Schema> {
74        // 1. 先检查是否有显式关联的 Schema
75        if let Some(schema_id) = self.file_schemas.get(uri) {
76            return self.schema_manager.get(*schema_id.value());
77        }
78
79        // 2. 从文档内容推断 Schema
80        if let Some(text) = self.document_manager.get(uri) {
81            use crate::parser::SqlParser;
82            let mut parser = SqlParser::new();
83            let parse_result = parser.parse(&text);
84
85            if let Some(tree) = parse_result.tree {
86                let tables = parser.extract_tables(&tree, &text);
87
88                if !tables.is_empty() {
89                    // 3. 查找最佳匹配的 Schema
90                    let best_match = self
91                        .schema_manager
92                        .list_ids()
93                        .iter()
94                        .filter_map(|&schema_id| {
95                            let schema = self.schema_manager.get(schema_id)?;
96                            let score = self.calculate_schema_match_score(&tables, &schema);
97                            if score > 0 {
98                                Some((schema_id, score))
99                            } else {
100                                None
101                            }
102                        })
103                        .max_by_key(|(_, score)| *score);
104
105                    if let Some((schema_id, _score)) = best_match {
106                        // 缓存推断结果
107                        self.file_schemas.insert(uri.to_string(), schema_id);
108                        return self.schema_manager.get(schema_id);
109                    }
110                }
111            }
112        }
113
114        None
115    }
116
117    /// 计算 Schema 匹配分数
118    /// 返回匹配分数,分数越高表示匹配度越高
119    fn calculate_schema_match_score(&self, tables: &[String], schema: &Schema) -> i32 {
120        let mut score = 0;
121
122        for table_name in tables {
123            // 完全匹配:+10 分
124            if schema.tables.iter().any(|t| t.name == *table_name) {
125                score += 10;
126            } else {
127                // 模糊匹配:+5 分
128                for schema_table in &schema.tables {
129                    if schema_table.name.contains(table_name)
130                        || table_name.contains(&schema_table.name)
131                    {
132                        score += 5;
133                        break; // 每个表名只匹配一次
134                    }
135                }
136            }
137        }
138
139        // 如果匹配的表数量越多,额外加分
140        let matched_count = tables
141            .iter()
142            .filter(|table_name| schema.tables.iter().any(|t| t.name == **table_name))
143            .count();
144
145        if matched_count > 1 {
146            score += matched_count as i32 * 2; // 多表匹配额外加分
147        }
148
149        score
150    }
151
152    /// 将 LSP Position 转换为字符串字节偏移
153    fn position_to_offset(&self, text: &str, position: tower_lsp::lsp_types::Position) -> usize {
154        let mut offset = 0;
155        for (line_idx, line) in text.lines().enumerate() {
156            if line_idx < position.line as usize {
157                offset += line.len() + 1; // +1 for newline
158            } else {
159                offset += position.character.min(line.len() as u32) as usize;
160                break;
161            }
162        }
163        offset.min(text.len())
164    }
165}
166
167#[tower_lsp::async_trait]
168impl LanguageServer for SqlLspServer {
169    async fn initialize(&self, _: InitializeParams) -> Result<InitializeResult> {
170        Ok(InitializeResult {
171            server_info: Some(ServerInfo {
172                name: "sql-lsp".to_string(),
173                version: Some("0.1.0".to_string()),
174            }),
175            capabilities: ServerCapabilities {
176                text_document_sync: Some(TextDocumentSyncCapability::Kind(
177                    TextDocumentSyncKind::INCREMENTAL,
178                )),
179                completion_provider: Some(CompletionOptions {
180                    resolve_provider: Some(false),
181                    trigger_characters: Some(vec![
182                        ".".to_string(),
183                        " ".to_string(),
184                        "(".to_string(),
185                    ]),
186                    ..Default::default()
187                }),
188                hover_provider: Some(HoverProviderCapability::Simple(true)),
189                definition_provider: Some(OneOf::Left(true)),
190                references_provider: Some(OneOf::Left(true)),
191                document_formatting_provider: Some(OneOf::Left(true)),
192                diagnostic_provider: Some(DiagnosticServerCapabilities::Options(
193                    DiagnosticOptions {
194                        identifier: Some("sql-lsp".to_string()),
195                        inter_file_dependencies: true,
196                        workspace_diagnostics: false,
197                        ..Default::default()
198                    },
199                )),
200                ..Default::default()
201            },
202        })
203    }
204
205    async fn initialized(&self, _: InitializedParams) {
206        tracing::info!("SQL LSP server initialized and ready");
207        self.client
208            .log_message(MessageType::INFO, "SQL LSP server initialized")
209            .await;
210    }
211
212    async fn shutdown(&self) -> Result<()> {
213        Ok(())
214    }
215
216    async fn did_change_configuration(&self, params: DidChangeConfigurationParams) {
217        tracing::debug!("Received configuration change");
218        // 解析配置 JSON
219        if let Some(settings) = params.settings.as_object() {
220            // 处理 schemas 配置
221            if let Some(schemas_value) = settings.get("schemas") {
222                if let Ok(schemas) =
223                    serde_json::from_value::<Vec<crate::schema::Schema>>(schemas_value.clone())
224                {
225                    // 清空旧的 schema 并注册新的
226                    self.schema_manager.clear();
227                    let count = schemas.len();
228                    for schema in schemas {
229                        self.schema_manager.register(schema);
230                    }
231                    self.client
232                        .log_message(MessageType::INFO, format!("Updated {} schemas", count))
233                        .await;
234                } else {
235                    self.client
236                        .log_message(
237                            MessageType::WARNING,
238                            "Failed to parse schemas configuration",
239                        )
240                        .await;
241                }
242            }
243
244            // 处理文件到 schema 的映射配置
245            if let Some(file_schemas_value) = settings.get("fileSchemas") {
246                if let Some(file_schemas_obj) = file_schemas_value.as_object() {
247                    for (uri, schema_id_str) in file_schemas_obj {
248                        if let Some(id_str) = schema_id_str.as_str() {
249                            if let Ok(schema_id) = id_str.parse::<crate::schema::SchemaId>() {
250                                self.file_schemas.insert(uri.clone(), schema_id);
251                            }
252                        }
253                    }
254                    self.client
255                        .log_message(MessageType::INFO, "Updated file-schema mappings")
256                        .await;
257                }
258            }
259        }
260    }
261
262    async fn did_open(&self, params: DidOpenTextDocumentParams) {
263        let uri = params.text_document.uri.to_string();
264        let text = params.text_document.text.clone();
265        let language_id = params.text_document.language_id.clone();
266
267        // 存储文档内容
268        self.document_manager.update(uri.clone(), text.clone());
269
270        // 尝试从 URI 和 languageId 推断方言
271        // 优先级:URI 扩展名 > languageId > 默认 MySQL
272        let dialect_name = infer_dialect_from_uri_and_language(&uri, &language_id);
273        self.file_dialects.insert(uri.clone(), dialect_name.clone());
274
275        // 发布诊断
276        if let Some(dialect) = self.get_dialect_for_file(&uri) {
277            let schema = self.get_schema_for_file(&uri);
278            let diagnostics = dialect.parse(&text, schema.as_ref()).await;
279            self.client
280                .publish_diagnostics(params.text_document.uri, diagnostics, None)
281                .await;
282        }
283    }
284
285    async fn did_change(&self, params: DidChangeTextDocumentParams) {
286        let uri = params.text_document.uri.to_string();
287
288        // 处理增量同步
289        for change in params.content_changes {
290            if let Some(range) = change.range {
291                // 增量更新:应用部分文本变更
292                if let Some(mut current_text) = self.document_manager.get(&uri) {
293                    // 将 LSP Range 转换为字节偏移
294                    let start_offset = self.position_to_offset(&current_text, range.start);
295                    let end_offset = self.position_to_offset(&current_text, range.end);
296
297                    // 应用变更
298                    current_text.replace_range(start_offset..end_offset, &change.text);
299                    self.document_manager
300                        .update(uri.clone(), current_text.clone());
301
302                    // 重新解析并发布诊断
303                    if let Some(dialect) = self.get_dialect_for_file(&uri) {
304                        let schema = self.get_schema_for_file(&uri);
305                        let diagnostics = dialect.parse(&current_text, schema.as_ref()).await;
306                        self.client
307                            .publish_diagnostics(
308                                params.text_document.uri.clone(),
309                                diagnostics,
310                                None,
311                            )
312                            .await;
313                    }
314                }
315            } else {
316                // 完整文档更新
317                let text = change.text.clone();
318                self.document_manager.update(uri.clone(), text.clone());
319
320                if let Some(dialect) = self.get_dialect_for_file(&uri) {
321                    let schema = self.get_schema_for_file(&uri);
322                    let diagnostics = dialect.parse(&text, schema.as_ref()).await;
323                    self.client
324                        .publish_diagnostics(params.text_document.uri.clone(), diagnostics, None)
325                        .await;
326                }
327            }
328        }
329    }
330
331    async fn did_close(&self, params: DidCloseTextDocumentParams) {
332        let uri = params.text_document.uri.to_string();
333        // 清理文档
334        self.document_manager.remove(&uri);
335    }
336
337    async fn completion(&self, params: CompletionParams) -> Result<Option<CompletionResponse>> {
338        let uri = params.text_document_position.text_document.uri.to_string();
339        let position = params.text_document_position.position;
340
341        let text = self.document_manager.get(&uri).unwrap_or_default();
342
343        if let Some(dialect) = self.get_dialect_for_file(&uri) {
344            let schema = self.get_schema_for_file(&uri);
345            let items = dialect.completion(&text, position, schema.as_ref()).await;
346            return Ok(Some(CompletionResponse::Array(items)));
347        }
348
349        Ok(None)
350    }
351
352    async fn hover(&self, params: HoverParams) -> Result<Option<Hover>> {
353        let uri = params
354            .text_document_position_params
355            .text_document
356            .uri
357            .to_string();
358        let position = params.text_document_position_params.position;
359
360        let text = self.document_manager.get(&uri).unwrap_or_default();
361
362        if let Some(dialect) = self.get_dialect_for_file(&uri) {
363            let schema = self.get_schema_for_file(&uri);
364            return Ok(dialect.hover(&text, position, schema.as_ref()).await);
365        }
366
367        Ok(None)
368    }
369
370    async fn goto_definition(
371        &self,
372        params: GotoDefinitionParams,
373    ) -> Result<Option<GotoDefinitionResponse>> {
374        let uri = params
375            .text_document_position_params
376            .text_document
377            .uri
378            .to_string();
379        let position = params.text_document_position_params.position;
380
381        let text = self.document_manager.get(&uri).unwrap_or_default();
382
383        if let Some(dialect) = self.get_dialect_for_file(&uri) {
384            let schema = self.get_schema_for_file(&uri);
385            if let Some(location) = dialect
386                .goto_definition(&text, position, schema.as_ref())
387                .await
388            {
389                return Ok(Some(GotoDefinitionResponse::Scalar(location)));
390            }
391        }
392
393        Ok(None)
394    }
395
396    async fn references(&self, params: ReferenceParams) -> Result<Option<Vec<Location>>> {
397        let uri = params.text_document_position.text_document.uri.to_string();
398        let position = params.text_document_position.position;
399
400        let text = self.document_manager.get(&uri).unwrap_or_default();
401
402        if let Some(dialect) = self.get_dialect_for_file(&uri) {
403            let schema = self.get_schema_for_file(&uri);
404            let locations = dialect.references(&text, position, schema.as_ref()).await;
405            return Ok(Some(locations));
406        }
407
408        Ok(None)
409    }
410
411    async fn formatting(&self, params: DocumentFormattingParams) -> Result<Option<Vec<TextEdit>>> {
412        let uri = params.text_document.uri.to_string();
413        let text = self.document_manager.get(&uri).unwrap_or_default();
414
415        if let Some(dialect) = self.get_dialect_for_file(&uri) {
416            let formatted = dialect.format(&text).await;
417            let line_count = if text.is_empty() {
418                0
419            } else {
420                text.lines().count() as u32
421            };
422            let range = Range {
423                start: Position {
424                    line: 0,
425                    character: 0,
426                },
427                end: Position {
428                    line: line_count.saturating_sub(1),
429                    character: 0,
430                },
431            };
432            return Ok(Some(vec![TextEdit {
433                range,
434                new_text: formatted,
435            }]));
436        }
437
438        Ok(None)
439    }
440}
441
442/// 从 URI 和 languageId 推断方言类型
443///
444/// 支持多种 URI scheme:
445/// - `file://` - 文件系统文件
446/// - `untitled://` - 未保存的文档(VS Code 等编辑器)
447/// - 其他自定义 scheme
448///
449/// 推断优先级:
450/// 1. URI 扩展名(如 `.mysql.sql`)
451/// 2. languageId(如 `mysql`, `postgresql`, `sql`)
452/// 3. 默认 MySQL
453fn infer_dialect_from_uri_and_language(uri: &str, language_id: &str) -> String {
454    // 首先尝试从 URI 扩展名推断
455    let uri_lower = uri.to_lowercase();
456
457    if uri_lower.ends_with(".mysql.sql") || uri_lower.ends_with(".mysql") {
458        return "mysql".to_string();
459    } else if uri_lower.ends_with(".postgres.sql") || uri_lower.ends_with(".pgsql") {
460        return "postgres".to_string();
461    } else if uri_lower.ends_with(".hive.sql") || uri_lower.ends_with(".hql") {
462        return "hive".to_string();
463    } else if uri_lower.ends_with(".es.eql") || uri_lower.ends_with(".eql") {
464        return "elasticsearch-eql".to_string();
465    } else if uri_lower.ends_with(".es.dsl")
466        || uri_lower.ends_with(".es.json")
467        || uri_lower.ends_with(".elasticsearch")
468    {
469        return "elasticsearch-dsl".to_string();
470    } else if uri_lower.ends_with(".ch.sql") || uri_lower.ends_with(".clickhouse") {
471        return "clickhouse".to_string();
472    } else if uri_lower.ends_with(".redis.sql") || uri_lower.ends_with(".redis") {
473        return "redis".to_string();
474    }
475
476    // 如果 URI 无法推断,尝试从 languageId 推断
477    let lang_lower = language_id.to_lowercase();
478    match lang_lower.as_str() {
479        "mysql" | "mysql-sql" => "mysql".to_string(),
480        "postgresql" | "postgres" | "pgsql" => "postgres".to_string(),
481        "hive" | "hql" => "hive".to_string(),
482        "elasticsearch-eql" | "eql" => "elasticsearch-eql".to_string(),
483        "elasticsearch-dsl" | "es-dsl" | "json" if uri_lower.contains("elasticsearch") => {
484            "elasticsearch-dsl".to_string()
485        }
486        "clickhouse" | "ch" => "clickhouse".to_string(),
487        "redis" => "redis".to_string(),
488        _ => "mysql".to_string(), // 默认使用 MySQL
489    }
490}