1use async_trait::async_trait;
6use serde_json::Value;
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use super::protocol::{
11 error_codes, Implementation, InitializeParams, InitializeResult, McpErrorData, McpNotification,
12 McpRequest, McpResponse, RequestId, ServerCapabilities, ToolDefinition, ToolResult,
13 MCP_VERSION,
14};
15use anyhow::{anyhow, Result};
16
17#[async_trait]
19pub trait McpHandler: Send + Sync {
20 async fn handle(&self, request: &McpRequest) -> Result<McpResponse>;
22
23 async fn handle_notification(&self, notification: &McpNotification) -> Result<()>;
25}
26
27pub struct DefaultHandler {
29 server_info: Implementation,
31 tools: Arc<parking_lot::RwLock<HashMap<String, ToolDefinition>>>,
33 tool_executors: Arc<parking_lot::RwLock<HashMap<String, Arc<dyn ToolExecutor>>>>,
35}
36
37#[async_trait]
39pub trait ToolExecutor: Send + Sync {
40 async fn execute(&self, name: &str, arguments: Value) -> Result<ToolResult>;
42}
43
44impl DefaultHandler {
45 pub fn new(name: &str, version: &str) -> Self {
46 Self {
47 server_info: Implementation {
48 name: name.to_string(),
49 version: version.to_string(),
50 },
51 tools: Arc::new(parking_lot::RwLock::new(HashMap::new())),
52 tool_executors: Arc::new(parking_lot::RwLock::new(HashMap::new())),
53 }
54 }
55
56 pub fn register_tool(&self, tool: ToolDefinition, executor: Arc<dyn ToolExecutor>) {
58 let name = tool.name.clone();
59 self.tools.write().insert(name.clone(), tool);
60 self.tool_executors.write().insert(name.clone(), executor);
61 }
62
63 fn handle_initialize(&self, _params: &InitializeParams) -> Result<McpResponse> {
65 let result = InitializeResult {
66 protocol_version: MCP_VERSION.to_string(),
67 capabilities: ServerCapabilities {
68 tools: Some(Default::default()),
69 resources: Some(Default::default()),
70 prompts: Some(Default::default()),
71 ..Default::default()
72 },
73 server_info: self.server_info.clone(),
74 instructions: Some("Continuum MCP Server".to_string()),
75 };
76
77 Ok(McpResponse {
78 id: RequestId::Number(0),
79 result: Some(serde_json::to_value(result)?),
80 error: None,
81 })
82 }
83
84 fn handle_list_tools(&self, id: &RequestId) -> Result<McpResponse> {
86 let tools: Vec<ToolDefinition> = self.tools.read().values().cloned().collect();
87 Ok(McpResponse {
88 id: id.clone(),
89 result: Some(serde_json::json!({ "tools": tools })),
90 error: None,
91 })
92 }
93
94 async fn handle_call_tool(
96 &self,
97 id: &RequestId,
98 params: Option<&Value>,
99 ) -> Result<McpResponse> {
100 let params = params.ok_or_else(|| anyhow!("Missing params"))?;
101
102 let name = params
103 .get("name")
104 .and_then(|v| v.as_str())
105 .ok_or_else(|| anyhow!("Missing tool name"))?;
106
107 let arguments = params.get("arguments").cloned().unwrap_or(Value::Null);
108
109 let executor = {
110 let executors = self.tool_executors.read();
111 executors
112 .get(name)
113 .ok_or_else(|| anyhow!("Tool not found: {}", name))?
114 .clone()
115 };
116
117 match executor.execute(name, arguments).await {
118 Ok(result) => Ok(McpResponse {
119 id: id.clone(),
120 result: Some(serde_json::to_value(result)?),
121 error: None,
122 }),
123 Err(e) => Ok(McpResponse {
124 id: id.clone(),
125 result: None,
126 error: Some(McpErrorData {
127 code: error_codes::INTERNAL_ERROR,
128 message: e.to_string(),
129 data: None,
130 }),
131 }),
132 }
133 }
134}
135
136#[async_trait]
137impl McpHandler for DefaultHandler {
138 async fn handle(&self, request: &McpRequest) -> Result<McpResponse> {
139 match request.method.as_str() {
140 "initialize" => {
141 let params = request
142 .params
143 .as_ref()
144 .map(|p| serde_json::from_value(p.clone()))
145 .transpose()?;
146
147 if let Some(params) = params {
148 self.handle_initialize(¶ms)
149 } else {
150 Ok(McpResponse {
151 id: request.id.clone(),
152 result: None,
153 error: Some(McpErrorData {
154 code: error_codes::INVALID_PARAMS,
155 message: "Missing initialize params".to_string(),
156 data: None,
157 }),
158 })
159 }
160 }
161 "tools/list" => self.handle_list_tools(&request.id),
162 "tools/call" => {
163 self.handle_call_tool(&request.id, request.params.as_ref())
164 .await
165 }
166 "shutdown" => Ok(McpResponse {
167 id: request.id.clone(),
168 result: Some(Value::Null),
169 error: None,
170 }),
171 _ => Ok(McpResponse {
172 id: request.id.clone(),
173 result: None,
174 error: Some(McpErrorData {
175 code: error_codes::METHOD_NOT_FOUND,
176 message: format!("Method not found: {}", request.method),
177 data: None,
178 }),
179 }),
180 }
181 }
182
183 async fn handle_notification(&self, _notification: &McpNotification) -> Result<()> {
184 Ok(())
186 }
187}
188
189pub struct SimpleToolExecutor<F>(pub F)
194where
195 F: Fn(&str, Value) -> Result<ToolResult> + Send + Sync;
196
197#[async_trait]
198impl<F> ToolExecutor for SimpleToolExecutor<F>
199where
200 F: Fn(&str, Value) -> Result<ToolResult> + Send + Sync,
201{
202 async fn execute(&self, name: &str, arguments: Value) -> Result<ToolResult> {
203 (self.0)(name, arguments)
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use crate::mcp_bridge::protocol::ContentBlock;
211
212 #[tokio::test]
213 async fn test_handle_initialize() {
214 let handler = DefaultHandler::new("test-server", "1.0.0");
215 let request = McpRequest {
216 id: RequestId::Number(1),
217 method: "initialize".to_string(),
218 params: Some(serde_json::json!({
219 "protocol_version": "2024-11-05",
220 "capabilities": {},
221 "client_info": { "name": "test-client", "version": "1.0.0" }
222 })),
223 };
224
225 let response = handler.handle(&request).await.unwrap();
226 assert!(response.error.is_none());
227 assert!(response.result.is_some());
228 }
229
230 #[tokio::test]
231 async fn test_handle_list_tools() {
232 let handler = DefaultHandler::new("test-server", "1.0.0");
233 handler.register_tool(
234 ToolDefinition {
235 name: "test_tool".to_string(),
236 description: Some("A test tool".to_string()),
237 input_schema: None,
238 },
239 Arc::new(SimpleToolExecutor(|_name, _args| {
240 Ok(ToolResult {
241 is_error: false,
242 content: vec![ContentBlock::Text {
243 text: "OK".to_string(),
244 }],
245 })
246 })),
247 );
248
249 let request = McpRequest {
250 id: RequestId::Number(2),
251 method: "tools/list".to_string(),
252 params: None,
253 };
254
255 let response = handler.handle(&request).await.unwrap();
256 assert!(response.error.is_none());
257 }
258
259 #[tokio::test]
260 async fn test_handle_unknown_method() {
261 let handler = DefaultHandler::new("test-server", "1.0.0");
262 let request = McpRequest {
263 id: RequestId::Number(3),
264 method: "unknown_method".to_string(),
265 params: None,
266 };
267
268 let response = handler.handle(&request).await.unwrap();
269 assert!(response.error.is_some());
270 assert_eq!(response.error.unwrap().code, error_codes::METHOD_NOT_FOUND);
271 }
272}