Skip to main content

synwire_agent/mcp/
lifecycle.rs

1//! MCP server lifecycle manager.
2//!
3//! Manages a set of named MCP servers: connects on start, reconnects on drop,
4//! monitors health, and supports runtime enable/disable.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9
10use tokio::sync::RwLock;
11use tokio::time::sleep;
12
13use synwire_core::agents::error::AgentError;
14use synwire_core::mcp::traits::{McpConnectionState, McpServerStatus, McpTransport};
15
16// ---------------------------------------------------------------------------
17// Managed server entry
18// ---------------------------------------------------------------------------
19
20struct ManagedServer {
21    transport: Box<dyn McpTransport>,
22    enabled: bool,
23    reconnect_delay: Duration,
24}
25
26impl std::fmt::Debug for ManagedServer {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("ManagedServer")
29            .field("enabled", &self.enabled)
30            .field("reconnect_delay", &self.reconnect_delay)
31            .finish_non_exhaustive()
32    }
33}
34
35// ---------------------------------------------------------------------------
36// McpLifecycleManager
37// ---------------------------------------------------------------------------
38
39/// Manages the lifecycle of multiple MCP server connections.
40#[derive(Debug, Default)]
41pub struct McpLifecycleManager {
42    servers: Arc<RwLock<HashMap<String, ManagedServer>>>,
43}
44
45impl McpLifecycleManager {
46    /// Create an empty lifecycle manager.
47    #[must_use]
48    pub fn new() -> Self {
49        Self::default()
50    }
51
52    /// Register an MCP server under the given name.
53    pub async fn register(
54        &self,
55        name: impl Into<String>,
56        transport: impl McpTransport + 'static,
57        reconnect_delay: Duration,
58    ) {
59        let _ = self.servers.write().await.insert(
60            name.into(),
61            ManagedServer {
62                transport: Box::new(transport),
63                enabled: true,
64                reconnect_delay,
65            },
66        );
67    }
68
69    /// Connect all registered, enabled servers.
70    pub async fn start_all(&self) -> Result<(), AgentError> {
71        // Collect enabled server names first, then release the read lock before
72        // performing async operations to avoid holding the lock across awaits.
73        let names: Vec<String> = self
74            .servers
75            .read()
76            .await
77            .iter()
78            .filter(|(_, server)| server.enabled)
79            .map(|(name, _)| name.clone())
80            .collect();
81        for name in names {
82            let guard = self.servers.read().await;
83            if let Some(server) = guard.get(&name) {
84                tracing::info!(%name, "Connecting MCP server");
85                server.transport.connect().await?;
86            }
87        }
88        Ok(())
89    }
90
91    /// Disconnect all servers cleanly.
92    pub async fn stop_all(&self) -> Result<(), AgentError> {
93        let names: Vec<String> = self.servers.read().await.keys().cloned().collect();
94        for name in names {
95            let guard = self.servers.read().await;
96            if let Some(server) = guard.get(&name) {
97                tracing::info!(%name, "Disconnecting MCP server");
98                let _ = server.transport.disconnect().await;
99            }
100        }
101        Ok(())
102    }
103
104    /// Enable a specific server (connects if not already connected).
105    pub async fn enable(&self, name: &str) -> Result<(), AgentError> {
106        let guard = self.servers.read().await;
107        if let Some(server) = guard.get(name)
108            && !server.enabled
109        {
110            drop(guard);
111            let _ = self
112                .servers
113                .write()
114                .await
115                .get_mut(name)
116                .map(|s| s.enabled = true);
117            let guard = self.servers.read().await;
118            if let Some(server) = guard.get(name) {
119                server.transport.connect().await?;
120            }
121        }
122        Ok(())
123    }
124
125    /// Disable a specific server (disconnects immediately).
126    pub async fn disable(&self, name: &str) -> Result<(), AgentError> {
127        // Set enabled = false under the write lock, then drop before async disconnect.
128        let found = {
129            let mut guard = self.servers.write().await;
130            if let Some(server) = guard.get_mut(name) {
131                server.enabled = false;
132                true
133            } else {
134                false
135            }
136        };
137        if found {
138            let guard = self.servers.read().await;
139            if let Some(server) = guard.get(name) {
140                server.transport.disconnect().await?;
141            }
142        }
143        Ok(())
144    }
145
146    /// Return current status for all managed servers.
147    pub async fn all_status(&self) -> Vec<McpServerStatus> {
148        let names: Vec<String> = self.servers.read().await.keys().cloned().collect();
149        let mut statuses = Vec::new();
150        for name in names {
151            let guard = self.servers.read().await;
152            if let Some(server) = guard.get(&name) {
153                statuses.push(server.transport.status().await);
154            }
155        }
156        statuses
157    }
158
159    /// List tools available from a named server.
160    #[allow(clippy::significant_drop_tightening)]
161    pub async fn list_tools(
162        &self,
163        server_name: &str,
164    ) -> Result<Vec<synwire_core::mcp::traits::McpToolDescriptor>, AgentError> {
165        let guard = self.servers.read().await;
166        let server = guard
167            .get(server_name)
168            .ok_or_else(|| AgentError::Vfs(format!("Unknown MCP server: {server_name}")))?;
169        server.transport.list_tools().await
170    }
171
172    /// Invoke a tool on a named server, reconnecting if needed.
173    #[allow(clippy::significant_drop_tightening)]
174    pub async fn call_tool(
175        &self,
176        server_name: &str,
177        tool_name: &str,
178        arguments: serde_json::Value,
179    ) -> Result<serde_json::Value, AgentError> {
180        // Check enabled state and connection status with a short-lived guard.
181        let (enabled, needs_reconnect) = {
182            let guard = self.servers.read().await;
183            let server = guard
184                .get(server_name)
185                .ok_or_else(|| AgentError::Vfs(format!("Unknown MCP server: {server_name}")))?;
186            let status = server.transport.status().await;
187            (
188                server.enabled,
189                status.state != McpConnectionState::Connected,
190            )
191        };
192
193        if !enabled {
194            return Err(AgentError::Vfs(format!(
195                "MCP server {server_name} is disabled"
196            )));
197        }
198
199        if needs_reconnect {
200            tracing::warn!(%server_name, "MCP server not connected — attempting reconnect");
201            let guard = self.servers.read().await;
202            if let Some(server) = guard.get(server_name) {
203                server.transport.reconnect().await?;
204            }
205        }
206
207        let guard = self.servers.read().await;
208        let server = guard
209            .get(server_name)
210            .ok_or_else(|| AgentError::Vfs(format!("Unknown MCP server: {server_name}")))?;
211        server.transport.call_tool(tool_name, arguments).await
212    }
213
214    /// Spawn a background health-monitor task that reconnects servers that drop.
215    ///
216    /// The task polls every `interval` and attempts reconnection with the
217    /// server's configured `reconnect_delay`.
218    #[allow(clippy::significant_drop_tightening)]
219    pub fn spawn_health_monitor(self: Arc<Self>, interval: Duration) {
220        drop(tokio::spawn(async move {
221            loop {
222                sleep(interval).await;
223                // Collect disconnected servers with a short-lived guard.
224                let disconnected: Option<(String, Duration)> = {
225                    let guard = self.servers.read().await;
226                    let mut found = None;
227                    for (name, server) in guard.iter() {
228                        if !server.enabled {
229                            continue;
230                        }
231                        let status = server.transport.status().await;
232                        if status.state == McpConnectionState::Disconnected {
233                            tracing::warn!(%name, "MCP server disconnected — scheduling reconnect");
234                            found = Some((name.clone(), server.reconnect_delay));
235                            break;
236                        }
237                    }
238                    found
239                };
240                if let Some((name, delay)) = disconnected {
241                    sleep(delay).await;
242                    let guard = self.servers.read().await;
243                    if let Some(server) = guard.get(&name)
244                        && let Err(e) = server.transport.reconnect().await
245                    {
246                        tracing::error!(%name, %e, "MCP reconnect failed");
247                    }
248                }
249            }
250        }));
251    }
252}