sh_layer4/mcp_bridge/bridge.rs
1//! MCP 桥接器
2//!
3//! MCP (Model Context Protocol) 协议的主要实现。
4//!
5//! # 功能
6//!
7//! - MCP 协议消息处理与路由
8//! - 工具注册与发现
9//! - 多种传输层支持 (stdio, tcp, unix socket)
10//! - 请求/响应生命周期管理
11//!
12//! # 用法示例
13//!
14//! ```rust,ignore
15//! use sh_layer4::mcp_bridge::{McpBridge, McpBridgeConfig, ToolDefinition, ToolResult, ContentBlock};
16//!
17//! // 创建桥接器
18//! let config = McpBridgeConfig {
19//! server_name: "my-server".to_string(),
20//! server_version: "1.0.0".to_string(),
21//! request_timeout_ms: 30000,
22//! max_concurrent_requests: 100,
23//! };
24//! let bridge = McpBridge::new(config);
25//!
26//! // 注册工具
27//! bridge.register_simple_tool("echo", "Echo input text", |_name, args| {
28//! Ok(ToolResult {
29//! is_error: false,
30//! content: vec![ContentBlock::Text { text: args.to_string() }],
31//! })
32//! });
33//!
34//! // 启动服务
35//! bridge.start().await?;
36//! ```
37
38use parking_lot::RwLock;
39use serde_json::Value;
40use std::collections::HashMap;
41use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
42use std::sync::Arc;
43use tokio::sync::mpsc;
44use tracing::{debug, info, warn};
45
46use super::handler::{DefaultHandler, McpHandler, ToolExecutor};
47use super::protocol::{
48 McpMessage, McpNotification, McpRequest, McpResponse, RequestId, ToolDefinition, ToolResult,
49};
50use super::transport::McpTransport;
51use anyhow::{anyhow, Result};
52
53/// MCP 桥接器配置
54///
55/// 配置 MCP 服务端的基本参数。
56#[derive(Debug, Clone)]
57pub struct McpBridgeConfig {
58 /// 服务端名称,用于客户端识别
59 pub server_name: String,
60 /// 服务端版本号
61 pub server_version: String,
62 /// 请求超时时间 (毫秒)
63 pub request_timeout_ms: u64,
64 /// 最大并发请求数
65 pub max_concurrent_requests: usize,
66}
67
68impl Default for McpBridgeConfig {
69 fn default() -> Self {
70 Self {
71 server_name: "Continuum".to_string(),
72 server_version: "0.1.0".to_string(),
73 request_timeout_ms: 30000,
74 max_concurrent_requests: 100,
75 }
76 }
77}
78
79/// MCP 桥接器
80///
81/// MCP 协议的核心实现,负责消息处理、工具注册和传输层管理。
82///
83/// # 线程安全
84///
85/// 所有内部状态都通过 `RwLock` 或原子类型保护,支持多线程并发访问。
86pub struct McpBridge {
87 /// 传输层实例
88 transport: RwLock<Option<Arc<dyn McpTransport>>>,
89 /// 消息处理器
90 handler: Arc<DefaultHandler>,
91 /// 配置
92 config: McpBridgeConfig,
93 /// 请求 ID 计数器 (原子递增)
94 request_id_counter: AtomicU64,
95 /// 待处理响应映射表
96 pending_responses: Arc<RwLock<HashMap<RequestId, mpsc::Sender<McpResponse>>>>,
97 /// 运行状态
98 running: Arc<AtomicBool>,
99}
100
101impl McpBridge {
102 /// 创建新的 MCP 桥接器
103 ///
104 /// # 参数
105 ///
106 /// - `config`: 桥接器配置
107 ///
108 /// # 示例
109 ///
110 /// ```rust,ignore
111 /// let config = McpBridgeConfig::default();
112 /// let bridge = McpBridge::new(config);
113 /// ```
114 pub fn new(config: McpBridgeConfig) -> Self {
115 let handler = DefaultHandler::new(&config.server_name, &config.server_version);
116 Self {
117 transport: RwLock::new(None),
118 handler: Arc::new(handler),
119 config,
120 request_id_counter: AtomicU64::new(0),
121 pending_responses: Arc::new(RwLock::new(HashMap::new())),
122 running: Arc::new(AtomicBool::new(false)),
123 }
124 }
125
126 /// 设置传输层 (Builder 模式)
127 ///
128 /// # 参数
129 ///
130 /// - `transport`: 传输层实现 (StdioTransport, TcpTransport 等)
131 ///
132 /// # 示例
133 ///
134 /// ```rust,ignore
135 /// let bridge = McpBridge::new(config)
136 /// .with_transport(Box::new(StdioTransport::new()));
137 /// ```
138 pub fn with_transport(self, transport: Box<dyn McpTransport>) -> Self {
139 *self.transport.write() = Some(Arc::from(transport));
140 self
141 }
142
143 /// 注册工具及其执行器
144 ///
145 /// # 参数
146 ///
147 /// - `tool`: 工具定义
148 /// - `executor`: 工具执行器实现
149 pub fn register_tool(&self, tool: ToolDefinition, executor: Arc<dyn ToolExecutor>) {
150 self.handler.register_tool(tool, executor);
151 }
152
153 /// 注册简单工具 (便捷方法)
154 ///
155 /// 适用于不需要复杂执行器逻辑的工具。
156 ///
157 /// # 参数
158 ///
159 /// - `name`: 工具名称
160 /// - `description`: 工具描述
161 /// - `executor`: 执行函数,接收工具名和参数,返回执行结果
162 ///
163 /// # 示例
164 ///
165 /// ```rust,ignore
166 /// bridge.register_simple_tool("echo", "Echo input", |_name, args| {
167 /// Ok(ToolResult {
168 /// is_error: false,
169 /// content: vec![ContentBlock::Text { text: args.to_string() }],
170 /// })
171 /// });
172 /// ```
173 pub fn register_simple_tool<F>(&self, name: &str, description: &str, executor: F)
174 where
175 F: Fn(&str, Value) -> Result<ToolResult> + Send + Sync + 'static,
176 {
177 let tool = ToolDefinition {
178 name: name.to_string(),
179 description: Some(description.to_string()),
180 input_schema: None,
181 };
182 self.register_tool(tool, Arc::new(super::handler::SimpleToolExecutor(executor)));
183 }
184
185 /// 生成下一个请求 ID (内部方法)
186 fn next_request_id(&self) -> RequestId {
187 RequestId::Number(self.request_id_counter.fetch_add(1, Ordering::SeqCst) as i64)
188 }
189
190 /// 启动桥接器
191 ///
192 /// 初始化消息处理循环,开始接收和处理 MCP 消息。
193 ///
194 /// # 错误
195 ///
196 /// 如果传输层未初始化或启动失败,返回错误。
197 pub async fn start(&self) -> Result<()> {
198 self.running.store(true, Ordering::SeqCst);
199
200 // Clone transport Arc while holding lock
201 let transport_opt = {
202 let transport_guard = self.transport.read();
203 transport_guard.clone()
204 };
205
206 // 启动消息处理循环
207 let handler = self.handler.clone();
208 let pending = self.pending_responses.clone();
209 let running = self.running.clone();
210
211 tokio::spawn(async move {
212 info!("MCP message loop started");
213
214 if transport_opt.is_none() {
215 info!("No transport configured, message loop will idle");
216 }
217
218 loop {
219 // Check if we should stop
220 if !running.load(Ordering::SeqCst) {
221 info!("MCP message loop stopping");
222 break;
223 }
224
225 // Read message from transport
226 if let Some(ref t) = transport_opt {
227 match t.receive().await {
228 Ok(Some(message)) => {
229 // Process received message
230 match message {
231 McpMessage::Request(request) => {
232 // Handle incoming request
233 match handler.handle(&request).await {
234 Ok(response) => {
235 // Send response back
236 if let Err(e) =
237 t.send(&McpMessage::Response(response)).await
238 {
239 warn!("Failed to send response: {}", e);
240 }
241 }
242 Err(e) => {
243 warn!(
244 "Handler error for request {:?}: {}",
245 request.id, e
246 );
247 }
248 }
249 }
250 McpMessage::Notification(notification) => {
251 // Handle notification (no response needed)
252 if let Err(e) = handler.handle_notification(¬ification).await
253 {
254 warn!("Notification handler error: {}", e);
255 }
256 }
257 McpMessage::Response(response) => {
258 // Response to our request - find matching pending request
259 // Release lock before await to satisfy Send requirement
260 let sender_opt = pending.write().remove(&response.id);
261 if let Some(sender) = sender_opt {
262 if let Err(e) = sender.send(response).await {
263 warn!(
264 "Failed to forward response to pending request: {}",
265 e
266 );
267 }
268 } else {
269 debug!(
270 "Received response for unknown request {:?}",
271 response.id
272 );
273 }
274 }
275 McpMessage::Error(error) => {
276 warn!("Received error: {:?}", error);
277 }
278 }
279 }
280 Ok(None) => {
281 // No message available, brief sleep
282 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
283 }
284 Err(e) => {
285 warn!("Transport receive error: {}", e);
286 // Brief pause before retry to avoid tight loop on error
287 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
288 }
289 }
290 } else {
291 // No transport configured
292 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
293 }
294 }
295
296 // Clean up pending responses when stopping
297 pending.write().clear();
298 info!("MCP message loop stopped");
299 });
300
301 Ok(())
302 }
303
304 /// 停止桥接器
305 ///
306 /// 关闭消息处理循环并释放传输层资源。
307 ///
308 /// # 错误
309 ///
310 /// 如果传输层关闭失败,返回错误。
311 pub async fn stop(&self) -> Result<()> {
312 self.running.store(false, Ordering::SeqCst);
313
314 let transport = self.transport.write().take();
315 if let Some(transport) = transport {
316 transport.close().await?;
317 }
318
319 Ok(())
320 }
321
322 /// 发送请求并等待响应
323 ///
324 /// 向 MCP 服务端发送请求消息,并等待对应的响应。
325 ///
326 /// # 参数
327 ///
328 /// - `method`: MCP 方法名 (如 "tools/list", "tools/call")
329 /// - `params`: 可选的请求参数
330 ///
331 /// # 返回
332 ///
333 /// 返回对应的响应消息。
334 ///
335 /// # 错误
336 ///
337 /// - 如果传输层未初始化,返回 `Transport not initialized` 错误
338 /// - 如果请求超时,返回超时错误
339 pub async fn request(&self, method: &str, params: Option<Value>) -> Result<McpResponse> {
340 let id = self.next_request_id();
341 let timeout_duration = std::time::Duration::from_millis(self.config.request_timeout_ms);
342
343 // 创建响应接收通道
344 let (tx, mut rx) = mpsc::channel::<McpResponse>(1);
345
346 // 注册待处理请求
347 {
348 self.pending_responses.write().insert(id.clone(), tx);
349 }
350
351 let request = McpRequest {
352 id: id.clone(),
353 method: method.to_string(),
354 params,
355 };
356
357 let message = McpMessage::Request(request);
358
359 // 发送请求
360 {
361 let transport_guard = self.transport.read();
362 let transport = transport_guard
363 .as_ref()
364 .ok_or_else(|| anyhow!("Transport not initialized"))?;
365 transport.send(&message).await?;
366 }
367
368 // 等待响应,带超时
369 let result = tokio::time::timeout(timeout_duration, rx.recv()).await;
370
371 // 清理待处理请求(无论成功还是失败)
372 self.pending_responses.write().remove(&id);
373
374 match result {
375 Ok(Some(response)) => Ok(response),
376 Ok(None) => Err(anyhow!("Response channel closed")),
377 Err(_) => Err(anyhow!(
378 "Request timeout after {}ms",
379 self.config.request_timeout_ms
380 )),
381 }
382 }
383
384 /// 发送通知消息
385 ///
386 /// 向 MCP 服务端发送通知消息,不等待响应。
387 ///
388 /// # 参数
389 ///
390 /// - `method`: 通知方法名 (如 "notifications/initialized")
391 /// - `params`: 可选的通知参数
392 ///
393 /// # 错误
394 ///
395 /// 如果传输层未初始化,返回 `Transport not initialized` 错误。
396 #[allow(clippy::await_holding_lock)]
397 pub async fn notify(&self, method: &str, params: Option<Value>) -> Result<()> {
398 let notification = McpNotification {
399 method: method.to_string(),
400 params,
401 };
402
403 let message = McpMessage::Notification(notification);
404
405 let transport_guard = self.transport.read();
406 let transport = transport_guard
407 .as_ref()
408 .ok_or_else(|| anyhow!("Transport not initialized"))?;
409 transport.send(&message).await?;
410
411 Ok(())
412 }
413
414 /// 列出可用工具
415 ///
416 /// 请求 MCP 服务端列出所有可用工具。
417 ///
418 /// # 返回
419 ///
420 /// 返回工具定义列表。
421 ///
422 /// # 错误
423 ///
424 /// 如果请求失败或响应解析失败,返回错误。
425 pub async fn list_tools(&self) -> Result<Vec<ToolDefinition>> {
426 let response = self.request("tools/list", None).await?;
427
428 if let Some(result) = response.result {
429 let tools: Vec<ToolDefinition> = serde_json::from_value(
430 result.get("tools").cloned().unwrap_or(Value::Array(vec![])),
431 )?;
432 Ok(tools)
433 } else {
434 Ok(vec![])
435 }
436 }
437
438 /// 调用工具
439 ///
440 /// 调用 MCP 服务端上的指定工具。
441 ///
442 /// # 参数
443 ///
444 /// - `name`: 工具名称
445 /// - `arguments`: 工具参数 (JSON 格式)
446 ///
447 /// # 返回
448 ///
449 /// 返回工具执行结果。
450 ///
451 /// # 错误
452 ///
453 /// - 如果工具不存在或执行失败,返回错误
454 /// - 如果响应解析失败,返回错误
455 pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<ToolResult> {
456 let params = serde_json::json!({
457 "name": name,
458 "arguments": arguments
459 });
460
461 let response = self.request("tools/call", Some(params)).await?;
462
463 if let Some(result) = response.result {
464 let tool_result: ToolResult = serde_json::from_value(result)?;
465 Ok(tool_result)
466 } else if let Some(error) = response.error {
467 Err(anyhow!("Tool call error: {}", error.message))
468 } else {
469 Err(anyhow!("Unknown error"))
470 }
471 }
472
473 /// 初始化 MCP 连接
474 ///
475 /// 与 MCP 服务端进行握手,交换协议版本和能力信息。
476 ///
477 /// # 参数
478 ///
479 /// - `client_info`: 客户端名称
480 /// - `version`: 客户端版本号
481 ///
482 /// # 流程
483 ///
484 /// 1. 发送 `initialize` 请求
485 /// 2. 接收服务端响应
486 /// 3. 发送 `notifications/initialized` 通知
487 ///
488 /// # 错误
489 ///
490 /// 如果初始化请求失败,返回 `Initialize failed` 错误。
491 pub async fn initialize(&self, client_info: &str, version: &str) -> Result<()> {
492 let params = serde_json::json!({
493 "protocol_version": "2024-11-05",
494 "capabilities": {},
495 "client_info": {
496 "name": client_info,
497 "version": version
498 }
499 });
500
501 let response = self.request("initialize", Some(params)).await?;
502
503 if response.error.is_some() {
504 return Err(anyhow!("Initialize failed"));
505 }
506
507 // 发送 initialized 通知
508 self.notify("notifications/initialized", None).await?;
509
510 Ok(())
511 }
512
513 /// 检查桥接器是否正在运行
514 ///
515 /// # 返回
516 ///
517 /// 如果桥接器已启动且正在处理消息,返回 `true`;否则返回 `false`。
518 pub fn is_running(&self) -> bool {
519 self.running.load(Ordering::SeqCst)
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526 use crate::mcp_bridge::protocol::ContentBlock;
527
528 #[tokio::test]
529 async fn test_bridge_creation() {
530 let config = McpBridgeConfig::default();
531 let bridge = McpBridge::new(config);
532
533 assert!(!bridge.is_running());
534 }
535
536 #[tokio::test]
537 async fn test_register_tool() {
538 let bridge = McpBridge::new(McpBridgeConfig::default());
539
540 bridge.register_simple_tool("test_tool", "A test tool", |_name, _args| {
541 Ok(ToolResult {
542 is_error: false,
543 content: vec![ContentBlock::Text {
544 text: "OK".to_string(),
545 }],
546 })
547 });
548
549 // 工具已注册
550 }
551
552 #[tokio::test]
553 async fn test_next_request_id() {
554 let bridge = McpBridge::new(McpBridgeConfig::default());
555
556 let id1 = bridge.next_request_id();
557 let id2 = bridge.next_request_id();
558
559 assert_ne!(id1, id2);
560 }
561
562 #[test]
563 fn test_config_default() {
564 let config = McpBridgeConfig::default();
565 assert_eq!(config.server_name, "Continuum");
566 assert_eq!(config.request_timeout_ms, 30000);
567 }
568}