Skip to main content

spring_lsp/
di_validator.rs

1//! 依赖注入验证模块
2//!
3//! 提供依赖注入的验证功能,包括:
4//! - 组件注册验证
5//! - 组件类型存在性验证
6//! - 组件名称匹配验证
7//! - 循环依赖检测
8//! - 配置注入验证
9
10use crate::index::IndexManager;
11use crate::macro_analyzer::{InjectMacro, InjectType, RustDocument, SpringMacro};
12use crate::toml_analyzer::TomlDocument;
13use lsp_types::{Diagnostic, DiagnosticSeverity, Location, NumberOrString};
14use std::collections::{HashMap, HashSet};
15
16/// 依赖注入验证器
17pub struct DependencyInjectionValidator {
18    /// 索引管理器
19    index_manager: IndexManager,
20}
21
22impl DependencyInjectionValidator {
23    /// 创建新的依赖注入验证器
24    pub fn new(index_manager: IndexManager) -> Self {
25        Self { index_manager }
26    }
27
28    /// 验证依赖注入
29    ///
30    /// 验证 Rust 文档中的所有依赖注入,包括:
31    /// - 组件注册验证
32    /// - 组件类型存在性验证
33    /// - 组件名称匹配验证
34    /// - 循环依赖检测
35    /// - 配置注入验证
36    ///
37    /// # Arguments
38    ///
39    /// * `rust_docs` - Rust 文档列表
40    /// * `toml_docs` - TOML 配置文档列表(包含 URI 和文档内容)
41    ///
42    /// # Returns
43    ///
44    /// 返回诊断信息列表
45    ///
46    /// # Requirements
47    ///
48    /// - 11.1: 组件注册验证
49    /// - 11.2: 组件类型存在性验证
50    /// - 11.3: 组件名称匹配验证
51    /// - 11.4: 循环依赖检测
52    /// - 11.5: 配置注入验证
53    pub fn validate(
54        &self,
55        rust_docs: &[RustDocument],
56        toml_docs: &[(lsp_types::Url, TomlDocument)],
57    ) -> Vec<Diagnostic> {
58        let mut diagnostics = Vec::new();
59
60        // 提取所有服务和注入信息
61        let services = self.extract_services(rust_docs);
62
63        // 验证每个服务的依赖注入
64        for (service_name, service_info) in &services {
65            for field in &service_info.fields {
66                if let Some(inject) = &field.inject {
67                    match inject.inject_type {
68                        InjectType::Component => {
69                            // 验证组件注入
70                            diagnostics.extend(self.validate_component_injection(
71                                service_name,
72                                field,
73                                inject,
74                                &service_info.location,
75                            ));
76                        }
77                        InjectType::Config => {
78                            // 验证配置注入
79                            diagnostics.extend(self.validate_config_injection(
80                                field,
81                                inject,
82                                toml_docs,
83                                &service_info.location,
84                            ));
85                        }
86                    }
87                }
88            }
89        }
90
91        // 检测循环依赖
92        diagnostics.extend(self.detect_circular_dependencies(&services));
93
94        diagnostics
95    }
96
97    /// 验证组件注入
98    ///
99    /// # Requirements
100    ///
101    /// - 11.1: 验证组件是否已注册
102    /// - 11.2: 验证组件类型是否存在
103    /// - 11.3: 验证组件名称是否匹配
104    fn validate_component_injection(
105        &self,
106        _service_name: &str,
107        field: &FieldInfo,
108        inject: &InjectMacro,
109        location: &Location,
110    ) -> Vec<Diagnostic> {
111        let mut diagnostics = Vec::new();
112
113        // 获取组件名称(如果指定)
114        let component_name = inject.component_name.as_deref().unwrap_or(&field.type_name);
115
116        // 验证组件是否已注册(需求 11.1)
117        if let Some(component_info) = self.index_manager.find_component(component_name) {
118            // 组件已注册,验证类型是否匹配
119            if component_info.type_name != field.type_name {
120                diagnostics.push(Diagnostic {
121                    range: location.range,
122                    severity: Some(DiagnosticSeverity::WARNING),
123                    code: Some(NumberOrString::String(
124                        "component-type-mismatch".to_string(),
125                    )),
126                    message: format!(
127                        "组件 '{}' 的类型不匹配。期望类型: {},实际类型: {}",
128                        component_name, field.type_name, component_info.type_name
129                    ),
130                    source: Some("spring-lsp".to_string()),
131                    ..Default::default()
132                });
133            }
134        } else {
135            // 组件未注册,检查类型是否存在(需求 11.2)
136            let symbols = self.index_manager.find_symbol(&field.type_name);
137            if symbols.is_empty() {
138                diagnostics.push(Diagnostic {
139                    range: location.range,
140                    severity: Some(DiagnosticSeverity::ERROR),
141                    code: Some(NumberOrString::String(
142                        "component-type-not-found".to_string(),
143                    )),
144                    message: format!(
145                        "组件类型 '{}' 不存在。请确保该类型已定义。",
146                        field.type_name
147                    ),
148                    source: Some("spring-lsp".to_string()),
149                    ..Default::default()
150                });
151            } else {
152                // 类型存在但组件未注册(需求 11.1)
153                diagnostics.push(Diagnostic {
154                    range: location.range,
155                    severity: Some(DiagnosticSeverity::ERROR),
156                    code: Some(NumberOrString::String(
157                        "component-not-registered".to_string(),
158                    )),
159                    message: format!(
160                        "组件 '{}' 未注册。请确保该组件已通过插件注册。",
161                        component_name
162                    ),
163                    source: Some("spring-lsp".to_string()),
164                    ..Default::default()
165                });
166            }
167        }
168
169        // 如果指定了组件名称,验证名称是否匹配(需求 11.3)
170        if let Some(specified_name) = &inject.component_name {
171            if let Some(component_info) = self.index_manager.find_component(specified_name) {
172                // 组件存在,验证类型是否匹配
173                if component_info.type_name != field.type_name {
174                    // 获取所有可用的同类型组件
175                    let available_components =
176                        self.get_available_components_by_type(&field.type_name);
177
178                    let suggestion = if available_components.is_empty() {
179                        String::new()
180                    } else {
181                        format!(
182                            "\n可用的 {} 类型组件: {}",
183                            field.type_name,
184                            available_components.join(", ")
185                        )
186                    };
187
188                    diagnostics.push(Diagnostic {
189                        range: location.range,
190                        severity: Some(DiagnosticSeverity::ERROR),
191                        code: Some(NumberOrString::String(
192                            "component-name-mismatch".to_string(),
193                        )),
194                        message: format!(
195                            "组件名称 '{}' 的类型不匹配。期望类型: {},实际类型: {}{}",
196                            specified_name, field.type_name, component_info.type_name, suggestion
197                        ),
198                        source: Some("spring-lsp".to_string()),
199                        ..Default::default()
200                    });
201                }
202            } else {
203                // 指定的组件名称不存在
204                let available_components = self.get_available_components_by_type(&field.type_name);
205
206                let suggestion = if available_components.is_empty() {
207                    String::new()
208                } else {
209                    format!(
210                        "\n可用的 {} 类型组件: {}",
211                        field.type_name,
212                        available_components.join(", ")
213                    )
214                };
215
216                diagnostics.push(Diagnostic {
217                    range: location.range,
218                    severity: Some(DiagnosticSeverity::ERROR),
219                    code: Some(NumberOrString::String(
220                        "component-name-not-found".to_string(),
221                    )),
222                    message: format!("组件名称 '{}' 不存在。{}", specified_name, suggestion),
223                    source: Some("spring-lsp".to_string()),
224                    ..Default::default()
225                });
226            }
227        }
228
229        diagnostics
230    }
231
232    /// 验证配置注入
233    ///
234    /// # Requirements
235    ///
236    /// - 11.5: 验证配置项是否存在
237    fn validate_config_injection(
238        &self,
239        field: &FieldInfo,
240        _inject: &InjectMacro,
241        toml_docs: &[(lsp_types::Url, TomlDocument)],
242        location: &Location,
243    ) -> Vec<Diagnostic> {
244        let mut diagnostics = Vec::new();
245
246        // 从类型名称中提取配置前缀
247        // 例如:UserConfig -> user, DatabaseConfig -> database
248        let config_prefix = self.extract_config_prefix(&field.type_name);
249
250        // 检查配置是否存在
251        let mut config_found = false;
252        let mut config_file_uri = None;
253
254        for (uri, toml_doc) in toml_docs {
255            if toml_doc.config_sections.contains_key(&config_prefix) {
256                config_found = true;
257                config_file_uri = Some(uri.clone());
258                break;
259            }
260        }
261
262        if !config_found {
263            let message = if let Some(uri) = &config_file_uri {
264                format!(
265                    "配置项 '{}' 不存在。请在配置文件 {} 中添加 [{}] 配置节。",
266                    config_prefix, uri, config_prefix
267                )
268            } else {
269                format!(
270                    "配置项 '{}' 不存在。请在配置文件中添加 [{}] 配置节。",
271                    config_prefix, config_prefix
272                )
273            };
274
275            diagnostics.push(Diagnostic {
276                range: location.range,
277                severity: Some(DiagnosticSeverity::ERROR),
278                code: Some(NumberOrString::String("config-not-found".to_string())),
279                message,
280                source: Some("spring-lsp".to_string()),
281                related_information: config_file_uri.map(|uri| {
282                    vec![lsp_types::DiagnosticRelatedInformation {
283                        location: Location {
284                            uri,
285                            range: lsp_types::Range {
286                                start: lsp_types::Position {
287                                    line: 0,
288                                    character: 0,
289                                },
290                                end: lsp_types::Position {
291                                    line: 0,
292                                    character: 0,
293                                },
294                            },
295                        },
296                        message: format!("在此文件中添加 [{}] 配置节", config_prefix),
297                    }]
298                }),
299                ..Default::default()
300            });
301        }
302
303        diagnostics
304    }
305
306    /// 检测循环依赖
307    ///
308    /// # Requirements
309    ///
310    /// - 11.4: 检测循环依赖并建议使用 LazyComponent
311    fn detect_circular_dependencies(
312        &self,
313        services: &HashMap<String, ServiceInfo>,
314    ) -> Vec<Diagnostic> {
315        let mut diagnostics = Vec::new();
316
317        // 构建依赖图
318        let mut dependency_graph: HashMap<String, Vec<String>> = HashMap::new();
319        for (service_name, service_info) in services {
320            let mut dependencies = Vec::new();
321            for field in &service_info.fields {
322                if let Some(inject) = &field.inject {
323                    if inject.inject_type == InjectType::Component {
324                        // 检查是否是 LazyComponent
325                        if !field.type_name.contains("LazyComponent") {
326                            dependencies.push(field.type_name.clone());
327                        }
328                    }
329                }
330            }
331            dependency_graph.insert(service_name.clone(), dependencies);
332        }
333
334        // 使用 DFS 检测循环
335        let mut visited = HashSet::new();
336        let mut rec_stack = HashSet::new();
337        let mut path = Vec::new();
338
339        for service_name in services.keys() {
340            if !visited.contains(service_name) {
341                if let Some(cycle) = self.detect_cycle_dfs(
342                    service_name,
343                    &dependency_graph,
344                    &mut visited,
345                    &mut rec_stack,
346                    &mut path,
347                ) {
348                    // 找到循环依赖
349                    if let Some(service_info) = services.get(service_name) {
350                        diagnostics.push(Diagnostic {
351                            range: service_info.location.range,
352                            severity: Some(DiagnosticSeverity::WARNING),
353                            code: Some(NumberOrString::String("circular-dependency".to_string())),
354                            message: format!(
355                                "检测到循环依赖: {}。建议使用 LazyComponent<T> 打破循环。",
356                                cycle.join(" -> ")
357                            ),
358                            source: Some("spring-lsp".to_string()),
359                            ..Default::default()
360                        });
361                    }
362                }
363            }
364        }
365
366        diagnostics
367    }
368
369    /// DFS 检测循环依赖
370    fn detect_cycle_dfs(
371        &self,
372        node: &str,
373        graph: &HashMap<String, Vec<String>>,
374        visited: &mut HashSet<String>,
375        rec_stack: &mut HashSet<String>,
376        path: &mut Vec<String>,
377    ) -> Option<Vec<String>> {
378        visited.insert(node.to_string());
379        rec_stack.insert(node.to_string());
380        path.push(node.to_string());
381
382        if let Some(neighbors) = graph.get(node) {
383            for neighbor in neighbors {
384                if !visited.contains(neighbor) {
385                    if let Some(cycle) =
386                        self.detect_cycle_dfs(neighbor, graph, visited, rec_stack, path)
387                    {
388                        return Some(cycle);
389                    }
390                } else if rec_stack.contains(neighbor) {
391                    // 找到循环
392                    let cycle_start = path.iter().position(|n| n == neighbor).unwrap();
393                    let mut cycle = path[cycle_start..].to_vec();
394                    cycle.push(neighbor.to_string());
395                    return Some(cycle);
396                }
397            }
398        }
399
400        rec_stack.remove(node);
401        path.pop();
402        None
403    }
404
405    /// 提取服务信息
406    fn extract_services(&self, rust_docs: &[RustDocument]) -> HashMap<String, ServiceInfo> {
407        let mut services = HashMap::new();
408
409        for doc in rust_docs {
410            for spring_macro in &doc.macros {
411                if let SpringMacro::DeriveService(service_macro) = spring_macro {
412                    let service_info = ServiceInfo {
413                        name: service_macro.struct_name.clone(),
414                        fields: service_macro
415                            .fields
416                            .iter()
417                            .map(|f| FieldInfo {
418                                name: f.name.clone(),
419                                type_name: f.type_name.clone(),
420                                inject: f.inject.clone(),
421                            })
422                            .collect(),
423                        location: Location {
424                            uri: doc.uri.clone(),
425                            range: service_macro.range,
426                        },
427                    };
428                    services.insert(service_macro.struct_name.clone(), service_info);
429                }
430            }
431        }
432
433        services
434    }
435
436    /// 提取配置前缀
437    ///
438    /// 从类型名称中提取配置前缀
439    /// 例如:UserConfig -> user, DatabaseConfig -> database
440    fn extract_config_prefix(&self, type_name: &str) -> String {
441        // 移除 "Config" 后缀
442        let prefix = type_name.strip_suffix("Config").unwrap_or(type_name);
443
444        // 转换为小写并用连字符分隔
445        // 例如:UserProfile -> user-profile
446        let mut result = String::new();
447        for (i, ch) in prefix.chars().enumerate() {
448            if i > 0 && ch.is_uppercase() {
449                result.push('-');
450            }
451            result.push(ch.to_lowercase().next().unwrap());
452        }
453
454        result
455    }
456
457    /// 获取指定类型的所有可用组件
458    fn get_available_components_by_type(&self, _type_name: &str) -> Vec<String> {
459        // TODO: 实现从索引中查找所有指定类型的组件
460        // 当前返回空列表
461        Vec::new()
462    }
463}
464
465/// 服务信息
466struct ServiceInfo {
467    /// 服务名称
468    #[allow(dead_code)]
469    name: String,
470    /// 字段列表
471    fields: Vec<FieldInfo>,
472    /// 位置
473    location: Location,
474}
475
476/// 字段信息
477struct FieldInfo {
478    /// 字段名称
479    #[allow(dead_code)]
480    name: String,
481    /// 字段类型
482    type_name: String,
483    /// 注入宏
484    inject: Option<InjectMacro>,
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490
491    #[test]
492    fn test_extract_config_prefix() {
493        let validator = DependencyInjectionValidator::new(IndexManager::new());
494
495        assert_eq!(validator.extract_config_prefix("UserConfig"), "user");
496        assert_eq!(
497            validator.extract_config_prefix("DatabaseConfig"),
498            "database"
499        );
500        assert_eq!(
501            validator.extract_config_prefix("UserProfileConfig"),
502            "user-profile"
503        );
504        assert_eq!(validator.extract_config_prefix("Config"), "");
505        assert_eq!(validator.extract_config_prefix("User"), "user");
506    }
507
508    #[test]
509    fn test_dependency_injection_validator_new() {
510        let index_manager = IndexManager::new();
511        let validator = DependencyInjectionValidator::new(index_manager);
512
513        // 验证可以创建验证器
514        let diagnostics = validator.validate(&[], &[]);
515        assert_eq!(diagnostics.len(), 0);
516    }
517}