Skip to main content

agent_api/
mcp.rs

1use std::collections::BTreeMap;
2use std::path::PathBuf;
3use std::process::ExitStatus;
4use std::time::Duration;
5
6use crate::AgentWrapperError;
7
8pub(crate) const CAPABILITY_MCP_LIST_V1: &str = "agent_api.tools.mcp.list.v1";
9pub(crate) const CAPABILITY_MCP_GET_V1: &str = "agent_api.tools.mcp.get.v1";
10pub(crate) const CAPABILITY_MCP_ADD_V1: &str = "agent_api.tools.mcp.add.v1";
11pub(crate) const CAPABILITY_MCP_REMOVE_V1: &str = "agent_api.tools.mcp.remove.v1";
12
13const ERR_MCP_SERVER_NAME_EMPTY: &str = "mcp server name must be non-empty";
14const ERR_MCP_ADD_STDIO_COMMAND_EMPTY: &str =
15    "mcp add stdio.command must contain at least one item";
16const ERR_MCP_ADD_URL_EMPTY: &str = "mcp add url must be non-empty";
17const ERR_MCP_ADD_URL_INVALID: &str = "mcp add url must be an absolute http or https URL";
18const ERR_MCP_ADD_BEARER_TOKEN_ENV_VAR_EMPTY: &str =
19    "mcp add bearer_token_env_var must be non-empty";
20const ERR_MCP_ADD_BEARER_TOKEN_ENV_VAR_INVALID: &str =
21    "mcp add bearer_token_env_var must match ^[A-Za-z_][A-Za-z0-9_]*$";
22
23#[derive(Clone, Debug, Default)]
24pub struct AgentWrapperMcpCommandContext {
25    pub working_dir: Option<PathBuf>,
26    pub timeout: Option<Duration>,
27    pub env: BTreeMap<String, String>,
28}
29
30#[derive(Clone, Debug)]
31pub struct AgentWrapperMcpCommandOutput {
32    pub status: ExitStatus,
33    /// Backends should populate this via `crate::bounds::enforce_mcp_output_bound` so stdout and
34    /// stderr stay aligned with the pinned MM-C04 truncation algorithm.
35    pub stdout: String,
36    pub stderr: String,
37    pub stdout_truncated: bool,
38    pub stderr_truncated: bool,
39}
40
41#[derive(Clone, Debug, Default)]
42pub struct AgentWrapperMcpListRequest {
43    pub context: AgentWrapperMcpCommandContext,
44}
45
46#[derive(Clone, Debug)]
47pub struct AgentWrapperMcpGetRequest {
48    pub name: String,
49    pub context: AgentWrapperMcpCommandContext,
50}
51
52#[derive(Clone, Debug)]
53pub struct AgentWrapperMcpRemoveRequest {
54    pub name: String,
55    pub context: AgentWrapperMcpCommandContext,
56}
57
58#[derive(Clone, Debug)]
59pub enum AgentWrapperMcpAddTransport {
60    /// Launches an MCP server via stdio.
61    Stdio {
62        /// Command argv (MUST be non-empty).
63        command: Vec<String>,
64        /// Additional argv items appended after `command`.
65        args: Vec<String>,
66        /// Env vars injected into the MCP server process.
67        env: BTreeMap<String, String>,
68    },
69    /// Connects to a streamable HTTP MCP server.
70    Url {
71        url: String,
72        bearer_token_env_var: Option<String>,
73    },
74}
75
76#[derive(Clone, Debug)]
77pub struct AgentWrapperMcpAddRequest {
78    pub name: String,
79    pub transport: AgentWrapperMcpAddTransport,
80    pub context: AgentWrapperMcpCommandContext,
81}
82
83pub(crate) fn normalize_server_name(name: &str) -> Result<String, AgentWrapperError> {
84    let name = name.trim();
85    if name.is_empty() {
86        return Err(invalid_request(ERR_MCP_SERVER_NAME_EMPTY));
87    }
88
89    Ok(name.to_string())
90}
91
92pub(crate) fn normalize_add_transport(
93    transport: AgentWrapperMcpAddTransport,
94) -> Result<AgentWrapperMcpAddTransport, AgentWrapperError> {
95    match transport {
96        AgentWrapperMcpAddTransport::Stdio { command, args, env } => {
97            if command.is_empty() {
98                return Err(invalid_request(ERR_MCP_ADD_STDIO_COMMAND_EMPTY));
99            }
100
101            Ok(AgentWrapperMcpAddTransport::Stdio {
102                command: normalize_stdio_items(command, "mcp add stdio.command")?,
103                args: normalize_stdio_items(args, "mcp add stdio.args")?,
104                env,
105            })
106        }
107        AgentWrapperMcpAddTransport::Url {
108            url,
109            bearer_token_env_var,
110        } => {
111            let url = normalize_url(url)?;
112            let bearer_token_env_var = normalize_bearer_token_env_var(bearer_token_env_var)?;
113            Ok(AgentWrapperMcpAddTransport::Url {
114                url,
115                bearer_token_env_var,
116            })
117        }
118    }
119}
120
121pub(crate) fn normalize_mcp_get_request(
122    request: AgentWrapperMcpGetRequest,
123) -> Result<AgentWrapperMcpGetRequest, AgentWrapperError> {
124    Ok(AgentWrapperMcpGetRequest {
125        name: normalize_server_name(&request.name)?,
126        context: request.context,
127    })
128}
129
130pub(crate) fn normalize_mcp_add_request(
131    request: AgentWrapperMcpAddRequest,
132) -> Result<AgentWrapperMcpAddRequest, AgentWrapperError> {
133    Ok(AgentWrapperMcpAddRequest {
134        name: normalize_server_name(&request.name)?,
135        transport: normalize_add_transport(request.transport)?,
136        context: request.context,
137    })
138}
139
140pub(crate) fn normalize_mcp_remove_request(
141    request: AgentWrapperMcpRemoveRequest,
142) -> Result<AgentWrapperMcpRemoveRequest, AgentWrapperError> {
143    Ok(AgentWrapperMcpRemoveRequest {
144        name: normalize_server_name(&request.name)?,
145        context: request.context,
146    })
147}
148
149fn normalize_stdio_items(
150    items: Vec<String>,
151    field: &str,
152) -> Result<Vec<String>, AgentWrapperError> {
153    items
154        .into_iter()
155        .enumerate()
156        .map(|(idx, item)| {
157            let trimmed = item.trim();
158            if trimmed.is_empty() {
159                return Err(invalid_request(format!("{field}[{idx}] must be non-empty")));
160            }
161
162            Ok(trimmed.to_string())
163        })
164        .collect()
165}
166
167fn normalize_url(url: String) -> Result<String, AgentWrapperError> {
168    let url = url.trim();
169    if url.is_empty() {
170        return Err(invalid_request(ERR_MCP_ADD_URL_EMPTY));
171    }
172
173    let parsed = url::Url::parse(url).map_err(|_| invalid_request(ERR_MCP_ADD_URL_INVALID))?;
174    match parsed.scheme() {
175        "http" | "https" if has_http_authority_separator(url) => Ok(url.to_string()),
176        _ => Err(invalid_request(ERR_MCP_ADD_URL_INVALID)),
177    }
178}
179
180fn has_http_authority_separator(url: &str) -> bool {
181    url.get(..7)
182        .is_some_and(|prefix| prefix.eq_ignore_ascii_case("http://"))
183        || url
184            .get(..8)
185            .is_some_and(|prefix| prefix.eq_ignore_ascii_case("https://"))
186}
187
188fn normalize_bearer_token_env_var(
189    value: Option<String>,
190) -> Result<Option<String>, AgentWrapperError> {
191    match value {
192        Some(value) => {
193            let trimmed = value.trim();
194            if trimmed.is_empty() {
195                return Err(invalid_request(ERR_MCP_ADD_BEARER_TOKEN_ENV_VAR_EMPTY));
196            }
197            if !is_valid_env_var_name(trimmed) {
198                return Err(invalid_request(ERR_MCP_ADD_BEARER_TOKEN_ENV_VAR_INVALID));
199            }
200
201            Ok(Some(trimmed.to_string()))
202        }
203        None => Ok(None),
204    }
205}
206
207fn is_valid_env_var_name(value: &str) -> bool {
208    let mut chars = value.chars();
209    let Some(first) = chars.next() else {
210        return false;
211    };
212    if !(first.is_ascii_alphabetic() || first == '_') {
213        return false;
214    }
215
216    chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
217}
218
219fn invalid_request(message: impl Into<String>) -> AgentWrapperError {
220    AgentWrapperError::InvalidRequest {
221        message: message.into(),
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    fn sample_context() -> AgentWrapperMcpCommandContext {
230        let mut env = BTreeMap::new();
231        env.insert("UNCHANGED".to_string(), "  value with spaces  ".to_string());
232        AgentWrapperMcpCommandContext {
233            working_dir: Some(PathBuf::from("relative/workdir")),
234            timeout: Some(Duration::from_secs(30)),
235            env,
236        }
237    }
238
239    fn assert_invalid_request(
240        result: Result<(), AgentWrapperError>,
241        expected_message: &str,
242        redacted_values: &[&str],
243    ) {
244        match result {
245            Err(AgentWrapperError::InvalidRequest { message }) => {
246                assert_eq!(message, expected_message);
247                for value in redacted_values {
248                    assert!(
249                        !message.contains(value),
250                        "message leaked raw input `{value}`: {message}"
251                    );
252                }
253            }
254            Err(other) => panic!("expected InvalidRequest, got {other:?}"),
255            Ok(()) => panic!("expected InvalidRequest"),
256        }
257    }
258
259    #[test]
260    fn normalize_server_name_trims_and_rejects_empty_values() {
261        assert_eq!(
262            normalize_server_name("  demo-server  ").expect("name should normalize"),
263            "demo-server"
264        );
265
266        assert_invalid_request(
267            normalize_server_name("   \n\t  ").map(|_| ()),
268            ERR_MCP_SERVER_NAME_EMPTY,
269            &[],
270        );
271    }
272
273    #[test]
274    fn normalize_mcp_add_request_trims_fields_and_preserves_context_and_env_maps() {
275        let context = sample_context();
276        let mut transport_env = BTreeMap::new();
277        transport_env.insert("KEEP".to_string(), "  exact value  ".to_string());
278
279        let request = AgentWrapperMcpAddRequest {
280            name: "  example  ".to_string(),
281            transport: AgentWrapperMcpAddTransport::Stdio {
282                command: vec!["  bin/example  ".to_string()],
283                args: vec!["  --flag  ".to_string(), "  value  ".to_string()],
284                env: transport_env.clone(),
285            },
286            context: context.clone(),
287        };
288
289        let normalized = normalize_mcp_add_request(request).expect("request should normalize");
290        assert_eq!(normalized.name, "example");
291        assert_eq!(normalized.context.working_dir, context.working_dir);
292        assert_eq!(normalized.context.timeout, context.timeout);
293        assert_eq!(normalized.context.env, context.env);
294        match normalized.transport {
295            AgentWrapperMcpAddTransport::Stdio { command, args, env } => {
296                assert_eq!(command, vec!["bin/example".to_string()]);
297                assert_eq!(args, vec!["--flag".to_string(), "value".to_string()]);
298                assert_eq!(env, transport_env);
299            }
300            AgentWrapperMcpAddTransport::Url { .. } => panic!("expected stdio transport"),
301        }
302    }
303
304    #[test]
305    fn normalize_add_transport_accepts_and_trims_valid_url_transport() {
306        let transport = AgentWrapperMcpAddTransport::Url {
307            url: "  https://example.com/mcp  ".to_string(),
308            bearer_token_env_var: Some("  TOKEN_NAME  ".to_string()),
309        };
310
311        let normalized =
312            normalize_add_transport(transport).expect("url transport should normalize");
313        match normalized {
314            AgentWrapperMcpAddTransport::Url {
315                url,
316                bearer_token_env_var,
317            } => {
318                assert_eq!(url, "https://example.com/mcp");
319                assert_eq!(bearer_token_env_var.as_deref(), Some("TOKEN_NAME"));
320            }
321            AgentWrapperMcpAddTransport::Stdio { .. } => panic!("expected url transport"),
322        }
323    }
324
325    #[test]
326    fn normalize_add_transport_rejects_invalid_stdio_fields_without_leaking_raw_values() {
327        let secret = "SECRET_STDIO_VALUE";
328
329        assert_invalid_request(
330            normalize_add_transport(AgentWrapperMcpAddTransport::Stdio {
331                command: Vec::new(),
332                args: Vec::new(),
333                env: BTreeMap::new(),
334            })
335            .map(|_| ()),
336            ERR_MCP_ADD_STDIO_COMMAND_EMPTY,
337            &[],
338        );
339
340        assert_invalid_request(
341            normalize_add_transport(AgentWrapperMcpAddTransport::Stdio {
342                command: vec![format!("  {secret}  "), "   ".to_string()],
343                args: Vec::new(),
344                env: BTreeMap::new(),
345            })
346            .map(|_| ()),
347            "mcp add stdio.command[1] must be non-empty",
348            &[secret],
349        );
350
351        assert_invalid_request(
352            normalize_add_transport(AgentWrapperMcpAddTransport::Stdio {
353                command: vec!["cmd".to_string()],
354                args: vec![format!("  {secret}  "), "   ".to_string()],
355                env: BTreeMap::new(),
356            })
357            .map(|_| ()),
358            "mcp add stdio.args[1] must be non-empty",
359            &[secret],
360        );
361    }
362
363    #[test]
364    fn normalize_add_transport_rejects_invalid_url_fields_without_leaking_raw_values() {
365        let secret = "SECRET_URL_VALUE";
366
367        assert_invalid_request(
368            normalize_add_transport(AgentWrapperMcpAddTransport::Url {
369                url: "   ".to_string(),
370                bearer_token_env_var: None,
371            })
372            .map(|_| ()),
373            ERR_MCP_ADD_URL_EMPTY,
374            &[],
375        );
376
377        for raw in [
378            format!(" {secret} "),
379            format!("relative/{secret}"),
380            format!("ftp://{secret}.example.com"),
381            format!("http:// space/{secret}"),
382            format!("https:{secret}.example.com"),
383            format!("http:{secret}"),
384            format!("https:/{secret}.example.com"),
385        ] {
386            assert_invalid_request(
387                normalize_add_transport(AgentWrapperMcpAddTransport::Url {
388                    url: raw,
389                    bearer_token_env_var: None,
390                })
391                .map(|_| ()),
392                ERR_MCP_ADD_URL_INVALID,
393                &[secret],
394            );
395        }
396
397        assert_invalid_request(
398            normalize_add_transport(AgentWrapperMcpAddTransport::Url {
399                url: "https://example.com/mcp".to_string(),
400                bearer_token_env_var: Some("   ".to_string()),
401            })
402            .map(|_| ()),
403            ERR_MCP_ADD_BEARER_TOKEN_ENV_VAR_EMPTY,
404            &[],
405        );
406
407        for raw in [
408            format!("9{secret}"),
409            format!("BAD-{secret}"),
410            format!("bad space {secret}"),
411        ] {
412            assert_invalid_request(
413                normalize_add_transport(AgentWrapperMcpAddTransport::Url {
414                    url: "https://example.com/mcp".to_string(),
415                    bearer_token_env_var: Some(raw),
416                })
417                .map(|_| ()),
418                ERR_MCP_ADD_BEARER_TOKEN_ENV_VAR_INVALID,
419                &[secret],
420            );
421        }
422    }
423
424    #[test]
425    fn normalize_get_and_remove_requests_trim_name_and_preserve_context() {
426        let context = sample_context();
427
428        let get = normalize_mcp_get_request(AgentWrapperMcpGetRequest {
429            name: "  get-name  ".to_string(),
430            context: context.clone(),
431        })
432        .expect("get request should normalize");
433        assert_eq!(get.name, "get-name");
434        assert_eq!(get.context.env, context.env);
435
436        let remove = normalize_mcp_remove_request(AgentWrapperMcpRemoveRequest {
437            name: "  remove-name  ".to_string(),
438            context: context.clone(),
439        })
440        .expect("remove request should normalize");
441        assert_eq!(remove.name, "remove-name");
442        assert_eq!(remove.context.working_dir, context.working_dir);
443        assert_eq!(remove.context.timeout, context.timeout);
444    }
445}