Skip to main content

tandem_runtime/
mcp.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, CONTENT_TYPE};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use sha2::{Digest, Sha256};
10use tandem_types::ToolResult;
11use tokio::process::{Child, Command};
12use tokio::sync::{Mutex, RwLock};
13
14const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
15const MCP_CLIENT_NAME: &str = "tandem";
16const MCP_CLIENT_VERSION: &str = env!("CARGO_PKG_VERSION");
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct McpToolCacheEntry {
20    pub tool_name: String,
21    pub description: String,
22    #[serde(default)]
23    pub input_schema: Value,
24    pub fetched_at_ms: u64,
25    pub schema_hash: String,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct McpServer {
30    pub name: String,
31    pub transport: String,
32    #[serde(default = "default_enabled")]
33    pub enabled: bool,
34    pub connected: bool,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub pid: Option<u32>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub last_error: Option<String>,
39    #[serde(default)]
40    pub headers: HashMap<String, String>,
41    #[serde(default)]
42    pub tool_cache: Vec<McpToolCacheEntry>,
43    #[serde(default, skip_serializing_if = "Option::is_none")]
44    pub tools_fetched_at_ms: Option<u64>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct McpRemoteTool {
49    pub server_name: String,
50    pub tool_name: String,
51    pub namespaced_name: String,
52    pub description: String,
53    #[serde(default)]
54    pub input_schema: Value,
55    pub fetched_at_ms: u64,
56    pub schema_hash: String,
57}
58
59#[derive(Clone)]
60pub struct McpRegistry {
61    servers: Arc<RwLock<HashMap<String, McpServer>>>,
62    processes: Arc<Mutex<HashMap<String, Child>>>,
63    state_file: Arc<PathBuf>,
64}
65
66impl McpRegistry {
67    pub fn new() -> Self {
68        Self::new_with_state_file(resolve_state_file())
69    }
70
71    pub fn new_with_state_file(state_file: PathBuf) -> Self {
72        let loaded = load_state(&state_file)
73            .into_iter()
74            .map(|(k, mut v)| {
75                v.connected = false;
76                v.pid = None;
77                if v.name.trim().is_empty() {
78                    v.name = k.clone();
79                }
80                if v.headers.is_empty() {
81                    v.headers = HashMap::new();
82                }
83                (k, v)
84            })
85            .collect::<HashMap<_, _>>();
86        Self {
87            servers: Arc::new(RwLock::new(loaded)),
88            processes: Arc::new(Mutex::new(HashMap::new())),
89            state_file: Arc::new(state_file),
90        }
91    }
92
93    pub async fn list(&self) -> HashMap<String, McpServer> {
94        self.servers.read().await.clone()
95    }
96
97    pub async fn add(&self, name: String, transport: String) {
98        self.add_or_update(name, transport, HashMap::new(), true).await;
99    }
100
101    pub async fn add_or_update(
102        &self,
103        name: String,
104        transport: String,
105        headers: HashMap<String, String>,
106        enabled: bool,
107    ) {
108        let mut servers = self.servers.write().await;
109        let existing = servers.get(&name).cloned();
110        let existing_tool_cache = existing
111            .as_ref()
112            .map(|row| row.tool_cache.clone())
113            .unwrap_or_default();
114        let existing_fetched_at = existing.as_ref().and_then(|row| row.tools_fetched_at_ms);
115        let server = McpServer {
116            name: name.clone(),
117            transport,
118            enabled,
119            connected: false,
120            pid: None,
121            last_error: None,
122            headers,
123            tool_cache: existing_tool_cache,
124            tools_fetched_at_ms: existing_fetched_at,
125        };
126        servers.insert(name, server);
127        drop(servers);
128        self.persist_state().await;
129    }
130
131    pub async fn set_enabled(&self, name: &str, enabled: bool) -> bool {
132        let mut servers = self.servers.write().await;
133        let Some(server) = servers.get_mut(name) else {
134            return false;
135        };
136        server.enabled = enabled;
137        if !enabled {
138            server.connected = false;
139            server.pid = None;
140        }
141        drop(servers);
142        if !enabled {
143            if let Some(mut child) = self.processes.lock().await.remove(name) {
144                let _ = child.kill().await;
145                let _ = child.wait().await;
146            }
147        }
148        self.persist_state().await;
149        true
150    }
151
152    pub async fn connect(&self, name: &str) -> bool {
153        let server = {
154            let servers = self.servers.read().await;
155            let Some(server) = servers.get(name) else {
156                return false;
157            };
158            server.clone()
159        };
160
161        if !server.enabled {
162            let mut servers = self.servers.write().await;
163            if let Some(entry) = servers.get_mut(name) {
164                entry.connected = false;
165                entry.pid = None;
166                entry.last_error = Some("MCP server is disabled".to_string());
167            }
168            drop(servers);
169            self.persist_state().await;
170            return false;
171        }
172
173        if let Some(command_text) = parse_stdio_transport(&server.transport) {
174            return self.connect_stdio(name, command_text).await;
175        }
176
177        if parse_remote_endpoint(&server.transport).is_some() {
178            return self.refresh(name).await.is_ok();
179        }
180
181        let mut servers = self.servers.write().await;
182        if let Some(entry) = servers.get_mut(name) {
183            entry.connected = true;
184            entry.pid = None;
185            entry.last_error = None;
186        }
187        drop(servers);
188        self.persist_state().await;
189        true
190    }
191
192    pub async fn refresh(&self, name: &str) -> Result<Vec<McpRemoteTool>, String> {
193        let server = {
194            let servers = self.servers.read().await;
195            let Some(server) = servers.get(name) else {
196                return Err("MCP server not found".to_string());
197            };
198            server.clone()
199        };
200
201        if !server.enabled {
202            return Err("MCP server is disabled".to_string());
203        }
204
205        let endpoint = parse_remote_endpoint(&server.transport)
206            .ok_or_else(|| "MCP refresh currently supports HTTP/S transports only".to_string())?;
207
208        let tools = match self.discover_remote_tools(&endpoint, &server.headers).await {
209            Ok(tools) => tools,
210            Err(err) => {
211                let mut servers = self.servers.write().await;
212                if let Some(entry) = servers.get_mut(name) {
213                    entry.connected = false;
214                    entry.pid = None;
215                    entry.last_error = Some(err.clone());
216                }
217                drop(servers);
218                self.persist_state().await;
219                return Err(err);
220            }
221        };
222
223        let now = now_ms();
224        let cache = tools
225            .iter()
226            .map(|tool| McpToolCacheEntry {
227                tool_name: tool.tool_name.clone(),
228                description: tool.description.clone(),
229                input_schema: tool.input_schema.clone(),
230                fetched_at_ms: now,
231                schema_hash: schema_hash(&tool.input_schema),
232            })
233            .collect::<Vec<_>>();
234
235        let mut servers = self.servers.write().await;
236        if let Some(entry) = servers.get_mut(name) {
237            entry.connected = true;
238            entry.pid = None;
239            entry.last_error = None;
240            entry.tool_cache = cache;
241            entry.tools_fetched_at_ms = Some(now);
242        }
243        drop(servers);
244        self.persist_state().await;
245        Ok(self.server_tools(name).await)
246    }
247
248    pub async fn disconnect(&self, name: &str) -> bool {
249        if let Some(mut child) = self.processes.lock().await.remove(name) {
250            let _ = child.kill().await;
251            let _ = child.wait().await;
252        }
253        let mut servers = self.servers.write().await;
254        if let Some(server) = servers.get_mut(name) {
255            server.connected = false;
256            server.pid = None;
257            drop(servers);
258            self.persist_state().await;
259            return true;
260        }
261        false
262    }
263
264    pub async fn list_tools(&self) -> Vec<McpRemoteTool> {
265        let mut out = self
266            .servers
267            .read()
268            .await
269            .values()
270            .filter(|server| server.enabled && server.connected)
271            .flat_map(server_tool_rows)
272            .collect::<Vec<_>>();
273        out.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
274        out
275    }
276
277    pub async fn server_tools(&self, name: &str) -> Vec<McpRemoteTool> {
278        let Some(server) = self.servers.read().await.get(name).cloned() else {
279            return Vec::new();
280        };
281        let mut rows = server_tool_rows(&server);
282        rows.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
283        rows
284    }
285
286    pub async fn call_tool(
287        &self,
288        server_name: &str,
289        tool_name: &str,
290        args: Value,
291    ) -> Result<ToolResult, String> {
292        let server = {
293            let servers = self.servers.read().await;
294            let Some(server) = servers.get(server_name) else {
295                return Err(format!("MCP server '{server_name}' not found"));
296            };
297            server.clone()
298        };
299
300        if !server.enabled {
301            return Err(format!("MCP server '{server_name}' is disabled"));
302        }
303        if !server.connected {
304            return Err(format!("MCP server '{server_name}' is not connected"));
305        }
306
307        let endpoint = parse_remote_endpoint(&server.transport)
308            .ok_or_else(|| "MCP tools/call currently supports HTTP/S transports only".to_string())?;
309
310        let request = json!({
311            "jsonrpc": "2.0",
312            "id": format!("call-{}-{}", server_name, now_ms()),
313            "method": "tools/call",
314            "params": {
315                "name": tool_name,
316                "arguments": args
317            }
318        });
319        let response = post_json_rpc(&endpoint, &server.headers, request).await?;
320
321        if let Some(err) = response.get("error") {
322            let message = err
323                .get("message")
324                .and_then(|v| v.as_str())
325                .unwrap_or("MCP tools/call failed");
326            return Err(message.to_string());
327        }
328
329        let result = response.get("result").cloned().unwrap_or(Value::Null);
330        let output = result
331            .get("content")
332            .map(render_mcp_content)
333            .or_else(|| result.get("output").map(|v| v.to_string()))
334            .unwrap_or_else(|| result.to_string());
335
336        Ok(ToolResult {
337            output,
338            metadata: json!({
339                "server": server_name,
340                "tool": tool_name,
341                "result": result
342            }),
343        })
344    }
345
346    async fn connect_stdio(&self, name: &str, command_text: &str) -> bool {
347        match spawn_stdio_process(command_text).await {
348            Ok(child) => {
349                let pid = child.id();
350                self.processes.lock().await.insert(name.to_string(), child);
351                let mut servers = self.servers.write().await;
352                if let Some(server) = servers.get_mut(name) {
353                    server.connected = true;
354                    server.pid = pid;
355                    server.last_error = None;
356                }
357                drop(servers);
358                self.persist_state().await;
359                true
360            }
361            Err(err) => {
362                let mut servers = self.servers.write().await;
363                if let Some(server) = servers.get_mut(name) {
364                    server.connected = false;
365                    server.pid = None;
366                    server.last_error = Some(err);
367                }
368                drop(servers);
369                self.persist_state().await;
370                false
371            }
372        }
373    }
374
375    async fn discover_remote_tools(
376        &self,
377        endpoint: &str,
378        headers: &HashMap<String, String>,
379    ) -> Result<Vec<McpRemoteTool>, String> {
380        let initialize = json!({
381            "jsonrpc": "2.0",
382            "id": "initialize-1",
383            "method": "initialize",
384            "params": {
385                "protocolVersion": MCP_PROTOCOL_VERSION,
386                "capabilities": {},
387                "clientInfo": {
388                    "name": MCP_CLIENT_NAME,
389                    "version": MCP_CLIENT_VERSION,
390                }
391            }
392        });
393        let init_response = post_json_rpc(endpoint, headers, initialize).await?;
394        if let Some(err) = init_response.get("error") {
395            let message = err
396                .get("message")
397                .and_then(|v| v.as_str())
398                .unwrap_or("MCP initialize failed");
399            return Err(message.to_string());
400        }
401
402        let tools_list = json!({
403            "jsonrpc": "2.0",
404            "id": "tools-list-1",
405            "method": "tools/list",
406            "params": {}
407        });
408        let tools_response = post_json_rpc(endpoint, headers, tools_list).await?;
409        if let Some(err) = tools_response.get("error") {
410            let message = err
411                .get("message")
412                .and_then(|v| v.as_str())
413                .unwrap_or("MCP tools/list failed");
414            return Err(message.to_string());
415        }
416
417        let tools = tools_response
418            .get("result")
419            .and_then(|v| v.get("tools"))
420            .and_then(|v| v.as_array())
421            .ok_or_else(|| "MCP tools/list result missing tools array".to_string())?;
422
423        let now = now_ms();
424        let mut out = Vec::new();
425        for row in tools {
426            let Some(tool_name) = row.get("name").and_then(|v| v.as_str()) else {
427                continue;
428            };
429            let description = row
430                .get("description")
431                .and_then(|v| v.as_str())
432                .unwrap_or("")
433                .to_string();
434            let input_schema = row
435                .get("inputSchema")
436                .or_else(|| row.get("input_schema"))
437                .cloned()
438                .unwrap_or_else(|| json!({"type":"object"}));
439            out.push(McpRemoteTool {
440                server_name: String::new(),
441                tool_name: tool_name.to_string(),
442                namespaced_name: String::new(),
443                description,
444                input_schema,
445                fetched_at_ms: now,
446                schema_hash: String::new(),
447            });
448        }
449
450        Ok(out)
451    }
452
453    async fn persist_state(&self) {
454        let snapshot = self.servers.read().await.clone();
455        if let Some(parent) = self.state_file.parent() {
456            let _ = tokio::fs::create_dir_all(parent).await;
457        }
458        if let Ok(payload) = serde_json::to_string_pretty(&snapshot) {
459            let _ = tokio::fs::write(self.state_file.as_path(), payload).await;
460        }
461    }
462}
463
464impl Default for McpRegistry {
465    fn default() -> Self {
466        Self::new()
467    }
468}
469
470fn default_enabled() -> bool {
471    true
472}
473
474fn resolve_state_file() -> PathBuf {
475    if let Ok(path) = std::env::var("TANDEM_MCP_REGISTRY") {
476        return PathBuf::from(path);
477    }
478    PathBuf::from(".tandem").join("mcp_servers.json")
479}
480
481fn load_state(path: &Path) -> HashMap<String, McpServer> {
482    let Ok(raw) = std::fs::read_to_string(path) else {
483        return HashMap::new();
484    };
485    serde_json::from_str::<HashMap<String, McpServer>>(&raw).unwrap_or_default()
486}
487
488fn parse_stdio_transport(transport: &str) -> Option<&str> {
489    transport.strip_prefix("stdio:").map(str::trim)
490}
491
492fn parse_remote_endpoint(transport: &str) -> Option<String> {
493    let trimmed = transport.trim();
494    if trimmed.starts_with("http://") || trimmed.starts_with("https://") {
495        return Some(trimmed.to_string());
496    }
497    for prefix in ["http:", "https:"] {
498        if let Some(rest) = trimmed.strip_prefix(prefix) {
499            let endpoint = rest.trim();
500            if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
501                return Some(endpoint.to_string());
502            }
503        }
504    }
505    None
506}
507
508fn server_tool_rows(server: &McpServer) -> Vec<McpRemoteTool> {
509    let server_slug = sanitize_namespace_segment(&server.name);
510    server
511        .tool_cache
512        .iter()
513        .map(|tool| {
514            let tool_slug = sanitize_namespace_segment(&tool.tool_name);
515            McpRemoteTool {
516                server_name: server.name.clone(),
517                tool_name: tool.tool_name.clone(),
518                namespaced_name: format!("mcp.{server_slug}.{tool_slug}"),
519                description: tool.description.clone(),
520                input_schema: tool.input_schema.clone(),
521                fetched_at_ms: tool.fetched_at_ms,
522                schema_hash: tool.schema_hash.clone(),
523            }
524        })
525        .collect()
526}
527
528fn sanitize_namespace_segment(raw: &str) -> String {
529    let mut out = String::new();
530    let mut previous_underscore = false;
531    for ch in raw.trim().chars() {
532        if ch.is_ascii_alphanumeric() {
533            out.push(ch.to_ascii_lowercase());
534            previous_underscore = false;
535        } else if !previous_underscore {
536            out.push('_');
537            previous_underscore = true;
538        }
539    }
540    let cleaned = out.trim_matches('_');
541    if cleaned.is_empty() {
542        "tool".to_string()
543    } else {
544        cleaned.to_string()
545    }
546}
547
548fn schema_hash(schema: &Value) -> String {
549    let payload = serde_json::to_vec(schema).unwrap_or_default();
550    let mut hasher = Sha256::new();
551    hasher.update(payload);
552    format!("{:x}", hasher.finalize())
553}
554
555fn now_ms() -> u64 {
556    SystemTime::now()
557        .duration_since(UNIX_EPOCH)
558        .map(|d| d.as_millis() as u64)
559        .unwrap_or(0)
560}
561
562fn build_headers(headers: &HashMap<String, String>) -> Result<HeaderMap, String> {
563    let mut map = HeaderMap::new();
564    map.insert(
565        ACCEPT,
566        HeaderValue::from_static("application/json, text/event-stream"),
567    );
568    map.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
569    for (key, value) in headers {
570        let name = HeaderName::from_bytes(key.trim().as_bytes())
571            .map_err(|e| format!("Invalid header name '{key}': {e}"))?;
572        let header = HeaderValue::from_str(value.trim())
573            .map_err(|e| format!("Invalid header value for '{key}': {e}"))?;
574        map.insert(name, header);
575    }
576    Ok(map)
577}
578
579async fn post_json_rpc(
580    endpoint: &str,
581    headers: &HashMap<String, String>,
582    request: Value,
583) -> Result<Value, String> {
584    let client = reqwest::Client::builder()
585        .timeout(std::time::Duration::from_secs(12))
586        .build()
587        .map_err(|e| format!("Failed to build HTTP client: {e}"))?;
588    let response = client
589        .post(endpoint)
590        .headers(build_headers(headers)?)
591        .json(&request)
592        .send()
593        .await
594        .map_err(|e| format!("MCP request failed: {e}"))?;
595    let status = response.status();
596    let payload = response
597        .text()
598        .await
599        .map_err(|e| format!("Failed to read MCP response: {e}"))?;
600    if !status.is_success() {
601        return Err(format!(
602            "MCP endpoint returned HTTP {}: {}",
603            status.as_u16(),
604            payload.chars().take(400).collect::<String>()
605        ));
606    }
607    serde_json::from_str::<Value>(&payload)
608        .map_err(|e| format!("Invalid MCP JSON response: {e}"))
609}
610
611fn render_mcp_content(value: &Value) -> String {
612    let Some(items) = value.as_array() else {
613        return value.to_string();
614    };
615    let mut chunks = Vec::new();
616    for item in items {
617        if let Some(text) = item.get("text").and_then(|v| v.as_str()) {
618            chunks.push(text.to_string());
619            continue;
620        }
621        chunks.push(item.to_string());
622    }
623    if chunks.is_empty() {
624        value.to_string()
625    } else {
626        chunks.join("\n")
627    }
628}
629
630async fn spawn_stdio_process(command_text: &str) -> Result<Child, String> {
631    if command_text.is_empty() {
632        return Err("Missing stdio command".to_string());
633    }
634    #[cfg(windows)]
635    let mut command = {
636        let mut cmd = Command::new("powershell");
637        cmd.args(["-NoProfile", "-Command", command_text]);
638        cmd
639    };
640    #[cfg(not(windows))]
641    let mut command = {
642        let mut cmd = Command::new("sh");
643        cmd.args(["-lc", command_text]);
644        cmd
645    };
646    command
647        .stdin(std::process::Stdio::null())
648        .stdout(std::process::Stdio::null())
649        .stderr(std::process::Stdio::null());
650    command.spawn().map_err(|e| e.to_string())
651}
652
653#[cfg(test)]
654mod tests {
655    use super::*;
656    use uuid::Uuid;
657
658    #[tokio::test]
659    async fn add_connect_disconnect_non_stdio_server() {
660        let file = std::env::temp_dir().join(format!("mcp-test-{}.json", Uuid::new_v4()));
661        let registry = McpRegistry::new_with_state_file(file);
662        registry
663            .add("example".to_string(), "sse:https://example.com".to_string())
664            .await;
665        assert!(registry.connect("example").await);
666        let listed = registry.list().await;
667        assert!(listed.get("example").map(|s| s.connected).unwrap_or(false));
668        assert!(registry.disconnect("example").await);
669    }
670
671    #[test]
672    fn parse_remote_endpoint_supports_http_prefixes() {
673        assert_eq!(
674            parse_remote_endpoint("https://mcp.example.com/mcp"),
675            Some("https://mcp.example.com/mcp".to_string())
676        );
677        assert_eq!(
678            parse_remote_endpoint("http:https://mcp.example.com/mcp"),
679            Some("https://mcp.example.com/mcp".to_string())
680        );
681    }
682}