zagens_runtime_adapters/mcp/
auth.rs1use std::collections::HashMap;
8
9use anyhow::{Context, Result};
10use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
11use serde::{Deserialize, Serialize};
12
13use super::config::McpServerConfig;
14
15#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq, Eq)]
17pub struct McpAuthConfig {
18 #[serde(default, rename = "type")]
20 pub auth_type: Option<String>,
21 #[serde(default)]
23 pub token: Option<String>,
24 #[serde(default)]
26 pub header: Option<String>,
27 #[serde(default, alias = "apiKey")]
29 pub api_key: Option<String>,
30}
31
32const SENSITIVE_HEADER_NAMES: &[&str] = &[
34 "authorization",
35 "x-api-key",
36 "api-key",
37 "x-auth-token",
38 "proxy-authorization",
39 "cookie",
40];
41
42impl McpServerConfig {
43 pub fn resolve_http_headers(&self, server_name: &str) -> Result<HashMap<String, String>> {
46 let mut out = HashMap::new();
47
48 if let Some(auth) = &self.auth {
49 auth.apply_to_map(&mut out, server_name)?;
50 }
51
52 for (name, value) in &self.headers {
53 let resolved = resolve_env_placeholders(value)
54 .with_context(|| format!("MCP server '{server_name}' header '{name}'"))?;
55 out.insert(name.clone(), resolved);
56 }
57
58 Ok(out)
59 }
60
61 #[must_use]
63 pub fn redacted_for_display(&self) -> Self {
64 let mut copy = self.clone();
65 copy.auth = copy.auth.as_ref().map(McpAuthConfig::redacted);
66 copy.headers = redact_header_map(©.headers);
67 copy
68 }
69}
70
71impl McpAuthConfig {
72 fn apply_to_map(&self, out: &mut HashMap<String, String>, server_name: &str) -> Result<()> {
73 let kind = self
74 .auth_type
75 .as_deref()
76 .map(str::trim)
77 .filter(|s| !s.is_empty())
78 .ok_or_else(|| {
79 anyhow::anyhow!("MCP server '{server_name}' auth block requires a 'type' field")
80 })?;
81
82 match kind.to_ascii_lowercase().as_str() {
83 "bearer" => {
84 let token = self.token.as_deref().ok_or_else(|| {
85 anyhow::anyhow!(
86 "MCP server '{server_name}' bearer auth requires a 'token' field"
87 )
88 })?;
89 let resolved = resolve_env_placeholders(token)
90 .with_context(|| format!("MCP server '{server_name}' bearer token"))?;
91 let value = normalize_bearer_value(&resolved);
92 out.insert("Authorization".to_string(), value);
93 }
94 "apikey" | "api_key" | "api-key" => {
95 let header = self
96 .header
97 .as_deref()
98 .filter(|s| !s.trim().is_empty())
99 .unwrap_or("X-API-Key");
100 let key = self
101 .api_key
102 .as_deref()
103 .or(self.token.as_deref())
104 .ok_or_else(|| {
105 anyhow::anyhow!(
106 "MCP server '{server_name}' apiKey auth requires 'apiKey' or 'token'"
107 )
108 })?;
109 let resolved = resolve_env_placeholders(key)
110 .with_context(|| format!("MCP server '{server_name}' apiKey value"))?;
111 out.insert(header.to_string(), resolved);
112 }
113 other => anyhow::bail!(
114 "MCP server '{server_name}' unknown auth type '{other}' (expected bearer or apiKey)"
115 ),
116 }
117 Ok(())
118 }
119
120 #[must_use]
121 fn redacted(&self) -> Self {
122 Self {
123 auth_type: self.auth_type.clone(),
124 token: self
127 .token
128 .as_ref()
129 .filter(|t| looks_like_env_placeholder(t))
130 .cloned(),
131 header: self.header.clone(),
132 api_key: self
133 .api_key
134 .as_ref()
135 .filter(|t| looks_like_env_placeholder(t))
136 .cloned(),
137 }
138 }
139}
140
141pub fn merge_preserved_secrets(new: &mut McpServerConfig, old: &McpServerConfig) {
144 match (&mut new.auth, &old.auth) {
145 (Some(new_auth), Some(old_auth)) => {
146 if new_auth
147 .token
148 .as_deref()
149 .is_none_or(|t| t.trim().is_empty())
150 {
151 new_auth.token = old_auth.token.clone();
152 }
153 if new_auth
154 .api_key
155 .as_deref()
156 .is_none_or(|t| t.trim().is_empty())
157 {
158 new_auth.api_key = old_auth.api_key.clone();
159 }
160 if new_auth.auth_type.is_none() {
161 new_auth.auth_type = old_auth.auth_type.clone();
162 }
163 if new_auth.header.is_none() {
164 new_auth.header = old_auth.header.clone();
165 }
166 }
167 (None, Some(old_auth)) if old_auth.token.is_some() || old_auth.api_key.is_some() => {
168 new.auth = Some(old_auth.clone());
169 }
170 _ => {}
171 }
172
173 for (name, value) in &old.headers {
174 if is_sensitive_header(name)
175 && !new.headers.contains_key(name)
176 && !looks_like_env_placeholder(value)
177 {
178 new.headers.insert(name.clone(), value.clone());
179 }
180 }
181}
182
183pub fn resolve_env_placeholders(raw: &str) -> Result<String> {
185 if !raw.contains('$') {
186 return Ok(raw.to_string());
187 }
188
189 let mut out = String::with_capacity(raw.len());
190 let bytes = raw.as_bytes();
191 let mut i = 0usize;
192 while i < bytes.len() {
193 if bytes[i] != b'$' {
194 out.push(bytes[i] as char);
195 i += 1;
196 continue;
197 }
198
199 if i + 1 < bytes.len() && bytes[i + 1] == b'{' {
201 let start = i + 2;
202 let mut j = start;
203 while j < bytes.len() && bytes[j] != b'}' {
204 j += 1;
205 }
206 if j >= bytes.len() {
207 anyhow::bail!("unclosed '${{...}}' in value");
208 }
209 let name = std::str::from_utf8(&bytes[start..j])
210 .context("invalid UTF-8 in env placeholder")?
211 .trim();
212 if name.is_empty() {
213 anyhow::bail!("empty env placeholder '${{}}'");
214 }
215 let value = std::env::var(name)
216 .with_context(|| format!("environment variable '{name}' is not set"))?;
217 out.push_str(&value);
218 i = j + 1;
219 continue;
220 }
221
222 let start = i + 1;
224 if start >= bytes.len() {
225 out.push('$');
226 break;
227 }
228 let first = bytes[start];
229 if !(first.is_ascii_alphabetic() || first == b'_') {
230 out.push('$');
231 i += 1;
232 continue;
233 }
234 let mut j = start + 1;
235 while j < bytes.len() {
236 let b = bytes[j];
237 if b.is_ascii_alphanumeric() || b == b'_' {
238 j += 1;
239 } else {
240 break;
241 }
242 }
243 let name =
244 std::str::from_utf8(&bytes[start..j]).context("invalid UTF-8 in env placeholder")?;
245 let value = std::env::var(name)
246 .with_context(|| format!("environment variable '{name}' is not set"))?;
247 out.push_str(&value);
248 i = j;
249 }
250 Ok(out)
251}
252
253pub fn apply_default_headers(
255 builder: reqwest::ClientBuilder,
256 headers: &HashMap<String, String>,
257) -> Result<reqwest::ClientBuilder> {
258 if headers.is_empty() {
259 return Ok(builder);
260 }
261 let mut map = HeaderMap::new();
262 for (name, value) in headers {
263 let name = HeaderName::from_bytes(name.as_bytes())
264 .with_context(|| format!("invalid HTTP header name '{name}'"))?;
265 let value = HeaderValue::from_str(value)
266 .with_context(|| format!("invalid HTTP header value for '{name}'"))?;
267 map.insert(name, value);
268 }
269 Ok(builder.default_headers(map))
270}
271
272fn normalize_bearer_value(token: &str) -> String {
273 let trimmed = token.trim();
274 if trimmed.len() >= 7 && trimmed[..7].eq_ignore_ascii_case("bearer ") {
275 trimmed.to_string()
276 } else {
277 format!("Bearer {trimmed}")
278 }
279}
280
281fn redact_header_map(headers: &HashMap<String, String>) -> HashMap<String, String> {
282 headers
283 .iter()
284 .filter_map(|(k, v)| {
285 if is_sensitive_header(k) && !looks_like_env_placeholder(v) {
286 None
287 } else {
288 Some((k.clone(), v.clone()))
289 }
290 })
291 .collect()
292}
293
294fn is_sensitive_header(name: &str) -> bool {
295 let lower = name.trim().to_ascii_lowercase();
296 SENSITIVE_HEADER_NAMES.iter().any(|s| *s == lower)
297}
298
299fn looks_like_env_placeholder(value: &str) -> bool {
300 value.contains("${") || value.starts_with('$')
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn resolve_env_braced() {
309 unsafe {
311 std::env::set_var("MCP_TEST_TOKEN", "secret-value");
312 }
313 assert_eq!(
314 resolve_env_placeholders("Bearer ${MCP_TEST_TOKEN}").unwrap(),
315 "Bearer secret-value"
316 );
317 unsafe {
318 std::env::remove_var("MCP_TEST_TOKEN");
319 }
320 }
321
322 #[test]
323 fn bearer_auth_adds_prefix() {
324 let cfg = McpServerConfig {
325 command: None,
326 args: vec![],
327 env: HashMap::new(),
328 url: Some("https://example.com/mcp".to_string()),
329 transport: Some("http".to_string()),
330 headers: HashMap::new(),
331 auth: Some(McpAuthConfig {
332 auth_type: Some("bearer".to_string()),
333 token: Some("tok123".to_string()),
334 header: None,
335 api_key: None,
336 }),
337 connect_timeout: None,
338 execute_timeout: None,
339 read_timeout: None,
340 disabled: false,
341 enabled: true,
342 required: false,
343 enabled_tools: vec![],
344 disabled_tools: vec![],
345 };
346 let headers = cfg.resolve_http_headers("test").unwrap();
347 assert_eq!(
348 headers.get("Authorization").map(String::as_str),
349 Some("Bearer tok123")
350 );
351 }
352
353 #[test]
354 fn custom_headers_override_auth() {
355 let mut headers = HashMap::new();
356 headers.insert("Authorization".to_string(), "Bearer override".to_string());
357 let cfg = McpServerConfig {
358 command: None,
359 args: vec![],
360 env: HashMap::new(),
361 url: Some("https://example.com/mcp".to_string()),
362 transport: None,
363 headers,
364 auth: Some(McpAuthConfig {
365 auth_type: Some("bearer".to_string()),
366 token: Some("from-auth".to_string()),
367 header: None,
368 api_key: None,
369 }),
370 connect_timeout: None,
371 execute_timeout: None,
372 read_timeout: None,
373 disabled: false,
374 enabled: true,
375 required: false,
376 enabled_tools: vec![],
377 disabled_tools: vec![],
378 };
379 let resolved = cfg.resolve_http_headers("test").unwrap();
380 assert_eq!(
381 resolved.get("Authorization").map(String::as_str),
382 Some("Bearer override")
383 );
384 }
385}