Skip to main content

sh_layer4/mcp_bridge/
handler.rs

1//! MCP 消息处理器
2//!
3//! 处理各类 MCP 消息。
4
5use async_trait::async_trait;
6use serde_json::Value;
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use super::protocol::{
11    error_codes, Implementation, InitializeParams, InitializeResult, McpErrorData, McpNotification,
12    McpRequest, McpResponse, RequestId, ServerCapabilities, ToolDefinition, ToolResult,
13    MCP_VERSION,
14};
15use anyhow::{anyhow, Result};
16
17/// MCP 消息处理器 trait
18#[async_trait]
19pub trait McpHandler: Send + Sync {
20    /// 处理请求
21    async fn handle(&self, request: &McpRequest) -> Result<McpResponse>;
22
23    /// 处理通知
24    async fn handle_notification(&self, notification: &McpNotification) -> Result<()>;
25}
26
27/// 默认处理器实现
28pub struct DefaultHandler {
29    /// 服务端信息
30    server_info: Implementation,
31    /// 已注册的工具
32    tools: Arc<parking_lot::RwLock<HashMap<String, ToolDefinition>>>,
33    /// 工具执行器
34    tool_executors: Arc<parking_lot::RwLock<HashMap<String, Arc<dyn ToolExecutor>>>>,
35}
36
37/// 工具执行器 trait
38#[async_trait]
39pub trait ToolExecutor: Send + Sync {
40    /// 执行工具
41    async fn execute(&self, name: &str, arguments: Value) -> Result<ToolResult>;
42}
43
44impl DefaultHandler {
45    pub fn new(name: &str, version: &str) -> Self {
46        Self {
47            server_info: Implementation {
48                name: name.to_string(),
49                version: version.to_string(),
50            },
51            tools: Arc::new(parking_lot::RwLock::new(HashMap::new())),
52            tool_executors: Arc::new(parking_lot::RwLock::new(HashMap::new())),
53        }
54    }
55
56    /// 注册工具
57    pub fn register_tool(&self, tool: ToolDefinition, executor: Arc<dyn ToolExecutor>) {
58        let name = tool.name.clone();
59        self.tools.write().insert(name.clone(), tool);
60        self.tool_executors.write().insert(name.clone(), executor);
61    }
62
63    /// 处理初始化请求
64    fn handle_initialize(&self, _params: &InitializeParams) -> Result<McpResponse> {
65        let result = InitializeResult {
66            protocol_version: MCP_VERSION.to_string(),
67            capabilities: ServerCapabilities {
68                tools: Some(Default::default()),
69                resources: Some(Default::default()),
70                prompts: Some(Default::default()),
71                ..Default::default()
72            },
73            server_info: self.server_info.clone(),
74            instructions: Some("Continuum MCP Server".to_string()),
75        };
76
77        Ok(McpResponse {
78            id: RequestId::Number(0),
79            result: Some(serde_json::to_value(result)?),
80            error: None,
81        })
82    }
83
84    /// 处理列出工具请求
85    fn handle_list_tools(&self, id: &RequestId) -> Result<McpResponse> {
86        let tools: Vec<ToolDefinition> = self.tools.read().values().cloned().collect();
87        Ok(McpResponse {
88            id: id.clone(),
89            result: Some(serde_json::json!({ "tools": tools })),
90            error: None,
91        })
92    }
93
94    /// 处理调用工具请求
95    async fn handle_call_tool(
96        &self,
97        id: &RequestId,
98        params: Option<&Value>,
99    ) -> Result<McpResponse> {
100        let params = params.ok_or_else(|| anyhow!("Missing params"))?;
101
102        let name = params
103            .get("name")
104            .and_then(|v| v.as_str())
105            .ok_or_else(|| anyhow!("Missing tool name"))?;
106
107        let arguments = params.get("arguments").cloned().unwrap_or(Value::Null);
108
109        let executor = {
110            let executors = self.tool_executors.read();
111            executors
112                .get(name)
113                .ok_or_else(|| anyhow!("Tool not found: {}", name))?
114                .clone()
115        };
116
117        match executor.execute(name, arguments).await {
118            Ok(result) => Ok(McpResponse {
119                id: id.clone(),
120                result: Some(serde_json::to_value(result)?),
121                error: None,
122            }),
123            Err(e) => Ok(McpResponse {
124                id: id.clone(),
125                result: None,
126                error: Some(McpErrorData {
127                    code: error_codes::INTERNAL_ERROR,
128                    message: e.to_string(),
129                    data: None,
130                }),
131            }),
132        }
133    }
134}
135
136#[async_trait]
137impl McpHandler for DefaultHandler {
138    async fn handle(&self, request: &McpRequest) -> Result<McpResponse> {
139        match request.method.as_str() {
140            "initialize" => {
141                let params = request
142                    .params
143                    .as_ref()
144                    .map(|p| serde_json::from_value(p.clone()))
145                    .transpose()?;
146
147                if let Some(params) = params {
148                    self.handle_initialize(&params)
149                } else {
150                    Ok(McpResponse {
151                        id: request.id.clone(),
152                        result: None,
153                        error: Some(McpErrorData {
154                            code: error_codes::INVALID_PARAMS,
155                            message: "Missing initialize params".to_string(),
156                            data: None,
157                        }),
158                    })
159                }
160            }
161            "tools/list" => self.handle_list_tools(&request.id),
162            "tools/call" => {
163                self.handle_call_tool(&request.id, request.params.as_ref())
164                    .await
165            }
166            "shutdown" => Ok(McpResponse {
167                id: request.id.clone(),
168                result: Some(Value::Null),
169                error: None,
170            }),
171            _ => Ok(McpResponse {
172                id: request.id.clone(),
173                result: None,
174                error: Some(McpErrorData {
175                    code: error_codes::METHOD_NOT_FOUND,
176                    message: format!("Method not found: {}", request.method),
177                    data: None,
178                }),
179            }),
180        }
181    }
182
183    async fn handle_notification(&self, _notification: &McpNotification) -> Result<()> {
184        // Notifications are handled asynchronously and don't require responses
185        Ok(())
186    }
187}
188
189/// 简单工具执行器
190///
191/// 将闭包包装为 `ToolExecutor` trait 实现。
192/// 适用于不需要异步执行的简单工具。
193pub struct SimpleToolExecutor<F>(pub F)
194where
195    F: Fn(&str, Value) -> Result<ToolResult> + Send + Sync;
196
197#[async_trait]
198impl<F> ToolExecutor for SimpleToolExecutor<F>
199where
200    F: Fn(&str, Value) -> Result<ToolResult> + Send + Sync,
201{
202    async fn execute(&self, name: &str, arguments: Value) -> Result<ToolResult> {
203        (self.0)(name, arguments)
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use crate::mcp_bridge::protocol::ContentBlock;
211
212    #[tokio::test]
213    async fn test_handle_initialize() {
214        let handler = DefaultHandler::new("test-server", "1.0.0");
215        let request = McpRequest {
216            id: RequestId::Number(1),
217            method: "initialize".to_string(),
218            params: Some(serde_json::json!({
219                "protocol_version": "2024-11-05",
220                "capabilities": {},
221                "client_info": { "name": "test-client", "version": "1.0.0" }
222            })),
223        };
224
225        let response = handler.handle(&request).await.unwrap();
226        assert!(response.error.is_none());
227        assert!(response.result.is_some());
228    }
229
230    #[tokio::test]
231    async fn test_handle_list_tools() {
232        let handler = DefaultHandler::new("test-server", "1.0.0");
233        handler.register_tool(
234            ToolDefinition {
235                name: "test_tool".to_string(),
236                description: Some("A test tool".to_string()),
237                input_schema: None,
238            },
239            Arc::new(SimpleToolExecutor(|_name, _args| {
240                Ok(ToolResult {
241                    is_error: false,
242                    content: vec![ContentBlock::Text {
243                        text: "OK".to_string(),
244                    }],
245                })
246            })),
247        );
248
249        let request = McpRequest {
250            id: RequestId::Number(2),
251            method: "tools/list".to_string(),
252            params: None,
253        };
254
255        let response = handler.handle(&request).await.unwrap();
256        assert!(response.error.is_none());
257    }
258
259    #[tokio::test]
260    async fn test_handle_unknown_method() {
261        let handler = DefaultHandler::new("test-server", "1.0.0");
262        let request = McpRequest {
263            id: RequestId::Number(3),
264            method: "unknown_method".to_string(),
265            params: None,
266        };
267
268        let response = handler.handle(&request).await.unwrap();
269        assert!(response.error.is_some());
270        assert_eq!(response.error.unwrap().code, error_codes::METHOD_NOT_FOUND);
271    }
272}