tooltest_core/
session.rs

1//! Error handling strategy for the rmcp-backed session driver.
2//!
3//! We preserve rmcp error types inside `SessionError` to keep transport and
4//! session layers aligned and to retain full error context for debugging.
5
6use crate::{HttpConfig, StdioConfig, ToolInvocation, TraceEntry};
7use rmcp::model::Tool;
8use rmcp::service::{ClientInitializeError, RoleClient, RunningService, ServiceError, ServiceExt};
9/// Errors emitted by the rmcp-backed session driver.
10///
11/// The rmcp error variants are boxed to keep the enum size small; match on
12/// `SessionError` and then inspect the boxed error as needed.
13#[non_exhaustive]
14#[derive(Debug)]
15pub enum SessionError {
16    /// Initialization failed while establishing the session.
17    Initialize(Box<ClientInitializeError>),
18    /// The session failed while sending or receiving requests.
19    Service(Box<ServiceError>),
20    /// Failed to spawn or configure the stdio transport.
21    Transport(Box<std::io::Error>),
22}
23
24impl From<ClientInitializeError> for SessionError {
25    fn from(error: ClientInitializeError) -> Self {
26        Self::Initialize(Box::new(error))
27    }
28}
29
30impl From<ServiceError> for SessionError {
31    fn from(error: ServiceError) -> Self {
32        Self::Service(Box::new(error))
33    }
34}
35
36impl From<std::io::Error> for SessionError {
37    fn from(error: std::io::Error) -> Self {
38        Self::Transport(Box::new(error))
39    }
40}
41
42/// Session driver that uses rmcp client/session APIs.
43pub struct SessionDriver {
44    service: RunningService<RoleClient, ()>,
45}
46
47#[cfg(test)]
48#[path = "../tests/internal/session_unit_tests.rs"]
49mod tests;
50
51impl SessionDriver {
52    /// Connects to an MCP server over stdio using rmcp child-process transport.
53    pub async fn connect_stdio(config: &StdioConfig) -> Result<Self, SessionError> {
54        use rmcp::transport::TokioChildProcess;
55        use tokio::process::Command;
56
57        let mut command = Command::new(&config.command);
58        command.args(&config.args).envs(&config.env);
59        if let Some(cwd) = &config.cwd {
60            command.current_dir(cwd);
61        }
62        let transport = TokioChildProcess::new(command)?;
63        Self::connect_with_transport(transport).await
64    }
65
66    /// Connects to an MCP server over HTTP using rmcp streamable HTTP transport.
67    pub async fn connect_http(config: &HttpConfig) -> Result<Self, SessionError> {
68        let transport = build_http_transport(config);
69        Self::connect_with_transport(transport).await
70    }
71
72    /// Connects using a custom rmcp transport implementation.
73    pub async fn connect_with_transport<T, E, A>(transport: T) -> Result<Self, SessionError>
74    where
75        T: rmcp::transport::IntoTransport<RoleClient, E, A>,
76        E: std::error::Error + Send + Sync + 'static,
77    {
78        let service = ().serve(transport).await?;
79        Ok(Self { service })
80    }
81
82    /// Sends a tool invocation via rmcp and records the response.
83    pub async fn send_tool_call(
84        &self,
85        invocation: ToolInvocation,
86    ) -> Result<TraceEntry, SessionError> {
87        let response = self.service.peer().call_tool(invocation.clone()).await?;
88        Ok(TraceEntry::tool_call_with_response(invocation, response))
89    }
90
91    /// Lists all tools available from the MCP session.
92    pub async fn list_tools(&self) -> Result<Vec<Tool>, SessionError> {
93        let tools = self.service.peer().list_all_tools().await?;
94        Ok(tools)
95    }
96
97    /// Sends a sequence of tool invocations via rmcp.
98    pub async fn run_invocations<I>(&self, invocations: I) -> Result<Vec<TraceEntry>, SessionError>
99    where
100        I: IntoIterator<Item = ToolInvocation>,
101    {
102        let mut trace = Vec::new();
103        for invocation in invocations {
104            trace.push(self.send_tool_call(invocation).await?);
105        }
106        Ok(trace)
107    }
108}
109
110/// Builds an HTTP transport config for MCP communication.
111#[cfg_attr(coverage, allow(dead_code))]
112fn http_transport_config(
113    config: &HttpConfig,
114) -> rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig {
115    use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
116
117    let mut transport_config = StreamableHttpClientTransportConfig::with_uri(config.url.as_str());
118    if let Some(auth_token) = &config.auth_token {
119        let token = auth_token.trim();
120        let token = token.strip_prefix("Bearer ").unwrap_or(token);
121        transport_config = transport_config.auth_header(token.to_string());
122    }
123    transport_config
124}
125
126/// Builds an HTTP transport for MCP communication.
127fn build_http_transport(
128    config: &HttpConfig,
129) -> rmcp::transport::StreamableHttpClientTransport<reqwest::Client> {
130    use rmcp::transport::StreamableHttpClientTransport;
131
132    StreamableHttpClientTransport::from_config(http_transport_config(config))
133}