1use std::{
15 collections::HashMap, future::Future, path::Path, process::Stdio, sync::Arc, time::Duration,
16};
17
18use http::{HeaderName, HeaderValue};
19use rmcp::{
20 model::{CallToolRequestParams, CallToolResult, Prompt, Resource, Tool},
21 service::{RoleClient, RunningService, RxJsonRpcMessage, TxJsonRpcMessage},
22 transport::{
23 async_rw::AsyncRwTransport, streamable_http_client::StreamableHttpClientTransportConfig,
24 StreamableHttpClientTransport, Transport as RmcpTransport,
25 },
26 ServiceExt,
27};
28use serde_json::Value;
29use thiserror::Error;
30use tokio::{
31 process::{Child, ChildStdin, ChildStdout, Command},
32 sync::RwLock,
33 time,
34};
35
36use crate::target::{Target, Transport as TargetTransport};
37
38const CHILD_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(3);
39
40#[derive(Debug, Error)]
41pub enum ClientError {
42 #[error("failed to spawn stdio transport: {0}")]
43 Spawn(#[source] std::io::Error),
44 #[error("failed to initialize MCP client: {0}")]
45 Initialize(String),
46 #[error("invalid HTTP header {name}: {message}")]
47 InvalidHeader { name: String, message: String },
48 #[error("failed to shut down MCP client: {0}")]
49 Shutdown(#[source] tokio::task::JoinError),
50 #[error("MCP request failed: {0}")]
51 Request(String),
52}
53
54pub type Result<T> = std::result::Result<T, ClientError>;
55
56#[derive(Clone)]
60pub struct Client {
61 service: Arc<RwLock<Option<RunningService<RoleClient, ()>>>>,
65 target: Target,
66}
67
68#[derive(Debug)]
69pub enum CallOutcome {
70 Ok(CallToolResult),
71 Hang(Duration),
72 Crash(String),
73 ProtocolError(String),
74}
75
76impl Client {
77 pub async fn connect(target: &Target) -> Result<Self> {
78 let service = build_service(target).await?;
79 Ok(Self {
80 service: Arc::new(RwLock::new(Some(service))),
81 target: target.clone(),
82 })
83 }
84
85 pub async fn reconnect(&self) -> Result<()> {
91 let replacement = build_service(&self.target).await?;
94 let mut guard = self.service.write().await;
95 if let Some(old) = guard.take() {
96 let _ = old.cancel().await;
97 }
98 *guard = Some(replacement);
99 Ok(())
100 }
101
102 pub async fn list_tools(&self) -> Result<Vec<Tool>> {
103 let guard = self.service.read().await;
104 let service = guard
105 .as_ref()
106 .ok_or_else(|| ClientError::Request("client has been shut down".into()))?;
107 service
108 .list_all_tools()
109 .await
110 .map_err(|error| ClientError::Request(error.to_string()))
111 }
112
113 pub async fn list_resources(&self) -> Result<Vec<Resource>> {
114 let guard = self.service.read().await;
115 let service = guard
116 .as_ref()
117 .ok_or_else(|| ClientError::Request("client has been shut down".into()))?;
118 service
119 .list_all_resources()
120 .await
121 .map_err(|error| ClientError::Request(error.to_string()))
122 }
123
124 pub async fn list_prompts(&self) -> Result<Vec<Prompt>> {
125 let guard = self.service.read().await;
126 let service = guard
127 .as_ref()
128 .ok_or_else(|| ClientError::Request("client has been shut down".into()))?;
129 service
130 .list_all_prompts()
131 .await
132 .map_err(|error| ClientError::Request(error.to_string()))
133 }
134
135 pub async fn call_tool(&self, name: &str, arguments: Value, timeout: Duration) -> CallOutcome {
136 let arguments = match arguments {
137 Value::Object(map) => Some(map),
138 Value::Null => None,
139 other => {
140 return CallOutcome::ProtocolError(format!(
141 "tool arguments must be a JSON object or null, got {other}"
142 ));
143 }
144 };
145
146 let request = match arguments {
147 Some(arguments) => {
148 CallToolRequestParams::new(name.to_owned()).with_arguments(arguments)
149 }
150 None => CallToolRequestParams::new(name.to_owned()),
151 };
152
153 let guard = self.service.read().await;
157 let Some(service) = guard.as_ref() else {
158 return CallOutcome::ProtocolError("client has been shut down".into());
159 };
160 match time::timeout(timeout, service.call_tool(request)).await {
161 Ok(Ok(result)) => CallOutcome::Ok(result),
162 Ok(Err(error)) if service.is_transport_closed() => {
163 CallOutcome::Crash(error.to_string())
164 }
165 Ok(Err(error)) => CallOutcome::ProtocolError(error.to_string()),
166 Err(_) => CallOutcome::Hang(timeout),
167 }
168 }
169
170 pub async fn shutdown(&self) -> Result<()> {
174 let mut guard = self.service.write().await;
175 match guard.take() {
176 Some(service) => service
177 .cancel()
178 .await
179 .map(|_| ())
180 .map_err(ClientError::Shutdown),
181 None => Ok(()),
182 }
183 }
184
185 pub fn target(&self) -> &Target {
186 &self.target
187 }
188}
189
190async fn build_service(target: &Target) -> Result<RunningService<RoleClient, ()>> {
191 let service = match &target.transport {
192 TargetTransport::Stdio { command, args, env } => {
193 let mut process = Command::new(command);
194 process.args(args).envs(env);
195 let transport = StdioChildTransport::spawn(process).map_err(ClientError::Spawn)?;
196 ().serve(transport)
197 .await
198 .map_err(|error| ClientError::Initialize(error.to_string()))?
199 }
200 TargetTransport::Http { url, headers } => {
201 let headers = header_map(headers)?;
202 let config =
203 StreamableHttpClientTransportConfig::with_uri(url.clone()).custom_headers(headers);
204 let transport = StreamableHttpClientTransport::from_config(config);
205 ().serve(transport)
206 .await
207 .map_err(|error| ClientError::Initialize(error.to_string()))?
208 }
209 };
210 Ok(service)
211}
212
213pub fn fixture_config_path(repo_root: &Path) -> std::path::PathBuf {
214 repo_root.join("tests/fixtures/wallfacer.toml")
215}
216
217fn header_map(headers: &HashMap<String, String>) -> Result<HashMap<HeaderName, HeaderValue>> {
218 headers
219 .iter()
220 .map(|(name, value)| {
221 let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|error| {
222 ClientError::InvalidHeader {
223 name: name.clone(),
224 message: error.to_string(),
225 }
226 })?;
227 let header_value =
228 HeaderValue::from_str(value).map_err(|error| ClientError::InvalidHeader {
229 name: name.clone(),
230 message: error.to_string(),
231 })?;
232 Ok((header_name, header_value))
233 })
234 .collect()
235}
236
237struct StdioChildTransport {
238 child: Option<Child>,
239 transport: AsyncRwTransport<RoleClient, ChildStdout, ChildStdin>,
240}
241
242impl StdioChildTransport {
243 fn spawn(mut command: Command) -> std::io::Result<Self> {
244 command
245 .stdin(Stdio::piped())
246 .stdout(Stdio::piped())
247 .stderr(Stdio::inherit());
248
249 let mut child = command.spawn()?;
250 let stdout = child
251 .stdout
252 .take()
253 .ok_or_else(|| std::io::Error::other("child stdout was already taken"))?;
254 let stdin = child
255 .stdin
256 .take()
257 .ok_or_else(|| std::io::Error::other("child stdin was already taken"))?;
258
259 Ok(Self {
260 child: Some(child),
261 transport: AsyncRwTransport::new_client(stdout, stdin),
262 })
263 }
264
265 async fn close_child(&mut self) -> std::io::Result<()> {
266 self.transport.close().await?;
267
268 if let Some(mut child) = self.child.take() {
269 match time::timeout(CHILD_SHUTDOWN_TIMEOUT, child.wait()).await {
270 Ok(status) => {
271 status?;
272 }
273 Err(_) => {
274 child.kill().await?;
275 }
276 }
277 }
278
279 Ok(())
280 }
281}
282
283impl Drop for StdioChildTransport {
284 fn drop(&mut self) {
285 if let Some(mut child) = self.child.take() {
286 let _ = child.start_kill();
287 tokio::spawn(async move {
288 let _ = child.wait().await;
289 });
290 }
291 }
292}
293
294impl RmcpTransport<RoleClient> for StdioChildTransport {
295 type Error = std::io::Error;
296
297 fn send(
298 &mut self,
299 item: TxJsonRpcMessage<RoleClient>,
300 ) -> impl Future<Output = std::result::Result<(), Self::Error>> + Send + 'static {
301 self.transport.send(item)
302 }
303
304 fn receive(&mut self) -> impl Future<Output = Option<RxJsonRpcMessage<RoleClient>>> + Send {
305 self.transport.receive()
306 }
307
308 fn close(&mut self) -> impl Future<Output = std::result::Result<(), Self::Error>> + Send {
309 self.close_child()
310 }
311}