1use std::collections::HashMap;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use tokio::sync::RwLock;
14
15use synaptic_core::{SynapticError, Tool};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct StdioConnection {
24 pub command: String,
25 pub args: Vec<String>,
26 #[serde(default)]
27 pub env: HashMap<String, String>,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct SseConnection {
33 pub url: String,
34 #[serde(default)]
35 pub headers: HashMap<String, String>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct HttpConnection {
41 pub url: String,
42 #[serde(default)]
43 pub headers: HashMap<String, String>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48#[serde(tag = "type")]
49pub enum McpConnection {
50 Stdio(StdioConnection),
51 Sse(SseConnection),
52 Http(HttpConnection),
53}
54
55struct McpTool {
61 tool_name: &'static str,
62 tool_description: &'static str,
63 tool_parameters: Value,
64 #[expect(dead_code)]
65 server_name: String,
66 connection: McpConnection,
67 client: reqwest::Client,
68}
69
70fn leak_string(s: String) -> &'static str {
75 Box::leak(s.into_boxed_str())
76}
77
78#[async_trait]
79impl Tool for McpTool {
80 fn name(&self) -> &'static str {
81 self.tool_name
82 }
83
84 fn description(&self) -> &'static str {
85 self.tool_description
86 }
87
88 fn parameters(&self) -> Option<Value> {
89 Some(self.tool_parameters.clone())
90 }
91
92 async fn call(&self, args: Value) -> Result<Value, SynapticError> {
93 match &self.connection {
94 McpConnection::Http(conn) => {
95 call_http(
96 &self.client,
97 &conn.url,
98 &conn.headers,
99 self.tool_name,
100 &args,
101 )
102 .await
103 }
104 McpConnection::Sse(conn) => {
105 call_http(
107 &self.client,
108 &conn.url,
109 &conn.headers,
110 self.tool_name,
111 &args,
112 )
113 .await
114 }
115 McpConnection::Stdio(conn) => call_stdio(conn, self.tool_name, &args).await,
116 }
117 }
118}
119
120async fn call_http(
126 client: &reqwest::Client,
127 url: &str,
128 headers: &HashMap<String, String>,
129 tool_name: &str,
130 args: &Value,
131) -> Result<Value, SynapticError> {
132 let request_body = serde_json::json!({
133 "jsonrpc": "2.0",
134 "method": "tools/call",
135 "params": {
136 "name": tool_name,
137 "arguments": args,
138 },
139 "id": 1
140 });
141
142 let mut builder = client.post(url);
143 for (key, value) in headers {
144 builder = builder.header(key.as_str(), value.as_str());
145 }
146 builder = builder.header("Content-Type", "application/json");
147
148 let resp = builder
149 .json(&request_body)
150 .send()
151 .await
152 .map_err(|e| SynapticError::Mcp(format!("HTTP request failed: {}", e)))?;
153
154 let body: Value = resp
155 .json()
156 .await
157 .map_err(|e| SynapticError::Mcp(format!("Failed to parse response: {}", e)))?;
158
159 if let Some(error) = body.get("error") {
160 return Err(SynapticError::Mcp(format!("MCP error: {}", error)));
161 }
162
163 body.get("result")
164 .cloned()
165 .ok_or_else(|| SynapticError::Mcp("No result in MCP response".to_string()))
166}
167
168async fn call_stdio(
171 conn: &StdioConnection,
172 tool_name: &str,
173 args: &Value,
174) -> Result<Value, SynapticError> {
175 use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
176 use tokio::process::Command;
177
178 let request_body = serde_json::json!({
179 "jsonrpc": "2.0",
180 "method": "tools/call",
181 "params": {
182 "name": tool_name,
183 "arguments": args,
184 },
185 "id": 1
186 });
187
188 let mut child = Command::new(&conn.command)
189 .args(&conn.args)
190 .envs(&conn.env)
191 .stdin(std::process::Stdio::piped())
192 .stdout(std::process::Stdio::piped())
193 .stderr(std::process::Stdio::null())
194 .spawn()
195 .map_err(|e| SynapticError::Mcp(format!("Failed to spawn process: {}", e)))?;
196
197 let stdin = child
198 .stdin
199 .as_mut()
200 .ok_or_else(|| SynapticError::Mcp("Failed to open stdin".to_string()))?;
201
202 let msg =
203 serde_json::to_string(&request_body).map_err(|e| SynapticError::Mcp(e.to_string()))?;
204
205 stdin
206 .write_all(msg.as_bytes())
207 .await
208 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
209 stdin
210 .write_all(b"\n")
211 .await
212 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
213 stdin
214 .flush()
215 .await
216 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
217
218 drop(child.stdin.take());
220
221 let stdout = child
222 .stdout
223 .take()
224 .ok_or_else(|| SynapticError::Mcp("Failed to open stdout".to_string()))?;
225 let mut reader = BufReader::new(stdout);
226 let mut line = String::new();
227 reader
228 .read_line(&mut line)
229 .await
230 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
231
232 let body: Value = serde_json::from_str(&line)
233 .map_err(|e| SynapticError::Mcp(format!("Failed to parse response: {}", e)))?;
234
235 let _ = child.kill().await;
236
237 if let Some(error) = body.get("error") {
238 return Err(SynapticError::Mcp(format!("MCP error: {}", error)));
239 }
240
241 body.get("result")
242 .cloned()
243 .ok_or_else(|| SynapticError::Mcp("No result in MCP response".to_string()))
244}
245
246async fn list_tools_http(
249 client: &reqwest::Client,
250 url: &str,
251 headers: &HashMap<String, String>,
252) -> Result<Value, SynapticError> {
253 let request_body = serde_json::json!({
254 "jsonrpc": "2.0",
255 "method": "tools/list",
256 "params": {},
257 "id": 1
258 });
259
260 let mut builder = client.post(url);
261 for (key, value) in headers {
262 builder = builder.header(key.as_str(), value.as_str());
263 }
264 builder = builder.header("Content-Type", "application/json");
265
266 let resp = builder
267 .json(&request_body)
268 .send()
269 .await
270 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
271
272 let body: Value = resp
273 .json()
274 .await
275 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
276
277 Ok(body
278 .get("result")
279 .and_then(|r| r.get("tools"))
280 .cloned()
281 .unwrap_or(Value::Array(vec![])))
282}
283
284async fn list_tools_stdio(conn: &StdioConnection) -> Result<Value, SynapticError> {
287 use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
288 use tokio::process::Command;
289
290 let request_body = serde_json::json!({
291 "jsonrpc": "2.0",
292 "method": "tools/list",
293 "params": {},
294 "id": 1
295 });
296
297 let mut child = Command::new(&conn.command)
298 .args(&conn.args)
299 .envs(&conn.env)
300 .stdin(std::process::Stdio::piped())
301 .stdout(std::process::Stdio::piped())
302 .stderr(std::process::Stdio::null())
303 .spawn()
304 .map_err(|e| SynapticError::Mcp(format!("Failed to spawn process: {}", e)))?;
305
306 let stdin = child
307 .stdin
308 .as_mut()
309 .ok_or_else(|| SynapticError::Mcp("Failed to open stdin".to_string()))?;
310
311 let msg =
312 serde_json::to_string(&request_body).map_err(|e| SynapticError::Mcp(e.to_string()))?;
313
314 stdin
315 .write_all(msg.as_bytes())
316 .await
317 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
318 stdin
319 .write_all(b"\n")
320 .await
321 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
322 stdin
323 .flush()
324 .await
325 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
326
327 drop(child.stdin.take());
329
330 let stdout = child
331 .stdout
332 .take()
333 .ok_or_else(|| SynapticError::Mcp("Failed to open stdout".to_string()))?;
334 let mut reader = BufReader::new(stdout);
335 let mut line = String::new();
336 reader
337 .read_line(&mut line)
338 .await
339 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
340
341 let body: Value = serde_json::from_str(&line)
342 .map_err(|e| SynapticError::Mcp(format!("Failed to parse response: {}", e)))?;
343
344 let _ = child.kill().await;
345
346 if let Some(error) = body.get("error") {
347 return Err(SynapticError::Mcp(format!("MCP error: {}", error)));
348 }
349
350 Ok(body
351 .get("result")
352 .and_then(|r| r.get("tools"))
353 .cloned()
354 .unwrap_or(Value::Array(vec![])))
355}
356
357pub struct MultiServerMcpClient {
363 servers: HashMap<String, McpConnection>,
364 prefix_tool_names: bool,
365 tools: Arc<RwLock<Vec<Arc<dyn Tool>>>>,
366}
367
368impl MultiServerMcpClient {
369 pub fn new(servers: HashMap<String, McpConnection>) -> Self {
371 Self {
372 servers,
373 prefix_tool_names: true,
374 tools: Arc::new(RwLock::new(Vec::new())),
375 }
376 }
377
378 pub fn with_prefix(mut self, prefix: bool) -> Self {
381 self.prefix_tool_names = prefix;
382 self
383 }
384
385 pub async fn connect(&self) -> Result<(), SynapticError> {
387 let client = reqwest::Client::new();
388 let mut all_tools = Vec::new();
389
390 for (server_name, connection) in &self.servers {
391 let tools = self
392 .discover_tools(server_name, connection, &client)
393 .await?;
394 all_tools.extend(tools);
395 }
396
397 *self.tools.write().await = all_tools;
398 Ok(())
399 }
400
401 async fn discover_tools(
403 &self,
404 server_name: &str,
405 connection: &McpConnection,
406 client: &reqwest::Client,
407 ) -> Result<Vec<Arc<dyn Tool>>, SynapticError> {
408 let tools_list = match connection {
409 McpConnection::Http(conn) => list_tools_http(client, &conn.url, &conn.headers).await?,
410 McpConnection::Sse(conn) => list_tools_http(client, &conn.url, &conn.headers).await?,
411 McpConnection::Stdio(conn) => list_tools_stdio(conn).await?,
412 };
413
414 let mut tools: Vec<Arc<dyn Tool>> = Vec::new();
415
416 if let Value::Array(tool_arr) = tools_list {
417 for tool_def in tool_arr {
418 let name = tool_def
419 .get("name")
420 .and_then(|n| n.as_str())
421 .unwrap_or("")
422 .to_string();
423 let description = tool_def
424 .get("description")
425 .and_then(|d| d.as_str())
426 .unwrap_or("")
427 .to_string();
428 let parameters = tool_def
429 .get("inputSchema")
430 .cloned()
431 .unwrap_or(serde_json::json!({"type": "object"}));
432
433 let tool_name = if self.prefix_tool_names {
434 format!("{}_{}", server_name, name)
435 } else {
436 name
437 };
438
439 tools.push(Arc::new(McpTool {
440 tool_name: leak_string(tool_name),
441 tool_description: leak_string(description),
442 tool_parameters: parameters,
443 server_name: server_name.to_string(),
444 connection: connection.clone(),
445 client: client.clone(),
446 }));
447 }
448 }
449
450 Ok(tools)
451 }
452
453 pub async fn get_tools(&self) -> Vec<Arc<dyn Tool>> {
455 self.tools.read().await.clone()
456 }
457}
458
459pub async fn load_mcp_tools(
466 client: &MultiServerMcpClient,
467) -> Result<Vec<Arc<dyn Tool>>, SynapticError> {
468 client.connect().await?;
469 Ok(client.get_tools().await)
470}