Skip to main content

tandem_runtime/
mcp.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use serde::{Deserialize, Serialize};
6use tokio::process::{Child, Command};
7use tokio::sync::{Mutex, RwLock};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct McpServer {
11    pub name: String,
12    pub transport: String,
13    pub connected: bool,
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub pid: Option<u32>,
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub last_error: Option<String>,
18}
19
20#[derive(Clone)]
21pub struct McpRegistry {
22    servers: Arc<RwLock<HashMap<String, McpServer>>>,
23    processes: Arc<Mutex<HashMap<String, Child>>>,
24    state_file: Arc<PathBuf>,
25}
26
27impl McpRegistry {
28    pub fn new() -> Self {
29        Self::new_with_state_file(resolve_state_file())
30    }
31
32    pub fn new_with_state_file(state_file: PathBuf) -> Self {
33        let loaded = load_state(&state_file)
34            .into_iter()
35            .map(|(k, mut v)| {
36                v.connected = false;
37                v.pid = None;
38                (k, v)
39            })
40            .collect::<HashMap<_, _>>();
41        Self {
42            servers: Arc::new(RwLock::new(loaded)),
43            processes: Arc::new(Mutex::new(HashMap::new())),
44            state_file: Arc::new(state_file),
45        }
46    }
47
48    pub async fn list(&self) -> HashMap<String, McpServer> {
49        self.servers.read().await.clone()
50    }
51
52    pub async fn add(&self, name: String, transport: String) {
53        self.servers.write().await.insert(
54            name.clone(),
55            McpServer {
56                name,
57                transport,
58                connected: false,
59                pid: None,
60                last_error: None,
61            },
62        );
63        self.persist_state().await;
64    }
65
66    pub async fn connect(&self, name: &str) -> bool {
67        let transport = {
68            let servers = self.servers.read().await;
69            let Some(server) = servers.get(name) else {
70                return false;
71            };
72            server.transport.clone()
73        };
74
75        if let Some(command_text) = parse_stdio_transport(&transport) {
76            match spawn_stdio_process(command_text).await {
77                Ok(child) => {
78                    let pid = child.id();
79                    self.processes.lock().await.insert(name.to_string(), child);
80                    let mut servers = self.servers.write().await;
81                    if let Some(server) = servers.get_mut(name) {
82                        server.connected = true;
83                        server.pid = pid;
84                        server.last_error = None;
85                    }
86                    drop(servers);
87                    self.persist_state().await;
88                    true
89                }
90                Err(err) => {
91                    let mut servers = self.servers.write().await;
92                    if let Some(server) = servers.get_mut(name) {
93                        server.connected = false;
94                        server.pid = None;
95                        server.last_error = Some(err);
96                    }
97                    drop(servers);
98                    self.persist_state().await;
99                    false
100                }
101            }
102        } else {
103            let mut servers = self.servers.write().await;
104            if let Some(server) = servers.get_mut(name) {
105                server.connected = true;
106                server.pid = None;
107                server.last_error = None;
108            }
109            drop(servers);
110            self.persist_state().await;
111            true
112        }
113    }
114
115    pub async fn disconnect(&self, name: &str) -> bool {
116        if let Some(mut child) = self.processes.lock().await.remove(name) {
117            let _ = child.kill().await;
118            let _ = child.wait().await;
119        }
120        let mut servers = self.servers.write().await;
121        if let Some(server) = servers.get_mut(name) {
122            server.connected = false;
123            server.pid = None;
124            drop(servers);
125            self.persist_state().await;
126            return true;
127        }
128        false
129    }
130
131    async fn persist_state(&self) {
132        let snapshot = self.servers.read().await.clone();
133        if let Some(parent) = self.state_file.parent() {
134            let _ = tokio::fs::create_dir_all(parent).await;
135        }
136        if let Ok(payload) = serde_json::to_string_pretty(&snapshot) {
137            let _ = tokio::fs::write(self.state_file.as_path(), payload).await;
138        }
139    }
140}
141
142impl Default for McpRegistry {
143    fn default() -> Self {
144        Self::new()
145    }
146}
147
148fn resolve_state_file() -> PathBuf {
149    if let Ok(path) = std::env::var("TANDEM_MCP_REGISTRY") {
150        return PathBuf::from(path);
151    }
152    PathBuf::from(".tandem").join("mcp_servers.json")
153}
154
155fn load_state(path: &Path) -> HashMap<String, McpServer> {
156    let Ok(raw) = std::fs::read_to_string(path) else {
157        return HashMap::new();
158    };
159    serde_json::from_str::<HashMap<String, McpServer>>(&raw).unwrap_or_default()
160}
161
162fn parse_stdio_transport(transport: &str) -> Option<&str> {
163    transport.strip_prefix("stdio:").map(str::trim)
164}
165
166async fn spawn_stdio_process(command_text: &str) -> Result<Child, String> {
167    if command_text.is_empty() {
168        return Err("Missing stdio command".to_string());
169    }
170    #[cfg(windows)]
171    let mut command = {
172        let mut cmd = Command::new("powershell");
173        cmd.args(["-NoProfile", "-Command", command_text]);
174        cmd
175    };
176    #[cfg(not(windows))]
177    let mut command = {
178        let mut cmd = Command::new("sh");
179        cmd.args(["-lc", command_text]);
180        cmd
181    };
182    command
183        .stdin(std::process::Stdio::null())
184        .stdout(std::process::Stdio::null())
185        .stderr(std::process::Stdio::null());
186    command.spawn().map_err(|e| e.to_string())
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use uuid::Uuid;
193
194    #[tokio::test]
195    async fn add_connect_disconnect_non_stdio_server() {
196        let file = std::env::temp_dir().join(format!("mcp-test-{}.json", Uuid::new_v4()));
197        let registry = McpRegistry::new_with_state_file(file);
198        registry
199            .add("example".to_string(), "sse:https://example.com".to_string())
200            .await;
201        assert!(registry.connect("example").await);
202        let listed = registry.list().await;
203        assert!(listed.get("example").map(|s| s.connected).unwrap_or(false));
204        assert!(registry.disconnect("example").await);
205    }
206}