1use std::collections::HashMap;
7use std::sync::Arc;
8
9use futures_util::future::join_all;
10use serde_json::Value;
11use synwire_core::mcp::traits::{McpServerStatus, McpTransport};
12use tokio::sync::RwLock;
13
14use crate::callbacks::McpCallbacks;
15use crate::error::McpAdapterError;
16use crate::session::McpClientSession;
17
18#[derive(Debug, Clone)]
26#[non_exhaustive]
27pub enum Connection {
28 Stdio {
30 command: String,
32 args: Vec<String>,
34 env: HashMap<String, String>,
36 },
37
38 Sse {
40 url: String,
42 auth_token: Option<String>,
44 timeout_secs: Option<u64>,
46 },
47
48 StreamableHttp {
50 url: String,
52 auth_token: Option<String>,
54 timeout_secs: Option<u64>,
56 },
57
58 WebSocket {
60 url: String,
62 auth_token: Option<String>,
64 },
65}
66
67impl Connection {
68 pub fn into_transport(
78 self,
79 name: &str,
80 ) -> Result<Box<dyn McpTransport>, synwire_core::agents::error::AgentError> {
81 match self {
82 Self::Stdio { command, args, env } => Ok(Box::new(
83 synwire_agent::mcp::StdioMcpTransport::new(name, command, args, env),
84 )),
85 Self::Sse {
86 url,
87 auth_token,
88 timeout_secs,
89 }
90 | Self::StreamableHttp {
91 url,
92 auth_token,
93 timeout_secs,
94 } => Ok(Box::new(synwire_agent::mcp::HttpMcpTransport::try_new(
95 name,
96 url,
97 auth_token,
98 timeout_secs,
99 )?)),
100 Self::WebSocket { url, auth_token } => Ok(Box::new(
101 crate::transport::WebSocketMcpTransport::new(name, url, auth_token),
102 )),
103 }
104 }
105}
106
107struct ServerEntry {
112 session: McpClientSession,
113 tool_name_prefix: Option<String>,
115}
116
117impl std::fmt::Debug for ServerEntry {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 f.debug_struct("ServerEntry")
120 .field("session", &self.session)
121 .field("tool_name_prefix", &self.tool_name_prefix)
122 .finish()
123 }
124}
125
126#[derive(Debug, Default)]
132pub struct MultiServerMcpClientConfig {
133 pub servers: HashMap<String, Connection>,
135 pub global_tool_prefix: Option<String>,
139 pub server_prefixes: HashMap<String, String>,
141}
142
143impl MultiServerMcpClientConfig {
144 #[must_use]
146 pub fn new() -> Self {
147 Self::default()
148 }
149
150 #[must_use]
152 pub fn with_server(mut self, name: impl Into<String>, connection: Connection) -> Self {
153 let _ = self.servers.insert(name.into(), connection);
154 self
155 }
156
157 #[must_use]
159 pub fn with_server_prefix(
160 mut self,
161 server_name: impl Into<String>,
162 prefix: impl Into<String>,
163 ) -> Self {
164 let _ = self
165 .server_prefixes
166 .insert(server_name.into(), prefix.into());
167 self
168 }
169
170 #[must_use]
173 pub fn with_global_prefix(mut self, prefix: impl Into<String>) -> Self {
174 self.global_tool_prefix = Some(prefix.into());
175 self
176 }
177}
178
179pub struct MultiServerMcpClient {
193 servers: Arc<RwLock<HashMap<String, ServerEntry>>>,
194 callbacks: Arc<McpCallbacks>,
196}
197
198impl std::fmt::Debug for MultiServerMcpClient {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 f.debug_struct("MultiServerMcpClient")
201 .field("callbacks", &self.callbacks)
202 .finish_non_exhaustive()
203 }
204}
205
206impl MultiServerMcpClient {
207 pub async fn connect(
219 config: MultiServerMcpClientConfig,
220 callbacks: McpCallbacks,
221 ) -> Result<Self, McpAdapterError> {
222 let callbacks = Arc::new(callbacks);
223
224 let MultiServerMcpClientConfig {
226 servers,
227 server_prefixes,
228 global_tool_prefix,
229 } = config;
230
231 let connect_futures: Vec<_> = servers
233 .into_iter()
234 .map(|(name, conn)| {
235 let prefix = server_prefixes
236 .get(&name)
237 .cloned()
238 .or_else(|| global_tool_prefix.clone());
239 let transport_result = conn.into_transport(&name);
240 async move {
241 let transport: Arc<dyn McpTransport> = match transport_result {
242 Ok(t) => Arc::from(t),
243 Err(e) => {
244 tracing::error!(server = %name, error = %e, "Failed to build transport");
245 return None;
246 }
247 };
248 match McpClientSession::connect(name.clone(), transport).await {
249 Ok(mut session) => {
250 if let Err(e) = session.populate_tool_cache().await {
252 tracing::warn!(
253 server = %name,
254 error = %e,
255 "Failed to populate tool cache"
256 );
257 }
258 Some((
259 name,
260 ServerEntry {
261 session,
262 tool_name_prefix: prefix,
263 },
264 ))
265 }
266 Err(e) => {
267 tracing::error!(
268 server = %name,
269 error = %e,
270 "Failed to connect to MCP server"
271 );
272 None
273 }
274 }
275 }
276 })
277 .collect();
278
279 let results = join_all(connect_futures).await;
280 let servers: HashMap<String, ServerEntry> = results.into_iter().flatten().collect();
281
282 tracing::info!(connected = servers.len(), "MultiServerMcpClient connected");
283
284 Ok(Self {
285 servers: Arc::new(RwLock::new(servers)),
286 callbacks,
287 })
288 }
289
290 pub async fn get_tool_descriptors(&self) -> Vec<AggregatedToolDescriptor> {
294 let servers = self.servers.read().await;
295 let mut tools = Vec::new();
296
297 for (server_name, entry) in servers.iter() {
298 for descriptor in entry.session.cached_tools() {
299 let exposed_name = entry.tool_name_prefix.as_ref().map_or_else(
300 || descriptor.name.clone(),
301 |prefix| format!("{prefix}/{}", descriptor.name),
302 );
303 tools.push(AggregatedToolDescriptor {
304 exposed_name,
305 server_name: server_name.clone(),
306 original_name: descriptor.name.clone(),
307 description: descriptor.description.clone(),
308 input_schema: descriptor.input_schema.clone(),
309 });
310 }
311 }
312 drop(servers);
313
314 tools
315 }
316
317 #[allow(clippy::significant_drop_tightening)]
319 pub async fn health(&self) -> Vec<McpServerStatus> {
320 let servers = self.servers.read().await;
321 let status_futures: Vec<_> = servers
322 .values()
323 .map(|entry| entry.session.status())
324 .collect();
325 join_all(status_futures).await
326 }
327
328 pub async fn call_tool(
335 &self,
336 exposed_tool_name: &str,
337 arguments: Value,
338 ) -> Result<Value, McpAdapterError> {
339 let (server_name, original_name, transport) = {
342 let servers = self.servers.read().await;
343
344 let routing = servers.iter().find_map(|(server_name, entry)| {
345 for descriptor in entry.session.cached_tools() {
346 let exposed = entry.tool_name_prefix.as_ref().map_or_else(
347 || descriptor.name.clone(),
348 |prefix| format!("{prefix}/{}", descriptor.name),
349 );
350 if exposed == exposed_tool_name {
351 return Some((server_name.clone(), descriptor.name.clone()));
352 }
353 }
354 None
355 });
356
357 let (server_name, original_name) =
358 routing.ok_or_else(|| McpAdapterError::ToolNotFound {
359 name: exposed_tool_name.to_owned(),
360 })?;
361
362 let transport = servers
363 .get(&server_name)
364 .ok_or_else(|| McpAdapterError::ServerNotFound {
365 name: server_name.clone(),
366 })?
367 .session
368 .transport()
369 .clone();
370 drop(servers);
371
372 (server_name, original_name, transport)
373 };
374
375 transport
376 .call_tool(&original_name, arguments)
377 .await
378 .map_err(|e| McpAdapterError::Transport {
379 message: format!("Tool '{original_name}' on server '{server_name}' failed: {e}"),
380 })
381 }
382
383 #[must_use]
385 pub fn callbacks(&self) -> &McpCallbacks {
386 &self.callbacks
387 }
388}
389
390#[derive(Debug, Clone)]
396pub struct AggregatedToolDescriptor {
397 pub exposed_name: String,
399 pub server_name: String,
401 pub original_name: String,
403 pub description: String,
405 pub input_schema: Value,
407}
408
409#[cfg(test)]
410#[allow(clippy::unwrap_used)]
411mod tests {
412 use super::*;
413 use crate::pagination::PaginationCursor;
414
415 #[test]
416 fn connection_enum_variants_exist() {
417 let _stdio = Connection::Stdio {
418 command: "mcp-server".into(),
419 args: vec![],
420 env: HashMap::new(),
421 };
422 let _ws = Connection::WebSocket {
423 url: "ws://localhost:3000".into(),
424 auth_token: None,
425 };
426 let _sse = Connection::Sse {
427 url: "http://localhost:3000/sse".into(),
428 auth_token: None,
429 timeout_secs: None,
430 };
431 let _http = Connection::StreamableHttp {
432 url: "http://localhost:3000".into(),
433 auth_token: None,
434 timeout_secs: None,
435 };
436 }
437
438 #[test]
439 fn config_builder() {
440 let config = MultiServerMcpClientConfig::new()
441 .with_server(
442 "s1",
443 Connection::WebSocket {
444 url: "ws://localhost:3000".into(),
445 auth_token: None,
446 },
447 )
448 .with_server_prefix("s1", "srv1")
449 .with_global_prefix("global");
450
451 assert!(config.servers.contains_key("s1"));
452 assert_eq!(config.server_prefixes.get("s1"), Some(&"srv1".to_owned()));
453 assert_eq!(config.global_tool_prefix, Some("global".to_owned()));
454 }
455
456 #[test]
457 fn pagination_used_in_client_context() {
458 let mut cursor = PaginationCursor::new();
460 assert!(cursor.advance(Some("token1".into())));
461 assert!(!cursor.advance(None));
462 }
463}