rust_mcp_sdk/hyper_servers/
server.rs1use crate::{
2 error::SdkResult,
3 id_generator::{FastIdGenerator, UuidGenerator},
4 mcp_http::{
5 utils::{
6 DEFAULT_MESSAGES_ENDPOINT, DEFAULT_SSE_ENDPOINT, DEFAULT_STREAMABLE_HTTP_ENDPOINT,
7 },
8 McpAppState, McpHttpHandler,
9 },
10 mcp_server::hyper_runtime::HyperRuntime,
11 mcp_traits::{mcp_handler::McpServerHandler, IdGenerator},
12 session_store::InMemorySessionStore,
13};
14#[cfg(feature = "ssl")]
15use axum_server::tls_rustls::RustlsConfig;
16use axum_server::Handle;
17use std::{
18 net::{SocketAddr, ToSocketAddrs},
19 path::Path,
20 sync::Arc,
21 time::Duration,
22};
23use tokio::signal;
24
25use super::{
26 error::{TransportServerError, TransportServerResult},
27 routes::app_routes,
28};
29use crate::schema::InitializeResult;
30use axum::Router;
31use rust_mcp_transport::{event_store::EventStore, SessionId, TransportOptions};
32
33const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12);
35const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 5;
36
37pub struct HyperServerOptions {
40 pub host: String,
42
43 pub port: u16,
45
46 pub session_id_generator: Option<Arc<dyn IdGenerator<SessionId>>>,
48
49 pub custom_streamable_http_endpoint: Option<String>,
51
52 pub transport_options: Arc<TransportOptions>,
54
55 pub event_store: Option<Arc<dyn EventStore>>,
58
59 pub enable_json_response: Option<bool>,
64
65 pub ping_interval: Duration,
67
68 pub enable_ssl: bool,
70
71 pub ssl_cert_path: Option<String>,
74
75 pub ssl_key_path: Option<String>,
78
79 pub allowed_hosts: Option<Vec<String>>,
82
83 pub allowed_origins: Option<Vec<String>>,
86
87 pub dns_rebinding_protection: bool,
90
91 pub sse_support: bool,
93
94 pub custom_sse_endpoint: Option<String>,
97
98 pub custom_messages_endpoint: Option<String>,
101}
102
103impl HyperServerOptions {
104 pub fn validate(&self) -> TransportServerResult<()> {
111 if self.enable_ssl {
112 if self.ssl_cert_path.is_none() || self.ssl_key_path.is_none() {
113 return Err(TransportServerError::InvalidServerOptions(
114 "Both 'ssl_cert_path' and 'ssl_key_path' must be provided when SSL is enabled."
115 .into(),
116 ));
117 }
118
119 if !Path::new(self.ssl_cert_path.as_deref().unwrap_or("")).is_file() {
120 return Err(TransportServerError::InvalidServerOptions(
121 "'ssl_cert_path' does not point to a valid or existing file.".into(),
122 ));
123 }
124
125 if !Path::new(self.ssl_key_path.as_deref().unwrap_or("")).is_file() {
126 return Err(TransportServerError::InvalidServerOptions(
127 "'ssl_key_path' does not point to a valid or existing file.".into(),
128 ));
129 }
130 }
131
132 Ok(())
133 }
134
135 pub(crate) async fn resolve_server_address(&self) -> TransportServerResult<SocketAddr> {
143 self.validate()?;
144
145 let mut host = self.host.to_string();
146 if let Some(stripped) = self.host.strip_prefix("http://") {
147 if self.enable_ssl {
148 tracing::warn!("Warning: Ignoring http:// scheme for SSL; using hostname only");
149 }
150 host = stripped.to_string();
151 } else if let Some(stripped) = host.strip_prefix("https://") {
152 host = stripped.to_string();
153 }
154
155 let addr = {
156 let mut iter = (host, self.port)
157 .to_socket_addrs()
158 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))?;
159 match iter.next() {
160 Some(addr) => addr,
161 None => format!("{}:{}", self.host, self.port).parse().map_err(
162 |err: std::net::AddrParseError| {
163 TransportServerError::ServerStartError(err.to_string())
164 },
165 )?,
166 }
167 };
168 Ok(addr)
169 }
170
171 pub fn base_url(&self) -> String {
172 format!(
173 "{}://{}:{}",
174 if self.enable_ssl { "https" } else { "http" },
175 self.host,
176 self.port
177 )
178 }
179 pub fn streamable_http_url(&self) -> String {
180 format!("{}{}", self.base_url(), self.streamable_http_endpoint())
181 }
182 pub fn sse_url(&self) -> String {
183 format!("{}{}", self.base_url(), self.sse_endpoint())
184 }
185 pub fn sse_message_url(&self) -> String {
186 format!("{}{}", self.base_url(), self.sse_messages_endpoint())
187 }
188
189 pub fn sse_endpoint(&self) -> &str {
190 self.custom_sse_endpoint
191 .as_deref()
192 .unwrap_or(DEFAULT_SSE_ENDPOINT)
193 }
194
195 pub fn sse_messages_endpoint(&self) -> &str {
196 self.custom_messages_endpoint
197 .as_deref()
198 .unwrap_or(DEFAULT_MESSAGES_ENDPOINT)
199 }
200
201 pub fn streamable_http_endpoint(&self) -> &str {
202 self.custom_messages_endpoint
203 .as_deref()
204 .unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT)
205 }
206}
207
208impl Default for HyperServerOptions {
213 fn default() -> Self {
214 Self {
215 host: "127.0.0.1".to_string(),
216 port: 8080,
217 custom_sse_endpoint: None,
218 custom_streamable_http_endpoint: None,
219 custom_messages_endpoint: None,
220 ping_interval: DEFAULT_CLIENT_PING_INTERVAL,
221 transport_options: Default::default(),
222 enable_ssl: false,
223 ssl_cert_path: None,
224 ssl_key_path: None,
225 session_id_generator: None,
226 enable_json_response: None,
227 sse_support: true,
228 allowed_hosts: None,
229 allowed_origins: None,
230 dns_rebinding_protection: false,
231 event_store: None,
232 }
233 }
234}
235
236pub struct HyperServer {
238 app: Router,
239 state: Arc<McpAppState>,
240 pub(crate) options: HyperServerOptions,
241 handle: Handle,
242}
243
244impl HyperServer {
245 pub(crate) fn new(
257 server_details: InitializeResult,
258 handler: Arc<dyn McpServerHandler + 'static>,
259 mut server_options: HyperServerOptions,
260 ) -> Self {
261 let state: Arc<McpAppState> = Arc::new(McpAppState {
262 session_store: Arc::new(InMemorySessionStore::new()),
263 id_generator: server_options
264 .session_id_generator
265 .take()
266 .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)),
267 stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))),
268 server_details: Arc::new(server_details),
269 handler,
270 ping_interval: server_options.ping_interval,
271 transport_options: Arc::clone(&server_options.transport_options),
272 enable_json_response: server_options.enable_json_response.unwrap_or(false),
273 allowed_hosts: server_options.allowed_hosts.take(),
274 allowed_origins: server_options.allowed_origins.take(),
275 dns_rebinding_protection: server_options.dns_rebinding_protection,
276 event_store: server_options.event_store.as_ref().map(Arc::clone),
277 });
278
279 let http_handler = McpHttpHandler::new(); let app = app_routes(Arc::clone(&state), &server_options, http_handler);
281 Self {
282 app,
283 state,
284 options: server_options,
285 handle: Handle::new(),
286 }
287 }
288
289 pub fn state(&self) -> Arc<McpAppState> {
294 Arc::clone(&self.state)
295 }
296
297 pub fn with_route(mut self, path: &'static str, route: axum::routing::MethodRouter) -> Self {
306 self.app = self.app.route(path, route);
307 self
308 }
309
310 pub async fn server_info(&self, addr: Option<SocketAddr>) -> TransportServerResult<String> {
320 let addr = addr.unwrap_or(self.options.resolve_server_address().await?);
321 let server_type = if self.options.enable_ssl {
322 "SSL server"
323 } else {
324 "Server"
325 };
326 let protocol = if self.options.enable_ssl {
327 "https"
328 } else {
329 "http"
330 };
331
332 let mut server_url = format!(
333 "\n• Streamable HTTP {} is available at {}://{}{}",
334 server_type,
335 protocol,
336 addr,
337 self.options.streamable_http_endpoint()
338 );
339
340 if self.options.sse_support {
341 let sse_url = format!(
342 "\n• SSE {} is available at {}://{}{}",
343 server_type,
344 protocol,
345 addr,
346 self.options.sse_endpoint()
347 );
348 server_url.push_str(&sse_url);
349 };
350
351 Ok(server_url)
352 }
353
354 pub fn options(&self) -> &HyperServerOptions {
355 &self.options
356 }
357
358 #[cfg(feature = "ssl")]
375 pub(crate) async fn start_ssl(self, addr: SocketAddr) -> TransportServerResult<()> {
376 let config = RustlsConfig::from_pem_file(
377 self.options.ssl_cert_path.as_deref().unwrap_or_default(),
378 self.options.ssl_key_path.as_deref().unwrap_or_default(),
379 )
380 .await
381 .map_err(|err| TransportServerError::SslCertError(err.to_string()))?;
382
383 tracing::info!("{}", self.server_info(Some(addr)).await?);
384
385 let handle_clone = self.handle.clone();
387 let state_clone = self.state().clone();
388 tokio::spawn(async move {
389 shutdown_signal(handle_clone, state_clone).await;
390 });
391
392 let handle_clone = self.handle.clone();
393 axum_server::bind_rustls(addr, config)
394 .handle(handle_clone)
395 .serve(self.app.into_make_service())
396 .await
397 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
398 }
399
400 pub fn server_handle(&self) -> Handle {
402 self.handle.clone()
403 }
404
405 pub(crate) async fn start_http(self, addr: SocketAddr) -> TransportServerResult<()> {
413 tracing::info!("{}", self.server_info(Some(addr)).await?);
414
415 let handle_clone = self.handle.clone();
417 tokio::spawn(async move {
418 shutdown_signal(handle_clone, self.state.clone()).await;
419 });
420
421 let handle_clone = self.handle.clone();
422 axum_server::bind(addr)
423 .handle(handle_clone)
424 .serve(self.app.into_make_service())
425 .await
426 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
427 }
428
429 pub async fn start(self) -> SdkResult<()> {
437 let runtime = HyperRuntime::create(self).await?;
438 runtime.await_server().await
439 }
440
441 pub async fn start_runtime(self) -> SdkResult<HyperRuntime> {
448 HyperRuntime::create(self).await
449 }
450}
451
452async fn shutdown_signal(handle: Handle, state: Arc<McpAppState>) {
454 let ctrl_c = async {
456 signal::ctrl_c()
457 .await
458 .expect("Failed to install Ctrl+C handler");
459 };
460
461 #[cfg(unix)]
462 let terminate = async {
463 signal::unix::signal(signal::unix::SignalKind::terminate())
464 .expect("Failed to install signal handler")
465 .recv()
466 .await;
467 };
468
469 #[cfg(not(unix))]
470 let terminate = std::future::pending::<()>();
471
472 tokio::select! {
473 _ = ctrl_c => {},
474 _ = terminate => {},
475 }
476
477 tracing::info!("Signal received, starting graceful shutdown");
478 state.session_store.clear().await;
479 handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS)));
481}