1use std::collections::HashMap;
2use std::fs;
3
4use anyhow::{Context, Result};
5
6use crate::network_policy::NetworkPolicyDecider;
7
8use super::config::McpConfig;
9
10#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
12pub struct McpReloadReport {
13 pub removed: Vec<String>,
14 pub updated: Vec<String>,
15 pub connected: Vec<String>,
16 pub connect_errors: Vec<(String, String)>,
17}
18use super::connection::McpConnection;
19use super::types::{McpPrompt, McpResource, McpResourceTemplate, McpTool};
20
21pub struct McpPool {
25 pub(super) connections: HashMap<String, McpConnection>,
26 config: McpConfig,
27 network_policy: Option<NetworkPolicyDecider>,
28}
29
30impl McpPool {
31 pub fn new(config: McpConfig) -> Self {
33 Self {
34 connections: HashMap::new(),
35 config,
36 network_policy: None,
37 }
38 }
39
40 pub fn from_config_path(path: &std::path::Path) -> Result<Self> {
42 let config = if path.exists() {
43 let contents = fs::read_to_string(path)
44 .with_context(|| format!("Failed to read MCP config: {}", path.display()))?;
45 serde_json::from_str(&contents)
46 .with_context(|| format!("Failed to parse MCP config: {}", path.display()))?
47 } else {
48 McpConfig::default()
49 };
50 Ok(Self::new(config))
51 }
52
53 pub fn with_network_policy(mut self, policy: NetworkPolicyDecider) -> Self {
56 self.network_policy = Some(policy);
57 self
58 }
59
60 pub async fn get_or_connect(&mut self, server_name: &str) -> Result<&mut McpConnection> {
62 let is_ready = self
63 .connections
64 .get(server_name)
65 .map(|conn| conn.is_ready())
66 .unwrap_or(false);
67 if is_ready {
68 return self
69 .connections
70 .get_mut(server_name)
71 .ok_or_else(|| anyhow::anyhow!("MCP connection disappeared for {server_name}"));
72 }
73
74 self.connections.remove(server_name);
75
76 let server_config = self
77 .config
78 .servers
79 .get(server_name)
80 .ok_or_else(|| anyhow::anyhow!("Failed to find MCP server: {server_name}"))?
81 .clone();
82
83 if !server_config.is_enabled() {
84 anyhow::bail!("Failed to connect MCP server '{server_name}': server is disabled");
85 }
86
87 let connection = McpConnection::connect_with_policy(
88 server_name.to_string(),
89 server_config,
90 &self.config.timeouts,
91 self.network_policy.as_ref(),
92 )
93 .await?;
94
95 self.connections.insert(server_name.to_string(), connection);
96 self.connections
97 .get_mut(server_name)
98 .ok_or_else(|| anyhow::anyhow!("Failed to store MCP connection for {server_name}"))
99 }
100
101 pub async fn connect_all(&mut self) -> Vec<(String, anyhow::Error)> {
103 let mut errors = Vec::new();
104 let names: Vec<String> = self
105 .config
106 .servers
107 .keys()
108 .filter(|n| self.config.servers[*n].is_enabled())
109 .cloned()
110 .collect();
111
112 for name in names {
113 if let Err(e) = self.get_or_connect(&name).await {
114 errors.push((name, e));
115 }
116 }
117
118 for (name, server_cfg) in &self.config.servers {
119 if server_cfg.required
120 && server_cfg.is_enabled()
121 && !self
122 .connections
123 .get(name)
124 .is_some_and(McpConnection::is_ready)
125 {
126 errors.push((
127 name.clone(),
128 anyhow::anyhow!("required MCP server failed to initialize"),
129 ));
130 }
131 }
132
133 errors
134 }
135
136 pub fn all_tools(&self) -> Vec<(String, &McpTool)> {
138 let mut tools = Vec::new();
139 for (server, conn) in &self.connections {
140 for tool in conn.tools() {
141 if !conn.config().is_tool_enabled(&tool.name) {
142 continue;
143 }
144 tools.push((format!("mcp_{}_{}", server, tool.name), tool));
146 }
147 }
148 tools
149 }
150
151 pub fn all_resources(&self) -> Vec<(String, &McpResource)> {
153 let mut resources = Vec::new();
154 for (server, conn) in &self.connections {
155 for resource in conn.resources() {
156 let safe_name = resource.name.replace(' ', "_").to_lowercase();
159 resources.push((format!("mcp_{}_{}", server, safe_name), resource));
160 }
161 }
162 resources
163 }
164
165 #[allow(dead_code)] pub fn all_resource_templates(&self) -> Vec<(String, &McpResourceTemplate)> {
168 let mut templates = Vec::new();
169 for (server, conn) in &self.connections {
170 for template in conn.resource_templates() {
171 let safe_name = template.name.replace(' ', "_").to_lowercase();
172 templates.push((format!("mcp_{}_{}", server, safe_name), template));
173 }
174 }
175 templates
176 }
177
178 async fn list_resources(&mut self, server: Option<String>) -> Result<Vec<serde_json::Value>> {
179 if let Some(server_name) = server {
180 let conn = self.get_or_connect(&server_name).await?;
181 let resources = conn
182 .resources()
183 .iter()
184 .map(|resource| {
185 serde_json::json!({
186 "server": server_name.clone(),
187 "uri": resource.uri,
188 "name": resource.name,
189 "description": resource.description,
190 "mime_type": resource.mime_type,
191 })
192 })
193 .collect();
194 return Ok(resources);
195 }
196
197 let _ = self.connect_all().await;
198 let mut items = Vec::new();
199 for (server, conn) in &self.connections {
200 for resource in conn.resources() {
201 items.push(serde_json::json!({
202 "server": server,
203 "uri": resource.uri,
204 "name": resource.name,
205 "description": resource.description,
206 "mime_type": resource.mime_type,
207 }));
208 }
209 }
210 Ok(items)
211 }
212
213 async fn list_resource_templates(
214 &mut self,
215 server: Option<String>,
216 ) -> Result<Vec<serde_json::Value>> {
217 if let Some(server_name) = server {
218 let conn = self.get_or_connect(&server_name).await?;
219 let templates = conn
220 .resource_templates()
221 .iter()
222 .map(|template| {
223 serde_json::json!({
224 "server": server_name.clone(),
225 "uri_template": template.uri_template,
226 "name": template.name,
227 "description": template.description,
228 "mime_type": template.mime_type,
229 })
230 })
231 .collect();
232 return Ok(templates);
233 }
234
235 let _ = self.connect_all().await;
236 let mut items = Vec::new();
237 for (server, conn) in &self.connections {
238 for template in conn.resource_templates() {
239 items.push(serde_json::json!({
240 "server": server,
241 "uri_template": template.uri_template,
242 "name": template.name,
243 "description": template.description,
244 "mime_type": template.mime_type,
245 }));
246 }
247 }
248 Ok(items)
249 }
250
251 pub fn all_prompts(&self) -> Vec<(String, &McpPrompt)> {
253 let mut prompts = Vec::new();
254 for (server, conn) in &self.connections {
255 for prompt in conn.prompts() {
256 prompts.push((format!("mcp_{}_{}", server, prompt.name), prompt));
258 }
259 }
260 prompts
261 }
262
263 pub async fn read_resource(
265 &mut self,
266 server_name: &str,
267 uri: &str,
268 ) -> Result<serde_json::Value> {
269 let global_timeouts = self.config.timeouts;
270 let conn = self.get_or_connect(server_name).await?;
271 let timeout = conn.config().effective_read_timeout(&global_timeouts);
272 conn.read_resource(uri, timeout).await
273 }
274
275 pub async fn get_prompt(
277 &mut self,
278 server_name: &str,
279 prompt_name: &str,
280 arguments: serde_json::Value,
281 ) -> Result<serde_json::Value> {
282 let global_timeouts = self.config.timeouts;
283 let conn = self.get_or_connect(server_name).await?;
284 let timeout = conn.config().effective_execute_timeout(&global_timeouts);
285 conn.get_prompt(prompt_name, arguments, timeout).await
286 }
287
288 pub(super) fn parse_prefixed_name<'a>(
297 &self,
298 prefixed_name: &'a str,
299 ) -> Result<(&'a str, &'a str)> {
300 let rest = prefixed_name
301 .strip_prefix("mcp_")
302 .ok_or_else(|| anyhow::anyhow!("Invalid MCP tool name: {prefixed_name}"))?;
303
304 let mut servers: Vec<&str> = self.config.servers.keys().map(String::as_str).collect();
305 servers.sort_by_key(|name| std::cmp::Reverse(name.len()));
306 for server in servers {
307 if let Some(tool) = rest
308 .strip_prefix(server)
309 .and_then(|tail| tail.strip_prefix('_'))
310 && !tool.is_empty()
311 {
312 return Ok((&rest[..server.len()], tool));
313 }
314 }
315
316 rest.split_once('_')
317 .filter(|(server, tool)| !server.is_empty() && !tool.is_empty())
318 .ok_or_else(|| anyhow::anyhow!("Invalid MCP tool name format: {prefixed_name}"))
319 }
320
321 pub fn to_api_tools(&self) -> Vec<crate::models::Tool> {
323 let mut api_tools = Vec::new();
324
325 for (name, tool) in self.all_tools() {
327 api_tools.push(crate::models::Tool {
328 tool_type: None,
329 name,
330 description: tool.description.clone().unwrap_or_default(),
331 input_schema: tool.input_schema.clone(),
332 allowed_callers: Some(vec!["direct".to_string()]),
333 defer_loading: Some(false),
334 input_examples: None,
335 strict: None,
336 cache_control: None,
337 });
338 }
339
340 if !self.config.servers.is_empty() {
341 api_tools.push(crate::models::Tool {
342 tool_type: None,
343 name: "list_mcp_resources".to_string(),
344 description: "List available MCP resources across servers (optionally filtered by server).".to_string(),
345 input_schema: serde_json::json!({
346 "type": "object",
347 "properties": {
348 "server": { "type": "string", "description": "Optional MCP server name to filter by" }
349 }
350 }),
351 allowed_callers: Some(vec!["direct".to_string()]),
352 defer_loading: Some(false),
353 input_examples: None,
354 strict: None,
355 cache_control: None,
356 });
357 api_tools.push(crate::models::Tool {
358 tool_type: None,
359 name: "list_mcp_resource_templates".to_string(),
360 description: "List available MCP resource templates across servers (optionally filtered by server).".to_string(),
361 input_schema: serde_json::json!({
362 "type": "object",
363 "properties": {
364 "server": { "type": "string", "description": "Optional MCP server name to filter by" }
365 }
366 }),
367 allowed_callers: Some(vec!["direct".to_string()]),
368 defer_loading: Some(false),
369 input_examples: None,
370 strict: None,
371 cache_control: None,
372 });
373 }
374
375 let resources = self.all_resources();
377 if !resources.is_empty() {
378 api_tools.push(crate::models::Tool {
379 tool_type: None,
380 name: "mcp_read_resource".to_string(),
381 description: "Read a resource from an MCP server using its URI".to_string(),
382 input_schema: serde_json::json!({
383 "type": "object",
384 "properties": {
385 "server": { "type": "string", "description": "The name of the MCP server" },
386 "uri": { "type": "string", "description": "The URI of the resource to read" }
387 },
388 "required": ["server", "uri"]
389 }),
390 allowed_callers: Some(vec!["direct".to_string()]),
391 defer_loading: Some(false),
392 input_examples: None,
393 strict: None,
394 cache_control: None,
395 });
396 api_tools.push(crate::models::Tool {
397 tool_type: None,
398 name: "read_mcp_resource".to_string(),
399 description: "Alias for mcp_read_resource.".to_string(),
400 input_schema: serde_json::json!({
401 "type": "object",
402 "properties": {
403 "server": { "type": "string", "description": "The name of the MCP server" },
404 "uri": { "type": "string", "description": "The URI of the resource to read" }
405 },
406 "required": ["server", "uri"]
407 }),
408 allowed_callers: Some(vec!["direct".to_string()]),
409 defer_loading: Some(false),
410 input_examples: None,
411 strict: None,
412 cache_control: None,
413 });
414 }
415
416 let prompts = self.all_prompts();
418 if !prompts.is_empty() {
419 api_tools.push(crate::models::Tool {
420 tool_type: None,
421 name: "mcp_get_prompt".to_string(),
422 description: "Get a prompt from an MCP server".to_string(),
423 input_schema: serde_json::json!({
424 "type": "object",
425 "properties": {
426 "server": { "type": "string", "description": "The name of the MCP server" },
427 "name": { "type": "string", "description": "The name of the prompt" },
428 "arguments": {
429 "type": "object",
430 "description": "Optional arguments for the prompt",
431 "additionalProperties": { "type": "string" }
432 }
433 },
434 "required": ["server", "name"]
435 }),
436 allowed_callers: Some(vec!["direct".to_string()]),
437 defer_loading: Some(false),
438 input_examples: None,
439 strict: None,
440 cache_control: None,
441 });
442 }
443
444 api_tools
445 }
446
447 pub async fn call_tool(
449 &mut self,
450 prefixed_name: &str,
451 arguments: serde_json::Value,
452 ) -> Result<serde_json::Value> {
453 if prefixed_name == "list_mcp_resources" {
454 let server = arguments
455 .get("server")
456 .and_then(|v| v.as_str())
457 .map(str::to_string);
458 let resources = self.list_resources(server).await?;
459 return Ok(serde_json::json!({ "resources": resources }));
460 }
461
462 if prefixed_name == "list_mcp_resource_templates" {
463 let server = arguments
464 .get("server")
465 .and_then(|v| v.as_str())
466 .map(str::to_string);
467 let templates = self.list_resource_templates(server).await?;
468 return Ok(serde_json::json!({ "templates": templates }));
469 }
470
471 if prefixed_name == "mcp_read_resource" {
472 let server_name = arguments
473 .get("server")
474 .and_then(|v| v.as_str())
475 .context("Missing 'server' argument")?;
476 let uri = arguments
477 .get("uri")
478 .and_then(|v| v.as_str())
479 .context("Missing 'uri' argument")?;
480 return self.read_resource(server_name, uri).await;
481 }
482
483 if prefixed_name == "read_mcp_resource" {
484 let server_name = arguments
485 .get("server")
486 .and_then(|v| v.as_str())
487 .context("Missing 'server' argument")?;
488 let uri = arguments
489 .get("uri")
490 .and_then(|v| v.as_str())
491 .context("Missing 'uri' argument")?;
492 return self.read_resource(server_name, uri).await;
493 }
494
495 if prefixed_name == "mcp_get_prompt" {
496 let server_name = arguments
497 .get("server")
498 .and_then(|v| v.as_str())
499 .context("Missing 'server' argument")?;
500 let name = arguments
501 .get("name")
502 .and_then(|v| v.as_str())
503 .context("Missing 'name' argument")?;
504 let args = arguments
505 .get("arguments")
506 .cloned()
507 .unwrap_or(serde_json::json!({}));
508 return self.get_prompt(server_name, name, args).await;
509 }
510
511 let (server_name, tool_name) = self.parse_prefixed_name(prefixed_name)?;
512 let global_timeouts = self.config.timeouts;
514 let conn = self.get_or_connect(server_name).await?;
515 if !conn.config().is_tool_enabled(tool_name) {
516 anyhow::bail!("MCP tool '{tool_name}' is disabled for server '{server_name}'");
517 }
518 let timeout = conn.config().effective_execute_timeout(&global_timeouts);
519 let started = std::time::Instant::now();
520 let result = conn.call_tool(tool_name, arguments, timeout).await;
521 let duration_ms = started.elapsed().as_millis() as u64;
522 let (success, err_msg, result_bytes) = match &result {
523 Ok(value) => (
524 true,
525 None,
526 serde_json::to_string(value).map(|s| s.len()).unwrap_or(0),
527 ),
528 Err(err) => (false, Some(err.to_string()), 0),
529 };
530 super::observability::record_mcp_call(
531 server_name,
532 format!("tools/call:{tool_name}"),
533 duration_ms,
534 success,
535 err_msg,
536 result_bytes,
537 );
538 result
539 }
540
541 #[allow(dead_code)] pub fn server_names(&self) -> Vec<&str> {
544 self.config
545 .servers
546 .keys()
547 .map(std::string::String::as_str)
548 .collect()
549 }
550
551 pub fn connected_servers(&self) -> Vec<&str> {
553 self.connections
554 .iter()
555 .filter(|(_, c)| c.is_ready())
556 .map(|(n, _)| n.as_str())
557 .collect()
558 }
559
560 #[allow(dead_code)] pub fn disconnect_all(&mut self) {
563 self.connections.clear();
564 }
565
566 pub async fn reload_from_path(&mut self, path: &std::path::Path) -> Result<McpReloadReport> {
568 let config = if path.exists() {
569 let contents = fs::read_to_string(path)
570 .with_context(|| format!("Failed to read MCP config: {}", path.display()))?;
571 serde_json::from_str(&contents)
572 .with_context(|| format!("Failed to parse MCP config: {}", path.display()))?
573 } else {
574 McpConfig::default()
575 };
576 Ok(self.reload_config(config, true).await)
577 }
578
579 pub async fn reload_config(
582 &mut self,
583 new_config: McpConfig,
584 reconnect: bool,
585 ) -> McpReloadReport {
586 let old_config = std::mem::replace(&mut self.config, new_config);
587 let mut removed = Vec::new();
588 let mut updated = Vec::new();
589
590 let old_names: std::collections::HashSet<_> = old_config.servers.keys().collect();
591 let new_names: std::collections::HashSet<_> = self.config.servers.keys().collect();
592
593 for name in old_names.difference(&new_names) {
594 removed.push((*name).clone());
595 if let Some(mut conn) = self.connections.remove(*name) {
596 conn.transport.shutdown().await;
597 }
598 }
599
600 for name in old_names.intersection(&new_names) {
601 if old_config.servers[*name] != self.config.servers[*name] {
602 updated.push((*name).clone());
603 if let Some(mut conn) = self.connections.remove(*name) {
604 conn.transport.shutdown().await;
605 }
606 }
607 }
608
609 let disabled_or_missing: Vec<String> = self
610 .connections
611 .keys()
612 .filter(|name| {
613 self.config
614 .servers
615 .get(*name)
616 .is_none_or(|cfg| !cfg.is_enabled())
617 })
618 .cloned()
619 .collect();
620 for name in disabled_or_missing {
621 if let Some(mut conn) = self.connections.remove(&name) {
622 conn.transport.shutdown().await;
623 }
624 }
625
626 let mut connect_errors = Vec::new();
627 if reconnect {
628 connect_errors = self
629 .connect_all()
630 .await
631 .into_iter()
632 .map(|(name, err)| (name, err.to_string()))
633 .collect();
634 }
635
636 let connected = self
637 .connected_servers()
638 .into_iter()
639 .map(str::to_string)
640 .collect();
641
642 McpReloadReport {
643 removed,
644 updated,
645 connected,
646 connect_errors,
647 }
648 }
649
650 #[allow(dead_code)] pub async fn shutdown_all(&mut self) {
660 let names: Vec<String> = self.connections.keys().cloned().collect();
661 for name in names {
662 if let Some(conn) = self.connections.get_mut(&name) {
663 conn.transport.shutdown().await;
664 }
665 }
666 self.connections.clear();
667 }
668
669 #[allow(dead_code)] pub fn config(&self) -> &McpConfig {
672 &self.config
673 }
674
675 pub fn is_mcp_tool(name: &str) -> bool {
677 name.starts_with("mcp_")
678 || matches!(
679 name,
680 "list_mcp_resources" | "list_mcp_resource_templates" | "read_mcp_resource"
681 )
682 }
683}