1use async_trait::async_trait;
4use serde_json::{json, Value};
5use std::collections::HashMap;
6use thulp_core::{ToolDefinition, ToolResult};
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8
9#[async_trait]
11pub trait ToolHandler: Send + Sync {
12 async fn call(&self, arguments: Value) -> ToolResult;
14}
15
16pub struct McpServer {
18 name: String,
19 version: String,
20 tools: HashMap<String, (ToolDefinition, Box<dyn ToolHandler>)>,
21}
22
23pub struct McpServerBuilder {
25 name: String,
26 version: String,
27 tools: HashMap<String, (ToolDefinition, Box<dyn ToolHandler>)>,
28}
29
30impl McpServerBuilder {
31 fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
32 Self {
33 name: name.into(),
34 version: version.into(),
35 tools: HashMap::new(),
36 }
37 }
38
39 pub fn tool(
41 mut self,
42 name: impl Into<String>,
43 definition: ToolDefinition,
44 handler: Box<dyn ToolHandler>,
45 ) -> Self {
46 self.tools.insert(name.into(), (definition, handler));
47 self
48 }
49
50 pub fn build(self) -> McpServer {
52 McpServer {
53 name: self.name,
54 version: self.version,
55 tools: self.tools,
56 }
57 }
58}
59
60impl McpServer {
61 pub fn builder(name: impl Into<String>, version: impl Into<String>) -> McpServerBuilder {
63 McpServerBuilder::new(name, version)
64 }
65
66 pub async fn handle_request(&self, request: Value) -> Value {
68 let id = request.get("id").cloned().unwrap_or(Value::Null);
69 let method = request
70 .get("method")
71 .and_then(|m| m.as_str())
72 .unwrap_or("");
73
74 match method {
75 "initialize" => self.handle_initialize(id),
76 "tools/list" => self.handle_tools_list(id),
77 "tools/call" => self.handle_tools_call(id, &request).await,
78 _ => json_rpc_error(id, -32601, &format!("Method not found: {method}")),
79 }
80 }
81
82 fn handle_initialize(&self, id: Value) -> Value {
83 json!({
84 "jsonrpc": "2.0",
85 "id": id,
86 "result": {
87 "protocolVersion": "2024-11-05",
88 "capabilities": {
89 "tools": {}
90 },
91 "serverInfo": {
92 "name": self.name,
93 "version": self.version
94 }
95 }
96 })
97 }
98
99 fn handle_tools_list(&self, id: Value) -> Value {
100 let tools: Vec<Value> = self
101 .tools
102 .values()
103 .map(|(def, _)| {
104 json!({
105 "name": def.name,
106 "description": def.description,
107 "inputSchema": {
108 "type": "object",
109 "properties": {},
110 }
111 })
112 })
113 .collect();
114
115 json!({
116 "jsonrpc": "2.0",
117 "id": id,
118 "result": { "tools": tools }
119 })
120 }
121
122 async fn handle_tools_call(&self, id: Value, request: &Value) -> Value {
123 let params = request.get("params").cloned().unwrap_or(json!({}));
124 let tool_name = params
125 .get("name")
126 .and_then(|n| n.as_str())
127 .unwrap_or("");
128 let arguments = params.get("arguments").cloned().unwrap_or(json!({}));
129
130 match self.tools.get(tool_name) {
131 Some((_, handler)) => {
132 let result = handler.call(arguments).await;
133 let content = if result.success {
134 let text = result
135 .data
136 .map(|d| d.to_string())
137 .unwrap_or_default();
138 vec![json!({"type": "text", "text": text})]
139 } else {
140 let text = result.error.unwrap_or_else(|| "unknown error".into());
141 vec![json!({"type": "text", "text": text})]
142 };
143 json!({
144 "jsonrpc": "2.0",
145 "id": id,
146 "result": {
147 "content": content,
148 "isError": !result.success
149 }
150 })
151 }
152 None => json_rpc_error(id, -32602, &format!("Unknown tool: {tool_name}")),
153 }
154 }
155
156 pub async fn serve_stdio(&self) -> std::io::Result<()> {
159 let stdin = tokio::io::stdin();
160 let mut stdout = tokio::io::stdout();
161 let mut reader = BufReader::new(stdin);
162 let mut line = String::new();
163
164 loop {
165 line.clear();
166 let n = reader.read_line(&mut line).await?;
167 if n == 0 {
168 break; }
170 let trimmed = line.trim();
171 if trimmed.is_empty() {
172 continue;
173 }
174 let request: Value = match serde_json::from_str(trimmed) {
175 Ok(v) => v,
176 Err(_) => {
177 let err = json_rpc_error(Value::Null, -32700, "Parse error");
178 let mut out = serde_json::to_string(&err).unwrap();
179 out.push('\n');
180 stdout.write_all(out.as_bytes()).await?;
181 stdout.flush().await?;
182 continue;
183 }
184 };
185
186 let response = self.handle_request(request).await;
187 let mut out = serde_json::to_string(&response).unwrap();
188 out.push('\n');
189 stdout.write_all(out.as_bytes()).await?;
190 stdout.flush().await?;
191 }
192 Ok(())
193 }
194}
195
196fn json_rpc_error(id: Value, code: i32, message: &str) -> Value {
197 json!({
198 "jsonrpc": "2.0",
199 "id": id,
200 "error": {
201 "code": code,
202 "message": message
203 }
204 })
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 struct EchoHandler;
212
213 #[async_trait]
214 impl ToolHandler for EchoHandler {
215 async fn call(&self, arguments: Value) -> ToolResult {
216 ToolResult::success(arguments)
217 }
218 }
219
220 fn test_server() -> McpServer {
221 let def = ToolDefinition::new("echo");
222 McpServer::builder("test-server", "0.1.0")
223 .tool("echo", def, Box::new(EchoHandler))
224 .build()
225 }
226
227 #[tokio::test]
228 async fn test_initialize() {
229 let server = test_server();
230 let req = json!({"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {}});
231 let resp = server.handle_request(req).await;
232 assert_eq!(resp["id"], 1);
233 assert_eq!(resp["result"]["serverInfo"]["name"], "test-server");
234 assert!(resp["result"]["capabilities"]["tools"].is_object());
235 }
236
237 #[tokio::test]
238 async fn test_tools_list() {
239 let server = test_server();
240 let req = json!({"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}});
241 let resp = server.handle_request(req).await;
242 let tools = resp["result"]["tools"].as_array().unwrap();
243 assert_eq!(tools.len(), 1);
244 assert_eq!(tools[0]["name"], "echo");
245 }
246
247 #[tokio::test]
248 async fn test_tools_call() {
249 let server = test_server();
250 let req = json!({
251 "jsonrpc": "2.0",
252 "id": 3,
253 "method": "tools/call",
254 "params": {
255 "name": "echo",
256 "arguments": {"msg": "hello"}
257 }
258 });
259 let resp = server.handle_request(req).await;
260 assert_eq!(resp["id"], 3);
261 assert_eq!(resp["result"]["isError"], false);
262 let content = resp["result"]["content"].as_array().unwrap();
263 assert!(content[0]["text"].as_str().unwrap().contains("hello"));
264 }
265
266 #[tokio::test]
267 async fn test_unknown_method() {
268 let server = test_server();
269 let req = json!({"jsonrpc": "2.0", "id": 4, "method": "unknown/method", "params": {}});
270 let resp = server.handle_request(req).await;
271 assert!(resp.get("error").is_some());
272 assert_eq!(resp["error"]["code"], -32601);
273 }
274
275 #[tokio::test]
276 async fn test_unknown_tool() {
277 let server = test_server();
278 let req = json!({
279 "jsonrpc": "2.0",
280 "id": 5,
281 "method": "tools/call",
282 "params": {"name": "nonexistent", "arguments": {}}
283 });
284 let resp = server.handle_request(req).await;
285 assert!(resp.get("error").is_some());
286 assert_eq!(resp["error"]["code"], -32602);
287 }
288}