1use super::protocol::{McpMessage, McpRequest, RequestId, ToolDefinition, ToolResult};
8#[cfg(unix)]
9use super::transport::UnixSocketTransport;
10use super::transport::{McpTransport, McpTransportType, StdioTransport, TcpTransport};
11use anyhow::{anyhow, Result};
12use parking_lot::RwLock as ParkingRwLock;
13use serde_json::Value;
14use std::collections::HashMap;
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::Arc;
17use tracing::{debug, info, warn};
18
19#[derive(Debug, Clone)]
21pub struct McpServerConfig {
22 pub name: String,
24 pub transport: McpTransportType,
26 pub auto_reconnect: bool,
28 pub reconnect_interval_ms: u64,
30}
31
32struct ConnectedServer {
34 #[allow(dead_code)]
35 config: McpServerConfig,
36 transport: Arc<dyn McpTransport>,
37 tools: Vec<ToolDefinition>,
38}
39
40pub struct McpClientManager {
42 configs: ParkingRwLock<HashMap<String, McpServerConfig>>,
44 servers: ParkingRwLock<HashMap<String, ConnectedServer>>,
46 tool_mapping: ParkingRwLock<HashMap<String, String>>,
48 request_id_counter: AtomicU64,
50}
51
52impl Default for McpClientManager {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl McpClientManager {
59 pub fn new() -> Self {
61 Self {
62 configs: ParkingRwLock::new(HashMap::new()),
63 servers: ParkingRwLock::new(HashMap::new()),
64 tool_mapping: ParkingRwLock::new(HashMap::new()),
65 request_id_counter: AtomicU64::new(1),
66 }
67 }
68
69 fn next_request_id(&self) -> RequestId {
71 RequestId::Number(self.request_id_counter.fetch_add(1, Ordering::SeqCst) as i64)
72 }
73
74 pub async fn add_server(&self, config: McpServerConfig) -> Result<()> {
76 let name = config.name.clone();
77 let mut configs = self.configs.write();
78 configs.insert(name, config);
79 Ok(())
80 }
81
82 pub async fn connect(&self, name: &str) -> Result<()> {
84 let config = {
85 let configs = self.configs.read();
86 configs
87 .get(name)
88 .ok_or_else(|| anyhow!("Server not found: {}", name))?
89 .clone()
90 };
91
92 info!(server = %name, transport = ?config.transport, "Connecting to MCP server");
93
94 let transport: Arc<dyn McpTransport> = match &config.transport {
96 McpTransportType::Stdio { command, args } => {
97 let stdio_transport = StdioTransport::new(command, args)?;
99 stdio_transport.start(command, args).await?;
100 info!(server = %name, command = %command, "Stdio transport started");
101 Arc::new(stdio_transport)
102 }
103 McpTransportType::Tcp { addr } => {
104 let tcp_transport = TcpTransport::connect(addr).await?;
106 info!(server = %name, addr = %addr, "TCP transport connected");
107 Arc::new(tcp_transport)
108 }
109 #[cfg(unix)]
110 McpTransportType::Unix { path } => {
111 let unix_transport = UnixSocketTransport::connect(path).await?;
113 info!(server = %name, path = %path, "Unix socket transport connected");
114 Arc::new(unix_transport)
115 }
116 };
117
118 let init_params = serde_json::json!({
120 "protocolVersion": "2024-11-05",
121 "capabilities": {
122 "roots": {
123 "listChanged": true
124 },
125 "sampling": {}
126 },
127 "clientInfo": {
128 "name": "continuum",
129 "version": env!("CARGO_PKG_VERSION")
130 }
131 });
132
133 let request = McpRequest {
134 id: self.next_request_id(),
135 method: "initialize".to_string(),
136 params: Some(init_params),
137 };
138
139 debug!(server = %name, "Sending initialize request");
140 transport.send(&McpMessage::Request(request)).await?;
141
142 let response =
144 tokio::time::timeout(std::time::Duration::from_secs(30), transport.receive())
145 .await
146 .map_err(|_| anyhow!("Initialize timeout for server: {}", name))??;
147
148 match response {
149 Some(McpMessage::Response(response)) => {
150 if let Some(error) = &response.error {
151 warn!(server = %name, code = ?error.code, message = %error.message, "Initialize failed");
152 return Err(anyhow!(
153 "Initialize failed (code {}): {}",
154 error.code,
155 error.message
156 ));
157 }
158
159 if let Some(result) = &response.result {
161 debug!(server = %name, result = ?result, "Server capabilities received");
162 if let Some(server_info) = result.get("serverInfo") {
163 info!(server = %name, server_info = ?server_info, "Connected to MCP server");
164 }
165 }
166 }
167 Some(McpMessage::Error(error)) => {
168 warn!(server = %name, error = ?error, "Received error response");
169 return Err(anyhow!("Server error: {:?}", error));
170 }
171 Some(other) => {
172 warn!(server = %name, message = ?other, "Unexpected message type");
173 return Err(anyhow!("Unexpected response type during initialization"));
174 }
175 None => {
176 warn!(server = %name, "No response received");
177 return Err(anyhow!("No response from server during initialization"));
178 }
179 }
180
181 let notification = McpMessage::Notification(super::protocol::McpNotification {
183 method: "notifications/initialized".to_string(),
184 params: None,
185 });
186 transport.send(¬ification).await?;
187 debug!(server = %name, "Sent initialized notification");
188
189 let list_tools_request = McpRequest {
191 id: self.next_request_id(),
192 method: "tools/list".to_string(),
193 params: None,
194 };
195 transport
196 .send(&McpMessage::Request(list_tools_request))
197 .await?;
198
199 let tools_response =
200 tokio::time::timeout(std::time::Duration::from_secs(10), transport.receive())
201 .await
202 .map_err(|_| anyhow!("Tools list timeout for server: {}", name))??;
203
204 let tools = match tools_response {
205 Some(McpMessage::Response(response)) => {
206 if let Some(result) = &response.result {
207 if let Some(tools_array) = result.get("tools") {
208 match serde_json::from_value::<Vec<ToolDefinition>>(tools_array.clone()) {
209 Ok(t) => {
210 info!(server = %name, tool_count = t.len(), "Tools discovered");
211 t
212 }
213 Err(e) => {
214 warn!(server = %name, error = %e, "Failed to parse tools");
215 Vec::new()
216 }
217 }
218 } else {
219 Vec::new()
220 }
221 } else {
222 Vec::new()
223 }
224 }
225 _ => Vec::new(),
226 };
227
228 let mut servers = self.servers.write();
230 servers.insert(
231 name.to_string(),
232 ConnectedServer {
233 config,
234 transport,
235 tools,
236 },
237 );
238
239 info!(server = %name, "MCP connection established successfully");
240 Ok(())
241 }
242
243 pub async fn connect_all(&self) -> Result<Vec<String>> {
245 let names: Vec<String> = self.configs.read().keys().cloned().collect();
246 let mut results = Vec::new();
247
248 for name in &names {
249 if self.connect(name).await.is_ok() {
250 results.push(name.clone());
251 }
252 }
253
254 Ok(results)
255 }
256
257 pub async fn disconnect(&self, name: &str) -> Result<()> {
259 let server = {
260 let mut servers = self.servers.write();
261 servers.remove(name)
262 };
263 if let Some(s) = server {
264 s.transport.close().await?;
265 }
266 Ok(())
267 }
268
269 pub fn get_server_status(&self, name: &str) -> Option<bool> {
271 let servers = self.servers.read();
272 servers.get(name).map(|_| true)
273 }
274
275 pub fn list_servers(&self) -> Vec<(String, bool)> {
277 let servers = self.servers.read();
278 let configs = self.configs.read();
279
280 let mut result = Vec::new();
281 for name in configs.keys() {
282 let connected = servers.contains_key(name);
283 result.push((name.clone(), connected));
284 }
285 result
286 }
287
288 pub fn list_all_tools(&self) -> Vec<(String, ToolDefinition)> {
290 let servers = self.servers.read();
291 let mut tools = Vec::new();
292
293 for (server_name, server) in servers.iter() {
294 for tool in &server.tools {
295 tools.push((server_name.clone(), tool.clone()));
296 }
297 }
298
299 tools
300 }
301
302 pub async fn call_tool(&self, tool_name: &str, arguments: Value) -> Result<ToolResult> {
304 let server_name = {
306 let tool_mapping = self.tool_mapping.read();
307 tool_mapping
308 .get(tool_name)
309 .ok_or_else(|| anyhow!("Tool not found: {}", tool_name))?
310 .clone()
311 };
312
313 let (transport, request_id) = {
315 let servers = self.servers.read();
316 let server = servers
317 .get(&server_name)
318 .ok_or_else(|| anyhow!("Server not found: {}", server_name))?;
319
320 (server.transport.clone(), self.next_request_id())
321 };
322
323 let params = serde_json::json!({
325 "name": tool_name,
326 "arguments": arguments
327 });
328
329 let request = McpRequest {
330 id: request_id,
331 method: "tools/call".to_string(),
332 params: Some(params),
333 };
334
335 transport.send(&McpMessage::Request(request)).await?;
337
338 match transport.receive().await? {
340 Some(McpMessage::Response(response)) => {
341 if let Some(error) = response.error {
342 return Err(anyhow!("Tool call error: {}", error.message));
343 }
344
345 if let Some(result) = response.result {
346 let tool_result: ToolResult =
347 serde_json::from_value(result).unwrap_or_else(|_| ToolResult {
348 is_error: false,
349 content: vec![super::protocol::ContentBlock::Text {
350 text: "Tool executed successfully".to_string(),
351 }],
352 });
353 Ok(tool_result)
354 } else {
355 Err(anyhow!("Empty response"))
356 }
357 }
358 Some(McpMessage::Error(error)) => Err(anyhow!("Error: {:?}", error)),
359 _ => Err(anyhow!("Unexpected response type")),
360 }
361 }
362
363 pub async fn register_tools(
365 &self,
366 server_name: &str,
367 tools: Vec<ToolDefinition>,
368 ) -> Result<()> {
369 let mut servers = self.servers.write();
370 let server = servers
371 .get_mut(server_name)
372 .ok_or_else(|| anyhow!("Server not found: {}", server_name))?;
373
374 let mut tool_mapping = self.tool_mapping.write();
375 for tool in &tools {
376 tool_mapping.insert(tool.name.clone(), server_name.to_string());
377 }
378
379 server.tools = tools;
380 Ok(())
381 }
382
383 pub fn render_status(&self) -> String {
385 let servers = self.servers.read();
386 let configs = self.configs.read();
387 let mut output = String::new();
388
389 output.push_str("MCP Servers:\n");
390
391 if configs.is_empty() {
392 output.push_str(" No servers configured\n");
393 } else {
394 for name in configs.keys() {
395 let server = servers.get(name);
396 let status = if server.is_some() { "🟢" } else { "🔴" };
397 let tool_count = server.map(|s| s.tools.len()).unwrap_or(0);
398 output.push_str(&format!(" {} {} ({} tools)\n", status, name, tool_count));
399 }
400 }
401
402 output
403 }
404}
405
406pub fn preset_servers() -> Vec<McpServerConfig> {
408 vec![
409 McpServerConfig {
411 name: "filesystem".to_string(),
412 transport: McpTransportType::Stdio {
413 command: "mcp-server-filesystem".to_string(),
414 args: vec!["--root".to_string(), ".".to_string()],
415 },
416 auto_reconnect: true,
417 reconnect_interval_ms: 5000,
418 },
419 McpServerConfig {
421 name: "github".to_string(),
422 transport: McpTransportType::Stdio {
423 command: "mcp-server-github".to_string(),
424 args: vec![],
425 },
426 auto_reconnect: true,
427 reconnect_interval_ms: 5000,
428 },
429 McpServerConfig {
431 name: "playwright".to_string(),
432 transport: McpTransportType::Stdio {
433 command: "npx".to_string(),
434 args: vec![
435 "@playwright/mcp@latest".to_string(),
436 "--headless".to_string(),
437 "--browser".to_string(),
438 "chrome".to_string(),
439 ],
440 },
441 auto_reconnect: true,
442 reconnect_interval_ms: 5000,
443 },
444 ]
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 #[tokio::test]
452 async fn test_manager_creation() {
453 let manager = McpClientManager::new();
454 let servers = manager.list_servers();
455 assert!(servers.is_empty());
456 }
457
458 #[tokio::test]
459 async fn test_add_server() {
460 let manager = McpClientManager::new();
461
462 let config = McpServerConfig {
463 name: "test".to_string(),
464 transport: McpTransportType::Stdio {
465 command: "test-command".to_string(),
466 args: vec![],
467 },
468 auto_reconnect: false,
469 reconnect_interval_ms: 1000,
470 };
471
472 manager.add_server(config).await.unwrap();
473 let servers = manager.list_servers();
474 assert_eq!(servers.len(), 1);
475 assert_eq!(servers[0].0, "test");
476 assert!(!servers[0].1); }
478
479 #[test]
480 fn test_preset_servers() {
481 let presets = preset_servers();
482 assert!(!presets.is_empty());
483 assert!(presets.iter().any(|s| s.name == "filesystem"));
484 assert!(presets.iter().any(|s| s.name == "github"));
485 }
486
487 #[test]
488 fn test_request_id_generation() {
489 let manager = McpClientManager::new();
490 let id1 = manager.next_request_id();
491 let id2 = manager.next_request_id();
492 assert_ne!(id1, id2);
493 }
494}