rust_mcp_sdk/hyper_servers/
server.rs1use crate::{
2 error::SdkResult,
3 id_generator::{FastIdGenerator, UuidGenerator},
4 mcp_server::hyper_runtime::HyperRuntime,
5 mcp_traits::{mcp_handler::McpServerHandler, IdGenerator},
6};
7#[cfg(feature = "ssl")]
8use axum_server::tls_rustls::RustlsConfig;
9use axum_server::Handle;
10use std::{
11 net::{SocketAddr, ToSocketAddrs},
12 path::Path,
13 sync::Arc,
14 time::Duration,
15};
16use tokio::signal;
17
18use super::{
19 app_state::AppState,
20 error::{TransportServerError, TransportServerResult},
21 routes::app_routes,
22 InMemorySessionStore,
23};
24use crate::schema::InitializeResult;
25use axum::Router;
26use rust_mcp_transport::{event_store::EventStore, SessionId, TransportOptions};
27
28const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12);
30const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 5;
31const DEFAULT_SSE_ENDPOINT: &str = "/sse";
33const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages";
35const DEFAULT_STREAMABLE_HTTP_ENDPOINT: &str = "/mcp";
37
38pub struct HyperServerOptions {
41 pub host: String,
43
44 pub port: u16,
46
47 pub session_id_generator: Option<Arc<dyn IdGenerator<SessionId>>>,
49
50 pub custom_streamable_http_endpoint: Option<String>,
52
53 pub transport_options: Arc<TransportOptions>,
55
56 pub event_store: Option<Arc<dyn EventStore>>,
59
60 pub enable_json_response: Option<bool>,
65
66 pub ping_interval: Duration,
68
69 pub enable_ssl: bool,
71
72 pub ssl_cert_path: Option<String>,
75
76 pub ssl_key_path: Option<String>,
79
80 pub allowed_hosts: Option<Vec<String>>,
83
84 pub allowed_origins: Option<Vec<String>>,
87
88 pub dns_rebinding_protection: bool,
91
92 pub sse_support: bool,
94
95 pub custom_sse_endpoint: Option<String>,
98
99 pub custom_messages_endpoint: Option<String>,
102}
103
104impl HyperServerOptions {
105 pub fn validate(&self) -> TransportServerResult<()> {
112 if self.enable_ssl {
113 if self.ssl_cert_path.is_none() || self.ssl_key_path.is_none() {
114 return Err(TransportServerError::InvalidServerOptions(
115 "Both 'ssl_cert_path' and 'ssl_key_path' must be provided when SSL is enabled."
116 .into(),
117 ));
118 }
119
120 if !Path::new(self.ssl_cert_path.as_deref().unwrap_or("")).is_file() {
121 return Err(TransportServerError::InvalidServerOptions(
122 "'ssl_cert_path' does not point to a valid or existing file.".into(),
123 ));
124 }
125
126 if !Path::new(self.ssl_key_path.as_deref().unwrap_or("")).is_file() {
127 return Err(TransportServerError::InvalidServerOptions(
128 "'ssl_key_path' does not point to a valid or existing file.".into(),
129 ));
130 }
131 }
132
133 Ok(())
134 }
135
136 pub(crate) async fn resolve_server_address(&self) -> TransportServerResult<SocketAddr> {
144 self.validate()?;
145
146 let mut host = self.host.to_string();
147 if let Some(stripped) = self.host.strip_prefix("http://") {
148 if self.enable_ssl {
149 tracing::warn!("Warning: Ignoring http:// scheme for SSL; using hostname only");
150 }
151 host = stripped.to_string();
152 } else if let Some(stripped) = host.strip_prefix("https://") {
153 host = stripped.to_string();
154 }
155
156 let addr = {
157 let mut iter = (host, self.port)
158 .to_socket_addrs()
159 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))?;
160 match iter.next() {
161 Some(addr) => addr,
162 None => format!("{}:{}", self.host, self.port).parse().map_err(
163 |err: std::net::AddrParseError| {
164 TransportServerError::ServerStartError(err.to_string())
165 },
166 )?,
167 }
168 };
169 Ok(addr)
170 }
171
172 pub fn base_url(&self) -> String {
173 format!(
174 "{}://{}:{}",
175 if self.enable_ssl { "https" } else { "http" },
176 self.host,
177 self.port
178 )
179 }
180 pub fn streamable_http_url(&self) -> String {
181 format!("{}{}", self.base_url(), self.streamable_http_endpoint())
182 }
183 pub fn sse_url(&self) -> String {
184 format!("{}{}", self.base_url(), self.sse_endpoint())
185 }
186 pub fn sse_message_url(&self) -> String {
187 format!("{}{}", self.base_url(), self.sse_messages_endpoint())
188 }
189
190 pub fn sse_endpoint(&self) -> &str {
191 self.custom_sse_endpoint
192 .as_deref()
193 .unwrap_or(DEFAULT_SSE_ENDPOINT)
194 }
195
196 pub fn sse_messages_endpoint(&self) -> &str {
197 self.custom_messages_endpoint
198 .as_deref()
199 .unwrap_or(DEFAULT_MESSAGES_ENDPOINT)
200 }
201
202 pub fn streamable_http_endpoint(&self) -> &str {
203 self.custom_messages_endpoint
204 .as_deref()
205 .unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT)
206 }
207}
208
209impl Default for HyperServerOptions {
214 fn default() -> Self {
215 Self {
216 host: "127.0.0.1".to_string(),
217 port: 8080,
218 custom_sse_endpoint: None,
219 custom_streamable_http_endpoint: None,
220 custom_messages_endpoint: None,
221 ping_interval: DEFAULT_CLIENT_PING_INTERVAL,
222 transport_options: Default::default(),
223 enable_ssl: false,
224 ssl_cert_path: None,
225 ssl_key_path: None,
226 session_id_generator: None,
227 enable_json_response: None,
228 sse_support: true,
229 allowed_hosts: None,
230 allowed_origins: None,
231 dns_rebinding_protection: false,
232 event_store: None,
233 }
234 }
235}
236
237pub struct HyperServer {
239 app: Router,
240 state: Arc<AppState>,
241 pub(crate) options: HyperServerOptions,
242 handle: Handle,
243}
244
245impl HyperServer {
246 pub(crate) fn new(
258 server_details: InitializeResult,
259 handler: Arc<dyn McpServerHandler + 'static>,
260 mut server_options: HyperServerOptions,
261 ) -> Self {
262 let state: Arc<AppState> = Arc::new(AppState {
263 session_store: Arc::new(InMemorySessionStore::new()),
264 id_generator: server_options
265 .session_id_generator
266 .take()
267 .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)),
268 stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))),
269 server_details: Arc::new(server_details),
270 handler,
271 ping_interval: server_options.ping_interval,
272 sse_message_endpoint: server_options.sse_messages_endpoint().to_owned(),
273 http_streamable_endpoint: server_options.streamable_http_endpoint().to_owned(),
274 transport_options: Arc::clone(&server_options.transport_options),
275 enable_json_response: server_options.enable_json_response.unwrap_or(false),
276 allowed_hosts: server_options.allowed_hosts.take(),
277 allowed_origins: server_options.allowed_origins.take(),
278 dns_rebinding_protection: server_options.dns_rebinding_protection,
279 event_store: server_options.event_store.as_ref().map(Arc::clone),
280 });
281 let app = app_routes(Arc::clone(&state), &server_options);
282 Self {
283 app,
284 state,
285 options: server_options,
286 handle: Handle::new(),
287 }
288 }
289
290 pub fn state(&self) -> Arc<AppState> {
295 Arc::clone(&self.state)
296 }
297
298 pub fn with_route(mut self, path: &'static str, route: axum::routing::MethodRouter) -> Self {
307 self.app = self.app.route(path, route);
308 self
309 }
310
311 pub async fn server_info(&self, addr: Option<SocketAddr>) -> TransportServerResult<String> {
321 let addr = addr.unwrap_or(self.options.resolve_server_address().await?);
322 let server_type = if self.options.enable_ssl {
323 "SSL server"
324 } else {
325 "Server"
326 };
327 let protocol = if self.options.enable_ssl {
328 "https"
329 } else {
330 "http"
331 };
332
333 let mut server_url = format!(
334 "\n• Streamable HTTP {} is available at {}://{}{}",
335 server_type,
336 protocol,
337 addr,
338 self.options.streamable_http_endpoint()
339 );
340
341 if self.options.sse_support {
342 let sse_url = format!(
343 "\n• SSE {} is available at {}://{}{}",
344 server_type,
345 protocol,
346 addr,
347 self.options.sse_endpoint()
348 );
349 server_url.push_str(&sse_url);
350 };
351
352 Ok(server_url)
353 }
354
355 pub fn options(&self) -> &HyperServerOptions {
356 &self.options
357 }
358
359 #[cfg(feature = "ssl")]
376 pub(crate) async fn start_ssl(self, addr: SocketAddr) -> TransportServerResult<()> {
377 let config = RustlsConfig::from_pem_file(
378 self.options.ssl_cert_path.as_deref().unwrap_or_default(),
379 self.options.ssl_key_path.as_deref().unwrap_or_default(),
380 )
381 .await
382 .map_err(|err| TransportServerError::SslCertError(err.to_string()))?;
383
384 tracing::info!("{}", self.server_info(Some(addr)).await?);
385
386 let handle_clone = self.handle.clone();
388 let state_clone = self.state().clone();
389 tokio::spawn(async move {
390 shutdown_signal(handle_clone, state_clone).await;
391 });
392
393 let handle_clone = self.handle.clone();
394 axum_server::bind_rustls(addr, config)
395 .handle(handle_clone)
396 .serve(self.app.into_make_service())
397 .await
398 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
399 }
400
401 pub fn server_handle(&self) -> Handle {
403 self.handle.clone()
404 }
405
406 pub(crate) async fn start_http(self, addr: SocketAddr) -> TransportServerResult<()> {
414 tracing::info!("{}", self.server_info(Some(addr)).await?);
415
416 let handle_clone = self.handle.clone();
418 tokio::spawn(async move {
419 shutdown_signal(handle_clone, self.state.clone()).await;
420 });
421
422 let handle_clone = self.handle.clone();
423 axum_server::bind(addr)
424 .handle(handle_clone)
425 .serve(self.app.into_make_service())
426 .await
427 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
428 }
429
430 pub async fn start(self) -> SdkResult<()> {
438 let runtime = HyperRuntime::create(self).await?;
439 runtime.await_server().await
440 }
441
442 pub async fn start_runtime(self) -> SdkResult<HyperRuntime> {
449 HyperRuntime::create(self).await
450 }
451}
452
453async fn shutdown_signal(handle: Handle, state: Arc<AppState>) {
455 let ctrl_c = async {
457 signal::ctrl_c()
458 .await
459 .expect("Failed to install Ctrl+C handler");
460 };
461
462 #[cfg(unix)]
463 let terminate = async {
464 signal::unix::signal(signal::unix::SignalKind::terminate())
465 .expect("Failed to install signal handler")
466 .recv()
467 .await;
468 };
469
470 #[cfg(not(unix))]
471 let terminate = std::future::pending::<()>();
472
473 tokio::select! {
474 _ = ctrl_c => {},
475 _ = terminate => {},
476 }
477
478 tracing::info!("Signal received, starting graceful shutdown");
479 state.session_store.clear().await;
480 handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS)));
482}