Skip to main content

zagens_runtime_adapters/mcp/
auth.rs

1//! Remote MCP authentication: config → HTTP headers, with env substitution.
2//!
3//! Credentials should not live as plaintext in `mcp.json` when avoidable — use
4//! `${ENV_VAR}` placeholders (resolved at connection time from the process
5//! environment, including values injected by the desktop shell).
6
7use std::collections::HashMap;
8
9use anyhow::{Context, Result};
10use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
11use serde::{Deserialize, Serialize};
12
13use super::config::McpServerConfig;
14
15/// Optional auth block on a remote MCP server (`sse` / `http` transports).
16#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq, Eq)]
17pub struct McpAuthConfig {
18    /// `bearer` or `apiKey` (case-insensitive).
19    #[serde(default, rename = "type")]
20    pub auth_type: Option<String>,
21    /// Bearer token, or full `Bearer <token>` value. Supports `${ENV}`.
22    #[serde(default)]
23    pub token: Option<String>,
24    /// Header name for API-key auth (default `X-API-Key`).
25    #[serde(default)]
26    pub header: Option<String>,
27    /// API key value. Supports `${ENV}`. Accepts `apiKey` alias in JSON.
28    #[serde(default, alias = "apiKey")]
29    pub api_key: Option<String>,
30}
31
32/// Header names treated as sensitive when exporting config to the UI/API.
33const SENSITIVE_HEADER_NAMES: &[&str] = &[
34    "authorization",
35    "x-api-key",
36    "api-key",
37    "x-auth-token",
38    "proxy-authorization",
39    "cookie",
40];
41
42impl McpServerConfig {
43    /// Resolve HTTP headers for remote transports: explicit `headers`, plus
44    /// shorthand `auth`, with `${VAR}` / `$VAR` env substitution.
45    pub fn resolve_http_headers(&self, server_name: &str) -> Result<HashMap<String, String>> {
46        let mut out = HashMap::new();
47
48        if let Some(auth) = &self.auth {
49            auth.apply_to_map(&mut out, server_name)?;
50        }
51
52        for (name, value) in &self.headers {
53            let resolved = resolve_env_placeholders(value)
54                .with_context(|| format!("MCP server '{server_name}' header '{name}'"))?;
55            out.insert(name.clone(), resolved);
56        }
57
58        Ok(out)
59    }
60
61    /// Return a copy safe to expose over HTTP APIs (secrets redacted).
62    #[must_use]
63    pub fn redacted_for_display(&self) -> Self {
64        let mut copy = self.clone();
65        copy.auth = copy.auth.as_ref().map(McpAuthConfig::redacted);
66        copy.headers = redact_header_map(&copy.headers);
67        copy
68    }
69}
70
71impl McpAuthConfig {
72    fn apply_to_map(&self, out: &mut HashMap<String, String>, server_name: &str) -> Result<()> {
73        let kind = self
74            .auth_type
75            .as_deref()
76            .map(str::trim)
77            .filter(|s| !s.is_empty())
78            .ok_or_else(|| {
79                anyhow::anyhow!("MCP server '{server_name}' auth block requires a 'type' field")
80            })?;
81
82        match kind.to_ascii_lowercase().as_str() {
83            "bearer" => {
84                let token = self.token.as_deref().ok_or_else(|| {
85                    anyhow::anyhow!(
86                        "MCP server '{server_name}' bearer auth requires a 'token' field"
87                    )
88                })?;
89                let resolved = resolve_env_placeholders(token)
90                    .with_context(|| format!("MCP server '{server_name}' bearer token"))?;
91                let value = normalize_bearer_value(&resolved);
92                out.insert("Authorization".to_string(), value);
93            }
94            "apikey" | "api_key" | "api-key" => {
95                let header = self
96                    .header
97                    .as_deref()
98                    .filter(|s| !s.trim().is_empty())
99                    .unwrap_or("X-API-Key");
100                let key = self
101                    .api_key
102                    .as_deref()
103                    .or(self.token.as_deref())
104                    .ok_or_else(|| {
105                        anyhow::anyhow!(
106                            "MCP server '{server_name}' apiKey auth requires 'apiKey' or 'token'"
107                        )
108                    })?;
109                let resolved = resolve_env_placeholders(key)
110                    .with_context(|| format!("MCP server '{server_name}' apiKey value"))?;
111                out.insert(header.to_string(), resolved);
112            }
113            other => anyhow::bail!(
114                "MCP server '{server_name}' unknown auth type '{other}' (expected bearer or apiKey)"
115            ),
116        }
117        Ok(())
118    }
119
120    #[must_use]
121    fn redacted(&self) -> Self {
122        Self {
123            auth_type: self.auth_type.clone(),
124            // Plaintext secrets are omitted so API consumers cannot echo them
125            // back on PUT; [`merge_preserved_secrets`] restores them on save.
126            token: self
127                .token
128                .as_ref()
129                .filter(|t| looks_like_env_placeholder(t))
130                .cloned(),
131            header: self.header.clone(),
132            api_key: self
133                .api_key
134                .as_ref()
135                .filter(|t| looks_like_env_placeholder(t))
136                .cloned(),
137        }
138    }
139}
140
141/// When the UI saves a server block after a redacted GET, restore secret
142/// fields that were omitted (not re-entered by the user).
143pub fn merge_preserved_secrets(new: &mut McpServerConfig, old: &McpServerConfig) {
144    match (&mut new.auth, &old.auth) {
145        (Some(new_auth), Some(old_auth)) => {
146            if new_auth
147                .token
148                .as_deref()
149                .is_none_or(|t| t.trim().is_empty())
150            {
151                new_auth.token = old_auth.token.clone();
152            }
153            if new_auth
154                .api_key
155                .as_deref()
156                .is_none_or(|t| t.trim().is_empty())
157            {
158                new_auth.api_key = old_auth.api_key.clone();
159            }
160            if new_auth.auth_type.is_none() {
161                new_auth.auth_type = old_auth.auth_type.clone();
162            }
163            if new_auth.header.is_none() {
164                new_auth.header = old_auth.header.clone();
165            }
166        }
167        (None, Some(old_auth)) if old_auth.token.is_some() || old_auth.api_key.is_some() => {
168            new.auth = Some(old_auth.clone());
169        }
170        _ => {}
171    }
172
173    for (name, value) in &old.headers {
174        if is_sensitive_header(name)
175            && !new.headers.contains_key(name)
176            && !looks_like_env_placeholder(value)
177        {
178            new.headers.insert(name.clone(), value.clone());
179        }
180    }
181}
182
183/// Substitute `${VAR}` and `$VAR` from the process environment.
184pub fn resolve_env_placeholders(raw: &str) -> Result<String> {
185    if !raw.contains('$') {
186        return Ok(raw.to_string());
187    }
188
189    let mut out = String::with_capacity(raw.len());
190    let bytes = raw.as_bytes();
191    let mut i = 0usize;
192    while i < bytes.len() {
193        if bytes[i] != b'$' {
194            out.push(bytes[i] as char);
195            i += 1;
196            continue;
197        }
198
199        // ${VAR}
200        if i + 1 < bytes.len() && bytes[i + 1] == b'{' {
201            let start = i + 2;
202            let mut j = start;
203            while j < bytes.len() && bytes[j] != b'}' {
204                j += 1;
205            }
206            if j >= bytes.len() {
207                anyhow::bail!("unclosed '${{...}}' in value");
208            }
209            let name = std::str::from_utf8(&bytes[start..j])
210                .context("invalid UTF-8 in env placeholder")?
211                .trim();
212            if name.is_empty() {
213                anyhow::bail!("empty env placeholder '${{}}'");
214            }
215            let value = std::env::var(name)
216                .with_context(|| format!("environment variable '{name}' is not set"))?;
217            out.push_str(&value);
218            i = j + 1;
219            continue;
220        }
221
222        // $VAR (identifier: letter/underscore start, then alphanumeric/_)
223        let start = i + 1;
224        if start >= bytes.len() {
225            out.push('$');
226            break;
227        }
228        let first = bytes[start];
229        if !(first.is_ascii_alphabetic() || first == b'_') {
230            out.push('$');
231            i += 1;
232            continue;
233        }
234        let mut j = start + 1;
235        while j < bytes.len() {
236            let b = bytes[j];
237            if b.is_ascii_alphanumeric() || b == b'_' {
238                j += 1;
239            } else {
240                break;
241            }
242        }
243        let name =
244            std::str::from_utf8(&bytes[start..j]).context("invalid UTF-8 in env placeholder")?;
245        let value = std::env::var(name)
246            .with_context(|| format!("environment variable '{name}' is not set"))?;
247        out.push_str(&value);
248        i = j;
249    }
250    Ok(out)
251}
252
253/// Apply resolved headers as reqwest default headers (remote MCP only).
254pub fn apply_default_headers(
255    builder: reqwest::ClientBuilder,
256    headers: &HashMap<String, String>,
257) -> Result<reqwest::ClientBuilder> {
258    if headers.is_empty() {
259        return Ok(builder);
260    }
261    let mut map = HeaderMap::new();
262    for (name, value) in headers {
263        let name = HeaderName::from_bytes(name.as_bytes())
264            .with_context(|| format!("invalid HTTP header name '{name}'"))?;
265        let value = HeaderValue::from_str(value)
266            .with_context(|| format!("invalid HTTP header value for '{name}'"))?;
267        map.insert(name, value);
268    }
269    Ok(builder.default_headers(map))
270}
271
272fn normalize_bearer_value(token: &str) -> String {
273    let trimmed = token.trim();
274    if trimmed.len() >= 7 && trimmed[..7].eq_ignore_ascii_case("bearer ") {
275        trimmed.to_string()
276    } else {
277        format!("Bearer {trimmed}")
278    }
279}
280
281fn redact_header_map(headers: &HashMap<String, String>) -> HashMap<String, String> {
282    headers
283        .iter()
284        .filter_map(|(k, v)| {
285            if is_sensitive_header(k) && !looks_like_env_placeholder(v) {
286                None
287            } else {
288                Some((k.clone(), v.clone()))
289            }
290        })
291        .collect()
292}
293
294fn is_sensitive_header(name: &str) -> bool {
295    let lower = name.trim().to_ascii_lowercase();
296    SENSITIVE_HEADER_NAMES.iter().any(|s| *s == lower)
297}
298
299fn looks_like_env_placeholder(value: &str) -> bool {
300    value.contains("${") || value.starts_with('$')
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn resolve_env_braced() {
309        // SAFETY: test-only env mutation; single-threaded `cargo test` harness.
310        unsafe {
311            std::env::set_var("MCP_TEST_TOKEN", "secret-value");
312        }
313        assert_eq!(
314            resolve_env_placeholders("Bearer ${MCP_TEST_TOKEN}").unwrap(),
315            "Bearer secret-value"
316        );
317        unsafe {
318            std::env::remove_var("MCP_TEST_TOKEN");
319        }
320    }
321
322    #[test]
323    fn bearer_auth_adds_prefix() {
324        let cfg = McpServerConfig {
325            command: None,
326            args: vec![],
327            env: HashMap::new(),
328            url: Some("https://example.com/mcp".to_string()),
329            transport: Some("http".to_string()),
330            headers: HashMap::new(),
331            auth: Some(McpAuthConfig {
332                auth_type: Some("bearer".to_string()),
333                token: Some("tok123".to_string()),
334                header: None,
335                api_key: None,
336            }),
337            connect_timeout: None,
338            execute_timeout: None,
339            read_timeout: None,
340            disabled: false,
341            enabled: true,
342            required: false,
343            enabled_tools: vec![],
344            disabled_tools: vec![],
345        };
346        let headers = cfg.resolve_http_headers("test").unwrap();
347        assert_eq!(
348            headers.get("Authorization").map(String::as_str),
349            Some("Bearer tok123")
350        );
351    }
352
353    #[test]
354    fn custom_headers_override_auth() {
355        let mut headers = HashMap::new();
356        headers.insert("Authorization".to_string(), "Bearer override".to_string());
357        let cfg = McpServerConfig {
358            command: None,
359            args: vec![],
360            env: HashMap::new(),
361            url: Some("https://example.com/mcp".to_string()),
362            transport: None,
363            headers,
364            auth: Some(McpAuthConfig {
365                auth_type: Some("bearer".to_string()),
366                token: Some("from-auth".to_string()),
367                header: None,
368                api_key: None,
369            }),
370            connect_timeout: None,
371            execute_timeout: None,
372            read_timeout: None,
373            disabled: false,
374            enabled: true,
375            required: false,
376            enabled_tools: vec![],
377            disabled_tools: vec![],
378        };
379        let resolved = cfg.resolve_http_headers("test").unwrap();
380        assert_eq!(
381            resolved.get("Authorization").map(String::as_str),
382            Some("Bearer override")
383        );
384    }
385}