Skip to main content

sh_layer3/query_engine/
lsp_query_engine.rs

1//! # LSP Query Engine
2//!
3//! 基于 LSP 的代码查询引擎实现。
4
5use crate::lsp::client::LspClient;
6use crate::lsp::types::SymbolKind as LspSymbolKind;
7use crate::lsp::types::{DocumentSymbol, Hover, HoverContents, Location, MarkedString, Position};
8use crate::query_engine::{
9    CodeAnalyzer, CodeMatch, CodeStructure, DetectedPattern, PatternType, QueryEngine, SymbolInfo,
10    SymbolKind,
11};
12use crate::types::{CodeLocation, CodeRange, Layer3Result, QueryResult, QueryType};
13use async_trait::async_trait;
14use std::path::PathBuf;
15use std::sync::Arc;
16
17/// LSP 查询引擎
18pub struct LspQueryEngine {
19    client: Arc<LspClient>,
20}
21
22impl LspQueryEngine {
23    /// 创建新的 LSP 查询引擎
24    pub fn new() -> Self {
25        Self {
26            client: Arc::new(LspClient::new()),
27        }
28    }
29
30    /// 使用现有的 LSP 客户端创建
31    pub fn with_client(client: Arc<LspClient>) -> Self {
32        Self { client }
33    }
34
35    /// 获取 LSP 客户端引用
36    pub fn client(&self) -> &Arc<LspClient> {
37        &self.client
38    }
39
40    /// 确保语言服务器已连接
41    async fn ensure_connected(&self, file: &PathBuf) -> Layer3Result<String> {
42        let language = detect_language(file);
43
44        if !self.client.is_connected(&language).await {
45            self.client
46                .initialize(&language, file.parent().unwrap_or(&PathBuf::from(".")))
47                .await
48                .map_err(|e| anyhow::anyhow!("Failed to initialize LSP server: {}", e))?;
49        }
50
51        Ok(language)
52    }
53
54    /// 将 LSP Location 转换为 QueryResult
55    fn location_to_result(loc: &Location, query_type: QueryType) -> QueryResult {
56        let file = PathBuf::from(loc.uri.replace("file://", ""));
57        QueryResult {
58            query_type,
59            location: CodeLocation {
60                file: file.clone(),
61                line: loc.range.start.line,
62                column: loc.range.start.character,
63            },
64            range: Some(CodeRange {
65                start: CodeLocation {
66                    file: file.clone(),
67                    line: loc.range.start.line,
68                    column: loc.range.start.character,
69                },
70                end: CodeLocation {
71                    file,
72                    line: loc.range.end.line,
73                    column: loc.range.end.character,
74                },
75            }),
76            display_text: String::new(),
77            snippet: None,
78        }
79    }
80
81    /// 将 DocumentSymbol 转换为 SymbolInfo
82    fn doc_symbol_to_info(symbol: &DocumentSymbol, file: &PathBuf) -> SymbolInfo {
83        SymbolInfo {
84            name: symbol.name.clone(),
85            kind: lsp_kind_to_query_kind(symbol.kind),
86            location: CodeLocation {
87                file: file.clone(),
88                line: symbol.range.start.line,
89                column: symbol.range.start.character,
90            },
91            range: Some(CodeRange {
92                start: CodeLocation {
93                    file: file.clone(),
94                    line: symbol.range.start.line,
95                    column: symbol.range.start.character,
96                },
97                end: CodeLocation {
98                    file: file.clone(),
99                    line: symbol.range.end.line,
100                    column: symbol.range.end.character,
101                },
102            }),
103            container_name: None,
104        }
105    }
106}
107
108impl Default for LspQueryEngine {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114#[async_trait]
115impl QueryEngine for LspQueryEngine {
116    async fn go_to_definition(&self, location: CodeLocation) -> Layer3Result<Option<QueryResult>> {
117        let language = self.ensure_connected(&location.file).await?;
118
119        let result = self
120            .client
121            .go_to_definition(
122                &language,
123                &location.file,
124                Position::new(location.line, location.column),
125            )
126            .await
127            .map_err(|e| anyhow::anyhow!("LSP go_to_definition failed: {}", e))?;
128
129        Ok(result
130            .first()
131            .map(|loc| Self::location_to_result(loc, QueryType::Definition)))
132    }
133
134    async fn find_references(&self, location: CodeLocation) -> Layer3Result<Vec<QueryResult>> {
135        let language = self.ensure_connected(&location.file).await?;
136
137        let result = self
138            .client
139            .find_references(
140                &language,
141                &location.file,
142                Position::new(location.line, location.column),
143                true,
144            )
145            .await
146            .map_err(|e| anyhow::anyhow!("LSP find_references failed: {}", e))?;
147
148        Ok(result
149            .iter()
150            .map(|loc| Self::location_to_result(loc, QueryType::References))
151            .collect())
152    }
153
154    async fn go_to_implementation(&self, location: CodeLocation) -> Layer3Result<Vec<QueryResult>> {
155        let language = self.ensure_connected(&location.file).await?;
156
157        let result = self
158            .client
159            .go_to_implementation(
160                &language,
161                &location.file,
162                Position::new(location.line, location.column),
163            )
164            .await
165            .map_err(|e| anyhow::anyhow!("LSP go_to_implementation failed: {}", e))?;
166
167        Ok(result
168            .iter()
169            .map(|loc| Self::location_to_result(loc, QueryType::Implementations))
170            .collect())
171    }
172
173    async fn go_to_type_definition(
174        &self,
175        location: CodeLocation,
176    ) -> Layer3Result<Option<QueryResult>> {
177        let language = self.ensure_connected(&location.file).await?;
178
179        let result = self
180            .client
181            .go_to_type_definition(
182                &language,
183                &location.file,
184                Position::new(location.line, location.column),
185            )
186            .await
187            .map_err(|e| anyhow::anyhow!("LSP go_to_type_definition failed: {}", e))?;
188
189        Ok(result.map(|loc| Self::location_to_result(&loc, QueryType::TypeDefinition)))
190    }
191
192    async fn hover(&self, location: CodeLocation) -> Layer3Result<Option<String>> {
193        let language = self.ensure_connected(&location.file).await?;
194
195        let result = self
196            .client
197            .get_hover(
198                &language,
199                &location.file,
200                Position::new(location.line, location.column),
201            )
202            .await
203            .map_err(|e| anyhow::anyhow!("LSP hover failed: {}", e))?;
204
205        match result {
206            Some(Hover {
207                contents: HoverContents::Markup(markup),
208                ..
209            }) => Ok(Some(markup.value)),
210            Some(Hover {
211                contents: HoverContents::String(s),
212                ..
213            }) => Ok(Some(s)),
214            Some(Hover {
215                contents: HoverContents::Array(arr),
216                ..
217            }) => Ok(Some(
218                arr.iter()
219                    .map(|ms| match ms {
220                        MarkedString::String(s) => s.clone(),
221                        MarkedString::LanguageString(ls) => {
222                            format!("```{}\n{}\n```", ls.language, ls.value)
223                        }
224                    })
225                    .collect::<Vec<_>>()
226                    .join("\n"),
227            )),
228            None => Ok(None),
229        }
230    }
231
232    async fn document_symbols(&self, file: PathBuf) -> Layer3Result<Vec<SymbolInfo>> {
233        let language = self.ensure_connected(&file).await?;
234
235        let result = self
236            .client
237            .get_document_symbols(&language, &file)
238            .await
239            .map_err(|e| anyhow::anyhow!("LSP document_symbols failed: {}", e))?;
240
241        Ok(result
242            .iter()
243            .map(|s| Self::doc_symbol_to_info(s, &file))
244            .collect())
245    }
246
247    async fn workspace_symbols(&self, query: &str) -> Layer3Result<Vec<SymbolInfo>> {
248        // Try to get workspace symbols from any connected language server
249        let languages = vec!["rust", "python", "typescript", "go", "java"];
250
251        for language in languages {
252            if self.client.is_connected(language).await {
253                let result = self
254                    .client
255                    .get_workspace_symbols(language, query)
256                    .await
257                    .map_err(|e| anyhow::anyhow!("LSP workspace_symbols failed: {}", e))?;
258
259                return Ok(result
260                    .iter()
261                    .map(|s| SymbolInfo {
262                        name: s.name.clone(),
263                        kind: lsp_kind_to_query_kind(s.kind),
264                        location: CodeLocation {
265                            file: PathBuf::from(s.location.uri.replace("file://", "")),
266                            line: s.location.range.start.line,
267                            column: s.location.range.start.character,
268                        },
269                        range: Some(CodeRange {
270                            start: CodeLocation {
271                                file: PathBuf::from(s.location.uri.replace("file://", "")),
272                                line: s.location.range.start.line,
273                                column: s.location.range.start.character,
274                            },
275                            end: CodeLocation {
276                                file: PathBuf::from(s.location.uri.replace("file://", "")),
277                                line: s.location.range.end.line,
278                                column: s.location.range.end.character,
279                            },
280                        }),
281                        container_name: s.container_name.clone(),
282                    })
283                    .collect());
284            }
285        }
286
287        Ok(Vec::new())
288    }
289
290    async fn query(
291        &self,
292        query_type: QueryType,
293        location: CodeLocation,
294    ) -> Layer3Result<Vec<QueryResult>> {
295        match query_type {
296            QueryType::Definition => {
297                let result = self.go_to_definition(location).await?;
298                Ok(result.map(|r| vec![r]).unwrap_or_default())
299            }
300            QueryType::References => self.find_references(location).await,
301            QueryType::Implementations => self.go_to_implementation(location).await,
302            QueryType::TypeDefinition => {
303                let result = self.go_to_type_definition(location).await?;
304                Ok(result.map(|r| vec![r]).unwrap_or_default())
305            }
306            QueryType::DocumentSymbols => {
307                let symbols = self.document_symbols(location.file.clone()).await?;
308                Ok(symbols
309                    .iter()
310                    .map(|s| QueryResult {
311                        query_type: QueryType::DocumentSymbols,
312                        location: s.location.clone(),
313                        range: s.range.clone(),
314                        display_text: s.name.clone(),
315                        snippet: None,
316                    })
317                    .collect())
318            }
319            QueryType::WorkspaceSymbols => {
320                let symbols = self
321                    .workspace_symbols(&location.file.to_string_lossy())
322                    .await?;
323                Ok(symbols
324                    .iter()
325                    .map(|s| QueryResult {
326                        query_type: QueryType::WorkspaceSymbols,
327                        location: s.location.clone(),
328                        range: s.range.clone(),
329                        display_text: s.name.clone(),
330                        snippet: None,
331                    })
332                    .collect())
333            }
334            QueryType::Hover => {
335                let result = self.hover(location.clone()).await?;
336                Ok(result
337                    .map(|content| {
338                        vec![QueryResult {
339                            query_type: QueryType::Hover,
340                            location,
341                            range: None,
342                            display_text: content,
343                            snippet: None,
344                        }]
345                    })
346                    .unwrap_or_default())
347            }
348        }
349    }
350}
351
352/// 根据 LSP SymbolKind 转换为查询引擎 SymbolKind
353fn lsp_kind_to_query_kind(kind: LspSymbolKind) -> SymbolKind {
354    match kind {
355        LspSymbolKind::File => SymbolKind::File,
356        LspSymbolKind::Module => SymbolKind::Module,
357        LspSymbolKind::Namespace => SymbolKind::Namespace,
358        LspSymbolKind::Package => SymbolKind::Package,
359        LspSymbolKind::Class => SymbolKind::Class,
360        LspSymbolKind::Method => SymbolKind::Method,
361        LspSymbolKind::Property => SymbolKind::Property,
362        LspSymbolKind::Field => SymbolKind::Field,
363        LspSymbolKind::Constructor => SymbolKind::Constructor,
364        LspSymbolKind::Enum => SymbolKind::Enum,
365        LspSymbolKind::Interface => SymbolKind::Interface,
366        LspSymbolKind::Function => SymbolKind::Function,
367        LspSymbolKind::Variable => SymbolKind::Variable,
368        LspSymbolKind::Constant => SymbolKind::Constant,
369        LspSymbolKind::String => SymbolKind::String,
370        LspSymbolKind::Number => SymbolKind::Number,
371        LspSymbolKind::Boolean => SymbolKind::Boolean,
372        LspSymbolKind::Array => SymbolKind::Array,
373        LspSymbolKind::Object => SymbolKind::Object,
374        LspSymbolKind::Key => SymbolKind::Key,
375        LspSymbolKind::Null => SymbolKind::Null,
376        LspSymbolKind::EnumMember => SymbolKind::EnumMember,
377        LspSymbolKind::Struct => SymbolKind::Struct,
378        LspSymbolKind::Event => SymbolKind::Event,
379        LspSymbolKind::Operator => SymbolKind::Operator,
380        LspSymbolKind::TypeParameter => SymbolKind::TypeParameter,
381    }
382}
383
384/// 根据文件扩展名检测语言
385fn detect_language(file: &PathBuf) -> String {
386    match file.extension().and_then(|e| e.to_str()) {
387        Some("rs") => "rust",
388        Some("py") => "python",
389        Some("ts") | Some("tsx") => "typescript",
390        Some("js") | Some("jsx") => "javascript",
391        Some("go") => "go",
392        Some("java") => "java",
393        Some("c") | Some("h") => "c",
394        Some("cpp") | Some("hpp") | Some("cc") => "cpp",
395        Some("json") => "json",
396        Some("yaml") | Some("yml") => "yaml",
397        Some("md") => "markdown",
398        Some("html") => "html",
399        Some("css") => "css",
400        Some("sql") => "sql",
401        _ => "text",
402    }
403    .to_string()
404}
405
406/// LSP 代码分析器
407pub struct LspCodeAnalyzer {
408    engine: Arc<LspQueryEngine>,
409}
410
411impl LspCodeAnalyzer {
412    /// 创建新的代码分析器
413    pub fn new(engine: Arc<LspQueryEngine>) -> Self {
414        Self { engine }
415    }
416}
417
418#[async_trait]
419impl CodeAnalyzer for LspCodeAnalyzer {
420    async fn analyze_structure(&self, file: PathBuf) -> Layer3Result<CodeStructure> {
421        let symbols = self.engine.document_symbols(file.clone()).await?;
422
423        let imports = symbols
424            .iter()
425            .filter(|s| matches!(s.kind, SymbolKind::Module | SymbolKind::Namespace))
426            .map(|s| s.name.clone())
427            .collect();
428
429        let exports = symbols
430            .iter()
431            .filter(|s| s.kind == SymbolKind::Function || s.kind == SymbolKind::Class)
432            .map(|s| s.name.clone())
433            .collect();
434
435        let lines = std::fs::read_to_string(&file)
436            .map(|content| content.lines().count())
437            .unwrap_or(0);
438
439        Ok(CodeStructure {
440            file,
441            imports,
442            exports,
443            symbols,
444            lines,
445        })
446    }
447
448    async fn find_similar(&self, snippet: &str, threshold: f32) -> Layer3Result<Vec<CodeMatch>> {
449        let mut matches = Vec::new();
450
451        let symbols = self.engine.workspace_symbols(snippet).await?;
452
453        for symbol in symbols {
454            let similarity = calculate_similarity(snippet, &symbol.name);
455            if similarity >= threshold {
456                matches.push(CodeMatch {
457                    location: symbol.location,
458                    content: symbol.name,
459                    similarity,
460                });
461            }
462        }
463
464        Ok(matches)
465    }
466
467    async fn detect_patterns(&self, file: PathBuf) -> Layer3Result<Vec<DetectedPattern>> {
468        let symbols = self.engine.document_symbols(file.clone()).await?;
469
470        let mut patterns = Vec::new();
471
472        let classes: Vec<_> = symbols
473            .iter()
474            .filter(|s| s.kind == SymbolKind::Class)
475            .collect();
476        if classes.len() == 1 {
477            patterns.push(DetectedPattern {
478                name: "Potential Singleton".to_string(),
479                pattern_type: PatternType::DesignPattern,
480                locations: classes.iter().map(|c| c.location.clone()).collect(),
481            });
482        }
483
484        for symbol in &symbols {
485            if symbol.kind == SymbolKind::Function {
486                if let Some(range) = &symbol.range {
487                    let lines = range.end.line - range.start.line;
488                    if lines > 50 {
489                        patterns.push(DetectedPattern {
490                            name: format!("Long Function ({})", symbol.name),
491                            pattern_type: PatternType::AntiPattern,
492                            locations: vec![symbol.location.clone()],
493                        });
494                    }
495                }
496            }
497        }
498
499        Ok(patterns)
500    }
501}
502
503/// 计算字符串相似度 (Jaccard-based)
504fn calculate_similarity(a: &str, b: &str) -> f32 {
505    if a == b {
506        return 1.0;
507    }
508
509    if a.is_empty() || b.is_empty() {
510        return 0.0;
511    }
512
513    let words_a: Vec<_> = a.split_whitespace().collect();
514    let words_b: Vec<_> = b.split_whitespace().collect();
515
516    let intersection = words_a.iter().filter(|w| words_b.contains(w)).count();
517    let union = words_a.len() + words_b.len() - intersection;
518
519    if union == 0 {
520        return 0.0;
521    }
522
523    intersection as f32 / union as f32
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529
530    #[test]
531    fn test_detect_language() {
532        assert_eq!(detect_language(&PathBuf::from("test.rs")), "rust");
533        assert_eq!(detect_language(&PathBuf::from("test.py")), "python");
534        assert_eq!(detect_language(&PathBuf::from("test.ts")), "typescript");
535        assert_eq!(detect_language(&PathBuf::from("test.go")), "go");
536    }
537
538    #[test]
539    fn test_similarity_calculation() {
540        assert_eq!(calculate_similarity("test", "test"), 1.0);
541        assert!((calculate_similarity("hello world", "hello") - 0.5).abs() < 0.01);
542    }
543
544    #[test]
545    fn test_lsp_query_engine_creation() {
546        let engine = LspQueryEngine::new();
547        // Basic creation test - no async needed for new()
548        assert!(true);
549    }
550}