Skip to main content

tandem_runtime/
mcp.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, CONTENT_TYPE};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use sha2::{Digest, Sha256};
10use tandem_types::ToolResult;
11use tokio::process::{Child, Command};
12use tokio::sync::{Mutex, RwLock};
13
14const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
15const MCP_CLIENT_NAME: &str = "tandem";
16const MCP_CLIENT_VERSION: &str = env!("CARGO_PKG_VERSION");
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct McpToolCacheEntry {
20    pub tool_name: String,
21    pub description: String,
22    #[serde(default)]
23    pub input_schema: Value,
24    pub fetched_at_ms: u64,
25    pub schema_hash: String,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct McpServer {
30    pub name: String,
31    pub transport: String,
32    #[serde(default = "default_enabled")]
33    pub enabled: bool,
34    pub connected: bool,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub pid: Option<u32>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub last_error: Option<String>,
39    #[serde(default, skip_serializing_if = "Option::is_none")]
40    pub last_auth_challenge: Option<McpAuthChallenge>,
41    #[serde(default, skip_serializing_if = "Option::is_none")]
42    pub mcp_session_id: Option<String>,
43    #[serde(default)]
44    pub headers: HashMap<String, String>,
45    #[serde(default)]
46    pub tool_cache: Vec<McpToolCacheEntry>,
47    #[serde(default, skip_serializing_if = "Option::is_none")]
48    pub tools_fetched_at_ms: Option<u64>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct McpAuthChallenge {
53    pub challenge_id: String,
54    pub tool_name: String,
55    pub authorization_url: String,
56    pub message: String,
57    pub requested_at_ms: u64,
58    pub status: String,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct McpRemoteTool {
63    pub server_name: String,
64    pub tool_name: String,
65    pub namespaced_name: String,
66    pub description: String,
67    #[serde(default)]
68    pub input_schema: Value,
69    pub fetched_at_ms: u64,
70    pub schema_hash: String,
71}
72
73#[derive(Clone)]
74pub struct McpRegistry {
75    servers: Arc<RwLock<HashMap<String, McpServer>>>,
76    processes: Arc<Mutex<HashMap<String, Child>>>,
77    state_file: Arc<PathBuf>,
78}
79
80impl McpRegistry {
81    pub fn new() -> Self {
82        Self::new_with_state_file(resolve_state_file())
83    }
84
85    pub fn new_with_state_file(state_file: PathBuf) -> Self {
86        let loaded = load_state(&state_file)
87            .into_iter()
88            .map(|(k, mut v)| {
89                v.connected = false;
90                v.pid = None;
91                if v.name.trim().is_empty() {
92                    v.name = k.clone();
93                }
94                if v.headers.is_empty() {
95                    v.headers = HashMap::new();
96                }
97                (k, v)
98            })
99            .collect::<HashMap<_, _>>();
100        Self {
101            servers: Arc::new(RwLock::new(loaded)),
102            processes: Arc::new(Mutex::new(HashMap::new())),
103            state_file: Arc::new(state_file),
104        }
105    }
106
107    pub async fn list(&self) -> HashMap<String, McpServer> {
108        self.servers.read().await.clone()
109    }
110
111    pub async fn add(&self, name: String, transport: String) {
112        self.add_or_update(name, transport, HashMap::new(), true)
113            .await;
114    }
115
116    pub async fn add_or_update(
117        &self,
118        name: String,
119        transport: String,
120        headers: HashMap<String, String>,
121        enabled: bool,
122    ) {
123        let mut servers = self.servers.write().await;
124        let existing = servers.get(&name).cloned();
125        let preserve_cache = existing
126            .as_ref()
127            .is_some_and(|row| row.transport == transport && row.headers == headers);
128        let existing_tool_cache = if preserve_cache {
129            existing
130                .as_ref()
131                .map(|row| row.tool_cache.clone())
132                .unwrap_or_default()
133        } else {
134            Vec::new()
135        };
136        let existing_fetched_at = if preserve_cache {
137            existing.as_ref().and_then(|row| row.tools_fetched_at_ms)
138        } else {
139            None
140        };
141        let server = McpServer {
142            name: name.clone(),
143            transport,
144            enabled,
145            connected: false,
146            pid: None,
147            last_error: None,
148            last_auth_challenge: None,
149            mcp_session_id: None,
150            headers,
151            tool_cache: existing_tool_cache,
152            tools_fetched_at_ms: existing_fetched_at,
153        };
154        servers.insert(name, server);
155        drop(servers);
156        self.persist_state().await;
157    }
158
159    pub async fn set_enabled(&self, name: &str, enabled: bool) -> bool {
160        let mut servers = self.servers.write().await;
161        let Some(server) = servers.get_mut(name) else {
162            return false;
163        };
164        server.enabled = enabled;
165        if !enabled {
166            server.connected = false;
167            server.pid = None;
168            server.last_auth_challenge = None;
169            server.mcp_session_id = None;
170        }
171        drop(servers);
172        if !enabled {
173            if let Some(mut child) = self.processes.lock().await.remove(name) {
174                let _ = child.kill().await;
175                let _ = child.wait().await;
176            }
177        }
178        self.persist_state().await;
179        true
180    }
181
182    pub async fn remove(&self, name: &str) -> bool {
183        let removed = {
184            let mut servers = self.servers.write().await;
185            servers.remove(name).is_some()
186        };
187        if !removed {
188            return false;
189        }
190
191        if let Some(mut child) = self.processes.lock().await.remove(name) {
192            let _ = child.kill().await;
193            let _ = child.wait().await;
194        }
195        self.persist_state().await;
196        true
197    }
198
199    pub async fn connect(&self, name: &str) -> bool {
200        let server = {
201            let servers = self.servers.read().await;
202            let Some(server) = servers.get(name) else {
203                return false;
204            };
205            server.clone()
206        };
207
208        if !server.enabled {
209            let mut servers = self.servers.write().await;
210            if let Some(entry) = servers.get_mut(name) {
211                entry.connected = false;
212                entry.pid = None;
213                entry.last_error = Some("MCP server is disabled".to_string());
214                entry.last_auth_challenge = None;
215                entry.mcp_session_id = None;
216            }
217            drop(servers);
218            self.persist_state().await;
219            return false;
220        }
221
222        if let Some(command_text) = parse_stdio_transport(&server.transport) {
223            return self.connect_stdio(name, command_text).await;
224        }
225
226        if parse_remote_endpoint(&server.transport).is_some() {
227            return self.refresh(name).await.is_ok();
228        }
229
230        let mut servers = self.servers.write().await;
231        if let Some(entry) = servers.get_mut(name) {
232            entry.connected = true;
233            entry.pid = None;
234            entry.last_error = None;
235            entry.last_auth_challenge = None;
236            entry.mcp_session_id = None;
237        }
238        drop(servers);
239        self.persist_state().await;
240        true
241    }
242
243    pub async fn refresh(&self, name: &str) -> Result<Vec<McpRemoteTool>, String> {
244        let server = {
245            let servers = self.servers.read().await;
246            let Some(server) = servers.get(name) else {
247                return Err("MCP server not found".to_string());
248            };
249            server.clone()
250        };
251
252        if !server.enabled {
253            return Err("MCP server is disabled".to_string());
254        }
255
256        let endpoint = parse_remote_endpoint(&server.transport)
257            .ok_or_else(|| "MCP refresh currently supports HTTP/S transports only".to_string())?;
258
259        let (tools, session_id) = match self.discover_remote_tools(&endpoint, &server.headers).await
260        {
261            Ok(result) => result,
262            Err(err) => {
263                let mut servers = self.servers.write().await;
264                if let Some(entry) = servers.get_mut(name) {
265                    entry.connected = false;
266                    entry.pid = None;
267                    entry.last_error = Some(err.clone());
268                    entry.last_auth_challenge = None;
269                    entry.mcp_session_id = None;
270                    entry.tool_cache.clear();
271                    entry.tools_fetched_at_ms = None;
272                }
273                drop(servers);
274                self.persist_state().await;
275                return Err(err);
276            }
277        };
278
279        let now = now_ms();
280        let cache = tools
281            .iter()
282            .map(|tool| McpToolCacheEntry {
283                tool_name: tool.tool_name.clone(),
284                description: tool.description.clone(),
285                input_schema: tool.input_schema.clone(),
286                fetched_at_ms: now,
287                schema_hash: schema_hash(&tool.input_schema),
288            })
289            .collect::<Vec<_>>();
290
291        let mut servers = self.servers.write().await;
292        if let Some(entry) = servers.get_mut(name) {
293            entry.connected = true;
294            entry.pid = None;
295            entry.last_error = None;
296            entry.last_auth_challenge = None;
297            entry.mcp_session_id = session_id;
298            entry.tool_cache = cache;
299            entry.tools_fetched_at_ms = Some(now);
300        }
301        drop(servers);
302        self.persist_state().await;
303        Ok(self.server_tools(name).await)
304    }
305
306    pub async fn disconnect(&self, name: &str) -> bool {
307        if let Some(mut child) = self.processes.lock().await.remove(name) {
308            let _ = child.kill().await;
309            let _ = child.wait().await;
310        }
311        let mut servers = self.servers.write().await;
312        if let Some(server) = servers.get_mut(name) {
313            server.connected = false;
314            server.pid = None;
315            server.last_auth_challenge = None;
316            server.mcp_session_id = None;
317            drop(servers);
318            self.persist_state().await;
319            return true;
320        }
321        false
322    }
323
324    pub async fn list_tools(&self) -> Vec<McpRemoteTool> {
325        let mut out = self
326            .servers
327            .read()
328            .await
329            .values()
330            .filter(|server| server.enabled && server.connected)
331            .flat_map(server_tool_rows)
332            .collect::<Vec<_>>();
333        out.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
334        out
335    }
336
337    pub async fn server_tools(&self, name: &str) -> Vec<McpRemoteTool> {
338        let Some(server) = self.servers.read().await.get(name).cloned() else {
339            return Vec::new();
340        };
341        let mut rows = server_tool_rows(&server);
342        rows.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
343        rows
344    }
345
346    pub async fn call_tool(
347        &self,
348        server_name: &str,
349        tool_name: &str,
350        args: Value,
351    ) -> Result<ToolResult, String> {
352        let server = {
353            let servers = self.servers.read().await;
354            let Some(server) = servers.get(server_name) else {
355                return Err(format!("MCP server '{server_name}' not found"));
356            };
357            server.clone()
358        };
359
360        if !server.enabled {
361            return Err(format!("MCP server '{server_name}' is disabled"));
362        }
363        if !server.connected {
364            return Err(format!("MCP server '{server_name}' is not connected"));
365        }
366
367        let endpoint = parse_remote_endpoint(&server.transport).ok_or_else(|| {
368            "MCP tools/call currently supports HTTP/S transports only".to_string()
369        })?;
370        let normalized_args = normalize_mcp_tool_args(&server, tool_name, args);
371
372        let request = json!({
373            "jsonrpc": "2.0",
374            "id": format!("call-{}-{}", server_name, now_ms()),
375            "method": "tools/call",
376            "params": {
377                "name": tool_name,
378                "arguments": normalized_args
379            }
380        });
381        let (response, session_id) = post_json_rpc_with_session(
382            &endpoint,
383            &server.headers,
384            request,
385            server.mcp_session_id.as_deref(),
386        )
387        .await?;
388        if session_id.is_some() {
389            let mut servers = self.servers.write().await;
390            if let Some(row) = servers.get_mut(server_name) {
391                row.mcp_session_id = session_id;
392            }
393            drop(servers);
394            self.persist_state().await;
395        }
396
397        if let Some(err) = response.get("error") {
398            if let Some(challenge) = extract_auth_challenge(err, tool_name) {
399                let output = format!(
400                    "{}\n\nAuthorize here: {}",
401                    challenge.message, challenge.authorization_url
402                );
403                {
404                    let mut servers = self.servers.write().await;
405                    if let Some(row) = servers.get_mut(server_name) {
406                        row.last_auth_challenge = Some(challenge.clone());
407                        row.last_error = None;
408                    }
409                }
410                self.persist_state().await;
411                return Ok(ToolResult {
412                    output,
413                    metadata: json!({
414                        "server": server_name,
415                        "tool": tool_name,
416                        "result": Value::Null,
417                        "mcpAuth": {
418                            "required": true,
419                            "challengeId": challenge.challenge_id,
420                            "tool": challenge.tool_name,
421                            "authorizationUrl": challenge.authorization_url,
422                            "message": challenge.message,
423                            "status": challenge.status
424                        }
425                    }),
426                });
427            }
428            let message = err
429                .get("message")
430                .and_then(|v| v.as_str())
431                .unwrap_or("MCP tools/call failed");
432            return Err(message.to_string());
433        }
434
435        let result = response.get("result").cloned().unwrap_or(Value::Null);
436        let auth_challenge = extract_auth_challenge(&result, tool_name);
437        let output = if let Some(challenge) = auth_challenge.as_ref() {
438            format!(
439                "{}\n\nAuthorize here: {}",
440                challenge.message, challenge.authorization_url
441            )
442        } else {
443            result
444                .get("content")
445                .map(render_mcp_content)
446                .or_else(|| result.get("output").map(|v| v.to_string()))
447                .unwrap_or_else(|| result.to_string())
448        };
449
450        {
451            let mut servers = self.servers.write().await;
452            if let Some(row) = servers.get_mut(server_name) {
453                row.last_auth_challenge = auth_challenge.clone();
454            }
455        }
456        self.persist_state().await;
457
458        let auth_metadata = auth_challenge.as_ref().map(|challenge| {
459            json!({
460                "required": true,
461                "challengeId": challenge.challenge_id,
462                "tool": challenge.tool_name,
463                "authorizationUrl": challenge.authorization_url,
464                "message": challenge.message,
465                "status": challenge.status
466            })
467        });
468
469        Ok(ToolResult {
470            output,
471            metadata: json!({
472                "server": server_name,
473                "tool": tool_name,
474                "result": result,
475                "mcpAuth": auth_metadata
476            }),
477        })
478    }
479
480    async fn connect_stdio(&self, name: &str, command_text: &str) -> bool {
481        match spawn_stdio_process(command_text).await {
482            Ok(child) => {
483                let pid = child.id();
484                self.processes.lock().await.insert(name.to_string(), child);
485                let mut servers = self.servers.write().await;
486                if let Some(server) = servers.get_mut(name) {
487                    server.connected = true;
488                    server.pid = pid;
489                    server.last_error = None;
490                }
491                drop(servers);
492                self.persist_state().await;
493                true
494            }
495            Err(err) => {
496                let mut servers = self.servers.write().await;
497                if let Some(server) = servers.get_mut(name) {
498                    server.connected = false;
499                    server.pid = None;
500                    server.last_error = Some(err);
501                }
502                drop(servers);
503                self.persist_state().await;
504                false
505            }
506        }
507    }
508
509    async fn discover_remote_tools(
510        &self,
511        endpoint: &str,
512        headers: &HashMap<String, String>,
513    ) -> Result<(Vec<McpRemoteTool>, Option<String>), String> {
514        let initialize = json!({
515            "jsonrpc": "2.0",
516            "id": "initialize-1",
517            "method": "initialize",
518            "params": {
519                "protocolVersion": MCP_PROTOCOL_VERSION,
520                "capabilities": {},
521                "clientInfo": {
522                    "name": MCP_CLIENT_NAME,
523                    "version": MCP_CLIENT_VERSION,
524                }
525            }
526        });
527        let (init_response, mut session_id) =
528            post_json_rpc_with_session(endpoint, headers, initialize, None).await?;
529        if let Some(err) = init_response.get("error") {
530            let message = err
531                .get("message")
532                .and_then(|v| v.as_str())
533                .unwrap_or("MCP initialize failed");
534            return Err(message.to_string());
535        }
536
537        let tools_list = json!({
538            "jsonrpc": "2.0",
539            "id": "tools-list-1",
540            "method": "tools/list",
541            "params": {}
542        });
543        let (tools_response, next_session_id) =
544            post_json_rpc_with_session(endpoint, headers, tools_list, session_id.as_deref())
545                .await?;
546        if next_session_id.is_some() {
547            session_id = next_session_id;
548        }
549        if let Some(err) = tools_response.get("error") {
550            let message = err
551                .get("message")
552                .and_then(|v| v.as_str())
553                .unwrap_or("MCP tools/list failed");
554            return Err(message.to_string());
555        }
556
557        let tools = tools_response
558            .get("result")
559            .and_then(|v| v.get("tools"))
560            .and_then(|v| v.as_array())
561            .ok_or_else(|| "MCP tools/list result missing tools array".to_string())?;
562
563        let now = now_ms();
564        let mut out = Vec::new();
565        for row in tools {
566            let Some(tool_name) = row.get("name").and_then(|v| v.as_str()) else {
567                continue;
568            };
569            let description = row
570                .get("description")
571                .and_then(|v| v.as_str())
572                .unwrap_or("")
573                .to_string();
574            let mut input_schema = row
575                .get("inputSchema")
576                .or_else(|| row.get("input_schema"))
577                .cloned()
578                .unwrap_or_else(|| json!({"type":"object"}));
579            normalize_tool_input_schema(&mut input_schema);
580            out.push(McpRemoteTool {
581                server_name: String::new(),
582                tool_name: tool_name.to_string(),
583                namespaced_name: String::new(),
584                description,
585                input_schema,
586                fetched_at_ms: now,
587                schema_hash: String::new(),
588            });
589        }
590
591        Ok((out, session_id))
592    }
593
594    async fn persist_state(&self) {
595        let snapshot = self.servers.read().await.clone();
596        if let Some(parent) = self.state_file.parent() {
597            let _ = tokio::fs::create_dir_all(parent).await;
598        }
599        if let Ok(payload) = serde_json::to_string_pretty(&snapshot) {
600            let _ = tokio::fs::write(self.state_file.as_path(), payload).await;
601        }
602    }
603}
604
605impl Default for McpRegistry {
606    fn default() -> Self {
607        Self::new()
608    }
609}
610
611fn default_enabled() -> bool {
612    true
613}
614
615fn resolve_state_file() -> PathBuf {
616    if let Ok(path) = std::env::var("TANDEM_MCP_REGISTRY") {
617        return PathBuf::from(path);
618    }
619    if let Ok(state_dir) = std::env::var("TANDEM_STATE_DIR") {
620        let trimmed = state_dir.trim();
621        if !trimmed.is_empty() {
622            return PathBuf::from(trimmed).join("mcp_servers.json");
623        }
624    }
625    if let Some(data_dir) = dirs::data_dir() {
626        return data_dir
627            .join("tandem")
628            .join("data")
629            .join("mcp_servers.json");
630    }
631    dirs::home_dir()
632        .map(|home| home.join(".tandem").join("data").join("mcp_servers.json"))
633        .unwrap_or_else(|| PathBuf::from("mcp_servers.json"))
634}
635
636fn load_state(path: &Path) -> HashMap<String, McpServer> {
637    let Ok(raw) = std::fs::read_to_string(path) else {
638        return HashMap::new();
639    };
640    serde_json::from_str::<HashMap<String, McpServer>>(&raw).unwrap_or_default()
641}
642
643fn parse_stdio_transport(transport: &str) -> Option<&str> {
644    transport.strip_prefix("stdio:").map(str::trim)
645}
646
647fn parse_remote_endpoint(transport: &str) -> Option<String> {
648    let trimmed = transport.trim();
649    if trimmed.starts_with("http://") || trimmed.starts_with("https://") {
650        return Some(trimmed.to_string());
651    }
652    for prefix in ["http:", "https:"] {
653        if let Some(rest) = trimmed.strip_prefix(prefix) {
654            let endpoint = rest.trim();
655            if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
656                return Some(endpoint.to_string());
657            }
658        }
659    }
660    None
661}
662
663fn server_tool_rows(server: &McpServer) -> Vec<McpRemoteTool> {
664    let server_slug = sanitize_namespace_segment(&server.name);
665    server
666        .tool_cache
667        .iter()
668        .map(|tool| {
669            let tool_slug = sanitize_namespace_segment(&tool.tool_name);
670            McpRemoteTool {
671                server_name: server.name.clone(),
672                tool_name: tool.tool_name.clone(),
673                namespaced_name: format!("mcp.{server_slug}.{tool_slug}"),
674                description: tool.description.clone(),
675                input_schema: tool.input_schema.clone(),
676                fetched_at_ms: tool.fetched_at_ms,
677                schema_hash: tool.schema_hash.clone(),
678            }
679        })
680        .collect()
681}
682
683fn sanitize_namespace_segment(raw: &str) -> String {
684    let mut out = String::new();
685    let mut previous_underscore = false;
686    for ch in raw.trim().chars() {
687        if ch.is_ascii_alphanumeric() {
688            out.push(ch.to_ascii_lowercase());
689            previous_underscore = false;
690        } else if !previous_underscore {
691            out.push('_');
692            previous_underscore = true;
693        }
694    }
695    let cleaned = out.trim_matches('_');
696    if cleaned.is_empty() {
697        "tool".to_string()
698    } else {
699        cleaned.to_string()
700    }
701}
702
703fn schema_hash(schema: &Value) -> String {
704    let payload = serde_json::to_vec(schema).unwrap_or_default();
705    let mut hasher = Sha256::new();
706    hasher.update(payload);
707    format!("{:x}", hasher.finalize())
708}
709
710fn extract_auth_challenge(result: &Value, tool_name: &str) -> Option<McpAuthChallenge> {
711    let authorization_url = find_string_by_any_key(
712        result,
713        &["authorization_url", "authorizationUrl", "auth_url"],
714    )?;
715    let message = find_string_by_any_key(result, &["llm_instructions", "message", "text"])
716        .unwrap_or_else(|| "This tool requires authorization before it can run.".to_string());
717    let challenge_id = stable_id_seed(&format!("{tool_name}:{authorization_url}"));
718    Some(McpAuthChallenge {
719        challenge_id,
720        tool_name: tool_name.to_string(),
721        authorization_url,
722        message,
723        requested_at_ms: now_ms(),
724        status: "pending".to_string(),
725    })
726}
727
728fn find_string_by_any_key(value: &Value, keys: &[&str]) -> Option<String> {
729    match value {
730        Value::Object(map) => {
731            for key in keys {
732                if let Some(s) = map.get(*key).and_then(|v| v.as_str()) {
733                    let trimmed = s.trim();
734                    if !trimmed.is_empty() {
735                        return Some(trimmed.to_string());
736                    }
737                }
738            }
739            for child in map.values() {
740                if let Some(found) = find_string_by_any_key(child, keys) {
741                    return Some(found);
742                }
743            }
744            None
745        }
746        Value::Array(items) => items
747            .iter()
748            .find_map(|item| find_string_by_any_key(item, keys)),
749        _ => None,
750    }
751}
752
753fn stable_id_seed(seed: &str) -> String {
754    let mut hasher = Sha256::new();
755    hasher.update(seed.as_bytes());
756    let encoded = format!("{:x}", hasher.finalize());
757    encoded.chars().take(16).collect()
758}
759
760fn normalize_tool_input_schema(schema: &mut Value) {
761    normalize_schema_node(schema);
762}
763
764fn normalize_schema_node(node: &mut Value) {
765    let Some(obj) = node.as_object_mut() else {
766        return;
767    };
768
769    // Some MCP servers publish enums on non-string/object/array fields, which
770    // OpenAI-compatible providers may reject (e.g. Gemini via OpenRouter).
771    // Keep enum only when values are all strings and schema type is string-like.
772    if let Some(enum_values) = obj.get("enum").and_then(|v| v.as_array()) {
773        let all_strings = enum_values.iter().all(|v| v.is_string());
774        let string_like_type = schema_type_allows_string_enum(obj.get("type"));
775        if !all_strings || !string_like_type {
776            obj.remove("enum");
777        }
778    }
779
780    if let Some(properties) = obj.get_mut("properties").and_then(|v| v.as_object_mut()) {
781        for value in properties.values_mut() {
782            normalize_schema_node(value);
783        }
784    }
785
786    if let Some(items) = obj.get_mut("items") {
787        normalize_schema_node(items);
788    }
789
790    for key in ["anyOf", "oneOf", "allOf"] {
791        if let Some(array) = obj.get_mut(key).and_then(|v| v.as_array_mut()) {
792            for child in array.iter_mut() {
793                normalize_schema_node(child);
794            }
795        }
796    }
797
798    if let Some(additional) = obj.get_mut("additionalProperties") {
799        normalize_schema_node(additional);
800    }
801}
802
803fn schema_type_allows_string_enum(schema_type: Option<&Value>) -> bool {
804    let Some(schema_type) = schema_type else {
805        // No explicit type: keep enum to avoid over-normalizing loosely-typed schemas.
806        return true;
807    };
808
809    if let Some(kind) = schema_type.as_str() {
810        return kind == "string";
811    }
812
813    if let Some(kinds) = schema_type.as_array() {
814        let mut saw_string = false;
815        for kind in kinds {
816            let Some(kind) = kind.as_str() else {
817                return false;
818            };
819            if kind == "string" {
820                saw_string = true;
821                continue;
822            }
823            if kind != "null" {
824                return false;
825            }
826        }
827        return saw_string;
828    }
829
830    false
831}
832
833fn now_ms() -> u64 {
834    SystemTime::now()
835        .duration_since(UNIX_EPOCH)
836        .map(|d| d.as_millis() as u64)
837        .unwrap_or(0)
838}
839
840fn build_headers(headers: &HashMap<String, String>) -> Result<HeaderMap, String> {
841    let mut map = HeaderMap::new();
842    map.insert(
843        ACCEPT,
844        HeaderValue::from_static("application/json, text/event-stream"),
845    );
846    map.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
847    for (key, value) in headers {
848        let name = HeaderName::from_bytes(key.trim().as_bytes())
849            .map_err(|e| format!("Invalid header name '{key}': {e}"))?;
850        let header = HeaderValue::from_str(value.trim())
851            .map_err(|e| format!("Invalid header value for '{key}': {e}"))?;
852        map.insert(name, header);
853    }
854    Ok(map)
855}
856
857async fn post_json_rpc_with_session(
858    endpoint: &str,
859    headers: &HashMap<String, String>,
860    request: Value,
861    session_id: Option<&str>,
862) -> Result<(Value, Option<String>), String> {
863    let client = reqwest::Client::builder()
864        .timeout(std::time::Duration::from_secs(12))
865        .build()
866        .map_err(|e| format!("Failed to build HTTP client: {e}"))?;
867    let mut req = client.post(endpoint).headers(build_headers(headers)?);
868    if let Some(id) = session_id {
869        let trimmed = id.trim();
870        if !trimmed.is_empty() {
871            req = req.header("Mcp-Session-Id", trimmed);
872        }
873    }
874    let response = req
875        .json(&request)
876        .send()
877        .await
878        .map_err(|e| format!("MCP request failed: {e}"))?;
879    let response_session_id = response
880        .headers()
881        .get("mcp-session-id")
882        .and_then(|v| v.to_str().ok())
883        .map(|v| v.trim().to_string())
884        .filter(|v| !v.is_empty());
885    let status = response.status();
886    let payload = response
887        .text()
888        .await
889        .map_err(|e| format!("Failed to read MCP response: {e}"))?;
890    if !status.is_success() {
891        return Err(format!(
892            "MCP endpoint returned HTTP {}: {}",
893            status.as_u16(),
894            payload.chars().take(400).collect::<String>()
895        ));
896    }
897    let value = serde_json::from_str::<Value>(&payload)
898        .map_err(|e| format!("Invalid MCP JSON response: {e}"))?;
899    Ok((value, response_session_id))
900}
901
902fn render_mcp_content(value: &Value) -> String {
903    let Some(items) = value.as_array() else {
904        return value.to_string();
905    };
906    let mut chunks = Vec::new();
907    for item in items {
908        if let Some(text) = item.get("text").and_then(|v| v.as_str()) {
909            chunks.push(text.to_string());
910            continue;
911        }
912        chunks.push(item.to_string());
913    }
914    if chunks.is_empty() {
915        value.to_string()
916    } else {
917        chunks.join("\n")
918    }
919}
920
921fn normalize_mcp_tool_args(server: &McpServer, tool_name: &str, raw_args: Value) -> Value {
922    let Some(schema) = server
923        .tool_cache
924        .iter()
925        .find(|row| row.tool_name.eq_ignore_ascii_case(tool_name))
926        .map(|row| &row.input_schema)
927    else {
928        return raw_args;
929    };
930
931    let mut args_obj = match raw_args {
932        Value::Object(obj) => obj,
933        other => return other,
934    };
935
936    let properties = schema
937        .get("properties")
938        .and_then(|v| v.as_object())
939        .cloned()
940        .unwrap_or_default();
941    if properties.is_empty() {
942        return Value::Object(args_obj);
943    }
944
945    // Build a normalized-key lookup so taskTitle -> task_title and list-id -> list_id resolve.
946    let mut normalized_existing: HashMap<String, String> = HashMap::new();
947    for key in args_obj.keys() {
948        normalized_existing.insert(normalize_arg_key(key), key.clone());
949    }
950
951    // Copy values from normalized aliases to canonical schema property names.
952    let canonical_keys = properties.keys().cloned().collect::<Vec<_>>();
953    for canonical in &canonical_keys {
954        if args_obj.contains_key(canonical) {
955            continue;
956        }
957        if let Some(existing_key) = normalized_existing.get(&normalize_arg_key(canonical)) {
958            if let Some(value) = args_obj.get(existing_key).cloned() {
959                args_obj.insert(canonical.clone(), value);
960            }
961        }
962    }
963
964    // Fill required fields using conservative aliases when models choose common alternatives.
965    let required = schema
966        .get("required")
967        .and_then(|v| v.as_array())
968        .map(|arr| {
969            arr.iter()
970                .filter_map(|v| v.as_str().map(str::to_string))
971                .collect::<Vec<_>>()
972        })
973        .unwrap_or_default();
974
975    for required_key in required {
976        if args_obj.contains_key(&required_key) {
977            continue;
978        }
979        if let Some(alias_value) = find_required_alias_value(&required_key, &args_obj) {
980            args_obj.insert(required_key, alias_value);
981        }
982    }
983
984    Value::Object(args_obj)
985}
986
987fn find_required_alias_value(
988    required_key: &str,
989    args_obj: &serde_json::Map<String, Value>,
990) -> Option<Value> {
991    let mut alias_candidates = vec![
992        required_key.to_string(),
993        required_key.to_ascii_lowercase(),
994        required_key.replace('_', ""),
995    ];
996
997    // Common fallback for fields like task_title where models often send `name`.
998    if required_key.contains("title") {
999        alias_candidates.extend([
1000            "name".to_string(),
1001            "title".to_string(),
1002            "task_name".to_string(),
1003            "taskname".to_string(),
1004        ]);
1005    }
1006
1007    // Common fallback for *_id fields where models emit `<base>` or `<base>Id`.
1008    if let Some(base) = required_key.strip_suffix("_id") {
1009        alias_candidates.extend([base.to_string(), format!("{base}id"), format!("{base}_id")]);
1010    }
1011
1012    let mut by_normalized: HashMap<String, &Value> = HashMap::new();
1013    for (key, value) in args_obj {
1014        by_normalized.insert(normalize_arg_key(key), value);
1015    }
1016
1017    alias_candidates
1018        .into_iter()
1019        .find_map(|candidate| by_normalized.get(&normalize_arg_key(&candidate)).cloned())
1020        .cloned()
1021}
1022
1023fn normalize_arg_key(key: &str) -> String {
1024    key.chars()
1025        .filter(|ch| ch.is_ascii_alphanumeric())
1026        .map(|ch| ch.to_ascii_lowercase())
1027        .collect()
1028}
1029
1030async fn spawn_stdio_process(command_text: &str) -> Result<Child, String> {
1031    if command_text.is_empty() {
1032        return Err("Missing stdio command".to_string());
1033    }
1034    #[cfg(windows)]
1035    let mut command = {
1036        let mut cmd = Command::new("powershell");
1037        cmd.args(["-NoProfile", "-Command", command_text]);
1038        cmd
1039    };
1040    #[cfg(not(windows))]
1041    let mut command = {
1042        let mut cmd = Command::new("sh");
1043        cmd.args(["-lc", command_text]);
1044        cmd
1045    };
1046    command
1047        .stdin(std::process::Stdio::null())
1048        .stdout(std::process::Stdio::null())
1049        .stderr(std::process::Stdio::null());
1050    command.spawn().map_err(|e| e.to_string())
1051}
1052
1053#[cfg(test)]
1054mod tests {
1055    use super::*;
1056    use uuid::Uuid;
1057
1058    #[tokio::test]
1059    async fn add_connect_disconnect_non_stdio_server() {
1060        let file = std::env::temp_dir().join(format!("mcp-test-{}.json", Uuid::new_v4()));
1061        let registry = McpRegistry::new_with_state_file(file);
1062        registry
1063            .add("example".to_string(), "sse:https://example.com".to_string())
1064            .await;
1065        assert!(registry.connect("example").await);
1066        let listed = registry.list().await;
1067        assert!(listed.get("example").map(|s| s.connected).unwrap_or(false));
1068        assert!(registry.disconnect("example").await);
1069    }
1070
1071    #[test]
1072    fn parse_remote_endpoint_supports_http_prefixes() {
1073        assert_eq!(
1074            parse_remote_endpoint("https://mcp.example.com/mcp"),
1075            Some("https://mcp.example.com/mcp".to_string())
1076        );
1077        assert_eq!(
1078            parse_remote_endpoint("http:https://mcp.example.com/mcp"),
1079            Some("https://mcp.example.com/mcp".to_string())
1080        );
1081    }
1082
1083    #[test]
1084    fn normalize_schema_removes_non_string_enums_recursively() {
1085        let mut schema = json!({
1086            "type": "object",
1087            "properties": {
1088                "good": { "type": "string", "enum": ["a", "b"] },
1089                "good_nullable": { "type": ["string", "null"], "enum": ["asc", "desc"] },
1090                "bad_object": { "type": "object", "enum": ["asc", "desc"] },
1091                "bad_array": { "type": "array", "enum": ["asc", "desc"] },
1092                "bad_number": { "type": "number", "enum": [1, 2] },
1093                "bad_mixed": { "enum": ["ok", 1] },
1094                "nested": {
1095                    "type": "object",
1096                    "properties": {
1097                        "child": { "enum": [true, false] }
1098                    }
1099                }
1100            }
1101        });
1102
1103        normalize_tool_input_schema(&mut schema);
1104
1105        assert!(
1106            schema["properties"]["good"]["enum"].is_array(),
1107            "string enums should be preserved"
1108        );
1109        assert!(
1110            schema["properties"]["good_nullable"]["enum"].is_array(),
1111            "string|null enums should be preserved"
1112        );
1113        assert!(
1114            schema["properties"]["bad_object"]["enum"].is_null(),
1115            "object enums should be dropped"
1116        );
1117        assert!(
1118            schema["properties"]["bad_array"]["enum"].is_null(),
1119            "array enums should be dropped"
1120        );
1121        assert!(
1122            schema["properties"]["bad_number"]["enum"].is_null(),
1123            "non-string enums should be dropped"
1124        );
1125        assert!(
1126            schema["properties"]["bad_mixed"]["enum"].is_null(),
1127            "mixed enums should be dropped"
1128        );
1129        assert!(
1130            schema["properties"]["nested"]["properties"]["child"]["enum"].is_null(),
1131            "recursive non-string enums should be dropped"
1132        );
1133    }
1134
1135    #[test]
1136    fn extract_auth_challenge_from_result_payload() {
1137        let payload = json!({
1138            "content": [
1139                {
1140                    "type": "text",
1141                    "llm_instructions": "Authorize Gmail access first.",
1142                    "authorization_url": "https://example.com/oauth/start"
1143                }
1144            ]
1145        });
1146        let challenge = extract_auth_challenge(&payload, "gmail_whoami")
1147            .expect("auth challenge should be detected");
1148        assert_eq!(challenge.tool_name, "gmail_whoami");
1149        assert_eq!(
1150            challenge.authorization_url,
1151            "https://example.com/oauth/start"
1152        );
1153        assert_eq!(challenge.status, "pending");
1154    }
1155
1156    #[test]
1157    fn extract_auth_challenge_returns_none_without_url() {
1158        let payload = json!({
1159            "content": [
1160                {"type":"text","text":"No authorization needed"}
1161            ]
1162        });
1163        assert!(extract_auth_challenge(&payload, "gmail_whoami").is_none());
1164    }
1165
1166    #[test]
1167    fn normalize_mcp_tool_args_maps_clickup_aliases() {
1168        let server = McpServer {
1169            name: "arcade".to_string(),
1170            transport: "https://example.com/mcp".to_string(),
1171            enabled: true,
1172            connected: true,
1173            pid: None,
1174            last_error: None,
1175            last_auth_challenge: None,
1176            mcp_session_id: None,
1177            headers: HashMap::new(),
1178            tool_cache: vec![McpToolCacheEntry {
1179                tool_name: "Clickup_CreateTask".to_string(),
1180                description: "Create task".to_string(),
1181                input_schema: json!({
1182                    "type":"object",
1183                    "properties":{
1184                        "list_id":{"type":"string"},
1185                        "task_title":{"type":"string"}
1186                    },
1187                    "required":["list_id","task_title"]
1188                }),
1189                fetched_at_ms: 0,
1190                schema_hash: "x".to_string(),
1191            }],
1192            tools_fetched_at_ms: None,
1193        };
1194
1195        let normalized = normalize_mcp_tool_args(
1196            &server,
1197            "Clickup_CreateTask",
1198            json!({
1199                "listId": "123",
1200                "name": "Prep fish"
1201            }),
1202        );
1203        assert_eq!(
1204            normalized.get("list_id").and_then(|v| v.as_str()),
1205            Some("123")
1206        );
1207        assert_eq!(
1208            normalized.get("task_title").and_then(|v| v.as_str()),
1209            Some("Prep fish")
1210        );
1211    }
1212
1213    #[test]
1214    fn normalize_arg_key_ignores_case_and_separators() {
1215        assert_eq!(normalize_arg_key("task_title"), "tasktitle");
1216        assert_eq!(normalize_arg_key("taskTitle"), "tasktitle");
1217        assert_eq!(normalize_arg_key("task-title"), "tasktitle");
1218    }
1219}