rust_mcp_sdk/hyper_servers/
hyper_runtime.rs1use std::{sync::Arc, time::Duration};
2
3use crate::{
4 mcp_server::HyperServer,
5 schema::{
6 schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient},
7 CreateMessageRequestParams, CreateMessageResult, InitializeRequestParams,
8 ListRootsRequestParams, ListRootsResult, LoggingMessageNotificationParams,
9 PromptListChangedNotificationParams, ResourceListChangedNotificationParams,
10 ResourceUpdatedNotificationParams, ToolListChangedNotificationParams,
11 },
12 McpServer,
13};
14
15use axum_server::Handle;
16use rust_mcp_transport::SessionId;
17use tokio::{sync::Mutex, task::JoinHandle};
18
19use crate::{
20 error::SdkResult,
21 hyper_servers::app_state::AppState,
22 mcp_server::{
23 error::{TransportServerError, TransportServerResult},
24 ServerRuntime,
25 },
26};
27
28pub struct HyperRuntime {
29 pub(crate) state: Arc<AppState>,
30 pub(crate) server_task: JoinHandle<Result<(), TransportServerError>>,
31 pub(crate) server_handle: Handle,
32}
33
34impl HyperRuntime {
35 pub async fn create(server: HyperServer) -> SdkResult<Self> {
36 let addr = server.options.resolve_server_address().await?;
37 let state = server.state();
38
39 let server_handle = server.server_handle();
40
41 let server_task = tokio::spawn(async move {
42 #[cfg(feature = "ssl")]
43 if server.options.enable_ssl {
44 server.start_ssl(addr).await
45 } else {
46 server.start_http(addr).await
47 }
48
49 #[cfg(not(feature = "ssl"))]
50 if server.options.enable_ssl {
51 panic!("SSL requested but the 'ssl' feature is not enabled");
52 } else {
53 server.start_http(addr).await
54 }
55 });
56
57 Ok(Self {
58 state,
59 server_task,
60 server_handle,
61 })
62 }
63
64 pub fn graceful_shutdown(&self, timeout: Option<Duration>) {
65 self.server_handle.graceful_shutdown(timeout);
66 }
67
68 pub async fn await_server(self) -> SdkResult<()> {
69 let result = self.server_task.await?;
70 result.map_err(|err| err.into())
71 }
72
73 pub async fn sessions(&self) -> Vec<String> {
75 self.state.session_store.keys().await
76 }
77
78 pub async fn runtime_by_session(
80 &self,
81 session_id: &SessionId,
82 ) -> TransportServerResult<Arc<Mutex<Arc<ServerRuntime>>>> {
83 self.state.session_store.get(session_id).await.ok_or(
84 TransportServerError::SessionIdInvalid(session_id.to_string()),
85 )
86 }
87
88 pub async fn send_request(
89 &self,
90 session_id: &SessionId,
91 request: RequestFromServer,
92 timeout: Option<Duration>,
93 ) -> SdkResult<ResultFromClient> {
94 let runtime = self.runtime_by_session(session_id).await?;
95 let runtime = runtime.lock().await.to_owned();
96 runtime.request(request, timeout).await
97 }
98
99 pub async fn send_notification(
100 &self,
101 session_id: &SessionId,
102 notification: NotificationFromServer,
103 ) -> SdkResult<()> {
104 let runtime = self.runtime_by_session(session_id).await?;
105 let runtime = runtime.lock().await.to_owned();
106 runtime.send_notification(notification).await
107 }
108
109 pub async fn list_roots(
115 &self,
116 session_id: &SessionId,
117 params: Option<ListRootsRequestParams>,
118 ) -> SdkResult<ListRootsResult> {
119 let runtime = self.runtime_by_session(session_id).await?;
120 let runtime = runtime.lock().await.to_owned();
121 runtime.list_roots(params).await
122 }
123
124 pub async fn send_logging_message(
125 &self,
126 session_id: &SessionId,
127 params: LoggingMessageNotificationParams,
128 ) -> SdkResult<()> {
129 let runtime = self.runtime_by_session(session_id).await?;
130 let runtime = runtime.lock().await.to_owned();
131 runtime.send_logging_message(params).await
132 }
133
134 pub async fn send_prompt_list_changed(
138 &self,
139 session_id: &SessionId,
140 params: Option<PromptListChangedNotificationParams>,
141 ) -> SdkResult<()> {
142 let runtime = self.runtime_by_session(session_id).await?;
143 let runtime = runtime.lock().await.to_owned();
144 runtime.send_prompt_list_changed(params).await
145 }
146
147 pub async fn send_resource_list_changed(
151 &self,
152 session_id: &SessionId,
153 params: Option<ResourceListChangedNotificationParams>,
154 ) -> SdkResult<()> {
155 let runtime = self.runtime_by_session(session_id).await?;
156 let runtime = runtime.lock().await.to_owned();
157 runtime.send_resource_list_changed(params).await
158 }
159
160 pub async fn send_resource_updated(
164 &self,
165 session_id: &SessionId,
166 params: ResourceUpdatedNotificationParams,
167 ) -> SdkResult<()> {
168 let runtime = self.runtime_by_session(session_id).await?;
169 let runtime = runtime.lock().await.to_owned();
170 runtime.send_resource_updated(params).await
171 }
172
173 pub async fn send_tool_list_changed(
177 &self,
178 session_id: &SessionId,
179 params: Option<ToolListChangedNotificationParams>,
180 ) -> SdkResult<()> {
181 let runtime = self.runtime_by_session(session_id).await?;
182 let runtime = runtime.lock().await.to_owned();
183 runtime.send_tool_list_changed(params).await
184 }
185
186 pub async fn ping(
197 &self,
198 session_id: &SessionId,
199 timeout: Option<Duration>,
200 ) -> SdkResult<crate::schema::Result> {
201 let runtime = self.runtime_by_session(session_id).await?;
202 let runtime = runtime.lock().await.to_owned();
203 runtime.ping(timeout).await
204 }
205
206 pub async fn create_message(
212 &self,
213 session_id: &SessionId,
214 params: CreateMessageRequestParams,
215 ) -> SdkResult<CreateMessageResult> {
216 let runtime = self.runtime_by_session(session_id).await?;
217 let runtime = runtime.lock().await.to_owned();
218 runtime.create_message(params).await
219 }
220
221 pub async fn client_info(
222 &self,
223 session_id: &SessionId,
224 ) -> SdkResult<Option<InitializeRequestParams>> {
225 let runtime = self.runtime_by_session(session_id).await?;
226 let runtime = runtime.lock().await.to_owned();
227 Ok(runtime.client_info())
228 }
229}