1pub mod health;
8pub mod oauth;
9
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use tokio::sync::RwLock;
17
18use synaptic_core::{SynapticError, Tool};
19
20pub use health::{McpHealthHandle, McpHealthMonitor};
21pub use oauth::{McpOAuthConfig, OAuthTokenManager};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct StdioConnection {
30 pub command: String,
31 pub args: Vec<String>,
32 #[serde(default)]
33 pub env: HashMap<String, String>,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct SseConnection {
39 pub url: String,
40 #[serde(default)]
41 pub headers: HashMap<String, String>,
42 #[serde(default)]
44 pub oauth: Option<McpOAuthConfig>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct HttpConnection {
50 pub url: String,
51 #[serde(default)]
52 pub headers: HashMap<String, String>,
53 #[serde(default)]
55 pub oauth: Option<McpOAuthConfig>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60#[serde(tag = "type")]
61pub enum McpConnection {
62 Stdio(StdioConnection),
63 Sse(SseConnection),
64 Http(HttpConnection),
65}
66
67struct McpTool {
73 tool_name: &'static str,
74 tool_description: &'static str,
75 tool_parameters: Value,
76 #[expect(dead_code)]
77 server_name: String,
78 connection: McpConnection,
79 client: reqwest::Client,
80 oauth_manager: Option<Arc<OAuthTokenManager>>,
82}
83
84fn leak_string(s: String) -> &'static str {
89 Box::leak(s.into_boxed_str())
90}
91
92#[async_trait]
93impl Tool for McpTool {
94 fn name(&self) -> &'static str {
95 self.tool_name
96 }
97
98 fn description(&self) -> &'static str {
99 self.tool_description
100 }
101
102 fn parameters(&self) -> Option<Value> {
103 Some(self.tool_parameters.clone())
104 }
105
106 async fn call(&self, args: Value) -> Result<Value, SynapticError> {
107 match &self.connection {
108 McpConnection::Http(conn) => {
109 let headers = self.headers_with_oauth(&conn.headers).await?;
110 call_http(&self.client, &conn.url, &headers, self.tool_name, &args).await
111 }
112 McpConnection::Sse(conn) => {
113 let headers = self.headers_with_oauth(&conn.headers).await?;
115 call_http(&self.client, &conn.url, &headers, self.tool_name, &args).await
116 }
117 McpConnection::Stdio(conn) => call_stdio(conn, self.tool_name, &args).await,
118 }
119 }
120}
121
122impl McpTool {
123 async fn headers_with_oauth(
126 &self,
127 base_headers: &HashMap<String, String>,
128 ) -> Result<HashMap<String, String>, SynapticError> {
129 let mut headers = base_headers.clone();
130 if let Some(ref mgr) = self.oauth_manager {
131 let token = mgr.get_token().await?;
132 headers.insert("Authorization".to_string(), format!("Bearer {}", token));
133 }
134 Ok(headers)
135 }
136}
137
138async fn call_http(
144 client: &reqwest::Client,
145 url: &str,
146 headers: &HashMap<String, String>,
147 tool_name: &str,
148 args: &Value,
149) -> Result<Value, SynapticError> {
150 let request_body = serde_json::json!({
151 "jsonrpc": "2.0",
152 "method": "tools/call",
153 "params": {
154 "name": tool_name,
155 "arguments": args,
156 },
157 "id": 1
158 });
159
160 let mut builder = client.post(url);
161 for (key, value) in headers {
162 builder = builder.header(key.as_str(), value.as_str());
163 }
164 builder = builder.header("Content-Type", "application/json");
165
166 let resp = builder
167 .json(&request_body)
168 .send()
169 .await
170 .map_err(|e| SynapticError::Mcp(format!("HTTP request failed: {}", e)))?;
171
172 let body: Value = resp
173 .json()
174 .await
175 .map_err(|e| SynapticError::Mcp(format!("Failed to parse response: {}", e)))?;
176
177 if let Some(error) = body.get("error") {
178 return Err(SynapticError::Mcp(format!("MCP error: {}", error)));
179 }
180
181 body.get("result")
182 .cloned()
183 .ok_or_else(|| SynapticError::Mcp("No result in MCP response".to_string()))
184}
185
186async fn call_stdio(
189 conn: &StdioConnection,
190 tool_name: &str,
191 args: &Value,
192) -> Result<Value, SynapticError> {
193 use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
194 use tokio::process::Command;
195
196 let request_body = serde_json::json!({
197 "jsonrpc": "2.0",
198 "method": "tools/call",
199 "params": {
200 "name": tool_name,
201 "arguments": args,
202 },
203 "id": 1
204 });
205
206 let mut child = Command::new(&conn.command)
207 .args(&conn.args)
208 .envs(&conn.env)
209 .stdin(std::process::Stdio::piped())
210 .stdout(std::process::Stdio::piped())
211 .stderr(std::process::Stdio::null())
212 .spawn()
213 .map_err(|e| SynapticError::Mcp(format!("Failed to spawn process: {}", e)))?;
214
215 let stdin = child
216 .stdin
217 .as_mut()
218 .ok_or_else(|| SynapticError::Mcp("Failed to open stdin".to_string()))?;
219
220 let msg =
221 serde_json::to_string(&request_body).map_err(|e| SynapticError::Mcp(e.to_string()))?;
222
223 stdin
224 .write_all(msg.as_bytes())
225 .await
226 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
227 stdin
228 .write_all(b"\n")
229 .await
230 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
231 stdin
232 .flush()
233 .await
234 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
235
236 drop(child.stdin.take());
238
239 let stdout = child
240 .stdout
241 .take()
242 .ok_or_else(|| SynapticError::Mcp("Failed to open stdout".to_string()))?;
243 let mut reader = BufReader::new(stdout);
244 let mut line = String::new();
245 reader
246 .read_line(&mut line)
247 .await
248 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
249
250 let body: Value = serde_json::from_str(&line)
251 .map_err(|e| SynapticError::Mcp(format!("Failed to parse response: {}", e)))?;
252
253 let _ = child.kill().await;
254
255 if let Some(error) = body.get("error") {
256 return Err(SynapticError::Mcp(format!("MCP error: {}", error)));
257 }
258
259 body.get("result")
260 .cloned()
261 .ok_or_else(|| SynapticError::Mcp("No result in MCP response".to_string()))
262}
263
264async fn list_tools_http(
267 client: &reqwest::Client,
268 url: &str,
269 headers: &HashMap<String, String>,
270) -> Result<Value, SynapticError> {
271 let request_body = serde_json::json!({
272 "jsonrpc": "2.0",
273 "method": "tools/list",
274 "params": {},
275 "id": 1
276 });
277
278 let mut builder = client.post(url);
279 for (key, value) in headers {
280 builder = builder.header(key.as_str(), value.as_str());
281 }
282 builder = builder.header("Content-Type", "application/json");
283
284 let resp = builder
285 .json(&request_body)
286 .send()
287 .await
288 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
289
290 let body: Value = resp
291 .json()
292 .await
293 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
294
295 Ok(body
296 .get("result")
297 .and_then(|r| r.get("tools"))
298 .cloned()
299 .unwrap_or(Value::Array(vec![])))
300}
301
302async fn list_tools_stdio(conn: &StdioConnection) -> Result<Value, SynapticError> {
305 use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
306 use tokio::process::Command;
307
308 let request_body = serde_json::json!({
309 "jsonrpc": "2.0",
310 "method": "tools/list",
311 "params": {},
312 "id": 1
313 });
314
315 let mut child = Command::new(&conn.command)
316 .args(&conn.args)
317 .envs(&conn.env)
318 .stdin(std::process::Stdio::piped())
319 .stdout(std::process::Stdio::piped())
320 .stderr(std::process::Stdio::null())
321 .spawn()
322 .map_err(|e| SynapticError::Mcp(format!("Failed to spawn process: {}", e)))?;
323
324 let stdin = child
325 .stdin
326 .as_mut()
327 .ok_or_else(|| SynapticError::Mcp("Failed to open stdin".to_string()))?;
328
329 let msg =
330 serde_json::to_string(&request_body).map_err(|e| SynapticError::Mcp(e.to_string()))?;
331
332 stdin
333 .write_all(msg.as_bytes())
334 .await
335 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
336 stdin
337 .write_all(b"\n")
338 .await
339 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
340 stdin
341 .flush()
342 .await
343 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
344
345 drop(child.stdin.take());
347
348 let stdout = child
349 .stdout
350 .take()
351 .ok_or_else(|| SynapticError::Mcp("Failed to open stdout".to_string()))?;
352 let mut reader = BufReader::new(stdout);
353 let mut line = String::new();
354 reader
355 .read_line(&mut line)
356 .await
357 .map_err(|e| SynapticError::Mcp(e.to_string()))?;
358
359 let body: Value = serde_json::from_str(&line)
360 .map_err(|e| SynapticError::Mcp(format!("Failed to parse response: {}", e)))?;
361
362 let _ = child.kill().await;
363
364 if let Some(error) = body.get("error") {
365 return Err(SynapticError::Mcp(format!("MCP error: {}", error)));
366 }
367
368 Ok(body
369 .get("result")
370 .and_then(|r| r.get("tools"))
371 .cloned()
372 .unwrap_or(Value::Array(vec![])))
373}
374
375pub struct MultiServerMcpClient {
381 servers: HashMap<String, McpConnection>,
382 prefix_tool_names: bool,
383 tools: Arc<RwLock<Vec<Arc<dyn Tool>>>>,
384 oauth_managers: Arc<RwLock<HashMap<String, Arc<OAuthTokenManager>>>>,
386}
387
388impl MultiServerMcpClient {
389 pub fn new(servers: HashMap<String, McpConnection>) -> Self {
391 Self {
392 servers,
393 prefix_tool_names: true,
394 tools: Arc::new(RwLock::new(Vec::new())),
395 oauth_managers: Arc::new(RwLock::new(HashMap::new())),
396 }
397 }
398
399 pub fn with_prefix(mut self, prefix: bool) -> Self {
402 self.prefix_tool_names = prefix;
403 self
404 }
405
406 pub async fn connect(&self) -> Result<(), SynapticError> {
408 let client = reqwest::Client::new();
409 let mut all_tools = Vec::new();
410
411 let mut managers = self.oauth_managers.write().await;
413 for (server_name, connection) in &self.servers {
414 let oauth_config = match connection {
415 McpConnection::Http(conn) => conn.oauth.as_ref(),
416 McpConnection::Sse(conn) => conn.oauth.as_ref(),
417 McpConnection::Stdio(_) => None,
418 };
419 if let Some(config) = oauth_config {
420 if !managers.contains_key(server_name) {
421 managers.insert(
422 server_name.clone(),
423 Arc::new(OAuthTokenManager::new(config.clone())),
424 );
425 }
426 }
427 }
428 drop(managers);
429
430 for (server_name, connection) in &self.servers {
431 let oauth_manager = self.oauth_managers.read().await.get(server_name).cloned();
432 let tools = self
433 .discover_tools(server_name, connection, &client, oauth_manager)
434 .await?;
435 all_tools.extend(tools);
436 }
437
438 *self.tools.write().await = all_tools;
439 Ok(())
440 }
441
442 async fn discover_tools(
444 &self,
445 server_name: &str,
446 connection: &McpConnection,
447 client: &reqwest::Client,
448 oauth_manager: Option<Arc<OAuthTokenManager>>,
449 ) -> Result<Vec<Arc<dyn Tool>>, SynapticError> {
450 let tools_list = match connection {
451 McpConnection::Http(conn) => list_tools_http(client, &conn.url, &conn.headers).await?,
452 McpConnection::Sse(conn) => list_tools_http(client, &conn.url, &conn.headers).await?,
453 McpConnection::Stdio(conn) => list_tools_stdio(conn).await?,
454 };
455
456 let mut tools: Vec<Arc<dyn Tool>> = Vec::new();
457
458 if let Value::Array(tool_arr) = tools_list {
459 for tool_def in tool_arr {
460 let name = tool_def
461 .get("name")
462 .and_then(|n| n.as_str())
463 .unwrap_or("")
464 .to_string();
465 let description = tool_def
466 .get("description")
467 .and_then(|d| d.as_str())
468 .unwrap_or("")
469 .to_string();
470 let parameters = tool_def
471 .get("inputSchema")
472 .cloned()
473 .unwrap_or(serde_json::json!({"type": "object"}));
474
475 let tool_name = if self.prefix_tool_names {
476 format!("{}_{}", server_name, name)
477 } else {
478 name
479 };
480
481 tools.push(Arc::new(McpTool {
482 tool_name: leak_string(tool_name),
483 tool_description: leak_string(description),
484 tool_parameters: parameters,
485 server_name: server_name.to_string(),
486 connection: connection.clone(),
487 client: client.clone(),
488 oauth_manager: oauth_manager.clone(),
489 }));
490 }
491 }
492
493 Ok(tools)
494 }
495
496 pub async fn get_tools(&self) -> Vec<Arc<dyn Tool>> {
498 self.tools.read().await.clone()
499 }
500}
501
502pub async fn load_mcp_tools(
509 client: &MultiServerMcpClient,
510) -> Result<Vec<Arc<dyn Tool>>, SynapticError> {
511 client.connect().await?;
512 Ok(client.get_tools().await)
513}