Skip to main content

sh_layer3/query_engine/
lsp_query_engine.rs

1//! # LSP Query Engine
2//!
3//! 基于 LSP 的代码查询引擎实现。
4
5use crate::lsp::client::LspClient;
6use crate::lsp::types::SymbolKind as LspSymbolKind;
7use crate::lsp::types::{DocumentSymbol, Hover, HoverContents, Location, MarkedString, Position};
8use crate::query_engine::{
9    CodeAnalyzer, CodeMatch, CodeStructure, DetectedPattern, PatternType, QueryEngine, SymbolInfo,
10    SymbolKind,
11};
12use crate::types::{CodeLocation, CodeRange, Layer3Result, QueryResult, QueryType};
13use async_trait::async_trait;
14use std::path::{Path, PathBuf};
15use std::sync::Arc;
16
17/// LSP 查询引擎
18pub struct LspQueryEngine {
19    client: Arc<LspClient>,
20}
21
22impl LspQueryEngine {
23    /// 创建新的 LSP 查询引擎
24    pub fn new() -> Self {
25        Self {
26            client: Arc::new(LspClient::new()),
27        }
28    }
29
30    /// 使用现有的 LSP 客户端创建
31    pub fn with_client(client: Arc<LspClient>) -> Self {
32        Self { client }
33    }
34
35    /// 获取 LSP 客户端引用
36    pub fn client(&self) -> &Arc<LspClient> {
37        &self.client
38    }
39
40    /// 确保语言服务器已连接
41    async fn ensure_connected(&self, file: &Path) -> Layer3Result<String> {
42        let language = detect_language(file);
43
44        if !self.client.is_connected(&language).await {
45            self.client
46                .initialize(&language, file.parent().unwrap_or(Path::new(".")))
47                .await
48                .map_err(|e| anyhow::anyhow!("Failed to initialize LSP server: {}", e))?;
49        }
50
51        Ok(language)
52    }
53
54    /// 将 LSP Location 转换为 QueryResult
55    fn location_to_result(loc: &Location, query_type: QueryType) -> QueryResult {
56        let file = PathBuf::from(loc.uri.replace("file://", ""));
57        QueryResult {
58            query_type,
59            location: CodeLocation {
60                file: file.clone(),
61                line: loc.range.start.line,
62                column: loc.range.start.character,
63            },
64            range: Some(CodeRange {
65                start: CodeLocation {
66                    file: file.clone(),
67                    line: loc.range.start.line,
68                    column: loc.range.start.character,
69                },
70                end: CodeLocation {
71                    file,
72                    line: loc.range.end.line,
73                    column: loc.range.end.character,
74                },
75            }),
76            display_text: String::new(),
77            snippet: None,
78        }
79    }
80
81    /// 将 DocumentSymbol 转换为 SymbolInfo
82    fn doc_symbol_to_info(symbol: &DocumentSymbol, file: &Path) -> SymbolInfo {
83        let file = file.to_path_buf();
84        SymbolInfo {
85            name: symbol.name.clone(),
86            kind: lsp_kind_to_query_kind(symbol.kind),
87            location: CodeLocation {
88                file: file.clone(),
89                line: symbol.range.start.line,
90                column: symbol.range.start.character,
91            },
92            range: Some(CodeRange {
93                start: CodeLocation {
94                    file: file.clone(),
95                    line: symbol.range.start.line,
96                    column: symbol.range.start.character,
97                },
98                end: CodeLocation {
99                    file,
100                    line: symbol.range.end.line,
101                    column: symbol.range.end.character,
102                },
103            }),
104            container_name: None,
105        }
106    }
107}
108
109impl Default for LspQueryEngine {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115#[async_trait]
116impl QueryEngine for LspQueryEngine {
117    async fn go_to_definition(&self, location: CodeLocation) -> Layer3Result<Option<QueryResult>> {
118        let language = self.ensure_connected(&location.file).await?;
119
120        let result = self
121            .client
122            .go_to_definition(
123                &language,
124                &location.file,
125                Position::new(location.line, location.column),
126            )
127            .await
128            .map_err(|e| anyhow::anyhow!("LSP go_to_definition failed: {}", e))?;
129
130        Ok(result
131            .first()
132            .map(|loc| Self::location_to_result(loc, QueryType::Definition)))
133    }
134
135    async fn find_references(&self, location: CodeLocation) -> Layer3Result<Vec<QueryResult>> {
136        let language = self.ensure_connected(&location.file).await?;
137
138        let result = self
139            .client
140            .find_references(
141                &language,
142                &location.file,
143                Position::new(location.line, location.column),
144                true,
145            )
146            .await
147            .map_err(|e| anyhow::anyhow!("LSP find_references failed: {}", e))?;
148
149        Ok(result
150            .iter()
151            .map(|loc| Self::location_to_result(loc, QueryType::References))
152            .collect())
153    }
154
155    async fn go_to_implementation(&self, location: CodeLocation) -> Layer3Result<Vec<QueryResult>> {
156        let language = self.ensure_connected(&location.file).await?;
157
158        let result = self
159            .client
160            .go_to_implementation(
161                &language,
162                &location.file,
163                Position::new(location.line, location.column),
164            )
165            .await
166            .map_err(|e| anyhow::anyhow!("LSP go_to_implementation failed: {}", e))?;
167
168        Ok(result
169            .iter()
170            .map(|loc| Self::location_to_result(loc, QueryType::Implementations))
171            .collect())
172    }
173
174    async fn go_to_type_definition(
175        &self,
176        location: CodeLocation,
177    ) -> Layer3Result<Option<QueryResult>> {
178        let language = self.ensure_connected(&location.file).await?;
179
180        let result = self
181            .client
182            .go_to_type_definition(
183                &language,
184                &location.file,
185                Position::new(location.line, location.column),
186            )
187            .await
188            .map_err(|e| anyhow::anyhow!("LSP go_to_type_definition failed: {}", e))?;
189
190        Ok(result.map(|loc| Self::location_to_result(&loc, QueryType::TypeDefinition)))
191    }
192
193    async fn hover(&self, location: CodeLocation) -> Layer3Result<Option<String>> {
194        let language = self.ensure_connected(&location.file).await?;
195
196        let result = self
197            .client
198            .get_hover(
199                &language,
200                &location.file,
201                Position::new(location.line, location.column),
202            )
203            .await
204            .map_err(|e| anyhow::anyhow!("LSP hover failed: {}", e))?;
205
206        match result {
207            Some(Hover {
208                contents: HoverContents::Markup(markup),
209                ..
210            }) => Ok(Some(markup.value)),
211            Some(Hover {
212                contents: HoverContents::String(s),
213                ..
214            }) => Ok(Some(s)),
215            Some(Hover {
216                contents: HoverContents::Array(arr),
217                ..
218            }) => Ok(Some(
219                arr.iter()
220                    .map(|ms| match ms {
221                        MarkedString::String(s) => s.clone(),
222                        MarkedString::LanguageString(ls) => {
223                            format!("```{}\n{}\n```", ls.language, ls.value)
224                        }
225                    })
226                    .collect::<Vec<_>>()
227                    .join("\n"),
228            )),
229            None => Ok(None),
230        }
231    }
232
233    async fn document_symbols(&self, file: PathBuf) -> Layer3Result<Vec<SymbolInfo>> {
234        let language = self.ensure_connected(&file).await?;
235
236        let result = self
237            .client
238            .get_document_symbols(&language, &file)
239            .await
240            .map_err(|e| anyhow::anyhow!("LSP document_symbols failed: {}", e))?;
241
242        Ok(result
243            .iter()
244            .map(|s| Self::doc_symbol_to_info(s, &file))
245            .collect())
246    }
247
248    async fn workspace_symbols(&self, query: &str) -> Layer3Result<Vec<SymbolInfo>> {
249        // Try to get workspace symbols from any connected language server
250        let languages = vec!["rust", "python", "typescript", "go", "java"];
251
252        for language in languages {
253            if self.client.is_connected(language).await {
254                let result = self
255                    .client
256                    .get_workspace_symbols(language, query)
257                    .await
258                    .map_err(|e| anyhow::anyhow!("LSP workspace_symbols failed: {}", e))?;
259
260                return Ok(result
261                    .iter()
262                    .map(|s| SymbolInfo {
263                        name: s.name.clone(),
264                        kind: lsp_kind_to_query_kind(s.kind),
265                        location: CodeLocation {
266                            file: PathBuf::from(s.location.uri.replace("file://", "")),
267                            line: s.location.range.start.line,
268                            column: s.location.range.start.character,
269                        },
270                        range: Some(CodeRange {
271                            start: CodeLocation {
272                                file: PathBuf::from(s.location.uri.replace("file://", "")),
273                                line: s.location.range.start.line,
274                                column: s.location.range.start.character,
275                            },
276                            end: CodeLocation {
277                                file: PathBuf::from(s.location.uri.replace("file://", "")),
278                                line: s.location.range.end.line,
279                                column: s.location.range.end.character,
280                            },
281                        }),
282                        container_name: s.container_name.clone(),
283                    })
284                    .collect());
285            }
286        }
287
288        Ok(Vec::new())
289    }
290
291    async fn query(
292        &self,
293        query_type: QueryType,
294        location: CodeLocation,
295    ) -> Layer3Result<Vec<QueryResult>> {
296        match query_type {
297            QueryType::Definition => {
298                let result = self.go_to_definition(location).await?;
299                Ok(result.map(|r| vec![r]).unwrap_or_default())
300            }
301            QueryType::References => self.find_references(location).await,
302            QueryType::Implementations => self.go_to_implementation(location).await,
303            QueryType::TypeDefinition => {
304                let result = self.go_to_type_definition(location).await?;
305                Ok(result.map(|r| vec![r]).unwrap_or_default())
306            }
307            QueryType::DocumentSymbols => {
308                let symbols = self.document_symbols(location.file.clone()).await?;
309                Ok(symbols
310                    .iter()
311                    .map(|s| QueryResult {
312                        query_type: QueryType::DocumentSymbols,
313                        location: s.location.clone(),
314                        range: s.range.clone(),
315                        display_text: s.name.clone(),
316                        snippet: None,
317                    })
318                    .collect())
319            }
320            QueryType::WorkspaceSymbols => {
321                let symbols = self
322                    .workspace_symbols(&location.file.to_string_lossy())
323                    .await?;
324                Ok(symbols
325                    .iter()
326                    .map(|s| QueryResult {
327                        query_type: QueryType::WorkspaceSymbols,
328                        location: s.location.clone(),
329                        range: s.range.clone(),
330                        display_text: s.name.clone(),
331                        snippet: None,
332                    })
333                    .collect())
334            }
335            QueryType::Hover => {
336                let result = self.hover(location.clone()).await?;
337                Ok(result
338                    .map(|content| {
339                        vec![QueryResult {
340                            query_type: QueryType::Hover,
341                            location,
342                            range: None,
343                            display_text: content,
344                            snippet: None,
345                        }]
346                    })
347                    .unwrap_or_default())
348            }
349        }
350    }
351}
352
353/// 根据 LSP SymbolKind 转换为查询引擎 SymbolKind
354fn lsp_kind_to_query_kind(kind: LspSymbolKind) -> SymbolKind {
355    match kind {
356        LspSymbolKind::File => SymbolKind::File,
357        LspSymbolKind::Module => SymbolKind::Module,
358        LspSymbolKind::Namespace => SymbolKind::Namespace,
359        LspSymbolKind::Package => SymbolKind::Package,
360        LspSymbolKind::Class => SymbolKind::Class,
361        LspSymbolKind::Method => SymbolKind::Method,
362        LspSymbolKind::Property => SymbolKind::Property,
363        LspSymbolKind::Field => SymbolKind::Field,
364        LspSymbolKind::Constructor => SymbolKind::Constructor,
365        LspSymbolKind::Enum => SymbolKind::Enum,
366        LspSymbolKind::Interface => SymbolKind::Interface,
367        LspSymbolKind::Function => SymbolKind::Function,
368        LspSymbolKind::Variable => SymbolKind::Variable,
369        LspSymbolKind::Constant => SymbolKind::Constant,
370        LspSymbolKind::String => SymbolKind::String,
371        LspSymbolKind::Number => SymbolKind::Number,
372        LspSymbolKind::Boolean => SymbolKind::Boolean,
373        LspSymbolKind::Array => SymbolKind::Array,
374        LspSymbolKind::Object => SymbolKind::Object,
375        LspSymbolKind::Key => SymbolKind::Key,
376        LspSymbolKind::Null => SymbolKind::Null,
377        LspSymbolKind::EnumMember => SymbolKind::EnumMember,
378        LspSymbolKind::Struct => SymbolKind::Struct,
379        LspSymbolKind::Event => SymbolKind::Event,
380        LspSymbolKind::Operator => SymbolKind::Operator,
381        LspSymbolKind::TypeParameter => SymbolKind::TypeParameter,
382    }
383}
384
385/// 根据文件扩展名检测语言
386fn detect_language(file: &Path) -> String {
387    match file.extension().and_then(|e| e.to_str()) {
388        Some("rs") => "rust",
389        Some("py") => "python",
390        Some("ts") | Some("tsx") => "typescript",
391        Some("js") | Some("jsx") => "javascript",
392        Some("go") => "go",
393        Some("java") => "java",
394        Some("c") | Some("h") => "c",
395        Some("cpp") | Some("hpp") | Some("cc") => "cpp",
396        Some("json") => "json",
397        Some("yaml") | Some("yml") => "yaml",
398        Some("md") => "markdown",
399        Some("html") => "html",
400        Some("css") => "css",
401        Some("sql") => "sql",
402        _ => "text",
403    }
404    .to_string()
405}
406
407/// LSP 代码分析器
408pub struct LspCodeAnalyzer {
409    engine: Arc<LspQueryEngine>,
410}
411
412impl LspCodeAnalyzer {
413    /// 创建新的代码分析器
414    pub fn new(engine: Arc<LspQueryEngine>) -> Self {
415        Self { engine }
416    }
417}
418
419#[async_trait]
420impl CodeAnalyzer for LspCodeAnalyzer {
421    async fn analyze_structure(&self, file: PathBuf) -> Layer3Result<CodeStructure> {
422        let symbols = self.engine.document_symbols(file.clone()).await?;
423
424        let imports = symbols
425            .iter()
426            .filter(|s| matches!(s.kind, SymbolKind::Module | SymbolKind::Namespace))
427            .map(|s| s.name.clone())
428            .collect();
429
430        let exports = symbols
431            .iter()
432            .filter(|s| s.kind == SymbolKind::Function || s.kind == SymbolKind::Class)
433            .map(|s| s.name.clone())
434            .collect();
435
436        let lines = std::fs::read_to_string(&file)
437            .map(|content| content.lines().count())
438            .unwrap_or(0);
439
440        Ok(CodeStructure {
441            file,
442            imports,
443            exports,
444            symbols,
445            lines,
446        })
447    }
448
449    async fn find_similar(&self, snippet: &str, threshold: f32) -> Layer3Result<Vec<CodeMatch>> {
450        let mut matches = Vec::new();
451
452        let symbols = self.engine.workspace_symbols(snippet).await?;
453
454        for symbol in symbols {
455            let similarity = calculate_similarity(snippet, &symbol.name);
456            if similarity >= threshold {
457                matches.push(CodeMatch {
458                    location: symbol.location,
459                    content: symbol.name,
460                    similarity,
461                });
462            }
463        }
464
465        Ok(matches)
466    }
467
468    async fn detect_patterns(&self, file: PathBuf) -> Layer3Result<Vec<DetectedPattern>> {
469        let symbols = self.engine.document_symbols(file.clone()).await?;
470
471        let mut patterns = Vec::new();
472
473        let classes: Vec<_> = symbols
474            .iter()
475            .filter(|s| s.kind == SymbolKind::Class)
476            .collect();
477        if classes.len() == 1 {
478            patterns.push(DetectedPattern {
479                name: "Potential Singleton".to_string(),
480                pattern_type: PatternType::DesignPattern,
481                locations: classes.iter().map(|c| c.location.clone()).collect(),
482            });
483        }
484
485        for symbol in &symbols {
486            if symbol.kind == SymbolKind::Function {
487                if let Some(range) = &symbol.range {
488                    let lines = range.end.line - range.start.line;
489                    if lines > 50 {
490                        patterns.push(DetectedPattern {
491                            name: format!("Long Function ({})", symbol.name),
492                            pattern_type: PatternType::AntiPattern,
493                            locations: vec![symbol.location.clone()],
494                        });
495                    }
496                }
497            }
498        }
499
500        Ok(patterns)
501    }
502}
503
504/// 计算字符串相似度 (Jaccard-based)
505fn calculate_similarity(a: &str, b: &str) -> f32 {
506    if a == b {
507        return 1.0;
508    }
509
510    if a.is_empty() || b.is_empty() {
511        return 0.0;
512    }
513
514    let words_a: Vec<_> = a.split_whitespace().collect();
515    let words_b: Vec<_> = b.split_whitespace().collect();
516
517    let intersection = words_a.iter().filter(|w| words_b.contains(w)).count();
518    let union = words_a.len() + words_b.len() - intersection;
519
520    if union == 0 {
521        return 0.0;
522    }
523
524    intersection as f32 / union as f32
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530
531    #[test]
532    fn test_detect_language() {
533        assert_eq!(detect_language(&PathBuf::from("test.rs")), "rust");
534        assert_eq!(detect_language(&PathBuf::from("test.py")), "python");
535        assert_eq!(detect_language(&PathBuf::from("test.ts")), "typescript");
536        assert_eq!(detect_language(&PathBuf::from("test.go")), "go");
537    }
538
539    #[test]
540    fn test_similarity_calculation() {
541        assert_eq!(calculate_similarity("test", "test"), 1.0);
542        assert!((calculate_similarity("hello world", "hello") - 0.5).abs() < 0.01);
543    }
544
545    #[test]
546    fn test_lsp_query_engine_creation() {
547        let _engine = LspQueryEngine::new();
548        // Basic creation test - no async needed for new()
549    }
550}