1use std::fmt;
7
8use crate::{HttpConfig, StdioConfig, ToolInvocation};
9use log::debug;
10use rmcp::model::Tool;
11use rmcp::service::{ClientInitializeError, RoleClient, RunningService, ServiceError, ServiceExt};
12use serde::Serialize;
13#[non_exhaustive]
18#[derive(Debug)]
19pub enum SessionError {
20 Initialize(Box<ClientInitializeError>),
22 Service(Box<ServiceError>),
24 Transport(Box<std::io::Error>),
26}
27
28impl fmt::Display for SessionError {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 match self {
31 SessionError::Initialize(error) => write!(f, "initialize error: {error}"),
32 SessionError::Service(error) => write!(f, "service error: {error}"),
33 SessionError::Transport(error) => write!(f, "transport error: {error}"),
34 }
35 }
36}
37
38impl std::error::Error for SessionError {
39 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
40 match self {
41 SessionError::Initialize(error) => Some(error.as_ref()),
42 SessionError::Service(error) => Some(error.as_ref()),
43 SessionError::Transport(error) => Some(error.as_ref()),
44 }
45 }
46}
47
48impl From<ClientInitializeError> for SessionError {
49 fn from(error: ClientInitializeError) -> Self {
50 Self::Initialize(Box::new(error))
51 }
52}
53
54impl From<ServiceError> for SessionError {
55 fn from(error: ServiceError) -> Self {
56 Self::Service(Box::new(error))
57 }
58}
59
60impl From<std::io::Error> for SessionError {
61 fn from(error: std::io::Error) -> Self {
62 Self::Transport(Box::new(error))
63 }
64}
65
66pub struct SessionDriver {
68 service: RunningService<RoleClient, ()>,
69}
70
71#[cfg(test)]
72#[path = "../tests/internal/session_unit_tests.rs"]
73mod tests;
74
75impl SessionDriver {
76 pub async fn connect_stdio(config: &StdioConfig) -> Result<Self, SessionError> {
78 use rmcp::transport::TokioChildProcess;
79 use tokio::process::Command;
80
81 let mut command = Command::new(config.command());
82 command.args(&config.args).envs(&config.env);
83 if let Some(cwd) = &config.cwd {
84 command.current_dir(cwd);
85 }
86 let transport = TokioChildProcess::new(command)?;
87 Self::connect_with_transport(transport).await
88 }
89
90 pub async fn connect_http(config: &HttpConfig) -> Result<Self, SessionError> {
92 let transport = build_http_transport(config);
93 Self::connect_with_transport(transport).await
94 }
95
96 pub async fn connect_with_transport<T, E, A>(transport: T) -> Result<Self, SessionError>
98 where
99 T: rmcp::transport::IntoTransport<RoleClient, E, A>,
100 E: std::error::Error + Send + Sync + 'static,
101 {
102 let service = ().serve(transport).await?;
103 Ok(Self { service })
104 }
105
106 pub async fn call_tool(
108 &self,
109 invocation: ToolInvocation,
110 ) -> Result<rmcp::model::CallToolResult, SessionError> {
111 log_io("call_tool request", &invocation);
112 let response = self.service.peer().call_tool(invocation).await?;
113 log_io("call_tool response", &response);
114 Ok(response)
115 }
116
117 pub async fn list_tools(&self) -> Result<Vec<Tool>, SessionError> {
119 log_io_message("list_tools request");
120 let tools = self.service.peer().list_all_tools().await?;
121 log_io("list_tools response", &tools);
122 Ok(tools)
123 }
124
125 pub fn server_protocol_version(&self) -> Option<String> {
127 self.service
128 .peer()
129 .peer_info()
130 .map(|info| info.protocol_version.to_string())
131 }
132}
133
134const IO_LOG_TARGET: &str = "tooltest.io_logs";
135
136fn log_io_message(message: &str) {
137 debug!(target: IO_LOG_TARGET, "{message}");
138}
139
140fn log_io<T: Serialize>(label: &str, value: &T) {
141 debug!(
142 target: IO_LOG_TARGET,
143 "{label}: {}",
144 serde_json::to_string(value)
145 .unwrap_or_else(|error| format!("<serialize error: {error}>"))
146 );
147}
148
149#[cfg_attr(coverage, allow(dead_code))]
151fn http_transport_config(
152 config: &HttpConfig,
153) -> rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig {
154 use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
155
156 let mut transport_config = StreamableHttpClientTransportConfig::with_uri(config.url());
157 if let Some(auth_token) = &config.auth_token {
158 let token = auth_token.trim();
159 let token = token.strip_prefix("Bearer ").unwrap_or(token);
160 transport_config = transport_config.auth_header(token.to_string());
161 }
162 transport_config
163}
164
165fn build_http_transport(
167 config: &HttpConfig,
168) -> rmcp::transport::StreamableHttpClientTransport<reqwest::Client> {
169 use rmcp::transport::StreamableHttpClientTransport;
170
171 StreamableHttpClientTransport::from_config(http_transport_config(config))
172}