1use crate::{HttpConfig, StdioConfig, ToolInvocation, TraceEntry};
7use rmcp::model::Tool;
8use rmcp::service::{ClientInitializeError, RoleClient, RunningService, ServiceError, ServiceExt};
9#[non_exhaustive]
14#[derive(Debug)]
15pub enum SessionError {
16 Initialize(Box<ClientInitializeError>),
18 Service(Box<ServiceError>),
20 Transport(Box<std::io::Error>),
22}
23
24impl From<ClientInitializeError> for SessionError {
25 fn from(error: ClientInitializeError) -> Self {
26 Self::Initialize(Box::new(error))
27 }
28}
29
30impl From<ServiceError> for SessionError {
31 fn from(error: ServiceError) -> Self {
32 Self::Service(Box::new(error))
33 }
34}
35
36impl From<std::io::Error> for SessionError {
37 fn from(error: std::io::Error) -> Self {
38 Self::Transport(Box::new(error))
39 }
40}
41
42pub struct SessionDriver {
44 service: RunningService<RoleClient, ()>,
45}
46
47#[cfg(test)]
48#[path = "../tests/internal/session_unit_tests.rs"]
49mod tests;
50
51impl SessionDriver {
52 pub async fn connect_stdio(config: &StdioConfig) -> Result<Self, SessionError> {
54 use rmcp::transport::TokioChildProcess;
55 use tokio::process::Command;
56
57 let mut command = Command::new(&config.command);
58 command.args(&config.args).envs(&config.env);
59 if let Some(cwd) = &config.cwd {
60 command.current_dir(cwd);
61 }
62 let transport = TokioChildProcess::new(command)?;
63 Self::connect_with_transport(transport).await
64 }
65
66 pub async fn connect_http(config: &HttpConfig) -> Result<Self, SessionError> {
68 let transport = build_http_transport(config);
69 Self::connect_with_transport(transport).await
70 }
71
72 pub async fn connect_with_transport<T, E, A>(transport: T) -> Result<Self, SessionError>
74 where
75 T: rmcp::transport::IntoTransport<RoleClient, E, A>,
76 E: std::error::Error + Send + Sync + 'static,
77 {
78 let service = ().serve(transport).await?;
79 Ok(Self { service })
80 }
81
82 pub async fn send_tool_call(
84 &self,
85 invocation: ToolInvocation,
86 ) -> Result<TraceEntry, SessionError> {
87 let response = self.service.peer().call_tool(invocation.clone()).await?;
88 Ok(TraceEntry::tool_call_with_response(invocation, response))
89 }
90
91 pub async fn list_tools(&self) -> Result<Vec<Tool>, SessionError> {
93 let tools = self.service.peer().list_all_tools().await?;
94 Ok(tools)
95 }
96
97 pub async fn run_invocations<I>(&self, invocations: I) -> Result<Vec<TraceEntry>, SessionError>
99 where
100 I: IntoIterator<Item = ToolInvocation>,
101 {
102 let mut trace = Vec::new();
103 for invocation in invocations {
104 trace.push(self.send_tool_call(invocation).await?);
105 }
106 Ok(trace)
107 }
108}
109
110#[cfg_attr(coverage, allow(dead_code))]
112fn http_transport_config(
113 config: &HttpConfig,
114) -> rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig {
115 use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
116
117 let mut transport_config = StreamableHttpClientTransportConfig::with_uri(config.url.as_str());
118 if let Some(auth_token) = &config.auth_token {
119 let token = auth_token.trim();
120 let token = token.strip_prefix("Bearer ").unwrap_or(token);
121 transport_config = transport_config.auth_header(token.to_string());
122 }
123 transport_config
124}
125
126fn build_http_transport(
128 config: &HttpConfig,
129) -> rmcp::transport::StreamableHttpClientTransport<reqwest::Client> {
130 use rmcp::transport::StreamableHttpClientTransport;
131
132 StreamableHttpClientTransport::from_config(http_transport_config(config))
133}