rust_mcp_sdk/hyper_servers/
server.rs1use crate::{
2 error::SdkResult,
3 id_generator::{FastIdGenerator, UuidGenerator},
4 mcp_http::{
5 http_utils::{
6 DEFAULT_MESSAGES_ENDPOINT, DEFAULT_SSE_ENDPOINT, DEFAULT_STREAMABLE_HTTP_ENDPOINT,
7 },
8 middleware::dns_rebind_protector::DnsRebindProtector,
9 McpAppState, McpHttpHandler,
10 },
11 mcp_server::hyper_runtime::HyperRuntime,
12 mcp_traits::{mcp_handler::McpServerHandler, IdGenerator},
13 session_store::InMemorySessionStore,
14};
15#[cfg(feature = "ssl")]
16use axum_server::tls_rustls::RustlsConfig;
17use axum_server::Handle;
18use std::{
19 net::{SocketAddr, ToSocketAddrs},
20 path::Path,
21 sync::Arc,
22 time::Duration,
23};
24use tokio::signal;
25
26use super::{
27 error::{TransportServerError, TransportServerResult},
28 routes::app_routes,
29};
30use crate::schema::InitializeResult;
31use axum::Router;
32use rust_mcp_transport::{event_store::EventStore, SessionId, TransportOptions};
33
34const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12);
36const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 5;
37
38pub struct HyperServerOptions {
41 pub host: String,
43
44 pub port: u16,
46
47 pub session_id_generator: Option<Arc<dyn IdGenerator<SessionId>>>,
49
50 pub custom_streamable_http_endpoint: Option<String>,
52
53 pub transport_options: Arc<TransportOptions>,
55
56 pub event_store: Option<Arc<dyn EventStore>>,
59
60 pub enable_json_response: Option<bool>,
65
66 pub ping_interval: Duration,
68
69 pub enable_ssl: bool,
71
72 pub ssl_cert_path: Option<String>,
75
76 pub ssl_key_path: Option<String>,
79
80 pub allowed_hosts: Option<Vec<String>>,
83
84 pub allowed_origins: Option<Vec<String>>,
87
88 pub dns_rebinding_protection: bool,
91
92 pub sse_support: bool,
94
95 pub custom_sse_endpoint: Option<String>,
98
99 pub custom_messages_endpoint: Option<String>,
102}
103
104impl HyperServerOptions {
105 pub fn validate(&self) -> TransportServerResult<()> {
112 if self.enable_ssl {
113 if self.ssl_cert_path.is_none() || self.ssl_key_path.is_none() {
114 return Err(TransportServerError::InvalidServerOptions(
115 "Both 'ssl_cert_path' and 'ssl_key_path' must be provided when SSL is enabled."
116 .into(),
117 ));
118 }
119
120 if !Path::new(self.ssl_cert_path.as_deref().unwrap_or("")).is_file() {
121 return Err(TransportServerError::InvalidServerOptions(
122 "'ssl_cert_path' does not point to a valid or existing file.".into(),
123 ));
124 }
125
126 if !Path::new(self.ssl_key_path.as_deref().unwrap_or("")).is_file() {
127 return Err(TransportServerError::InvalidServerOptions(
128 "'ssl_key_path' does not point to a valid or existing file.".into(),
129 ));
130 }
131 }
132
133 Ok(())
134 }
135
136 pub(crate) async fn resolve_server_address(&self) -> TransportServerResult<SocketAddr> {
144 self.validate()?;
145
146 let mut host = self.host.to_string();
147 if let Some(stripped) = self.host.strip_prefix("http://") {
148 if self.enable_ssl {
149 tracing::warn!("Warning: Ignoring http:// scheme for SSL; using hostname only");
150 }
151 host = stripped.to_string();
152 } else if let Some(stripped) = host.strip_prefix("https://") {
153 host = stripped.to_string();
154 }
155
156 let addr = {
157 let mut iter = (host, self.port)
158 .to_socket_addrs()
159 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))?;
160 match iter.next() {
161 Some(addr) => addr,
162 None => format!("{}:{}", self.host, self.port).parse().map_err(
163 |err: std::net::AddrParseError| {
164 TransportServerError::ServerStartError(err.to_string())
165 },
166 )?,
167 }
168 };
169 Ok(addr)
170 }
171
172 pub fn base_url(&self) -> String {
173 format!(
174 "{}://{}:{}",
175 if self.enable_ssl { "https" } else { "http" },
176 self.host,
177 self.port
178 )
179 }
180 pub fn streamable_http_url(&self) -> String {
181 format!("{}{}", self.base_url(), self.streamable_http_endpoint())
182 }
183 pub fn sse_url(&self) -> String {
184 format!("{}{}", self.base_url(), self.sse_endpoint())
185 }
186 pub fn sse_message_url(&self) -> String {
187 format!("{}{}", self.base_url(), self.sse_messages_endpoint())
188 }
189
190 pub fn sse_endpoint(&self) -> &str {
191 self.custom_sse_endpoint
192 .as_deref()
193 .unwrap_or(DEFAULT_SSE_ENDPOINT)
194 }
195
196 pub fn sse_messages_endpoint(&self) -> &str {
197 self.custom_messages_endpoint
198 .as_deref()
199 .unwrap_or(DEFAULT_MESSAGES_ENDPOINT)
200 }
201
202 pub fn streamable_http_endpoint(&self) -> &str {
203 self.custom_messages_endpoint
204 .as_deref()
205 .unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT)
206 }
207
208 pub fn needs_dns_protection(&self) -> bool {
209 self.dns_rebinding_protection
210 && (self.allowed_hosts.is_some() || self.allowed_origins.is_some())
211 }
212}
213
214impl Default for HyperServerOptions {
219 fn default() -> Self {
220 Self {
221 host: "127.0.0.1".to_string(),
222 port: 8080,
223 custom_sse_endpoint: None,
224 custom_streamable_http_endpoint: None,
225 custom_messages_endpoint: None,
226 ping_interval: DEFAULT_CLIENT_PING_INTERVAL,
227 transport_options: Default::default(),
228 enable_ssl: false,
229 ssl_cert_path: None,
230 ssl_key_path: None,
231 session_id_generator: None,
232 enable_json_response: None,
233 sse_support: true,
234 allowed_hosts: None,
235 allowed_origins: None,
236 dns_rebinding_protection: false,
237 event_store: None,
238 }
239 }
240}
241
242pub struct HyperServer {
244 app: Router,
245 state: Arc<McpAppState>,
246 pub(crate) options: HyperServerOptions,
247 handle: Handle,
248}
249
250impl HyperServer {
251 pub(crate) fn new(
263 server_details: InitializeResult,
264 handler: Arc<dyn McpServerHandler + 'static>,
265 mut server_options: HyperServerOptions,
266 ) -> Self {
267 let state: Arc<McpAppState> = Arc::new(McpAppState {
268 session_store: Arc::new(InMemorySessionStore::new()),
269 id_generator: server_options
270 .session_id_generator
271 .take()
272 .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)),
273 stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))),
274 server_details: Arc::new(server_details),
275 handler,
276 ping_interval: server_options.ping_interval,
277 transport_options: Arc::clone(&server_options.transport_options),
278 enable_json_response: server_options.enable_json_response.unwrap_or(false),
279 event_store: server_options.event_store.as_ref().map(Arc::clone),
280 });
281
282 let mut http_handler = McpHttpHandler::new();
283
284 if server_options.needs_dns_protection() {
285 http_handler.add_middleware(DnsRebindProtector::new(
286 server_options.allowed_hosts.take(),
287 server_options.allowed_origins.take(),
288 ));
289 }
290
291 let app = app_routes(Arc::clone(&state), &server_options, http_handler);
292 Self {
293 app,
294 state,
295 options: server_options,
296 handle: Handle::new(),
297 }
298 }
299
300 pub fn state(&self) -> Arc<McpAppState> {
305 Arc::clone(&self.state)
306 }
307
308 pub fn with_route(mut self, path: &'static str, route: axum::routing::MethodRouter) -> Self {
317 self.app = self.app.route(path, route);
318 self
319 }
320
321 pub async fn server_info(&self, addr: Option<SocketAddr>) -> TransportServerResult<String> {
331 let addr = addr.unwrap_or(self.options.resolve_server_address().await?);
332 let server_type = if self.options.enable_ssl {
333 "SSL server"
334 } else {
335 "Server"
336 };
337 let protocol = if self.options.enable_ssl {
338 "https"
339 } else {
340 "http"
341 };
342
343 let mut server_url = format!(
344 "\n• Streamable HTTP {} is available at {}://{}{}",
345 server_type,
346 protocol,
347 addr,
348 self.options.streamable_http_endpoint()
349 );
350
351 if self.options.sse_support {
352 let sse_url = format!(
353 "\n• SSE {} is available at {}://{}{}",
354 server_type,
355 protocol,
356 addr,
357 self.options.sse_endpoint()
358 );
359 server_url.push_str(&sse_url);
360 };
361
362 Ok(server_url)
363 }
364
365 pub fn options(&self) -> &HyperServerOptions {
366 &self.options
367 }
368
369 #[cfg(feature = "ssl")]
386 pub(crate) async fn start_ssl(self, addr: SocketAddr) -> TransportServerResult<()> {
387 let config = RustlsConfig::from_pem_file(
388 self.options.ssl_cert_path.as_deref().unwrap_or_default(),
389 self.options.ssl_key_path.as_deref().unwrap_or_default(),
390 )
391 .await
392 .map_err(|err| TransportServerError::SslCertError(err.to_string()))?;
393
394 tracing::info!("{}", self.server_info(Some(addr)).await?);
395
396 let handle_clone = self.handle.clone();
398 let state_clone = self.state().clone();
399 tokio::spawn(async move {
400 shutdown_signal(handle_clone, state_clone).await;
401 });
402
403 let handle_clone = self.handle.clone();
404 axum_server::bind_rustls(addr, config)
405 .handle(handle_clone)
406 .serve(self.app.into_make_service())
407 .await
408 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
409 }
410
411 pub fn server_handle(&self) -> Handle {
413 self.handle.clone()
414 }
415
416 pub(crate) async fn start_http(self, addr: SocketAddr) -> TransportServerResult<()> {
424 tracing::info!("{}", self.server_info(Some(addr)).await?);
425
426 let handle_clone = self.handle.clone();
428 tokio::spawn(async move {
429 shutdown_signal(handle_clone, self.state.clone()).await;
430 });
431
432 let handle_clone = self.handle.clone();
433 axum_server::bind(addr)
434 .handle(handle_clone)
435 .serve(self.app.into_make_service())
436 .await
437 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
438 }
439
440 pub async fn start(self) -> SdkResult<()> {
448 let runtime = HyperRuntime::create(self).await?;
449 runtime.await_server().await
450 }
451
452 pub async fn start_runtime(self) -> SdkResult<HyperRuntime> {
459 HyperRuntime::create(self).await
460 }
461}
462
463async fn shutdown_signal(handle: Handle, state: Arc<McpAppState>) {
465 let ctrl_c = async {
467 signal::ctrl_c()
468 .await
469 .expect("Failed to install Ctrl+C handler");
470 };
471
472 #[cfg(unix)]
473 let terminate = async {
474 signal::unix::signal(signal::unix::SignalKind::terminate())
475 .expect("Failed to install signal handler")
476 .recv()
477 .await;
478 };
479
480 #[cfg(not(unix))]
481 let terminate = std::future::pending::<()>();
482
483 tokio::select! {
484 _ = ctrl_c => {},
485 _ = terminate => {},
486 }
487
488 tracing::info!("Signal received, starting graceful shutdown");
489 state.session_store.clear().await;
490 handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS)));
492}