Skip to main content

smcp_computer/mcp_clients/
manager.rs

1/**
2* 文件名: manager
3* 作者: JQQ
4* 创建日期: 2025/12/15
5* 最后修改日期: 2025/12/15
6* 版权: 2023 JQQ. All rights reserved.
7* 依赖: tokio, async-trait, serde_json
8* 描述: MCP服务器管理器,负责管理多个MCP服务器连接和工具调用路由
9*/
10use super::model::*;
11use super::utils::client_factory;
12use super::vrl_runtime::VrlRuntime;
13use crate::errors::ComputerError;
14use serde_json::Value;
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17use std::sync::Arc as StdArc;
18use tokio::sync::{watch, RwLock};
19use tracing::{debug, error, info, warn};
20
21/// 工具名称重复错误 / Tool name duplication error
22#[derive(Debug, thiserror::Error)]
23#[error("Tool '{tool_name}' exists in multiple servers: {servers:?}")]
24pub struct ToolNameDuplicatedError {
25    pub tool_name: String,
26    pub servers: Vec<String>,
27}
28
29/// MCP服务器管理器 / MCP server manager
30pub struct MCPServerManager {
31    /// 服务器配置映射 / Server configuration mapping
32    servers_config: Arc<RwLock<HashMap<ServerName, MCPServerConfig>>>,
33    /// 活动客户端映射 / Active client mapping
34    active_clients: Arc<RwLock<HashMap<ServerName, StdArc<dyn MCPClientProtocol>>>>,
35    /// 工具到服务器的映射 / Tool to server mapping
36    tool_mapping: Arc<RwLock<HashMap<ToolName, ServerName>>>,
37    /// 别名映射 / Alias mapping
38    alias_mapping: Arc<RwLock<HashMap<String, (ServerName, ToolName)>>>,
39    /// 禁用工具集合 / Disabled tools set
40    disabled_tools: Arc<RwLock<HashSet<ToolName>>>,
41    /// 自动重连标志 / Auto reconnect flag
42    auto_reconnect: Arc<RwLock<bool>>,
43    /// 自动连接标志 / Auto connect flag
44    auto_connect: Arc<RwLock<bool>>,
45    /// 状态变化通知器 / State change notifier
46    state_notifier: watch::Sender<ManagerState>,
47    /// 健康检查配置 / Health check configuration
48    health_check_config: Arc<RwLock<HealthCheckConfig>>,
49    /// 重连策略 / Reconnect policy
50    reconnect_policy: Arc<RwLock<ReconnectPolicy>>,
51    /// 健康监控任务句柄 / Health monitor task handle
52    health_monitor_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
53    /// 重试计数器(服务器名 -> 重试次数)/ Retry counters (server name -> retry count)
54    retry_counts: Arc<RwLock<HashMap<ServerName, u32>>>,
55}
56
57/// 管理器状态 / Manager state
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum ManagerState {
60    /// 未初始化 / Uninitialized
61    Uninitialized,
62    /// 已初始化 / Initialized
63    Initialized,
64    /// 运行中 / Running
65    Running,
66    /// 错误状态 / Error
67    Error,
68}
69
70impl MCPServerManager {
71    /// 创建新的管理器 / Create new manager
72    pub fn new() -> Self {
73        let (state_tx, _) = watch::channel(ManagerState::Uninitialized);
74
75        Self {
76            servers_config: Arc::new(RwLock::new(HashMap::new())),
77            active_clients: Arc::new(RwLock::new(HashMap::new())),
78            tool_mapping: Arc::new(RwLock::new(HashMap::new())),
79            alias_mapping: Arc::new(RwLock::new(HashMap::new())),
80            disabled_tools: Arc::new(RwLock::new(HashSet::new())),
81            auto_reconnect: Arc::new(RwLock::new(true)),
82            auto_connect: Arc::new(RwLock::new(false)),
83            state_notifier: state_tx,
84            health_check_config: Arc::new(RwLock::new(HealthCheckConfig::default())),
85            reconnect_policy: Arc::new(RwLock::new(ReconnectPolicy::default())),
86            health_monitor_handle: Arc::new(RwLock::new(None)),
87            retry_counts: Arc::new(RwLock::new(HashMap::new())),
88        }
89    }
90
91    /// 使用自定义配置创建管理器 / Create manager with custom configuration
92    pub fn with_config(
93        health_check_config: HealthCheckConfig,
94        reconnect_policy: ReconnectPolicy,
95    ) -> Self {
96        let (state_tx, _) = watch::channel(ManagerState::Uninitialized);
97
98        Self {
99            servers_config: Arc::new(RwLock::new(HashMap::new())),
100            active_clients: Arc::new(RwLock::new(HashMap::new())),
101            tool_mapping: Arc::new(RwLock::new(HashMap::new())),
102            alias_mapping: Arc::new(RwLock::new(HashMap::new())),
103            disabled_tools: Arc::new(RwLock::new(HashSet::new())),
104            auto_reconnect: Arc::new(RwLock::new(reconnect_policy.enabled)),
105            auto_connect: Arc::new(RwLock::new(false)),
106            state_notifier: state_tx,
107            health_check_config: Arc::new(RwLock::new(health_check_config)),
108            reconnect_policy: Arc::new(RwLock::new(reconnect_policy)),
109            health_monitor_handle: Arc::new(RwLock::new(None)),
110            retry_counts: Arc::new(RwLock::new(HashMap::new())),
111        }
112    }
113
114    /// 获取状态通知器 / Get state notifier
115    pub fn get_state_notifier(&self) -> watch::Receiver<ManagerState> {
116        self.state_notifier.subscribe()
117    }
118
119    /// 更新管理器状态 / Update manager state
120    async fn update_state(&self, state: ManagerState) {
121        let _ = self.state_notifier.send(state);
122    }
123
124    /// 初始化管理器 / Initialize manager
125    pub async fn initialize(&self, servers: Vec<MCPServerConfig>) -> Result<(), ComputerError> {
126        // 停止所有现有客户端 / Stop all existing clients
127        self.stop_all().await?;
128
129        // 清空所有状态 / Clear all state
130        self.clear_all().await;
131
132        // 添加新配置 / Add new configurations
133        {
134            let mut configs = self.servers_config.write().await;
135            for server in servers {
136                configs.insert(server.name().to_string(), server);
137            }
138        }
139
140        // 刷新工具映射 / Refresh tool mapping
141        self.refresh_tool_mapping().await?;
142
143        // 更新状态 / Update state
144        self.update_state(ManagerState::Initialized).await;
145
146        info!("Manager initialized successfully");
147        Ok(())
148    }
149
150    /// 添加或更新服务器配置 / Add or update server configuration
151    pub async fn add_or_update_server(&self, config: MCPServerConfig) -> Result<(), ComputerError> {
152        let server_name = config.name().to_string();
153
154        // 检查是否已激活 / Check if already active
155        let is_active = {
156            let clients = self.active_clients.read().await;
157            clients.contains_key(&server_name)
158        };
159
160        if is_active {
161            let auto_reconnect = *self.auto_reconnect.read().await;
162            if auto_reconnect {
163                // 重启服务器 / Restart server
164                self.restart_server(&server_name).await?;
165            } else {
166                return Err(ComputerError::InvalidConfiguration(format!(
167                    "Server {} is active. Stop it before updating config",
168                    server_name
169                )));
170            }
171        }
172
173        // 更新配置 / Update configuration
174        {
175            let mut configs = self.servers_config.write().await;
176            configs.insert(server_name.clone(), config);
177        }
178
179        // 检查是否需要自动连接 / Check if need auto connect
180        let auto_connect = *self.auto_connect.read().await;
181        if auto_connect && !is_active {
182            self.start_client(&server_name).await?;
183        }
184
185        // 刷新工具映射 / Refresh tool mapping
186        self.refresh_tool_mapping().await?;
187
188        Ok(())
189    }
190
191    /// 移除服务器配置 / Remove server configuration
192    pub async fn remove_server(&self, server_name: &str) -> Result<(), ComputerError> {
193        // 停止客户端 / Stop client
194        self.stop_client(server_name).await?;
195
196        // 移除配置 / Remove configuration
197        {
198            let mut configs = self.servers_config.write().await;
199            configs.remove(server_name);
200        }
201
202        // 刷新工具映射 / Refresh tool mapping
203        self.refresh_tool_mapping().await?;
204
205        Ok(())
206    }
207
208    /// 启动所有启用的服务器 / Start all enabled servers
209    pub async fn start_all(&self) -> Result<(), ComputerError> {
210        let configs = self.servers_config.read().await;
211        let server_names: Vec<String> = configs
212            .iter()
213            .filter(|(_, config)| !config.disabled())
214            .map(|(name, _)| name.clone())
215            .collect();
216
217        drop(configs);
218
219        for server_name in server_names {
220            self.start_client(&server_name).await?;
221        }
222
223        // 更新状态 / Update state
224        self.update_state(ManagerState::Running).await;
225
226        info!("All servers started successfully");
227        Ok(())
228    }
229
230    /// 启动单个客户端 / Start single client
231    pub async fn start_client(&self, server_name: &str) -> Result<(), ComputerError> {
232        // 获取配置 / Get configuration
233        let config = {
234            let configs = self.servers_config.read().await;
235            configs.get(server_name).cloned()
236        };
237
238        let config = config.ok_or_else(|| {
239            ComputerError::InvalidConfiguration(format!("Unknown server: {}", server_name))
240        })?;
241
242        if config.disabled() {
243            return Err(ComputerError::InvalidConfiguration(format!(
244                "Cannot start disabled server: {}",
245                server_name
246            )));
247        }
248
249        // 检查是否已启动 / Check if already started
250        {
251            let clients = self.active_clients.read().await;
252            if clients.contains_key(server_name) {
253                return Ok(()); // 已经启动 / Already started
254            }
255        }
256
257        // 创建客户端 / Create client
258        let client = client_factory(config);
259
260        // 连接服务器 / Connect to server
261        client.connect().await.map_err(|e| {
262            ComputerError::ConnectionError(format!("Failed to connect to {}: {}", server_name, e))
263        })?;
264
265        // 添加到活动客户端 / Add to active clients
266        {
267            let mut clients = self.active_clients.write().await;
268            clients.insert(server_name.to_string(), client);
269        }
270
271        // 刷新工具映射 / Refresh tool mapping
272        self.refresh_tool_mapping().await?;
273
274        info!("Client {} started successfully", server_name);
275        Ok(())
276    }
277
278    /// 停止单个客户端 / Stop single client
279    pub async fn stop_client(&self, server_name: &str) -> Result<(), ComputerError> {
280        // 移除客户端 / Remove client
281        let mut client = {
282            let mut clients = self.active_clients.write().await;
283            clients.remove(server_name)
284        };
285
286        // 断开连接 / Disconnect
287        if let Some(ref mut c) = client {
288            c.disconnect().await.map_err(|e| {
289                ComputerError::ConnectionError(format!(
290                    "Failed to disconnect from {}: {}",
291                    server_name, e
292                ))
293            })?;
294        }
295
296        // 刷新工具映射 / Refresh tool mapping
297        self.refresh_tool_mapping().await?;
298
299        info!("Client {} stopped successfully", server_name);
300        Ok(())
301    }
302
303    /// 重启服务器 / Restart server
304    async fn restart_server(&self, server_name: &str) -> Result<(), ComputerError> {
305        self.stop_client(server_name).await?;
306
307        // 检查是否启用 / Check if enabled
308        let enabled = {
309            let configs = self.servers_config.read().await;
310            configs
311                .get(server_name)
312                .map(|c| !c.disabled())
313                .unwrap_or(false)
314        };
315
316        if enabled {
317            self.start_client(server_name).await?;
318        }
319
320        Ok(())
321    }
322
323    /// 停止所有客户端 / Stop all clients
324    pub async fn stop_all(&self) -> Result<(), ComputerError> {
325        let server_names: Vec<String> = {
326            let clients = self.active_clients.read().await;
327            clients.keys().cloned().collect()
328        };
329
330        for server_name in server_names {
331            self.stop_client(&server_name).await?;
332        }
333
334        // 更新状态 / Update state
335        self.update_state(ManagerState::Initialized).await;
336
337        info!("All servers stopped successfully");
338        Ok(())
339    }
340
341    /// 清空所有状态 / Clear all state
342    async fn clear_all(&self) {
343        self.servers_config.write().await.clear();
344        self.active_clients.write().await.clear();
345        self.tool_mapping.write().await.clear();
346        self.alias_mapping.write().await.clear();
347        self.disabled_tools.write().await.clear();
348    }
349
350    /// 关闭管理器 / Close manager
351    pub async fn close(&self) -> Result<(), ComputerError> {
352        self.stop_all().await?;
353        self.clear_all().await;
354        self.update_state(ManagerState::Uninitialized).await;
355        info!("Manager closed successfully");
356        Ok(())
357    }
358
359    /// 刷新工具映射 / Refresh tool mapping
360    async fn refresh_tool_mapping(&self) -> Result<(), ComputerError> {
361        // 清空现有映射 / Clear existing mappings
362        self.tool_mapping.write().await.clear();
363        self.alias_mapping.write().await.clear();
364        self.disabled_tools.write().await.clear();
365
366        // 临时存储工具源服务器 / Temporarily store tool source servers
367        let mut tool_sources: HashMap<ToolName, Vec<ServerName>> = HashMap::new();
368
369        // 收集所有活动服务器的工具 / Collect tools from all active servers
370        let clients = self.active_clients.read().await;
371        let configs = self.servers_config.read().await;
372
373        for (server_name, client) in clients.iter() {
374            let config = match configs.get(server_name) {
375                Some(c) => c,
376                None => continue,
377            };
378
379            // 获取工具列表 / Get tool list
380            match client.list_tools().await {
381                Ok(tools) => {
382                    for tool in tools {
383                        let original_tool_name = tool.name.clone();
384
385                        // 获取合并后的工具元数据 / Get merged tool metadata
386                        let tool_meta = self.merged_tool_meta(config, &original_tool_name);
387
388                        // 确定最终显示的工具名 / Determine final display name
389                        let display_name = tool_meta
390                            .and_then(|meta| meta.alias)
391                            .unwrap_or_else(|| original_tool_name.clone());
392
393                        // 如果使用别名,更新别名映射 / Update alias mapping if using alias
394                        if display_name != original_tool_name {
395                            let mut alias_map = self.alias_mapping.write().await;
396                            alias_map.insert(
397                                display_name.clone(),
398                                (server_name.clone(), original_tool_name.clone()),
399                            );
400                        }
401
402                        // 添加到工具源映射 / Add to tool source mapping
403                        tool_sources
404                            .entry(display_name.clone())
405                            .or_default()
406                            .push(server_name.clone());
407
408                        // 检查是否为禁用工具 / Check if disabled tool
409                        let forbidden_tools = config.forbidden_tools();
410                        if forbidden_tools.contains(&display_name)
411                            || forbidden_tools.contains(&original_tool_name)
412                        {
413                            let mut disabled = self.disabled_tools.write().await;
414                            disabled.insert(display_name);
415                        }
416                    }
417                }
418                Err(e) => {
419                    error!("Error listing tools for {}: {}", server_name, e);
420                }
421            }
422        }
423
424        // 构建最终映射(处理工具名冲突) / Build final mapping (handle tool name conflicts)
425        for (tool, sources) in tool_sources {
426            if sources.len() > 1 {
427                warn!("Tool '{}' exists in multiple servers: {:?}", tool, sources);
428                let suggestion =
429                    "Please use the 'alias' feature in ToolMeta to resolve conflicts. \
430                    Each tool should have a unique name or alias across all servers.";
431                return Err(ComputerError::InvalidConfiguration(format!(
432                    "Tool '{}' exists in multiple servers: {:?}\n{}",
433                    tool, sources, suggestion
434                )));
435            }
436            let mut mapping = self.tool_mapping.write().await;
437            mapping.insert(tool, sources[0].clone());
438        }
439
440        debug!("Tool mapping refreshed successfully");
441        Ok(())
442    }
443
444    /// 验证工具调用 / Validate tool call
445    pub async fn validate_tool_call(
446        &self,
447        tool_name: &str,
448        _parameters: &serde_json::Value,
449    ) -> Result<(ServerName, ToolName), ComputerError> {
450        // 检查工具是否可用 / Check if tool is available
451        let disabled = self.disabled_tools.read().await;
452        if disabled.contains(tool_name) {
453            return Err(ComputerError::PermissionError(format!(
454                "Tool '{}' is disabled by configuration",
455                tool_name
456            )));
457        }
458
459        // 获取服务器名称 / Get server name
460        let server_name = {
461            let mapping = self.tool_mapping.read().await;
462            mapping.get(tool_name).cloned()
463        };
464
465        let server_name = server_name.ok_or_else(|| {
466            ComputerError::InvalidConfiguration(format!(
467                "Tool '{}' not found in any active server",
468                tool_name
469            ))
470        })?;
471
472        // 检查是否为别名 / Check if it's an alias
473        let original_tool_name = {
474            let alias_map = self.alias_mapping.read().await;
475            if let Some((_, original)) = alias_map.get(tool_name) {
476                original.clone()
477            } else {
478                tool_name.to_string()
479            }
480        };
481
482        Ok((server_name, original_tool_name))
483    }
484
485    /// 调用工具 / Call tool
486    pub async fn call_tool(
487        &self,
488        server_name: &str,
489        tool_name: &str,
490        parameters: serde_json::Value,
491        timeout: Option<std::time::Duration>,
492    ) -> Result<CallToolResult, ComputerError> {
493        // 获取客户端引用 / Get client reference
494        let client = {
495            let clients = self.active_clients.read().await;
496            clients
497                .get(server_name)
498                .ok_or_else(|| {
499                    ComputerError::InvalidConfiguration(format!(
500                        "Server '{}' for tool '{}' is not active",
501                        server_name, tool_name
502                    ))
503                })?
504                .clone()
505        };
506
507        // 执行工具调用 / Execute tool call
508        let result = if let Some(timeout) = timeout {
509            tokio::time::timeout(timeout, client.call_tool(tool_name, parameters))
510                .await
511                .map_err(|_| ComputerError::TimeoutError("Tool execution timed out".to_string()))?
512        } else {
513            client.call_tool(tool_name, parameters).await
514        };
515
516        let mut result = result
517            .map_err(|e| ComputerError::ProtocolError(format!("Tool execution failed: {}", e)))?;
518
519        // 添加工具元数据到结果 / Add tool metadata to result
520        let config = {
521            let configs = self.servers_config.read().await;
522            configs.get(server_name).cloned()
523        };
524
525        if let Some(config) = config {
526            if let Some(tool_meta) = self.merged_tool_meta(&config, tool_name) {
527                if result.meta.is_none() {
528                    result.meta = Some(std::collections::HashMap::new());
529                }
530                if let Some(ref mut meta) = result.meta {
531                    meta.insert(
532                        A2C_TOOL_META.to_string(),
533                        serde_json::to_value(tool_meta).unwrap(),
534                    );
535                }
536            }
537
538            // VRL转换 / VRL transformation
539            if let Some(vrl_script) = config.vrl() {
540                // 获取原始参数用于VRL处理
541                // Note: 这里需要从调用栈获取原始参数,暂时使用空对象
542                let parameters = serde_json::json!({});
543
544                // 创建VRL事件,包含工具调用结果和元数据
545                let mut event = serde_json::to_value(&result).unwrap_or_default();
546                if let Value::Object(ref mut map) = event {
547                    map.insert(
548                        "tool_name".to_string(),
549                        Value::String(tool_name.to_string()),
550                    );
551                    map.insert("parameters".to_string(), parameters);
552                }
553
554                // 执行VRL转换
555                let mut runtime = VrlRuntime::new();
556                match runtime.run(vrl_script, event, "UTC") {
557                    Ok(vrl_result) => {
558                        // 将转换后的结果存储到meta中
559                        if result.meta.is_none() {
560                            result.meta = Some(std::collections::HashMap::new());
561                        }
562                        if let Some(ref mut meta) = result.meta {
563                            // 将转换后的结果序列化为JSON字符串
564                            if let Ok(transformed_json) =
565                                serde_json::to_string(&vrl_result.processed_event)
566                            {
567                                meta.insert(
568                                    A2C_VRL_TRANSFORMED.to_string(),
569                                    Value::String(transformed_json),
570                                );
571                            }
572                        }
573                        debug!(
574                            "VRL转换成功 / VRL transformation succeeded for tool '{}'",
575                            tool_name
576                        );
577                    }
578                    Err(e) => {
579                        warn!(
580                            "VRL转换失败 / VRL transformation failed for tool '{}': {}. 原始结果将正常返回 / Original result will be returned normally.",
581                            tool_name, e
582                        );
583                    }
584                }
585            }
586        }
587
588        Ok(result)
589    }
590
591    /// 执行工具(支持别名) / Execute tool (supports alias)
592    pub async fn execute_tool(
593        &self,
594        tool_name: &str,
595        parameters: serde_json::Value,
596        timeout: Option<std::time::Duration>,
597    ) -> Result<CallToolResult, ComputerError> {
598        let (server_name, original_tool_name) =
599            self.validate_tool_call(tool_name, &parameters).await?;
600        self.call_tool(&server_name, &original_tool_name, parameters, timeout)
601            .await
602    }
603
604    /// 获取服务器状态列表 / Get server status list
605    pub async fn get_server_status(&self) -> Vec<(String, bool, String)> {
606        let configs = self.servers_config.read().await;
607        let clients = self.active_clients.read().await;
608
609        configs
610            .keys()
611            .map(|name| {
612                let is_active = clients.contains_key(name);
613                let state = if is_active {
614                    clients
615                        .get(name)
616                        .map(|c| c.state().to_string())
617                        .unwrap_or_else(|| "unknown".to_string())
618                } else {
619                    "pending".to_string()
620                };
621                (name.clone(), is_active, state)
622            })
623            .collect()
624    }
625
626    /// 获取所有服务器配置(用于 GetComputerConfigRet)
627    /// Get all server configurations (for GetComputerConfigRet)
628    /// 返回格式:{ server_name: { type, status, disabled, ... } }
629    /// Returns format: { server_name: { type, status, disabled, ... } }
630    pub async fn get_server_configs(&self) -> serde_json::Value {
631        let configs = self.servers_config.read().await;
632        let clients = self.active_clients.read().await;
633
634        let mut result = serde_json::Map::new();
635
636        for (name, config) in configs.iter() {
637            let is_active = clients.contains_key(name);
638            let state = if is_active {
639                clients
640                    .get(name)
641                    .map(|c| c.state().to_string())
642                    .unwrap_or_else(|| "unknown".to_string())
643            } else {
644                "pending".to_string()
645            };
646
647            // 构建服务器配置信息 / Build server config info
648            let mut server_info = serde_json::Map::new();
649
650            // 添加类型信息 / Add type info
651            let server_type = match config {
652                MCPServerConfig::Stdio(_) => "stdio",
653                MCPServerConfig::Sse(_) => "sse",
654                MCPServerConfig::Http(_) => "http",
655            };
656            server_info.insert(
657                "type".to_string(),
658                serde_json::Value::String(server_type.to_string()),
659            );
660
661            // 添加状态信息 / Add status info
662            server_info.insert("status".to_string(), serde_json::Value::String(state));
663            server_info.insert("is_active".to_string(), serde_json::Value::Bool(is_active));
664            server_info.insert(
665                "disabled".to_string(),
666                serde_json::Value::Bool(config.disabled()),
667            );
668
669            // 添加禁用工具列表 / Add forbidden tools list
670            let forbidden_tools: Vec<serde_json::Value> = config
671                .forbidden_tools()
672                .iter()
673                .map(|t| serde_json::Value::String(t.clone()))
674                .collect();
675            server_info.insert(
676                "forbidden_tools".to_string(),
677                serde_json::Value::Array(forbidden_tools),
678            );
679
680            // 添加工具元数据 / Add tool metadata
681            if let Ok(tool_meta_json) = serde_json::to_value(config.tool_meta()) {
682                server_info.insert("tool_meta".to_string(), tool_meta_json);
683            }
684
685            // 添加默认工具元数据 / Add default tool metadata
686            if let Some(default_meta) = config.default_tool_meta() {
687                if let Ok(default_meta_json) = serde_json::to_value(default_meta) {
688                    server_info.insert("default_tool_meta".to_string(), default_meta_json);
689                }
690            }
691
692            // 添加 VRL 脚本(如果有)/ Add VRL script if present
693            if let Some(vrl) = config.vrl() {
694                server_info.insert(
695                    "vrl".to_string(),
696                    serde_json::Value::String(vrl.to_string()),
697                );
698            }
699
700            // 添加服务器参数(根据类型)/ Add server parameters based on type
701            match config {
702                MCPServerConfig::Stdio(stdio_config) => {
703                    if let Ok(params_json) = serde_json::to_value(&stdio_config.server_parameters) {
704                        server_info.insert("server_parameters".to_string(), params_json);
705                    }
706                }
707                MCPServerConfig::Sse(sse_config) => {
708                    if let Ok(params_json) = serde_json::to_value(&sse_config.server_parameters) {
709                        server_info.insert("server_parameters".to_string(), params_json);
710                    }
711                }
712                MCPServerConfig::Http(http_config) => {
713                    if let Ok(params_json) = serde_json::to_value(&http_config.server_parameters) {
714                        server_info.insert("server_parameters".to_string(), params_json);
715                    }
716                }
717            }
718
719            result.insert(name.clone(), serde_json::Value::Object(server_info));
720        }
721
722        serde_json::Value::Object(result)
723    }
724
725    /// 获取可用工具列表 / Get available tools list
726    pub async fn list_available_tools(&self) -> Vec<Tool> {
727        let mut tools = Vec::new();
728        let mapping = self.tool_mapping.read().await;
729        let alias_map = self.alias_mapping.read().await;
730
731        for (display_name, server_name) in mapping.iter() {
732            let client = {
733                let clients = self.active_clients.read().await;
734                clients.get(server_name).cloned()
735            };
736
737            if let Some(client) = client {
738                // 获取原始工具名称 / Get original tool name
739                let original_name = alias_map
740                    .get(display_name)
741                    .map(|(_, original)| original.clone())
742                    .unwrap_or_else(|| display_name.clone());
743
744                // 获取工具列表 / Get tool list
745                if let Ok(tool_list) = client.list_tools().await {
746                    if let Some(tool) = tool_list.into_iter().find(|t| t.name == original_name) {
747                        // 更新工具名称为显示名称 / Update tool name to display name
748                        let mut display_tool = tool;
749                        display_tool.name = display_name.clone();
750                        tools.push(display_tool);
751                    }
752                }
753            }
754        }
755
756        tools
757    }
758
759    /// 合并工具元数据 / Merge tool metadata
760    fn merged_tool_meta(&self, config: &MCPServerConfig, tool_name: &str) -> Option<ToolMeta> {
761        let specific = config.tool_meta().get(tool_name);
762        let default = config.default_tool_meta();
763
764        match (specific, default) {
765            (None, None) => None,
766            (Some(s), None) => Some(s.clone()),
767            (None, Some(d)) => Some(d.clone()),
768            (Some(s), Some(d)) => {
769                // 浅合并,specific优先 / Shallow merge, specific takes priority
770                let mut merged = d.clone();
771                if s.auto_apply.is_some() {
772                    merged.auto_apply = s.auto_apply;
773                }
774                if s.alias.is_some() {
775                    merged.alias = s.alias.clone();
776                }
777                if s.tags.is_some() {
778                    merged.tags = s.tags.clone();
779                }
780                if s.ret_object_mapper.is_some() {
781                    merged.ret_object_mapper = s.ret_object_mapper.clone();
782                }
783                Some(merged)
784            }
785        }
786    }
787
788    /// 启用自动连接 / Enable auto connect
789    pub async fn enable_auto_connect(&self) {
790        *self.auto_connect.write().await = true;
791    }
792
793    /// 禁用自动连接 / Disable auto connect
794    pub async fn disable_auto_connect(&self) {
795        *self.auto_connect.write().await = false;
796    }
797
798    /// 启用自动重连 / Enable auto reconnect
799    pub async fn enable_auto_reconnect(&self) {
800        *self.auto_reconnect.write().await = true;
801    }
802
803    /// 禁用自动重连 / Disable auto reconnect
804    pub async fn disable_auto_reconnect(&self) {
805        *self.auto_reconnect.write().await = false;
806    }
807
808    /// 设置健康检查配置 / Set health check configuration
809    pub async fn set_health_check_config(&self, config: HealthCheckConfig) {
810        *self.health_check_config.write().await = config;
811    }
812
813    /// 获取健康检查配置 / Get health check configuration
814    pub async fn get_health_check_config(&self) -> HealthCheckConfig {
815        self.health_check_config.read().await.clone()
816    }
817
818    /// 设置重连策略 / Set reconnect policy
819    pub async fn set_reconnect_policy(&self, policy: ReconnectPolicy) {
820        *self.reconnect_policy.write().await = policy;
821    }
822
823    /// 获取重连策略 / Get reconnect policy
824    pub async fn get_reconnect_policy(&self) -> ReconnectPolicy {
825        self.reconnect_policy.read().await.clone()
826    }
827
828    /// 启动健康监控 / Start health monitoring
829    /// 定期检查所有活动客户端的健康状态,并在检测到故障时自动重连
830    /// Periodically checks health of all active clients and auto-reconnects on failure
831    pub async fn start_health_monitor(&self) {
832        // 先停止现有的监控任务 / Stop existing monitor task first
833        self.stop_health_monitor().await;
834
835        let health_config = self.health_check_config.clone();
836        let reconnect_policy = self.reconnect_policy.clone();
837        let active_clients = self.active_clients.clone();
838        let _servers_config = self.servers_config.clone();
839        let retry_counts = self.retry_counts.clone();
840        let auto_reconnect = self.auto_reconnect.clone();
841
842        let handle = tokio::spawn(async move {
843            loop {
844                let config = health_config.read().await.clone();
845                if !config.enabled {
846                    // 健康检查禁用,等待一段时间后重新检查配置
847                    // Health check disabled, wait and re-check config
848                    tokio::time::sleep(std::time::Duration::from_secs(10)).await;
849                    continue;
850                }
851
852                // 获取所有活动客户端 / Get all active clients
853                let clients: Vec<(String, StdArc<dyn MCPClientProtocol>)> = {
854                    let clients_guard = active_clients.read().await;
855                    clients_guard
856                        .iter()
857                        .map(|(k, v)| (k.clone(), v.clone()))
858                        .collect()
859                };
860
861                // 对每个客户端执行健康检查 / Perform health check on each client
862                for (server_name, client) in clients {
863                    let check_result = tokio::time::timeout(
864                        std::time::Duration::from_secs(config.timeout_secs),
865                        client.health_check(),
866                    )
867                    .await;
868
869                    let is_healthy = match check_result {
870                        Ok(result) => result.is_healthy,
871                        Err(_) => {
872                            warn!("Health check timed out for server: {}", server_name);
873                            false
874                        }
875                    };
876
877                    if !is_healthy {
878                        warn!("Server {} is unhealthy", server_name);
879
880                        // 检查是否启用自动重连 / Check if auto-reconnect is enabled
881                        let should_reconnect = *auto_reconnect.read().await;
882                        if !should_reconnect {
883                            continue;
884                        }
885
886                        let policy = reconnect_policy.read().await.clone();
887                        let mut retries = retry_counts.write().await;
888                        let retry_count = retries.entry(server_name.clone()).or_insert(0);
889
890                        if policy.should_retry(*retry_count) {
891                            let delay = policy.calculate_delay(*retry_count);
892                            info!(
893                                "Attempting to reconnect {} (retry {}/{}), delay {:?}",
894                                server_name,
895                                *retry_count + 1,
896                                if policy.max_retries == 0 {
897                                    "∞".to_string()
898                                } else {
899                                    policy.max_retries.to_string()
900                                },
901                                delay
902                            );
903
904                            tokio::time::sleep(delay).await;
905
906                            // 尝试断开并重新连接 / Try disconnect and reconnect
907                            if let Err(e) = client.disconnect().await {
908                                warn!("Failed to disconnect {}: {}", server_name, e);
909                            }
910
911                            match client.connect().await {
912                                Ok(_) => {
913                                    info!("Successfully reconnected to {}", server_name);
914                                    // 重置重试计数 / Reset retry count
915                                    *retry_count = 0;
916                                }
917                                Err(e) => {
918                                    error!("Failed to reconnect to {}: {}", server_name, e);
919                                    *retry_count += 1;
920                                }
921                            }
922                        } else {
923                            error!(
924                                "Max retries ({}) reached for server {}. Giving up.",
925                                policy.max_retries, server_name
926                            );
927                            // 可以考虑从活动客户端中移除 / Consider removing from active clients
928                        }
929                    } else {
930                        // 健康检查通过,重置重试计数 / Health check passed, reset retry count
931                        let mut retries = retry_counts.write().await;
932                        retries.remove(&server_name);
933                        debug!("Server {} is healthy", server_name);
934                    }
935                }
936
937                // 等待下一次健康检查 / Wait for next health check
938                tokio::time::sleep(std::time::Duration::from_secs(config.interval_secs)).await;
939            }
940        });
941
942        *self.health_monitor_handle.write().await = Some(handle);
943        info!("Health monitor started");
944    }
945
946    /// 停止健康监控 / Stop health monitoring
947    pub async fn stop_health_monitor(&self) {
948        if let Some(handle) = self.health_monitor_handle.write().await.take() {
949            handle.abort();
950            info!("Health monitor stopped");
951        }
952    }
953
954    /// 检查单个服务器的健康状态 / Check health of a single server
955    pub async fn check_server_health(&self, server_name: &str) -> Option<HealthCheckResult> {
956        let clients = self.active_clients.read().await;
957        if let Some(client) = clients.get(server_name) {
958            let config = self.health_check_config.read().await;
959            let result = tokio::time::timeout(
960                std::time::Duration::from_secs(config.timeout_secs),
961                client.health_check(),
962            )
963            .await;
964
965            match result {
966                Ok(health_result) => Some(health_result),
967                Err(_) => Some(HealthCheckResult {
968                    is_healthy: false,
969                    checked_at: std::time::Instant::now(),
970                    error: Some("Health check timed out".to_string()),
971                    response_time_ms: None,
972                }),
973            }
974        } else {
975            None
976        }
977    }
978
979    /// 检查所有服务器的健康状态 / Check health of all servers
980    pub async fn check_all_health(&self) -> HashMap<String, HealthCheckResult> {
981        let mut results = HashMap::new();
982        let clients: Vec<(String, StdArc<dyn MCPClientProtocol>)> = {
983            let clients_guard = self.active_clients.read().await;
984            clients_guard
985                .iter()
986                .map(|(k, v)| (k.clone(), v.clone()))
987                .collect()
988        };
989
990        let config = self.health_check_config.read().await.clone();
991
992        for (server_name, client) in clients {
993            let result = tokio::time::timeout(
994                std::time::Duration::from_secs(config.timeout_secs),
995                client.health_check(),
996            )
997            .await;
998
999            let health_result = match result {
1000                Ok(hr) => hr,
1001                Err(_) => HealthCheckResult {
1002                    is_healthy: false,
1003                    checked_at: std::time::Instant::now(),
1004                    error: Some("Health check timed out".to_string()),
1005                    response_time_ms: None,
1006                },
1007            };
1008
1009            results.insert(server_name, health_result);
1010        }
1011
1012        results
1013    }
1014
1015    /// 获取重试计数 / Get retry counts
1016    pub async fn get_retry_counts(&self) -> HashMap<String, u32> {
1017        self.retry_counts.read().await.clone()
1018    }
1019
1020    /// 重置特定服务器的重试计数 / Reset retry count for a specific server
1021    pub async fn reset_retry_count(&self, server_name: &str) {
1022        self.retry_counts.write().await.remove(server_name);
1023    }
1024
1025    /// 重置所有重试计数 / Reset all retry counts
1026    pub async fn reset_all_retry_counts(&self) {
1027        self.retry_counts.write().await.clear();
1028    }
1029}
1030
1031impl Default for MCPServerManager {
1032    fn default() -> Self {
1033        Self::new()
1034    }
1035}
1036
1037#[cfg(test)]
1038mod tests {
1039    use super::*;
1040    use std::collections::HashMap;
1041    use tokio::time::{sleep, Duration};
1042
1043    #[tokio::test]
1044    async fn test_manager_creation() {
1045        let manager = MCPServerManager::new();
1046        let status = manager.get_server_status().await;
1047        assert!(status.is_empty());
1048    }
1049
1050    #[tokio::test]
1051    async fn test_manager_initialization() {
1052        let manager = MCPServerManager::new();
1053
1054        // 创建服务器配置 / Create server configurations
1055        let configs = vec![
1056            // STDIO服务器配置 / STDIO server configuration
1057            MCPServerConfig::Stdio(StdioServerConfig {
1058                name: "test_stdio".to_string(),
1059                disabled: false,
1060                forbidden_tools: vec![],
1061                tool_meta: HashMap::new(),
1062                default_tool_meta: None,
1063                vrl: None,
1064                server_parameters: StdioServerParameters {
1065                    command: "echo".to_string(),
1066                    args: vec!["hello".to_string()],
1067                    env: HashMap::new(),
1068                    cwd: None,
1069                },
1070            }),
1071            // HTTP服务器配置 / HTTP server configuration
1072            MCPServerConfig::Http(HttpServerConfig {
1073                name: "test_http".to_string(),
1074                disabled: true, // 禁用此服务器 / Disable this server
1075                forbidden_tools: vec![],
1076                tool_meta: HashMap::new(),
1077                default_tool_meta: None,
1078                vrl: None,
1079                server_parameters: HttpServerParameters {
1080                    url: "http://localhost:8080".to_string(),
1081                    headers: HashMap::new(),
1082                },
1083            }),
1084        ];
1085
1086        // 初始化管理器 / Initialize manager
1087        let result = manager.initialize(configs).await;
1088        assert!(result.is_ok());
1089
1090        // 检查状态 / Check status
1091        let status = manager.get_server_status().await;
1092        assert_eq!(status.len(), 2);
1093
1094        // 验证状态 / Verify status
1095        let stdio_status = status
1096            .iter()
1097            .find(|(name, _, _)| name == "test_stdio")
1098            .unwrap();
1099        assert!(!stdio_status.1); // 未激活 / Not active
1100
1101        let http_status = status
1102            .iter()
1103            .find(|(name, _, _)| name == "test_http")
1104            .unwrap();
1105        assert!(!http_status.1); // 未激活 / Not active
1106    }
1107
1108    #[tokio::test]
1109    async fn test_add_server() {
1110        let manager = MCPServerManager::new();
1111
1112        // 添加服务器配置 / Add server configuration
1113        let config = MCPServerConfig::Stdio(StdioServerConfig {
1114            name: "test_server".to_string(),
1115            disabled: false,
1116            forbidden_tools: vec![],
1117            tool_meta: HashMap::new(),
1118            default_tool_meta: None,
1119            vrl: None,
1120            server_parameters: StdioServerParameters {
1121                command: "echo".to_string(),
1122                args: vec![],
1123                env: HashMap::new(),
1124                cwd: None,
1125            },
1126        });
1127
1128        let result = manager.add_or_update_server(config).await;
1129        assert!(result.is_ok());
1130
1131        // 检查状态 / Check status
1132        let status = manager.get_server_status().await;
1133        assert_eq!(status.len(), 1);
1134        assert_eq!(status[0].0, "test_server");
1135    }
1136
1137    #[tokio::test]
1138    async fn test_remove_server() {
1139        let manager = MCPServerManager::new();
1140
1141        // 添加服务器 / Add server
1142        let config = MCPServerConfig::Stdio(StdioServerConfig {
1143            name: "test_server".to_string(),
1144            disabled: false,
1145            forbidden_tools: vec![],
1146            tool_meta: HashMap::new(),
1147            default_tool_meta: None,
1148            vrl: None,
1149            server_parameters: StdioServerParameters {
1150                command: "echo".to_string(),
1151                args: vec![],
1152                env: HashMap::new(),
1153                cwd: None,
1154            },
1155        });
1156
1157        manager.add_or_update_server(config).await.unwrap();
1158
1159        // 移除服务器 / Remove server
1160        let result = manager.remove_server("test_server").await;
1161        assert!(result.is_ok());
1162
1163        // 检查状态 / Check status
1164        let status = manager.get_server_status().await;
1165        assert!(status.is_empty());
1166    }
1167
1168    #[tokio::test]
1169    async fn test_tool_conflict_detection() {
1170        let manager = MCPServerManager::new();
1171
1172        // 创建两个服务器,有同名工具 / Create two servers with same tool name
1173        let configs = vec![
1174            // 第一个服务器 / First server
1175            MCPServerConfig::Stdio(StdioServerConfig {
1176                name: "server1".to_string(),
1177                disabled: false,
1178                forbidden_tools: vec![],
1179                tool_meta: HashMap::new(),
1180                default_tool_meta: None,
1181                vrl: None,
1182                server_parameters: StdioServerParameters {
1183                    command: "echo".to_string(),
1184                    args: vec!["server1".to_string()],
1185                    env: HashMap::new(),
1186                    cwd: None,
1187                },
1188            }),
1189            // 第二个服务器 / Second server
1190            MCPServerConfig::Stdio(StdioServerConfig {
1191                name: "server2".to_string(),
1192                disabled: false,
1193                forbidden_tools: vec![],
1194                tool_meta: HashMap::new(),
1195                default_tool_meta: None,
1196                vrl: None,
1197                server_parameters: StdioServerParameters {
1198                    command: "echo".to_string(),
1199                    args: vec!["server2".to_string()],
1200                    env: HashMap::new(),
1201                    cwd: None,
1202                },
1203            }),
1204        ];
1205
1206        // 初始化应该成功 / Initialization should succeed
1207        let result = manager.initialize(configs).await;
1208        assert!(result.is_ok());
1209
1210        // 启动所有服务器 / Start all servers
1211        let _result = manager.start_all().await;
1212        // 可能会因为工具冲突而失败,这是预期的
1213        // Might fail due to tool conflicts, which is expected
1214
1215        // 等待连接建立 / Wait for connections to establish
1216        sleep(Duration::from_millis(200)).await;
1217    }
1218
1219    #[tokio::test]
1220    async fn test_health_check_config() {
1221        let manager = MCPServerManager::new();
1222
1223        // 验证默认配置 / Verify default config
1224        let config = manager.get_health_check_config().await;
1225        assert_eq!(config.interval_secs, 30);
1226        assert_eq!(config.timeout_secs, 5);
1227        assert!(config.enabled);
1228
1229        // 更新配置 / Update config
1230        let new_config = HealthCheckConfig {
1231            interval_secs: 60,
1232            timeout_secs: 10,
1233            enabled: false,
1234        };
1235        manager.set_health_check_config(new_config.clone()).await;
1236
1237        let updated = manager.get_health_check_config().await;
1238        assert_eq!(updated.interval_secs, 60);
1239        assert_eq!(updated.timeout_secs, 10);
1240        assert!(!updated.enabled);
1241    }
1242
1243    #[tokio::test]
1244    async fn test_reconnect_policy() {
1245        let manager = MCPServerManager::new();
1246
1247        // 验证默认策略 / Verify default policy
1248        let policy = manager.get_reconnect_policy().await;
1249        assert!(policy.enabled);
1250        assert_eq!(policy.max_retries, 5);
1251        assert_eq!(policy.initial_delay_ms, 1000);
1252        assert_eq!(policy.max_delay_ms, 30000);
1253        assert_eq!(policy.backoff_factor, 2.0);
1254
1255        // 测试延迟计算 / Test delay calculation
1256        assert_eq!(policy.calculate_delay(0).as_millis(), 1000);
1257        assert_eq!(policy.calculate_delay(1).as_millis(), 2000);
1258        assert_eq!(policy.calculate_delay(2).as_millis(), 4000);
1259        assert_eq!(policy.calculate_delay(3).as_millis(), 8000);
1260
1261        // 测试 should_retry / Test should_retry
1262        assert!(policy.should_retry(0));
1263        assert!(policy.should_retry(4));
1264        assert!(!policy.should_retry(5)); // max is 5
1265
1266        // 测试无限重试 / Test infinite retry
1267        let infinite_policy = ReconnectPolicy {
1268            enabled: true,
1269            max_retries: 0,
1270            ..Default::default()
1271        };
1272        assert!(infinite_policy.should_retry(100));
1273    }
1274
1275    #[tokio::test]
1276    async fn test_retry_counts() {
1277        let manager = MCPServerManager::new();
1278
1279        // 初始应该为空 / Should be empty initially
1280        let counts = manager.get_retry_counts().await;
1281        assert!(counts.is_empty());
1282
1283        // 通过内部操作添加重试计数 / Add retry counts through internal operation
1284        {
1285            manager
1286                .retry_counts
1287                .write()
1288                .await
1289                .insert("server1".to_string(), 3);
1290            manager
1291                .retry_counts
1292                .write()
1293                .await
1294                .insert("server2".to_string(), 5);
1295        }
1296
1297        let counts = manager.get_retry_counts().await;
1298        assert_eq!(counts.get("server1"), Some(&3));
1299        assert_eq!(counts.get("server2"), Some(&5));
1300
1301        // 重置单个服务器 / Reset single server
1302        manager.reset_retry_count("server1").await;
1303        let counts = manager.get_retry_counts().await;
1304        assert!(!counts.contains_key("server1"));
1305        assert_eq!(counts.get("server2"), Some(&5));
1306
1307        // 重置所有 / Reset all
1308        manager.reset_all_retry_counts().await;
1309        let counts = manager.get_retry_counts().await;
1310        assert!(counts.is_empty());
1311    }
1312
1313    #[tokio::test]
1314    async fn test_manager_with_custom_config() {
1315        let health_config = HealthCheckConfig {
1316            interval_secs: 15,
1317            timeout_secs: 3,
1318            enabled: true,
1319        };
1320        let reconnect_policy = ReconnectPolicy {
1321            enabled: true,
1322            max_retries: 10,
1323            initial_delay_ms: 500,
1324            max_delay_ms: 60000,
1325            backoff_factor: 1.5,
1326        };
1327
1328        let manager =
1329            MCPServerManager::with_config(health_config.clone(), reconnect_policy.clone());
1330
1331        let got_health = manager.get_health_check_config().await;
1332        assert_eq!(got_health.interval_secs, 15);
1333        assert_eq!(got_health.timeout_secs, 3);
1334
1335        let got_reconnect = manager.get_reconnect_policy().await;
1336        assert_eq!(got_reconnect.max_retries, 10);
1337        assert_eq!(got_reconnect.initial_delay_ms, 500);
1338    }
1339}