Skip to main content

sh_layer4/mcp_bridge/
bridge.rs

1//! MCP 桥接器
2//!
3//! MCP (Model Context Protocol) 协议的主要实现。
4//!
5//! # 功能
6//!
7//! - MCP 协议消息处理与路由
8//! - 工具注册与发现
9//! - 多种传输层支持 (stdio, tcp, unix socket)
10//! - 请求/响应生命周期管理
11//!
12//! # 用法示例
13//!
14//! ```rust,ignore
15//! use sh_layer4::mcp_bridge::{McpBridge, McpBridgeConfig, ToolDefinition, ToolResult, ContentBlock};
16//!
17//! // 创建桥接器
18//! let config = McpBridgeConfig {
19//!     server_name: "my-server".to_string(),
20//!     server_version: "1.0.0".to_string(),
21//!     request_timeout_ms: 30000,
22//!     max_concurrent_requests: 100,
23//! };
24//! let bridge = McpBridge::new(config);
25//!
26//! // 注册工具
27//! bridge.register_simple_tool("echo", "Echo input text", |_name, args| {
28//!     Ok(ToolResult {
29//!         is_error: false,
30//!         content: vec![ContentBlock::Text { text: args.to_string() }],
31//!     })
32//! });
33//!
34//! // 启动服务
35//! bridge.start().await?;
36//! ```
37
38use parking_lot::RwLock;
39use serde_json::Value;
40use std::collections::HashMap;
41use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
42use std::sync::Arc;
43use tokio::sync::mpsc;
44use tracing::{debug, info, warn};
45
46use super::handler::{DefaultHandler, McpHandler, ToolExecutor};
47use super::protocol::{
48    McpMessage, McpNotification, McpRequest, McpResponse, RequestId, ToolDefinition, ToolResult,
49};
50use super::transport::McpTransport;
51use anyhow::{anyhow, Result};
52
53/// MCP 桥接器配置
54///
55/// 配置 MCP 服务端的基本参数。
56#[derive(Debug, Clone)]
57pub struct McpBridgeConfig {
58    /// 服务端名称,用于客户端识别
59    pub server_name: String,
60    /// 服务端版本号
61    pub server_version: String,
62    /// 请求超时时间 (毫秒)
63    pub request_timeout_ms: u64,
64    /// 最大并发请求数
65    pub max_concurrent_requests: usize,
66}
67
68impl Default for McpBridgeConfig {
69    fn default() -> Self {
70        Self {
71            server_name: "Continuum".to_string(),
72            server_version: "0.1.0".to_string(),
73            request_timeout_ms: 30000,
74            max_concurrent_requests: 100,
75        }
76    }
77}
78
79/// MCP 桥接器
80///
81/// MCP 协议的核心实现,负责消息处理、工具注册和传输层管理。
82///
83/// # 线程安全
84///
85/// 所有内部状态都通过 `RwLock` 或原子类型保护,支持多线程并发访问。
86pub struct McpBridge {
87    /// 传输层实例
88    transport: RwLock<Option<Arc<dyn McpTransport>>>,
89    /// 消息处理器
90    handler: Arc<DefaultHandler>,
91    /// 配置
92    config: McpBridgeConfig,
93    /// 请求 ID 计数器 (原子递增)
94    request_id_counter: AtomicU64,
95    /// 待处理响应映射表
96    pending_responses: Arc<RwLock<HashMap<RequestId, mpsc::Sender<McpResponse>>>>,
97    /// 运行状态
98    running: Arc<AtomicBool>,
99}
100
101impl McpBridge {
102    /// 创建新的 MCP 桥接器
103    ///
104    /// # 参数
105    ///
106    /// - `config`: 桥接器配置
107    ///
108    /// # 示例
109    ///
110    /// ```rust,ignore
111    /// let config = McpBridgeConfig::default();
112    /// let bridge = McpBridge::new(config);
113    /// ```
114    pub fn new(config: McpBridgeConfig) -> Self {
115        let handler = DefaultHandler::new(&config.server_name, &config.server_version);
116        Self {
117            transport: RwLock::new(None),
118            handler: Arc::new(handler),
119            config,
120            request_id_counter: AtomicU64::new(0),
121            pending_responses: Arc::new(RwLock::new(HashMap::new())),
122            running: Arc::new(AtomicBool::new(false)),
123        }
124    }
125
126    /// 设置传输层 (Builder 模式)
127    ///
128    /// # 参数
129    ///
130    /// - `transport`: 传输层实现 (StdioTransport, TcpTransport 等)
131    ///
132    /// # 示例
133    ///
134    /// ```rust,ignore
135    /// let bridge = McpBridge::new(config)
136    ///     .with_transport(Box::new(StdioTransport::new()));
137    /// ```
138    pub fn with_transport(self, transport: Box<dyn McpTransport>) -> Self {
139        *self.transport.write() = Some(Arc::from(transport));
140        self
141    }
142
143    /// 注册工具及其执行器
144    ///
145    /// # 参数
146    ///
147    /// - `tool`: 工具定义
148    /// - `executor`: 工具执行器实现
149    pub fn register_tool(&self, tool: ToolDefinition, executor: Arc<dyn ToolExecutor>) {
150        self.handler.register_tool(tool, executor);
151    }
152
153    /// 注册简单工具 (便捷方法)
154    ///
155    /// 适用于不需要复杂执行器逻辑的工具。
156    ///
157    /// # 参数
158    ///
159    /// - `name`: 工具名称
160    /// - `description`: 工具描述
161    /// - `executor`: 执行函数,接收工具名和参数,返回执行结果
162    ///
163    /// # 示例
164    ///
165    /// ```rust,ignore
166    /// bridge.register_simple_tool("echo", "Echo input", |_name, args| {
167    ///     Ok(ToolResult {
168    ///         is_error: false,
169    ///         content: vec![ContentBlock::Text { text: args.to_string() }],
170    ///     })
171    /// });
172    /// ```
173    pub fn register_simple_tool<F>(&self, name: &str, description: &str, executor: F)
174    where
175        F: Fn(&str, Value) -> Result<ToolResult> + Send + Sync + 'static,
176    {
177        let tool = ToolDefinition {
178            name: name.to_string(),
179            description: Some(description.to_string()),
180            input_schema: None,
181        };
182        self.register_tool(tool, Arc::new(super::handler::SimpleToolExecutor(executor)));
183    }
184
185    /// 生成下一个请求 ID (内部方法)
186    fn next_request_id(&self) -> RequestId {
187        RequestId::Number(self.request_id_counter.fetch_add(1, Ordering::SeqCst) as i64)
188    }
189
190    /// 启动桥接器
191    ///
192    /// 初始化消息处理循环,开始接收和处理 MCP 消息。
193    ///
194    /// # 错误
195    ///
196    /// 如果传输层未初始化或启动失败,返回错误。
197    pub async fn start(&self) -> Result<()> {
198        self.running.store(true, Ordering::SeqCst);
199
200        // Clone transport Arc while holding lock
201        let transport_opt = {
202            let transport_guard = self.transport.read();
203            transport_guard.clone()
204        };
205
206        // 启动消息处理循环
207        let handler = self.handler.clone();
208        let pending = self.pending_responses.clone();
209        let running = self.running.clone();
210
211        tokio::spawn(async move {
212            info!("MCP message loop started");
213
214            if transport_opt.is_none() {
215                info!("No transport configured, message loop will idle");
216            }
217
218            loop {
219                // Check if we should stop
220                if !running.load(Ordering::SeqCst) {
221                    info!("MCP message loop stopping");
222                    break;
223                }
224
225                // Read message from transport
226                if let Some(ref t) = transport_opt {
227                    match t.receive().await {
228                        Ok(Some(message)) => {
229                            // Process received message
230                            match message {
231                                McpMessage::Request(request) => {
232                                    // Handle incoming request
233                                    match handler.handle(&request).await {
234                                        Ok(response) => {
235                                            // Send response back
236                                            if let Err(e) =
237                                                t.send(&McpMessage::Response(response)).await
238                                            {
239                                                warn!("Failed to send response: {}", e);
240                                            }
241                                        }
242                                        Err(e) => {
243                                            warn!(
244                                                "Handler error for request {:?}: {}",
245                                                request.id, e
246                                            );
247                                        }
248                                    }
249                                }
250                                McpMessage::Notification(notification) => {
251                                    // Handle notification (no response needed)
252                                    if let Err(e) = handler.handle_notification(&notification).await
253                                    {
254                                        warn!("Notification handler error: {}", e);
255                                    }
256                                }
257                                McpMessage::Response(response) => {
258                                    // Response to our request - find matching pending request
259                                    // Release lock before await to satisfy Send requirement
260                                    let sender_opt = pending.write().remove(&response.id);
261                                    if let Some(sender) = sender_opt {
262                                        if let Err(e) = sender.send(response).await {
263                                            warn!(
264                                                "Failed to forward response to pending request: {}",
265                                                e
266                                            );
267                                        }
268                                    } else {
269                                        debug!(
270                                            "Received response for unknown request {:?}",
271                                            response.id
272                                        );
273                                    }
274                                }
275                                McpMessage::Error(error) => {
276                                    warn!("Received error: {:?}", error);
277                                }
278                            }
279                        }
280                        Ok(None) => {
281                            // No message available, brief sleep
282                            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
283                        }
284                        Err(e) => {
285                            warn!("Transport receive error: {}", e);
286                            // Brief pause before retry to avoid tight loop on error
287                            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
288                        }
289                    }
290                } else {
291                    // No transport configured
292                    tokio::time::sleep(std::time::Duration::from_millis(100)).await;
293                }
294            }
295
296            // Clean up pending responses when stopping
297            pending.write().clear();
298            info!("MCP message loop stopped");
299        });
300
301        Ok(())
302    }
303
304    /// 停止桥接器
305    ///
306    /// 关闭消息处理循环并释放传输层资源。
307    ///
308    /// # 错误
309    ///
310    /// 如果传输层关闭失败,返回错误。
311    pub async fn stop(&self) -> Result<()> {
312        self.running.store(false, Ordering::SeqCst);
313
314        let transport = self.transport.write().take();
315        if let Some(transport) = transport {
316            transport.close().await?;
317        }
318
319        Ok(())
320    }
321
322    /// 发送请求并等待响应
323    ///
324    /// 向 MCP 服务端发送请求消息,并等待对应的响应。
325    ///
326    /// # 参数
327    ///
328    /// - `method`: MCP 方法名 (如 "tools/list", "tools/call")
329    /// - `params`: 可选的请求参数
330    ///
331    /// # 返回
332    ///
333    /// 返回对应的响应消息。
334    ///
335    /// # 错误
336    ///
337    /// - 如果传输层未初始化,返回 `Transport not initialized` 错误
338    /// - 如果请求超时,返回超时错误
339    pub async fn request(&self, method: &str, params: Option<Value>) -> Result<McpResponse> {
340        let id = self.next_request_id();
341        let timeout_duration = std::time::Duration::from_millis(self.config.request_timeout_ms);
342
343        // 创建响应接收通道
344        let (tx, mut rx) = mpsc::channel::<McpResponse>(1);
345
346        // 注册待处理请求
347        {
348            self.pending_responses.write().insert(id.clone(), tx);
349        }
350
351        let request = McpRequest {
352            id: id.clone(),
353            method: method.to_string(),
354            params,
355        };
356
357        let message = McpMessage::Request(request);
358
359        // 发送请求
360        let transport = {
361            let transport_guard = self.transport.read();
362            transport_guard
363                .as_ref()
364                .ok_or_else(|| anyhow!("Transport not initialized"))?
365                .clone()
366        };
367        transport.send(&message).await?;
368
369        // 等待响应,带超时
370        let result = tokio::time::timeout(timeout_duration, rx.recv()).await;
371
372        // 清理待处理请求(无论成功还是失败)
373        self.pending_responses.write().remove(&id);
374
375        match result {
376            Ok(Some(response)) => Ok(response),
377            Ok(None) => Err(anyhow!("Response channel closed")),
378            Err(_) => Err(anyhow!(
379                "Request timeout after {}ms",
380                self.config.request_timeout_ms
381            )),
382        }
383    }
384
385    /// 发送通知消息
386    ///
387    /// 向 MCP 服务端发送通知消息,不等待响应。
388    ///
389    /// # 参数
390    ///
391    /// - `method`: 通知方法名 (如 "notifications/initialized")
392    /// - `params`: 可选的通知参数
393    ///
394    /// # 错误
395    ///
396    /// 如果传输层未初始化,返回 `Transport not initialized` 错误。
397    #[allow(clippy::await_holding_lock)]
398    pub async fn notify(&self, method: &str, params: Option<Value>) -> Result<()> {
399        let notification = McpNotification {
400            method: method.to_string(),
401            params,
402        };
403
404        let message = McpMessage::Notification(notification);
405
406        let transport_guard = self.transport.read();
407        let transport = transport_guard
408            .as_ref()
409            .ok_or_else(|| anyhow!("Transport not initialized"))?;
410        transport.send(&message).await?;
411
412        Ok(())
413    }
414
415    /// 列出可用工具
416    ///
417    /// 请求 MCP 服务端列出所有可用工具。
418    ///
419    /// # 返回
420    ///
421    /// 返回工具定义列表。
422    ///
423    /// # 错误
424    ///
425    /// 如果请求失败或响应解析失败,返回错误。
426    pub async fn list_tools(&self) -> Result<Vec<ToolDefinition>> {
427        let response = self.request("tools/list", None).await?;
428
429        if let Some(result) = response.result {
430            let tools: Vec<ToolDefinition> = serde_json::from_value(
431                result.get("tools").cloned().unwrap_or(Value::Array(vec![])),
432            )?;
433            Ok(tools)
434        } else {
435            Ok(vec![])
436        }
437    }
438
439    /// 调用工具
440    ///
441    /// 调用 MCP 服务端上的指定工具。
442    ///
443    /// # 参数
444    ///
445    /// - `name`: 工具名称
446    /// - `arguments`: 工具参数 (JSON 格式)
447    ///
448    /// # 返回
449    ///
450    /// 返回工具执行结果。
451    ///
452    /// # 错误
453    ///
454    /// - 如果工具不存在或执行失败,返回错误
455    /// - 如果响应解析失败,返回错误
456    pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<ToolResult> {
457        let params = serde_json::json!({
458            "name": name,
459            "arguments": arguments
460        });
461
462        let response = self.request("tools/call", Some(params)).await?;
463
464        if let Some(result) = response.result {
465            let tool_result: ToolResult = serde_json::from_value(result)?;
466            Ok(tool_result)
467        } else if let Some(error) = response.error {
468            Err(anyhow!("Tool call error: {}", error.message))
469        } else {
470            Err(anyhow!("Unknown error"))
471        }
472    }
473
474    /// 初始化 MCP 连接
475    ///
476    /// 与 MCP 服务端进行握手,交换协议版本和能力信息。
477    ///
478    /// # 参数
479    ///
480    /// - `client_info`: 客户端名称
481    /// - `version`: 客户端版本号
482    ///
483    /// # 流程
484    ///
485    /// 1. 发送 `initialize` 请求
486    /// 2. 接收服务端响应
487    /// 3. 发送 `notifications/initialized` 通知
488    ///
489    /// # 错误
490    ///
491    /// 如果初始化请求失败,返回 `Initialize failed` 错误。
492    pub async fn initialize(&self, client_info: &str, version: &str) -> Result<()> {
493        let params = serde_json::json!({
494            "protocol_version": "2024-11-05",
495            "capabilities": {},
496            "client_info": {
497                "name": client_info,
498                "version": version
499            }
500        });
501
502        let response = self.request("initialize", Some(params)).await?;
503
504        if response.error.is_some() {
505            return Err(anyhow!("Initialize failed"));
506        }
507
508        // 发送 initialized 通知
509        self.notify("notifications/initialized", None).await?;
510
511        Ok(())
512    }
513
514    /// 检查桥接器是否正在运行
515    ///
516    /// # 返回
517    ///
518    /// 如果桥接器已启动且正在处理消息,返回 `true`;否则返回 `false`。
519    pub fn is_running(&self) -> bool {
520        self.running.load(Ordering::SeqCst)
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527    use crate::mcp_bridge::protocol::ContentBlock;
528
529    #[tokio::test]
530    async fn test_bridge_creation() {
531        let config = McpBridgeConfig::default();
532        let bridge = McpBridge::new(config);
533
534        assert!(!bridge.is_running());
535    }
536
537    #[tokio::test]
538    async fn test_register_tool() {
539        let bridge = McpBridge::new(McpBridgeConfig::default());
540
541        bridge.register_simple_tool("test_tool", "A test tool", |_name, _args| {
542            Ok(ToolResult {
543                is_error: false,
544                content: vec![ContentBlock::Text {
545                    text: "OK".to_string(),
546                }],
547            })
548        });
549
550        // 工具已注册
551    }
552
553    #[tokio::test]
554    async fn test_next_request_id() {
555        let bridge = McpBridge::new(McpBridgeConfig::default());
556
557        let id1 = bridge.next_request_id();
558        let id2 = bridge.next_request_id();
559
560        assert_ne!(id1, id2);
561    }
562
563    #[test]
564    fn test_config_default() {
565        let config = McpBridgeConfig::default();
566        assert_eq!(config.server_name, "Continuum");
567        assert_eq!(config.request_timeout_ms, 30000);
568    }
569}