1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, CONTENT_TYPE};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use sha2::{Digest, Sha256};
10use tandem_types::{LocalImplicitTenant, SecretRef, TenantContext, ToolResult};
11use tokio::process::{Child, Command};
12use tokio::sync::{Mutex, RwLock};
13
14const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
15const MCP_CLIENT_NAME: &str = "tandem";
16const MCP_CLIENT_VERSION: &str = env!("CARGO_PKG_VERSION");
17const MCP_AUTH_REPROBE_COOLDOWN_MS: u64 = 15_000;
18const MCP_SECRET_PLACEHOLDER: &str = "";
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct McpToolCacheEntry {
22 pub tool_name: String,
23 pub description: String,
24 #[serde(default)]
25 pub input_schema: Value,
26 pub fetched_at_ms: u64,
27 pub schema_hash: String,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct McpServer {
32 pub name: String,
33 pub transport: String,
34 #[serde(default, skip_serializing_if = "String::is_empty")]
35 pub auth_kind: String,
36 #[serde(default = "default_enabled")]
37 pub enabled: bool,
38 pub connected: bool,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub pid: Option<u32>,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub last_error: Option<String>,
43 #[serde(default, skip_serializing_if = "Option::is_none")]
44 pub last_auth_challenge: Option<McpAuthChallenge>,
45 #[serde(default, skip_serializing_if = "Option::is_none")]
46 pub mcp_session_id: Option<String>,
47 #[serde(default)]
48 pub headers: HashMap<String, String>,
49 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
50 pub secret_headers: HashMap<String, McpSecretRef>,
51 #[serde(default)]
52 pub tool_cache: Vec<McpToolCacheEntry>,
53 #[serde(default, skip_serializing_if = "Option::is_none")]
54 pub tools_fetched_at_ms: Option<u64>,
55 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
56 pub pending_auth_by_tool: HashMap<String, PendingMcpAuth>,
57 #[serde(default, skip_serializing_if = "Option::is_none")]
58 pub allowed_tools: Option<Vec<String>>,
59 #[serde(default, skip)]
60 pub secret_header_values: HashMap<String, String>,
61 #[serde(default, skip_serializing_if = "Option::is_none")]
62 pub oauth: Option<McpOAuthConfig>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
66#[serde(tag = "type", rename_all = "snake_case")]
67pub enum McpSecretRef {
68 Store {
69 secret_id: String,
70 #[serde(default)]
71 tenant_context: TenantContext,
72 },
73 Env {
74 env: String,
75 },
76 BearerEnv {
77 env: String,
78 },
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct McpAuthChallenge {
83 pub challenge_id: String,
84 pub tool_name: String,
85 pub authorization_url: String,
86 pub message: String,
87 pub requested_at_ms: u64,
88 pub status: String,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct PendingMcpAuth {
93 pub challenge_id: String,
94 pub authorization_url: String,
95 pub message: String,
96 pub status: String,
97 pub first_seen_ms: u64,
98 pub last_probe_ms: u64,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct McpOAuthConfig {
103 pub provider_id: String,
104 pub token_endpoint: String,
105 pub client_id: String,
106 #[serde(default, skip_serializing_if = "Option::is_none")]
107 pub client_secret_ref: Option<McpSecretRef>,
108 #[serde(default, skip)]
109 pub client_secret_value: Option<String>,
110}
111
112#[derive(Debug, Clone)]
113enum DiscoverRemoteToolsError {
114 Message(String),
115 AuthChallenge(McpAuthChallenge),
116}
117
118impl From<String> for DiscoverRemoteToolsError {
119 fn from(value: String) -> Self {
120 Self::Message(value)
121 }
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct McpRemoteTool {
126 pub server_name: String,
127 pub tool_name: String,
128 pub namespaced_name: String,
129 pub description: String,
130 #[serde(default)]
131 pub input_schema: Value,
132 pub fetched_at_ms: u64,
133 pub schema_hash: String,
134}
135
136#[derive(Clone)]
137pub struct McpRegistry {
138 servers: Arc<RwLock<HashMap<String, McpServer>>>,
139 processes: Arc<Mutex<HashMap<String, Child>>>,
140 state_file: Arc<PathBuf>,
141}
142
143impl McpRegistry {
144 pub fn new() -> Self {
145 Self::new_with_state_file(resolve_state_file())
146 }
147
148 pub fn new_with_state_file(state_file: PathBuf) -> Self {
149 let (loaded_state, migrated) = load_state(&state_file);
150 let loaded = loaded_state
151 .into_iter()
152 .map(|(k, mut v)| {
153 v.connected = false;
154 v.pid = None;
155 if v.name.trim().is_empty() {
156 v.name = k.clone();
157 }
158 if v.headers.is_empty() {
159 v.headers = HashMap::new();
160 }
161 if v.secret_headers.is_empty() {
162 v.secret_headers = HashMap::new();
163 }
164 let tenant_context = local_tenant_context();
165 v.secret_header_values =
166 resolve_secret_header_values(&v.secret_headers, &tenant_context);
167 if let Some(oauth) = v.oauth.as_mut() {
168 oauth.client_secret_value =
169 oauth.client_secret_ref.as_ref().and_then(|secret_ref| {
170 resolve_secret_ref_value(secret_ref, &tenant_context)
171 });
172 }
173 (k, v)
174 })
175 .collect::<HashMap<_, _>>();
176 if migrated {
177 persist_state_blocking(&state_file, &loaded);
178 }
179 Self {
180 servers: Arc::new(RwLock::new(loaded)),
181 processes: Arc::new(Mutex::new(HashMap::new())),
182 state_file: Arc::new(state_file),
183 }
184 }
185
186 pub async fn list(&self) -> HashMap<String, McpServer> {
187 self.servers.read().await.clone()
188 }
189
190 pub async fn list_public(&self) -> HashMap<String, McpServer> {
191 self.servers
192 .read()
193 .await
194 .iter()
195 .map(|(name, server)| (name.clone(), redacted_server_view(server)))
196 .collect()
197 }
198
199 pub async fn add(&self, name: String, transport: String) {
200 self.add_or_update(name, transport, HashMap::new(), true)
201 .await;
202 }
203
204 pub async fn add_or_update(
205 &self,
206 name: String,
207 transport: String,
208 headers: HashMap<String, String>,
209 enabled: bool,
210 ) {
211 self.add_or_update_with_secret_refs(name, transport, headers, HashMap::new(), enabled)
212 .await;
213 }
214
215 pub async fn add_or_update_with_secret_refs(
216 &self,
217 name: String,
218 transport: String,
219 headers: HashMap<String, String>,
220 secret_headers: HashMap<String, McpSecretRef>,
221 enabled: bool,
222 ) {
223 let normalized_name = name.trim().to_string();
224 let tenant_context = local_tenant_context();
225 let (persisted_headers, persisted_secret_headers, secret_header_values) =
226 split_headers_for_storage(&normalized_name, headers, secret_headers, &tenant_context);
227 let mut servers = self.servers.write().await;
228 let existing = servers.get(&normalized_name).cloned();
229 let preserve_cache = existing.as_ref().is_some_and(|row| {
230 row.transport == transport
231 && effective_headers(row)
232 == combine_headers(&persisted_headers, &secret_header_values)
233 });
234 let existing_tool_cache = if preserve_cache {
235 existing
236 .as_ref()
237 .map(|row| row.tool_cache.clone())
238 .unwrap_or_default()
239 } else {
240 Vec::new()
241 };
242 let existing_fetched_at = if preserve_cache {
243 existing.as_ref().and_then(|row| row.tools_fetched_at_ms)
244 } else {
245 None
246 };
247 let server = McpServer {
248 name: normalized_name.clone(),
249 transport,
250 auth_kind: existing
251 .as_ref()
252 .map(|row| row.auth_kind.clone())
253 .unwrap_or_default(),
254 enabled,
255 connected: false,
256 pid: None,
257 last_error: None,
258 last_auth_challenge: None,
259 mcp_session_id: None,
260 headers: persisted_headers,
261 secret_headers: persisted_secret_headers,
262 tool_cache: existing_tool_cache,
263 tools_fetched_at_ms: existing_fetched_at,
264 pending_auth_by_tool: HashMap::new(),
265 allowed_tools: existing.as_ref().and_then(|row| row.allowed_tools.clone()),
266 secret_header_values,
267 oauth: existing.as_ref().and_then(|row| row.oauth.clone()),
268 };
269 servers.insert(normalized_name, server);
270 drop(servers);
271 self.persist_state().await;
272 }
273
274 pub async fn set_allowed_tools(&self, name: &str, allowed_tools: Option<Vec<String>>) -> bool {
275 let mut servers = self.servers.write().await;
276 let Some(server) = servers.get_mut(name) else {
277 return false;
278 };
279 let normalized = allowed_tools.map(normalize_allowed_tool_names);
280 if server.allowed_tools == normalized {
281 return true;
282 }
283 server.allowed_tools = normalized;
284 drop(servers);
285 self.persist_state().await;
286 true
287 }
288
289 pub async fn set_enabled(&self, name: &str, enabled: bool) -> bool {
290 let mut servers = self.servers.write().await;
291 let Some(server) = servers.get_mut(name) else {
292 return false;
293 };
294 server.enabled = enabled;
295 if !enabled {
296 server.connected = false;
297 server.pid = None;
298 server.last_auth_challenge = None;
299 server.mcp_session_id = None;
300 server.pending_auth_by_tool.clear();
301 }
302 drop(servers);
303 if !enabled {
304 if let Some(mut child) = self.processes.lock().await.remove(name) {
305 let _ = child.kill().await;
306 let _ = child.wait().await;
307 }
308 }
309 self.persist_state().await;
310 true
311 }
312
313 pub async fn remove(&self, name: &str) -> bool {
314 let removed_server = {
315 let mut servers = self.servers.write().await;
316 servers.remove(name)
317 };
318 let Some(server) = removed_server else {
319 return false;
320 };
321 let current_tenant = local_tenant_context();
322 delete_secret_header_refs(&server.secret_headers, ¤t_tenant);
323 delete_oauth_secret_ref(server.oauth.as_ref(), ¤t_tenant);
324
325 if let Some(mut child) = self.processes.lock().await.remove(name) {
326 let _ = child.kill().await;
327 let _ = child.wait().await;
328 }
329 self.persist_state().await;
330 true
331 }
332
333 pub async fn connect(&self, name: &str) -> bool {
334 let server = {
335 let servers = self.servers.read().await;
336 let Some(server) = servers.get(name) else {
337 return false;
338 };
339 server.clone()
340 };
341
342 if !server.enabled {
343 let mut servers = self.servers.write().await;
344 if let Some(entry) = servers.get_mut(name) {
345 entry.connected = false;
346 entry.pid = None;
347 entry.last_error = Some("MCP server is disabled".to_string());
348 entry.last_auth_challenge = None;
349 entry.mcp_session_id = None;
350 entry.pending_auth_by_tool.clear();
351 }
352 drop(servers);
353 self.persist_state().await;
354 return false;
355 }
356
357 if let Some(command_text) = parse_stdio_transport(&server.transport) {
358 return self.connect_stdio(name, command_text).await;
359 }
360
361 if parse_remote_endpoint(&server.transport).is_some() {
362 return self.refresh(name).await.is_ok();
363 }
364
365 let mut servers = self.servers.write().await;
366 if let Some(entry) = servers.get_mut(name) {
367 entry.connected = true;
368 entry.pid = None;
369 entry.last_error = None;
370 entry.last_auth_challenge = None;
371 entry.mcp_session_id = None;
372 entry.pending_auth_by_tool.clear();
373 }
374 drop(servers);
375 self.persist_state().await;
376 true
377 }
378
379 pub async fn refresh(&self, name: &str) -> Result<Vec<McpRemoteTool>, String> {
380 let server = {
381 let servers = self.servers.read().await;
382 let Some(server) = servers.get(name) else {
383 return Err("MCP server not found".to_string());
384 };
385 server.clone()
386 };
387
388 if !server.enabled {
389 return Err("MCP server is disabled".to_string());
390 }
391
392 let endpoint = parse_remote_endpoint(&server.transport)
393 .ok_or_else(|| "MCP refresh currently supports HTTP/S transports only".to_string())?;
394
395 let _ = self.ensure_oauth_bearer_token_fresh(name, false).await;
396 let server = {
397 let servers = self.servers.read().await;
398 let Some(server) = servers.get(name) else {
399 return Err("MCP server not found".to_string());
400 };
401 server.clone()
402 };
403 let request_headers = effective_headers(&server);
404 let discovery = self
405 .discover_remote_tools(name, &endpoint, &request_headers)
406 .await;
407 let (tools, session_id) = match discovery {
408 Ok(result) => result,
409 Err(DiscoverRemoteToolsError::AuthChallenge(challenge)) => {
410 let mut servers = self.servers.write().await;
411 if let Some(entry) = servers.get_mut(name) {
412 entry.connected = false;
413 entry.pid = None;
414 entry.last_error = Some(challenge.message.clone());
415 entry.last_auth_challenge = Some(challenge.clone());
416 entry.mcp_session_id = None;
417 entry.pending_auth_by_tool.clear();
418 entry.tool_cache.clear();
419 entry.tools_fetched_at_ms = None;
420 }
421 drop(servers);
422 self.persist_state().await;
423 return Err(format!(
424 "MCP server '{name}' requires authorization: {}",
425 challenge.message
426 ));
427 }
428 Err(DiscoverRemoteToolsError::Message(err)) => {
429 if should_retry_mcp_oauth_refresh(&server, &err)
430 && self.ensure_oauth_bearer_token_fresh(name, true).await?
431 {
432 let refreshed_server = {
433 let servers = self.servers.read().await;
434 servers
435 .get(name)
436 .cloned()
437 .ok_or_else(|| "MCP server not found".to_string())?
438 };
439 match self
440 .discover_remote_tools(
441 name,
442 &endpoint,
443 &effective_headers(&refreshed_server),
444 )
445 .await
446 {
447 Ok(result) => result,
448 Err(DiscoverRemoteToolsError::AuthChallenge(challenge)) => {
449 let mut servers = self.servers.write().await;
450 if let Some(entry) = servers.get_mut(name) {
451 entry.connected = false;
452 entry.pid = None;
453 entry.last_error = Some(challenge.message.clone());
454 entry.last_auth_challenge = Some(challenge.clone());
455 entry.mcp_session_id = None;
456 entry.pending_auth_by_tool.clear();
457 entry.tool_cache.clear();
458 entry.tools_fetched_at_ms = None;
459 }
460 drop(servers);
461 self.persist_state().await;
462 return Err(format!(
463 "MCP server '{name}' requires authorization: {}",
464 challenge.message
465 ));
466 }
467 Err(DiscoverRemoteToolsError::Message(retry_err)) => {
468 let mut servers = self.servers.write().await;
469 if let Some(entry) = servers.get_mut(name) {
470 entry.connected = false;
471 entry.pid = None;
472 entry.last_error = Some(retry_err.clone());
473 entry.last_auth_challenge = None;
474 entry.mcp_session_id = None;
475 entry.pending_auth_by_tool.clear();
476 entry.tool_cache.clear();
477 entry.tools_fetched_at_ms = None;
478 }
479 drop(servers);
480 self.persist_state().await;
481 return Err(retry_err);
482 }
483 }
484 } else {
485 let mut servers = self.servers.write().await;
486 if let Some(entry) = servers.get_mut(name) {
487 entry.connected = false;
488 entry.pid = None;
489 entry.last_error = Some(err.clone());
490 entry.last_auth_challenge = None;
491 entry.mcp_session_id = None;
492 entry.pending_auth_by_tool.clear();
493 entry.tool_cache.clear();
494 entry.tools_fetched_at_ms = None;
495 }
496 drop(servers);
497 self.persist_state().await;
498 return Err(err);
499 }
500 }
501 };
502
503 let now = now_ms();
504 let cache = tools
505 .iter()
506 .map(|tool| McpToolCacheEntry {
507 tool_name: tool.tool_name.clone(),
508 description: tool.description.clone(),
509 input_schema: tool.input_schema.clone(),
510 fetched_at_ms: now,
511 schema_hash: schema_hash(&tool.input_schema),
512 })
513 .collect::<Vec<_>>();
514
515 let mut servers = self.servers.write().await;
516 if let Some(entry) = servers.get_mut(name) {
517 entry.connected = true;
518 entry.pid = None;
519 entry.last_error = None;
520 entry.last_auth_challenge = None;
521 entry.mcp_session_id = session_id;
522 entry.tool_cache = cache;
523 entry.tools_fetched_at_ms = Some(now);
524 entry.pending_auth_by_tool.clear();
525 }
526 drop(servers);
527 self.persist_state().await;
528 Ok(self.server_tools(name).await)
529 }
530
531 pub async fn disconnect(&self, name: &str) -> bool {
532 if let Some(mut child) = self.processes.lock().await.remove(name) {
533 let _ = child.kill().await;
534 let _ = child.wait().await;
535 }
536 let mut servers = self.servers.write().await;
537 if let Some(server) = servers.get_mut(name) {
538 server.connected = false;
539 server.pid = None;
540 server.last_auth_challenge = None;
541 server.mcp_session_id = None;
542 server.pending_auth_by_tool.clear();
543 drop(servers);
544 self.persist_state().await;
545 return true;
546 }
547 false
548 }
549
550 pub async fn complete_auth(&self, name: &str) -> bool {
551 let mut servers = self.servers.write().await;
552 let Some(server) = servers.get_mut(name) else {
553 return false;
554 };
555 server.last_error = None;
556 server.last_auth_challenge = None;
557 server.pending_auth_by_tool.clear();
558 drop(servers);
559 self.persist_state().await;
560 true
561 }
562
563 pub async fn set_auth_kind(&self, name: &str, auth_kind: String) -> bool {
564 let normalized = normalize_auth_kind(&auth_kind);
565 let mut servers = self.servers.write().await;
566 let Some(server) = servers.get_mut(name) else {
567 return false;
568 };
569 server.auth_kind = normalized;
570 drop(servers);
571 self.persist_state().await;
572 true
573 }
574
575 pub async fn record_server_auth_challenge(
576 &self,
577 name: &str,
578 challenge: McpAuthChallenge,
579 last_error: Option<String>,
580 ) -> bool {
581 let mut servers = self.servers.write().await;
582 let Some(server) = servers.get_mut(name) else {
583 return false;
584 };
585 let tool_key = canonical_tool_key(&challenge.tool_name);
586 server.connected = false;
587 server.pid = None;
588 server.last_error = last_error.or_else(|| Some(challenge.message.clone()));
589 server.last_auth_challenge = Some(challenge.clone());
590 server.mcp_session_id = None;
591 server.pending_auth_by_tool.clear();
592 server
593 .pending_auth_by_tool
594 .insert(tool_key, pending_auth_from_challenge(&challenge));
595 drop(servers);
596 self.persist_state().await;
597 true
598 }
599
600 pub async fn clear_server_auth_challenge(&self, name: &str) -> bool {
601 let mut servers = self.servers.write().await;
602 let Some(server) = servers.get_mut(name) else {
603 return false;
604 };
605 server.last_auth_challenge = None;
606 server.pending_auth_by_tool.clear();
607 drop(servers);
608 self.persist_state().await;
609 true
610 }
611
612 pub async fn set_bearer_token(&self, name: &str, token: &str) -> Result<bool, String> {
613 let trimmed = token.trim();
614 if trimmed.is_empty() {
615 return Err("oauth access token cannot be empty".to_string());
616 }
617 let current_tenant = local_tenant_context();
618 let mut servers = self.servers.write().await;
619 let Some(server) = servers.get_mut(name) else {
620 return Ok(false);
621 };
622 let header_name = "Authorization".to_string();
623 let secret_id = mcp_header_secret_id(name, &header_name);
624 tandem_core::set_provider_auth(&secret_id, &format!("Bearer {trimmed}"))
625 .map_err(|error| error.to_string())?;
626 server.secret_headers.insert(
627 header_name.clone(),
628 McpSecretRef::Store {
629 secret_id: secret_id.clone(),
630 tenant_context: current_tenant,
631 },
632 );
633 server
634 .secret_header_values
635 .insert(header_name.clone(), format!("Bearer {trimmed}"));
636 server.headers.remove(&header_name);
637 drop(servers);
638 self.persist_state().await;
639 Ok(true)
640 }
641
642 pub async fn set_oauth_refresh_config(
643 &self,
644 name: &str,
645 provider_id: String,
646 token_endpoint: String,
647 client_id: String,
648 client_secret: Option<String>,
649 ) -> Result<bool, String> {
650 let current_tenant = local_tenant_context();
651 let mut servers = self.servers.write().await;
652 let Some(server) = servers.get_mut(name) else {
653 return Ok(false);
654 };
655
656 let client_secret_ref = client_secret
657 .as_deref()
658 .map(str::trim)
659 .filter(|value| !value.is_empty())
660 .map(|value| -> Result<McpSecretRef, String> {
661 let secret_id = mcp_oauth_client_secret_id(name);
662 tandem_core::set_provider_auth(&secret_id, value)
663 .map_err(|error| error.to_string())?;
664 Ok(McpSecretRef::Store {
665 secret_id,
666 tenant_context: current_tenant.clone(),
667 })
668 })
669 .transpose()?;
670 if client_secret_ref.is_none() {
671 let secret_id = mcp_oauth_client_secret_id(name);
672 let _ = tandem_core::delete_provider_auth(&secret_id);
673 }
674
675 server.oauth = Some(McpOAuthConfig {
676 provider_id,
677 token_endpoint,
678 client_id,
679 client_secret_ref,
680 client_secret_value: client_secret
681 .map(|value| value.trim().to_string())
682 .filter(|value| !value.is_empty()),
683 });
684 drop(servers);
685 self.persist_state().await;
686 Ok(true)
687 }
688
689 pub async fn list_tools(&self) -> Vec<McpRemoteTool> {
690 let mut out = self
691 .servers
692 .read()
693 .await
694 .values()
695 .filter(|server| server.enabled && server.connected)
696 .flat_map(server_tool_rows)
697 .collect::<Vec<_>>();
698 out.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
699 out
700 }
701
702 pub async fn server_tools(&self, name: &str) -> Vec<McpRemoteTool> {
703 let Some(server) = self.servers.read().await.get(name).cloned() else {
704 return Vec::new();
705 };
706 let mut rows = server_tool_rows(&server);
707 rows.sort_by(|a, b| a.namespaced_name.cmp(&b.namespaced_name));
708 rows
709 }
710
711 pub async fn call_tool(
712 &self,
713 server_name: &str,
714 tool_name: &str,
715 args: Value,
716 ) -> Result<ToolResult, String> {
717 let server = {
718 let servers = self.servers.read().await;
719 let Some(server) = servers.get(server_name) else {
720 return Err(format!("MCP server '{server_name}' not found"));
721 };
722 server.clone()
723 };
724
725 if !server.enabled {
726 return Err(format!("MCP server '{server_name}' is disabled"));
727 }
728 if !server.connected {
729 return Err(format!("MCP server '{server_name}' is not connected"));
730 }
731
732 let endpoint = parse_remote_endpoint(&server.transport).ok_or_else(|| {
733 "MCP tools/call currently supports HTTP/S transports only".to_string()
734 })?;
735 let canonical_tool = canonical_tool_key(tool_name);
736 let now = now_ms();
737 let _ = self
738 .ensure_oauth_bearer_token_fresh(server_name, false)
739 .await;
740 let server = {
741 let servers = self.servers.read().await;
742 let Some(server) = servers.get(server_name) else {
743 return Err(format!("MCP server '{server_name}' not found"));
744 };
745 server.clone()
746 };
747 if let Some(blocked) = pending_auth_short_circuit(
748 &server,
749 &canonical_tool,
750 tool_name,
751 now,
752 MCP_AUTH_REPROBE_COOLDOWN_MS,
753 ) {
754 return Ok(ToolResult {
755 output: blocked.output,
756 metadata: json!({
757 "server": server_name,
758 "tool": tool_name,
759 "result": Value::Null,
760 "mcpAuth": blocked.mcp_auth
761 }),
762 });
763 }
764 let normalized_args = normalize_mcp_tool_args(&server, tool_name, args);
765
766 {
767 let mut servers = self.servers.write().await;
768 if let Some(row) = servers.get_mut(server_name) {
769 if let Some(pending) = row.pending_auth_by_tool.get_mut(&canonical_tool) {
770 pending.last_probe_ms = now;
771 }
772 }
773 }
774
775 let request = json!({
776 "jsonrpc": "2.0",
777 "id": format!("call-{}-{}", server_name, now_ms()),
778 "method": "tools/call",
779 "params": {
780 "name": tool_name,
781 "arguments": normalized_args
782 }
783 });
784 let (response, session_id) = match post_json_rpc_with_session(
785 &endpoint,
786 &effective_headers(&server),
787 request.clone(),
788 server.mcp_session_id.as_deref(),
789 )
790 .await
791 {
792 Ok(result) => result,
793 Err(error) => {
794 if should_retry_mcp_oauth_refresh(&server, &error)
795 && self
796 .ensure_oauth_bearer_token_fresh(server_name, true)
797 .await?
798 {
799 let refreshed_server = {
800 let servers = self.servers.read().await;
801 servers
802 .get(server_name)
803 .cloned()
804 .ok_or_else(|| format!("MCP server '{server_name}' not found"))?
805 };
806 post_json_rpc_with_session(
807 &endpoint,
808 &effective_headers(&refreshed_server),
809 request,
810 refreshed_server.mcp_session_id.as_deref(),
811 )
812 .await?
813 } else {
814 return Err(error);
815 }
816 }
817 };
818 if session_id.is_some() {
819 let mut servers = self.servers.write().await;
820 if let Some(row) = servers.get_mut(server_name) {
821 row.mcp_session_id = session_id;
822 }
823 drop(servers);
824 self.persist_state().await;
825 }
826
827 if let Some(err) = response.get("error") {
828 if let Some(challenge) = extract_auth_challenge(err, tool_name) {
829 let output = format!(
830 "{}\n\nAuthorize here: {}",
831 challenge.message, challenge.authorization_url
832 );
833 {
834 let mut servers = self.servers.write().await;
835 if let Some(row) = servers.get_mut(server_name) {
836 row.last_auth_challenge = Some(challenge.clone());
837 row.last_error = None;
838 row.pending_auth_by_tool.insert(
839 canonical_tool.clone(),
840 pending_auth_from_challenge(&challenge),
841 );
842 }
843 }
844 self.persist_state().await;
845 return Ok(ToolResult {
846 output,
847 metadata: json!({
848 "server": server_name,
849 "tool": tool_name,
850 "result": Value::Null,
851 "mcpAuth": {
852 "required": true,
853 "challengeId": challenge.challenge_id,
854 "tool": challenge.tool_name,
855 "authorizationUrl": challenge.authorization_url,
856 "message": challenge.message,
857 "status": challenge.status
858 }
859 }),
860 });
861 }
862 let message = err
863 .get("message")
864 .and_then(|v| v.as_str())
865 .unwrap_or("MCP tools/call failed");
866 return Err(message.to_string());
867 }
868
869 let result = response.get("result").cloned().unwrap_or(Value::Null);
870 let auth_challenge = extract_auth_challenge(&result, tool_name);
871 let output = if let Some(challenge) = auth_challenge.as_ref() {
872 format!(
873 "{}\n\nAuthorize here: {}",
874 challenge.message, challenge.authorization_url
875 )
876 } else {
877 result
878 .get("content")
879 .map(render_mcp_content)
880 .or_else(|| result.get("output").map(|v| v.to_string()))
881 .unwrap_or_else(|| result.to_string())
882 };
883
884 {
885 let mut servers = self.servers.write().await;
886 if let Some(row) = servers.get_mut(server_name) {
887 row.last_auth_challenge = auth_challenge.clone();
888 if let Some(challenge) = auth_challenge.as_ref() {
889 row.pending_auth_by_tool.insert(
890 canonical_tool.clone(),
891 pending_auth_from_challenge(challenge),
892 );
893 } else {
894 row.pending_auth_by_tool.remove(&canonical_tool);
895 }
896 }
897 }
898 self.persist_state().await;
899
900 let auth_metadata = auth_challenge.as_ref().map(|challenge| {
901 json!({
902 "required": true,
903 "challengeId": challenge.challenge_id,
904 "tool": challenge.tool_name,
905 "authorizationUrl": challenge.authorization_url,
906 "message": challenge.message,
907 "status": challenge.status
908 })
909 });
910
911 Ok(ToolResult {
912 output,
913 metadata: json!({
914 "server": server_name,
915 "tool": tool_name,
916 "result": result,
917 "mcpAuth": auth_metadata
918 }),
919 })
920 }
921
922 async fn connect_stdio(&self, name: &str, command_text: &str) -> bool {
923 match spawn_stdio_process(command_text).await {
924 Ok(child) => {
925 let pid = child.id();
926 self.processes.lock().await.insert(name.to_string(), child);
927 let mut servers = self.servers.write().await;
928 if let Some(server) = servers.get_mut(name) {
929 server.connected = true;
930 server.pid = pid;
931 server.last_error = None;
932 server.last_auth_challenge = None;
933 server.pending_auth_by_tool.clear();
934 }
935 drop(servers);
936 self.persist_state().await;
937 true
938 }
939 Err(err) => {
940 let mut servers = self.servers.write().await;
941 if let Some(server) = servers.get_mut(name) {
942 server.connected = false;
943 server.pid = None;
944 server.last_error = Some(err);
945 server.last_auth_challenge = None;
946 server.pending_auth_by_tool.clear();
947 }
948 drop(servers);
949 self.persist_state().await;
950 false
951 }
952 }
953 }
954
955 async fn discover_remote_tools(
956 &self,
957 server_name: &str,
958 endpoint: &str,
959 headers: &HashMap<String, String>,
960 ) -> Result<(Vec<McpRemoteTool>, Option<String>), DiscoverRemoteToolsError> {
961 let initialize = json!({
962 "jsonrpc": "2.0",
963 "id": "initialize-1",
964 "method": "initialize",
965 "params": {
966 "protocolVersion": MCP_PROTOCOL_VERSION,
967 "capabilities": {},
968 "clientInfo": {
969 "name": MCP_CLIENT_NAME,
970 "version": MCP_CLIENT_VERSION,
971 }
972 }
973 });
974 let (init_response, mut session_id) =
975 post_json_rpc_with_session(endpoint, headers, initialize, None).await?;
976 if let Some(err) = init_response.get("error") {
977 if let Some(challenge) = extract_auth_challenge(err, server_name) {
978 return Err(DiscoverRemoteToolsError::AuthChallenge(challenge));
979 }
980 let message = err
981 .get("message")
982 .and_then(|v| v.as_str())
983 .unwrap_or("MCP initialize failed");
984 return Err(DiscoverRemoteToolsError::Message(message.to_string()));
985 }
986
987 let tools_list = json!({
988 "jsonrpc": "2.0",
989 "id": "tools-list-1",
990 "method": "tools/list",
991 "params": {}
992 });
993 let (tools_response, next_session_id) =
994 post_json_rpc_with_session(endpoint, headers, tools_list, session_id.as_deref())
995 .await?;
996 if next_session_id.is_some() {
997 session_id = next_session_id;
998 }
999 if let Some(err) = tools_response.get("error") {
1000 if let Some(challenge) = extract_auth_challenge(err, server_name) {
1001 return Err(DiscoverRemoteToolsError::AuthChallenge(challenge));
1002 }
1003 let message = err
1004 .get("message")
1005 .and_then(|v| v.as_str())
1006 .unwrap_or("MCP tools/list failed");
1007 return Err(DiscoverRemoteToolsError::Message(message.to_string()));
1008 }
1009
1010 let tools = tools_response
1011 .get("result")
1012 .and_then(|v| v.get("tools"))
1013 .and_then(|v| v.as_array())
1014 .ok_or_else(|| "MCP tools/list result missing tools array".to_string())?;
1015
1016 let now = now_ms();
1017 let mut out = Vec::new();
1018 for row in tools {
1019 let Some(tool_name) = row.get("name").and_then(|v| v.as_str()) else {
1020 continue;
1021 };
1022 let description = row
1023 .get("description")
1024 .and_then(|v| v.as_str())
1025 .unwrap_or("")
1026 .to_string();
1027 let mut input_schema = row
1028 .get("inputSchema")
1029 .or_else(|| row.get("input_schema"))
1030 .cloned()
1031 .unwrap_or_else(|| json!({"type":"object"}));
1032 normalize_tool_input_schema(&mut input_schema);
1033 out.push(McpRemoteTool {
1034 server_name: String::new(),
1035 tool_name: tool_name.to_string(),
1036 namespaced_name: String::new(),
1037 description,
1038 input_schema,
1039 fetched_at_ms: now,
1040 schema_hash: String::new(),
1041 });
1042 }
1043
1044 Ok((out, session_id))
1045 }
1046
1047 async fn persist_state(&self) {
1048 let snapshot = self.servers.read().await.clone();
1049 persist_state_blocking(self.state_file.as_path(), &snapshot);
1050 }
1051
1052 async fn ensure_oauth_bearer_token_fresh(
1053 &self,
1054 name: &str,
1055 force: bool,
1056 ) -> Result<bool, String> {
1057 let server = {
1058 let servers = self.servers.read().await;
1059 servers.get(name).cloned()
1060 }
1061 .ok_or_else(|| format!("MCP server '{name}' not found"))?;
1062 let Some(oauth) = server.oauth.clone() else {
1063 return Ok(false);
1064 };
1065 let Some(credential) = tandem_core::load_provider_oauth_credential(&oauth.provider_id)
1066 else {
1067 return Ok(false);
1068 };
1069
1070 let should_refresh = force
1071 || credential.expires_at_ms <= now_ms().saturating_add(60_000)
1072 || credential.access_token.trim().is_empty();
1073 if !should_refresh {
1074 return Ok(false);
1075 }
1076
1077 let refreshed = refresh_mcp_oauth_credential(&oauth, &credential).await?;
1078 self.set_bearer_token(name, &refreshed.access_token).await?;
1079 tandem_core::set_provider_oauth_credential(&oauth.provider_id, refreshed)
1080 .map_err(|error| error.to_string())?;
1081 Ok(true)
1082 }
1083}
1084
1085impl Default for McpRegistry {
1086 fn default() -> Self {
1087 Self::new()
1088 }
1089}
1090
1091fn default_enabled() -> bool {
1092 true
1093}
1094
1095fn normalize_allowed_tool_names(raw: Vec<String>) -> Vec<String> {
1096 let mut normalized = Vec::new();
1097 let mut seen = std::collections::HashSet::new();
1098 for tool in raw {
1099 let value = tool.trim().to_string();
1100 if value.is_empty() || !seen.insert(value.clone()) {
1101 continue;
1102 }
1103 normalized.push(value);
1104 }
1105 normalized
1106}
1107
1108fn persist_state_blocking(path: &Path, snapshot: &HashMap<String, McpServer>) {
1109 if let Some(parent) = path.parent() {
1110 let _ = std::fs::create_dir_all(parent);
1111 }
1112 if let Ok(payload) = serde_json::to_string_pretty(snapshot) {
1113 let _ = std::fs::write(path, payload);
1114 }
1115}
1116
1117fn resolve_state_file() -> PathBuf {
1118 if let Ok(path) = std::env::var("TANDEM_MCP_REGISTRY") {
1119 return PathBuf::from(path);
1120 }
1121 if let Ok(state_dir) = std::env::var("TANDEM_STATE_DIR") {
1122 let trimmed = state_dir.trim();
1123 if !trimmed.is_empty() {
1124 return PathBuf::from(trimmed).join("mcp_servers.json");
1125 }
1126 }
1127 if let Some(data_dir) = dirs::data_dir() {
1128 return data_dir
1129 .join("tandem")
1130 .join("data")
1131 .join("mcp_servers.json");
1132 }
1133 dirs::home_dir()
1134 .map(|home| home.join(".tandem").join("data").join("mcp_servers.json"))
1135 .unwrap_or_else(|| PathBuf::from("mcp_servers.json"))
1136}
1137
1138fn load_state(path: &Path) -> (HashMap<String, McpServer>, bool) {
1139 let Ok(raw) = std::fs::read_to_string(path) else {
1140 return (HashMap::new(), false);
1141 };
1142 let mut migrated = false;
1143 let mut parsed = serde_json::from_str::<HashMap<String, McpServer>>(&raw).unwrap_or_default();
1144 for (name, server) in parsed.iter_mut() {
1145 let tenant_context = local_tenant_context();
1146 let (headers, secret_headers, secret_header_values, server_migrated) =
1147 migrate_server_headers(name, server, &tenant_context);
1148 migrated = migrated || server_migrated;
1149 server.headers = headers;
1150 server.secret_headers = secret_headers;
1151 server.secret_header_values = secret_header_values;
1152 }
1153 (parsed, migrated)
1154}
1155
1156fn migrate_server_headers(
1157 server_name: &str,
1158 server: &McpServer,
1159 current_tenant: &TenantContext,
1160) -> (
1161 HashMap<String, String>,
1162 HashMap<String, McpSecretRef>,
1163 HashMap<String, String>,
1164 bool,
1165) {
1166 let original_effective = effective_headers(server);
1167 let mut persisted_secret_headers = server.secret_headers.clone();
1168 let mut secret_header_values =
1169 resolve_secret_header_values(&persisted_secret_headers, current_tenant);
1170 let mut persisted_headers = server.headers.clone();
1171 let mut migrated = false;
1172
1173 let header_keys = persisted_headers.keys().cloned().collect::<Vec<_>>();
1174 for header_name in header_keys {
1175 let Some(value) = persisted_headers.get(&header_name).cloned() else {
1176 continue;
1177 };
1178 if persisted_secret_headers.contains_key(&header_name) {
1179 continue;
1180 }
1181 if let Some(secret_ref) = parse_secret_header_reference(value.trim()) {
1182 persisted_headers.remove(&header_name);
1183 let resolved =
1184 resolve_secret_ref_value(&secret_ref, current_tenant).unwrap_or_default();
1185 persisted_secret_headers.insert(header_name.clone(), secret_ref);
1186 if !resolved.is_empty() {
1187 secret_header_values.insert(header_name.clone(), resolved);
1188 }
1189 migrated = true;
1190 continue;
1191 }
1192 if header_name_is_sensitive(&header_name) && !value.trim().is_empty() {
1193 let secret_id = mcp_header_secret_id(server_name, &header_name);
1194 if tandem_core::set_provider_auth(&secret_id, &value).is_ok() {
1195 persisted_headers.remove(&header_name);
1196 persisted_secret_headers.insert(
1197 header_name.clone(),
1198 McpSecretRef::Store {
1199 secret_id: secret_id.clone(),
1200 tenant_context: current_tenant.clone(),
1201 },
1202 );
1203 secret_header_values.insert(header_name.clone(), value);
1204 migrated = true;
1205 }
1206 }
1207 }
1208
1209 if !migrated {
1210 let effective = combine_headers(&persisted_headers, &secret_header_values);
1211 migrated = effective != original_effective;
1212 }
1213
1214 (
1215 persisted_headers,
1216 persisted_secret_headers,
1217 secret_header_values,
1218 migrated,
1219 )
1220}
1221
1222fn split_headers_for_storage(
1223 server_name: &str,
1224 headers: HashMap<String, String>,
1225 explicit_secret_headers: HashMap<String, McpSecretRef>,
1226 current_tenant: &TenantContext,
1227) -> (
1228 HashMap<String, String>,
1229 HashMap<String, McpSecretRef>,
1230 HashMap<String, String>,
1231) {
1232 let mut persisted_headers = HashMap::new();
1233 let mut persisted_secret_headers = HashMap::new();
1234 let mut secret_header_values = HashMap::new();
1235
1236 for (header_name, raw_value) in headers {
1237 let value = raw_value.trim().to_string();
1238 if value.is_empty() {
1239 continue;
1240 }
1241 if let Some(secret_ref) = parse_secret_header_reference(&value) {
1242 if let Some(resolved) = resolve_secret_ref_value(&secret_ref, current_tenant) {
1243 secret_header_values.insert(header_name.clone(), resolved);
1244 }
1245 persisted_secret_headers.insert(header_name, secret_ref);
1246 continue;
1247 }
1248 if header_name_is_sensitive(&header_name) {
1249 let secret_id = mcp_header_secret_id(server_name, &header_name);
1250 if tandem_core::set_provider_auth(&secret_id, &value).is_ok() {
1251 persisted_secret_headers.insert(
1252 header_name.clone(),
1253 McpSecretRef::Store {
1254 secret_id: secret_id.clone(),
1255 tenant_context: current_tenant.clone(),
1256 },
1257 );
1258 secret_header_values.insert(header_name, value);
1259 continue;
1260 }
1261 }
1262 persisted_headers.insert(header_name, value);
1263 }
1264
1265 for (header_name, secret_ref) in explicit_secret_headers {
1266 if let Some(resolved) = resolve_secret_ref_value(&secret_ref, current_tenant) {
1267 secret_header_values.insert(header_name.clone(), resolved);
1268 }
1269 persisted_headers.remove(&header_name);
1270 persisted_secret_headers.insert(header_name, secret_ref);
1271 }
1272
1273 (
1274 persisted_headers,
1275 persisted_secret_headers,
1276 secret_header_values,
1277 )
1278}
1279
1280fn combine_headers(
1281 headers: &HashMap<String, String>,
1282 secret_header_values: &HashMap<String, String>,
1283) -> HashMap<String, String> {
1284 let mut combined = headers.clone();
1285 for (key, value) in secret_header_values {
1286 if !value.trim().is_empty() {
1287 combined.insert(key.clone(), value.clone());
1288 }
1289 }
1290 combined
1291}
1292
1293fn effective_headers(server: &McpServer) -> HashMap<String, String> {
1294 combine_headers(&server.headers, &server.secret_header_values)
1295}
1296
1297fn redacted_server_view(server: &McpServer) -> McpServer {
1298 let mut clone = server.clone();
1299 for (header_name, secret_ref) in &clone.secret_headers {
1300 clone.headers.insert(
1301 header_name.clone(),
1302 redacted_secret_header_value(secret_ref),
1303 );
1304 }
1305 clone.secret_header_values.clear();
1306 if let Some(oauth) = clone.oauth.as_mut() {
1307 oauth.client_secret_ref = None;
1308 oauth.client_secret_value = None;
1309 }
1310 clone
1311}
1312
1313fn normalize_auth_kind(raw: &str) -> String {
1314 match raw.trim().to_ascii_lowercase().as_str() {
1315 "oauth" | "auto" | "bearer" | "x-api-key" | "custom" | "none" => {
1316 raw.trim().to_ascii_lowercase()
1317 }
1318 _ => String::new(),
1319 }
1320}
1321
1322fn redacted_secret_header_value(secret_ref: &McpSecretRef) -> String {
1323 match secret_ref {
1324 McpSecretRef::BearerEnv { .. } => "Bearer ".to_string(),
1325 McpSecretRef::Env { .. } | McpSecretRef::Store { .. } => MCP_SECRET_PLACEHOLDER.to_string(),
1326 }
1327}
1328
1329fn resolve_secret_header_values(
1330 secret_headers: &HashMap<String, McpSecretRef>,
1331 current_tenant: &TenantContext,
1332) -> HashMap<String, String> {
1333 let mut out = HashMap::new();
1334 for (header_name, secret_ref) in secret_headers {
1335 if let Some(value) = resolve_secret_ref_value(secret_ref, current_tenant) {
1336 if !value.trim().is_empty() {
1337 out.insert(header_name.clone(), value);
1338 }
1339 }
1340 }
1341 out
1342}
1343
1344fn delete_secret_header_refs(
1345 secret_headers: &HashMap<String, McpSecretRef>,
1346 current_tenant: &TenantContext,
1347) {
1348 for secret_ref in secret_headers.values() {
1349 if let McpSecretRef::Store {
1350 secret_id,
1351 tenant_context,
1352 } = secret_ref
1353 {
1354 if tenant_context != current_tenant {
1355 continue;
1356 }
1357 let _ = tandem_core::delete_provider_auth(secret_id);
1358 }
1359 }
1360}
1361
1362fn delete_oauth_secret_ref(oauth: Option<&McpOAuthConfig>, current_tenant: &TenantContext) {
1363 let Some(secret_ref) = oauth.and_then(|oauth| oauth.client_secret_ref.as_ref()) else {
1364 return;
1365 };
1366 if let McpSecretRef::Store {
1367 secret_id,
1368 tenant_context,
1369 } = secret_ref
1370 {
1371 if tenant_context == current_tenant {
1372 let _ = tandem_core::delete_provider_auth(secret_id);
1373 }
1374 }
1375}