Skip to main content

wallfacer_core/run/
exec.rs

1//! Execution abstraction over the MCP client.
2//!
3//! Plans take a `&mut impl McpExec` rather than the concrete
4//! [`crate::client::Client`] so they can be unit-tested with [`MockClient`]
5//! without spawning a child process.
6
7use std::{collections::HashMap, future::Future, pin::Pin, sync::Mutex, time::Duration};
8
9use async_trait::async_trait;
10use rmcp::model::Tool;
11use serde_json::Value;
12
13use crate::client::{CallOutcome, Client, ClientError};
14
15/// Boxed future returned by [`MockClient`] async handlers. The future
16/// returns a [`CallOutcome`].
17type CallFuture = Pin<Box<dyn Future<Output = CallOutcome> + Send>>;
18
19/// Sync closure type for [`MockClient`] tool handlers.
20type SyncHandler = Box<dyn Fn(&Value) -> CallOutcome + Send + Sync>;
21
22/// Async closure type for [`MockClient`] tool handlers.
23type AsyncHandler = Box<dyn Fn(&Value) -> CallFuture + Send + Sync>;
24
25enum MockHandler {
26    Sync(SyncHandler),
27    Async(AsyncHandler),
28}
29
30/// Minimum surface a plan needs from an MCP client.
31///
32/// All methods take `&self` so plans can drive concurrent calls
33/// (`torture`-style) and recover from faults via `reconnect` without ever
34/// holding an exclusive borrow. Phase E1 made this possible by moving the
35/// production [`Client`] behind an internal `Arc<RwLock<...>>`.
36#[async_trait]
37pub trait McpExec: Send + Sync {
38    /// Lists every tool exposed by the server.
39    async fn list_tools(&self) -> Result<Vec<Tool>, ClientError>;
40
41    /// Calls a tool, applying `timeout`. Returns a [`CallOutcome`] that the
42    /// caller pattern-matches; this method itself never errors.
43    async fn call_tool(&self, name: &str, arguments: Value, timeout: Duration) -> CallOutcome;
44
45    /// Tears down the transport and rebuilds it. Used after a hang/crash
46    /// so subsequent calls can succeed. Concurrent callers see either the
47    /// old or the new transport, never a torn state.
48    async fn reconnect(&self) -> Result<(), ClientError>;
49}
50
51#[async_trait]
52impl McpExec for Client {
53    async fn list_tools(&self) -> Result<Vec<Tool>, ClientError> {
54        Client::list_tools(self).await
55    }
56
57    async fn call_tool(&self, name: &str, arguments: Value, timeout: Duration) -> CallOutcome {
58        Client::call_tool(self, name, arguments, timeout).await
59    }
60
61    async fn reconnect(&self) -> Result<(), ClientError> {
62        Client::reconnect(self).await
63    }
64}
65
66/// In-memory MCP client used for plan unit tests. Each tool is registered
67/// with a closure that maps the call arguments to a [`CallOutcome`]. The
68/// mock is fully synchronous internally; concurrency is provided by `tokio`.
69pub struct MockClient {
70    tools: Vec<Tool>,
71    handlers: Mutex<HashMap<String, MockHandler>>,
72    /// Number of `reconnect` calls observed; useful for tests that assert
73    /// the plan recovered after a fault.
74    reconnect_count: Mutex<usize>,
75}
76
77impl MockClient {
78    /// Creates an empty mock with no tools registered.
79    #[must_use]
80    pub fn new() -> Self {
81        Self {
82            tools: Vec::new(),
83            handlers: Mutex::new(HashMap::new()),
84            reconnect_count: Mutex::new(0),
85        }
86    }
87
88    /// Registers a tool with a synchronous handler. The handler is invoked
89    /// on every `call_tool`; it can return any [`CallOutcome`] variant to
90    /// simulate crashes, hangs, or protocol errors.
91    pub fn register<F>(mut self, tool: Tool, handler: F) -> Self
92    where
93        F: Fn(&Value) -> CallOutcome + Send + Sync + 'static,
94    {
95        let name = tool.name.to_string();
96        self.tools.push(tool);
97        self.handlers
98            .lock()
99            .unwrap_or_else(|p| p.into_inner())
100            .insert(name, MockHandler::Sync(Box::new(handler)));
101        self
102    }
103
104    /// Registers a tool with an async handler. Useful when testing
105    /// cancellation: the handler can `await` indefinitely and rely on
106    /// the caller's cancellation to drop it.
107    pub fn register_async<F, Fut>(mut self, tool: Tool, handler: F) -> Self
108    where
109        F: Fn(&Value) -> Fut + Send + Sync + 'static,
110        Fut: Future<Output = CallOutcome> + Send + 'static,
111    {
112        let name = tool.name.to_string();
113        self.tools.push(tool);
114        let boxed: AsyncHandler = Box::new(move |args| Box::pin(handler(args)));
115        self.handlers
116            .lock()
117            .unwrap_or_else(|p| p.into_inner())
118            .insert(name, MockHandler::Async(boxed));
119        self
120    }
121
122    /// Returns the number of times `reconnect` has been invoked.
123    #[must_use]
124    pub fn reconnect_count(&self) -> usize {
125        *self
126            .reconnect_count
127            .lock()
128            .unwrap_or_else(|p| p.into_inner())
129    }
130}
131
132impl Default for MockClient {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138#[async_trait]
139impl McpExec for MockClient {
140    async fn list_tools(&self) -> Result<Vec<Tool>, ClientError> {
141        Ok(self.tools.clone())
142    }
143
144    async fn call_tool(&self, name: &str, arguments: Value, _timeout: Duration) -> CallOutcome {
145        // Build the future under the mutex (cheap), then drop the guard
146        // before awaiting so concurrent calls aren't serialized.
147        let future = {
148            let handlers = self.handlers.lock().unwrap_or_else(|p| p.into_inner());
149            match handlers.get(name) {
150                Some(MockHandler::Sync(handler)) => return handler(&arguments),
151                Some(MockHandler::Async(handler)) => handler(&arguments),
152                None => return CallOutcome::ProtocolError(format!("unknown tool `{name}`")),
153            }
154        };
155        future.await
156    }
157
158    async fn reconnect(&self) -> Result<(), ClientError> {
159        *self
160            .reconnect_count
161            .lock()
162            .unwrap_or_else(|p| p.into_inner()) += 1;
163        Ok(())
164    }
165}