1use super::{
2 error::{TransportServerError, TransportServerResult},
3 routes::app_routes,
4};
5#[cfg(feature = "auth")]
6use crate::auth::AuthProvider;
7#[cfg(feature = "auth")]
8use crate::mcp_http::middleware::AuthMiddleware;
9use crate::{
10 error::SdkResult,
11 id_generator::{FastIdGenerator, UuidGenerator},
12 mcp_http::{
13 http_utils::{
14 DEFAULT_MESSAGES_ENDPOINT, DEFAULT_SSE_ENDPOINT, DEFAULT_STREAMABLE_HTTP_ENDPOINT,
15 },
16 middleware::DnsRebindProtector,
17 McpAppState, McpHttpHandler,
18 },
19 mcp_server::hyper_runtime::HyperRuntime,
20 mcp_traits::{IdGenerator, McpServerHandler},
21 session_store::InMemorySessionStore,
22};
23use crate::{mcp_http::Middleware, schema::InitializeResult};
24use axum::Router;
25#[cfg(feature = "ssl")]
26use axum_server::tls_rustls::RustlsConfig;
27use axum_server::Handle;
28use rust_mcp_transport::{event_store::EventStore, SessionId, TransportOptions};
29use std::{
30 net::{SocketAddr, ToSocketAddrs},
31 path::Path,
32 sync::Arc,
33 time::Duration,
34};
35use tokio::signal;
36
37const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12);
39const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 5;
40
41pub struct HyperServerOptions {
44 pub host: String,
46
47 pub port: u16,
49
50 pub session_id_generator: Option<Arc<dyn IdGenerator<SessionId>>>,
52
53 pub custom_streamable_http_endpoint: Option<String>,
55
56 pub transport_options: Arc<TransportOptions>,
58
59 pub event_store: Option<Arc<dyn EventStore>>,
62
63 pub enable_json_response: Option<bool>,
68
69 pub ping_interval: Duration,
71
72 pub enable_ssl: bool,
74
75 pub ssl_cert_path: Option<String>,
78
79 pub ssl_key_path: Option<String>,
82
83 pub allowed_hosts: Option<Vec<String>>,
86
87 pub allowed_origins: Option<Vec<String>>,
90
91 pub dns_rebinding_protection: bool,
94
95 pub sse_support: bool,
97
98 pub custom_sse_endpoint: Option<String>,
101
102 pub custom_messages_endpoint: Option<String>,
105
106 #[cfg(feature = "auth")]
108 pub auth: Option<Arc<dyn AuthProvider>>,
109}
110
111impl HyperServerOptions {
112 pub fn validate(&self) -> TransportServerResult<()> {
119 if self.enable_ssl {
120 if self.ssl_cert_path.is_none() || self.ssl_key_path.is_none() {
121 return Err(TransportServerError::InvalidServerOptions(
122 "Both 'ssl_cert_path' and 'ssl_key_path' must be provided when SSL is enabled."
123 .into(),
124 ));
125 }
126
127 if !Path::new(self.ssl_cert_path.as_deref().unwrap_or("")).is_file() {
128 return Err(TransportServerError::InvalidServerOptions(
129 "'ssl_cert_path' does not point to a valid or existing file.".into(),
130 ));
131 }
132
133 if !Path::new(self.ssl_key_path.as_deref().unwrap_or("")).is_file() {
134 return Err(TransportServerError::InvalidServerOptions(
135 "'ssl_key_path' does not point to a valid or existing file.".into(),
136 ));
137 }
138 }
139
140 Ok(())
141 }
142
143 pub(crate) async fn resolve_server_address(&self) -> TransportServerResult<SocketAddr> {
151 self.validate()?;
152
153 let mut host = self.host.to_string();
154 if let Some(stripped) = self.host.strip_prefix("http://") {
155 if self.enable_ssl {
156 tracing::warn!("Warning: Ignoring http:// scheme for SSL; using hostname only");
157 }
158 host = stripped.to_string();
159 } else if let Some(stripped) = host.strip_prefix("https://") {
160 host = stripped.to_string();
161 }
162
163 let addr = {
164 let mut iter = (host, self.port)
165 .to_socket_addrs()
166 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))?;
167 match iter.next() {
168 Some(addr) => addr,
169 None => format!("{}:{}", self.host, self.port).parse().map_err(
170 |err: std::net::AddrParseError| {
171 TransportServerError::ServerStartError(err.to_string())
172 },
173 )?,
174 }
175 };
176 Ok(addr)
177 }
178
179 pub fn base_url(&self) -> String {
180 format!(
181 "{}://{}:{}",
182 if self.enable_ssl { "https" } else { "http" },
183 self.host,
184 self.port
185 )
186 }
187 pub fn streamable_http_url(&self) -> String {
188 format!("{}{}", self.base_url(), self.streamable_http_endpoint())
189 }
190 pub fn sse_url(&self) -> String {
191 format!("{}{}", self.base_url(), self.sse_endpoint())
192 }
193 pub fn sse_message_url(&self) -> String {
194 format!("{}{}", self.base_url(), self.sse_messages_endpoint())
195 }
196
197 pub fn sse_endpoint(&self) -> &str {
198 self.custom_sse_endpoint
199 .as_deref()
200 .unwrap_or(DEFAULT_SSE_ENDPOINT)
201 }
202
203 pub fn sse_messages_endpoint(&self) -> &str {
204 self.custom_messages_endpoint
205 .as_deref()
206 .unwrap_or(DEFAULT_MESSAGES_ENDPOINT)
207 }
208
209 pub fn streamable_http_endpoint(&self) -> &str {
210 self.custom_streamable_http_endpoint
211 .as_deref()
212 .unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT)
213 }
214
215 pub fn needs_dns_protection(&self) -> bool {
216 self.dns_rebinding_protection
217 && (self.allowed_hosts.is_some() || self.allowed_origins.is_some())
218 }
219}
220
221impl Default for HyperServerOptions {
226 fn default() -> Self {
227 Self {
228 host: "127.0.0.1".to_string(),
229 port: 8080,
230 custom_sse_endpoint: None,
231 custom_streamable_http_endpoint: None,
232 custom_messages_endpoint: None,
233 ping_interval: DEFAULT_CLIENT_PING_INTERVAL,
234 transport_options: Default::default(),
235 enable_ssl: false,
236 ssl_cert_path: None,
237 ssl_key_path: None,
238 session_id_generator: None,
239 enable_json_response: None,
240 sse_support: true,
241 allowed_hosts: None,
242 allowed_origins: None,
243 dns_rebinding_protection: false,
244 event_store: None,
245 #[cfg(feature = "auth")]
246 auth: None,
247 }
248 }
249}
250
251pub struct HyperServer {
253 app: Router,
254 state: Arc<McpAppState>,
255 pub(crate) options: HyperServerOptions,
256 handle: Handle,
257}
258
259impl HyperServer {
260 pub(crate) fn new(
272 server_details: InitializeResult,
273 handler: Arc<dyn McpServerHandler + 'static>,
274 mut server_options: HyperServerOptions,
275 ) -> Self {
276 let state: Arc<McpAppState> = Arc::new(McpAppState {
277 session_store: Arc::new(InMemorySessionStore::new()),
278 id_generator: server_options
279 .session_id_generator
280 .take()
281 .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)),
282 stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))),
283 server_details: Arc::new(server_details),
284 handler,
285 ping_interval: server_options.ping_interval,
286 transport_options: Arc::clone(&server_options.transport_options),
287 enable_json_response: server_options.enable_json_response.unwrap_or(false),
288 event_store: server_options.event_store.as_ref().map(Arc::clone),
289 });
290
291 let mut middlewares: Vec<Arc<dyn Middleware>> = vec![];
293 if server_options.needs_dns_protection() {
294 middlewares.push(Arc::new(DnsRebindProtector::new(
296 server_options.allowed_hosts.take(),
297 server_options.allowed_origins.take(),
298 )));
299 }
300
301 let http_handler = {
302 #[cfg(feature = "auth")]
303 {
304 let auth_provider = server_options.auth.take();
305 if let Some(auth_provider) = auth_provider.as_ref() {
307 middlewares.push(Arc::new(AuthMiddleware::new(auth_provider.clone())))
308 }
309 McpHttpHandler::new(auth_provider, middlewares)
310 }
311 #[cfg(not(feature = "auth"))]
312 McpHttpHandler::new(middlewares)
313 };
314
315 let app = app_routes(Arc::clone(&state), &server_options, http_handler);
316
317 Self {
318 app,
319 state,
320 options: server_options,
321 handle: Handle::new(),
322 }
323 }
324
325 pub fn state(&self) -> Arc<McpAppState> {
330 Arc::clone(&self.state)
331 }
332
333 pub fn with_route(mut self, path: &'static str, route: axum::routing::MethodRouter) -> Self {
342 self.app = self.app.route(path, route);
343 self
344 }
345
346 pub async fn server_info(&self, addr: Option<SocketAddr>) -> TransportServerResult<String> {
356 let addr = addr.unwrap_or(self.options.resolve_server_address().await?);
357 let server_type = if self.options.enable_ssl {
358 "SSL server"
359 } else {
360 "Server"
361 };
362 let protocol = if self.options.enable_ssl {
363 "https"
364 } else {
365 "http"
366 };
367
368 let mut server_url = format!(
369 "\n• Streamable HTTP {} is available at {}://{}{}",
370 server_type,
371 protocol,
372 addr,
373 self.options.streamable_http_endpoint()
374 );
375
376 if self.options.sse_support {
377 let sse_url = format!(
378 "\n• SSE {} is available at {}://{}{}",
379 server_type,
380 protocol,
381 addr,
382 self.options.sse_endpoint()
383 );
384 server_url.push_str(&sse_url);
385 };
386
387 Ok(server_url)
388 }
389
390 pub fn options(&self) -> &HyperServerOptions {
391 &self.options
392 }
393
394 #[cfg(feature = "ssl")]
411 pub(crate) async fn start_ssl(self, addr: SocketAddr) -> TransportServerResult<()> {
412 let config = RustlsConfig::from_pem_file(
413 self.options.ssl_cert_path.as_deref().unwrap_or_default(),
414 self.options.ssl_key_path.as_deref().unwrap_or_default(),
415 )
416 .await
417 .map_err(|err| TransportServerError::SslCertError(err.to_string()))?;
418
419 tracing::info!("{}", self.server_info(Some(addr)).await?);
420
421 let handle_clone = self.handle.clone();
423 let state_clone = self.state().clone();
424 tokio::spawn(async move {
425 shutdown_signal(handle_clone, state_clone).await;
426 });
427
428 let handle_clone = self.handle.clone();
429 axum_server::bind_rustls(addr, config)
430 .handle(handle_clone)
431 .serve(self.app.into_make_service())
432 .await
433 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
434 }
435
436 pub fn server_handle(&self) -> Handle {
438 self.handle.clone()
439 }
440
441 pub(crate) async fn start_http(self, addr: SocketAddr) -> TransportServerResult<()> {
449 tracing::info!("{}", self.server_info(Some(addr)).await?);
450
451 let handle_clone = self.handle.clone();
453 tokio::spawn(async move {
454 shutdown_signal(handle_clone, self.state.clone()).await;
455 });
456
457 let handle_clone = self.handle.clone();
458 axum_server::bind(addr)
459 .handle(handle_clone)
460 .serve(self.app.into_make_service())
461 .await
462 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
463 }
464
465 pub async fn start(self) -> SdkResult<()> {
473 let runtime = HyperRuntime::create(self).await?;
474 runtime.await_server().await
475 }
476
477 pub async fn start_runtime(self) -> SdkResult<HyperRuntime> {
484 HyperRuntime::create(self).await
485 }
486}
487
488async fn shutdown_signal(handle: Handle, state: Arc<McpAppState>) {
490 let ctrl_c = async {
492 signal::ctrl_c()
493 .await
494 .expect("Failed to install Ctrl+C handler");
495 };
496
497 #[cfg(unix)]
498 let terminate = async {
499 signal::unix::signal(signal::unix::SignalKind::terminate())
500 .expect("Failed to install signal handler")
501 .recv()
502 .await;
503 };
504
505 #[cfg(not(unix))]
506 let terminate = std::future::pending::<()>();
507
508 tokio::select! {
509 _ = ctrl_c => {},
510 _ = terminate => {},
511 }
512
513 tracing::info!("Signal received, starting graceful shutdown");
514 state.session_store.clear().await;
515 handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS)));
517}
518
519#[cfg(test)]
520mod tests {
521 use super::*;
522
523 use tempfile::NamedTempFile;
524
525 #[test]
526 fn test_server_options_base_url_custom() {
527 let options = HyperServerOptions {
528 host: String::from("127.0.0.1"),
529 port: 8081,
530 enable_ssl: true,
531 ..Default::default()
532 };
533 assert_eq!(options.base_url(), "https://127.0.0.1:8081");
534 }
535
536 #[test]
537 fn test_server_options_streamable_http_custom() {
538 let options = HyperServerOptions {
539 custom_streamable_http_endpoint: Some(String::from("/abcd/mcp")),
540 host: String::from("127.0.0.1"),
541 port: 8081,
542 enable_ssl: true,
543 ..Default::default()
544 };
545 assert_eq!(
546 options.streamable_http_url(),
547 "https://127.0.0.1:8081/abcd/mcp"
548 );
549 assert_eq!(options.streamable_http_endpoint(), "/abcd/mcp");
550 }
551
552 #[test]
553 fn test_server_options_sse_custom() {
554 let options = HyperServerOptions {
555 custom_sse_endpoint: Some(String::from("/abcd/sse")),
556 host: String::from("127.0.0.1"),
557 port: 8081,
558 enable_ssl: true,
559 ..Default::default()
560 };
561 assert_eq!(options.sse_url(), "https://127.0.0.1:8081/abcd/sse");
562 assert_eq!(options.sse_endpoint(), "/abcd/sse");
563 }
564
565 #[test]
566 fn test_server_options_sse_messages_custom() {
567 let options = HyperServerOptions {
568 custom_messages_endpoint: Some(String::from("/abcd/messages")),
569 ..Default::default()
570 };
571 assert_eq!(
572 options.sse_message_url(),
573 "http://127.0.0.1:8080/abcd/messages"
574 );
575 assert_eq!(options.sse_messages_endpoint(), "/abcd/messages");
576 }
577
578 #[test]
579 fn test_server_options_needs_dns_protection() {
580 let options = HyperServerOptions::default();
581
582 assert!(!options.needs_dns_protection());
584
585 let options = HyperServerOptions {
587 dns_rebinding_protection: true,
588 ..Default::default()
589 };
590 assert!(!options.needs_dns_protection());
591
592 let options = HyperServerOptions {
594 dns_rebinding_protection: true,
595 allowed_hosts: Some(vec![String::from("127.0.0.1")]),
596 ..Default::default()
597 };
598 assert!(options.needs_dns_protection());
599
600 let options = HyperServerOptions {
602 dns_rebinding_protection: true,
603 allowed_origins: Some(vec![String::from("http://127.0.0.1:8080")]),
604 ..Default::default()
605 };
606 assert!(options.needs_dns_protection());
607 }
608
609 #[test]
610 fn test_server_options_validate() {
611 let options = HyperServerOptions::default();
612 assert!(options.validate().is_ok());
613
614 let options = HyperServerOptions {
616 enable_ssl: true,
617 ..Default::default()
618 };
619 assert!(options.validate().is_err());
620
621 let options = HyperServerOptions {
623 enable_ssl: true,
624 ssl_cert_path: Some(String::from("/invalid/path/to/cert.pem")),
625 ssl_key_path: Some(String::from("/invalid/path/to/key.pem")),
626 ..Default::default()
627 };
628 assert!(options.validate().is_err());
629
630 let cert_file =
632 NamedTempFile::with_suffix(".pem").expect("Expected to create test cert file");
633 let ssl_cert_path = cert_file
634 .path()
635 .to_str()
636 .expect("Expected to get cert path")
637 .to_string();
638 let key_file =
639 NamedTempFile::with_suffix(".pem").expect("Expected to create test key file");
640 let ssl_key_path = key_file
641 .path()
642 .to_str()
643 .expect("Expected to get key path")
644 .to_string();
645
646 let options = HyperServerOptions {
647 enable_ssl: true,
648 ssl_cert_path: Some(ssl_cert_path),
649 ssl_key_path: Some(ssl_key_path),
650 ..Default::default()
651 };
652 assert!(options.validate().is_ok());
653 }
654
655 #[tokio::test]
656 async fn test_server_options_resolve_server_address() {
657 let options = HyperServerOptions::default();
658 assert!(options.resolve_server_address().await.is_ok());
659
660 let options = HyperServerOptions {
662 host: String::from("8.6.7.5"),
663 port: 309,
664 ..Default::default()
665 };
666 assert!(options.resolve_server_address().await.is_ok());
667
668 let options = HyperServerOptions {
670 host: String::from("http://8.6.7.5"),
671 port: 309,
672 ..Default::default()
673 };
674 assert!(options.resolve_server_address().await.is_ok());
675
676 let options = HyperServerOptions {
678 host: String::from("invalid-host"),
679 port: 309,
680 ..Default::default()
681 };
682 assert!(options.resolve_server_address().await.is_err());
683 }
684}