synwire_agent/mcp/
lifecycle.rs1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9
10use tokio::sync::RwLock;
11use tokio::time::sleep;
12
13use synwire_core::agents::error::AgentError;
14use synwire_core::mcp::traits::{McpConnectionState, McpServerStatus, McpTransport};
15
16struct ManagedServer {
21 transport: Box<dyn McpTransport>,
22 enabled: bool,
23 reconnect_delay: Duration,
24}
25
26impl std::fmt::Debug for ManagedServer {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct("ManagedServer")
29 .field("enabled", &self.enabled)
30 .field("reconnect_delay", &self.reconnect_delay)
31 .finish_non_exhaustive()
32 }
33}
34
35#[derive(Debug, Default)]
41pub struct McpLifecycleManager {
42 servers: Arc<RwLock<HashMap<String, ManagedServer>>>,
43}
44
45impl McpLifecycleManager {
46 #[must_use]
48 pub fn new() -> Self {
49 Self::default()
50 }
51
52 pub async fn register(
54 &self,
55 name: impl Into<String>,
56 transport: impl McpTransport + 'static,
57 reconnect_delay: Duration,
58 ) {
59 let _ = self.servers.write().await.insert(
60 name.into(),
61 ManagedServer {
62 transport: Box::new(transport),
63 enabled: true,
64 reconnect_delay,
65 },
66 );
67 }
68
69 pub async fn start_all(&self) -> Result<(), AgentError> {
71 let names: Vec<String> = self
74 .servers
75 .read()
76 .await
77 .iter()
78 .filter(|(_, server)| server.enabled)
79 .map(|(name, _)| name.clone())
80 .collect();
81 for name in names {
82 let guard = self.servers.read().await;
83 if let Some(server) = guard.get(&name) {
84 tracing::info!(%name, "Connecting MCP server");
85 server.transport.connect().await?;
86 }
87 }
88 Ok(())
89 }
90
91 pub async fn stop_all(&self) -> Result<(), AgentError> {
93 let names: Vec<String> = self.servers.read().await.keys().cloned().collect();
94 for name in names {
95 let guard = self.servers.read().await;
96 if let Some(server) = guard.get(&name) {
97 tracing::info!(%name, "Disconnecting MCP server");
98 let _ = server.transport.disconnect().await;
99 }
100 }
101 Ok(())
102 }
103
104 pub async fn enable(&self, name: &str) -> Result<(), AgentError> {
106 let guard = self.servers.read().await;
107 if let Some(server) = guard.get(name)
108 && !server.enabled
109 {
110 drop(guard);
111 let _ = self
112 .servers
113 .write()
114 .await
115 .get_mut(name)
116 .map(|s| s.enabled = true);
117 let guard = self.servers.read().await;
118 if let Some(server) = guard.get(name) {
119 server.transport.connect().await?;
120 }
121 }
122 Ok(())
123 }
124
125 pub async fn disable(&self, name: &str) -> Result<(), AgentError> {
127 let found = {
129 let mut guard = self.servers.write().await;
130 if let Some(server) = guard.get_mut(name) {
131 server.enabled = false;
132 true
133 } else {
134 false
135 }
136 };
137 if found {
138 let guard = self.servers.read().await;
139 if let Some(server) = guard.get(name) {
140 server.transport.disconnect().await?;
141 }
142 }
143 Ok(())
144 }
145
146 pub async fn all_status(&self) -> Vec<McpServerStatus> {
148 let names: Vec<String> = self.servers.read().await.keys().cloned().collect();
149 let mut statuses = Vec::new();
150 for name in names {
151 let guard = self.servers.read().await;
152 if let Some(server) = guard.get(&name) {
153 statuses.push(server.transport.status().await);
154 }
155 }
156 statuses
157 }
158
159 #[allow(clippy::significant_drop_tightening)]
161 pub async fn list_tools(
162 &self,
163 server_name: &str,
164 ) -> Result<Vec<synwire_core::mcp::traits::McpToolDescriptor>, AgentError> {
165 let guard = self.servers.read().await;
166 let server = guard
167 .get(server_name)
168 .ok_or_else(|| AgentError::Vfs(format!("Unknown MCP server: {server_name}")))?;
169 server.transport.list_tools().await
170 }
171
172 #[allow(clippy::significant_drop_tightening)]
174 pub async fn call_tool(
175 &self,
176 server_name: &str,
177 tool_name: &str,
178 arguments: serde_json::Value,
179 ) -> Result<serde_json::Value, AgentError> {
180 let (enabled, needs_reconnect) = {
182 let guard = self.servers.read().await;
183 let server = guard
184 .get(server_name)
185 .ok_or_else(|| AgentError::Vfs(format!("Unknown MCP server: {server_name}")))?;
186 let status = server.transport.status().await;
187 (
188 server.enabled,
189 status.state != McpConnectionState::Connected,
190 )
191 };
192
193 if !enabled {
194 return Err(AgentError::Vfs(format!(
195 "MCP server {server_name} is disabled"
196 )));
197 }
198
199 if needs_reconnect {
200 tracing::warn!(%server_name, "MCP server not connected — attempting reconnect");
201 let guard = self.servers.read().await;
202 if let Some(server) = guard.get(server_name) {
203 server.transport.reconnect().await?;
204 }
205 }
206
207 let guard = self.servers.read().await;
208 let server = guard
209 .get(server_name)
210 .ok_or_else(|| AgentError::Vfs(format!("Unknown MCP server: {server_name}")))?;
211 server.transport.call_tool(tool_name, arguments).await
212 }
213
214 #[allow(clippy::significant_drop_tightening)]
219 pub fn spawn_health_monitor(self: Arc<Self>, interval: Duration) {
220 drop(tokio::spawn(async move {
221 loop {
222 sleep(interval).await;
223 let disconnected: Option<(String, Duration)> = {
225 let guard = self.servers.read().await;
226 let mut found = None;
227 for (name, server) in guard.iter() {
228 if !server.enabled {
229 continue;
230 }
231 let status = server.transport.status().await;
232 if status.state == McpConnectionState::Disconnected {
233 tracing::warn!(%name, "MCP server disconnected — scheduling reconnect");
234 found = Some((name.clone(), server.reconnect_delay));
235 break;
236 }
237 }
238 found
239 };
240 if let Some((name, delay)) = disconnected {
241 sleep(delay).await;
242 let guard = self.servers.read().await;
243 if let Some(server) = guard.get(&name)
244 && let Err(e) = server.transport.reconnect().await
245 {
246 tracing::error!(%name, %e, "MCP reconnect failed");
247 }
248 }
249 }
250 }));
251 }
252}