sh_layer4/mcp_bridge/bridge.rs
1//! MCP 桥接器
2//!
3//! MCP (Model Context Protocol) 协议的主要实现。
4//!
5//! # 功能
6//!
7//! - MCP 协议消息处理与路由
8//! - 工具注册与发现
9//! - 多种传输层支持 (stdio, tcp, unix socket)
10//! - 请求/响应生命周期管理
11//!
12//! # 用法示例
13//!
14//! ```rust,ignore
15//! use sh_layer4::mcp_bridge::{McpBridge, McpBridgeConfig, ToolDefinition, ToolResult, ContentBlock};
16//!
17//! // 创建桥接器
18//! let config = McpBridgeConfig {
19//! server_name: "my-server".to_string(),
20//! server_version: "1.0.0".to_string(),
21//! request_timeout_ms: 30000,
22//! max_concurrent_requests: 100,
23//! };
24//! let bridge = McpBridge::new(config);
25//!
26//! // 注册工具
27//! bridge.register_simple_tool("echo", "Echo input text", |_name, args| {
28//! Ok(ToolResult {
29//! is_error: false,
30//! content: vec![ContentBlock::Text { text: args.to_string() }],
31//! })
32//! });
33//!
34//! // 启动服务
35//! bridge.start().await?;
36//! ```
37
38use parking_lot::RwLock;
39use serde_json::Value;
40use std::collections::HashMap;
41use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
42use std::sync::Arc;
43use tokio::sync::mpsc;
44use tracing::{debug, info, warn};
45
46use super::handler::{DefaultHandler, McpHandler, ToolExecutor};
47use super::protocol::{
48 McpMessage, McpNotification, McpRequest, McpResponse, RequestId, ToolDefinition, ToolResult,
49};
50use super::transport::McpTransport;
51use anyhow::{anyhow, Result};
52
53/// MCP 桥接器配置
54///
55/// 配置 MCP 服务端的基本参数。
56#[derive(Debug, Clone)]
57pub struct McpBridgeConfig {
58 /// 服务端名称,用于客户端识别
59 pub server_name: String,
60 /// 服务端版本号
61 pub server_version: String,
62 /// 请求超时时间 (毫秒)
63 pub request_timeout_ms: u64,
64 /// 最大并发请求数
65 pub max_concurrent_requests: usize,
66}
67
68impl Default for McpBridgeConfig {
69 fn default() -> Self {
70 Self {
71 server_name: "Continuum".to_string(),
72 server_version: "0.1.0".to_string(),
73 request_timeout_ms: 30000,
74 max_concurrent_requests: 100,
75 }
76 }
77}
78
79/// MCP 桥接器
80///
81/// MCP 协议的核心实现,负责消息处理、工具注册和传输层管理。
82///
83/// # 线程安全
84///
85/// 所有内部状态都通过 `RwLock` 或原子类型保护,支持多线程并发访问。
86pub struct McpBridge {
87 /// 传输层实例
88 transport: RwLock<Option<Arc<dyn McpTransport>>>,
89 /// 消息处理器
90 handler: Arc<DefaultHandler>,
91 /// 配置
92 config: McpBridgeConfig,
93 /// 请求 ID 计数器 (原子递增)
94 request_id_counter: AtomicU64,
95 /// 待处理响应映射表
96 pending_responses: Arc<RwLock<HashMap<RequestId, mpsc::Sender<McpResponse>>>>,
97 /// 运行状态
98 running: Arc<AtomicBool>,
99}
100
101impl McpBridge {
102 /// 创建新的 MCP 桥接器
103 ///
104 /// # 参数
105 ///
106 /// - `config`: 桥接器配置
107 ///
108 /// # 示例
109 ///
110 /// ```rust,ignore
111 /// let config = McpBridgeConfig::default();
112 /// let bridge = McpBridge::new(config);
113 /// ```
114 pub fn new(config: McpBridgeConfig) -> Self {
115 let handler = DefaultHandler::new(&config.server_name, &config.server_version);
116 Self {
117 transport: RwLock::new(None),
118 handler: Arc::new(handler),
119 config,
120 request_id_counter: AtomicU64::new(0),
121 pending_responses: Arc::new(RwLock::new(HashMap::new())),
122 running: Arc::new(AtomicBool::new(false)),
123 }
124 }
125
126 /// 设置传输层 (Builder 模式)
127 ///
128 /// # 参数
129 ///
130 /// - `transport`: 传输层实现 (StdioTransport, TcpTransport 等)
131 ///
132 /// # 示例
133 ///
134 /// ```rust,ignore
135 /// let bridge = McpBridge::new(config)
136 /// .with_transport(Box::new(StdioTransport::new()));
137 /// ```
138 pub fn with_transport(self, transport: Box<dyn McpTransport>) -> Self {
139 *self.transport.write() = Some(Arc::from(transport));
140 self
141 }
142
143 /// 注册工具及其执行器
144 ///
145 /// # 参数
146 ///
147 /// - `tool`: 工具定义
148 /// - `executor`: 工具执行器实现
149 pub fn register_tool(&self, tool: ToolDefinition, executor: Arc<dyn ToolExecutor>) {
150 self.handler.register_tool(tool, executor);
151 }
152
153 /// 注册简单工具 (便捷方法)
154 ///
155 /// 适用于不需要复杂执行器逻辑的工具。
156 ///
157 /// # 参数
158 ///
159 /// - `name`: 工具名称
160 /// - `description`: 工具描述
161 /// - `executor`: 执行函数,接收工具名和参数,返回执行结果
162 ///
163 /// # 示例
164 ///
165 /// ```rust,ignore
166 /// bridge.register_simple_tool("echo", "Echo input", |_name, args| {
167 /// Ok(ToolResult {
168 /// is_error: false,
169 /// content: vec![ContentBlock::Text { text: args.to_string() }],
170 /// })
171 /// });
172 /// ```
173 pub fn register_simple_tool<F>(&self, name: &str, description: &str, executor: F)
174 where
175 F: Fn(&str, Value) -> Result<ToolResult> + Send + Sync + 'static,
176 {
177 let tool = ToolDefinition {
178 name: name.to_string(),
179 description: Some(description.to_string()),
180 input_schema: None,
181 };
182 self.register_tool(tool, Arc::new(super::handler::SimpleToolExecutor(executor)));
183 }
184
185 /// 生成下一个请求 ID (内部方法)
186 fn next_request_id(&self) -> RequestId {
187 RequestId::Number(self.request_id_counter.fetch_add(1, Ordering::SeqCst) as i64)
188 }
189
190 /// 启动桥接器
191 ///
192 /// 初始化消息处理循环,开始接收和处理 MCP 消息。
193 ///
194 /// # 错误
195 ///
196 /// 如果传输层未初始化或启动失败,返回错误。
197 pub async fn start(&self) -> Result<()> {
198 self.running.store(true, Ordering::SeqCst);
199
200 // Clone transport Arc while holding lock
201 let transport_opt = {
202 let transport_guard = self.transport.read();
203 transport_guard.clone()
204 };
205
206 // 启动消息处理循环
207 let handler = self.handler.clone();
208 let pending = self.pending_responses.clone();
209 let running = self.running.clone();
210
211 tokio::spawn(async move {
212 info!("MCP message loop started");
213
214 if transport_opt.is_none() {
215 info!("No transport configured, message loop will idle");
216 }
217
218 loop {
219 // Check if we should stop
220 if !running.load(Ordering::SeqCst) {
221 info!("MCP message loop stopping");
222 break;
223 }
224
225 // Read message from transport
226 if let Some(ref t) = transport_opt {
227 match t.receive().await {
228 Ok(Some(message)) => {
229 // Process received message
230 match message {
231 McpMessage::Request(request) => {
232 // Handle incoming request
233 match handler.handle(&request).await {
234 Ok(response) => {
235 // Send response back
236 if let Err(e) =
237 t.send(&McpMessage::Response(response)).await
238 {
239 warn!("Failed to send response: {}", e);
240 }
241 }
242 Err(e) => {
243 warn!(
244 "Handler error for request {:?}: {}",
245 request.id, e
246 );
247 }
248 }
249 }
250 McpMessage::Notification(notification) => {
251 // Handle notification (no response needed)
252 if let Err(e) = handler.handle_notification(¬ification).await
253 {
254 warn!("Notification handler error: {}", e);
255 }
256 }
257 McpMessage::Response(response) => {
258 // Response to our request - find matching pending request
259 // Release lock before await to satisfy Send requirement
260 let sender_opt = pending.write().remove(&response.id);
261 if let Some(sender) = sender_opt {
262 if let Err(e) = sender.send(response).await {
263 warn!(
264 "Failed to forward response to pending request: {}",
265 e
266 );
267 }
268 } else {
269 debug!(
270 "Received response for unknown request {:?}",
271 response.id
272 );
273 }
274 }
275 McpMessage::Error(error) => {
276 warn!("Received error: {:?}", error);
277 }
278 }
279 }
280 Ok(None) => {
281 // No message available, brief sleep
282 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
283 }
284 Err(e) => {
285 warn!("Transport receive error: {}", e);
286 // Brief pause before retry to avoid tight loop on error
287 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
288 }
289 }
290 } else {
291 // No transport configured
292 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
293 }
294 }
295
296 // Clean up pending responses when stopping
297 pending.write().clear();
298 info!("MCP message loop stopped");
299 });
300
301 Ok(())
302 }
303
304 /// 停止桥接器
305 ///
306 /// 关闭消息处理循环并释放传输层资源。
307 ///
308 /// # 错误
309 ///
310 /// 如果传输层关闭失败,返回错误。
311 pub async fn stop(&self) -> Result<()> {
312 self.running.store(false, Ordering::SeqCst);
313
314 let transport = self.transport.write().take();
315 if let Some(transport) = transport {
316 transport.close().await?;
317 }
318
319 Ok(())
320 }
321
322 /// 发送请求并等待响应
323 ///
324 /// 向 MCP 服务端发送请求消息,并等待对应的响应。
325 ///
326 /// # 参数
327 ///
328 /// - `method`: MCP 方法名 (如 "tools/list", "tools/call")
329 /// - `params`: 可选的请求参数
330 ///
331 /// # 返回
332 ///
333 /// 返回对应的响应消息。
334 ///
335 /// # 错误
336 ///
337 /// - 如果传输层未初始化,返回 `Transport not initialized` 错误
338 /// - 如果请求超时,返回超时错误
339 pub async fn request(&self, method: &str, params: Option<Value>) -> Result<McpResponse> {
340 let id = self.next_request_id();
341 let timeout_duration = std::time::Duration::from_millis(self.config.request_timeout_ms);
342
343 // 创建响应接收通道
344 let (tx, mut rx) = mpsc::channel::<McpResponse>(1);
345
346 // 注册待处理请求
347 {
348 self.pending_responses.write().insert(id.clone(), tx);
349 }
350
351 let request = McpRequest {
352 id: id.clone(),
353 method: method.to_string(),
354 params,
355 };
356
357 let message = McpMessage::Request(request);
358
359 // 发送请求
360 let transport = {
361 let transport_guard = self.transport.read();
362 transport_guard
363 .as_ref()
364 .ok_or_else(|| anyhow!("Transport not initialized"))?
365 .clone()
366 };
367 transport.send(&message).await?;
368
369 // 等待响应,带超时
370 let result = tokio::time::timeout(timeout_duration, rx.recv()).await;
371
372 // 清理待处理请求(无论成功还是失败)
373 self.pending_responses.write().remove(&id);
374
375 match result {
376 Ok(Some(response)) => Ok(response),
377 Ok(None) => Err(anyhow!("Response channel closed")),
378 Err(_) => Err(anyhow!(
379 "Request timeout after {}ms",
380 self.config.request_timeout_ms
381 )),
382 }
383 }
384
385 /// 发送通知消息
386 ///
387 /// 向 MCP 服务端发送通知消息,不等待响应。
388 ///
389 /// # 参数
390 ///
391 /// - `method`: 通知方法名 (如 "notifications/initialized")
392 /// - `params`: 可选的通知参数
393 ///
394 /// # 错误
395 ///
396 /// 如果传输层未初始化,返回 `Transport not initialized` 错误。
397 #[allow(clippy::await_holding_lock)]
398 pub async fn notify(&self, method: &str, params: Option<Value>) -> Result<()> {
399 let notification = McpNotification {
400 method: method.to_string(),
401 params,
402 };
403
404 let message = McpMessage::Notification(notification);
405
406 let transport_guard = self.transport.read();
407 let transport = transport_guard
408 .as_ref()
409 .ok_or_else(|| anyhow!("Transport not initialized"))?;
410 transport.send(&message).await?;
411
412 Ok(())
413 }
414
415 /// 列出可用工具
416 ///
417 /// 请求 MCP 服务端列出所有可用工具。
418 ///
419 /// # 返回
420 ///
421 /// 返回工具定义列表。
422 ///
423 /// # 错误
424 ///
425 /// 如果请求失败或响应解析失败,返回错误。
426 pub async fn list_tools(&self) -> Result<Vec<ToolDefinition>> {
427 let response = self.request("tools/list", None).await?;
428
429 if let Some(result) = response.result {
430 let tools: Vec<ToolDefinition> = serde_json::from_value(
431 result.get("tools").cloned().unwrap_or(Value::Array(vec![])),
432 )?;
433 Ok(tools)
434 } else {
435 Ok(vec![])
436 }
437 }
438
439 /// 调用工具
440 ///
441 /// 调用 MCP 服务端上的指定工具。
442 ///
443 /// # 参数
444 ///
445 /// - `name`: 工具名称
446 /// - `arguments`: 工具参数 (JSON 格式)
447 ///
448 /// # 返回
449 ///
450 /// 返回工具执行结果。
451 ///
452 /// # 错误
453 ///
454 /// - 如果工具不存在或执行失败,返回错误
455 /// - 如果响应解析失败,返回错误
456 pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<ToolResult> {
457 let params = serde_json::json!({
458 "name": name,
459 "arguments": arguments
460 });
461
462 let response = self.request("tools/call", Some(params)).await?;
463
464 if let Some(result) = response.result {
465 let tool_result: ToolResult = serde_json::from_value(result)?;
466 Ok(tool_result)
467 } else if let Some(error) = response.error {
468 Err(anyhow!("Tool call error: {}", error.message))
469 } else {
470 Err(anyhow!("Unknown error"))
471 }
472 }
473
474 /// 初始化 MCP 连接
475 ///
476 /// 与 MCP 服务端进行握手,交换协议版本和能力信息。
477 ///
478 /// # 参数
479 ///
480 /// - `client_info`: 客户端名称
481 /// - `version`: 客户端版本号
482 ///
483 /// # 流程
484 ///
485 /// 1. 发送 `initialize` 请求
486 /// 2. 接收服务端响应
487 /// 3. 发送 `notifications/initialized` 通知
488 ///
489 /// # 错误
490 ///
491 /// 如果初始化请求失败,返回 `Initialize failed` 错误。
492 pub async fn initialize(&self, client_info: &str, version: &str) -> Result<()> {
493 let params = serde_json::json!({
494 "protocol_version": "2024-11-05",
495 "capabilities": {},
496 "client_info": {
497 "name": client_info,
498 "version": version
499 }
500 });
501
502 let response = self.request("initialize", Some(params)).await?;
503
504 if response.error.is_some() {
505 return Err(anyhow!("Initialize failed"));
506 }
507
508 // 发送 initialized 通知
509 self.notify("notifications/initialized", None).await?;
510
511 Ok(())
512 }
513
514 /// 检查桥接器是否正在运行
515 ///
516 /// # 返回
517 ///
518 /// 如果桥接器已启动且正在处理消息,返回 `true`;否则返回 `false`。
519 pub fn is_running(&self) -> bool {
520 self.running.load(Ordering::SeqCst)
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527 use crate::mcp_bridge::protocol::ContentBlock;
528
529 #[tokio::test]
530 async fn test_bridge_creation() {
531 let config = McpBridgeConfig::default();
532 let bridge = McpBridge::new(config);
533
534 assert!(!bridge.is_running());
535 }
536
537 #[tokio::test]
538 async fn test_register_tool() {
539 let bridge = McpBridge::new(McpBridgeConfig::default());
540
541 bridge.register_simple_tool("test_tool", "A test tool", |_name, _args| {
542 Ok(ToolResult {
543 is_error: false,
544 content: vec![ContentBlock::Text {
545 text: "OK".to_string(),
546 }],
547 })
548 });
549
550 // 工具已注册
551 }
552
553 #[tokio::test]
554 async fn test_next_request_id() {
555 let bridge = McpBridge::new(McpBridgeConfig::default());
556
557 let id1 = bridge.next_request_id();
558 let id2 = bridge.next_request_id();
559
560 assert_ne!(id1, id2);
561 }
562
563 #[test]
564 fn test_config_default() {
565 let config = McpBridgeConfig::default();
566 assert_eq!(config.server_name, "Continuum");
567 assert_eq!(config.request_timeout_ms, 30000);
568 }
569}