1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use serde::{Deserialize, Serialize};
6use tokio::process::{Child, Command};
7use tokio::sync::{Mutex, RwLock};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct McpServer {
11 pub name: String,
12 pub transport: String,
13 pub connected: bool,
14 #[serde(skip_serializing_if = "Option::is_none")]
15 pub pid: Option<u32>,
16 #[serde(skip_serializing_if = "Option::is_none")]
17 pub last_error: Option<String>,
18}
19
20#[derive(Clone)]
21pub struct McpRegistry {
22 servers: Arc<RwLock<HashMap<String, McpServer>>>,
23 processes: Arc<Mutex<HashMap<String, Child>>>,
24 state_file: Arc<PathBuf>,
25}
26
27impl McpRegistry {
28 pub fn new() -> Self {
29 Self::new_with_state_file(resolve_state_file())
30 }
31
32 pub fn new_with_state_file(state_file: PathBuf) -> Self {
33 let loaded = load_state(&state_file)
34 .into_iter()
35 .map(|(k, mut v)| {
36 v.connected = false;
37 v.pid = None;
38 (k, v)
39 })
40 .collect::<HashMap<_, _>>();
41 Self {
42 servers: Arc::new(RwLock::new(loaded)),
43 processes: Arc::new(Mutex::new(HashMap::new())),
44 state_file: Arc::new(state_file),
45 }
46 }
47
48 pub async fn list(&self) -> HashMap<String, McpServer> {
49 self.servers.read().await.clone()
50 }
51
52 pub async fn add(&self, name: String, transport: String) {
53 self.servers.write().await.insert(
54 name.clone(),
55 McpServer {
56 name,
57 transport,
58 connected: false,
59 pid: None,
60 last_error: None,
61 },
62 );
63 self.persist_state().await;
64 }
65
66 pub async fn connect(&self, name: &str) -> bool {
67 let transport = {
68 let servers = self.servers.read().await;
69 let Some(server) = servers.get(name) else {
70 return false;
71 };
72 server.transport.clone()
73 };
74
75 if let Some(command_text) = parse_stdio_transport(&transport) {
76 match spawn_stdio_process(command_text).await {
77 Ok(child) => {
78 let pid = child.id();
79 self.processes.lock().await.insert(name.to_string(), child);
80 let mut servers = self.servers.write().await;
81 if let Some(server) = servers.get_mut(name) {
82 server.connected = true;
83 server.pid = pid;
84 server.last_error = None;
85 }
86 drop(servers);
87 self.persist_state().await;
88 true
89 }
90 Err(err) => {
91 let mut servers = self.servers.write().await;
92 if let Some(server) = servers.get_mut(name) {
93 server.connected = false;
94 server.pid = None;
95 server.last_error = Some(err);
96 }
97 drop(servers);
98 self.persist_state().await;
99 false
100 }
101 }
102 } else {
103 let mut servers = self.servers.write().await;
104 if let Some(server) = servers.get_mut(name) {
105 server.connected = true;
106 server.pid = None;
107 server.last_error = None;
108 }
109 drop(servers);
110 self.persist_state().await;
111 true
112 }
113 }
114
115 pub async fn disconnect(&self, name: &str) -> bool {
116 if let Some(mut child) = self.processes.lock().await.remove(name) {
117 let _ = child.kill().await;
118 let _ = child.wait().await;
119 }
120 let mut servers = self.servers.write().await;
121 if let Some(server) = servers.get_mut(name) {
122 server.connected = false;
123 server.pid = None;
124 drop(servers);
125 self.persist_state().await;
126 return true;
127 }
128 false
129 }
130
131 async fn persist_state(&self) {
132 let snapshot = self.servers.read().await.clone();
133 if let Some(parent) = self.state_file.parent() {
134 let _ = tokio::fs::create_dir_all(parent).await;
135 }
136 if let Ok(payload) = serde_json::to_string_pretty(&snapshot) {
137 let _ = tokio::fs::write(self.state_file.as_path(), payload).await;
138 }
139 }
140}
141
142impl Default for McpRegistry {
143 fn default() -> Self {
144 Self::new()
145 }
146}
147
148fn resolve_state_file() -> PathBuf {
149 if let Ok(path) = std::env::var("TANDEM_MCP_REGISTRY") {
150 return PathBuf::from(path);
151 }
152 PathBuf::from(".tandem").join("mcp_servers.json")
153}
154
155fn load_state(path: &Path) -> HashMap<String, McpServer> {
156 let Ok(raw) = std::fs::read_to_string(path) else {
157 return HashMap::new();
158 };
159 serde_json::from_str::<HashMap<String, McpServer>>(&raw).unwrap_or_default()
160}
161
162fn parse_stdio_transport(transport: &str) -> Option<&str> {
163 transport.strip_prefix("stdio:").map(str::trim)
164}
165
166async fn spawn_stdio_process(command_text: &str) -> Result<Child, String> {
167 if command_text.is_empty() {
168 return Err("Missing stdio command".to_string());
169 }
170 #[cfg(windows)]
171 let mut command = {
172 let mut cmd = Command::new("powershell");
173 cmd.args(["-NoProfile", "-Command", command_text]);
174 cmd
175 };
176 #[cfg(not(windows))]
177 let mut command = {
178 let mut cmd = Command::new("sh");
179 cmd.args(["-lc", command_text]);
180 cmd
181 };
182 command
183 .stdin(std::process::Stdio::null())
184 .stdout(std::process::Stdio::null())
185 .stderr(std::process::Stdio::null());
186 command.spawn().map_err(|e| e.to_string())
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use uuid::Uuid;
193
194 #[tokio::test]
195 async fn add_connect_disconnect_non_stdio_server() {
196 let file = std::env::temp_dir().join(format!("mcp-test-{}.json", Uuid::new_v4()));
197 let registry = McpRegistry::new_with_state_file(file);
198 registry
199 .add("example".to_string(), "sse:https://example.com".to_string())
200 .await;
201 assert!(registry.connect("example").await);
202 let listed = registry.list().await;
203 assert!(listed.get("example").map(|s| s.connected).unwrap_or(false));
204 assert!(registry.disconnect("example").await);
205 }
206}