Skip to main content

zagens_runtime_adapters/mcp/
pool.rs

1use std::collections::HashMap;
2use std::fs;
3
4use anyhow::{Context, Result};
5
6use crate::network_policy::NetworkPolicyDecider;
7
8use super::config::McpConfig;
9
10/// Summary of an in-place MCP pool reload (config diff + optional reconnect).
11#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
12pub struct McpReloadReport {
13    pub removed: Vec<String>,
14    pub updated: Vec<String>,
15    pub connected: Vec<String>,
16    pub connect_errors: Vec<(String, String)>,
17}
18use super::connection::McpConnection;
19use super::types::{McpPrompt, McpResource, McpResourceTemplate, McpTool};
20
21// === McpPool - Connection Pool Management ===
22
23/// Pool of MCP connections for reuse
24pub struct McpPool {
25    pub(super) connections: HashMap<String, McpConnection>,
26    config: McpConfig,
27    network_policy: Option<NetworkPolicyDecider>,
28}
29
30impl McpPool {
31    /// Create a new pool with the given configuration
32    pub fn new(config: McpConfig) -> Self {
33        Self {
34            connections: HashMap::new(),
35            config,
36            network_policy: None,
37        }
38    }
39
40    /// Create a pool from a configuration file path
41    pub fn from_config_path(path: &std::path::Path) -> Result<Self> {
42        let config = if path.exists() {
43            let contents = fs::read_to_string(path)
44                .with_context(|| format!("Failed to read MCP config: {}", path.display()))?;
45            serde_json::from_str(&contents)
46                .with_context(|| format!("Failed to parse MCP config: {}", path.display()))?
47        } else {
48            McpConfig::default()
49        };
50        Ok(Self::new(config))
51    }
52
53    /// Attach a per-domain network policy (#135). When set, HTTP/SSE
54    /// transports are gated through it; STDIO transports are unaffected.
55    pub fn with_network_policy(mut self, policy: NetworkPolicyDecider) -> Self {
56        self.network_policy = Some(policy);
57        self
58    }
59
60    /// Get or create a connection to a server
61    pub async fn get_or_connect(&mut self, server_name: &str) -> Result<&mut McpConnection> {
62        let is_ready = self
63            .connections
64            .get(server_name)
65            .map(|conn| conn.is_ready())
66            .unwrap_or(false);
67        if is_ready {
68            return self
69                .connections
70                .get_mut(server_name)
71                .ok_or_else(|| anyhow::anyhow!("MCP connection disappeared for {server_name}"));
72        }
73
74        self.connections.remove(server_name);
75
76        let server_config = self
77            .config
78            .servers
79            .get(server_name)
80            .ok_or_else(|| anyhow::anyhow!("Failed to find MCP server: {server_name}"))?
81            .clone();
82
83        if !server_config.is_enabled() {
84            anyhow::bail!("Failed to connect MCP server '{server_name}': server is disabled");
85        }
86
87        let connection = McpConnection::connect_with_policy(
88            server_name.to_string(),
89            server_config,
90            &self.config.timeouts,
91            self.network_policy.as_ref(),
92        )
93        .await?;
94
95        self.connections.insert(server_name.to_string(), connection);
96        self.connections
97            .get_mut(server_name)
98            .ok_or_else(|| anyhow::anyhow!("Failed to store MCP connection for {server_name}"))
99    }
100
101    /// Connect to all enabled servers, returning errors for failed connections
102    pub async fn connect_all(&mut self) -> Vec<(String, anyhow::Error)> {
103        let mut errors = Vec::new();
104        let names: Vec<String> = self
105            .config
106            .servers
107            .keys()
108            .filter(|n| self.config.servers[*n].is_enabled())
109            .cloned()
110            .collect();
111
112        for name in names {
113            if let Err(e) = self.get_or_connect(&name).await {
114                errors.push((name, e));
115            }
116        }
117
118        for (name, server_cfg) in &self.config.servers {
119            if server_cfg.required
120                && server_cfg.is_enabled()
121                && !self
122                    .connections
123                    .get(name)
124                    .is_some_and(McpConnection::is_ready)
125            {
126                errors.push((
127                    name.clone(),
128                    anyhow::anyhow!("required MCP server failed to initialize"),
129                ));
130            }
131        }
132
133        errors
134    }
135
136    /// Get all discovered tools with server-prefixed names
137    pub fn all_tools(&self) -> Vec<(String, &McpTool)> {
138        let mut tools = Vec::new();
139        for (server, conn) in &self.connections {
140            for tool in conn.tools() {
141                if !conn.config().is_tool_enabled(&tool.name) {
142                    continue;
143                }
144                // Format: mcp_{server}_{tool}
145                tools.push((format!("mcp_{}_{}", server, tool.name), tool));
146            }
147        }
148        tools
149    }
150
151    /// Get all discovered resources with server-prefixed names
152    pub fn all_resources(&self) -> Vec<(String, &McpResource)> {
153        let mut resources = Vec::new();
154        for (server, conn) in &self.connections {
155            for resource in conn.resources() {
156                // Format: mcp_{server}_{resource_name}
157                // Note: resource names might contain spaces, we should probably slugify them
158                let safe_name = resource.name.replace(' ', "_").to_lowercase();
159                resources.push((format!("mcp_{}_{}", server, safe_name), resource));
160            }
161        }
162        resources
163    }
164
165    /// Get all discovered resource templates with server-prefixed names
166    #[allow(dead_code)] // Public API for MCP resource discovery
167    pub fn all_resource_templates(&self) -> Vec<(String, &McpResourceTemplate)> {
168        let mut templates = Vec::new();
169        for (server, conn) in &self.connections {
170            for template in conn.resource_templates() {
171                let safe_name = template.name.replace(' ', "_").to_lowercase();
172                templates.push((format!("mcp_{}_{}", server, safe_name), template));
173            }
174        }
175        templates
176    }
177
178    async fn list_resources(&mut self, server: Option<String>) -> Result<Vec<serde_json::Value>> {
179        if let Some(server_name) = server {
180            let conn = self.get_or_connect(&server_name).await?;
181            let resources = conn
182                .resources()
183                .iter()
184                .map(|resource| {
185                    serde_json::json!({
186                        "server": server_name.clone(),
187                        "uri": resource.uri,
188                        "name": resource.name,
189                        "description": resource.description,
190                        "mime_type": resource.mime_type,
191                    })
192                })
193                .collect();
194            return Ok(resources);
195        }
196
197        let _ = self.connect_all().await;
198        let mut items = Vec::new();
199        for (server, conn) in &self.connections {
200            for resource in conn.resources() {
201                items.push(serde_json::json!({
202                    "server": server,
203                    "uri": resource.uri,
204                    "name": resource.name,
205                    "description": resource.description,
206                    "mime_type": resource.mime_type,
207                }));
208            }
209        }
210        Ok(items)
211    }
212
213    async fn list_resource_templates(
214        &mut self,
215        server: Option<String>,
216    ) -> Result<Vec<serde_json::Value>> {
217        if let Some(server_name) = server {
218            let conn = self.get_or_connect(&server_name).await?;
219            let templates = conn
220                .resource_templates()
221                .iter()
222                .map(|template| {
223                    serde_json::json!({
224                        "server": server_name.clone(),
225                        "uri_template": template.uri_template,
226                        "name": template.name,
227                        "description": template.description,
228                        "mime_type": template.mime_type,
229                    })
230                })
231                .collect();
232            return Ok(templates);
233        }
234
235        let _ = self.connect_all().await;
236        let mut items = Vec::new();
237        for (server, conn) in &self.connections {
238            for template in conn.resource_templates() {
239                items.push(serde_json::json!({
240                    "server": server,
241                    "uri_template": template.uri_template,
242                    "name": template.name,
243                    "description": template.description,
244                    "mime_type": template.mime_type,
245                }));
246            }
247        }
248        Ok(items)
249    }
250
251    /// Get all discovered prompts with server-prefixed names
252    pub fn all_prompts(&self) -> Vec<(String, &McpPrompt)> {
253        let mut prompts = Vec::new();
254        for (server, conn) in &self.connections {
255            for prompt in conn.prompts() {
256                // Format: mcp_{server}_{prompt}
257                prompts.push((format!("mcp_{}_{}", server, prompt.name), prompt));
258            }
259        }
260        prompts
261    }
262
263    /// Read a resource from a specific server
264    pub async fn read_resource(
265        &mut self,
266        server_name: &str,
267        uri: &str,
268    ) -> Result<serde_json::Value> {
269        let global_timeouts = self.config.timeouts;
270        let conn = self.get_or_connect(server_name).await?;
271        let timeout = conn.config().effective_read_timeout(&global_timeouts);
272        conn.read_resource(uri, timeout).await
273    }
274
275    /// Get a prompt from a specific server
276    pub async fn get_prompt(
277        &mut self,
278        server_name: &str,
279        prompt_name: &str,
280        arguments: serde_json::Value,
281    ) -> Result<serde_json::Value> {
282        let global_timeouts = self.config.timeouts;
283        let conn = self.get_or_connect(server_name).await?;
284        let timeout = conn.config().effective_execute_timeout(&global_timeouts);
285        conn.get_prompt(prompt_name, arguments, timeout).await
286    }
287
288    /// Parse a prefixed name `mcp_{server}_{tool}` into (server_name, tool_name).
289    ///
290    /// Server names may themselves contain underscores (e.g. `github_mcp`), so a
291    /// naive `split_once('_')` misattributes the boundary. We match against the
292    /// configured server names, longest first, and split the tool off the
293    /// remainder — this is symmetric with the `mcp_{server}_{tool}` formatting in
294    /// [`Self::all_tools`]. Falls back to `split_once('_')` for names whose
295    /// server isn't in the current config (preserves prior behavior).
296    pub(super) fn parse_prefixed_name<'a>(
297        &self,
298        prefixed_name: &'a str,
299    ) -> Result<(&'a str, &'a str)> {
300        let rest = prefixed_name
301            .strip_prefix("mcp_")
302            .ok_or_else(|| anyhow::anyhow!("Invalid MCP tool name: {prefixed_name}"))?;
303
304        let mut servers: Vec<&str> = self.config.servers.keys().map(String::as_str).collect();
305        servers.sort_by_key(|name| std::cmp::Reverse(name.len()));
306        for server in servers {
307            if let Some(tool) = rest
308                .strip_prefix(server)
309                .and_then(|tail| tail.strip_prefix('_'))
310                && !tool.is_empty()
311            {
312                return Ok((&rest[..server.len()], tool));
313            }
314        }
315
316        rest.split_once('_')
317            .filter(|(server, tool)| !server.is_empty() && !tool.is_empty())
318            .ok_or_else(|| anyhow::anyhow!("Invalid MCP tool name format: {prefixed_name}"))
319    }
320
321    /// Convert discovered tools to API Tool format
322    pub fn to_api_tools(&self) -> Vec<crate::models::Tool> {
323        let mut api_tools = Vec::new();
324
325        // Add regular tools
326        for (name, tool) in self.all_tools() {
327            api_tools.push(crate::models::Tool {
328                tool_type: None,
329                name,
330                description: tool.description.clone().unwrap_or_default(),
331                input_schema: tool.input_schema.clone(),
332                allowed_callers: Some(vec!["direct".to_string()]),
333                defer_loading: Some(false),
334                input_examples: None,
335                strict: None,
336                cache_control: None,
337            });
338        }
339
340        if !self.config.servers.is_empty() {
341            api_tools.push(crate::models::Tool {
342                tool_type: None,
343                name: "list_mcp_resources".to_string(),
344                description: "List available MCP resources across servers (optionally filtered by server).".to_string(),
345                input_schema: serde_json::json!({
346                    "type": "object",
347                    "properties": {
348                        "server": { "type": "string", "description": "Optional MCP server name to filter by" }
349                    }
350                }),
351                allowed_callers: Some(vec!["direct".to_string()]),
352                defer_loading: Some(false),
353                input_examples: None,
354                strict: None,
355                cache_control: None,
356            });
357            api_tools.push(crate::models::Tool {
358                tool_type: None,
359                name: "list_mcp_resource_templates".to_string(),
360                description: "List available MCP resource templates across servers (optionally filtered by server).".to_string(),
361                input_schema: serde_json::json!({
362                    "type": "object",
363                    "properties": {
364                        "server": { "type": "string", "description": "Optional MCP server name to filter by" }
365                    }
366                }),
367                allowed_callers: Some(vec!["direct".to_string()]),
368                defer_loading: Some(false),
369                input_examples: None,
370                strict: None,
371                cache_control: None,
372            });
373        }
374
375        // Add resource reading tools if resources exist
376        let resources = self.all_resources();
377        if !resources.is_empty() {
378            api_tools.push(crate::models::Tool {
379                tool_type: None,
380                name: "mcp_read_resource".to_string(),
381                description: "Read a resource from an MCP server using its URI".to_string(),
382                input_schema: serde_json::json!({
383                    "type": "object",
384                    "properties": {
385                        "server": { "type": "string", "description": "The name of the MCP server" },
386                        "uri": { "type": "string", "description": "The URI of the resource to read" }
387                    },
388                    "required": ["server", "uri"]
389                }),
390                allowed_callers: Some(vec!["direct".to_string()]),
391                defer_loading: Some(false),
392                input_examples: None,
393                strict: None,
394                cache_control: None,
395            });
396            api_tools.push(crate::models::Tool {
397                tool_type: None,
398                name: "read_mcp_resource".to_string(),
399                description: "Alias for mcp_read_resource.".to_string(),
400                input_schema: serde_json::json!({
401                    "type": "object",
402                    "properties": {
403                        "server": { "type": "string", "description": "The name of the MCP server" },
404                        "uri": { "type": "string", "description": "The URI of the resource to read" }
405                    },
406                    "required": ["server", "uri"]
407                }),
408                allowed_callers: Some(vec!["direct".to_string()]),
409                defer_loading: Some(false),
410                input_examples: None,
411                strict: None,
412                cache_control: None,
413            });
414        }
415
416        // Add prompt getting tools if prompts exist
417        let prompts = self.all_prompts();
418        if !prompts.is_empty() {
419            api_tools.push(crate::models::Tool {
420                tool_type: None,
421                name: "mcp_get_prompt".to_string(),
422                description: "Get a prompt from an MCP server".to_string(),
423                input_schema: serde_json::json!({
424                    "type": "object",
425                    "properties": {
426                        "server": { "type": "string", "description": "The name of the MCP server" },
427                        "name": { "type": "string", "description": "The name of the prompt" },
428                        "arguments": {
429                            "type": "object",
430                            "description": "Optional arguments for the prompt",
431                            "additionalProperties": { "type": "string" }
432                        }
433                    },
434                    "required": ["server", "name"]
435                }),
436                allowed_callers: Some(vec!["direct".to_string()]),
437                defer_loading: Some(false),
438                input_examples: None,
439                strict: None,
440                cache_control: None,
441            });
442        }
443
444        api_tools
445    }
446
447    /// Call a tool by its prefixed name (mcp_{server}_{tool})
448    pub async fn call_tool(
449        &mut self,
450        prefixed_name: &str,
451        arguments: serde_json::Value,
452    ) -> Result<serde_json::Value> {
453        if prefixed_name == "list_mcp_resources" {
454            let server = arguments
455                .get("server")
456                .and_then(|v| v.as_str())
457                .map(str::to_string);
458            let resources = self.list_resources(server).await?;
459            return Ok(serde_json::json!({ "resources": resources }));
460        }
461
462        if prefixed_name == "list_mcp_resource_templates" {
463            let server = arguments
464                .get("server")
465                .and_then(|v| v.as_str())
466                .map(str::to_string);
467            let templates = self.list_resource_templates(server).await?;
468            return Ok(serde_json::json!({ "templates": templates }));
469        }
470
471        if prefixed_name == "mcp_read_resource" {
472            let server_name = arguments
473                .get("server")
474                .and_then(|v| v.as_str())
475                .context("Missing 'server' argument")?;
476            let uri = arguments
477                .get("uri")
478                .and_then(|v| v.as_str())
479                .context("Missing 'uri' argument")?;
480            return self.read_resource(server_name, uri).await;
481        }
482
483        if prefixed_name == "read_mcp_resource" {
484            let server_name = arguments
485                .get("server")
486                .and_then(|v| v.as_str())
487                .context("Missing 'server' argument")?;
488            let uri = arguments
489                .get("uri")
490                .and_then(|v| v.as_str())
491                .context("Missing 'uri' argument")?;
492            return self.read_resource(server_name, uri).await;
493        }
494
495        if prefixed_name == "mcp_get_prompt" {
496            let server_name = arguments
497                .get("server")
498                .and_then(|v| v.as_str())
499                .context("Missing 'server' argument")?;
500            let name = arguments
501                .get("name")
502                .and_then(|v| v.as_str())
503                .context("Missing 'name' argument")?;
504            let args = arguments
505                .get("arguments")
506                .cloned()
507                .unwrap_or(serde_json::json!({}));
508            return self.get_prompt(server_name, name, args).await;
509        }
510
511        let (server_name, tool_name) = self.parse_prefixed_name(prefixed_name)?;
512        // Copy the global timeouts to avoid borrow conflict
513        let global_timeouts = self.config.timeouts;
514        let conn = self.get_or_connect(server_name).await?;
515        if !conn.config().is_tool_enabled(tool_name) {
516            anyhow::bail!("MCP tool '{tool_name}' is disabled for server '{server_name}'");
517        }
518        let timeout = conn.config().effective_execute_timeout(&global_timeouts);
519        let started = std::time::Instant::now();
520        let result = conn.call_tool(tool_name, arguments, timeout).await;
521        let duration_ms = started.elapsed().as_millis() as u64;
522        let (success, err_msg, result_bytes) = match &result {
523            Ok(value) => (
524                true,
525                None,
526                serde_json::to_string(value).map(|s| s.len()).unwrap_or(0),
527            ),
528            Err(err) => (false, Some(err.to_string()), 0),
529        };
530        super::observability::record_mcp_call(
531            server_name,
532            format!("tools/call:{tool_name}"),
533            duration_ms,
534            success,
535            err_msg,
536            result_bytes,
537        );
538        result
539    }
540
541    /// Get list of configured server names
542    #[allow(dead_code)] // Public API for MCP consumers
543    pub fn server_names(&self) -> Vec<&str> {
544        self.config
545            .servers
546            .keys()
547            .map(std::string::String::as_str)
548            .collect()
549    }
550
551    /// Get list of connected server names
552    pub fn connected_servers(&self) -> Vec<&str> {
553        self.connections
554            .iter()
555            .filter(|(_, c)| c.is_ready())
556            .map(|(n, _)| n.as_str())
557            .collect()
558    }
559
560    /// Disconnect all connections
561    #[allow(dead_code)] // Public API for MCP lifecycle management
562    pub fn disconnect_all(&mut self) {
563        self.connections.clear();
564    }
565
566    /// Reload pool configuration from disk and reconcile live connections.
567    pub async fn reload_from_path(&mut self, path: &std::path::Path) -> Result<McpReloadReport> {
568        let config = if path.exists() {
569            let contents = fs::read_to_string(path)
570                .with_context(|| format!("Failed to read MCP config: {}", path.display()))?;
571            serde_json::from_str(&contents)
572                .with_context(|| format!("Failed to parse MCP config: {}", path.display()))?
573        } else {
574            McpConfig::default()
575        };
576        Ok(self.reload_config(config, true).await)
577    }
578
579    /// Apply a new config: drop removed/changed/disabled connections, swap
580    /// config, then optionally reconnect all enabled servers.
581    pub async fn reload_config(
582        &mut self,
583        new_config: McpConfig,
584        reconnect: bool,
585    ) -> McpReloadReport {
586        let old_config = std::mem::replace(&mut self.config, new_config);
587        let mut removed = Vec::new();
588        let mut updated = Vec::new();
589
590        let old_names: std::collections::HashSet<_> = old_config.servers.keys().collect();
591        let new_names: std::collections::HashSet<_> = self.config.servers.keys().collect();
592
593        for name in old_names.difference(&new_names) {
594            removed.push((*name).clone());
595            if let Some(mut conn) = self.connections.remove(*name) {
596                conn.transport.shutdown().await;
597            }
598        }
599
600        for name in old_names.intersection(&new_names) {
601            if old_config.servers[*name] != self.config.servers[*name] {
602                updated.push((*name).clone());
603                if let Some(mut conn) = self.connections.remove(*name) {
604                    conn.transport.shutdown().await;
605                }
606            }
607        }
608
609        let disabled_or_missing: Vec<String> = self
610            .connections
611            .keys()
612            .filter(|name| {
613                self.config
614                    .servers
615                    .get(*name)
616                    .is_none_or(|cfg| !cfg.is_enabled())
617            })
618            .cloned()
619            .collect();
620        for name in disabled_or_missing {
621            if let Some(mut conn) = self.connections.remove(&name) {
622                conn.transport.shutdown().await;
623            }
624        }
625
626        let mut connect_errors = Vec::new();
627        if reconnect {
628            connect_errors = self
629                .connect_all()
630                .await
631                .into_iter()
632                .map(|(name, err)| (name, err.to_string()))
633                .collect();
634        }
635
636        let connected = self
637            .connected_servers()
638            .into_iter()
639            .map(str::to_string)
640            .collect();
641
642        McpReloadReport {
643            removed,
644            updated,
645            connected,
646            connect_errors,
647        }
648    }
649
650    /// Graceful shutdown of every connection in the pool: send SIGTERM to
651    /// each stdio child and give them a short grace period before drop
652    /// fires SIGKILL. Whalescale#420.
653    ///
654    /// Call from the TUI exit path *before* dropping the pool to give
655    /// MCP servers a chance to flush state. The fallback Drop on
656    /// `StdioTransport` still sends SIGTERM if this never runs, so even
657    /// abnormal exits avoid leaking PIDs without a signal.
658    #[allow(dead_code)] // Wired in by callers that want graceful shutdown
659    pub async fn shutdown_all(&mut self) {
660        let names: Vec<String> = self.connections.keys().cloned().collect();
661        for name in names {
662            if let Some(conn) = self.connections.get_mut(&name) {
663                conn.transport.shutdown().await;
664            }
665        }
666        self.connections.clear();
667    }
668
669    /// Get the underlying configuration
670    #[allow(dead_code)] // Public API for MCP consumers
671    pub fn config(&self) -> &McpConfig {
672        &self.config
673    }
674
675    /// Check if a tool name is an MCP tool
676    pub fn is_mcp_tool(name: &str) -> bool {
677        name.starts_with("mcp_")
678            || matches!(
679                name,
680                "list_mcp_resources" | "list_mcp_resource_templates" | "read_mcp_resource"
681            )
682    }
683}