1use std::{
15 collections::HashMap, future::Future, path::Path, process::Stdio, sync::Arc, time::Duration,
16};
17
18use http::{HeaderName, HeaderValue};
19use rmcp::{
20 model::{CallToolRequestParams, CallToolResult, Prompt, Resource, ServerCapabilities, Tool},
21 service::{RoleClient, RunningService, RxJsonRpcMessage, TxJsonRpcMessage},
22 transport::{
23 async_rw::AsyncRwTransport, streamable_http_client::StreamableHttpClientTransportConfig,
24 StreamableHttpClientTransport, Transport as RmcpTransport,
25 },
26 ServiceExt,
27};
28use serde_json::Value;
29use thiserror::Error;
30use tokio::{
31 process::{Child, ChildStdin, ChildStdout, Command},
32 sync::RwLock,
33 time,
34};
35
36use crate::target::{Target, Transport as TargetTransport};
37
38const CHILD_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(3);
39
40#[derive(Debug, Error)]
41pub enum ClientError {
42 #[error("failed to spawn stdio transport: {0}")]
43 Spawn(#[source] std::io::Error),
44 #[error("failed to initialize MCP client: {0}")]
45 Initialize(String),
46 #[error("invalid HTTP header {name}: {message}")]
47 InvalidHeader { name: String, message: String },
48 #[error("failed to shut down MCP client: {0}")]
49 Shutdown(#[source] tokio::task::JoinError),
50 #[error("MCP request failed: {0}")]
51 Request(String),
52}
53
54pub type Result<T> = std::result::Result<T, ClientError>;
55
56#[derive(Clone)]
60pub struct Client {
61 service: Arc<RwLock<Option<RunningService<RoleClient, ()>>>>,
65 target: Target,
66}
67
68#[derive(Debug)]
69pub enum CallOutcome {
70 Ok(CallToolResult),
71 Hang(Duration),
72 Crash(String),
73 ProtocolError(String),
74}
75
76impl Client {
77 pub async fn connect(target: &Target) -> Result<Self> {
78 let service = build_service(target).await?;
79 Ok(Self {
80 service: Arc::new(RwLock::new(Some(service))),
81 target: target.clone(),
82 })
83 }
84
85 pub async fn reconnect(&self) -> Result<()> {
91 let replacement = build_service(&self.target).await?;
94 let mut guard = self.service.write().await;
95 if let Some(old) = guard.take() {
96 let _ = old.cancel().await;
97 }
98 *guard = Some(replacement);
99 Ok(())
100 }
101
102 pub async fn list_tools(&self) -> Result<Vec<Tool>> {
103 let guard = self.service.read().await;
104 let service = guard
105 .as_ref()
106 .ok_or_else(|| ClientError::Request("client has been shut down".into()))?;
107 service
108 .list_all_tools()
109 .await
110 .map_err(|error| ClientError::Request(error.to_string()))
111 }
112
113 pub async fn server_capabilities(&self) -> Option<ServerCapabilities> {
123 let guard = self.service.read().await;
124 guard
125 .as_ref()
126 .and_then(|service| service.peer_info())
127 .map(|info| info.capabilities.clone())
128 }
129
130 pub async fn list_resources(&self) -> Result<Vec<Resource>> {
137 let advertises = self
138 .server_capabilities()
139 .await
140 .is_some_and(|caps| caps.resources.is_some());
141 if !advertises {
142 return Ok(Vec::new());
143 }
144 let guard = self.service.read().await;
145 let service = guard
146 .as_ref()
147 .ok_or_else(|| ClientError::Request("client has been shut down".into()))?;
148 service
149 .list_all_resources()
150 .await
151 .map_err(|error| ClientError::Request(error.to_string()))
152 }
153
154 pub async fn list_prompts(&self) -> Result<Vec<Prompt>> {
158 let advertises = self
159 .server_capabilities()
160 .await
161 .is_some_and(|caps| caps.prompts.is_some());
162 if !advertises {
163 return Ok(Vec::new());
164 }
165 let guard = self.service.read().await;
166 let service = guard
167 .as_ref()
168 .ok_or_else(|| ClientError::Request("client has been shut down".into()))?;
169 service
170 .list_all_prompts()
171 .await
172 .map_err(|error| ClientError::Request(error.to_string()))
173 }
174
175 pub async fn call_tool(&self, name: &str, arguments: Value, timeout: Duration) -> CallOutcome {
176 let arguments = match arguments {
177 Value::Object(map) => Some(map),
178 Value::Null => None,
179 other => {
180 return CallOutcome::ProtocolError(format!(
181 "tool arguments must be a JSON object or null, got {other}"
182 ));
183 }
184 };
185
186 let request = match arguments {
187 Some(arguments) => {
188 CallToolRequestParams::new(name.to_owned()).with_arguments(arguments)
189 }
190 None => CallToolRequestParams::new(name.to_owned()),
191 };
192
193 let guard = self.service.read().await;
197 let Some(service) = guard.as_ref() else {
198 return CallOutcome::ProtocolError("client has been shut down".into());
199 };
200 match time::timeout(timeout, service.call_tool(request)).await {
201 Ok(Ok(result)) => CallOutcome::Ok(result),
202 Ok(Err(error)) if service.is_transport_closed() => {
203 CallOutcome::Crash(error.to_string())
204 }
205 Ok(Err(error)) => CallOutcome::ProtocolError(error.to_string()),
206 Err(_) => CallOutcome::Hang(timeout),
207 }
208 }
209
210 pub async fn shutdown(&self) -> Result<()> {
214 let mut guard = self.service.write().await;
215 match guard.take() {
216 Some(service) => service
217 .cancel()
218 .await
219 .map(|_| ())
220 .map_err(ClientError::Shutdown),
221 None => Ok(()),
222 }
223 }
224
225 pub fn target(&self) -> &Target {
226 &self.target
227 }
228}
229
230async fn build_service(target: &Target) -> Result<RunningService<RoleClient, ()>> {
231 let service = match &target.transport {
232 TargetTransport::Stdio { command, args, env } => {
233 let mut process = Command::new(command);
234 process.args(args).envs(env);
235 let transport = StdioChildTransport::spawn(process).map_err(ClientError::Spawn)?;
236 ().serve(transport)
237 .await
238 .map_err(|error| ClientError::Initialize(error.to_string()))?
239 }
240 TargetTransport::Http { url, headers } => {
241 let headers = header_map(headers)?;
242 let config =
243 StreamableHttpClientTransportConfig::with_uri(url.clone()).custom_headers(headers);
244 let transport = StreamableHttpClientTransport::from_config(config);
245 ().serve(transport)
246 .await
247 .map_err(|error| ClientError::Initialize(error.to_string()))?
248 }
249 };
250 Ok(service)
251}
252
253pub fn fixture_config_path(repo_root: &Path) -> std::path::PathBuf {
254 repo_root.join("tests/fixtures/wallfacer.toml")
255}
256
257fn header_map(headers: &HashMap<String, String>) -> Result<HashMap<HeaderName, HeaderValue>> {
258 headers
259 .iter()
260 .map(|(name, value)| {
261 let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|error| {
262 ClientError::InvalidHeader {
263 name: name.clone(),
264 message: error.to_string(),
265 }
266 })?;
267 let header_value =
268 HeaderValue::from_str(value).map_err(|error| ClientError::InvalidHeader {
269 name: name.clone(),
270 message: error.to_string(),
271 })?;
272 Ok((header_name, header_value))
273 })
274 .collect()
275}
276
277struct StdioChildTransport {
278 child: Option<Child>,
279 transport: AsyncRwTransport<RoleClient, ChildStdout, ChildStdin>,
280}
281
282impl StdioChildTransport {
283 fn spawn(mut command: Command) -> std::io::Result<Self> {
284 command
285 .stdin(Stdio::piped())
286 .stdout(Stdio::piped())
287 .stderr(Stdio::inherit());
288
289 let mut child = command.spawn()?;
290 let stdout = child
291 .stdout
292 .take()
293 .ok_or_else(|| std::io::Error::other("child stdout was already taken"))?;
294 let stdin = child
295 .stdin
296 .take()
297 .ok_or_else(|| std::io::Error::other("child stdin was already taken"))?;
298
299 Ok(Self {
300 child: Some(child),
301 transport: AsyncRwTransport::new_client(stdout, stdin),
302 })
303 }
304
305 async fn close_child(&mut self) -> std::io::Result<()> {
306 self.transport.close().await?;
307
308 if let Some(mut child) = self.child.take() {
309 match time::timeout(CHILD_SHUTDOWN_TIMEOUT, child.wait()).await {
310 Ok(status) => {
311 status?;
312 }
313 Err(_) => {
314 child.kill().await?;
315 }
316 }
317 }
318
319 Ok(())
320 }
321}
322
323impl Drop for StdioChildTransport {
324 fn drop(&mut self) {
325 if let Some(mut child) = self.child.take() {
326 let _ = child.start_kill();
327 tokio::spawn(async move {
328 let _ = child.wait().await;
329 });
330 }
331 }
332}
333
334impl RmcpTransport<RoleClient> for StdioChildTransport {
335 type Error = std::io::Error;
336
337 fn send(
338 &mut self,
339 item: TxJsonRpcMessage<RoleClient>,
340 ) -> impl Future<Output = std::result::Result<(), Self::Error>> + Send + 'static {
341 self.transport.send(item)
342 }
343
344 fn receive(&mut self) -> impl Future<Output = Option<RxJsonRpcMessage<RoleClient>>> + Send {
345 self.transport.receive()
346 }
347
348 fn close(&mut self) -> impl Future<Output = std::result::Result<(), Self::Error>> + Send {
349 self.close_child()
350 }
351}