Skip to main content

tandem_runtime/mcp_parts/
part01.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, CONTENT_TYPE};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use sha2::{Digest, Sha256};
10use tandem_types::{LocalImplicitTenant, SecretRef, TenantContext, ToolResult};
11use tokio::process::{Child, Command};
12use tokio::sync::{Mutex, RwLock};
13
14const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
15const MCP_CLIENT_NAME: &str = "tandem";
16const MCP_CLIENT_VERSION: &str = env!("CARGO_PKG_VERSION");
17const MCP_AUTH_REPROBE_COOLDOWN_MS: u64 = 15_000;
18const MCP_SECRET_PLACEHOLDER: &str = "";
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct McpToolCacheEntry {
22    pub tool_name: String,
23    pub description: String,
24    #[serde(default)]
25    pub input_schema: Value,
26    pub fetched_at_ms: u64,
27    pub schema_hash: String,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct McpServer {
32    pub name: String,
33    pub transport: String,
34    #[serde(default, skip_serializing_if = "String::is_empty")]
35    pub auth_kind: String,
36    #[serde(default = "default_enabled")]
37    pub enabled: bool,
38    pub connected: bool,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub pid: Option<u32>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub last_error: Option<String>,
43    #[serde(default, skip_serializing_if = "Option::is_none")]
44    pub last_auth_challenge: Option<McpAuthChallenge>,
45    #[serde(default, skip_serializing_if = "Option::is_none")]
46    pub mcp_session_id: Option<String>,
47    #[serde(default)]
48    pub headers: HashMap<String, String>,
49    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
50    pub secret_headers: HashMap<String, McpSecretRef>,
51    #[serde(default)]
52    pub tool_cache: Vec<McpToolCacheEntry>,
53    #[serde(default, skip_serializing_if = "Option::is_none")]
54    pub tools_fetched_at_ms: Option<u64>,
55    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
56    pub pending_auth_by_tool: HashMap<String, PendingMcpAuth>,
57    #[serde(default, skip_serializing_if = "Option::is_none")]
58    pub allowed_tools: Option<Vec<String>>,
59    #[serde(default, skip_serializing_if = "String::is_empty")]
60    pub purpose: String,
61    #[serde(default)]
62    pub grounding_required: bool,
63    #[serde(default, skip)]
64    pub secret_header_values: HashMap<String, String>,
65    #[serde(default, skip_serializing_if = "Option::is_none")]
66    pub oauth: Option<McpOAuthConfig>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
70#[serde(tag = "type", rename_all = "snake_case")]
71pub enum McpSecretRef {
72    Store {
73        secret_id: String,
74        #[serde(default)]
75        tenant_context: TenantContext,
76    },
77    Env {
78        env: String,
79    },
80    BearerEnv {
81        env: String,
82    },
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct McpAuthChallenge {
87    pub challenge_id: String,
88    pub tool_name: String,
89    pub authorization_url: String,
90    pub message: String,
91    pub requested_at_ms: u64,
92    pub status: String,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct PendingMcpAuth {
97    pub challenge_id: String,
98    pub authorization_url: String,
99    pub message: String,
100    pub status: String,
101    pub first_seen_ms: u64,
102    pub last_probe_ms: u64,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct McpOAuthConfig {
107    pub provider_id: String,
108    pub token_endpoint: String,
109    pub client_id: String,
110    #[serde(default, skip_serializing_if = "Option::is_none")]
111    pub client_secret_ref: Option<McpSecretRef>,
112    #[serde(default, skip)]
113    pub client_secret_value: Option<String>,
114}
115
116#[derive(Debug, Clone)]
117enum DiscoverRemoteToolsError {
118    Message(String),
119    AuthChallenge(McpAuthChallenge),
120}
121
122impl From<String> for DiscoverRemoteToolsError {
123    fn from(value: String) -> Self {
124        Self::Message(value)
125    }
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct McpRemoteTool {
130    pub server_name: String,
131    pub tool_name: String,
132    pub namespaced_name: String,
133    pub description: String,
134    #[serde(default)]
135    pub input_schema: Value,
136    pub fetched_at_ms: u64,
137    pub schema_hash: String,
138}
139
140#[derive(Clone)]
141pub struct McpRegistry {
142    servers: Arc<RwLock<HashMap<String, McpServer>>>,
143    processes: Arc<Mutex<HashMap<String, Child>>>,
144    state_file: Arc<PathBuf>,
145}
146
147impl McpRegistry {
148    pub fn new() -> Self {
149        Self::new_with_state_file(resolve_state_file())
150    }
151
152    pub fn new_with_state_file(state_file: PathBuf) -> Self {
153        let (loaded_state, migrated) = load_state(&state_file);
154        let loaded = loaded_state
155            .into_iter()
156            .map(|(k, mut v)| {
157                v.connected = false;
158                v.pid = None;
159                if v.name.trim().is_empty() {
160                    v.name = k.clone();
161                }
162                if v.headers.is_empty() {
163                    v.headers = HashMap::new();
164                }
165                if v.secret_headers.is_empty() {
166                    v.secret_headers = HashMap::new();
167                }
168                let tenant_context = local_tenant_context();
169                v.secret_header_values =
170                    resolve_secret_header_values(&v.secret_headers, &tenant_context);
171                if let Some(oauth) = v.oauth.as_mut() {
172                    oauth.client_secret_value =
173                        oauth.client_secret_ref.as_ref().and_then(|secret_ref| {
174                            resolve_secret_ref_value(secret_ref, &tenant_context)
175                        });
176                }
177                (k, v)
178            })
179            .collect::<HashMap<_, _>>();
180        if migrated {
181            persist_state_blocking(&state_file, &loaded);
182        }
183        Self {
184            servers: Arc::new(RwLock::new(loaded)),
185            processes: Arc::new(Mutex::new(HashMap::new())),
186            state_file: Arc::new(state_file),
187        }
188    }
189
190    pub async fn list(&self) -> HashMap<String, McpServer> {
191        self.servers.read().await.clone()
192    }
193
194    pub async fn list_public(&self) -> HashMap<String, McpServer> {
195        self.servers
196            .read()
197            .await
198            .iter()
199            .map(|(name, server)| (name.clone(), redacted_server_view(server)))
200            .collect()
201    }
202
203    pub async fn add(&self, name: String, transport: String) {
204        self.add_or_update(name, transport, HashMap::new(), true)
205            .await;
206    }
207
208    pub async fn add_or_update(
209        &self,
210        name: String,
211        transport: String,
212        headers: HashMap<String, String>,
213        enabled: bool,
214    ) {
215        self.add_or_update_with_secret_refs(name, transport, headers, HashMap::new(), enabled)
216            .await;
217    }
218
219    pub async fn add_or_update_with_secret_refs(
220        &self,
221        name: String,
222        transport: String,
223        headers: HashMap<String, String>,
224        secret_headers: HashMap<String, McpSecretRef>,
225        enabled: bool,
226    ) {
227        let normalized_name = name.trim().to_string();
228        let tenant_context = local_tenant_context();
229        let (persisted_headers, persisted_secret_headers, secret_header_values) =
230            split_headers_for_storage(&normalized_name, headers, secret_headers, &tenant_context);
231        let mut servers = self.servers.write().await;
232        let existing = servers.get(&normalized_name).cloned();
233        let preserve_cache = existing.as_ref().is_some_and(|row| {
234            row.transport == transport
235                && effective_headers(row)
236                    == combine_headers(&persisted_headers, &secret_header_values)
237        });
238        let existing_tool_cache = if preserve_cache {
239            existing
240                .as_ref()
241                .map(|row| row.tool_cache.clone())
242                .unwrap_or_default()
243        } else {
244            Vec::new()
245        };
246        let existing_fetched_at = if preserve_cache {
247            existing.as_ref().and_then(|row| row.tools_fetched_at_ms)
248        } else {
249            None
250        };
251        let server = McpServer {
252            name: normalized_name.clone(),
253            transport,
254            auth_kind: existing
255                .as_ref()
256                .map(|row| row.auth_kind.clone())
257                .unwrap_or_default(),
258            enabled,
259            connected: false,
260            pid: None,
261            last_error: None,
262            last_auth_challenge: None,
263            mcp_session_id: None,
264            headers: persisted_headers,
265            secret_headers: persisted_secret_headers,
266            tool_cache: existing_tool_cache,
267            tools_fetched_at_ms: existing_fetched_at,
268            pending_auth_by_tool: HashMap::new(),
269            allowed_tools: existing.as_ref().and_then(|row| row.allowed_tools.clone()),
270            purpose: existing
271                .as_ref()
272                .map(|row| row.purpose.clone())
273                .unwrap_or_default(),
274            grounding_required: existing
275                .as_ref()
276                .map(|row| row.grounding_required)
277                .unwrap_or(false),
278            secret_header_values,
279            oauth: existing.as_ref().and_then(|row| row.oauth.clone()),
280        };
281        servers.insert(normalized_name, server);
282        drop(servers);
283        self.persist_state().await;
284    }
285
286    pub async fn set_allowed_tools(&self, name: &str, allowed_tools: Option<Vec<String>>) -> bool {
287        let mut servers = self.servers.write().await;
288        let Some(server) = servers.get_mut(name) else {
289            return false;
290        };
291        let normalized = allowed_tools.map(normalize_allowed_tool_names);
292        if server.allowed_tools == normalized {
293            return true;
294        }
295        server.allowed_tools = normalized;
296        drop(servers);
297        self.persist_state().await;
298        true
299    }
300
301    pub async fn set_grounding_metadata(
302        &self,
303        name: &str,
304        purpose: Option<String>,
305        grounding_required: Option<bool>,
306    ) -> bool {
307        let mut servers = self.servers.write().await;
308        let Some(server) = servers.get_mut(name) else {
309            return false;
310        };
311        let mut changed = false;
312        if let Some(purpose) = purpose {
313            let normalized = purpose.trim().to_ascii_lowercase();
314            if server.purpose != normalized {
315                server.purpose = normalized;
316                changed = true;
317            }
318        }
319        if let Some(grounding_required) = grounding_required {
320            if server.grounding_required != grounding_required {
321                server.grounding_required = grounding_required;
322                changed = true;
323            }
324        }
325        drop(servers);
326        if changed {
327            self.persist_state().await;
328        }
329        true
330    }
331
332    pub async fn set_enabled(&self, name: &str, enabled: bool) -> bool {
333        let mut servers = self.servers.write().await;
334        let Some(server) = servers.get_mut(name) else {
335            return false;
336        };
337        server.enabled = enabled;
338        if !enabled {
339            server.connected = false;
340            server.pid = None;
341            server.last_auth_challenge = None;
342            server.mcp_session_id = None;
343            server.pending_auth_by_tool.clear();
344        }
345        drop(servers);
346        if !enabled {
347            if let Some(mut child) = self.processes.lock().await.remove(name) {
348                let _ = child.kill().await;
349                let _ = child.wait().await;
350            }
351        }
352        self.persist_state().await;
353        true
354    }
355
356    pub async fn remove(&self, name: &str) -> bool {
357        let removed_server = {
358            let mut servers = self.servers.write().await;
359            servers.remove(name)
360        };
361        let Some(server) = removed_server else {
362            return false;
363        };
364        let current_tenant = local_tenant_context();
365        delete_secret_header_refs(&server.secret_headers, &current_tenant);
366        delete_oauth_secret_ref(server.oauth.as_ref(), &current_tenant);
367
368        if let Some(mut child) = self.processes.lock().await.remove(name) {
369            let _ = child.kill().await;
370            let _ = child.wait().await;
371        }
372        self.persist_state().await;
373        true
374    }
375
376    pub async fn connect(&self, name: &str) -> bool {
377        let server = {
378            let servers = self.servers.read().await;
379            let Some(server) = servers.get(name) else {
380                return false;
381            };
382            server.clone()
383        };
384
385        if !server.enabled {
386            let mut servers = self.servers.write().await;
387            if let Some(entry) = servers.get_mut(name) {
388                entry.connected = false;
389                entry.pid = None;
390                entry.last_error = Some("MCP server is disabled".to_string());
391                entry.last_auth_challenge = None;
392                entry.mcp_session_id = None;
393                entry.pending_auth_by_tool.clear();
394            }
395            drop(servers);
396            self.persist_state().await;
397            return false;
398        }
399
400        if let Some(command_text) = parse_stdio_transport(&server.transport) {
401            return self.connect_stdio(name, command_text).await;
402        }
403
404        if parse_remote_endpoint(&server.transport).is_some() {
405            return self.refresh(name).await.is_ok();
406        }
407
408        let mut servers = self.servers.write().await;
409        if let Some(entry) = servers.get_mut(name) {
410            entry.connected = true;
411            entry.pid = None;
412            entry.last_error = None;
413            entry.last_auth_challenge = None;
414            entry.mcp_session_id = None;
415            entry.pending_auth_by_tool.clear();
416        }
417        drop(servers);
418        self.persist_state().await;
419        true
420    }
421
422    pub async fn refresh(&self, name: &str) -> Result<Vec<McpRemoteTool>, String> {
423        let server = {
424            let servers = self.servers.read().await;
425            let Some(server) = servers.get(name) else {
426                return Err("MCP server not found".to_string());
427            };
428            server.clone()
429        };
430
431        if !server.enabled {
432            return Err("MCP server is disabled".to_string());
433        }
434
435        let endpoint = parse_remote_endpoint(&server.transport)
436            .ok_or_else(|| "MCP refresh currently supports HTTP/S transports only".to_string())?;
437
438        let _ = self.ensure_oauth_bearer_token_fresh(name, false).await;
439        let server = {
440            let servers = self.servers.read().await;
441            let Some(server) = servers.get(name) else {
442                return Err("MCP server not found".to_string());
443            };
444            server.clone()
445        };
446        let request_headers = effective_headers(&server);
447        let discovery = self
448            .discover_remote_tools(name, &endpoint, &request_headers)
449            .await;
450        let (tools, session_id) = match discovery {
451            Ok(result) => result,
452            Err(DiscoverRemoteToolsError::AuthChallenge(challenge)) => {
453                let mut servers = self.servers.write().await;
454                if let Some(entry) = servers.get_mut(name) {
455                    entry.connected = false;
456                    entry.pid = None;
457                    entry.last_error = Some(challenge.message.clone());
458                    entry.last_auth_challenge = Some(challenge.clone());
459                    entry.mcp_session_id = None;
460                    entry.pending_auth_by_tool.clear();
461                    entry.tool_cache.clear();
462                    entry.tools_fetched_at_ms = None;
463                }
464                drop(servers);
465                self.persist_state().await;
466                return Err(format!(
467                    "MCP server '{name}' requires authorization: {}",
468                    challenge.message
469                ));
470            }
471            Err(DiscoverRemoteToolsError::Message(err)) => {
472                if should_retry_mcp_oauth_refresh(&server, &err)
473                    && self.ensure_oauth_bearer_token_fresh(name, true).await?
474                {
475                    let refreshed_server = {
476                        let servers = self.servers.read().await;
477                        servers
478                            .get(name)
479                            .cloned()
480                            .ok_or_else(|| "MCP server not found".to_string())?
481                    };
482                    match self
483                        .discover_remote_tools(
484                            name,
485                            &endpoint,
486                            &effective_headers(&refreshed_server),
487                        )
488                        .await
489                    {
490                        Ok(result) => result,
491                        Err(DiscoverRemoteToolsError::AuthChallenge(challenge)) => {
492                            let mut servers = self.servers.write().await;
493                            if let Some(entry) = servers.get_mut(name) {
494                                entry.connected = false;
495                                entry.pid = None;
496                                entry.last_error = Some(challenge.message.clone());
497                                entry.last_auth_challenge = Some(challenge.clone());
498                                entry.mcp_session_id = None;
499                                entry.pending_auth_by_tool.clear();
500                                entry.tool_cache.clear();
501                                entry.tools_fetched_at_ms = None;
502                            }
503                            drop(servers);
504                            self.persist_state().await;
505                            return Err(format!(
506                                "MCP server '{name}' requires authorization: {}",
507                                challenge.message
508                            ));
509                        }
510                        Err(DiscoverRemoteToolsError::Message(retry_err)) => {
511                            let mut servers = self.servers.write().await;
512                            if let Some(entry) = servers.get_mut(name) {
513                                entry.connected = false;
514                                entry.pid = None;
515                                entry.last_error = Some(retry_err.clone());
516                                entry.last_auth_challenge = None;
517                                entry.mcp_session_id = None;
518                                entry.pending_auth_by_tool.clear();
519                                entry.tool_cache.clear();
520                                entry.tools_fetched_at_ms = None;
521                            }
522                            drop(servers);
523                            self.persist_state().await;
524                            return Err(retry_err);
525                        }
526                    }
527                } else {
528                    let mut servers = self.servers.write().await;
529                    if let Some(entry) = servers.get_mut(name) {
530                        entry.connected = false;
531                        entry.pid = None;
532                        entry.last_error = Some(err.clone());
533                        entry.last_auth_challenge = None;
534                        entry.mcp_session_id = None;
535                        entry.pending_auth_by_tool.clear();
536                        entry.tool_cache.clear();
537                        entry.tools_fetched_at_ms = None;
538                    }
539                    drop(servers);
540                    self.persist_state().await;
541                    return Err(err);
542                }
543            }
544        };
545
546        let now = now_ms();
547        let cache = tools
548            .iter()
549            .map(|tool| McpToolCacheEntry {
550                tool_name: tool.tool_name.clone(),
551                description: tool.description.clone(),
552                input_schema: tool.input_schema.clone(),
553                fetched_at_ms: now,
554                schema_hash: schema_hash(&tool.input_schema),
555            })
556            .collect::<Vec<_>>();
557
558        let mut servers = self.servers.write().await;
559        if let Some(entry) = servers.get_mut(name) {
560            entry.connected = true;
561            entry.pid = None;
562            entry.last_error = None;
563            entry.last_auth_challenge = None;
564            entry.mcp_session_id = session_id;
565            entry.tool_cache = cache;
566            entry.tools_fetched_at_ms = Some(now);
567            entry.pending_auth_by_tool.clear();
568        }
569        drop(servers);
570        self.persist_state().await;
571        Ok(self.server_tools(name).await)
572    }
573
574    pub async fn disconnect(&self, name: &str) -> bool {
575        if let Some(mut child) = self.processes.lock().await.remove(name) {
576            let _ = child.kill().await;
577            let _ = child.wait().await;
578        }
579        let mut servers = self.servers.write().await;
580        if let Some(server) = servers.get_mut(name) {
581            server.connected = false;
582            server.pid = None;
583            server.last_auth_challenge = None;
584            server.mcp_session_id = None;
585            server.pending_auth_by_tool.clear();
586            drop(servers);
587            self.persist_state().await;
588            return true;
589        }
590        false
591    }
592
593    pub async fn complete_auth(&self, name: &str) -> bool {
594        let mut servers = self.servers.write().await;
595        let Some(server) = servers.get_mut(name) else {
596            return false;
597        };
598        server.last_error = None;
599        server.last_auth_challenge = None;
600        server.pending_auth_by_tool.clear();
601        drop(servers);
602        self.persist_state().await;
603        true
604    }
605
606    pub async fn set_auth_kind(&self, name: &str, auth_kind: String) -> bool {
607        let normalized = normalize_auth_kind(&auth_kind);
608        let mut servers = self.servers.write().await;
609        let Some(server) = servers.get_mut(name) else {
610            return false;
611        };
612        server.auth_kind = normalized;
613        drop(servers);
614        self.persist_state().await;
615        true
616    }
617
618    pub async fn record_server_auth_challenge(
619        &self,
620        name: &str,
621        challenge: McpAuthChallenge,
622        last_error: Option<String>,
623    ) -> bool {
624        let mut servers = self.servers.write().await;
625        let Some(server) = servers.get_mut(name) else {
626            return false;
627        };
628        let tool_key = canonical_tool_key(&challenge.tool_name);
629        server.connected = false;
630        server.pid = None;
631        server.last_error = last_error.or_else(|| Some(challenge.message.clone()));
632        server.last_auth_challenge = Some(challenge.clone());
633        server.mcp_session_id = None;
634        server.pending_auth_by_tool.clear();
635        server
636            .pending_auth_by_tool
637            .insert(tool_key, pending_auth_from_challenge(&challenge));
638        drop(servers);
639        self.persist_state().await;
640        true
641    }
642
643    pub async fn clear_server_auth_challenge(&self, name: &str) -> bool {
644        let mut servers = self.servers.write().await;
645        let Some(server) = servers.get_mut(name) else {
646            return false;
647        };
648        server.last_auth_challenge = None;
649        server.pending_auth_by_tool.clear();
650        drop(servers);
651        self.persist_state().await;
652        true
653    }
654
655    pub async fn set_bearer_token(&self, name: &str, token: &str) -> Result<bool, String> {
656        let trimmed = token.trim();
657        if trimmed.is_empty() {
658            return Err("oauth access token cannot be empty".to_string());
659        }
660        let current_tenant = local_tenant_context();
661        let mut servers = self.servers.write().await;
662        let Some(server) = servers.get_mut(name) else {
663            return Ok(false);
664        };
665        let header_name = "Authorization".to_string();
666        let secret_id = mcp_header_secret_id(name, &header_name);
667        tandem_core::set_provider_auth(&secret_id, &format!("Bearer {trimmed}"))
668            .map_err(|error| error.to_string())?;
669        server.secret_headers.insert(
670            header_name.clone(),
671            McpSecretRef::Store {
672                secret_id: secret_id.clone(),
673                tenant_context: current_tenant,
674            },
675        );
676        server
677            .secret_header_values
678            .insert(header_name.clone(), format!("Bearer {trimmed}"));
679        server.headers.remove(&header_name);
680        drop(servers);
681        self.persist_state().await;
682        Ok(true)
683    }
684
685    pub async fn set_oauth_refresh_config(
686        &self,
687        name: &str,
688        provider_id: String,
689        token_endpoint: String,
690        client_id: String,
691        client_secret: Option<String>,
692    ) -> Result<bool, String> {
693        let current_tenant = local_tenant_context();
694        let mut servers = self.servers.write().await;
695        let Some(server) = servers.get_mut(name) else {
696            return Ok(false);
697        };
698
699        let client_secret_ref = client_secret
700            .as_deref()
701            .map(str::trim)
702            .filter(|value| !value.is_empty())
703            .map(|value| -> Result<McpSecretRef, String> {
704                let secret_id = mcp_oauth_client_secret_id(name);
705                tandem_core::set_provider_auth(&secret_id, value)
706                    .map_err(|error| error.to_string())?;
707                Ok(McpSecretRef::Store {
708                    secret_id,
709                    tenant_context: current_tenant.clone(),
710                })
711            })
712            .transpose()?;
713        if client_secret_ref.is_none() {
714            let secret_id = mcp_oauth_client_secret_id(name);
715            let _ = tandem_core::delete_provider_auth(&secret_id);
716        }
717
718        server.oauth = Some(McpOAuthConfig {
719            provider_id,
720            token_endpoint,
721            client_id,
722            client_secret_ref,
723            client_secret_value: client_secret
724                .map(|value| value.trim().to_string())
725                .filter(|value| !value.is_empty()),
726        });
727        drop(servers);
728        self.persist_state().await;
729        Ok(true)
730    }
731
732    pub async fn list_tools(&self) -> Vec<McpRemoteTool> {
733        let mut out = self
734            .servers
735            .read()
736            .await
737            .values()
738            .filter(|server| server.enabled && server.connected)
739            .flat_map(server_tool_rows)
740            .collect::<Vec<_>>();
741        out.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
742        out
743    }
744
745    pub async fn server_tools(&self, name: &str) -> Vec<McpRemoteTool> {
746        let Some(server) = self.servers.read().await.get(name).cloned() else {
747            return Vec::new();
748        };
749        let mut rows = server_tool_rows(&server);
750        rows.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
751        rows
752    }
753
754    pub async fn call_tool(
755        &self,
756        server_name: &str,
757        tool_name: &str,
758        args: Value,
759    ) -> Result<ToolResult, String> {
760        let server = {
761            let servers = self.servers.read().await;
762            let Some(server) = servers.get(server_name) else {
763                return Err(format!("MCP server '{server_name}' not found"));
764            };
765            server.clone()
766        };
767
768        if !server.enabled {
769            return Err(format!("MCP server '{server_name}' is disabled"));
770        }
771        if !server.connected {
772            return Err(format!("MCP server '{server_name}' is not connected"));
773        }
774
775        let endpoint = parse_remote_endpoint(&server.transport).ok_or_else(|| {
776            "MCP tools/call currently supports HTTP/S transports only".to_string()
777        })?;
778        let canonical_tool = canonical_tool_key(tool_name);
779        let now = now_ms();
780        let _ = self
781            .ensure_oauth_bearer_token_fresh(server_name, false)
782            .await;
783        let server = {
784            let servers = self.servers.read().await;
785            let Some(server) = servers.get(server_name) else {
786                return Err(format!("MCP server '{server_name}' not found"));
787            };
788            server.clone()
789        };
790        if let Some(blocked) = pending_auth_short_circuit(
791            &server,
792            &canonical_tool,
793            tool_name,
794            now,
795            MCP_AUTH_REPROBE_COOLDOWN_MS,
796        ) {
797            return Ok(ToolResult {
798                output: blocked.output,
799                metadata: json!({
800                    "server": server_name,
801                    "tool": tool_name,
802                    "result": Value::Null,
803                    "mcpAuth": blocked.mcp_auth
804                }),
805            });
806        }
807        let normalized_args = normalize_mcp_tool_args(&server, tool_name, args);
808
809        {
810            let mut servers = self.servers.write().await;
811            if let Some(row) = servers.get_mut(server_name) {
812                if let Some(pending) = row.pending_auth_by_tool.get_mut(&canonical_tool) {
813                    pending.last_probe_ms = now;
814                }
815            }
816        }
817
818        let request = json!({
819            "jsonrpc": "2.0",
820            "id": format!("call-{}-{}", server_name, now_ms()),
821            "method": "tools/call",
822            "params": {
823                "name": tool_name,
824                "arguments": normalized_args
825            }
826        });
827        let (response, session_id) = match post_json_rpc_with_session(
828            &endpoint,
829            &effective_headers(&server),
830            request.clone(),
831            server.mcp_session_id.as_deref(),
832        )
833        .await
834        {
835            Ok(result) => result,
836            Err(error) => {
837                if should_retry_mcp_oauth_refresh(&server, &error)
838                    && self
839                        .ensure_oauth_bearer_token_fresh(server_name, true)
840                        .await?
841                {
842                    let refreshed_server = {
843                        let servers = self.servers.read().await;
844                        servers
845                            .get(server_name)
846                            .cloned()
847                            .ok_or_else(|| format!("MCP server '{server_name}' not found"))?
848                    };
849                    post_json_rpc_with_session(
850                        &endpoint,
851                        &effective_headers(&refreshed_server),
852                        request,
853                        refreshed_server.mcp_session_id.as_deref(),
854                    )
855                    .await?
856                } else {
857                    return Err(error);
858                }
859            }
860        };
861        if session_id.is_some() {
862            let mut servers = self.servers.write().await;
863            if let Some(row) = servers.get_mut(server_name) {
864                row.mcp_session_id = session_id;
865            }
866            drop(servers);
867            self.persist_state().await;
868        }
869
870        if let Some(err) = response.get("error") {
871            if let Some(challenge) = extract_auth_challenge(err, tool_name) {
872                let output = format!(
873                    "{}\n\nAuthorize here: {}",
874                    challenge.message, challenge.authorization_url
875                );
876                {
877                    let mut servers = self.servers.write().await;
878                    if let Some(row) = servers.get_mut(server_name) {
879                        row.last_auth_challenge = Some(challenge.clone());
880                        row.last_error = None;
881                        row.pending_auth_by_tool.insert(
882                            canonical_tool.clone(),
883                            pending_auth_from_challenge(&challenge),
884                        );
885                    }
886                }
887                self.persist_state().await;
888                return Ok(ToolResult {
889                    output,
890                    metadata: json!({
891                        "server": server_name,
892                        "tool": tool_name,
893                        "result": Value::Null,
894                        "mcpAuth": {
895                            "required": true,
896                            "challengeId": challenge.challenge_id,
897                            "tool": challenge.tool_name,
898                            "authorizationUrl": challenge.authorization_url,
899                            "message": challenge.message,
900                            "status": challenge.status
901                        }
902                    }),
903                });
904            }
905            let message = err
906                .get("message")
907                .and_then(|v| v.as_str())
908                .unwrap_or("MCP tools/call failed");
909            return Err(message.to_string());
910        }
911
912        let result = response.get("result").cloned().unwrap_or(Value::Null);
913        let auth_challenge = extract_auth_challenge(&result, tool_name);
914        let output = if let Some(challenge) = auth_challenge.as_ref() {
915            format!(
916                "{}\n\nAuthorize here: {}",
917                challenge.message, challenge.authorization_url
918            )
919        } else {
920            result
921                .get("content")
922                .map(render_mcp_content)
923                .or_else(|| result.get("output").map(|v| v.to_string()))
924                .unwrap_or_else(|| result.to_string())
925        };
926
927        {
928            let mut servers = self.servers.write().await;
929            if let Some(row) = servers.get_mut(server_name) {
930                row.last_auth_challenge = auth_challenge.clone();
931                if let Some(challenge) = auth_challenge.as_ref() {
932                    row.pending_auth_by_tool.insert(
933                        canonical_tool.clone(),
934                        pending_auth_from_challenge(challenge),
935                    );
936                } else {
937                    row.pending_auth_by_tool.remove(&canonical_tool);
938                }
939            }
940        }
941        self.persist_state().await;
942
943        let auth_metadata = auth_challenge.as_ref().map(|challenge| {
944            json!({
945                "required": true,
946                "challengeId": challenge.challenge_id,
947                "tool": challenge.tool_name,
948                "authorizationUrl": challenge.authorization_url,
949                "message": challenge.message,
950                "status": challenge.status
951            })
952        });
953
954        Ok(ToolResult {
955            output,
956            metadata: json!({
957                "server": server_name,
958                "tool": tool_name,
959                "result": result,
960                "mcpAuth": auth_metadata
961            }),
962        })
963    }
964
965    async fn connect_stdio(&self, name: &str, command_text: &str) -> bool {
966        match spawn_stdio_process(command_text).await {
967            Ok(child) => {
968                let pid = child.id();
969                self.processes.lock().await.insert(name.to_string(), child);
970                let mut servers = self.servers.write().await;
971                if let Some(server) = servers.get_mut(name) {
972                    server.connected = true;
973                    server.pid = pid;
974                    server.last_error = None;
975                    server.last_auth_challenge = None;
976                    server.pending_auth_by_tool.clear();
977                }
978                drop(servers);
979                self.persist_state().await;
980                true
981            }
982            Err(err) => {
983                let mut servers = self.servers.write().await;
984                if let Some(server) = servers.get_mut(name) {
985                    server.connected = false;
986                    server.pid = None;
987                    server.last_error = Some(err);
988                    server.last_auth_challenge = None;
989                    server.pending_auth_by_tool.clear();
990                }
991                drop(servers);
992                self.persist_state().await;
993                false
994            }
995        }
996    }
997
998    async fn discover_remote_tools(
999        &self,
1000        server_name: &str,
1001        endpoint: &str,
1002        headers: &HashMap<String, String>,
1003    ) -> Result<(Vec<McpRemoteTool>, Option<String>), DiscoverRemoteToolsError> {
1004        let initialize = json!({
1005            "jsonrpc": "2.0",
1006            "id": "initialize-1",
1007            "method": "initialize",
1008            "params": {
1009                "protocolVersion": MCP_PROTOCOL_VERSION,
1010                "capabilities": {},
1011                "clientInfo": {
1012                    "name": MCP_CLIENT_NAME,
1013                    "version": MCP_CLIENT_VERSION,
1014                }
1015            }
1016        });
1017        let (init_response, mut session_id) =
1018            post_json_rpc_with_session(endpoint, headers, initialize, None).await?;
1019        if let Some(err) = init_response.get("error") {
1020            if let Some(challenge) = extract_auth_challenge(err, server_name) {
1021                return Err(DiscoverRemoteToolsError::AuthChallenge(challenge));
1022            }
1023            let message = err
1024                .get("message")
1025                .and_then(|v| v.as_str())
1026                .unwrap_or("MCP initialize failed");
1027            return Err(DiscoverRemoteToolsError::Message(message.to_string()));
1028        }
1029
1030        let tools_list = json!({
1031            "jsonrpc": "2.0",
1032            "id": "tools-list-1",
1033            "method": "tools/list",
1034            "params": {}
1035        });
1036        let (tools_response, next_session_id) =
1037            post_json_rpc_with_session(endpoint, headers, tools_list, session_id.as_deref())
1038                .await?;
1039        if next_session_id.is_some() {
1040            session_id = next_session_id;
1041        }
1042        if let Some(err) = tools_response.get("error") {
1043            if let Some(challenge) = extract_auth_challenge(err, server_name) {
1044                return Err(DiscoverRemoteToolsError::AuthChallenge(challenge));
1045            }
1046            let message = err
1047                .get("message")
1048                .and_then(|v| v.as_str())
1049                .unwrap_or("MCP tools/list failed");
1050            return Err(DiscoverRemoteToolsError::Message(message.to_string()));
1051        }
1052
1053        let tools = tools_response
1054            .get("result")
1055            .and_then(|v| v.get("tools"))
1056            .and_then(|v| v.as_array())
1057            .ok_or_else(|| "MCP tools/list result missing tools array".to_string())?;
1058
1059        let now = now_ms();
1060        let mut out = Vec::new();
1061        for row in tools {
1062            let Some(tool_name) = row.get("name").and_then(|v| v.as_str()) else {
1063                continue;
1064            };
1065            let description = row
1066                .get("description")
1067                .and_then(|v| v.as_str())
1068                .unwrap_or("")
1069                .to_string();
1070            let mut input_schema = row
1071                .get("inputSchema")
1072                .or_else(|| row.get("input_schema"))
1073                .cloned()
1074                .unwrap_or_else(|| json!({"type":"object"}));
1075            normalize_tool_input_schema(&mut input_schema);
1076            out.push(McpRemoteTool {
1077                server_name: String::new(),
1078                tool_name: tool_name.to_string(),
1079                namespaced_name: String::new(),
1080                description,
1081                input_schema,
1082                fetched_at_ms: now,
1083                schema_hash: String::new(),
1084            });
1085        }
1086
1087        Ok((out, session_id))
1088    }
1089
1090    async fn persist_state(&self) {
1091        let snapshot = self.servers.read().await.clone();
1092        persist_state_blocking(self.state_file.as_path(), &snapshot);
1093    }
1094
1095    async fn ensure_oauth_bearer_token_fresh(
1096        &self,
1097        name: &str,
1098        force: bool,
1099    ) -> Result<bool, String> {
1100        let server = {
1101            let servers = self.servers.read().await;
1102            servers.get(name).cloned()
1103        }
1104        .ok_or_else(|| format!("MCP server '{name}' not found"))?;
1105        let Some(oauth) = server.oauth.clone() else {
1106            return Ok(false);
1107        };
1108        let Some(credential) = tandem_core::load_provider_oauth_credential(&oauth.provider_id)
1109        else {
1110            return Ok(false);
1111        };
1112
1113        let should_refresh = force
1114            || credential.expires_at_ms <= now_ms().saturating_add(60_000)
1115            || credential.access_token.trim().is_empty();
1116        if !should_refresh {
1117            return Ok(false);
1118        }
1119
1120        let refreshed = refresh_mcp_oauth_credential(&oauth, &credential).await?;
1121        self.set_bearer_token(name, &refreshed.access_token).await?;
1122        tandem_core::set_provider_oauth_credential(&oauth.provider_id, refreshed)
1123            .map_err(|error| error.to_string())?;
1124        Ok(true)
1125    }
1126}
1127
1128impl Default for McpRegistry {
1129    fn default() -> Self {
1130        Self::new()
1131    }
1132}
1133
1134fn default_enabled() -> bool {
1135    true
1136}
1137
1138fn normalize_allowed_tool_names(raw: Vec<String>) -> Vec<String> {
1139    let mut normalized = Vec::new();
1140    let mut seen = std::collections::HashSet::new();
1141    for tool in raw {
1142        let value = tool.trim().to_string();
1143        if value.is_empty() || !seen.insert(value.clone()) {
1144            continue;
1145        }
1146        normalized.push(value);
1147    }
1148    normalized
1149}
1150
1151fn persist_state_blocking(path: &Path, snapshot: &HashMap<String, McpServer>) {
1152    if let Some(parent) = path.parent() {
1153        let _ = std::fs::create_dir_all(parent);
1154    }
1155    if let Ok(payload) = serde_json::to_string_pretty(snapshot) {
1156        let _ = std::fs::write(path, payload);
1157    }
1158}
1159
1160fn resolve_state_file() -> PathBuf {
1161    if let Ok(path) = std::env::var("TANDEM_MCP_REGISTRY") {
1162        return PathBuf::from(path);
1163    }
1164    if let Ok(state_dir) = std::env::var("TANDEM_STATE_DIR") {
1165        let trimmed = state_dir.trim();
1166        if !trimmed.is_empty() {
1167            return PathBuf::from(trimmed).join("mcp_servers.json");
1168        }
1169    }
1170    if let Some(data_dir) = dirs::data_dir() {
1171        return data_dir
1172            .join("tandem")
1173            .join("data")
1174            .join("mcp_servers.json");
1175    }
1176    dirs::home_dir()
1177        .map(|home| home.join(".tandem").join("data").join("mcp_servers.json"))
1178        .unwrap_or_else(|| PathBuf::from("mcp_servers.json"))
1179}
1180
1181fn load_state(path: &Path) -> (HashMap<String, McpServer>, bool) {
1182    let Ok(raw) = std::fs::read_to_string(path) else {
1183        return (HashMap::new(), false);
1184    };
1185    let mut migrated = false;
1186    let mut parsed = serde_json::from_str::<HashMap<String, McpServer>>(&raw).unwrap_or_default();
1187    for (name, server) in parsed.iter_mut() {
1188        let tenant_context = local_tenant_context();
1189        let (headers, secret_headers, secret_header_values, server_migrated) =
1190            migrate_server_headers(name, server, &tenant_context);
1191        migrated = migrated || server_migrated;
1192        server.headers = headers;
1193        server.secret_headers = secret_headers;
1194        server.secret_header_values = secret_header_values;
1195    }
1196    (parsed, migrated)
1197}
1198
1199fn migrate_server_headers(
1200    server_name: &str,
1201    server: &McpServer,
1202    current_tenant: &TenantContext,
1203) -> (
1204    HashMap<String, String>,
1205    HashMap<String, McpSecretRef>,
1206    HashMap<String, String>,
1207    bool,
1208) {
1209    let original_effective = effective_headers(server);
1210    let mut persisted_secret_headers = server.secret_headers.clone();
1211    let mut secret_header_values =
1212        resolve_secret_header_values(&persisted_secret_headers, current_tenant);
1213    let mut persisted_headers = server.headers.clone();
1214    let mut migrated = false;
1215
1216    let header_keys = persisted_headers.keys().cloned().collect::<Vec<_>>();
1217    for header_name in header_keys {
1218        let Some(value) = persisted_headers.get(&header_name).cloned() else {
1219            continue;
1220        };
1221        if persisted_secret_headers.contains_key(&header_name) {
1222            continue;
1223        }
1224        if let Some(secret_ref) = parse_secret_header_reference(value.trim()) {
1225            persisted_headers.remove(&header_name);
1226            let resolved =
1227                resolve_secret_ref_value(&secret_ref, current_tenant).unwrap_or_default();
1228            persisted_secret_headers.insert(header_name.clone(), secret_ref);
1229            if !resolved.is_empty() {
1230                secret_header_values.insert(header_name.clone(), resolved);
1231            }
1232            migrated = true;
1233            continue;
1234        }
1235        if header_name_is_sensitive(&header_name) && !value.trim().is_empty() {
1236            let secret_id = mcp_header_secret_id(server_name, &header_name);
1237            if tandem_core::set_provider_auth(&secret_id, &value).is_ok() {
1238                persisted_headers.remove(&header_name);
1239                persisted_secret_headers.insert(
1240                    header_name.clone(),
1241                    McpSecretRef::Store {
1242                        secret_id: secret_id.clone(),
1243                        tenant_context: current_tenant.clone(),
1244                    },
1245                );
1246                secret_header_values.insert(header_name.clone(), value);
1247                migrated = true;
1248            }
1249        }
1250    }
1251
1252    if !migrated {
1253        let effective = combine_headers(&persisted_headers, &secret_header_values);
1254        migrated = effective != original_effective;
1255    }
1256
1257    (
1258        persisted_headers,
1259        persisted_secret_headers,
1260        secret_header_values,
1261        migrated,
1262    )
1263}
1264
1265fn split_headers_for_storage(
1266    server_name: &str,
1267    headers: HashMap<String, String>,
1268    explicit_secret_headers: HashMap<String, McpSecretRef>,
1269    current_tenant: &TenantContext,
1270) -> (
1271    HashMap<String, String>,
1272    HashMap<String, McpSecretRef>,
1273    HashMap<String, String>,
1274) {
1275    let mut persisted_headers = HashMap::new();
1276    let mut persisted_secret_headers = HashMap::new();
1277    let mut secret_header_values = HashMap::new();
1278
1279    for (header_name, raw_value) in headers {
1280        let value = raw_value.trim().to_string();
1281        if value.is_empty() {
1282            continue;
1283        }
1284        if let Some(secret_ref) = parse_secret_header_reference(&value) {
1285            if let Some(resolved) = resolve_secret_ref_value(&secret_ref, current_tenant) {
1286                secret_header_values.insert(header_name.clone(), resolved);
1287            }
1288            persisted_secret_headers.insert(header_name, secret_ref);
1289            continue;
1290        }
1291        if header_name_is_sensitive(&header_name) {
1292            let secret_id = mcp_header_secret_id(server_name, &header_name);
1293            if tandem_core::set_provider_auth(&secret_id, &value).is_ok() {
1294                persisted_secret_headers.insert(
1295                    header_name.clone(),
1296                    McpSecretRef::Store {
1297                        secret_id: secret_id.clone(),
1298                        tenant_context: current_tenant.clone(),
1299                    },
1300                );
1301                secret_header_values.insert(header_name, value);
1302                continue;
1303            }
1304        }
1305        persisted_headers.insert(header_name, value);
1306    }
1307
1308    for (header_name, secret_ref) in explicit_secret_headers {
1309        if let Some(resolved) = resolve_secret_ref_value(&secret_ref, current_tenant) {
1310            secret_header_values.insert(header_name.clone(), resolved);
1311        }
1312        persisted_headers.remove(&header_name);
1313        persisted_secret_headers.insert(header_name, secret_ref);
1314    }
1315
1316    (
1317        persisted_headers,
1318        persisted_secret_headers,
1319        secret_header_values,
1320    )
1321}
1322
1323fn combine_headers(
1324    headers: &HashMap<String, String>,
1325    secret_header_values: &HashMap<String, String>,
1326) -> HashMap<String, String> {
1327    let mut combined = headers.clone();
1328    for (key, value) in secret_header_values {
1329        if !value.trim().is_empty() {
1330            combined.insert(key.clone(), value.clone());
1331        }
1332    }
1333    combined
1334}
1335
1336fn effective_headers(server: &McpServer) -> HashMap<String, String> {
1337    combine_headers(&server.headers, &server.secret_header_values)
1338}
1339
1340fn redacted_server_view(server: &McpServer) -> McpServer {
1341    let mut clone = server.clone();
1342    for (header_name, secret_ref) in &clone.secret_headers {
1343        clone.headers.insert(
1344            header_name.clone(),
1345            redacted_secret_header_value(secret_ref),
1346        );
1347    }
1348    clone.secret_header_values.clear();
1349    if let Some(oauth) = clone.oauth.as_mut() {
1350        oauth.client_secret_ref = None;
1351        oauth.client_secret_value = None;
1352    }
1353    clone
1354}
1355
1356fn normalize_auth_kind(raw: &str) -> String {
1357    match raw.trim().to_ascii_lowercase().as_str() {
1358        "oauth" | "auto" | "bearer" | "x-api-key" | "custom" | "none" => {
1359            raw.trim().to_ascii_lowercase()
1360        }
1361        _ => String::new(),
1362    }
1363}
1364
1365fn redacted_secret_header_value(secret_ref: &McpSecretRef) -> String {
1366    match secret_ref {
1367        McpSecretRef::BearerEnv { .. } => "Bearer ".to_string(),
1368        McpSecretRef::Env { .. } | McpSecretRef::Store { .. } => MCP_SECRET_PLACEHOLDER.to_string(),
1369    }
1370}
1371
1372fn resolve_secret_header_values(
1373    secret_headers: &HashMap<String, McpSecretRef>,
1374    current_tenant: &TenantContext,
1375) -> HashMap<String, String> {
1376    let mut out = HashMap::new();
1377    for (header_name, secret_ref) in secret_headers {
1378        if let Some(value) = resolve_secret_ref_value(secret_ref, current_tenant) {
1379            if !value.trim().is_empty() {
1380                out.insert(header_name.clone(), value);
1381            }
1382        }
1383    }
1384    out
1385}
1386
1387fn delete_secret_header_refs(
1388    secret_headers: &HashMap<String, McpSecretRef>,
1389    current_tenant: &TenantContext,
1390) {
1391    for secret_ref in secret_headers.values() {
1392        if let McpSecretRef::Store {
1393            secret_id,
1394            tenant_context,
1395        } = secret_ref
1396        {
1397            if tenant_context != current_tenant {
1398                continue;
1399            }
1400            let _ = tandem_core::delete_provider_auth(secret_id);
1401        }
1402    }
1403}
1404
1405fn delete_oauth_secret_ref(oauth: Option<&McpOAuthConfig>, current_tenant: &TenantContext) {
1406    let Some(secret_ref) = oauth.and_then(|oauth| oauth.client_secret_ref.as_ref()) else {
1407        return;
1408    };
1409    if let McpSecretRef::Store {
1410        secret_id,
1411        tenant_context,
1412    } = secret_ref
1413    {
1414        if tenant_context == current_tenant {
1415            let _ = tandem_core::delete_provider_auth(secret_id);
1416        }
1417    }
1418}