rust_mcp_sdk/hyper_servers/
hyper_runtime.rs1use std::{sync::Arc, time::Duration};
2
3use crate::{
4 mcp_server::HyperServer,
5 schema::{
6 schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient},
7 CreateMessageRequestParams, CreateMessageResult, InitializeRequestParams,
8 ListRootsRequestParams, ListRootsResult, LoggingMessageNotificationParams,
9 PromptListChangedNotificationParams, ResourceListChangedNotificationParams,
10 ResourceUpdatedNotificationParams, ToolListChangedNotificationParams,
11 },
12 McpServer,
13};
14
15use axum_server::Handle;
16use rust_mcp_transport::SessionId;
17use tokio::{sync::Mutex, task::JoinHandle};
18
19use crate::{
20 error::SdkResult,
21 hyper_servers::app_state::AppState,
22 mcp_server::{
23 error::{TransportServerError, TransportServerResult},
24 ServerRuntime,
25 },
26};
27
28pub struct HyperRuntime {
29 pub(crate) state: Arc<AppState>,
30 pub(crate) server_task: JoinHandle<Result<(), TransportServerError>>,
31 pub(crate) server_handle: Handle,
32}
33
34impl HyperRuntime {
35 pub async fn create(server: HyperServer) -> SdkResult<Self> {
36 let addr = server.options.resolve_server_address().await?;
37 let state = server.state();
38
39 let server_handle = server.server_handle();
40
41 let server_task = tokio::spawn(async move {
42 #[cfg(feature = "ssl")]
43 if server.options.enable_ssl {
44 server.start_ssl(addr).await
45 } else {
46 server.start_http(addr).await
47 }
48
49 #[cfg(not(feature = "ssl"))]
50 if server.options.enable_ssl {
51 panic!("SSL requested but the 'ssl' feature is not enabled");
52 } else {
53 server.start_http(addr).await
54 }
55 });
56
57 Ok(Self {
58 state,
59 server_task,
60 server_handle,
61 })
62 }
63
64 pub fn graceful_shutdown(&self, timeout: Option<Duration>) {
65 self.server_handle.graceful_shutdown(timeout);
66 }
67
68 pub async fn await_server(self) -> SdkResult<()> {
69 let result = self.server_task.await?;
70 result.map_err(|err| err.into())
71 }
72
73 pub async fn runtime_by_session(
74 &self,
75 session_id: &SessionId,
76 ) -> TransportServerResult<Arc<Mutex<Arc<ServerRuntime>>>> {
77 self.state.session_store.get(session_id).await.ok_or(
78 TransportServerError::SessionIdInvalid(session_id.to_string()),
79 )
80 }
81
82 pub async fn send_request(
83 &self,
84 session_id: &SessionId,
85 request: RequestFromServer,
86 timeout: Option<Duration>,
87 ) -> SdkResult<ResultFromClient> {
88 let runtime = self.runtime_by_session(session_id).await?;
89 let runtime = runtime.lock().await.to_owned();
90 runtime.request(request, timeout).await
91 }
92
93 pub async fn send_notification(
94 &self,
95 session_id: &SessionId,
96 notification: NotificationFromServer,
97 ) -> SdkResult<()> {
98 let runtime = self.runtime_by_session(session_id).await?;
99 let runtime = runtime.lock().await.to_owned();
100 runtime.send_notification(notification).await
101 }
102
103 pub async fn list_roots(
109 &self,
110 session_id: &SessionId,
111 params: Option<ListRootsRequestParams>,
112 ) -> SdkResult<ListRootsResult> {
113 let runtime = self.runtime_by_session(session_id).await?;
114 let runtime = runtime.lock().await.to_owned();
115 runtime.list_roots(params).await
116 }
117
118 pub async fn send_logging_message(
119 &self,
120 session_id: &SessionId,
121 params: LoggingMessageNotificationParams,
122 ) -> SdkResult<()> {
123 let runtime = self.runtime_by_session(session_id).await?;
124 let runtime = runtime.lock().await.to_owned();
125 runtime.send_logging_message(params).await
126 }
127
128 pub async fn send_prompt_list_changed(
132 &self,
133 session_id: &SessionId,
134 params: Option<PromptListChangedNotificationParams>,
135 ) -> SdkResult<()> {
136 let runtime = self.runtime_by_session(session_id).await?;
137 let runtime = runtime.lock().await.to_owned();
138 runtime.send_prompt_list_changed(params).await
139 }
140
141 pub async fn send_resource_list_changed(
145 &self,
146 session_id: &SessionId,
147 params: Option<ResourceListChangedNotificationParams>,
148 ) -> SdkResult<()> {
149 let runtime = self.runtime_by_session(session_id).await?;
150 let runtime = runtime.lock().await.to_owned();
151 runtime.send_resource_list_changed(params).await
152 }
153
154 pub async fn send_resource_updated(
158 &self,
159 session_id: &SessionId,
160 params: ResourceUpdatedNotificationParams,
161 ) -> SdkResult<()> {
162 let runtime = self.runtime_by_session(session_id).await?;
163 let runtime = runtime.lock().await.to_owned();
164 runtime.send_resource_updated(params).await
165 }
166
167 pub async fn send_tool_list_changed(
171 &self,
172 session_id: &SessionId,
173 params: Option<ToolListChangedNotificationParams>,
174 ) -> SdkResult<()> {
175 let runtime = self.runtime_by_session(session_id).await?;
176 let runtime = runtime.lock().await.to_owned();
177 runtime.send_tool_list_changed(params).await
178 }
179
180 pub async fn ping(
191 &self,
192 session_id: &SessionId,
193 timeout: Option<Duration>,
194 ) -> SdkResult<crate::schema::Result> {
195 let runtime = self.runtime_by_session(session_id).await?;
196 let runtime = runtime.lock().await.to_owned();
197 runtime.ping(timeout).await
198 }
199
200 pub async fn create_message(
206 &self,
207 session_id: &SessionId,
208 params: CreateMessageRequestParams,
209 ) -> SdkResult<CreateMessageResult> {
210 let runtime = self.runtime_by_session(session_id).await?;
211 let runtime = runtime.lock().await.to_owned();
212 runtime.create_message(params).await
213 }
214
215 pub async fn client_info(
216 &self,
217 session_id: &SessionId,
218 ) -> SdkResult<Option<InitializeRequestParams>> {
219 let runtime = self.runtime_by_session(session_id).await?;
220 let runtime = runtime.lock().await.to_owned();
221 Ok(runtime.client_info())
222 }
223}