Skip to main content

wallfacer_core/
client.rs

1//! MCP client wrapper.
2//!
3//! Phase E1 puts the underlying `rmcp` service behind an
4//! `Arc<RwLock<...>>` so:
5//!
6//! * `Client` is `Clone` (cheap `Arc::clone`); torture can fan out across
7//!   many tasks sharing the same connection.
8//! * `list_tools` and `call_tool` take `&self` and acquire a *read* lock,
9//!   allowing concurrent calls to be in flight at once.
10//! * `reconnect` takes `&self` and acquires a *write* lock to atomically
11//!   tear down and rebuild the underlying service. Concurrent callers see
12//!   either the old or the new transport, never a torn state.
13
14use std::{
15    collections::HashMap, future::Future, path::Path, process::Stdio, sync::Arc, time::Duration,
16};
17
18use http::{HeaderName, HeaderValue};
19use rmcp::{
20    model::{CallToolRequestParams, CallToolResult, Prompt, Resource, Tool},
21    service::{RoleClient, RunningService, RxJsonRpcMessage, TxJsonRpcMessage},
22    transport::{
23        async_rw::AsyncRwTransport, streamable_http_client::StreamableHttpClientTransportConfig,
24        StreamableHttpClientTransport, Transport as RmcpTransport,
25    },
26    ServiceExt,
27};
28use serde_json::Value;
29use thiserror::Error;
30use tokio::{
31    process::{Child, ChildStdin, ChildStdout, Command},
32    sync::RwLock,
33    time,
34};
35
36use crate::target::{Target, Transport as TargetTransport};
37
38const CHILD_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(3);
39
40#[derive(Debug, Error)]
41pub enum ClientError {
42    #[error("failed to spawn stdio transport: {0}")]
43    Spawn(#[source] std::io::Error),
44    #[error("failed to initialize MCP client: {0}")]
45    Initialize(String),
46    #[error("invalid HTTP header {name}: {message}")]
47    InvalidHeader { name: String, message: String },
48    #[error("failed to shut down MCP client: {0}")]
49    Shutdown(#[source] tokio::task::JoinError),
50    #[error("MCP request failed: {0}")]
51    Request(String),
52}
53
54pub type Result<T> = std::result::Result<T, ClientError>;
55
56/// Cheaply-cloneable MCP client. Cloning shares the same underlying
57/// transport via `Arc`. After [`Client::shutdown`] the transport is gone
58/// and further calls return `ClientError::Request`.
59#[derive(Clone)]
60pub struct Client {
61    /// `Some` while the transport is live; `None` after `shutdown`.
62    /// `RwLock` allows concurrent `&self` calls (read) and exclusive
63    /// `reconnect` (write).
64    service: Arc<RwLock<Option<RunningService<RoleClient, ()>>>>,
65    target: Target,
66}
67
68#[derive(Debug)]
69pub enum CallOutcome {
70    Ok(CallToolResult),
71    Hang(Duration),
72    Crash(String),
73    ProtocolError(String),
74}
75
76impl Client {
77    pub async fn connect(target: &Target) -> Result<Self> {
78        let service = build_service(target).await?;
79        Ok(Self {
80            service: Arc::new(RwLock::new(Some(service))),
81            target: target.clone(),
82        })
83    }
84
85    /// Tears down the current transport and opens a new one. Other tasks
86    /// holding a `&self` reference will see the new transport on their
87    /// next call. Phase E1 changed this from `&mut self` to `&self` so
88    /// callers can recover after a fault without exclusive ownership of
89    /// the `Client`.
90    pub async fn reconnect(&self) -> Result<()> {
91        // Build the replacement *before* dropping the old one to keep the
92        // window where the client has no transport as small as possible.
93        let replacement = build_service(&self.target).await?;
94        let mut guard = self.service.write().await;
95        if let Some(old) = guard.take() {
96            let _ = old.cancel().await;
97        }
98        *guard = Some(replacement);
99        Ok(())
100    }
101
102    pub async fn list_tools(&self) -> Result<Vec<Tool>> {
103        let guard = self.service.read().await;
104        let service = guard
105            .as_ref()
106            .ok_or_else(|| ClientError::Request("client has been shut down".into()))?;
107        service
108            .list_all_tools()
109            .await
110            .map_err(|error| ClientError::Request(error.to_string()))
111    }
112
113    pub async fn list_resources(&self) -> Result<Vec<Resource>> {
114        let guard = self.service.read().await;
115        let service = guard
116            .as_ref()
117            .ok_or_else(|| ClientError::Request("client has been shut down".into()))?;
118        service
119            .list_all_resources()
120            .await
121            .map_err(|error| ClientError::Request(error.to_string()))
122    }
123
124    pub async fn list_prompts(&self) -> Result<Vec<Prompt>> {
125        let guard = self.service.read().await;
126        let service = guard
127            .as_ref()
128            .ok_or_else(|| ClientError::Request("client has been shut down".into()))?;
129        service
130            .list_all_prompts()
131            .await
132            .map_err(|error| ClientError::Request(error.to_string()))
133    }
134
135    pub async fn call_tool(&self, name: &str, arguments: Value, timeout: Duration) -> CallOutcome {
136        let arguments = match arguments {
137            Value::Object(map) => Some(map),
138            Value::Null => None,
139            other => {
140                return CallOutcome::ProtocolError(format!(
141                    "tool arguments must be a JSON object or null, got {other}"
142                ));
143            }
144        };
145
146        let request = match arguments {
147            Some(arguments) => {
148                CallToolRequestParams::new(name.to_owned()).with_arguments(arguments)
149            }
150            None => CallToolRequestParams::new(name.to_owned()),
151        };
152
153        // Snapshot the current transport for the duration of the call. We
154        // hold the read lock across the await so a concurrent reconnect
155        // waits until our call finishes before swapping.
156        let guard = self.service.read().await;
157        let Some(service) = guard.as_ref() else {
158            return CallOutcome::ProtocolError("client has been shut down".into());
159        };
160        match time::timeout(timeout, service.call_tool(request)).await {
161            Ok(Ok(result)) => CallOutcome::Ok(result),
162            Ok(Err(error)) if service.is_transport_closed() => {
163                CallOutcome::Crash(error.to_string())
164            }
165            Ok(Err(error)) => CallOutcome::ProtocolError(error.to_string()),
166            Err(_) => CallOutcome::Hang(timeout),
167        }
168    }
169
170    /// Shuts the underlying transport down. After this call other clones of
171    /// this `Client` keep working semantically but return `ProtocolError`
172    /// on every method (the transport is gone).
173    pub async fn shutdown(&self) -> Result<()> {
174        let mut guard = self.service.write().await;
175        match guard.take() {
176            Some(service) => service
177                .cancel()
178                .await
179                .map(|_| ())
180                .map_err(ClientError::Shutdown),
181            None => Ok(()),
182        }
183    }
184
185    pub fn target(&self) -> &Target {
186        &self.target
187    }
188}
189
190async fn build_service(target: &Target) -> Result<RunningService<RoleClient, ()>> {
191    let service = match &target.transport {
192        TargetTransport::Stdio { command, args, env } => {
193            let mut process = Command::new(command);
194            process.args(args).envs(env);
195            let transport = StdioChildTransport::spawn(process).map_err(ClientError::Spawn)?;
196            ().serve(transport)
197                .await
198                .map_err(|error| ClientError::Initialize(error.to_string()))?
199        }
200        TargetTransport::Http { url, headers } => {
201            let headers = header_map(headers)?;
202            let config =
203                StreamableHttpClientTransportConfig::with_uri(url.clone()).custom_headers(headers);
204            let transport = StreamableHttpClientTransport::from_config(config);
205            ().serve(transport)
206                .await
207                .map_err(|error| ClientError::Initialize(error.to_string()))?
208        }
209    };
210    Ok(service)
211}
212
213pub fn fixture_config_path(repo_root: &Path) -> std::path::PathBuf {
214    repo_root.join("tests/fixtures/wallfacer.toml")
215}
216
217fn header_map(headers: &HashMap<String, String>) -> Result<HashMap<HeaderName, HeaderValue>> {
218    headers
219        .iter()
220        .map(|(name, value)| {
221            let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|error| {
222                ClientError::InvalidHeader {
223                    name: name.clone(),
224                    message: error.to_string(),
225                }
226            })?;
227            let header_value =
228                HeaderValue::from_str(value).map_err(|error| ClientError::InvalidHeader {
229                    name: name.clone(),
230                    message: error.to_string(),
231                })?;
232            Ok((header_name, header_value))
233        })
234        .collect()
235}
236
237struct StdioChildTransport {
238    child: Option<Child>,
239    transport: AsyncRwTransport<RoleClient, ChildStdout, ChildStdin>,
240}
241
242impl StdioChildTransport {
243    fn spawn(mut command: Command) -> std::io::Result<Self> {
244        command
245            .stdin(Stdio::piped())
246            .stdout(Stdio::piped())
247            .stderr(Stdio::inherit());
248
249        let mut child = command.spawn()?;
250        let stdout = child
251            .stdout
252            .take()
253            .ok_or_else(|| std::io::Error::other("child stdout was already taken"))?;
254        let stdin = child
255            .stdin
256            .take()
257            .ok_or_else(|| std::io::Error::other("child stdin was already taken"))?;
258
259        Ok(Self {
260            child: Some(child),
261            transport: AsyncRwTransport::new_client(stdout, stdin),
262        })
263    }
264
265    async fn close_child(&mut self) -> std::io::Result<()> {
266        self.transport.close().await?;
267
268        if let Some(mut child) = self.child.take() {
269            match time::timeout(CHILD_SHUTDOWN_TIMEOUT, child.wait()).await {
270                Ok(status) => {
271                    status?;
272                }
273                Err(_) => {
274                    child.kill().await?;
275                }
276            }
277        }
278
279        Ok(())
280    }
281}
282
283impl Drop for StdioChildTransport {
284    fn drop(&mut self) {
285        if let Some(mut child) = self.child.take() {
286            let _ = child.start_kill();
287            tokio::spawn(async move {
288                let _ = child.wait().await;
289            });
290        }
291    }
292}
293
294impl RmcpTransport<RoleClient> for StdioChildTransport {
295    type Error = std::io::Error;
296
297    fn send(
298        &mut self,
299        item: TxJsonRpcMessage<RoleClient>,
300    ) -> impl Future<Output = std::result::Result<(), Self::Error>> + Send + 'static {
301        self.transport.send(item)
302    }
303
304    fn receive(&mut self) -> impl Future<Output = Option<RxJsonRpcMessage<RoleClient>>> + Send {
305        self.transport.receive()
306    }
307
308    fn close(&mut self) -> impl Future<Output = std::result::Result<(), Self::Error>> + Send {
309        self.close_child()
310    }
311}