Skip to main content

wallfacer_core/
client.rs

1//! MCP client wrapper.
2//!
3//! Phase E1 puts the underlying `rmcp` service behind an
4//! `Arc<RwLock<...>>` so:
5//!
6//! * `Client` is `Clone` (cheap `Arc::clone`); torture can fan out across
7//!   many tasks sharing the same connection.
8//! * `list_tools` and `call_tool` take `&self` and acquire a *read* lock,
9//!   allowing concurrent calls to be in flight at once.
10//! * `reconnect` takes `&self` and acquires a *write* lock to atomically
11//!   tear down and rebuild the underlying service. Concurrent callers see
12//!   either the old or the new transport, never a torn state.
13
14use std::{
15    collections::HashMap, future::Future, path::Path, process::Stdio, sync::Arc, time::Duration,
16};
17
18use http::{HeaderName, HeaderValue};
19use rmcp::{
20    model::{CallToolRequestParams, CallToolResult, Prompt, Resource, ServerCapabilities, Tool},
21    service::{RoleClient, RunningService, RxJsonRpcMessage, TxJsonRpcMessage},
22    transport::{
23        async_rw::AsyncRwTransport, streamable_http_client::StreamableHttpClientTransportConfig,
24        StreamableHttpClientTransport, Transport as RmcpTransport,
25    },
26    ServiceExt,
27};
28use serde_json::Value;
29use thiserror::Error;
30use tokio::{
31    process::{Child, ChildStdin, ChildStdout, Command},
32    sync::RwLock,
33    time,
34};
35
36use crate::target::{Target, Transport as TargetTransport};
37
38const CHILD_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(3);
39
40#[derive(Debug, Error)]
41pub enum ClientError {
42    #[error("failed to spawn stdio transport: {0}")]
43    Spawn(#[source] std::io::Error),
44    #[error("failed to initialize MCP client: {0}")]
45    Initialize(String),
46    #[error("invalid HTTP header {name}: {message}")]
47    InvalidHeader { name: String, message: String },
48    #[error("failed to shut down MCP client: {0}")]
49    Shutdown(#[source] tokio::task::JoinError),
50    #[error("MCP request failed: {0}")]
51    Request(String),
52}
53
54pub type Result<T> = std::result::Result<T, ClientError>;
55
56/// Cheaply-cloneable MCP client. Cloning shares the same underlying
57/// transport via `Arc`. After [`Client::shutdown`] the transport is gone
58/// and further calls return `ClientError::Request`.
59#[derive(Clone)]
60pub struct Client {
61    /// `Some` while the transport is live; `None` after `shutdown`.
62    /// `RwLock` allows concurrent `&self` calls (read) and exclusive
63    /// `reconnect` (write).
64    service: Arc<RwLock<Option<RunningService<RoleClient, ()>>>>,
65    target: Target,
66}
67
68#[derive(Debug)]
69pub enum CallOutcome {
70    Ok(CallToolResult),
71    Hang(Duration),
72    Crash(String),
73    ProtocolError(String),
74}
75
76impl Client {
77    pub async fn connect(target: &Target) -> Result<Self> {
78        let service = build_service(target).await?;
79        Ok(Self {
80            service: Arc::new(RwLock::new(Some(service))),
81            target: target.clone(),
82        })
83    }
84
85    /// Tears down the current transport and opens a new one. Other tasks
86    /// holding a `&self` reference will see the new transport on their
87    /// next call. Phase E1 changed this from `&mut self` to `&self` so
88    /// callers can recover after a fault without exclusive ownership of
89    /// the `Client`.
90    pub async fn reconnect(&self) -> Result<()> {
91        // Build the replacement *before* dropping the old one to keep the
92        // window where the client has no transport as small as possible.
93        let replacement = build_service(&self.target).await?;
94        let mut guard = self.service.write().await;
95        if let Some(old) = guard.take() {
96            let _ = old.cancel().await;
97        }
98        *guard = Some(replacement);
99        Ok(())
100    }
101
102    pub async fn list_tools(&self) -> Result<Vec<Tool>> {
103        let guard = self.service.read().await;
104        let service = guard
105            .as_ref()
106            .ok_or_else(|| ClientError::Request("client has been shut down".into()))?;
107        service
108            .list_all_tools()
109            .await
110            .map_err(|error| ClientError::Request(error.to_string()))
111    }
112
113    /// Returns a clone of the server's announced [`ServerCapabilities`]
114    /// from the initial `initialize` handshake, or `None` if the client
115    /// has been shut down or the handshake hadn't completed.
116    ///
117    /// Per MCP spec, clients should not call `resources/list` or
118    /// `prompts/list` against a server that didn't declare the
119    /// corresponding capability — doing so produces a noisy
120    /// `-32601 method not found` from compliant servers. Use this to
121    /// gate optional listing calls.
122    pub async fn server_capabilities(&self) -> Option<ServerCapabilities> {
123        let guard = self.service.read().await;
124        guard
125            .as_ref()
126            .and_then(|service| service.peer_info())
127            .map(|info| info.capabilities.clone())
128    }
129
130    /// Lists the server's resources. Returns `Ok(vec![])` (silently)
131    /// when the server didn't declare the `resources` capability at
132    /// init time — this avoids spamming a `-32601 method not found`
133    /// failure for servers that legitimately don't expose resources.
134    /// Callers needing to distinguish "not advertised" from "empty"
135    /// should check [`Self::server_capabilities`] first.
136    pub async fn list_resources(&self) -> Result<Vec<Resource>> {
137        let advertises = self
138            .server_capabilities()
139            .await
140            .is_some_and(|caps| caps.resources.is_some());
141        if !advertises {
142            return Ok(Vec::new());
143        }
144        let guard = self.service.read().await;
145        let service = guard
146            .as_ref()
147            .ok_or_else(|| ClientError::Request("client has been shut down".into()))?;
148        service
149            .list_all_resources()
150            .await
151            .map_err(|error| ClientError::Request(error.to_string()))
152    }
153
154    /// Lists the server's prompts. Same capability-aware short-circuit
155    /// as [`Self::list_resources`]: returns `Ok(vec![])` when the
156    /// server didn't declare the `prompts` capability.
157    pub async fn list_prompts(&self) -> Result<Vec<Prompt>> {
158        let advertises = self
159            .server_capabilities()
160            .await
161            .is_some_and(|caps| caps.prompts.is_some());
162        if !advertises {
163            return Ok(Vec::new());
164        }
165        let guard = self.service.read().await;
166        let service = guard
167            .as_ref()
168            .ok_or_else(|| ClientError::Request("client has been shut down".into()))?;
169        service
170            .list_all_prompts()
171            .await
172            .map_err(|error| ClientError::Request(error.to_string()))
173    }
174
175    pub async fn call_tool(&self, name: &str, arguments: Value, timeout: Duration) -> CallOutcome {
176        let arguments = match arguments {
177            Value::Object(map) => Some(map),
178            Value::Null => None,
179            other => {
180                return CallOutcome::ProtocolError(format!(
181                    "tool arguments must be a JSON object or null, got {other}"
182                ));
183            }
184        };
185
186        let request = match arguments {
187            Some(arguments) => {
188                CallToolRequestParams::new(name.to_owned()).with_arguments(arguments)
189            }
190            None => CallToolRequestParams::new(name.to_owned()),
191        };
192
193        // Snapshot the current transport for the duration of the call. We
194        // hold the read lock across the await so a concurrent reconnect
195        // waits until our call finishes before swapping.
196        let guard = self.service.read().await;
197        let Some(service) = guard.as_ref() else {
198            return CallOutcome::ProtocolError("client has been shut down".into());
199        };
200        match time::timeout(timeout, service.call_tool(request)).await {
201            Ok(Ok(result)) => CallOutcome::Ok(result),
202            Ok(Err(error)) if service.is_transport_closed() => {
203                CallOutcome::Crash(error.to_string())
204            }
205            Ok(Err(error)) => CallOutcome::ProtocolError(error.to_string()),
206            Err(_) => CallOutcome::Hang(timeout),
207        }
208    }
209
210    /// Shuts the underlying transport down. After this call other clones of
211    /// this `Client` keep working semantically but return `ProtocolError`
212    /// on every method (the transport is gone).
213    pub async fn shutdown(&self) -> Result<()> {
214        let mut guard = self.service.write().await;
215        match guard.take() {
216            Some(service) => service
217                .cancel()
218                .await
219                .map(|_| ())
220                .map_err(ClientError::Shutdown),
221            None => Ok(()),
222        }
223    }
224
225    pub fn target(&self) -> &Target {
226        &self.target
227    }
228}
229
230async fn build_service(target: &Target) -> Result<RunningService<RoleClient, ()>> {
231    let service = match &target.transport {
232        TargetTransport::Stdio { command, args, env } => {
233            let mut process = Command::new(command);
234            process.args(args).envs(env);
235            let transport = StdioChildTransport::spawn(process).map_err(ClientError::Spawn)?;
236            ().serve(transport)
237                .await
238                .map_err(|error| ClientError::Initialize(error.to_string()))?
239        }
240        TargetTransport::Http { url, headers } => {
241            let headers = header_map(headers)?;
242            let config =
243                StreamableHttpClientTransportConfig::with_uri(url.clone()).custom_headers(headers);
244            let transport = StreamableHttpClientTransport::from_config(config);
245            ().serve(transport)
246                .await
247                .map_err(|error| ClientError::Initialize(error.to_string()))?
248        }
249    };
250    Ok(service)
251}
252
253pub fn fixture_config_path(repo_root: &Path) -> std::path::PathBuf {
254    repo_root.join("tests/fixtures/wallfacer.toml")
255}
256
257fn header_map(headers: &HashMap<String, String>) -> Result<HashMap<HeaderName, HeaderValue>> {
258    headers
259        .iter()
260        .map(|(name, value)| {
261            let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|error| {
262                ClientError::InvalidHeader {
263                    name: name.clone(),
264                    message: error.to_string(),
265                }
266            })?;
267            let header_value =
268                HeaderValue::from_str(value).map_err(|error| ClientError::InvalidHeader {
269                    name: name.clone(),
270                    message: error.to_string(),
271                })?;
272            Ok((header_name, header_value))
273        })
274        .collect()
275}
276
277struct StdioChildTransport {
278    child: Option<Child>,
279    transport: AsyncRwTransport<RoleClient, ChildStdout, ChildStdin>,
280}
281
282impl StdioChildTransport {
283    fn spawn(mut command: Command) -> std::io::Result<Self> {
284        command
285            .stdin(Stdio::piped())
286            .stdout(Stdio::piped())
287            .stderr(Stdio::inherit());
288
289        let mut child = command.spawn()?;
290        let stdout = child
291            .stdout
292            .take()
293            .ok_or_else(|| std::io::Error::other("child stdout was already taken"))?;
294        let stdin = child
295            .stdin
296            .take()
297            .ok_or_else(|| std::io::Error::other("child stdin was already taken"))?;
298
299        Ok(Self {
300            child: Some(child),
301            transport: AsyncRwTransport::new_client(stdout, stdin),
302        })
303    }
304
305    async fn close_child(&mut self) -> std::io::Result<()> {
306        self.transport.close().await?;
307
308        if let Some(mut child) = self.child.take() {
309            match time::timeout(CHILD_SHUTDOWN_TIMEOUT, child.wait()).await {
310                Ok(status) => {
311                    status?;
312                }
313                Err(_) => {
314                    child.kill().await?;
315                }
316            }
317        }
318
319        Ok(())
320    }
321}
322
323impl Drop for StdioChildTransport {
324    fn drop(&mut self) {
325        if let Some(mut child) = self.child.take() {
326            let _ = child.start_kill();
327            tokio::spawn(async move {
328                let _ = child.wait().await;
329            });
330        }
331    }
332}
333
334impl RmcpTransport<RoleClient> for StdioChildTransport {
335    type Error = std::io::Error;
336
337    fn send(
338        &mut self,
339        item: TxJsonRpcMessage<RoleClient>,
340    ) -> impl Future<Output = std::result::Result<(), Self::Error>> + Send + 'static {
341        self.transport.send(item)
342    }
343
344    fn receive(&mut self) -> impl Future<Output = Option<RxJsonRpcMessage<RoleClient>>> + Send {
345        self.transport.receive()
346    }
347
348    fn close(&mut self) -> impl Future<Output = std::result::Result<(), Self::Error>> + Send {
349        self.close_child()
350    }
351}