wallfacer_core/run/
exec.rs1use std::{collections::HashMap, future::Future, pin::Pin, sync::Mutex, time::Duration};
8
9use async_trait::async_trait;
10use rmcp::model::Tool;
11use serde_json::Value;
12
13use crate::client::{CallOutcome, Client, ClientError};
14
15type CallFuture = Pin<Box<dyn Future<Output = CallOutcome> + Send>>;
18
19type SyncHandler = Box<dyn Fn(&Value) -> CallOutcome + Send + Sync>;
21
22type AsyncHandler = Box<dyn Fn(&Value) -> CallFuture + Send + Sync>;
24
25enum MockHandler {
26 Sync(SyncHandler),
27 Async(AsyncHandler),
28}
29
30#[async_trait]
37pub trait McpExec: Send + Sync {
38 async fn list_tools(&self) -> Result<Vec<Tool>, ClientError>;
40
41 async fn call_tool(&self, name: &str, arguments: Value, timeout: Duration) -> CallOutcome;
44
45 async fn reconnect(&self) -> Result<(), ClientError>;
49}
50
51#[async_trait]
52impl McpExec for Client {
53 async fn list_tools(&self) -> Result<Vec<Tool>, ClientError> {
54 Client::list_tools(self).await
55 }
56
57 async fn call_tool(&self, name: &str, arguments: Value, timeout: Duration) -> CallOutcome {
58 Client::call_tool(self, name, arguments, timeout).await
59 }
60
61 async fn reconnect(&self) -> Result<(), ClientError> {
62 Client::reconnect(self).await
63 }
64}
65
66pub struct MockClient {
70 tools: Vec<Tool>,
71 handlers: Mutex<HashMap<String, MockHandler>>,
72 reconnect_count: Mutex<usize>,
75}
76
77impl MockClient {
78 #[must_use]
80 pub fn new() -> Self {
81 Self {
82 tools: Vec::new(),
83 handlers: Mutex::new(HashMap::new()),
84 reconnect_count: Mutex::new(0),
85 }
86 }
87
88 pub fn register<F>(mut self, tool: Tool, handler: F) -> Self
92 where
93 F: Fn(&Value) -> CallOutcome + Send + Sync + 'static,
94 {
95 let name = tool.name.to_string();
96 self.tools.push(tool);
97 self.handlers
98 .lock()
99 .unwrap_or_else(|p| p.into_inner())
100 .insert(name, MockHandler::Sync(Box::new(handler)));
101 self
102 }
103
104 pub fn register_async<F, Fut>(mut self, tool: Tool, handler: F) -> Self
108 where
109 F: Fn(&Value) -> Fut + Send + Sync + 'static,
110 Fut: Future<Output = CallOutcome> + Send + 'static,
111 {
112 let name = tool.name.to_string();
113 self.tools.push(tool);
114 let boxed: AsyncHandler = Box::new(move |args| Box::pin(handler(args)));
115 self.handlers
116 .lock()
117 .unwrap_or_else(|p| p.into_inner())
118 .insert(name, MockHandler::Async(boxed));
119 self
120 }
121
122 #[must_use]
124 pub fn reconnect_count(&self) -> usize {
125 *self
126 .reconnect_count
127 .lock()
128 .unwrap_or_else(|p| p.into_inner())
129 }
130}
131
132impl Default for MockClient {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138#[async_trait]
139impl McpExec for MockClient {
140 async fn list_tools(&self) -> Result<Vec<Tool>, ClientError> {
141 Ok(self.tools.clone())
142 }
143
144 async fn call_tool(&self, name: &str, arguments: Value, _timeout: Duration) -> CallOutcome {
145 let future = {
148 let handlers = self.handlers.lock().unwrap_or_else(|p| p.into_inner());
149 match handlers.get(name) {
150 Some(MockHandler::Sync(handler)) => return handler(&arguments),
151 Some(MockHandler::Async(handler)) => handler(&arguments),
152 None => return CallOutcome::ProtocolError(format!("unknown tool `{name}`")),
153 }
154 };
155 future.await
156 }
157
158 async fn reconnect(&self) -> Result<(), ClientError> {
159 *self
160 .reconnect_count
161 .lock()
162 .unwrap_or_else(|p| p.into_inner()) += 1;
163 Ok(())
164 }
165}