Skip to main content

tooltest_core/
session.rs

1//! Error handling strategy for the rmcp-backed session driver.
2//!
3//! We preserve rmcp error types inside `SessionError` to keep transport and
4//! session layers aligned and to retain full error context for debugging.
5
6use std::fmt;
7
8use crate::{HttpConfig, StdioConfig, ToolInvocation};
9use log::debug;
10use rmcp::model::Tool;
11use rmcp::service::{ClientInitializeError, RoleClient, RunningService, ServiceError, ServiceExt};
12use serde::Serialize;
13/// Errors emitted by the rmcp-backed session driver.
14///
15/// The rmcp error variants are boxed to keep the enum size small; match on
16/// `SessionError` and then inspect the boxed error as needed.
17#[non_exhaustive]
18#[derive(Debug)]
19pub enum SessionError {
20    /// Initialization failed while establishing the session.
21    Initialize(Box<ClientInitializeError>),
22    /// The session failed while sending or receiving requests.
23    Service(Box<ServiceError>),
24    /// Failed to spawn or configure the stdio transport.
25    Transport(Box<std::io::Error>),
26}
27
28impl fmt::Display for SessionError {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        match self {
31            SessionError::Initialize(error) => write!(f, "initialize error: {error}"),
32            SessionError::Service(error) => write!(f, "service error: {error}"),
33            SessionError::Transport(error) => write!(f, "transport error: {error}"),
34        }
35    }
36}
37
38impl std::error::Error for SessionError {
39    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
40        match self {
41            SessionError::Initialize(error) => Some(error.as_ref()),
42            SessionError::Service(error) => Some(error.as_ref()),
43            SessionError::Transport(error) => Some(error.as_ref()),
44        }
45    }
46}
47
48impl From<ClientInitializeError> for SessionError {
49    fn from(error: ClientInitializeError) -> Self {
50        Self::Initialize(Box::new(error))
51    }
52}
53
54impl From<ServiceError> for SessionError {
55    fn from(error: ServiceError) -> Self {
56        Self::Service(Box::new(error))
57    }
58}
59
60impl From<std::io::Error> for SessionError {
61    fn from(error: std::io::Error) -> Self {
62        Self::Transport(Box::new(error))
63    }
64}
65
66/// Session driver that uses rmcp client/session APIs.
67pub struct SessionDriver {
68    service: RunningService<RoleClient, ()>,
69}
70
71#[cfg(test)]
72#[path = "../tests/internal/session_unit_tests.rs"]
73mod tests;
74
75impl SessionDriver {
76    /// Connects to an MCP server over stdio using rmcp child-process transport.
77    pub async fn connect_stdio(config: &StdioConfig) -> Result<Self, SessionError> {
78        use rmcp::transport::TokioChildProcess;
79        use tokio::process::Command;
80
81        let mut command = Command::new(config.command());
82        command.args(&config.args).envs(&config.env);
83        if let Some(cwd) = &config.cwd {
84            command.current_dir(cwd);
85        }
86        let transport = TokioChildProcess::new(command)?;
87        Self::connect_with_transport(transport).await
88    }
89
90    /// Connects to an MCP server over HTTP using rmcp streamable HTTP transport.
91    pub async fn connect_http(config: &HttpConfig) -> Result<Self, SessionError> {
92        let transport = build_http_transport(config);
93        Self::connect_with_transport(transport).await
94    }
95
96    /// Connects using a custom rmcp transport implementation.
97    pub async fn connect_with_transport<T, E, A>(transport: T) -> Result<Self, SessionError>
98    where
99        T: rmcp::transport::IntoTransport<RoleClient, E, A>,
100        E: std::error::Error + Send + Sync + 'static,
101    {
102        let service = ().serve(transport).await?;
103        Ok(Self { service })
104    }
105
106    /// Sends a tool invocation via rmcp and returns the response.
107    pub async fn call_tool(
108        &self,
109        invocation: ToolInvocation,
110    ) -> Result<rmcp::model::CallToolResult, SessionError> {
111        log_io("call_tool request", &invocation);
112        let response = self.service.peer().call_tool(invocation).await?;
113        log_io("call_tool response", &response);
114        Ok(response)
115    }
116
117    /// Lists all tools available from the MCP session.
118    pub async fn list_tools(&self) -> Result<Vec<Tool>, SessionError> {
119        log_io_message("list_tools request");
120        let tools = self.service.peer().list_all_tools().await?;
121        log_io("list_tools response", &tools);
122        Ok(tools)
123    }
124
125    /// Returns the server-reported MCP protocol version, if available.
126    pub fn server_protocol_version(&self) -> Option<String> {
127        self.service
128            .peer()
129            .peer_info()
130            .map(|info| info.protocol_version.to_string())
131    }
132}
133
134const IO_LOG_TARGET: &str = "tooltest.io_logs";
135
136fn log_io_message(message: &str) {
137    debug!(target: IO_LOG_TARGET, "{message}");
138}
139
140fn log_io<T: Serialize>(label: &str, value: &T) {
141    debug!(
142        target: IO_LOG_TARGET,
143        "{label}: {}",
144        serde_json::to_string(value)
145            .unwrap_or_else(|error| format!("<serialize error: {error}>"))
146    );
147}
148
149/// Builds an HTTP transport config for MCP communication.
150#[cfg_attr(coverage, allow(dead_code))]
151fn http_transport_config(
152    config: &HttpConfig,
153) -> rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig {
154    use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
155
156    let mut transport_config = StreamableHttpClientTransportConfig::with_uri(config.url());
157    if let Some(auth_token) = &config.auth_token {
158        let token = auth_token.trim();
159        let token = token.strip_prefix("Bearer ").unwrap_or(token);
160        transport_config = transport_config.auth_header(token.to_string());
161    }
162    transport_config
163}
164
165/// Builds an HTTP transport for MCP communication.
166fn build_http_transport(
167    config: &HttpConfig,
168) -> rmcp::transport::StreamableHttpClientTransport<reqwest::Client> {
169    use rmcp::transport::StreamableHttpClientTransport;
170
171    StreamableHttpClientTransport::from_config(http_transport_config(config))
172}