Skip to main content

tandem_runtime/mcp_parts/
part01.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, CONTENT_TYPE};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use sha2::{Digest, Sha256};
10use tandem_types::{LocalImplicitTenant, SecretRef, TenantContext, ToolResult};
11use tokio::process::{Child, Command};
12use tokio::sync::{Mutex, RwLock};
13
14const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
15const MCP_CLIENT_NAME: &str = "tandem";
16const MCP_CLIENT_VERSION: &str = env!("CARGO_PKG_VERSION");
17const MCP_AUTH_REPROBE_COOLDOWN_MS: u64 = 15_000;
18const MCP_SECRET_PLACEHOLDER: &str = "";
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct McpToolCacheEntry {
22    pub tool_name: String,
23    pub description: String,
24    #[serde(default)]
25    pub input_schema: Value,
26    pub fetched_at_ms: u64,
27    pub schema_hash: String,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct McpServer {
32    pub name: String,
33    pub transport: String,
34    #[serde(default, skip_serializing_if = "String::is_empty")]
35    pub auth_kind: String,
36    #[serde(default = "default_enabled")]
37    pub enabled: bool,
38    pub connected: bool,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub pid: Option<u32>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub last_error: Option<String>,
43    #[serde(default, skip_serializing_if = "Option::is_none")]
44    pub last_auth_challenge: Option<McpAuthChallenge>,
45    #[serde(default, skip_serializing_if = "Option::is_none")]
46    pub mcp_session_id: Option<String>,
47    #[serde(default)]
48    pub headers: HashMap<String, String>,
49    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
50    pub secret_headers: HashMap<String, McpSecretRef>,
51    #[serde(default)]
52    pub tool_cache: Vec<McpToolCacheEntry>,
53    #[serde(default, skip_serializing_if = "Option::is_none")]
54    pub tools_fetched_at_ms: Option<u64>,
55    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
56    pub pending_auth_by_tool: HashMap<String, PendingMcpAuth>,
57    #[serde(default, skip_serializing_if = "Option::is_none")]
58    pub allowed_tools: Option<Vec<String>>,
59    #[serde(default, skip)]
60    pub secret_header_values: HashMap<String, String>,
61    #[serde(default, skip_serializing_if = "Option::is_none")]
62    pub oauth: Option<McpOAuthConfig>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
66#[serde(tag = "type", rename_all = "snake_case")]
67pub enum McpSecretRef {
68    Store {
69        secret_id: String,
70        #[serde(default)]
71        tenant_context: TenantContext,
72    },
73    Env {
74        env: String,
75    },
76    BearerEnv {
77        env: String,
78    },
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct McpAuthChallenge {
83    pub challenge_id: String,
84    pub tool_name: String,
85    pub authorization_url: String,
86    pub message: String,
87    pub requested_at_ms: u64,
88    pub status: String,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct PendingMcpAuth {
93    pub challenge_id: String,
94    pub authorization_url: String,
95    pub message: String,
96    pub status: String,
97    pub first_seen_ms: u64,
98    pub last_probe_ms: u64,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct McpOAuthConfig {
103    pub provider_id: String,
104    pub token_endpoint: String,
105    pub client_id: String,
106    #[serde(default, skip_serializing_if = "Option::is_none")]
107    pub client_secret_ref: Option<McpSecretRef>,
108    #[serde(default, skip)]
109    pub client_secret_value: Option<String>,
110}
111
112#[derive(Debug, Clone)]
113enum DiscoverRemoteToolsError {
114    Message(String),
115    AuthChallenge(McpAuthChallenge),
116}
117
118impl From<String> for DiscoverRemoteToolsError {
119    fn from(value: String) -> Self {
120        Self::Message(value)
121    }
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct McpRemoteTool {
126    pub server_name: String,
127    pub tool_name: String,
128    pub namespaced_name: String,
129    pub description: String,
130    #[serde(default)]
131    pub input_schema: Value,
132    pub fetched_at_ms: u64,
133    pub schema_hash: String,
134}
135
136#[derive(Clone)]
137pub struct McpRegistry {
138    servers: Arc<RwLock<HashMap<String, McpServer>>>,
139    processes: Arc<Mutex<HashMap<String, Child>>>,
140    state_file: Arc<PathBuf>,
141}
142
143impl McpRegistry {
144    pub fn new() -> Self {
145        Self::new_with_state_file(resolve_state_file())
146    }
147
148    pub fn new_with_state_file(state_file: PathBuf) -> Self {
149        let (loaded_state, migrated) = load_state(&state_file);
150        let loaded = loaded_state
151            .into_iter()
152            .map(|(k, mut v)| {
153                v.connected = false;
154                v.pid = None;
155                if v.name.trim().is_empty() {
156                    v.name = k.clone();
157                }
158                if v.headers.is_empty() {
159                    v.headers = HashMap::new();
160                }
161                if v.secret_headers.is_empty() {
162                    v.secret_headers = HashMap::new();
163                }
164                let tenant_context = local_tenant_context();
165                v.secret_header_values =
166                    resolve_secret_header_values(&v.secret_headers, &tenant_context);
167                if let Some(oauth) = v.oauth.as_mut() {
168                    oauth.client_secret_value =
169                        oauth.client_secret_ref.as_ref().and_then(|secret_ref| {
170                            resolve_secret_ref_value(secret_ref, &tenant_context)
171                        });
172                }
173                (k, v)
174            })
175            .collect::<HashMap<_, _>>();
176        if migrated {
177            persist_state_blocking(&state_file, &loaded);
178        }
179        Self {
180            servers: Arc::new(RwLock::new(loaded)),
181            processes: Arc::new(Mutex::new(HashMap::new())),
182            state_file: Arc::new(state_file),
183        }
184    }
185
186    pub async fn list(&self) -> HashMap<String, McpServer> {
187        self.servers.read().await.clone()
188    }
189
190    pub async fn list_public(&self) -> HashMap<String, McpServer> {
191        self.servers
192            .read()
193            .await
194            .iter()
195            .map(|(name, server)| (name.clone(), redacted_server_view(server)))
196            .collect()
197    }
198
199    pub async fn add(&self, name: String, transport: String) {
200        self.add_or_update(name, transport, HashMap::new(), true)
201            .await;
202    }
203
204    pub async fn add_or_update(
205        &self,
206        name: String,
207        transport: String,
208        headers: HashMap<String, String>,
209        enabled: bool,
210    ) {
211        self.add_or_update_with_secret_refs(name, transport, headers, HashMap::new(), enabled)
212            .await;
213    }
214
215    pub async fn add_or_update_with_secret_refs(
216        &self,
217        name: String,
218        transport: String,
219        headers: HashMap<String, String>,
220        secret_headers: HashMap<String, McpSecretRef>,
221        enabled: bool,
222    ) {
223        let normalized_name = name.trim().to_string();
224        let tenant_context = local_tenant_context();
225        let (persisted_headers, persisted_secret_headers, secret_header_values) =
226            split_headers_for_storage(&normalized_name, headers, secret_headers, &tenant_context);
227        let mut servers = self.servers.write().await;
228        let existing = servers.get(&normalized_name).cloned();
229        let preserve_cache = existing.as_ref().is_some_and(|row| {
230            row.transport == transport
231                && effective_headers(row)
232                    == combine_headers(&persisted_headers, &secret_header_values)
233        });
234        let existing_tool_cache = if preserve_cache {
235            existing
236                .as_ref()
237                .map(|row| row.tool_cache.clone())
238                .unwrap_or_default()
239        } else {
240            Vec::new()
241        };
242        let existing_fetched_at = if preserve_cache {
243            existing.as_ref().and_then(|row| row.tools_fetched_at_ms)
244        } else {
245            None
246        };
247        let server = McpServer {
248            name: normalized_name.clone(),
249            transport,
250            auth_kind: existing
251                .as_ref()
252                .map(|row| row.auth_kind.clone())
253                .unwrap_or_default(),
254            enabled,
255            connected: false,
256            pid: None,
257            last_error: None,
258            last_auth_challenge: None,
259            mcp_session_id: None,
260            headers: persisted_headers,
261            secret_headers: persisted_secret_headers,
262            tool_cache: existing_tool_cache,
263            tools_fetched_at_ms: existing_fetched_at,
264            pending_auth_by_tool: HashMap::new(),
265            allowed_tools: existing.as_ref().and_then(|row| row.allowed_tools.clone()),
266            secret_header_values,
267            oauth: existing.as_ref().and_then(|row| row.oauth.clone()),
268        };
269        servers.insert(normalized_name, server);
270        drop(servers);
271        self.persist_state().await;
272    }
273
274    pub async fn set_allowed_tools(&self, name: &str, allowed_tools: Option<Vec<String>>) -> bool {
275        let mut servers = self.servers.write().await;
276        let Some(server) = servers.get_mut(name) else {
277            return false;
278        };
279        let normalized = allowed_tools.map(normalize_allowed_tool_names);
280        if server.allowed_tools == normalized {
281            return true;
282        }
283        server.allowed_tools = normalized;
284        drop(servers);
285        self.persist_state().await;
286        true
287    }
288
289    pub async fn set_enabled(&self, name: &str, enabled: bool) -> bool {
290        let mut servers = self.servers.write().await;
291        let Some(server) = servers.get_mut(name) else {
292            return false;
293        };
294        server.enabled = enabled;
295        if !enabled {
296            server.connected = false;
297            server.pid = None;
298            server.last_auth_challenge = None;
299            server.mcp_session_id = None;
300            server.pending_auth_by_tool.clear();
301        }
302        drop(servers);
303        if !enabled {
304            if let Some(mut child) = self.processes.lock().await.remove(name) {
305                let _ = child.kill().await;
306                let _ = child.wait().await;
307            }
308        }
309        self.persist_state().await;
310        true
311    }
312
313    pub async fn remove(&self, name: &str) -> bool {
314        let removed_server = {
315            let mut servers = self.servers.write().await;
316            servers.remove(name)
317        };
318        let Some(server) = removed_server else {
319            return false;
320        };
321        let current_tenant = local_tenant_context();
322        delete_secret_header_refs(&server.secret_headers, &current_tenant);
323        delete_oauth_secret_ref(server.oauth.as_ref(), &current_tenant);
324
325        if let Some(mut child) = self.processes.lock().await.remove(name) {
326            let _ = child.kill().await;
327            let _ = child.wait().await;
328        }
329        self.persist_state().await;
330        true
331    }
332
333    pub async fn connect(&self, name: &str) -> bool {
334        let server = {
335            let servers = self.servers.read().await;
336            let Some(server) = servers.get(name) else {
337                return false;
338            };
339            server.clone()
340        };
341
342        if !server.enabled {
343            let mut servers = self.servers.write().await;
344            if let Some(entry) = servers.get_mut(name) {
345                entry.connected = false;
346                entry.pid = None;
347                entry.last_error = Some("MCP server is disabled".to_string());
348                entry.last_auth_challenge = None;
349                entry.mcp_session_id = None;
350                entry.pending_auth_by_tool.clear();
351            }
352            drop(servers);
353            self.persist_state().await;
354            return false;
355        }
356
357        if let Some(command_text) = parse_stdio_transport(&server.transport) {
358            return self.connect_stdio(name, command_text).await;
359        }
360
361        if parse_remote_endpoint(&server.transport).is_some() {
362            return self.refresh(name).await.is_ok();
363        }
364
365        let mut servers = self.servers.write().await;
366        if let Some(entry) = servers.get_mut(name) {
367            entry.connected = true;
368            entry.pid = None;
369            entry.last_error = None;
370            entry.last_auth_challenge = None;
371            entry.mcp_session_id = None;
372            entry.pending_auth_by_tool.clear();
373        }
374        drop(servers);
375        self.persist_state().await;
376        true
377    }
378
379    pub async fn refresh(&self, name: &str) -> Result<Vec<McpRemoteTool>, String> {
380        let server = {
381            let servers = self.servers.read().await;
382            let Some(server) = servers.get(name) else {
383                return Err("MCP server not found".to_string());
384            };
385            server.clone()
386        };
387
388        if !server.enabled {
389            return Err("MCP server is disabled".to_string());
390        }
391
392        let endpoint = parse_remote_endpoint(&server.transport)
393            .ok_or_else(|| "MCP refresh currently supports HTTP/S transports only".to_string())?;
394
395        let _ = self.ensure_oauth_bearer_token_fresh(name, false).await;
396        let server = {
397            let servers = self.servers.read().await;
398            let Some(server) = servers.get(name) else {
399                return Err("MCP server not found".to_string());
400            };
401            server.clone()
402        };
403        let request_headers = effective_headers(&server);
404        let discovery = self
405            .discover_remote_tools(name, &endpoint, &request_headers)
406            .await;
407        let (tools, session_id) = match discovery {
408            Ok(result) => result,
409            Err(DiscoverRemoteToolsError::AuthChallenge(challenge)) => {
410                let mut servers = self.servers.write().await;
411                if let Some(entry) = servers.get_mut(name) {
412                    entry.connected = false;
413                    entry.pid = None;
414                    entry.last_error = Some(challenge.message.clone());
415                    entry.last_auth_challenge = Some(challenge.clone());
416                    entry.mcp_session_id = None;
417                    entry.pending_auth_by_tool.clear();
418                    entry.tool_cache.clear();
419                    entry.tools_fetched_at_ms = None;
420                }
421                drop(servers);
422                self.persist_state().await;
423                return Err(format!(
424                    "MCP server '{name}' requires authorization: {}",
425                    challenge.message
426                ));
427            }
428            Err(DiscoverRemoteToolsError::Message(err)) => {
429                if should_retry_mcp_oauth_refresh(&server, &err)
430                    && self.ensure_oauth_bearer_token_fresh(name, true).await?
431                {
432                    let refreshed_server = {
433                        let servers = self.servers.read().await;
434                        servers
435                            .get(name)
436                            .cloned()
437                            .ok_or_else(|| "MCP server not found".to_string())?
438                    };
439                    match self
440                        .discover_remote_tools(
441                            name,
442                            &endpoint,
443                            &effective_headers(&refreshed_server),
444                        )
445                        .await
446                    {
447                        Ok(result) => result,
448                        Err(DiscoverRemoteToolsError::AuthChallenge(challenge)) => {
449                            let mut servers = self.servers.write().await;
450                            if let Some(entry) = servers.get_mut(name) {
451                                entry.connected = false;
452                                entry.pid = None;
453                                entry.last_error = Some(challenge.message.clone());
454                                entry.last_auth_challenge = Some(challenge.clone());
455                                entry.mcp_session_id = None;
456                                entry.pending_auth_by_tool.clear();
457                                entry.tool_cache.clear();
458                                entry.tools_fetched_at_ms = None;
459                            }
460                            drop(servers);
461                            self.persist_state().await;
462                            return Err(format!(
463                                "MCP server '{name}' requires authorization: {}",
464                                challenge.message
465                            ));
466                        }
467                        Err(DiscoverRemoteToolsError::Message(retry_err)) => {
468                            let mut servers = self.servers.write().await;
469                            if let Some(entry) = servers.get_mut(name) {
470                                entry.connected = false;
471                                entry.pid = None;
472                                entry.last_error = Some(retry_err.clone());
473                                entry.last_auth_challenge = None;
474                                entry.mcp_session_id = None;
475                                entry.pending_auth_by_tool.clear();
476                                entry.tool_cache.clear();
477                                entry.tools_fetched_at_ms = None;
478                            }
479                            drop(servers);
480                            self.persist_state().await;
481                            return Err(retry_err);
482                        }
483                    }
484                } else {
485                    let mut servers = self.servers.write().await;
486                    if let Some(entry) = servers.get_mut(name) {
487                        entry.connected = false;
488                        entry.pid = None;
489                        entry.last_error = Some(err.clone());
490                        entry.last_auth_challenge = None;
491                        entry.mcp_session_id = None;
492                        entry.pending_auth_by_tool.clear();
493                        entry.tool_cache.clear();
494                        entry.tools_fetched_at_ms = None;
495                    }
496                    drop(servers);
497                    self.persist_state().await;
498                    return Err(err);
499                }
500            }
501        };
502
503        let now = now_ms();
504        let cache = tools
505            .iter()
506            .map(|tool| McpToolCacheEntry {
507                tool_name: tool.tool_name.clone(),
508                description: tool.description.clone(),
509                input_schema: tool.input_schema.clone(),
510                fetched_at_ms: now,
511                schema_hash: schema_hash(&tool.input_schema),
512            })
513            .collect::<Vec<_>>();
514
515        let mut servers = self.servers.write().await;
516        if let Some(entry) = servers.get_mut(name) {
517            entry.connected = true;
518            entry.pid = None;
519            entry.last_error = None;
520            entry.last_auth_challenge = None;
521            entry.mcp_session_id = session_id;
522            entry.tool_cache = cache;
523            entry.tools_fetched_at_ms = Some(now);
524            entry.pending_auth_by_tool.clear();
525        }
526        drop(servers);
527        self.persist_state().await;
528        Ok(self.server_tools(name).await)
529    }
530
531    pub async fn disconnect(&self, name: &str) -> bool {
532        if let Some(mut child) = self.processes.lock().await.remove(name) {
533            let _ = child.kill().await;
534            let _ = child.wait().await;
535        }
536        let mut servers = self.servers.write().await;
537        if let Some(server) = servers.get_mut(name) {
538            server.connected = false;
539            server.pid = None;
540            server.last_auth_challenge = None;
541            server.mcp_session_id = None;
542            server.pending_auth_by_tool.clear();
543            drop(servers);
544            self.persist_state().await;
545            return true;
546        }
547        false
548    }
549
550    pub async fn complete_auth(&self, name: &str) -> bool {
551        let mut servers = self.servers.write().await;
552        let Some(server) = servers.get_mut(name) else {
553            return false;
554        };
555        server.last_error = None;
556        server.last_auth_challenge = None;
557        server.pending_auth_by_tool.clear();
558        drop(servers);
559        self.persist_state().await;
560        true
561    }
562
563    pub async fn set_auth_kind(&self, name: &str, auth_kind: String) -> bool {
564        let normalized = normalize_auth_kind(&auth_kind);
565        let mut servers = self.servers.write().await;
566        let Some(server) = servers.get_mut(name) else {
567            return false;
568        };
569        server.auth_kind = normalized;
570        drop(servers);
571        self.persist_state().await;
572        true
573    }
574
575    pub async fn record_server_auth_challenge(
576        &self,
577        name: &str,
578        challenge: McpAuthChallenge,
579        last_error: Option<String>,
580    ) -> bool {
581        let mut servers = self.servers.write().await;
582        let Some(server) = servers.get_mut(name) else {
583            return false;
584        };
585        let tool_key = canonical_tool_key(&challenge.tool_name);
586        server.connected = false;
587        server.pid = None;
588        server.last_error = last_error.or_else(|| Some(challenge.message.clone()));
589        server.last_auth_challenge = Some(challenge.clone());
590        server.mcp_session_id = None;
591        server.pending_auth_by_tool.clear();
592        server
593            .pending_auth_by_tool
594            .insert(tool_key, pending_auth_from_challenge(&challenge));
595        drop(servers);
596        self.persist_state().await;
597        true
598    }
599
600    pub async fn clear_server_auth_challenge(&self, name: &str) -> bool {
601        let mut servers = self.servers.write().await;
602        let Some(server) = servers.get_mut(name) else {
603            return false;
604        };
605        server.last_auth_challenge = None;
606        server.pending_auth_by_tool.clear();
607        drop(servers);
608        self.persist_state().await;
609        true
610    }
611
612    pub async fn set_bearer_token(&self, name: &str, token: &str) -> Result<bool, String> {
613        let trimmed = token.trim();
614        if trimmed.is_empty() {
615            return Err("oauth access token cannot be empty".to_string());
616        }
617        let current_tenant = local_tenant_context();
618        let mut servers = self.servers.write().await;
619        let Some(server) = servers.get_mut(name) else {
620            return Ok(false);
621        };
622        let header_name = "Authorization".to_string();
623        let secret_id = mcp_header_secret_id(name, &header_name);
624        tandem_core::set_provider_auth(&secret_id, &format!("Bearer {trimmed}"))
625            .map_err(|error| error.to_string())?;
626        server.secret_headers.insert(
627            header_name.clone(),
628            McpSecretRef::Store {
629                secret_id: secret_id.clone(),
630                tenant_context: current_tenant,
631            },
632        );
633        server
634            .secret_header_values
635            .insert(header_name.clone(), format!("Bearer {trimmed}"));
636        server.headers.remove(&header_name);
637        drop(servers);
638        self.persist_state().await;
639        Ok(true)
640    }
641
642    pub async fn set_oauth_refresh_config(
643        &self,
644        name: &str,
645        provider_id: String,
646        token_endpoint: String,
647        client_id: String,
648        client_secret: Option<String>,
649    ) -> Result<bool, String> {
650        let current_tenant = local_tenant_context();
651        let mut servers = self.servers.write().await;
652        let Some(server) = servers.get_mut(name) else {
653            return Ok(false);
654        };
655
656        let client_secret_ref = client_secret
657            .as_deref()
658            .map(str::trim)
659            .filter(|value| !value.is_empty())
660            .map(|value| -> Result<McpSecretRef, String> {
661                let secret_id = mcp_oauth_client_secret_id(name);
662                tandem_core::set_provider_auth(&secret_id, value)
663                    .map_err(|error| error.to_string())?;
664                Ok(McpSecretRef::Store {
665                    secret_id,
666                    tenant_context: current_tenant.clone(),
667                })
668            })
669            .transpose()?;
670        if client_secret_ref.is_none() {
671            let secret_id = mcp_oauth_client_secret_id(name);
672            let _ = tandem_core::delete_provider_auth(&secret_id);
673        }
674
675        server.oauth = Some(McpOAuthConfig {
676            provider_id,
677            token_endpoint,
678            client_id,
679            client_secret_ref,
680            client_secret_value: client_secret
681                .map(|value| value.trim().to_string())
682                .filter(|value| !value.is_empty()),
683        });
684        drop(servers);
685        self.persist_state().await;
686        Ok(true)
687    }
688
689    pub async fn list_tools(&self) -> Vec<McpRemoteTool> {
690        let mut out = self
691            .servers
692            .read()
693            .await
694            .values()
695            .filter(|server| server.enabled && server.connected)
696            .flat_map(server_tool_rows)
697            .collect::<Vec<_>>();
698        out.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
699        out
700    }
701
702    pub async fn server_tools(&self, name: &str) -> Vec<McpRemoteTool> {
703        let Some(server) = self.servers.read().await.get(name).cloned() else {
704            return Vec::new();
705        };
706        let mut rows = server_tool_rows(&server);
707        rows.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
708        rows
709    }
710
711    pub async fn call_tool(
712        &self,
713        server_name: &str,
714        tool_name: &str,
715        args: Value,
716    ) -> Result<ToolResult, String> {
717        let server = {
718            let servers = self.servers.read().await;
719            let Some(server) = servers.get(server_name) else {
720                return Err(format!("MCP server '{server_name}' not found"));
721            };
722            server.clone()
723        };
724
725        if !server.enabled {
726            return Err(format!("MCP server '{server_name}' is disabled"));
727        }
728        if !server.connected {
729            return Err(format!("MCP server '{server_name}' is not connected"));
730        }
731
732        let endpoint = parse_remote_endpoint(&server.transport).ok_or_else(|| {
733            "MCP tools/call currently supports HTTP/S transports only".to_string()
734        })?;
735        let canonical_tool = canonical_tool_key(tool_name);
736        let now = now_ms();
737        let _ = self
738            .ensure_oauth_bearer_token_fresh(server_name, false)
739            .await;
740        let server = {
741            let servers = self.servers.read().await;
742            let Some(server) = servers.get(server_name) else {
743                return Err(format!("MCP server '{server_name}' not found"));
744            };
745            server.clone()
746        };
747        if let Some(blocked) = pending_auth_short_circuit(
748            &server,
749            &canonical_tool,
750            tool_name,
751            now,
752            MCP_AUTH_REPROBE_COOLDOWN_MS,
753        ) {
754            return Ok(ToolResult {
755                output: blocked.output,
756                metadata: json!({
757                    "server": server_name,
758                    "tool": tool_name,
759                    "result": Value::Null,
760                    "mcpAuth": blocked.mcp_auth
761                }),
762            });
763        }
764        let normalized_args = normalize_mcp_tool_args(&server, tool_name, args);
765
766        {
767            let mut servers = self.servers.write().await;
768            if let Some(row) = servers.get_mut(server_name) {
769                if let Some(pending) = row.pending_auth_by_tool.get_mut(&canonical_tool) {
770                    pending.last_probe_ms = now;
771                }
772            }
773        }
774
775        let request = json!({
776            "jsonrpc": "2.0",
777            "id": format!("call-{}-{}", server_name, now_ms()),
778            "method": "tools/call",
779            "params": {
780                "name": tool_name,
781                "arguments": normalized_args
782            }
783        });
784        let (response, session_id) = match post_json_rpc_with_session(
785            &endpoint,
786            &effective_headers(&server),
787            request.clone(),
788            server.mcp_session_id.as_deref(),
789        )
790        .await
791        {
792            Ok(result) => result,
793            Err(error) => {
794                if should_retry_mcp_oauth_refresh(&server, &error)
795                    && self
796                        .ensure_oauth_bearer_token_fresh(server_name, true)
797                        .await?
798                {
799                    let refreshed_server = {
800                        let servers = self.servers.read().await;
801                        servers
802                            .get(server_name)
803                            .cloned()
804                            .ok_or_else(|| format!("MCP server '{server_name}' not found"))?
805                    };
806                    post_json_rpc_with_session(
807                        &endpoint,
808                        &effective_headers(&refreshed_server),
809                        request,
810                        refreshed_server.mcp_session_id.as_deref(),
811                    )
812                    .await?
813                } else {
814                    return Err(error);
815                }
816            }
817        };
818        if session_id.is_some() {
819            let mut servers = self.servers.write().await;
820            if let Some(row) = servers.get_mut(server_name) {
821                row.mcp_session_id = session_id;
822            }
823            drop(servers);
824            self.persist_state().await;
825        }
826
827        if let Some(err) = response.get("error") {
828            if let Some(challenge) = extract_auth_challenge(err, tool_name) {
829                let output = format!(
830                    "{}\n\nAuthorize here: {}",
831                    challenge.message, challenge.authorization_url
832                );
833                {
834                    let mut servers = self.servers.write().await;
835                    if let Some(row) = servers.get_mut(server_name) {
836                        row.last_auth_challenge = Some(challenge.clone());
837                        row.last_error = None;
838                        row.pending_auth_by_tool.insert(
839                            canonical_tool.clone(),
840                            pending_auth_from_challenge(&challenge),
841                        );
842                    }
843                }
844                self.persist_state().await;
845                return Ok(ToolResult {
846                    output,
847                    metadata: json!({
848                        "server": server_name,
849                        "tool": tool_name,
850                        "result": Value::Null,
851                        "mcpAuth": {
852                            "required": true,
853                            "challengeId": challenge.challenge_id,
854                            "tool": challenge.tool_name,
855                            "authorizationUrl": challenge.authorization_url,
856                            "message": challenge.message,
857                            "status": challenge.status
858                        }
859                    }),
860                });
861            }
862            let message = err
863                .get("message")
864                .and_then(|v| v.as_str())
865                .unwrap_or("MCP tools/call failed");
866            return Err(message.to_string());
867        }
868
869        let result = response.get("result").cloned().unwrap_or(Value::Null);
870        let auth_challenge = extract_auth_challenge(&result, tool_name);
871        let output = if let Some(challenge) = auth_challenge.as_ref() {
872            format!(
873                "{}\n\nAuthorize here: {}",
874                challenge.message, challenge.authorization_url
875            )
876        } else {
877            result
878                .get("content")
879                .map(render_mcp_content)
880                .or_else(|| result.get("output").map(|v| v.to_string()))
881                .unwrap_or_else(|| result.to_string())
882        };
883
884        {
885            let mut servers = self.servers.write().await;
886            if let Some(row) = servers.get_mut(server_name) {
887                row.last_auth_challenge = auth_challenge.clone();
888                if let Some(challenge) = auth_challenge.as_ref() {
889                    row.pending_auth_by_tool.insert(
890                        canonical_tool.clone(),
891                        pending_auth_from_challenge(challenge),
892                    );
893                } else {
894                    row.pending_auth_by_tool.remove(&canonical_tool);
895                }
896            }
897        }
898        self.persist_state().await;
899
900        let auth_metadata = auth_challenge.as_ref().map(|challenge| {
901            json!({
902                "required": true,
903                "challengeId": challenge.challenge_id,
904                "tool": challenge.tool_name,
905                "authorizationUrl": challenge.authorization_url,
906                "message": challenge.message,
907                "status": challenge.status
908            })
909        });
910
911        Ok(ToolResult {
912            output,
913            metadata: json!({
914                "server": server_name,
915                "tool": tool_name,
916                "result": result,
917                "mcpAuth": auth_metadata
918            }),
919        })
920    }
921
922    async fn connect_stdio(&self, name: &str, command_text: &str) -> bool {
923        match spawn_stdio_process(command_text).await {
924            Ok(child) => {
925                let pid = child.id();
926                self.processes.lock().await.insert(name.to_string(), child);
927                let mut servers = self.servers.write().await;
928                if let Some(server) = servers.get_mut(name) {
929                    server.connected = true;
930                    server.pid = pid;
931                    server.last_error = None;
932                    server.last_auth_challenge = None;
933                    server.pending_auth_by_tool.clear();
934                }
935                drop(servers);
936                self.persist_state().await;
937                true
938            }
939            Err(err) => {
940                let mut servers = self.servers.write().await;
941                if let Some(server) = servers.get_mut(name) {
942                    server.connected = false;
943                    server.pid = None;
944                    server.last_error = Some(err);
945                    server.last_auth_challenge = None;
946                    server.pending_auth_by_tool.clear();
947                }
948                drop(servers);
949                self.persist_state().await;
950                false
951            }
952        }
953    }
954
955    async fn discover_remote_tools(
956        &self,
957        server_name: &str,
958        endpoint: &str,
959        headers: &HashMap<String, String>,
960    ) -> Result<(Vec<McpRemoteTool>, Option<String>), DiscoverRemoteToolsError> {
961        let initialize = json!({
962            "jsonrpc": "2.0",
963            "id": "initialize-1",
964            "method": "initialize",
965            "params": {
966                "protocolVersion": MCP_PROTOCOL_VERSION,
967                "capabilities": {},
968                "clientInfo": {
969                    "name": MCP_CLIENT_NAME,
970                    "version": MCP_CLIENT_VERSION,
971                }
972            }
973        });
974        let (init_response, mut session_id) =
975            post_json_rpc_with_session(endpoint, headers, initialize, None).await?;
976        if let Some(err) = init_response.get("error") {
977            if let Some(challenge) = extract_auth_challenge(err, server_name) {
978                return Err(DiscoverRemoteToolsError::AuthChallenge(challenge));
979            }
980            let message = err
981                .get("message")
982                .and_then(|v| v.as_str())
983                .unwrap_or("MCP initialize failed");
984            return Err(DiscoverRemoteToolsError::Message(message.to_string()));
985        }
986
987        let tools_list = json!({
988            "jsonrpc": "2.0",
989            "id": "tools-list-1",
990            "method": "tools/list",
991            "params": {}
992        });
993        let (tools_response, next_session_id) =
994            post_json_rpc_with_session(endpoint, headers, tools_list, session_id.as_deref())
995                .await?;
996        if next_session_id.is_some() {
997            session_id = next_session_id;
998        }
999        if let Some(err) = tools_response.get("error") {
1000            if let Some(challenge) = extract_auth_challenge(err, server_name) {
1001                return Err(DiscoverRemoteToolsError::AuthChallenge(challenge));
1002            }
1003            let message = err
1004                .get("message")
1005                .and_then(|v| v.as_str())
1006                .unwrap_or("MCP tools/list failed");
1007            return Err(DiscoverRemoteToolsError::Message(message.to_string()));
1008        }
1009
1010        let tools = tools_response
1011            .get("result")
1012            .and_then(|v| v.get("tools"))
1013            .and_then(|v| v.as_array())
1014            .ok_or_else(|| "MCP tools/list result missing tools array".to_string())?;
1015
1016        let now = now_ms();
1017        let mut out = Vec::new();
1018        for row in tools {
1019            let Some(tool_name) = row.get("name").and_then(|v| v.as_str()) else {
1020                continue;
1021            };
1022            let description = row
1023                .get("description")
1024                .and_then(|v| v.as_str())
1025                .unwrap_or("")
1026                .to_string();
1027            let mut input_schema = row
1028                .get("inputSchema")
1029                .or_else(|| row.get("input_schema"))
1030                .cloned()
1031                .unwrap_or_else(|| json!({"type":"object"}));
1032            normalize_tool_input_schema(&mut input_schema);
1033            out.push(McpRemoteTool {
1034                server_name: String::new(),
1035                tool_name: tool_name.to_string(),
1036                namespaced_name: String::new(),
1037                description,
1038                input_schema,
1039                fetched_at_ms: now,
1040                schema_hash: String::new(),
1041            });
1042        }
1043
1044        Ok((out, session_id))
1045    }
1046
1047    async fn persist_state(&self) {
1048        let snapshot = self.servers.read().await.clone();
1049        persist_state_blocking(self.state_file.as_path(), &snapshot);
1050    }
1051
1052    async fn ensure_oauth_bearer_token_fresh(
1053        &self,
1054        name: &str,
1055        force: bool,
1056    ) -> Result<bool, String> {
1057        let server = {
1058            let servers = self.servers.read().await;
1059            servers.get(name).cloned()
1060        }
1061        .ok_or_else(|| format!("MCP server '{name}' not found"))?;
1062        let Some(oauth) = server.oauth.clone() else {
1063            return Ok(false);
1064        };
1065        let Some(credential) = tandem_core::load_provider_oauth_credential(&oauth.provider_id)
1066        else {
1067            return Ok(false);
1068        };
1069
1070        let should_refresh = force
1071            || credential.expires_at_ms <= now_ms().saturating_add(60_000)
1072            || credential.access_token.trim().is_empty();
1073        if !should_refresh {
1074            return Ok(false);
1075        }
1076
1077        let refreshed = refresh_mcp_oauth_credential(&oauth, &credential).await?;
1078        self.set_bearer_token(name, &refreshed.access_token).await?;
1079        tandem_core::set_provider_oauth_credential(&oauth.provider_id, refreshed)
1080            .map_err(|error| error.to_string())?;
1081        Ok(true)
1082    }
1083}
1084
1085impl Default for McpRegistry {
1086    fn default() -> Self {
1087        Self::new()
1088    }
1089}
1090
1091fn default_enabled() -> bool {
1092    true
1093}
1094
1095fn normalize_allowed_tool_names(raw: Vec<String>) -> Vec<String> {
1096    let mut normalized = Vec::new();
1097    let mut seen = std::collections::HashSet::new();
1098    for tool in raw {
1099        let value = tool.trim().to_string();
1100        if value.is_empty() || !seen.insert(value.clone()) {
1101            continue;
1102        }
1103        normalized.push(value);
1104    }
1105    normalized
1106}
1107
1108fn persist_state_blocking(path: &Path, snapshot: &HashMap<String, McpServer>) {
1109    if let Some(parent) = path.parent() {
1110        let _ = std::fs::create_dir_all(parent);
1111    }
1112    if let Ok(payload) = serde_json::to_string_pretty(snapshot) {
1113        let _ = std::fs::write(path, payload);
1114    }
1115}
1116
1117fn resolve_state_file() -> PathBuf {
1118    if let Ok(path) = std::env::var("TANDEM_MCP_REGISTRY") {
1119        return PathBuf::from(path);
1120    }
1121    if let Ok(state_dir) = std::env::var("TANDEM_STATE_DIR") {
1122        let trimmed = state_dir.trim();
1123        if !trimmed.is_empty() {
1124            return PathBuf::from(trimmed).join("mcp_servers.json");
1125        }
1126    }
1127    if let Some(data_dir) = dirs::data_dir() {
1128        return data_dir
1129            .join("tandem")
1130            .join("data")
1131            .join("mcp_servers.json");
1132    }
1133    dirs::home_dir()
1134        .map(|home| home.join(".tandem").join("data").join("mcp_servers.json"))
1135        .unwrap_or_else(|| PathBuf::from("mcp_servers.json"))
1136}
1137
1138fn load_state(path: &Path) -> (HashMap<String, McpServer>, bool) {
1139    let Ok(raw) = std::fs::read_to_string(path) else {
1140        return (HashMap::new(), false);
1141    };
1142    let mut migrated = false;
1143    let mut parsed = serde_json::from_str::<HashMap<String, McpServer>>(&raw).unwrap_or_default();
1144    for (name, server) in parsed.iter_mut() {
1145        let tenant_context = local_tenant_context();
1146        let (headers, secret_headers, secret_header_values, server_migrated) =
1147            migrate_server_headers(name, server, &tenant_context);
1148        migrated = migrated || server_migrated;
1149        server.headers = headers;
1150        server.secret_headers = secret_headers;
1151        server.secret_header_values = secret_header_values;
1152    }
1153    (parsed, migrated)
1154}
1155
1156fn migrate_server_headers(
1157    server_name: &str,
1158    server: &McpServer,
1159    current_tenant: &TenantContext,
1160) -> (
1161    HashMap<String, String>,
1162    HashMap<String, McpSecretRef>,
1163    HashMap<String, String>,
1164    bool,
1165) {
1166    let original_effective = effective_headers(server);
1167    let mut persisted_secret_headers = server.secret_headers.clone();
1168    let mut secret_header_values =
1169        resolve_secret_header_values(&persisted_secret_headers, current_tenant);
1170    let mut persisted_headers = server.headers.clone();
1171    let mut migrated = false;
1172
1173    let header_keys = persisted_headers.keys().cloned().collect::<Vec<_>>();
1174    for header_name in header_keys {
1175        let Some(value) = persisted_headers.get(&header_name).cloned() else {
1176            continue;
1177        };
1178        if persisted_secret_headers.contains_key(&header_name) {
1179            continue;
1180        }
1181        if let Some(secret_ref) = parse_secret_header_reference(value.trim()) {
1182            persisted_headers.remove(&header_name);
1183            let resolved =
1184                resolve_secret_ref_value(&secret_ref, current_tenant).unwrap_or_default();
1185            persisted_secret_headers.insert(header_name.clone(), secret_ref);
1186            if !resolved.is_empty() {
1187                secret_header_values.insert(header_name.clone(), resolved);
1188            }
1189            migrated = true;
1190            continue;
1191        }
1192        if header_name_is_sensitive(&header_name) && !value.trim().is_empty() {
1193            let secret_id = mcp_header_secret_id(server_name, &header_name);
1194            if tandem_core::set_provider_auth(&secret_id, &value).is_ok() {
1195                persisted_headers.remove(&header_name);
1196                persisted_secret_headers.insert(
1197                    header_name.clone(),
1198                    McpSecretRef::Store {
1199                        secret_id: secret_id.clone(),
1200                        tenant_context: current_tenant.clone(),
1201                    },
1202                );
1203                secret_header_values.insert(header_name.clone(), value);
1204                migrated = true;
1205            }
1206        }
1207    }
1208
1209    if !migrated {
1210        let effective = combine_headers(&persisted_headers, &secret_header_values);
1211        migrated = effective != original_effective;
1212    }
1213
1214    (
1215        persisted_headers,
1216        persisted_secret_headers,
1217        secret_header_values,
1218        migrated,
1219    )
1220}
1221
1222fn split_headers_for_storage(
1223    server_name: &str,
1224    headers: HashMap<String, String>,
1225    explicit_secret_headers: HashMap<String, McpSecretRef>,
1226    current_tenant: &TenantContext,
1227) -> (
1228    HashMap<String, String>,
1229    HashMap<String, McpSecretRef>,
1230    HashMap<String, String>,
1231) {
1232    let mut persisted_headers = HashMap::new();
1233    let mut persisted_secret_headers = HashMap::new();
1234    let mut secret_header_values = HashMap::new();
1235
1236    for (header_name, raw_value) in headers {
1237        let value = raw_value.trim().to_string();
1238        if value.is_empty() {
1239            continue;
1240        }
1241        if let Some(secret_ref) = parse_secret_header_reference(&value) {
1242            if let Some(resolved) = resolve_secret_ref_value(&secret_ref, current_tenant) {
1243                secret_header_values.insert(header_name.clone(), resolved);
1244            }
1245            persisted_secret_headers.insert(header_name, secret_ref);
1246            continue;
1247        }
1248        if header_name_is_sensitive(&header_name) {
1249            let secret_id = mcp_header_secret_id(server_name, &header_name);
1250            if tandem_core::set_provider_auth(&secret_id, &value).is_ok() {
1251                persisted_secret_headers.insert(
1252                    header_name.clone(),
1253                    McpSecretRef::Store {
1254                        secret_id: secret_id.clone(),
1255                        tenant_context: current_tenant.clone(),
1256                    },
1257                );
1258                secret_header_values.insert(header_name, value);
1259                continue;
1260            }
1261        }
1262        persisted_headers.insert(header_name, value);
1263    }
1264
1265    for (header_name, secret_ref) in explicit_secret_headers {
1266        if let Some(resolved) = resolve_secret_ref_value(&secret_ref, current_tenant) {
1267            secret_header_values.insert(header_name.clone(), resolved);
1268        }
1269        persisted_headers.remove(&header_name);
1270        persisted_secret_headers.insert(header_name, secret_ref);
1271    }
1272
1273    (
1274        persisted_headers,
1275        persisted_secret_headers,
1276        secret_header_values,
1277    )
1278}
1279
1280fn combine_headers(
1281    headers: &HashMap<String, String>,
1282    secret_header_values: &HashMap<String, String>,
1283) -> HashMap<String, String> {
1284    let mut combined = headers.clone();
1285    for (key, value) in secret_header_values {
1286        if !value.trim().is_empty() {
1287            combined.insert(key.clone(), value.clone());
1288        }
1289    }
1290    combined
1291}
1292
1293fn effective_headers(server: &McpServer) -> HashMap<String, String> {
1294    combine_headers(&server.headers, &server.secret_header_values)
1295}
1296
1297fn redacted_server_view(server: &McpServer) -> McpServer {
1298    let mut clone = server.clone();
1299    for (header_name, secret_ref) in &clone.secret_headers {
1300        clone.headers.insert(
1301            header_name.clone(),
1302            redacted_secret_header_value(secret_ref),
1303        );
1304    }
1305    clone.secret_header_values.clear();
1306    if let Some(oauth) = clone.oauth.as_mut() {
1307        oauth.client_secret_ref = None;
1308        oauth.client_secret_value = None;
1309    }
1310    clone
1311}
1312
1313fn normalize_auth_kind(raw: &str) -> String {
1314    match raw.trim().to_ascii_lowercase().as_str() {
1315        "oauth" | "auto" | "bearer" | "x-api-key" | "custom" | "none" => {
1316            raw.trim().to_ascii_lowercase()
1317        }
1318        _ => String::new(),
1319    }
1320}
1321
1322fn redacted_secret_header_value(secret_ref: &McpSecretRef) -> String {
1323    match secret_ref {
1324        McpSecretRef::BearerEnv { .. } => "Bearer ".to_string(),
1325        McpSecretRef::Env { .. } | McpSecretRef::Store { .. } => MCP_SECRET_PLACEHOLDER.to_string(),
1326    }
1327}
1328
1329fn resolve_secret_header_values(
1330    secret_headers: &HashMap<String, McpSecretRef>,
1331    current_tenant: &TenantContext,
1332) -> HashMap<String, String> {
1333    let mut out = HashMap::new();
1334    for (header_name, secret_ref) in secret_headers {
1335        if let Some(value) = resolve_secret_ref_value(secret_ref, current_tenant) {
1336            if !value.trim().is_empty() {
1337                out.insert(header_name.clone(), value);
1338            }
1339        }
1340    }
1341    out
1342}
1343
1344fn delete_secret_header_refs(
1345    secret_headers: &HashMap<String, McpSecretRef>,
1346    current_tenant: &TenantContext,
1347) {
1348    for secret_ref in secret_headers.values() {
1349        if let McpSecretRef::Store {
1350            secret_id,
1351            tenant_context,
1352        } = secret_ref
1353        {
1354            if tenant_context != current_tenant {
1355                continue;
1356            }
1357            let _ = tandem_core::delete_provider_auth(secret_id);
1358        }
1359    }
1360}
1361
1362fn delete_oauth_secret_ref(oauth: Option<&McpOAuthConfig>, current_tenant: &TenantContext) {
1363    let Some(secret_ref) = oauth.and_then(|oauth| oauth.client_secret_ref.as_ref()) else {
1364        return;
1365    };
1366    if let McpSecretRef::Store {
1367        secret_id,
1368        tenant_context,
1369    } = secret_ref
1370    {
1371        if tenant_context == current_tenant {
1372            let _ = tandem_core::delete_provider_auth(secret_id);
1373        }
1374    }
1375}