Skip to main content

zeph_mcp/
executor.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::sync::Arc;
5
6use parking_lot::RwLock;
7
8use zeph_common::ToolName;
9use zeph_tools::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, extract_fenced_blocks};
10use zeph_tools::registry::{InvocationHint, ToolDef};
11
12use crate::manager::McpManager;
13use crate::tool::McpTool;
14
15/// [`ToolExecutor`] implementation that dispatches tool calls to MCP servers.
16///
17/// `McpToolExecutor` bridges the `zeph-tools` dispatch layer and `McpManager`. It
18/// maintains a local snapshot of the registered MCP tools (updated via
19/// [`set_tools`](McpToolExecutor::set_tools)) and resolves tool calls by matching
20/// the sanitized tool ID against the snapshot before forwarding to the manager.
21///
22/// # Security invariant
23///
24/// [`execute_tool_call`](McpToolExecutor::execute_tool_call) sets
25/// `ToolOutput::tool_name` to [`McpTool::qualified_name`] (i.e. `"server_id:name"`).
26/// The `':'` in the name is the signal used by `zeph-core`'s `sanitize_tool_output()`
27/// to route responses through the quarantine pipeline. Do not change this.
28///
29/// # Fenced-block execution
30///
31/// [`execute`](McpToolExecutor::execute) parses ```` ```mcp ```` fenced blocks
32/// from LLM output and validates each `server:tool` pair against the registered list
33/// before dispatching, preventing prompt injection from routing calls to unknown servers.
34#[derive(Debug, Clone)]
35pub struct McpToolExecutor {
36    manager: Arc<McpManager>,
37    tools: Arc<RwLock<Vec<McpTool>>>,
38}
39
40impl McpToolExecutor {
41    /// Create a new executor from a shared `McpManager` and a shared tool list.
42    ///
43    /// The `tools` `RwLock` is updated via [`set_tools`](Self::set_tools) after each
44    /// connect or refresh. Pass the same `Arc<RwLock<Vec<McpTool>>>` to both the executor
45    /// and the code that handles `tools/list_changed` events.
46    #[must_use]
47    pub fn new(manager: Arc<McpManager>, tools: Arc<RwLock<Vec<McpTool>>>) -> Self {
48        Self { manager, tools }
49    }
50
51    /// Replace the registered tool snapshot.
52    ///
53    /// Logs a `WARN` for each `sanitized_id` collision: when two tools map to the same
54    /// sanitized ID the second is unreachable via [`execute_tool_call`](Self::execute_tool_call).
55    pub fn set_tools(&self, tools: Vec<McpTool>) {
56        // Warn on sanitized_id collisions: two tools mapping to the same id means
57        // the second will be unreachable via execute_tool_call.
58        let mut seen = std::collections::HashMap::new();
59        for t in &tools {
60            let sid = t.sanitized_id();
61            if let Some(prev) = seen.insert(sid.clone(), t.qualified_name()) {
62                tracing::warn!(
63                    sanitized_id = %sid,
64                    first = %prev,
65                    second = %t.qualified_name(),
66                    "MCP tool sanitized_id collision: second tool will be unreachable"
67                );
68            }
69        }
70        let mut guard = self.tools.write();
71        *guard = tools;
72    }
73}
74
75impl ToolExecutor for McpToolExecutor {
76    fn tool_definitions(&self) -> Vec<ToolDef> {
77        let tools = self.tools.read();
78        tools
79            .iter()
80            .map(|t| ToolDef {
81                id: t.sanitized_id().into(),
82                description: t.description.clone().into(),
83                schema: serde_json::from_value(t.input_schema.clone())
84                    .unwrap_or_else(|_| schemars::Schema::default()),
85                invocation: InvocationHint::ToolCall,
86                output_schema: t.output_schema.clone(),
87            })
88            .collect()
89    }
90
91    #[cfg_attr(
92        feature = "profiling",
93        tracing::instrument(name = "mcp.executor.execute_tool_call", skip_all, fields(tool_id = %call.tool_id))
94    )]
95    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
96        // Lookup by sanitized_id because the LLM sees sanitized names (no ':' character).
97        //
98        // IMPORTANT: ToolOutput.tool_name MUST be set to qualified_name() (not sanitized_id()).
99        // sanitize_tool_output() in zeph-core classifies tool output as external/untrusted MCP
100        // content by checking tool_name.contains(':').  Breaking this invariant would silently
101        // route MCP responses through the local/trusted pipeline, bypassing quarantine.
102        let found = {
103            let tools = self.tools.read();
104            tools
105                .iter()
106                .find(|t| t.sanitized_id() == call.tool_id)
107                .cloned()
108        };
109        let Some(tool) = found else {
110            return Ok(None);
111        };
112
113        let args = serde_json::Value::Object(call.params.clone());
114        let result = self
115            .manager
116            .call_tool(&tool.server_id, &tool.name, args)
117            .await
118            .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
119
120        let raw_text = result
121            .content
122            .iter()
123            .filter_map(|c| {
124                if let rmcp::model::RawContent::Text(t) = &c.raw {
125                    Some(t.text.as_str())
126                } else {
127                    None
128                }
129            })
130            .collect::<Vec<_>>()
131            .join("\n");
132
133        let text = crate::sanitize::intent_anchor_wrap(&tool.server_id, &tool.name, &raw_text);
134
135        Ok(Some(ToolOutput {
136            tool_name: tool.qualified_name().into(),
137            summary: text,
138            blocks_executed: 1,
139            filter_stats: None,
140            diff: None,
141            streamed: false,
142            terminal_id: None,
143            locations: None,
144            raw_response: None,
145            claim_source: Some(zeph_tools::ClaimSource::Mcp),
146        }))
147    }
148
149    #[cfg_attr(
150        feature = "profiling",
151        tracing::instrument(name = "mcp.executor.execute", skip_all)
152    )]
153    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
154        let blocks = extract_fenced_blocks(response, "mcp");
155        if blocks.is_empty() {
156            return Ok(None);
157        }
158
159        let mut outputs = Vec::with_capacity(blocks.len());
160        #[allow(clippy::cast_possible_truncation)]
161        let blocks_executed = blocks.len() as u32;
162
163        for block in &blocks {
164            let instruction: McpInstruction =
165                serde_json::from_str(block).map_err(|e: serde_json::Error| {
166                    ToolError::Execution(std::io::Error::other(e.to_string()))
167                })?;
168
169            // SECURITY: Validate server:tool against the registered tool list before dispatch.
170            // This prevents a prompt injection from routing calls to unregistered servers or tools.
171            let found = {
172                let tools = self.tools.read();
173                tools
174                    .iter()
175                    .find(|t| t.server_id == instruction.server && t.name == instruction.tool)
176                    .cloned()
177            };
178            let Some(tool) = found else {
179                return Err(ToolError::Execution(std::io::Error::other(format!(
180                    "MCP tool {}:{} not in registered tool list",
181                    instruction.server, instruction.tool
182                ))));
183            };
184
185            // Delegate to execute_tool_call() so all security layers apply.
186            let call = ToolCall {
187                tool_id: tool.sanitized_id().into(),
188                params: match instruction.args {
189                    serde_json::Value::Object(map) => map,
190                    _ => serde_json::Map::new(),
191                },
192                caller_id: None,
193                context: None,
194                tool_call_id: String::new(),
195                skill_name: None,
196            };
197            if let Some(output) = self.execute_tool_call(&call).await? {
198                outputs.push(output.summary);
199            }
200        }
201
202        Ok(Some(ToolOutput {
203            // SECURITY: Use qualified format so quarantine routing works (tool_name must contain ':').
204            tool_name: ToolName::new("mcp:fenced_block"),
205            summary: outputs.join("\n\n"),
206            blocks_executed,
207            filter_stats: None,
208            diff: None,
209            streamed: false,
210            terminal_id: None,
211            locations: None,
212            raw_response: None,
213            claim_source: Some(zeph_tools::ClaimSource::Mcp),
214        }))
215    }
216}
217
218#[derive(serde::Deserialize)]
219struct McpInstruction {
220    server: String,
221    tool: String,
222    #[serde(default = "default_args")]
223    args: serde_json::Value,
224}
225
226fn default_args() -> serde_json::Value {
227    serde_json::Value::Object(serde_json::Map::new())
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use crate::policy::PolicyEnforcer;
234
235    fn make_executor() -> McpToolExecutor {
236        let mgr = Arc::new(McpManager::new(vec![], vec![], PolicyEnforcer::new(vec![])));
237        let tools = Arc::new(RwLock::new(vec![]));
238        McpToolExecutor::new(mgr, tools)
239    }
240
241    #[test]
242    fn parse_instruction_full() {
243        let json = r#"{"server": "github", "tool": "create_issue", "args": {"title": "bug"}}"#;
244        let instr: McpInstruction = serde_json::from_str(json).unwrap();
245        assert_eq!(instr.server, "github");
246        assert_eq!(instr.tool, "create_issue");
247        assert_eq!(instr.args["title"], "bug");
248    }
249
250    #[test]
251    fn parse_instruction_no_args() {
252        let json = r#"{"server": "fs", "tool": "list_dir"}"#;
253        let instr: McpInstruction = serde_json::from_str(json).unwrap();
254        assert_eq!(instr.server, "fs");
255        assert_eq!(instr.tool, "list_dir");
256        assert!(instr.args.is_object());
257    }
258
259    #[test]
260    fn parse_instruction_empty_args() {
261        let json = r#"{"server": "s", "tool": "t", "args": {}}"#;
262        let instr: McpInstruction = serde_json::from_str(json).unwrap();
263        assert!(instr.args.as_object().unwrap().is_empty());
264    }
265
266    #[test]
267    fn parse_instruction_missing_server_fails() {
268        let json = r#"{"tool": "t"}"#;
269        assert!(serde_json::from_str::<McpInstruction>(json).is_err());
270    }
271
272    #[test]
273    fn parse_instruction_missing_tool_fails() {
274        let json = r#"{"server": "s"}"#;
275        assert!(serde_json::from_str::<McpInstruction>(json).is_err());
276    }
277
278    #[test]
279    fn extract_mcp_blocks() {
280        let text = "Here:\n```mcp\n{\"server\":\"a\",\"tool\":\"b\"}\n```\nDone";
281        let blocks = extract_fenced_blocks(text, "mcp");
282        assert_eq!(blocks.len(), 1);
283        assert!(blocks[0].contains("\"server\""));
284    }
285
286    #[test]
287    fn no_mcp_blocks() {
288        let text = "```bash\necho hello\n```";
289        let blocks = extract_fenced_blocks(text, "mcp");
290        assert!(blocks.is_empty());
291    }
292
293    #[test]
294    fn multiple_mcp_blocks() {
295        let text = "```mcp\n{\"server\":\"a\",\"tool\":\"b\"}\n```\n\
296                    text\n\
297                    ```mcp\n{\"server\":\"c\",\"tool\":\"d\"}\n```";
298        let blocks = extract_fenced_blocks(text, "mcp");
299        assert_eq!(blocks.len(), 2);
300    }
301
302    #[test]
303    fn parse_instruction_invalid_json() {
304        let json = r"{not valid json}";
305        assert!(serde_json::from_str::<McpInstruction>(json).is_err());
306    }
307
308    #[test]
309    fn parse_instruction_extra_fields_ignored() {
310        let json = r#"{"server":"s","tool":"t","args":{},"extra":"ignored"}"#;
311        let instr: McpInstruction = serde_json::from_str(json).unwrap();
312        assert_eq!(instr.server, "s");
313        assert_eq!(instr.tool, "t");
314    }
315
316    #[test]
317    fn parse_instruction_args_array() {
318        let json = r#"{"server":"s","tool":"t","args":["a","b"]}"#;
319        let instr: McpInstruction = serde_json::from_str(json).unwrap();
320        assert!(instr.args.is_array());
321    }
322
323    #[test]
324    fn parse_instruction_args_nested() {
325        let json = r#"{"server":"s","tool":"t","args":{"nested":{"key":"val"}}}"#;
326        let instr: McpInstruction = serde_json::from_str(json).unwrap();
327        assert_eq!(instr.args["nested"]["key"], "val");
328    }
329
330    #[test]
331    fn default_args_is_empty_object() {
332        let val = default_args();
333        assert!(val.is_object());
334        assert!(val.as_object().unwrap().is_empty());
335    }
336
337    #[test]
338    fn extract_mcp_blocks_empty_input() {
339        let blocks = extract_fenced_blocks("", "mcp");
340        assert!(blocks.is_empty());
341    }
342
343    #[test]
344    fn extract_mcp_blocks_other_lang_ignored() {
345        let text =
346            "```json\n{\"key\":\"val\"}\n```\n```mcp\n{\"server\":\"a\",\"tool\":\"b\"}\n```";
347        let blocks = extract_fenced_blocks(text, "mcp");
348        assert_eq!(blocks.len(), 1);
349        assert!(blocks[0].contains("\"server\""));
350    }
351
352    #[test]
353    fn executor_construction() {
354        let executor = make_executor();
355        let dbg = format!("{executor:?}");
356        assert!(dbg.contains("McpToolExecutor"));
357    }
358
359    #[test]
360    fn tool_definitions_empty_when_no_tools() {
361        let executor = make_executor();
362        assert!(executor.tool_definitions().is_empty());
363    }
364
365    #[test]
366    fn tool_definitions_returns_sanitized_names() {
367        let mgr = Arc::new(McpManager::new(vec![], vec![], PolicyEnforcer::new(vec![])));
368        let tools = Arc::new(RwLock::new(vec![McpTool {
369            server_id: "gh".into(),
370            name: "create_issue".into(),
371            description: "Create a GitHub issue".into(),
372            input_schema: serde_json::json!({}),
373            output_schema: None,
374            security_meta: crate::tool::ToolSecurityMeta::default(),
375        }]));
376        let executor = McpToolExecutor::new(mgr, tools);
377        let defs = executor.tool_definitions();
378        assert_eq!(defs.len(), 1);
379        assert_eq!(defs[0].id.as_ref(), "gh_create_issue");
380        assert_eq!(defs[0].description.as_ref(), "Create a GitHub issue");
381    }
382
383    #[test]
384    fn set_tools_updates_definitions() {
385        let executor = make_executor();
386        assert!(executor.tool_definitions().is_empty());
387        executor.set_tools(vec![McpTool {
388            server_id: "fs".into(),
389            name: "list_dir".into(),
390            description: "List directory".into(),
391            input_schema: serde_json::json!({}),
392            output_schema: None,
393            security_meta: crate::tool::ToolSecurityMeta::default(),
394        }]);
395        let defs = executor.tool_definitions();
396        assert_eq!(defs.len(), 1);
397        assert_eq!(defs[0].id.as_ref(), "fs_list_dir");
398    }
399
400    #[tokio::test]
401    async fn execute_no_blocks_returns_none() {
402        let executor = make_executor();
403        let result = executor.execute("no mcp blocks here").await.unwrap();
404        assert!(result.is_none());
405    }
406
407    #[tokio::test]
408    async fn execute_invalid_json_block_returns_error() {
409        let executor = make_executor();
410        let text = "```mcp\nnot json\n```";
411        let result = executor.execute(text).await;
412        assert!(result.is_err());
413    }
414
415    #[tokio::test]
416    async fn execute_valid_block_tool_not_registered_returns_error() {
417        // Tool is not in the registered list → rejected before any server call.
418        let executor = make_executor();
419        let text = "```mcp\n{\"server\":\"missing\",\"tool\":\"t\"}\n```";
420        let result = executor.execute(text).await;
421        assert!(result.is_err());
422        let err_msg = result.unwrap_err().to_string();
423        assert!(
424            err_msg.contains("not in registered tool list"),
425            "expected 'not in registered tool list' in: {err_msg}"
426        );
427    }
428
429    #[tokio::test]
430    async fn execute_fenced_block_tool_name_contains_colon() {
431        // Verify the output tool_name uses qualified format for quarantine routing.
432        // We can't easily run a full call, but we can verify the rejection error path
433        // hits before any server dispatch.
434        let executor = make_executor();
435        // Register a real tool so the lookup can succeed but server call fails.
436        executor.set_tools(vec![McpTool {
437            server_id: "srv".into(),
438            name: "tool".into(),
439            description: "d".into(),
440            input_schema: serde_json::json!({}),
441            output_schema: None,
442            security_meta: crate::tool::ToolSecurityMeta::default(),
443        }]);
444        let text = "```mcp\n{\"server\":\"srv\",\"tool\":\"tool\"}\n```";
445        // Server not actually connected, so execute_tool_call returns Err.
446        let result = executor.execute(text).await;
447        assert!(result.is_err(), "expected Err when server is not connected");
448    }
449
450    #[tokio::test]
451    async fn execute_tool_call_unknown_format_returns_none() {
452        let executor = make_executor();
453        let call = ToolCall {
454            tool_id: ToolName::new("no_colon_here"),
455            params: serde_json::Map::new(),
456            caller_id: None,
457            context: None,
458
459            tool_call_id: String::new(),
460            skill_name: None,
461        };
462        let result = executor.execute_tool_call(&call).await.unwrap();
463        assert!(result.is_none());
464    }
465
466    #[tokio::test]
467    async fn execute_tool_call_unknown_server_returns_none() {
468        let executor = make_executor();
469        let call = ToolCall {
470            tool_id: ToolName::new("unknown_server:tool"),
471            params: serde_json::Map::new(),
472            caller_id: None,
473            context: None,
474
475            tool_call_id: String::new(),
476            skill_name: None,
477        };
478        let result = executor.execute_tool_call(&call).await.unwrap();
479        assert!(result.is_none());
480    }
481
482    // --- sanitized_id routing tests ---
483
484    #[tokio::test]
485    async fn execute_tool_call_by_sanitized_id_not_found_returns_none() {
486        // Register a tool whose sanitized_id is "gh_create_issue".
487        // A call with tool_id "gh_create_issue" matches; a call with a different id does not.
488        let mgr = Arc::new(McpManager::new(vec![], vec![], PolicyEnforcer::new(vec![])));
489        let tools = Arc::new(RwLock::new(vec![McpTool {
490            server_id: "gh".into(),
491            name: "create_issue".into(),
492            description: "desc".into(),
493            input_schema: serde_json::json!({}),
494            output_schema: None,
495            security_meta: crate::tool::ToolSecurityMeta::default(),
496        }]));
497        let executor = McpToolExecutor::new(mgr, tools);
498
499        // A different sanitized id must not match.
500        let call = ToolCall {
501            tool_id: ToolName::new("gh_list_issues"),
502            params: serde_json::Map::new(),
503            caller_id: None,
504            context: None,
505
506            tool_call_id: String::new(),
507            skill_name: None,
508        };
509        let result = executor.execute_tool_call(&call).await.unwrap();
510        assert!(result.is_none());
511    }
512
513    #[tokio::test]
514    async fn execute_tool_call_by_sanitized_id_matched_but_server_missing_returns_err() {
515        // Register a tool so the lookup succeeds, but the manager has no connected server —
516        // the call_tool on the manager must return an error (not None).
517        let mgr = Arc::new(McpManager::new(vec![], vec![], PolicyEnforcer::new(vec![])));
518        let tools = Arc::new(RwLock::new(vec![McpTool {
519            server_id: "missing_server".into(),
520            name: "some_tool".into(),
521            description: "desc".into(),
522            input_schema: serde_json::json!({}),
523            output_schema: None,
524            security_meta: crate::tool::ToolSecurityMeta::default(),
525        }]));
526        let executor = McpToolExecutor::new(mgr, tools);
527
528        // tool_id matches the sanitized_id "missing_server_some_tool".
529        let call = ToolCall {
530            tool_id: ToolName::new("missing_server_some_tool"),
531            params: serde_json::Map::new(),
532            caller_id: None,
533            context: None,
534
535            tool_call_id: String::new(),
536            skill_name: None,
537        };
538        let result = executor.execute_tool_call(&call).await;
539        assert!(result.is_err(), "expected Err when server is not connected");
540    }
541
542    #[test]
543    fn tool_definitions_sanitized_id_has_no_colon() {
544        // After the fix, no tool definition id may contain ':'.
545        let mgr = Arc::new(McpManager::new(vec![], vec![], PolicyEnforcer::new(vec![])));
546        let tools = Arc::new(RwLock::new(vec![
547            McpTool {
548                server_id: "srv-one".into(),
549                name: "tool:with:colons".into(),
550                description: "d".into(),
551                input_schema: serde_json::json!({}),
552                output_schema: None,
553                security_meta: crate::tool::ToolSecurityMeta::default(),
554            },
555            McpTool {
556                server_id: "srv.two".into(),
557                name: "normal_tool".into(),
558                description: "d".into(),
559                input_schema: serde_json::json!({}),
560                output_schema: None,
561                security_meta: crate::tool::ToolSecurityMeta::default(),
562            },
563        ]));
564        let executor = McpToolExecutor::new(mgr, tools);
565        let defs = executor.tool_definitions();
566        assert_eq!(defs.len(), 2);
567        for def in &defs {
568            assert!(
569                !def.id.contains(':'),
570                "tool id must not contain ':' but got: {}",
571                def.id
572            );
573        }
574    }
575
576    #[test]
577    fn tool_definitions_sanitized_id_matches_expected_pattern() {
578        // Verify that every character in every id matches [a-zA-Z0-9_-].
579        let mgr = Arc::new(McpManager::new(vec![], vec![], PolicyEnforcer::new(vec![])));
580        let tools = Arc::new(RwLock::new(vec![McpTool {
581            server_id: "my.server".into(),
582            name: "tool name!".into(),
583            description: "d".into(),
584            input_schema: serde_json::json!({}),
585            output_schema: None,
586            security_meta: crate::tool::ToolSecurityMeta::default(),
587        }]));
588        let executor = McpToolExecutor::new(mgr, tools);
589        let defs = executor.tool_definitions();
590        assert_eq!(defs.len(), 1);
591        let id = defs[0].id.as_ref();
592        assert!(
593            id.chars()
594                .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-'),
595            "id contains invalid chars: {id}"
596        );
597    }
598}