Skip to main content

tandem_runtime/mcp_parts/
part01.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, CONTENT_TYPE};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use sha2::{Digest, Sha256};
10use tandem_types::{LocalImplicitTenant, SecretRef, TenantContext, ToolResult};
11use tokio::process::{Child, Command};
12use tokio::sync::{Mutex, RwLock};
13
14const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
15const MCP_CLIENT_NAME: &str = "tandem";
16const MCP_CLIENT_VERSION: &str = env!("CARGO_PKG_VERSION");
17const MCP_AUTH_REPROBE_COOLDOWN_MS: u64 = 15_000;
18const MCP_SECRET_PLACEHOLDER: &str = "";
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct McpToolCacheEntry {
22    pub tool_name: String,
23    pub description: String,
24    #[serde(default)]
25    pub input_schema: Value,
26    pub fetched_at_ms: u64,
27    pub schema_hash: String,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct McpServer {
32    pub name: String,
33    pub transport: String,
34    #[serde(default, skip_serializing_if = "String::is_empty")]
35    pub auth_kind: String,
36    #[serde(default = "default_enabled")]
37    pub enabled: bool,
38    pub connected: bool,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub pid: Option<u32>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub last_error: Option<String>,
43    #[serde(default, skip_serializing_if = "Option::is_none")]
44    pub last_auth_challenge: Option<McpAuthChallenge>,
45    #[serde(default, skip_serializing_if = "Option::is_none")]
46    pub mcp_session_id: Option<String>,
47    #[serde(default)]
48    pub headers: HashMap<String, String>,
49    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
50    pub secret_headers: HashMap<String, McpSecretRef>,
51    #[serde(default)]
52    pub tool_cache: Vec<McpToolCacheEntry>,
53    #[serde(default, skip_serializing_if = "Option::is_none")]
54    pub tools_fetched_at_ms: Option<u64>,
55    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
56    pub pending_auth_by_tool: HashMap<String, PendingMcpAuth>,
57    #[serde(default, skip)]
58    pub secret_header_values: HashMap<String, String>,
59    #[serde(default, skip_serializing_if = "Option::is_none")]
60    pub oauth: Option<McpOAuthConfig>,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
64#[serde(tag = "type", rename_all = "snake_case")]
65pub enum McpSecretRef {
66    Store {
67        secret_id: String,
68        #[serde(default)]
69        tenant_context: TenantContext,
70    },
71    Env {
72        env: String,
73    },
74    BearerEnv {
75        env: String,
76    },
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct McpAuthChallenge {
81    pub challenge_id: String,
82    pub tool_name: String,
83    pub authorization_url: String,
84    pub message: String,
85    pub requested_at_ms: u64,
86    pub status: String,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct PendingMcpAuth {
91    pub challenge_id: String,
92    pub authorization_url: String,
93    pub message: String,
94    pub status: String,
95    pub first_seen_ms: u64,
96    pub last_probe_ms: u64,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct McpOAuthConfig {
101    pub provider_id: String,
102    pub token_endpoint: String,
103    pub client_id: String,
104    #[serde(default, skip_serializing_if = "Option::is_none")]
105    pub client_secret_ref: Option<McpSecretRef>,
106    #[serde(default, skip)]
107    pub client_secret_value: Option<String>,
108}
109
110#[derive(Debug, Clone)]
111enum DiscoverRemoteToolsError {
112    Message(String),
113    AuthChallenge(McpAuthChallenge),
114}
115
116impl From<String> for DiscoverRemoteToolsError {
117    fn from(value: String) -> Self {
118        Self::Message(value)
119    }
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct McpRemoteTool {
124    pub server_name: String,
125    pub tool_name: String,
126    pub namespaced_name: String,
127    pub description: String,
128    #[serde(default)]
129    pub input_schema: Value,
130    pub fetched_at_ms: u64,
131    pub schema_hash: String,
132}
133
134#[derive(Clone)]
135pub struct McpRegistry {
136    servers: Arc<RwLock<HashMap<String, McpServer>>>,
137    processes: Arc<Mutex<HashMap<String, Child>>>,
138    state_file: Arc<PathBuf>,
139}
140
141impl McpRegistry {
142    pub fn new() -> Self {
143        Self::new_with_state_file(resolve_state_file())
144    }
145
146    pub fn new_with_state_file(state_file: PathBuf) -> Self {
147        let (loaded_state, migrated) = load_state(&state_file);
148        let loaded = loaded_state
149            .into_iter()
150            .map(|(k, mut v)| {
151                v.connected = false;
152                v.pid = None;
153                if v.name.trim().is_empty() {
154                    v.name = k.clone();
155                }
156                if v.headers.is_empty() {
157                    v.headers = HashMap::new();
158                }
159                if v.secret_headers.is_empty() {
160                    v.secret_headers = HashMap::new();
161                }
162                let tenant_context = local_tenant_context();
163                v.secret_header_values =
164                    resolve_secret_header_values(&v.secret_headers, &tenant_context);
165                if let Some(oauth) = v.oauth.as_mut() {
166                    oauth.client_secret_value =
167                        oauth.client_secret_ref.as_ref().and_then(|secret_ref| {
168                            resolve_secret_ref_value(secret_ref, &tenant_context)
169                        });
170                }
171                (k, v)
172            })
173            .collect::<HashMap<_, _>>();
174        if migrated {
175            persist_state_blocking(&state_file, &loaded);
176        }
177        Self {
178            servers: Arc::new(RwLock::new(loaded)),
179            processes: Arc::new(Mutex::new(HashMap::new())),
180            state_file: Arc::new(state_file),
181        }
182    }
183
184    pub async fn list(&self) -> HashMap<String, McpServer> {
185        self.servers.read().await.clone()
186    }
187
188    pub async fn list_public(&self) -> HashMap<String, McpServer> {
189        self.servers
190            .read()
191            .await
192            .iter()
193            .map(|(name, server)| (name.clone(), redacted_server_view(server)))
194            .collect()
195    }
196
197    pub async fn add(&self, name: String, transport: String) {
198        self.add_or_update(name, transport, HashMap::new(), true)
199            .await;
200    }
201
202    pub async fn add_or_update(
203        &self,
204        name: String,
205        transport: String,
206        headers: HashMap<String, String>,
207        enabled: bool,
208    ) {
209        self.add_or_update_with_secret_refs(name, transport, headers, HashMap::new(), enabled)
210            .await;
211    }
212
213    pub async fn add_or_update_with_secret_refs(
214        &self,
215        name: String,
216        transport: String,
217        headers: HashMap<String, String>,
218        secret_headers: HashMap<String, McpSecretRef>,
219        enabled: bool,
220    ) {
221        let normalized_name = name.trim().to_string();
222        let tenant_context = local_tenant_context();
223        let (persisted_headers, persisted_secret_headers, secret_header_values) =
224            split_headers_for_storage(&normalized_name, headers, secret_headers, &tenant_context);
225        let mut servers = self.servers.write().await;
226        let existing = servers.get(&normalized_name).cloned();
227        let preserve_cache = existing.as_ref().is_some_and(|row| {
228            row.transport == transport
229                && effective_headers(row)
230                    == combine_headers(&persisted_headers, &secret_header_values)
231        });
232        let existing_tool_cache = if preserve_cache {
233            existing
234                .as_ref()
235                .map(|row| row.tool_cache.clone())
236                .unwrap_or_default()
237        } else {
238            Vec::new()
239        };
240        let existing_fetched_at = if preserve_cache {
241            existing.as_ref().and_then(|row| row.tools_fetched_at_ms)
242        } else {
243            None
244        };
245        let server = McpServer {
246            name: normalized_name.clone(),
247            transport,
248            auth_kind: existing
249                .as_ref()
250                .map(|row| row.auth_kind.clone())
251                .unwrap_or_default(),
252            enabled,
253            connected: false,
254            pid: None,
255            last_error: None,
256            last_auth_challenge: None,
257            mcp_session_id: None,
258            headers: persisted_headers,
259            secret_headers: persisted_secret_headers,
260            tool_cache: existing_tool_cache,
261            tools_fetched_at_ms: existing_fetched_at,
262            pending_auth_by_tool: HashMap::new(),
263            secret_header_values,
264            oauth: existing.as_ref().and_then(|row| row.oauth.clone()),
265        };
266        servers.insert(normalized_name, server);
267        drop(servers);
268        self.persist_state().await;
269    }
270
271    pub async fn set_enabled(&self, name: &str, enabled: bool) -> bool {
272        let mut servers = self.servers.write().await;
273        let Some(server) = servers.get_mut(name) else {
274            return false;
275        };
276        server.enabled = enabled;
277        if !enabled {
278            server.connected = false;
279            server.pid = None;
280            server.last_auth_challenge = None;
281            server.mcp_session_id = None;
282            server.pending_auth_by_tool.clear();
283        }
284        drop(servers);
285        if !enabled {
286            if let Some(mut child) = self.processes.lock().await.remove(name) {
287                let _ = child.kill().await;
288                let _ = child.wait().await;
289            }
290        }
291        self.persist_state().await;
292        true
293    }
294
295    pub async fn remove(&self, name: &str) -> bool {
296        let removed_server = {
297            let mut servers = self.servers.write().await;
298            servers.remove(name)
299        };
300        let Some(server) = removed_server else {
301            return false;
302        };
303        let current_tenant = local_tenant_context();
304        delete_secret_header_refs(&server.secret_headers, &current_tenant);
305        delete_oauth_secret_ref(server.oauth.as_ref(), &current_tenant);
306
307        if let Some(mut child) = self.processes.lock().await.remove(name) {
308            let _ = child.kill().await;
309            let _ = child.wait().await;
310        }
311        self.persist_state().await;
312        true
313    }
314
315    pub async fn connect(&self, name: &str) -> bool {
316        let server = {
317            let servers = self.servers.read().await;
318            let Some(server) = servers.get(name) else {
319                return false;
320            };
321            server.clone()
322        };
323
324        if !server.enabled {
325            let mut servers = self.servers.write().await;
326            if let Some(entry) = servers.get_mut(name) {
327                entry.connected = false;
328                entry.pid = None;
329                entry.last_error = Some("MCP server is disabled".to_string());
330                entry.last_auth_challenge = None;
331                entry.mcp_session_id = None;
332                entry.pending_auth_by_tool.clear();
333            }
334            drop(servers);
335            self.persist_state().await;
336            return false;
337        }
338
339        if let Some(command_text) = parse_stdio_transport(&server.transport) {
340            return self.connect_stdio(name, command_text).await;
341        }
342
343        if parse_remote_endpoint(&server.transport).is_some() {
344            return self.refresh(name).await.is_ok();
345        }
346
347        let mut servers = self.servers.write().await;
348        if let Some(entry) = servers.get_mut(name) {
349            entry.connected = true;
350            entry.pid = None;
351            entry.last_error = None;
352            entry.last_auth_challenge = None;
353            entry.mcp_session_id = None;
354            entry.pending_auth_by_tool.clear();
355        }
356        drop(servers);
357        self.persist_state().await;
358        true
359    }
360
361    pub async fn refresh(&self, name: &str) -> Result<Vec<McpRemoteTool>, String> {
362        let server = {
363            let servers = self.servers.read().await;
364            let Some(server) = servers.get(name) else {
365                return Err("MCP server not found".to_string());
366            };
367            server.clone()
368        };
369
370        if !server.enabled {
371            return Err("MCP server is disabled".to_string());
372        }
373
374        let endpoint = parse_remote_endpoint(&server.transport)
375            .ok_or_else(|| "MCP refresh currently supports HTTP/S transports only".to_string())?;
376
377        let _ = self.ensure_oauth_bearer_token_fresh(name, false).await;
378        let server = {
379            let servers = self.servers.read().await;
380            let Some(server) = servers.get(name) else {
381                return Err("MCP server not found".to_string());
382            };
383            server.clone()
384        };
385        let request_headers = effective_headers(&server);
386        let discovery = self
387            .discover_remote_tools(name, &endpoint, &request_headers)
388            .await;
389        let (tools, session_id) = match discovery {
390            Ok(result) => result,
391            Err(DiscoverRemoteToolsError::AuthChallenge(challenge)) => {
392                let mut servers = self.servers.write().await;
393                if let Some(entry) = servers.get_mut(name) {
394                    entry.connected = false;
395                    entry.pid = None;
396                    entry.last_error = Some(challenge.message.clone());
397                    entry.last_auth_challenge = Some(challenge.clone());
398                    entry.mcp_session_id = None;
399                    entry.pending_auth_by_tool.clear();
400                    entry.tool_cache.clear();
401                    entry.tools_fetched_at_ms = None;
402                }
403                drop(servers);
404                self.persist_state().await;
405                return Err(format!(
406                    "MCP server '{name}' requires authorization: {}",
407                    challenge.message
408                ));
409            }
410            Err(DiscoverRemoteToolsError::Message(err)) => {
411                if should_retry_mcp_oauth_refresh(&server, &err)
412                    && self.ensure_oauth_bearer_token_fresh(name, true).await?
413                {
414                    let refreshed_server = {
415                        let servers = self.servers.read().await;
416                        servers
417                            .get(name)
418                            .cloned()
419                            .ok_or_else(|| "MCP server not found".to_string())?
420                    };
421                    match self
422                        .discover_remote_tools(
423                            name,
424                            &endpoint,
425                            &effective_headers(&refreshed_server),
426                        )
427                        .await
428                    {
429                        Ok(result) => result,
430                        Err(DiscoverRemoteToolsError::AuthChallenge(challenge)) => {
431                            let mut servers = self.servers.write().await;
432                            if let Some(entry) = servers.get_mut(name) {
433                                entry.connected = false;
434                                entry.pid = None;
435                                entry.last_error = Some(challenge.message.clone());
436                                entry.last_auth_challenge = Some(challenge.clone());
437                                entry.mcp_session_id = None;
438                                entry.pending_auth_by_tool.clear();
439                                entry.tool_cache.clear();
440                                entry.tools_fetched_at_ms = None;
441                            }
442                            drop(servers);
443                            self.persist_state().await;
444                            return Err(format!(
445                                "MCP server '{name}' requires authorization: {}",
446                                challenge.message
447                            ));
448                        }
449                        Err(DiscoverRemoteToolsError::Message(retry_err)) => {
450                            let mut servers = self.servers.write().await;
451                            if let Some(entry) = servers.get_mut(name) {
452                                entry.connected = false;
453                                entry.pid = None;
454                                entry.last_error = Some(retry_err.clone());
455                                entry.last_auth_challenge = None;
456                                entry.mcp_session_id = None;
457                                entry.pending_auth_by_tool.clear();
458                                entry.tool_cache.clear();
459                                entry.tools_fetched_at_ms = None;
460                            }
461                            drop(servers);
462                            self.persist_state().await;
463                            return Err(retry_err);
464                        }
465                    }
466                } else {
467                    let mut servers = self.servers.write().await;
468                    if let Some(entry) = servers.get_mut(name) {
469                        entry.connected = false;
470                        entry.pid = None;
471                        entry.last_error = Some(err.clone());
472                        entry.last_auth_challenge = None;
473                        entry.mcp_session_id = None;
474                        entry.pending_auth_by_tool.clear();
475                        entry.tool_cache.clear();
476                        entry.tools_fetched_at_ms = None;
477                    }
478                    drop(servers);
479                    self.persist_state().await;
480                    return Err(err);
481                }
482            }
483        };
484
485        let now = now_ms();
486        let cache = tools
487            .iter()
488            .map(|tool| McpToolCacheEntry {
489                tool_name: tool.tool_name.clone(),
490                description: tool.description.clone(),
491                input_schema: tool.input_schema.clone(),
492                fetched_at_ms: now,
493                schema_hash: schema_hash(&tool.input_schema),
494            })
495            .collect::<Vec<_>>();
496
497        let mut servers = self.servers.write().await;
498        if let Some(entry) = servers.get_mut(name) {
499            entry.connected = true;
500            entry.pid = None;
501            entry.last_error = None;
502            entry.last_auth_challenge = None;
503            entry.mcp_session_id = session_id;
504            entry.tool_cache = cache;
505            entry.tools_fetched_at_ms = Some(now);
506            entry.pending_auth_by_tool.clear();
507        }
508        drop(servers);
509        self.persist_state().await;
510        Ok(self.server_tools(name).await)
511    }
512
513    pub async fn disconnect(&self, name: &str) -> bool {
514        if let Some(mut child) = self.processes.lock().await.remove(name) {
515            let _ = child.kill().await;
516            let _ = child.wait().await;
517        }
518        let mut servers = self.servers.write().await;
519        if let Some(server) = servers.get_mut(name) {
520            server.connected = false;
521            server.pid = None;
522            server.last_auth_challenge = None;
523            server.mcp_session_id = None;
524            server.pending_auth_by_tool.clear();
525            drop(servers);
526            self.persist_state().await;
527            return true;
528        }
529        false
530    }
531
532    pub async fn complete_auth(&self, name: &str) -> bool {
533        let mut servers = self.servers.write().await;
534        let Some(server) = servers.get_mut(name) else {
535            return false;
536        };
537        server.last_error = None;
538        server.last_auth_challenge = None;
539        server.pending_auth_by_tool.clear();
540        drop(servers);
541        self.persist_state().await;
542        true
543    }
544
545    pub async fn set_auth_kind(&self, name: &str, auth_kind: String) -> bool {
546        let normalized = normalize_auth_kind(&auth_kind);
547        let mut servers = self.servers.write().await;
548        let Some(server) = servers.get_mut(name) else {
549            return false;
550        };
551        server.auth_kind = normalized;
552        drop(servers);
553        self.persist_state().await;
554        true
555    }
556
557    pub async fn record_server_auth_challenge(
558        &self,
559        name: &str,
560        challenge: McpAuthChallenge,
561        last_error: Option<String>,
562    ) -> bool {
563        let mut servers = self.servers.write().await;
564        let Some(server) = servers.get_mut(name) else {
565            return false;
566        };
567        let tool_key = canonical_tool_key(&challenge.tool_name);
568        server.connected = false;
569        server.pid = None;
570        server.last_error = last_error.or_else(|| Some(challenge.message.clone()));
571        server.last_auth_challenge = Some(challenge.clone());
572        server.mcp_session_id = None;
573        server.pending_auth_by_tool.clear();
574        server
575            .pending_auth_by_tool
576            .insert(tool_key, pending_auth_from_challenge(&challenge));
577        drop(servers);
578        self.persist_state().await;
579        true
580    }
581
582    pub async fn clear_server_auth_challenge(&self, name: &str) -> bool {
583        let mut servers = self.servers.write().await;
584        let Some(server) = servers.get_mut(name) else {
585            return false;
586        };
587        server.last_auth_challenge = None;
588        server.pending_auth_by_tool.clear();
589        drop(servers);
590        self.persist_state().await;
591        true
592    }
593
594    pub async fn set_bearer_token(&self, name: &str, token: &str) -> Result<bool, String> {
595        let trimmed = token.trim();
596        if trimmed.is_empty() {
597            return Err("oauth access token cannot be empty".to_string());
598        }
599        let current_tenant = local_tenant_context();
600        let mut servers = self.servers.write().await;
601        let Some(server) = servers.get_mut(name) else {
602            return Ok(false);
603        };
604        let header_name = "Authorization".to_string();
605        let secret_id = mcp_header_secret_id(name, &header_name);
606        tandem_core::set_provider_auth(&secret_id, &format!("Bearer {trimmed}"))
607            .map_err(|error| error.to_string())?;
608        server.secret_headers.insert(
609            header_name.clone(),
610            McpSecretRef::Store {
611                secret_id: secret_id.clone(),
612                tenant_context: current_tenant,
613            },
614        );
615        server
616            .secret_header_values
617            .insert(header_name.clone(), format!("Bearer {trimmed}"));
618        server.headers.remove(&header_name);
619        drop(servers);
620        self.persist_state().await;
621        Ok(true)
622    }
623
624    pub async fn set_oauth_refresh_config(
625        &self,
626        name: &str,
627        provider_id: String,
628        token_endpoint: String,
629        client_id: String,
630        client_secret: Option<String>,
631    ) -> Result<bool, String> {
632        let current_tenant = local_tenant_context();
633        let mut servers = self.servers.write().await;
634        let Some(server) = servers.get_mut(name) else {
635            return Ok(false);
636        };
637
638        let client_secret_ref = client_secret
639            .as_deref()
640            .map(str::trim)
641            .filter(|value| !value.is_empty())
642            .map(|value| -> Result<McpSecretRef, String> {
643                let secret_id = mcp_oauth_client_secret_id(name);
644                tandem_core::set_provider_auth(&secret_id, value)
645                    .map_err(|error| error.to_string())?;
646                Ok(McpSecretRef::Store {
647                    secret_id,
648                    tenant_context: current_tenant.clone(),
649                })
650            })
651            .transpose()?;
652        if client_secret_ref.is_none() {
653            let secret_id = mcp_oauth_client_secret_id(name);
654            let _ = tandem_core::delete_provider_auth(&secret_id);
655        }
656
657        server.oauth = Some(McpOAuthConfig {
658            provider_id,
659            token_endpoint,
660            client_id,
661            client_secret_ref,
662            client_secret_value: client_secret
663                .map(|value| value.trim().to_string())
664                .filter(|value| !value.is_empty()),
665        });
666        drop(servers);
667        self.persist_state().await;
668        Ok(true)
669    }
670
671    pub async fn list_tools(&self) -> Vec<McpRemoteTool> {
672        let mut out = self
673            .servers
674            .read()
675            .await
676            .values()
677            .filter(|server| server.enabled && server.connected)
678            .flat_map(server_tool_rows)
679            .collect::<Vec<_>>();
680        out.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
681        out
682    }
683
684    pub async fn server_tools(&self, name: &str) -> Vec<McpRemoteTool> {
685        let Some(server) = self.servers.read().await.get(name).cloned() else {
686            return Vec::new();
687        };
688        let mut rows = server_tool_rows(&server);
689        rows.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
690        rows
691    }
692
693    pub async fn call_tool(
694        &self,
695        server_name: &str,
696        tool_name: &str,
697        args: Value,
698    ) -> Result<ToolResult, String> {
699        let server = {
700            let servers = self.servers.read().await;
701            let Some(server) = servers.get(server_name) else {
702                return Err(format!("MCP server '{server_name}' not found"));
703            };
704            server.clone()
705        };
706
707        if !server.enabled {
708            return Err(format!("MCP server '{server_name}' is disabled"));
709        }
710        if !server.connected {
711            return Err(format!("MCP server '{server_name}' is not connected"));
712        }
713
714        let endpoint = parse_remote_endpoint(&server.transport).ok_or_else(|| {
715            "MCP tools/call currently supports HTTP/S transports only".to_string()
716        })?;
717        let canonical_tool = canonical_tool_key(tool_name);
718        let now = now_ms();
719        let _ = self
720            .ensure_oauth_bearer_token_fresh(server_name, false)
721            .await;
722        let server = {
723            let servers = self.servers.read().await;
724            let Some(server) = servers.get(server_name) else {
725                return Err(format!("MCP server '{server_name}' not found"));
726            };
727            server.clone()
728        };
729        if let Some(blocked) = pending_auth_short_circuit(
730            &server,
731            &canonical_tool,
732            tool_name,
733            now,
734            MCP_AUTH_REPROBE_COOLDOWN_MS,
735        ) {
736            return Ok(ToolResult {
737                output: blocked.output,
738                metadata: json!({
739                    "server": server_name,
740                    "tool": tool_name,
741                    "result": Value::Null,
742                    "mcpAuth": blocked.mcp_auth
743                }),
744            });
745        }
746        let normalized_args = normalize_mcp_tool_args(&server, tool_name, args);
747
748        {
749            let mut servers = self.servers.write().await;
750            if let Some(row) = servers.get_mut(server_name) {
751                if let Some(pending) = row.pending_auth_by_tool.get_mut(&canonical_tool) {
752                    pending.last_probe_ms = now;
753                }
754            }
755        }
756
757        let request = json!({
758            "jsonrpc": "2.0",
759            "id": format!("call-{}-{}", server_name, now_ms()),
760            "method": "tools/call",
761            "params": {
762                "name": tool_name,
763                "arguments": normalized_args
764            }
765        });
766        let (response, session_id) = match post_json_rpc_with_session(
767            &endpoint,
768            &effective_headers(&server),
769            request.clone(),
770            server.mcp_session_id.as_deref(),
771        )
772        .await
773        {
774            Ok(result) => result,
775            Err(error) => {
776                if should_retry_mcp_oauth_refresh(&server, &error)
777                    && self
778                        .ensure_oauth_bearer_token_fresh(server_name, true)
779                        .await?
780                {
781                    let refreshed_server = {
782                        let servers = self.servers.read().await;
783                        servers
784                            .get(server_name)
785                            .cloned()
786                            .ok_or_else(|| format!("MCP server '{server_name}' not found"))?
787                    };
788                    post_json_rpc_with_session(
789                        &endpoint,
790                        &effective_headers(&refreshed_server),
791                        request,
792                        refreshed_server.mcp_session_id.as_deref(),
793                    )
794                    .await?
795                } else {
796                    return Err(error);
797                }
798            }
799        };
800        if session_id.is_some() {
801            let mut servers = self.servers.write().await;
802            if let Some(row) = servers.get_mut(server_name) {
803                row.mcp_session_id = session_id;
804            }
805            drop(servers);
806            self.persist_state().await;
807        }
808
809        if let Some(err) = response.get("error") {
810            if let Some(challenge) = extract_auth_challenge(err, tool_name) {
811                let output = format!(
812                    "{}\n\nAuthorize here: {}",
813                    challenge.message, challenge.authorization_url
814                );
815                {
816                    let mut servers = self.servers.write().await;
817                    if let Some(row) = servers.get_mut(server_name) {
818                        row.last_auth_challenge = Some(challenge.clone());
819                        row.last_error = None;
820                        row.pending_auth_by_tool.insert(
821                            canonical_tool.clone(),
822                            pending_auth_from_challenge(&challenge),
823                        );
824                    }
825                }
826                self.persist_state().await;
827                return Ok(ToolResult {
828                    output,
829                    metadata: json!({
830                        "server": server_name,
831                        "tool": tool_name,
832                        "result": Value::Null,
833                        "mcpAuth": {
834                            "required": true,
835                            "challengeId": challenge.challenge_id,
836                            "tool": challenge.tool_name,
837                            "authorizationUrl": challenge.authorization_url,
838                            "message": challenge.message,
839                            "status": challenge.status
840                        }
841                    }),
842                });
843            }
844            let message = err
845                .get("message")
846                .and_then(|v| v.as_str())
847                .unwrap_or("MCP tools/call failed");
848            return Err(message.to_string());
849        }
850
851        let result = response.get("result").cloned().unwrap_or(Value::Null);
852        let auth_challenge = extract_auth_challenge(&result, tool_name);
853        let output = if let Some(challenge) = auth_challenge.as_ref() {
854            format!(
855                "{}\n\nAuthorize here: {}",
856                challenge.message, challenge.authorization_url
857            )
858        } else {
859            result
860                .get("content")
861                .map(render_mcp_content)
862                .or_else(|| result.get("output").map(|v| v.to_string()))
863                .unwrap_or_else(|| result.to_string())
864        };
865
866        {
867            let mut servers = self.servers.write().await;
868            if let Some(row) = servers.get_mut(server_name) {
869                row.last_auth_challenge = auth_challenge.clone();
870                if let Some(challenge) = auth_challenge.as_ref() {
871                    row.pending_auth_by_tool.insert(
872                        canonical_tool.clone(),
873                        pending_auth_from_challenge(challenge),
874                    );
875                } else {
876                    row.pending_auth_by_tool.remove(&canonical_tool);
877                }
878            }
879        }
880        self.persist_state().await;
881
882        let auth_metadata = auth_challenge.as_ref().map(|challenge| {
883            json!({
884                "required": true,
885                "challengeId": challenge.challenge_id,
886                "tool": challenge.tool_name,
887                "authorizationUrl": challenge.authorization_url,
888                "message": challenge.message,
889                "status": challenge.status
890            })
891        });
892
893        Ok(ToolResult {
894            output,
895            metadata: json!({
896                "server": server_name,
897                "tool": tool_name,
898                "result": result,
899                "mcpAuth": auth_metadata
900            }),
901        })
902    }
903
904    async fn connect_stdio(&self, name: &str, command_text: &str) -> bool {
905        match spawn_stdio_process(command_text).await {
906            Ok(child) => {
907                let pid = child.id();
908                self.processes.lock().await.insert(name.to_string(), child);
909                let mut servers = self.servers.write().await;
910                if let Some(server) = servers.get_mut(name) {
911                    server.connected = true;
912                    server.pid = pid;
913                    server.last_error = None;
914                    server.last_auth_challenge = None;
915                    server.pending_auth_by_tool.clear();
916                }
917                drop(servers);
918                self.persist_state().await;
919                true
920            }
921            Err(err) => {
922                let mut servers = self.servers.write().await;
923                if let Some(server) = servers.get_mut(name) {
924                    server.connected = false;
925                    server.pid = None;
926                    server.last_error = Some(err);
927                    server.last_auth_challenge = None;
928                    server.pending_auth_by_tool.clear();
929                }
930                drop(servers);
931                self.persist_state().await;
932                false
933            }
934        }
935    }
936
937    async fn discover_remote_tools(
938        &self,
939        server_name: &str,
940        endpoint: &str,
941        headers: &HashMap<String, String>,
942    ) -> Result<(Vec<McpRemoteTool>, Option<String>), DiscoverRemoteToolsError> {
943        let initialize = json!({
944            "jsonrpc": "2.0",
945            "id": "initialize-1",
946            "method": "initialize",
947            "params": {
948                "protocolVersion": MCP_PROTOCOL_VERSION,
949                "capabilities": {},
950                "clientInfo": {
951                    "name": MCP_CLIENT_NAME,
952                    "version": MCP_CLIENT_VERSION,
953                }
954            }
955        });
956        let (init_response, mut session_id) =
957            post_json_rpc_with_session(endpoint, headers, initialize, None).await?;
958        if let Some(err) = init_response.get("error") {
959            if let Some(challenge) = extract_auth_challenge(err, server_name) {
960                return Err(DiscoverRemoteToolsError::AuthChallenge(challenge));
961            }
962            let message = err
963                .get("message")
964                .and_then(|v| v.as_str())
965                .unwrap_or("MCP initialize failed");
966            return Err(DiscoverRemoteToolsError::Message(message.to_string()));
967        }
968
969        let tools_list = json!({
970            "jsonrpc": "2.0",
971            "id": "tools-list-1",
972            "method": "tools/list",
973            "params": {}
974        });
975        let (tools_response, next_session_id) =
976            post_json_rpc_with_session(endpoint, headers, tools_list, session_id.as_deref())
977                .await?;
978        if next_session_id.is_some() {
979            session_id = next_session_id;
980        }
981        if let Some(err) = tools_response.get("error") {
982            if let Some(challenge) = extract_auth_challenge(err, server_name) {
983                return Err(DiscoverRemoteToolsError::AuthChallenge(challenge));
984            }
985            let message = err
986                .get("message")
987                .and_then(|v| v.as_str())
988                .unwrap_or("MCP tools/list failed");
989            return Err(DiscoverRemoteToolsError::Message(message.to_string()));
990        }
991
992        let tools = tools_response
993            .get("result")
994            .and_then(|v| v.get("tools"))
995            .and_then(|v| v.as_array())
996            .ok_or_else(|| "MCP tools/list result missing tools array".to_string())?;
997
998        let now = now_ms();
999        let mut out = Vec::new();
1000        for row in tools {
1001            let Some(tool_name) = row.get("name").and_then(|v| v.as_str()) else {
1002                continue;
1003            };
1004            let description = row
1005                .get("description")
1006                .and_then(|v| v.as_str())
1007                .unwrap_or("")
1008                .to_string();
1009            let mut input_schema = row
1010                .get("inputSchema")
1011                .or_else(|| row.get("input_schema"))
1012                .cloned()
1013                .unwrap_or_else(|| json!({"type":"object"}));
1014            normalize_tool_input_schema(&mut input_schema);
1015            out.push(McpRemoteTool {
1016                server_name: String::new(),
1017                tool_name: tool_name.to_string(),
1018                namespaced_name: String::new(),
1019                description,
1020                input_schema,
1021                fetched_at_ms: now,
1022                schema_hash: String::new(),
1023            });
1024        }
1025
1026        Ok((out, session_id))
1027    }
1028
1029    async fn persist_state(&self) {
1030        let snapshot = self.servers.read().await.clone();
1031        persist_state_blocking(self.state_file.as_path(), &snapshot);
1032    }
1033
1034    async fn ensure_oauth_bearer_token_fresh(
1035        &self,
1036        name: &str,
1037        force: bool,
1038    ) -> Result<bool, String> {
1039        let server = {
1040            let servers = self.servers.read().await;
1041            servers.get(name).cloned()
1042        }
1043        .ok_or_else(|| format!("MCP server '{name}' not found"))?;
1044        let Some(oauth) = server.oauth.clone() else {
1045            return Ok(false);
1046        };
1047        let Some(credential) = tandem_core::load_provider_oauth_credential(&oauth.provider_id)
1048        else {
1049            return Ok(false);
1050        };
1051
1052        let should_refresh = force
1053            || credential.expires_at_ms <= now_ms().saturating_add(60_000)
1054            || credential.access_token.trim().is_empty();
1055        if !should_refresh {
1056            return Ok(false);
1057        }
1058
1059        let refreshed = refresh_mcp_oauth_credential(&oauth, &credential).await?;
1060        self.set_bearer_token(name, &refreshed.access_token).await?;
1061        tandem_core::set_provider_oauth_credential(&oauth.provider_id, refreshed)
1062            .map_err(|error| error.to_string())?;
1063        Ok(true)
1064    }
1065}
1066
1067impl Default for McpRegistry {
1068    fn default() -> Self {
1069        Self::new()
1070    }
1071}
1072
1073fn default_enabled() -> bool {
1074    true
1075}
1076
1077fn persist_state_blocking(path: &Path, snapshot: &HashMap<String, McpServer>) {
1078    if let Some(parent) = path.parent() {
1079        let _ = std::fs::create_dir_all(parent);
1080    }
1081    if let Ok(payload) = serde_json::to_string_pretty(snapshot) {
1082        let _ = std::fs::write(path, payload);
1083    }
1084}
1085
1086fn resolve_state_file() -> PathBuf {
1087    if let Ok(path) = std::env::var("TANDEM_MCP_REGISTRY") {
1088        return PathBuf::from(path);
1089    }
1090    if let Ok(state_dir) = std::env::var("TANDEM_STATE_DIR") {
1091        let trimmed = state_dir.trim();
1092        if !trimmed.is_empty() {
1093            return PathBuf::from(trimmed).join("mcp_servers.json");
1094        }
1095    }
1096    if let Some(data_dir) = dirs::data_dir() {
1097        return data_dir
1098            .join("tandem")
1099            .join("data")
1100            .join("mcp_servers.json");
1101    }
1102    dirs::home_dir()
1103        .map(|home| home.join(".tandem").join("data").join("mcp_servers.json"))
1104        .unwrap_or_else(|| PathBuf::from("mcp_servers.json"))
1105}
1106
1107fn load_state(path: &Path) -> (HashMap<String, McpServer>, bool) {
1108    let Ok(raw) = std::fs::read_to_string(path) else {
1109        return (HashMap::new(), false);
1110    };
1111    let mut migrated = false;
1112    let mut parsed = serde_json::from_str::<HashMap<String, McpServer>>(&raw).unwrap_or_default();
1113    for (name, server) in parsed.iter_mut() {
1114        let tenant_context = local_tenant_context();
1115        let (headers, secret_headers, secret_header_values, server_migrated) =
1116            migrate_server_headers(name, server, &tenant_context);
1117        migrated = migrated || server_migrated;
1118        server.headers = headers;
1119        server.secret_headers = secret_headers;
1120        server.secret_header_values = secret_header_values;
1121    }
1122    (parsed, migrated)
1123}
1124
1125fn migrate_server_headers(
1126    server_name: &str,
1127    server: &McpServer,
1128    current_tenant: &TenantContext,
1129) -> (
1130    HashMap<String, String>,
1131    HashMap<String, McpSecretRef>,
1132    HashMap<String, String>,
1133    bool,
1134) {
1135    let original_effective = effective_headers(server);
1136    let mut persisted_secret_headers = server.secret_headers.clone();
1137    let mut secret_header_values =
1138        resolve_secret_header_values(&persisted_secret_headers, current_tenant);
1139    let mut persisted_headers = server.headers.clone();
1140    let mut migrated = false;
1141
1142    let header_keys = persisted_headers.keys().cloned().collect::<Vec<_>>();
1143    for header_name in header_keys {
1144        let Some(value) = persisted_headers.get(&header_name).cloned() else {
1145            continue;
1146        };
1147        if persisted_secret_headers.contains_key(&header_name) {
1148            continue;
1149        }
1150        if let Some(secret_ref) = parse_secret_header_reference(value.trim()) {
1151            persisted_headers.remove(&header_name);
1152            let resolved =
1153                resolve_secret_ref_value(&secret_ref, current_tenant).unwrap_or_default();
1154            persisted_secret_headers.insert(header_name.clone(), secret_ref);
1155            if !resolved.is_empty() {
1156                secret_header_values.insert(header_name.clone(), resolved);
1157            }
1158            migrated = true;
1159            continue;
1160        }
1161        if header_name_is_sensitive(&header_name) && !value.trim().is_empty() {
1162            let secret_id = mcp_header_secret_id(server_name, &header_name);
1163            if tandem_core::set_provider_auth(&secret_id, &value).is_ok() {
1164                persisted_headers.remove(&header_name);
1165                persisted_secret_headers.insert(
1166                    header_name.clone(),
1167                    McpSecretRef::Store {
1168                        secret_id: secret_id.clone(),
1169                        tenant_context: current_tenant.clone(),
1170                    },
1171                );
1172                secret_header_values.insert(header_name.clone(), value);
1173                migrated = true;
1174            }
1175        }
1176    }
1177
1178    if !migrated {
1179        let effective = combine_headers(&persisted_headers, &secret_header_values);
1180        migrated = effective != original_effective;
1181    }
1182
1183    (
1184        persisted_headers,
1185        persisted_secret_headers,
1186        secret_header_values,
1187        migrated,
1188    )
1189}
1190
1191fn split_headers_for_storage(
1192    server_name: &str,
1193    headers: HashMap<String, String>,
1194    explicit_secret_headers: HashMap<String, McpSecretRef>,
1195    current_tenant: &TenantContext,
1196) -> (
1197    HashMap<String, String>,
1198    HashMap<String, McpSecretRef>,
1199    HashMap<String, String>,
1200) {
1201    let mut persisted_headers = HashMap::new();
1202    let mut persisted_secret_headers = HashMap::new();
1203    let mut secret_header_values = HashMap::new();
1204
1205    for (header_name, raw_value) in headers {
1206        let value = raw_value.trim().to_string();
1207        if value.is_empty() {
1208            continue;
1209        }
1210        if let Some(secret_ref) = parse_secret_header_reference(&value) {
1211            if let Some(resolved) = resolve_secret_ref_value(&secret_ref, current_tenant) {
1212                secret_header_values.insert(header_name.clone(), resolved);
1213            }
1214            persisted_secret_headers.insert(header_name, secret_ref);
1215            continue;
1216        }
1217        if header_name_is_sensitive(&header_name) {
1218            let secret_id = mcp_header_secret_id(server_name, &header_name);
1219            if tandem_core::set_provider_auth(&secret_id, &value).is_ok() {
1220                persisted_secret_headers.insert(
1221                    header_name.clone(),
1222                    McpSecretRef::Store {
1223                        secret_id: secret_id.clone(),
1224                        tenant_context: current_tenant.clone(),
1225                    },
1226                );
1227                secret_header_values.insert(header_name, value);
1228                continue;
1229            }
1230        }
1231        persisted_headers.insert(header_name, value);
1232    }
1233
1234    for (header_name, secret_ref) in explicit_secret_headers {
1235        if let Some(resolved) = resolve_secret_ref_value(&secret_ref, current_tenant) {
1236            secret_header_values.insert(header_name.clone(), resolved);
1237        }
1238        persisted_headers.remove(&header_name);
1239        persisted_secret_headers.insert(header_name, secret_ref);
1240    }
1241
1242    (
1243        persisted_headers,
1244        persisted_secret_headers,
1245        secret_header_values,
1246    )
1247}
1248
1249fn combine_headers(
1250    headers: &HashMap<String, String>,
1251    secret_header_values: &HashMap<String, String>,
1252) -> HashMap<String, String> {
1253    let mut combined = headers.clone();
1254    for (key, value) in secret_header_values {
1255        if !value.trim().is_empty() {
1256            combined.insert(key.clone(), value.clone());
1257        }
1258    }
1259    combined
1260}
1261
1262fn effective_headers(server: &McpServer) -> HashMap<String, String> {
1263    combine_headers(&server.headers, &server.secret_header_values)
1264}
1265
1266fn redacted_server_view(server: &McpServer) -> McpServer {
1267    let mut clone = server.clone();
1268    for (header_name, secret_ref) in &clone.secret_headers {
1269        clone.headers.insert(
1270            header_name.clone(),
1271            redacted_secret_header_value(secret_ref),
1272        );
1273    }
1274    clone.secret_header_values.clear();
1275    if let Some(oauth) = clone.oauth.as_mut() {
1276        oauth.client_secret_ref = None;
1277        oauth.client_secret_value = None;
1278    }
1279    clone
1280}
1281
1282fn normalize_auth_kind(raw: &str) -> String {
1283    match raw.trim().to_ascii_lowercase().as_str() {
1284        "oauth" | "auto" | "bearer" | "x-api-key" | "custom" | "none" => {
1285            raw.trim().to_ascii_lowercase()
1286        }
1287        _ => String::new(),
1288    }
1289}
1290
1291fn redacted_secret_header_value(secret_ref: &McpSecretRef) -> String {
1292    match secret_ref {
1293        McpSecretRef::BearerEnv { .. } => "Bearer ".to_string(),
1294        McpSecretRef::Env { .. } | McpSecretRef::Store { .. } => MCP_SECRET_PLACEHOLDER.to_string(),
1295    }
1296}
1297
1298fn resolve_secret_header_values(
1299    secret_headers: &HashMap<String, McpSecretRef>,
1300    current_tenant: &TenantContext,
1301) -> HashMap<String, String> {
1302    let mut out = HashMap::new();
1303    for (header_name, secret_ref) in secret_headers {
1304        if let Some(value) = resolve_secret_ref_value(secret_ref, current_tenant) {
1305            if !value.trim().is_empty() {
1306                out.insert(header_name.clone(), value);
1307            }
1308        }
1309    }
1310    out
1311}
1312
1313fn delete_secret_header_refs(
1314    secret_headers: &HashMap<String, McpSecretRef>,
1315    current_tenant: &TenantContext,
1316) {
1317    for secret_ref in secret_headers.values() {
1318        if let McpSecretRef::Store {
1319            secret_id,
1320            tenant_context,
1321        } = secret_ref
1322        {
1323            if tenant_context != current_tenant {
1324                continue;
1325            }
1326            let _ = tandem_core::delete_provider_auth(secret_id);
1327        }
1328    }
1329}
1330
1331fn delete_oauth_secret_ref(oauth: Option<&McpOAuthConfig>, current_tenant: &TenantContext) {
1332    let Some(secret_ref) = oauth.and_then(|oauth| oauth.client_secret_ref.as_ref()) else {
1333        return;
1334    };
1335    if let McpSecretRef::Store {
1336        secret_id,
1337        tenant_context,
1338    } = secret_ref
1339    {
1340        if tenant_context == current_tenant {
1341            let _ = tandem_core::delete_provider_auth(secret_id);
1342        }
1343    }
1344}