Skip to main content

sh_layer4/mcp_bridge/
client.rs

1//! MCP 客户端管理器
2//!
3//! 管理多个 MCP 服务器连接。
4//!
5//! 支持真实的 stdio 和 TCP 连接,完整 MCP 协议握手。
6
7use super::protocol::{McpMessage, McpRequest, RequestId, ToolDefinition, ToolResult};
8#[cfg(unix)]
9use super::transport::UnixSocketTransport;
10use super::transport::{McpTransport, McpTransportType, StdioTransport, TcpTransport};
11use anyhow::{anyhow, Result};
12use parking_lot::RwLock as ParkingRwLock;
13use serde_json::Value;
14use std::collections::HashMap;
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::Arc;
17use tracing::{debug, info, warn};
18
19/// MCP 服务器配置
20#[derive(Debug, Clone)]
21pub struct McpServerConfig {
22    /// 服务器名称
23    pub name: String,
24    /// 传输类型
25    pub transport: McpTransportType,
26    /// 自动重连
27    pub auto_reconnect: bool,
28    /// 重连间隔 (毫秒)
29    pub reconnect_interval_ms: u64,
30}
31
32/// 已连接的 MCP 服务器
33struct ConnectedServer {
34    #[allow(dead_code)]
35    config: McpServerConfig,
36    transport: Arc<dyn McpTransport>,
37    tools: Vec<ToolDefinition>,
38}
39
40/// MCP 客户端管理器
41pub struct McpClientManager {
42    /// 服务器配置
43    configs: ParkingRwLock<HashMap<String, McpServerConfig>>,
44    /// 已连接的服务器
45    servers: ParkingRwLock<HashMap<String, ConnectedServer>>,
46    /// 工具到服务器的映射
47    tool_mapping: ParkingRwLock<HashMap<String, String>>,
48    /// 请求 ID 计数器
49    request_id_counter: AtomicU64,
50}
51
52impl Default for McpClientManager {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl McpClientManager {
59    /// 创建新的管理器
60    pub fn new() -> Self {
61        Self {
62            configs: ParkingRwLock::new(HashMap::new()),
63            servers: ParkingRwLock::new(HashMap::new()),
64            tool_mapping: ParkingRwLock::new(HashMap::new()),
65            request_id_counter: AtomicU64::new(1),
66        }
67    }
68
69    /// 生成下一个请求 ID
70    fn next_request_id(&self) -> RequestId {
71        RequestId::Number(self.request_id_counter.fetch_add(1, Ordering::SeqCst) as i64)
72    }
73
74    /// 添加服务器配置
75    pub async fn add_server(&self, config: McpServerConfig) -> Result<()> {
76        let name = config.name.clone();
77        let mut configs = self.configs.write();
78        configs.insert(name, config);
79        Ok(())
80    }
81
82    /// 连接到服务器
83    pub async fn connect(&self, name: &str) -> Result<()> {
84        let config = {
85            let configs = self.configs.read();
86            configs
87                .get(name)
88                .ok_or_else(|| anyhow!("Server not found: {}", name))?
89                .clone()
90        };
91
92        info!(server = %name, transport = ?config.transport, "Connecting to MCP server");
93
94        // 根据传输类型创建真实连接
95        let transport: Arc<dyn McpTransport> = match &config.transport {
96            McpTransportType::Stdio { command, args } => {
97                // 创建真实的 Stdio 传输
98                let stdio_transport = StdioTransport::new(command, args)?;
99                stdio_transport.start(command, args).await?;
100                info!(server = %name, command = %command, "Stdio transport started");
101                Arc::new(stdio_transport)
102            }
103            McpTransportType::Tcp { addr } => {
104                // 创建真实的 TCP 连接
105                let tcp_transport = TcpTransport::connect(addr).await?;
106                info!(server = %name, addr = %addr, "TCP transport connected");
107                Arc::new(tcp_transport)
108            }
109            #[cfg(unix)]
110            McpTransportType::Unix { path } => {
111                // 创建真实的 Unix socket 连接
112                let unix_transport = UnixSocketTransport::connect(path).await?;
113                info!(server = %name, path = %path, "Unix socket transport connected");
114                Arc::new(unix_transport)
115            }
116        };
117
118        // 发送初始化请求(MCP 协议握手第一步)
119        let init_params = serde_json::json!({
120            "protocolVersion": "2024-11-05",
121            "capabilities": {
122                "roots": {
123                    "listChanged": true
124                },
125                "sampling": {}
126            },
127            "clientInfo": {
128                "name": "continuum",
129                "version": env!("CARGO_PKG_VERSION")
130            }
131        });
132
133        let request = McpRequest {
134            id: self.next_request_id(),
135            method: "initialize".to_string(),
136            params: Some(init_params),
137        };
138
139        debug!(server = %name, "Sending initialize request");
140        transport.send(&McpMessage::Request(request)).await?;
141
142        // 等待响应(带超时处理)
143        let response =
144            tokio::time::timeout(std::time::Duration::from_secs(30), transport.receive())
145                .await
146                .map_err(|_| anyhow!("Initialize timeout for server: {}", name))??;
147
148        match response {
149            Some(McpMessage::Response(response)) => {
150                if let Some(error) = &response.error {
151                    warn!(server = %name, code = ?error.code, message = %error.message, "Initialize failed");
152                    return Err(anyhow!(
153                        "Initialize failed (code {}): {}",
154                        error.code,
155                        error.message
156                    ));
157                }
158
159                // 记录服务器能力
160                if let Some(result) = &response.result {
161                    debug!(server = %name, result = ?result, "Server capabilities received");
162                    if let Some(server_info) = result.get("serverInfo") {
163                        info!(server = %name, server_info = ?server_info, "Connected to MCP server");
164                    }
165                }
166            }
167            Some(McpMessage::Error(error)) => {
168                warn!(server = %name, error = ?error, "Received error response");
169                return Err(anyhow!("Server error: {:?}", error));
170            }
171            Some(other) => {
172                warn!(server = %name, message = ?other, "Unexpected message type");
173                return Err(anyhow!("Unexpected response type during initialization"));
174            }
175            None => {
176                warn!(server = %name, "No response received");
177                return Err(anyhow!("No response from server during initialization"));
178            }
179        }
180
181        // 发送 initialized 通知(MCP 协议握手第二步)
182        let notification = McpMessage::Notification(super::protocol::McpNotification {
183            method: "notifications/initialized".to_string(),
184            params: None,
185        });
186        transport.send(&notification).await?;
187        debug!(server = %name, "Sent initialized notification");
188
189        // 列出可用工具
190        let list_tools_request = McpRequest {
191            id: self.next_request_id(),
192            method: "tools/list".to_string(),
193            params: None,
194        };
195        transport
196            .send(&McpMessage::Request(list_tools_request))
197            .await?;
198
199        let tools_response =
200            tokio::time::timeout(std::time::Duration::from_secs(10), transport.receive())
201                .await
202                .map_err(|_| anyhow!("Tools list timeout for server: {}", name))??;
203
204        let tools = match tools_response {
205            Some(McpMessage::Response(response)) => {
206                if let Some(result) = &response.result {
207                    if let Some(tools_array) = result.get("tools") {
208                        match serde_json::from_value::<Vec<ToolDefinition>>(tools_array.clone()) {
209                            Ok(t) => {
210                                info!(server = %name, tool_count = t.len(), "Tools discovered");
211                                t
212                            }
213                            Err(e) => {
214                                warn!(server = %name, error = %e, "Failed to parse tools");
215                                Vec::new()
216                            }
217                        }
218                    } else {
219                        Vec::new()
220                    }
221                } else {
222                    Vec::new()
223                }
224            }
225            _ => Vec::new(),
226        };
227
228        // 更新服务器状态
229        let mut servers = self.servers.write();
230        servers.insert(
231            name.to_string(),
232            ConnectedServer {
233                config,
234                transport,
235                tools,
236            },
237        );
238
239        info!(server = %name, "MCP connection established successfully");
240        Ok(())
241    }
242
243    /// 连接所有服务器
244    pub async fn connect_all(&self) -> Result<Vec<String>> {
245        let names: Vec<String> = self.configs.read().keys().cloned().collect();
246        let mut results = Vec::new();
247
248        for name in &names {
249            if self.connect(name).await.is_ok() {
250                results.push(name.clone());
251            }
252        }
253
254        Ok(results)
255    }
256
257    /// 断开服务器连接
258    pub async fn disconnect(&self, name: &str) -> Result<()> {
259        let server = {
260            let mut servers = self.servers.write();
261            servers.remove(name)
262        };
263        if let Some(s) = server {
264            s.transport.close().await?;
265        }
266        Ok(())
267    }
268
269    /// 获取服务器状态
270    pub fn get_server_status(&self, name: &str) -> Option<bool> {
271        let servers = self.servers.read();
272        servers.get(name).map(|_| true)
273    }
274
275    /// 列出所有服务器
276    pub fn list_servers(&self) -> Vec<(String, bool)> {
277        let servers = self.servers.read();
278        let configs = self.configs.read();
279
280        let mut result = Vec::new();
281        for name in configs.keys() {
282            let connected = servers.contains_key(name);
283            result.push((name.clone(), connected));
284        }
285        result
286    }
287
288    /// 获取所有可用工具
289    pub fn list_all_tools(&self) -> Vec<(String, ToolDefinition)> {
290        let servers = self.servers.read();
291        let mut tools = Vec::new();
292
293        for (server_name, server) in servers.iter() {
294            for tool in &server.tools {
295                tools.push((server_name.clone(), tool.clone()));
296            }
297        }
298
299        tools
300    }
301
302    /// 调用工具
303    pub async fn call_tool(&self, tool_name: &str, arguments: Value) -> Result<ToolResult> {
304        // 查找工具所在的服务器
305        let server_name = {
306            let tool_mapping = self.tool_mapping.read();
307            tool_mapping
308                .get(tool_name)
309                .ok_or_else(|| anyhow!("Tool not found: {}", tool_name))?
310                .clone()
311        };
312
313        // 获取服务器和传输层
314        let (transport, request_id) = {
315            let servers = self.servers.read();
316            let server = servers
317                .get(&server_name)
318                .ok_or_else(|| anyhow!("Server not found: {}", server_name))?;
319
320            (server.transport.clone(), self.next_request_id())
321        };
322
323        // 构建请求
324        let params = serde_json::json!({
325            "name": tool_name,
326            "arguments": arguments
327        });
328
329        let request = McpRequest {
330            id: request_id,
331            method: "tools/call".to_string(),
332            params: Some(params),
333        };
334
335        // 发送请求
336        transport.send(&McpMessage::Request(request)).await?;
337
338        // 接收响应
339        match transport.receive().await? {
340            Some(McpMessage::Response(response)) => {
341                if let Some(error) = response.error {
342                    return Err(anyhow!("Tool call error: {}", error.message));
343                }
344
345                if let Some(result) = response.result {
346                    let tool_result: ToolResult =
347                        serde_json::from_value(result).unwrap_or_else(|_| ToolResult {
348                            is_error: false,
349                            content: vec![super::protocol::ContentBlock::Text {
350                                text: "Tool executed successfully".to_string(),
351                            }],
352                        });
353                    Ok(tool_result)
354                } else {
355                    Err(anyhow!("Empty response"))
356                }
357            }
358            Some(McpMessage::Error(error)) => Err(anyhow!("Error: {:?}", error)),
359            _ => Err(anyhow!("Unexpected response type")),
360        }
361    }
362
363    /// 注册工具到服务器
364    pub async fn register_tools(
365        &self,
366        server_name: &str,
367        tools: Vec<ToolDefinition>,
368    ) -> Result<()> {
369        let mut servers = self.servers.write();
370        let server = servers
371            .get_mut(server_name)
372            .ok_or_else(|| anyhow!("Server not found: {}", server_name))?;
373
374        let mut tool_mapping = self.tool_mapping.write();
375        for tool in &tools {
376            tool_mapping.insert(tool.name.clone(), server_name.to_string());
377        }
378
379        server.tools = tools;
380        Ok(())
381    }
382
383    /// 渲染服务器列表
384    pub fn render_status(&self) -> String {
385        let servers = self.servers.read();
386        let configs = self.configs.read();
387        let mut output = String::new();
388
389        output.push_str("MCP Servers:\n");
390
391        if configs.is_empty() {
392            output.push_str("  No servers configured\n");
393        } else {
394            for name in configs.keys() {
395                let server = servers.get(name);
396                let status = if server.is_some() { "🟢" } else { "🔴" };
397                let tool_count = server.map(|s| s.tools.len()).unwrap_or(0);
398                output.push_str(&format!("  {} {} ({} tools)\n", status, name, tool_count));
399            }
400        }
401
402        output
403    }
404}
405
406/// 预设的 MCP 服务器配置
407pub fn preset_servers() -> Vec<McpServerConfig> {
408    vec![
409        // 文件系统 MCP
410        McpServerConfig {
411            name: "filesystem".to_string(),
412            transport: McpTransportType::Stdio {
413                command: "mcp-server-filesystem".to_string(),
414                args: vec!["--root".to_string(), ".".to_string()],
415            },
416            auto_reconnect: true,
417            reconnect_interval_ms: 5000,
418        },
419        // GitHub MCP
420        McpServerConfig {
421            name: "github".to_string(),
422            transport: McpTransportType::Stdio {
423                command: "mcp-server-github".to_string(),
424                args: vec![],
425            },
426            auto_reconnect: true,
427            reconnect_interval_ms: 5000,
428        },
429        // Playwright MCP - 浏览器自动化
430        McpServerConfig {
431            name: "playwright".to_string(),
432            transport: McpTransportType::Stdio {
433                command: "npx".to_string(),
434                args: vec![
435                    "@playwright/mcp@latest".to_string(),
436                    "--headless".to_string(),
437                    "--browser".to_string(),
438                    "chrome".to_string(),
439                ],
440            },
441            auto_reconnect: true,
442            reconnect_interval_ms: 5000,
443        },
444    ]
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    #[tokio::test]
452    async fn test_manager_creation() {
453        let manager = McpClientManager::new();
454        let servers = manager.list_servers();
455        assert!(servers.is_empty());
456    }
457
458    #[tokio::test]
459    async fn test_add_server() {
460        let manager = McpClientManager::new();
461
462        let config = McpServerConfig {
463            name: "test".to_string(),
464            transport: McpTransportType::Stdio {
465                command: "test-command".to_string(),
466                args: vec![],
467            },
468            auto_reconnect: false,
469            reconnect_interval_ms: 1000,
470        };
471
472        manager.add_server(config).await.unwrap();
473        let servers = manager.list_servers();
474        assert_eq!(servers.len(), 1);
475        assert_eq!(servers[0].0, "test");
476        assert!(!servers[0].1); // 未连接
477    }
478
479    #[test]
480    fn test_preset_servers() {
481        let presets = preset_servers();
482        assert!(!presets.is_empty());
483        assert!(presets.iter().any(|s| s.name == "filesystem"));
484        assert!(presets.iter().any(|s| s.name == "github"));
485    }
486
487    #[test]
488    fn test_request_id_generation() {
489        let manager = McpClientManager::new();
490        let id1 = manager.next_request_id();
491        let id2 = manager.next_request_id();
492        assert_ne!(id1, id2);
493    }
494}