Skip to main content

tandem_runtime/mcp_parts/
part01.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, CONTENT_TYPE};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use sha2::{Digest, Sha256};
10use tandem_types::{LocalImplicitTenant, SecretRef, TenantContext, ToolResult};
11use tokio::process::{Child, Command};
12use tokio::sync::{Mutex, RwLock};
13
14use crate::mcp_ready::{EnsureReadyPolicy, McpReadyError};
15
16const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
17const MCP_CLIENT_NAME: &str = "tandem";
18const MCP_CLIENT_VERSION: &str = env!("CARGO_PKG_VERSION");
19const MCP_AUTH_REPROBE_COOLDOWN_MS: u64 = 15_000;
20const MCP_SECRET_PLACEHOLDER: &str = "";
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct McpToolCacheEntry {
24    pub tool_name: String,
25    pub description: String,
26    #[serde(default)]
27    pub input_schema: Value,
28    pub fetched_at_ms: u64,
29    pub schema_hash: String,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct McpServer {
34    pub name: String,
35    pub transport: String,
36    #[serde(default, skip_serializing_if = "String::is_empty")]
37    pub auth_kind: String,
38    #[serde(default = "default_enabled")]
39    pub enabled: bool,
40    pub connected: bool,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub pid: Option<u32>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub last_error: Option<String>,
45    #[serde(default, skip_serializing_if = "Option::is_none")]
46    pub last_auth_challenge: Option<McpAuthChallenge>,
47    #[serde(default, skip_serializing_if = "Option::is_none")]
48    pub mcp_session_id: Option<String>,
49    #[serde(default)]
50    pub headers: HashMap<String, String>,
51    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
52    pub secret_headers: HashMap<String, McpSecretRef>,
53    #[serde(default)]
54    pub tool_cache: Vec<McpToolCacheEntry>,
55    #[serde(default, skip_serializing_if = "Option::is_none")]
56    pub tools_fetched_at_ms: Option<u64>,
57    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
58    pub pending_auth_by_tool: HashMap<String, PendingMcpAuth>,
59    #[serde(default, skip_serializing_if = "Option::is_none")]
60    pub allowed_tools: Option<Vec<String>>,
61    #[serde(default, skip_serializing_if = "String::is_empty")]
62    pub purpose: String,
63    #[serde(default)]
64    pub grounding_required: bool,
65    #[serde(default, skip)]
66    pub secret_header_values: HashMap<String, String>,
67    #[serde(default, skip_serializing_if = "Option::is_none")]
68    pub oauth: Option<McpOAuthConfig>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
72#[serde(tag = "type", rename_all = "snake_case")]
73pub enum McpSecretRef {
74    Store {
75        secret_id: String,
76        #[serde(default)]
77        tenant_context: TenantContext,
78    },
79    Env {
80        env: String,
81    },
82    BearerEnv {
83        env: String,
84    },
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct McpAuthChallenge {
89    pub challenge_id: String,
90    pub tool_name: String,
91    pub authorization_url: String,
92    pub message: String,
93    pub requested_at_ms: u64,
94    pub status: String,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct PendingMcpAuth {
99    pub challenge_id: String,
100    pub authorization_url: String,
101    pub message: String,
102    pub status: String,
103    pub first_seen_ms: u64,
104    pub last_probe_ms: u64,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct McpOAuthConfig {
109    pub provider_id: String,
110    pub token_endpoint: String,
111    pub client_id: String,
112    #[serde(default, skip_serializing_if = "Option::is_none")]
113    pub client_secret_ref: Option<McpSecretRef>,
114    #[serde(default, skip)]
115    pub client_secret_value: Option<String>,
116}
117
118#[derive(Debug, Clone)]
119enum DiscoverRemoteToolsError {
120    Message(String),
121    AuthChallenge(McpAuthChallenge),
122}
123
124impl From<String> for DiscoverRemoteToolsError {
125    fn from(value: String) -> Self {
126        Self::Message(value)
127    }
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct McpRemoteTool {
132    pub server_name: String,
133    pub tool_name: String,
134    pub namespaced_name: String,
135    pub description: String,
136    #[serde(default)]
137    pub input_schema: Value,
138    pub fetched_at_ms: u64,
139    pub schema_hash: String,
140}
141
142#[derive(Clone)]
143pub struct McpRegistry {
144    servers: Arc<RwLock<HashMap<String, McpServer>>>,
145    processes: Arc<Mutex<HashMap<String, Child>>>,
146    state_file: Arc<PathBuf>,
147}
148
149impl McpRegistry {
150    pub fn new() -> Self {
151        Self::new_with_state_file(resolve_state_file())
152    }
153
154    pub fn new_with_state_file(state_file: PathBuf) -> Self {
155        let (loaded_state, migrated) = load_state(&state_file);
156        let loaded = loaded_state
157            .into_iter()
158            .map(|(k, mut v)| {
159                v.connected = false;
160                v.pid = None;
161                if v.name.trim().is_empty() {
162                    v.name = k.clone();
163                }
164                if v.headers.is_empty() {
165                    v.headers = HashMap::new();
166                }
167                if v.secret_headers.is_empty() {
168                    v.secret_headers = HashMap::new();
169                }
170                let tenant_context = local_tenant_context();
171                v.secret_header_values =
172                    resolve_secret_header_values(&v.secret_headers, &tenant_context);
173                if let Some(oauth) = v.oauth.as_mut() {
174                    oauth.client_secret_value =
175                        oauth.client_secret_ref.as_ref().and_then(|secret_ref| {
176                            resolve_secret_ref_value(secret_ref, &tenant_context)
177                        });
178                }
179                (k, v)
180            })
181            .collect::<HashMap<_, _>>();
182        if migrated {
183            persist_state_blocking(&state_file, &loaded);
184        }
185        Self {
186            servers: Arc::new(RwLock::new(loaded)),
187            processes: Arc::new(Mutex::new(HashMap::new())),
188            state_file: Arc::new(state_file),
189        }
190    }
191
192    pub async fn list(&self) -> HashMap<String, McpServer> {
193        self.servers.read().await.clone()
194    }
195
196    pub async fn list_public(&self) -> HashMap<String, McpServer> {
197        self.servers
198            .read()
199            .await
200            .iter()
201            .map(|(name, server)| (name.clone(), redacted_server_view(server)))
202            .collect()
203    }
204
205    pub async fn add(&self, name: String, transport: String) {
206        self.add_or_update(name, transport, HashMap::new(), true)
207            .await;
208    }
209
210    pub async fn add_or_update(
211        &self,
212        name: String,
213        transport: String,
214        headers: HashMap<String, String>,
215        enabled: bool,
216    ) {
217        self.add_or_update_with_secret_refs(name, transport, headers, HashMap::new(), enabled)
218            .await;
219    }
220
221    pub async fn add_or_update_with_secret_refs(
222        &self,
223        name: String,
224        transport: String,
225        headers: HashMap<String, String>,
226        secret_headers: HashMap<String, McpSecretRef>,
227        enabled: bool,
228    ) {
229        let normalized_name = name.trim().to_string();
230        let tenant_context = local_tenant_context();
231        let (persisted_headers, persisted_secret_headers, secret_header_values) =
232            split_headers_for_storage(&normalized_name, headers, secret_headers, &tenant_context);
233        let mut servers = self.servers.write().await;
234        let existing = servers.get(&normalized_name).cloned();
235        let preserve_cache = existing.as_ref().is_some_and(|row| {
236            row.transport == transport
237                && effective_headers(row)
238                    == combine_headers(&persisted_headers, &secret_header_values)
239        });
240        let existing_tool_cache = if preserve_cache {
241            existing
242                .as_ref()
243                .map(|row| row.tool_cache.clone())
244                .unwrap_or_default()
245        } else {
246            Vec::new()
247        };
248        let existing_fetched_at = if preserve_cache {
249            existing.as_ref().and_then(|row| row.tools_fetched_at_ms)
250        } else {
251            None
252        };
253        let server = McpServer {
254            name: normalized_name.clone(),
255            transport,
256            auth_kind: existing
257                .as_ref()
258                .map(|row| row.auth_kind.clone())
259                .unwrap_or_default(),
260            enabled,
261            connected: false,
262            pid: None,
263            last_error: None,
264            last_auth_challenge: None,
265            mcp_session_id: None,
266            headers: persisted_headers,
267            secret_headers: persisted_secret_headers,
268            tool_cache: existing_tool_cache,
269            tools_fetched_at_ms: existing_fetched_at,
270            pending_auth_by_tool: HashMap::new(),
271            allowed_tools: existing.as_ref().and_then(|row| row.allowed_tools.clone()),
272            purpose: existing
273                .as_ref()
274                .map(|row| row.purpose.clone())
275                .unwrap_or_default(),
276            grounding_required: existing
277                .as_ref()
278                .map(|row| row.grounding_required)
279                .unwrap_or(false),
280            secret_header_values,
281            oauth: existing.as_ref().and_then(|row| row.oauth.clone()),
282        };
283        servers.insert(normalized_name, server);
284        drop(servers);
285        self.persist_state().await;
286    }
287
288    pub async fn set_allowed_tools(&self, name: &str, allowed_tools: Option<Vec<String>>) -> bool {
289        let mut servers = self.servers.write().await;
290        let Some(server) = servers.get_mut(name) else {
291            return false;
292        };
293        let normalized = allowed_tools.map(normalize_allowed_tool_names);
294        if server.allowed_tools == normalized {
295            return true;
296        }
297        server.allowed_tools = normalized;
298        drop(servers);
299        self.persist_state().await;
300        true
301    }
302
303    pub async fn set_grounding_metadata(
304        &self,
305        name: &str,
306        purpose: Option<String>,
307        grounding_required: Option<bool>,
308    ) -> bool {
309        let mut servers = self.servers.write().await;
310        let Some(server) = servers.get_mut(name) else {
311            return false;
312        };
313        let mut changed = false;
314        if let Some(purpose) = purpose {
315            let normalized = purpose.trim().to_ascii_lowercase();
316            if server.purpose != normalized {
317                server.purpose = normalized;
318                changed = true;
319            }
320        }
321        if let Some(grounding_required) = grounding_required {
322            if server.grounding_required != grounding_required {
323                server.grounding_required = grounding_required;
324                changed = true;
325            }
326        }
327        drop(servers);
328        if changed {
329            self.persist_state().await;
330        }
331        true
332    }
333
334    pub async fn set_enabled(&self, name: &str, enabled: bool) -> bool {
335        let mut servers = self.servers.write().await;
336        let Some(server) = servers.get_mut(name) else {
337            return false;
338        };
339        server.enabled = enabled;
340        if !enabled {
341            server.connected = false;
342            server.pid = None;
343            server.last_auth_challenge = None;
344            server.mcp_session_id = None;
345            server.pending_auth_by_tool.clear();
346        }
347        drop(servers);
348        if !enabled {
349            if let Some(mut child) = self.processes.lock().await.remove(name) {
350                let _ = child.kill().await;
351                let _ = child.wait().await;
352            }
353        }
354        self.persist_state().await;
355        true
356    }
357
358    pub async fn remove(&self, name: &str) -> bool {
359        let removed_server = {
360            let mut servers = self.servers.write().await;
361            servers.remove(name)
362        };
363        let Some(server) = removed_server else {
364            return false;
365        };
366        let current_tenant = local_tenant_context();
367        delete_secret_header_refs(&server.secret_headers, &current_tenant);
368        delete_oauth_secret_ref(server.oauth.as_ref(), &current_tenant);
369
370        if let Some(mut child) = self.processes.lock().await.remove(name) {
371            let _ = child.kill().await;
372            let _ = child.wait().await;
373        }
374        self.persist_state().await;
375        true
376    }
377
378    pub async fn connect(&self, name: &str) -> bool {
379        let server = {
380            let servers = self.servers.read().await;
381            let Some(server) = servers.get(name) else {
382                return false;
383            };
384            server.clone()
385        };
386
387        if !server.enabled {
388            let mut servers = self.servers.write().await;
389            if let Some(entry) = servers.get_mut(name) {
390                entry.connected = false;
391                entry.pid = None;
392                entry.last_error = Some("MCP server is disabled".to_string());
393                entry.last_auth_challenge = None;
394                entry.mcp_session_id = None;
395                entry.pending_auth_by_tool.clear();
396            }
397            drop(servers);
398            self.persist_state().await;
399            return false;
400        }
401
402        if let Some(command_text) = parse_stdio_transport(&server.transport) {
403            return self.connect_stdio(name, command_text).await;
404        }
405
406        if parse_remote_endpoint(&server.transport).is_some() {
407            return self.refresh(name).await.is_ok();
408        }
409
410        let mut servers = self.servers.write().await;
411        if let Some(entry) = servers.get_mut(name) {
412            entry.connected = true;
413            entry.pid = None;
414            entry.last_error = None;
415            entry.last_auth_challenge = None;
416            entry.mcp_session_id = None;
417            entry.pending_auth_by_tool.clear();
418        }
419        drop(servers);
420        self.persist_state().await;
421        true
422    }
423
424    pub async fn refresh(&self, name: &str) -> Result<Vec<McpRemoteTool>, String> {
425        let server = {
426            let servers = self.servers.read().await;
427            let Some(server) = servers.get(name) else {
428                return Err("MCP server not found".to_string());
429            };
430            server.clone()
431        };
432
433        if !server.enabled {
434            return Err("MCP server is disabled".to_string());
435        }
436
437        let endpoint = parse_remote_endpoint(&server.transport)
438            .ok_or_else(|| "MCP refresh currently supports HTTP/S transports only".to_string())?;
439
440        let _ = self.ensure_oauth_bearer_token_fresh(name, false).await;
441        let server = {
442            let servers = self.servers.read().await;
443            let Some(server) = servers.get(name) else {
444                return Err("MCP server not found".to_string());
445            };
446            server.clone()
447        };
448        let request_headers = effective_headers(&server);
449        let discovery = self
450            .discover_remote_tools(name, &endpoint, &request_headers)
451            .await;
452        let (tools, session_id) = match discovery {
453            Ok(result) => result,
454            Err(DiscoverRemoteToolsError::AuthChallenge(challenge)) => {
455                let mut servers = self.servers.write().await;
456                if let Some(entry) = servers.get_mut(name) {
457                    entry.connected = false;
458                    entry.pid = None;
459                    entry.last_error = Some(challenge.message.clone());
460                    entry.last_auth_challenge = Some(challenge.clone());
461                    entry.mcp_session_id = None;
462                    entry.pending_auth_by_tool.clear();
463                    entry.tool_cache.clear();
464                    entry.tools_fetched_at_ms = None;
465                }
466                drop(servers);
467                self.persist_state().await;
468                return Err(format!(
469                    "MCP server '{name}' requires authorization: {}",
470                    challenge.message
471                ));
472            }
473            Err(DiscoverRemoteToolsError::Message(err)) => {
474                if should_retry_mcp_oauth_refresh(&server, &err)
475                    && self.ensure_oauth_bearer_token_fresh(name, true).await?
476                {
477                    let refreshed_server = {
478                        let servers = self.servers.read().await;
479                        servers
480                            .get(name)
481                            .cloned()
482                            .ok_or_else(|| "MCP server not found".to_string())?
483                    };
484                    match self
485                        .discover_remote_tools(
486                            name,
487                            &endpoint,
488                            &effective_headers(&refreshed_server),
489                        )
490                        .await
491                    {
492                        Ok(result) => result,
493                        Err(DiscoverRemoteToolsError::AuthChallenge(challenge)) => {
494                            let mut servers = self.servers.write().await;
495                            if let Some(entry) = servers.get_mut(name) {
496                                entry.connected = false;
497                                entry.pid = None;
498                                entry.last_error = Some(challenge.message.clone());
499                                entry.last_auth_challenge = Some(challenge.clone());
500                                entry.mcp_session_id = None;
501                                entry.pending_auth_by_tool.clear();
502                                entry.tool_cache.clear();
503                                entry.tools_fetched_at_ms = None;
504                            }
505                            drop(servers);
506                            self.persist_state().await;
507                            return Err(format!(
508                                "MCP server '{name}' requires authorization: {}",
509                                challenge.message
510                            ));
511                        }
512                        Err(DiscoverRemoteToolsError::Message(retry_err)) => {
513                            let mut servers = self.servers.write().await;
514                            if let Some(entry) = servers.get_mut(name) {
515                                entry.connected = false;
516                                entry.pid = None;
517                                entry.last_error = Some(retry_err.clone());
518                                entry.last_auth_challenge = None;
519                                entry.mcp_session_id = None;
520                                entry.pending_auth_by_tool.clear();
521                                entry.tool_cache.clear();
522                                entry.tools_fetched_at_ms = None;
523                            }
524                            drop(servers);
525                            self.persist_state().await;
526                            return Err(retry_err);
527                        }
528                    }
529                } else {
530                    let mut servers = self.servers.write().await;
531                    if let Some(entry) = servers.get_mut(name) {
532                        entry.connected = false;
533                        entry.pid = None;
534                        entry.last_error = Some(err.clone());
535                        entry.last_auth_challenge = None;
536                        entry.mcp_session_id = None;
537                        entry.pending_auth_by_tool.clear();
538                        entry.tool_cache.clear();
539                        entry.tools_fetched_at_ms = None;
540                    }
541                    drop(servers);
542                    self.persist_state().await;
543                    return Err(err);
544                }
545            }
546        };
547
548        let now = now_ms();
549        let cache = tools
550            .iter()
551            .map(|tool| McpToolCacheEntry {
552                tool_name: tool.tool_name.clone(),
553                description: tool.description.clone(),
554                input_schema: tool.input_schema.clone(),
555                fetched_at_ms: now,
556                schema_hash: schema_hash(&tool.input_schema),
557            })
558            .collect::<Vec<_>>();
559
560        let mut servers = self.servers.write().await;
561        if let Some(entry) = servers.get_mut(name) {
562            entry.connected = true;
563            entry.pid = None;
564            entry.last_error = None;
565            entry.last_auth_challenge = None;
566            entry.mcp_session_id = session_id;
567            entry.tool_cache = cache;
568            entry.tools_fetched_at_ms = Some(now);
569            entry.pending_auth_by_tool.clear();
570        }
571        drop(servers);
572        self.persist_state().await;
573        Ok(self.server_tools(name).await)
574    }
575
576    pub async fn disconnect(&self, name: &str) -> bool {
577        if let Some(mut child) = self.processes.lock().await.remove(name) {
578            let _ = child.kill().await;
579            let _ = child.wait().await;
580        }
581        let mut servers = self.servers.write().await;
582        if let Some(server) = servers.get_mut(name) {
583            server.connected = false;
584            server.pid = None;
585            server.last_auth_challenge = None;
586            server.mcp_session_id = None;
587            server.pending_auth_by_tool.clear();
588            drop(servers);
589            self.persist_state().await;
590            return true;
591        }
592        false
593    }
594
595    pub async fn complete_auth(&self, name: &str) -> bool {
596        let mut servers = self.servers.write().await;
597        let Some(server) = servers.get_mut(name) else {
598            return false;
599        };
600        server.last_error = None;
601        server.last_auth_challenge = None;
602        server.pending_auth_by_tool.clear();
603        drop(servers);
604        self.persist_state().await;
605        true
606    }
607
608    pub async fn set_auth_kind(&self, name: &str, auth_kind: String) -> bool {
609        let normalized = normalize_auth_kind(&auth_kind);
610        let mut servers = self.servers.write().await;
611        let Some(server) = servers.get_mut(name) else {
612            return false;
613        };
614        server.auth_kind = normalized;
615        drop(servers);
616        self.persist_state().await;
617        true
618    }
619
620    pub async fn record_server_auth_challenge(
621        &self,
622        name: &str,
623        challenge: McpAuthChallenge,
624        last_error: Option<String>,
625    ) -> bool {
626        let mut servers = self.servers.write().await;
627        let Some(server) = servers.get_mut(name) else {
628            return false;
629        };
630        let tool_key = canonical_tool_key(&challenge.tool_name);
631        server.connected = false;
632        server.pid = None;
633        server.last_error = last_error.or_else(|| Some(challenge.message.clone()));
634        server.last_auth_challenge = Some(challenge.clone());
635        server.mcp_session_id = None;
636        server.pending_auth_by_tool.clear();
637        server
638            .pending_auth_by_tool
639            .insert(tool_key, pending_auth_from_challenge(&challenge));
640        drop(servers);
641        self.persist_state().await;
642        true
643    }
644
645    pub async fn clear_server_auth_challenge(&self, name: &str) -> bool {
646        let mut servers = self.servers.write().await;
647        let Some(server) = servers.get_mut(name) else {
648            return false;
649        };
650        server.last_auth_challenge = None;
651        server.pending_auth_by_tool.clear();
652        drop(servers);
653        self.persist_state().await;
654        true
655    }
656
657    pub async fn set_bearer_token(&self, name: &str, token: &str) -> Result<bool, String> {
658        let trimmed = token.trim();
659        if trimmed.is_empty() {
660            return Err("oauth access token cannot be empty".to_string());
661        }
662        let current_tenant = local_tenant_context();
663        let mut servers = self.servers.write().await;
664        let Some(server) = servers.get_mut(name) else {
665            return Ok(false);
666        };
667        let header_name = "Authorization".to_string();
668        let secret_id = mcp_header_secret_id(name, &header_name);
669        tandem_core::set_provider_auth(&secret_id, &format!("Bearer {trimmed}"))
670            .map_err(|error| error.to_string())?;
671        server.secret_headers.insert(
672            header_name.clone(),
673            McpSecretRef::Store {
674                secret_id: secret_id.clone(),
675                tenant_context: current_tenant,
676            },
677        );
678        server
679            .secret_header_values
680            .insert(header_name.clone(), format!("Bearer {trimmed}"));
681        server.headers.remove(&header_name);
682        drop(servers);
683        self.persist_state().await;
684        Ok(true)
685    }
686
687    pub async fn set_oauth_refresh_config(
688        &self,
689        name: &str,
690        provider_id: String,
691        token_endpoint: String,
692        client_id: String,
693        client_secret: Option<String>,
694    ) -> Result<bool, String> {
695        let current_tenant = local_tenant_context();
696        let mut servers = self.servers.write().await;
697        let Some(server) = servers.get_mut(name) else {
698            return Ok(false);
699        };
700
701        let client_secret_ref = client_secret
702            .as_deref()
703            .map(str::trim)
704            .filter(|value| !value.is_empty())
705            .map(|value| -> Result<McpSecretRef, String> {
706                let secret_id = mcp_oauth_client_secret_id(name);
707                tandem_core::set_provider_auth(&secret_id, value)
708                    .map_err(|error| error.to_string())?;
709                Ok(McpSecretRef::Store {
710                    secret_id,
711                    tenant_context: current_tenant.clone(),
712                })
713            })
714            .transpose()?;
715        if client_secret_ref.is_none() {
716            let secret_id = mcp_oauth_client_secret_id(name);
717            let _ = tandem_core::delete_provider_auth(&secret_id);
718        }
719
720        server.oauth = Some(McpOAuthConfig {
721            provider_id,
722            token_endpoint,
723            client_id,
724            client_secret_ref,
725            client_secret_value: client_secret
726                .map(|value| value.trim().to_string())
727                .filter(|value| !value.is_empty()),
728        });
729        drop(servers);
730        self.persist_state().await;
731        Ok(true)
732    }
733
734    pub async fn list_tools(&self) -> Vec<McpRemoteTool> {
735        let mut out = self
736            .servers
737            .read()
738            .await
739            .values()
740            .filter(|server| server.enabled && server.connected)
741            .flat_map(server_tool_rows)
742            .collect::<Vec<_>>();
743        out.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
744        out
745    }
746
747    pub async fn server_tools(&self, name: &str) -> Vec<McpRemoteTool> {
748        let Some(server) = self.servers.read().await.get(name).cloned() else {
749            return Vec::new();
750        };
751        let mut rows = server_tool_rows(&server);
752        rows.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
753        rows
754    }
755
756    pub async fn call_tool(
757        &self,
758        server_name: &str,
759        tool_name: &str,
760        args: Value,
761    ) -> Result<ToolResult, String> {
762        // Single readiness gate (Invariant 2 of `docs/SPINE.md`): one
763        // attempt, no backoff — same shape as the previous inline check.
764        let server = match self
765            .ensure_ready(server_name, EnsureReadyPolicy::default())
766            .await
767        {
768            Ok(server) => server,
769            Err(McpReadyError::NotFound) => {
770                return Err(format!("MCP server '{server_name}' not found"));
771            }
772            Err(McpReadyError::Disabled) => {
773                return Err(format!("MCP server '{server_name}' is disabled"));
774            }
775            Err(McpReadyError::PermanentlyFailed { last_error }) => {
776                let detail = last_error.unwrap_or_else(|| "reconnect attempt failed".to_string());
777                return Err(format!(
778                    "MCP server '{server_name}' is not connected: {detail}"
779                ));
780            }
781        };
782
783        let endpoint = parse_remote_endpoint(&server.transport).ok_or_else(|| {
784            "MCP tools/call currently supports HTTP/S transports only".to_string()
785        })?;
786        let canonical_tool = canonical_tool_key(tool_name);
787        let now = now_ms();
788        let _ = self
789            .ensure_oauth_bearer_token_fresh(server_name, false)
790            .await;
791        let server = {
792            let servers = self.servers.read().await;
793            let Some(server) = servers.get(server_name) else {
794                return Err(format!("MCP server '{server_name}' not found"));
795            };
796            server.clone()
797        };
798        if let Some(blocked) = pending_auth_short_circuit(
799            &server,
800            &canonical_tool,
801            tool_name,
802            now,
803            MCP_AUTH_REPROBE_COOLDOWN_MS,
804        ) {
805            return Ok(ToolResult {
806                output: blocked.output,
807                metadata: json!({
808                    "server": server_name,
809                    "tool": tool_name,
810                    "result": Value::Null,
811                    "mcpAuth": blocked.mcp_auth
812                }),
813            });
814        }
815        let normalized_args = normalize_mcp_tool_args(&server, tool_name, args);
816
817        {
818            let mut servers = self.servers.write().await;
819            if let Some(row) = servers.get_mut(server_name) {
820                if let Some(pending) = row.pending_auth_by_tool.get_mut(&canonical_tool) {
821                    pending.last_probe_ms = now;
822                }
823            }
824        }
825
826        let request = json!({
827            "jsonrpc": "2.0",
828            "id": format!("call-{}-{}", server_name, now_ms()),
829            "method": "tools/call",
830            "params": {
831                "name": tool_name,
832                "arguments": normalized_args
833            }
834        });
835        let (response, session_id) = match post_json_rpc_with_session(
836            &endpoint,
837            &effective_headers(&server),
838            request.clone(),
839            server.mcp_session_id.as_deref(),
840        )
841        .await
842        {
843            Ok(result) => result,
844            Err(error) => {
845                if should_retry_mcp_oauth_refresh(&server, &error)
846                    && self
847                        .ensure_oauth_bearer_token_fresh(server_name, true)
848                        .await?
849                {
850                    let refreshed_server = {
851                        let servers = self.servers.read().await;
852                        servers
853                            .get(server_name)
854                            .cloned()
855                            .ok_or_else(|| format!("MCP server '{server_name}' not found"))?
856                    };
857                    post_json_rpc_with_session(
858                        &endpoint,
859                        &effective_headers(&refreshed_server),
860                        request,
861                        refreshed_server.mcp_session_id.as_deref(),
862                    )
863                    .await?
864                } else {
865                    return Err(error);
866                }
867            }
868        };
869        if session_id.is_some() {
870            let mut servers = self.servers.write().await;
871            if let Some(row) = servers.get_mut(server_name) {
872                row.mcp_session_id = session_id;
873            }
874            drop(servers);
875            self.persist_state().await;
876        }
877
878        if let Some(err) = response.get("error") {
879            if let Some(challenge) = extract_auth_challenge(err, tool_name) {
880                let output = format!(
881                    "{}\n\nAuthorize here: {}",
882                    challenge.message, challenge.authorization_url
883                );
884                {
885                    let mut servers = self.servers.write().await;
886                    if let Some(row) = servers.get_mut(server_name) {
887                        row.last_auth_challenge = Some(challenge.clone());
888                        row.last_error = None;
889                        row.pending_auth_by_tool.insert(
890                            canonical_tool.clone(),
891                            pending_auth_from_challenge(&challenge),
892                        );
893                    }
894                }
895                self.persist_state().await;
896                return Ok(ToolResult {
897                    output,
898                    metadata: json!({
899                        "server": server_name,
900                        "tool": tool_name,
901                        "result": Value::Null,
902                        "mcpAuth": {
903                            "required": true,
904                            "challengeId": challenge.challenge_id,
905                            "tool": challenge.tool_name,
906                            "authorizationUrl": challenge.authorization_url,
907                            "message": challenge.message,
908                            "status": challenge.status
909                        }
910                    }),
911                });
912            }
913            let message = err
914                .get("message")
915                .and_then(|v| v.as_str())
916                .unwrap_or("MCP tools/call failed");
917            return Err(message.to_string());
918        }
919
920        let result = response.get("result").cloned().unwrap_or(Value::Null);
921        let auth_challenge = extract_auth_challenge(&result, tool_name);
922        let output = if let Some(challenge) = auth_challenge.as_ref() {
923            format!(
924                "{}\n\nAuthorize here: {}",
925                challenge.message, challenge.authorization_url
926            )
927        } else {
928            result
929                .get("content")
930                .map(render_mcp_content)
931                .or_else(|| result.get("output").map(|v| v.to_string()))
932                .unwrap_or_else(|| result.to_string())
933        };
934
935        {
936            let mut servers = self.servers.write().await;
937            if let Some(row) = servers.get_mut(server_name) {
938                row.last_auth_challenge = auth_challenge.clone();
939                if let Some(challenge) = auth_challenge.as_ref() {
940                    row.pending_auth_by_tool.insert(
941                        canonical_tool.clone(),
942                        pending_auth_from_challenge(challenge),
943                    );
944                } else {
945                    row.pending_auth_by_tool.remove(&canonical_tool);
946                }
947            }
948        }
949        self.persist_state().await;
950
951        let auth_metadata = auth_challenge.as_ref().map(|challenge| {
952            json!({
953                "required": true,
954                "challengeId": challenge.challenge_id,
955                "tool": challenge.tool_name,
956                "authorizationUrl": challenge.authorization_url,
957                "message": challenge.message,
958                "status": challenge.status
959            })
960        });
961
962        Ok(ToolResult {
963            output,
964            metadata: json!({
965                "server": server_name,
966                "tool": tool_name,
967                "result": result,
968                "mcpAuth": auth_metadata
969            }),
970        })
971    }
972
973    async fn connect_stdio(&self, name: &str, command_text: &str) -> bool {
974        match spawn_stdio_process(command_text).await {
975            Ok(child) => {
976                let pid = child.id();
977                self.processes.lock().await.insert(name.to_string(), child);
978                let mut servers = self.servers.write().await;
979                if let Some(server) = servers.get_mut(name) {
980                    server.connected = true;
981                    server.pid = pid;
982                    server.last_error = None;
983                    server.last_auth_challenge = None;
984                    server.pending_auth_by_tool.clear();
985                }
986                drop(servers);
987                self.persist_state().await;
988                true
989            }
990            Err(err) => {
991                let mut servers = self.servers.write().await;
992                if let Some(server) = servers.get_mut(name) {
993                    server.connected = false;
994                    server.pid = None;
995                    server.last_error = Some(err);
996                    server.last_auth_challenge = None;
997                    server.pending_auth_by_tool.clear();
998                }
999                drop(servers);
1000                self.persist_state().await;
1001                false
1002            }
1003        }
1004    }
1005
1006    async fn discover_remote_tools(
1007        &self,
1008        server_name: &str,
1009        endpoint: &str,
1010        headers: &HashMap<String, String>,
1011    ) -> Result<(Vec<McpRemoteTool>, Option<String>), DiscoverRemoteToolsError> {
1012        let initialize = json!({
1013            "jsonrpc": "2.0",
1014            "id": "initialize-1",
1015            "method": "initialize",
1016            "params": {
1017                "protocolVersion": MCP_PROTOCOL_VERSION,
1018                "capabilities": {},
1019                "clientInfo": {
1020                    "name": MCP_CLIENT_NAME,
1021                    "version": MCP_CLIENT_VERSION,
1022                }
1023            }
1024        });
1025        let (init_response, mut session_id) =
1026            post_json_rpc_with_session(endpoint, headers, initialize, None).await?;
1027        if let Some(err) = init_response.get("error") {
1028            if let Some(challenge) = extract_auth_challenge(err, server_name) {
1029                return Err(DiscoverRemoteToolsError::AuthChallenge(challenge));
1030            }
1031            let message = err
1032                .get("message")
1033                .and_then(|v| v.as_str())
1034                .unwrap_or("MCP initialize failed");
1035            return Err(DiscoverRemoteToolsError::Message(message.to_string()));
1036        }
1037
1038        let tools_list = json!({
1039            "jsonrpc": "2.0",
1040            "id": "tools-list-1",
1041            "method": "tools/list",
1042            "params": {}
1043        });
1044        let (tools_response, next_session_id) =
1045            post_json_rpc_with_session(endpoint, headers, tools_list, session_id.as_deref())
1046                .await?;
1047        if next_session_id.is_some() {
1048            session_id = next_session_id;
1049        }
1050        if let Some(err) = tools_response.get("error") {
1051            if let Some(challenge) = extract_auth_challenge(err, server_name) {
1052                return Err(DiscoverRemoteToolsError::AuthChallenge(challenge));
1053            }
1054            let message = err
1055                .get("message")
1056                .and_then(|v| v.as_str())
1057                .unwrap_or("MCP tools/list failed");
1058            return Err(DiscoverRemoteToolsError::Message(message.to_string()));
1059        }
1060
1061        let tools = tools_response
1062            .get("result")
1063            .and_then(|v| v.get("tools"))
1064            .and_then(|v| v.as_array())
1065            .ok_or_else(|| "MCP tools/list result missing tools array".to_string())?;
1066
1067        let now = now_ms();
1068        let mut out = Vec::new();
1069        for row in tools {
1070            let Some(tool_name) = row.get("name").and_then(|v| v.as_str()) else {
1071                continue;
1072            };
1073            let description = row
1074                .get("description")
1075                .and_then(|v| v.as_str())
1076                .unwrap_or("")
1077                .to_string();
1078            let mut input_schema = row
1079                .get("inputSchema")
1080                .or_else(|| row.get("input_schema"))
1081                .cloned()
1082                .unwrap_or_else(|| json!({"type":"object"}));
1083            normalize_tool_input_schema(&mut input_schema);
1084            out.push(McpRemoteTool {
1085                server_name: String::new(),
1086                tool_name: tool_name.to_string(),
1087                namespaced_name: String::new(),
1088                description,
1089                input_schema,
1090                fetched_at_ms: now,
1091                schema_hash: String::new(),
1092            });
1093        }
1094
1095        Ok((out, session_id))
1096    }
1097
1098    async fn persist_state(&self) {
1099        let snapshot = self.servers.read().await.clone();
1100        persist_state_blocking(self.state_file.as_path(), &snapshot);
1101    }
1102
1103    async fn ensure_oauth_bearer_token_fresh(
1104        &self,
1105        name: &str,
1106        force: bool,
1107    ) -> Result<bool, String> {
1108        let server = {
1109            let servers = self.servers.read().await;
1110            servers.get(name).cloned()
1111        }
1112        .ok_or_else(|| format!("MCP server '{name}' not found"))?;
1113        let Some(oauth) = server.oauth.clone() else {
1114            return Ok(false);
1115        };
1116        let Some(credential) = tandem_core::load_provider_oauth_credential(&oauth.provider_id)
1117        else {
1118            return Ok(false);
1119        };
1120
1121        let should_refresh = force
1122            || credential.expires_at_ms <= now_ms().saturating_add(60_000)
1123            || credential.access_token.trim().is_empty();
1124        if !should_refresh {
1125            return Ok(false);
1126        }
1127
1128        let refreshed = refresh_mcp_oauth_credential(&oauth, &credential).await?;
1129        self.set_bearer_token(name, &refreshed.access_token).await?;
1130        tandem_core::set_provider_oauth_credential(&oauth.provider_id, refreshed)
1131            .map_err(|error| error.to_string())?;
1132        Ok(true)
1133    }
1134}
1135
1136impl Default for McpRegistry {
1137    fn default() -> Self {
1138        Self::new()
1139    }
1140}
1141
1142fn default_enabled() -> bool {
1143    true
1144}
1145
1146fn normalize_allowed_tool_names(raw: Vec<String>) -> Vec<String> {
1147    let mut normalized = Vec::new();
1148    let mut seen = std::collections::HashSet::new();
1149    for tool in raw {
1150        let value = tool.trim().to_string();
1151        if value.is_empty() || !seen.insert(value.clone()) {
1152            continue;
1153        }
1154        normalized.push(value);
1155    }
1156    normalized
1157}
1158
1159fn persist_state_blocking(path: &Path, snapshot: &HashMap<String, McpServer>) {
1160    if let Some(parent) = path.parent() {
1161        let _ = std::fs::create_dir_all(parent);
1162    }
1163    if let Ok(payload) = serde_json::to_string_pretty(snapshot) {
1164        let _ = std::fs::write(path, payload);
1165    }
1166}
1167
1168fn resolve_state_file() -> PathBuf {
1169    if let Ok(path) = std::env::var("TANDEM_MCP_REGISTRY") {
1170        return PathBuf::from(path);
1171    }
1172    if let Ok(state_dir) = std::env::var("TANDEM_STATE_DIR") {
1173        let trimmed = state_dir.trim();
1174        if !trimmed.is_empty() {
1175            let base = PathBuf::from(trimmed);
1176            return if base.file_name().and_then(|value| value.to_str()) == Some("data") {
1177                base.join("mcp").join("mcp_servers.json")
1178            } else {
1179                base.join("data").join("mcp").join("mcp_servers.json")
1180            };
1181        }
1182    }
1183    if let Some(data_dir) = dirs::data_dir() {
1184        return data_dir
1185            .join("tandem")
1186            .join("data")
1187            .join("mcp")
1188            .join("mcp_servers.json");
1189    }
1190    dirs::home_dir()
1191        .map(|home| {
1192            home.join(".tandem")
1193                .join("data")
1194                .join("mcp")
1195                .join("mcp_servers.json")
1196        })
1197        .unwrap_or_else(|| PathBuf::from("mcp_servers.json"))
1198}
1199
1200fn load_state(path: &Path) -> (HashMap<String, McpServer>, bool) {
1201    let read_path = if path.exists() {
1202        path.to_path_buf()
1203    } else {
1204        legacy_mcp_registry_path()
1205    };
1206    let Ok(raw) = std::fs::read_to_string(read_path) else {
1207        return (HashMap::new(), false);
1208    };
1209    let mut migrated = false;
1210    let mut parsed = serde_json::from_str::<HashMap<String, McpServer>>(&raw).unwrap_or_default();
1211    for (name, server) in parsed.iter_mut() {
1212        let tenant_context = local_tenant_context();
1213        let (headers, secret_headers, secret_header_values, server_migrated) =
1214            migrate_server_headers(name, server, &tenant_context);
1215        migrated = migrated || server_migrated;
1216        server.headers = headers;
1217        server.secret_headers = secret_headers;
1218        server.secret_header_values = secret_header_values;
1219    }
1220    (parsed, migrated)
1221}
1222
1223fn legacy_mcp_registry_path() -> PathBuf {
1224    if let Ok(state_dir) = std::env::var("TANDEM_STATE_DIR") {
1225        let trimmed = state_dir.trim();
1226        if !trimmed.is_empty() {
1227            let base = PathBuf::from(trimmed);
1228            if base.file_name().and_then(|value| value.to_str()) != Some("data") {
1229                return base.join("mcp_servers.json");
1230            }
1231            return base
1232                .parent()
1233                .map(|parent| parent.join("mcp_servers.json"))
1234                .unwrap_or_else(|| base.join("mcp_servers.json"));
1235        }
1236    }
1237    dirs::data_dir()
1238        .map(|base| base.join("tandem").join("mcp_servers.json"))
1239        .or_else(|| dirs::home_dir().map(|home| home.join(".tandem").join("mcp_servers.json")))
1240        .unwrap_or_else(|| PathBuf::from("mcp_servers.json"))
1241}
1242
1243fn migrate_server_headers(
1244    server_name: &str,
1245    server: &McpServer,
1246    current_tenant: &TenantContext,
1247) -> (
1248    HashMap<String, String>,
1249    HashMap<String, McpSecretRef>,
1250    HashMap<String, String>,
1251    bool,
1252) {
1253    let original_effective = effective_headers(server);
1254    let mut persisted_secret_headers = server.secret_headers.clone();
1255    let mut secret_header_values =
1256        resolve_secret_header_values(&persisted_secret_headers, current_tenant);
1257    let mut persisted_headers = server.headers.clone();
1258    let mut migrated = false;
1259
1260    let header_keys = persisted_headers.keys().cloned().collect::<Vec<_>>();
1261    for header_name in header_keys {
1262        let Some(value) = persisted_headers.get(&header_name).cloned() else {
1263            continue;
1264        };
1265        if persisted_secret_headers.contains_key(&header_name) {
1266            continue;
1267        }
1268        if let Some(secret_ref) = parse_secret_header_reference(value.trim()) {
1269            persisted_headers.remove(&header_name);
1270            let resolved =
1271                resolve_secret_ref_value(&secret_ref, current_tenant).unwrap_or_default();
1272            persisted_secret_headers.insert(header_name.clone(), secret_ref);
1273            if !resolved.is_empty() {
1274                secret_header_values.insert(header_name.clone(), resolved);
1275            }
1276            migrated = true;
1277            continue;
1278        }
1279        if header_name_is_sensitive(&header_name) && !value.trim().is_empty() {
1280            let secret_id = mcp_header_secret_id(server_name, &header_name);
1281            if tandem_core::set_provider_auth(&secret_id, &value).is_ok() {
1282                persisted_headers.remove(&header_name);
1283                persisted_secret_headers.insert(
1284                    header_name.clone(),
1285                    McpSecretRef::Store {
1286                        secret_id: secret_id.clone(),
1287                        tenant_context: current_tenant.clone(),
1288                    },
1289                );
1290                secret_header_values.insert(header_name.clone(), value);
1291                migrated = true;
1292            }
1293        }
1294    }
1295
1296    if !migrated {
1297        let effective = combine_headers(&persisted_headers, &secret_header_values);
1298        migrated = effective != original_effective;
1299    }
1300
1301    (
1302        persisted_headers,
1303        persisted_secret_headers,
1304        secret_header_values,
1305        migrated,
1306    )
1307}
1308
1309fn split_headers_for_storage(
1310    server_name: &str,
1311    headers: HashMap<String, String>,
1312    explicit_secret_headers: HashMap<String, McpSecretRef>,
1313    current_tenant: &TenantContext,
1314) -> (
1315    HashMap<String, String>,
1316    HashMap<String, McpSecretRef>,
1317    HashMap<String, String>,
1318) {
1319    let mut persisted_headers = HashMap::new();
1320    let mut persisted_secret_headers = HashMap::new();
1321    let mut secret_header_values = HashMap::new();
1322
1323    for (header_name, raw_value) in headers {
1324        let value = raw_value.trim().to_string();
1325        if value.is_empty() {
1326            continue;
1327        }
1328        if let Some(secret_ref) = parse_secret_header_reference(&value) {
1329            if let Some(resolved) = resolve_secret_ref_value(&secret_ref, current_tenant) {
1330                secret_header_values.insert(header_name.clone(), resolved);
1331            }
1332            persisted_secret_headers.insert(header_name, secret_ref);
1333            continue;
1334        }
1335        if header_name_is_sensitive(&header_name) {
1336            let secret_id = mcp_header_secret_id(server_name, &header_name);
1337            if tandem_core::set_provider_auth(&secret_id, &value).is_ok() {
1338                persisted_secret_headers.insert(
1339                    header_name.clone(),
1340                    McpSecretRef::Store {
1341                        secret_id: secret_id.clone(),
1342                        tenant_context: current_tenant.clone(),
1343                    },
1344                );
1345                secret_header_values.insert(header_name, value);
1346                continue;
1347            }
1348        }
1349        persisted_headers.insert(header_name, value);
1350    }
1351
1352    for (header_name, secret_ref) in explicit_secret_headers {
1353        if let Some(resolved) = resolve_secret_ref_value(&secret_ref, current_tenant) {
1354            secret_header_values.insert(header_name.clone(), resolved);
1355        }
1356        persisted_headers.remove(&header_name);
1357        persisted_secret_headers.insert(header_name, secret_ref);
1358    }
1359
1360    (
1361        persisted_headers,
1362        persisted_secret_headers,
1363        secret_header_values,
1364    )
1365}
1366
1367fn combine_headers(
1368    headers: &HashMap<String, String>,
1369    secret_header_values: &HashMap<String, String>,
1370) -> HashMap<String, String> {
1371    let mut combined = headers.clone();
1372    for (key, value) in secret_header_values {
1373        if !value.trim().is_empty() {
1374            combined.insert(key.clone(), value.clone());
1375        }
1376    }
1377    combined
1378}
1379
1380fn effective_headers(server: &McpServer) -> HashMap<String, String> {
1381    combine_headers(&server.headers, &server.secret_header_values)
1382}
1383
1384fn redacted_server_view(server: &McpServer) -> McpServer {
1385    let mut clone = server.clone();
1386    for (header_name, secret_ref) in &clone.secret_headers {
1387        clone.headers.insert(
1388            header_name.clone(),
1389            redacted_secret_header_value(secret_ref),
1390        );
1391    }
1392    clone.secret_header_values.clear();
1393    if let Some(oauth) = clone.oauth.as_mut() {
1394        oauth.client_secret_ref = None;
1395        oauth.client_secret_value = None;
1396    }
1397    clone
1398}
1399
1400fn normalize_auth_kind(raw: &str) -> String {
1401    match raw.trim().to_ascii_lowercase().as_str() {
1402        "oauth" | "auto" | "bearer" | "x-api-key" | "custom" | "none" => {
1403            raw.trim().to_ascii_lowercase()
1404        }
1405        _ => String::new(),
1406    }
1407}
1408
1409fn redacted_secret_header_value(secret_ref: &McpSecretRef) -> String {
1410    match secret_ref {
1411        McpSecretRef::BearerEnv { .. } => "Bearer ".to_string(),
1412        McpSecretRef::Env { .. } | McpSecretRef::Store { .. } => MCP_SECRET_PLACEHOLDER.to_string(),
1413    }
1414}
1415
1416fn resolve_secret_header_values(
1417    secret_headers: &HashMap<String, McpSecretRef>,
1418    current_tenant: &TenantContext,
1419) -> HashMap<String, String> {
1420    let mut out = HashMap::new();
1421    for (header_name, secret_ref) in secret_headers {
1422        if let Some(value) = resolve_secret_ref_value(secret_ref, current_tenant) {
1423            if !value.trim().is_empty() {
1424                out.insert(header_name.clone(), value);
1425            }
1426        }
1427    }
1428    out
1429}
1430
1431fn delete_secret_header_refs(
1432    secret_headers: &HashMap<String, McpSecretRef>,
1433    current_tenant: &TenantContext,
1434) {
1435    for secret_ref in secret_headers.values() {
1436        if let McpSecretRef::Store {
1437            secret_id,
1438            tenant_context,
1439        } = secret_ref
1440        {
1441            if tenant_context != current_tenant {
1442                continue;
1443            }
1444            let _ = tandem_core::delete_provider_auth(secret_id);
1445        }
1446    }
1447}
1448
1449fn delete_oauth_secret_ref(oauth: Option<&McpOAuthConfig>, current_tenant: &TenantContext) {
1450    let Some(secret_ref) = oauth.and_then(|oauth| oauth.client_secret_ref.as_ref()) else {
1451        return;
1452    };
1453    if let McpSecretRef::Store {
1454        secret_id,
1455        tenant_context,
1456    } = secret_ref
1457    {
1458        if tenant_context == current_tenant {
1459            let _ = tandem_core::delete_provider_auth(secret_id);
1460        }
1461    }
1462}