rust_mcp_sdk/hyper_servers/
server.rs

1use super::{
2    error::{TransportServerError, TransportServerResult},
3    routes::app_routes,
4};
5#[cfg(feature = "auth")]
6use crate::auth::AuthProvider;
7#[cfg(feature = "auth")]
8use crate::mcp_http::middleware::AuthMiddleware;
9use crate::{
10    error::SdkResult,
11    id_generator::{FastIdGenerator, UuidGenerator},
12    mcp_http::{
13        http_utils::{
14            DEFAULT_MESSAGES_ENDPOINT, DEFAULT_SSE_ENDPOINT, DEFAULT_STREAMABLE_HTTP_ENDPOINT,
15        },
16        middleware::DnsRebindProtector,
17        McpAppState, McpHttpHandler,
18    },
19    mcp_server::hyper_runtime::HyperRuntime,
20    mcp_traits::{IdGenerator, McpServerHandler},
21    session_store::InMemorySessionStore,
22};
23use crate::{mcp_http::Middleware, schema::InitializeResult};
24use axum::Router;
25#[cfg(feature = "ssl")]
26use axum_server::tls_rustls::RustlsConfig;
27use axum_server::Handle;
28use rust_mcp_transport::{event_store::EventStore, SessionId, TransportOptions};
29use std::{
30    net::{SocketAddr, ToSocketAddrs},
31    path::Path,
32    sync::Arc,
33    time::Duration,
34};
35use tokio::signal;
36
37// Default client ping interval (12 seconds)
38const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12);
39const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 5;
40
41/// Configuration struct for the Hyper server
42/// Used to configure the HyperServer instance.
43pub struct HyperServerOptions {
44    /// Hostname or IP address the server will bind to (default: "127.0.0.1")
45    pub host: String,
46
47    /// Hostname or IP address the server will bind to (default: "8080")
48    pub port: u16,
49
50    /// Optional thread-safe session id generator to generate unique session IDs.
51    pub session_id_generator: Option<Arc<dyn IdGenerator<SessionId>>>,
52
53    /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`)
54    pub custom_streamable_http_endpoint: Option<String>,
55
56    /// Shared transport configuration used by the server
57    pub transport_options: Arc<TransportOptions>,
58
59    /// Event store for resumability support
60    /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages
61    pub event_store: Option<Arc<dyn EventStore>>,
62
63    /// This setting only applies to streamable HTTP.
64    /// If true, the server will return JSON responses instead of starting an SSE stream.
65    /// This can be useful for simple request/response scenarios without streaming.
66    /// Default is false (SSE streams are preferred).
67    pub enable_json_response: Option<bool>,
68
69    /// Interval between automatic ping messages sent to clients to detect disconnects
70    pub ping_interval: Duration,
71
72    /// Enables SSL/TLS if set to `true`
73    pub enable_ssl: bool,
74
75    /// Path to the SSL/TLS certificate file (e.g., "cert.pem").
76    /// Required if `enable_ssl` is `true`.
77    pub ssl_cert_path: Option<String>,
78
79    /// Path to the SSL/TLS private key file (e.g., "key.pem").
80    /// Required if `enable_ssl` is `true`.
81    pub ssl_key_path: Option<String>,
82
83    /// List of allowed host header values for DNS rebinding protection.
84    /// If not specified, host validation is disabled.
85    pub allowed_hosts: Option<Vec<String>>,
86
87    /// List of allowed origin header values for DNS rebinding protection.
88    /// If not specified, origin validation is disabled.
89    pub allowed_origins: Option<Vec<String>>,
90
91    /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured).
92    /// Default is false for backwards compatibility.
93    pub dns_rebinding_protection: bool,
94
95    /// If set to true, the SSE transport will also be supported for backward compatibility (default: true)
96    pub sse_support: bool,
97
98    /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`)
99    /// Applicable only if sse_support is true
100    pub custom_sse_endpoint: Option<String>,
101
102    /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`)
103    /// Applicable only if sse_support is true
104    pub custom_messages_endpoint: Option<String>,
105
106    /// Optional authentication provider for protecting MCP server.
107    #[cfg(feature = "auth")]
108    pub auth: Option<Arc<dyn AuthProvider>>,
109}
110
111impl HyperServerOptions {
112    /// Validates the server configuration options
113    ///
114    /// Ensures that SSL-related paths are provided and valid when SSL is enabled.
115    ///
116    /// # Returns
117    /// * `TransportServerResult<()>` - Ok if validation passes, Err with TransportServerError if invalid
118    pub fn validate(&self) -> TransportServerResult<()> {
119        if self.enable_ssl {
120            if self.ssl_cert_path.is_none() || self.ssl_key_path.is_none() {
121                return Err(TransportServerError::InvalidServerOptions(
122                    "Both 'ssl_cert_path' and 'ssl_key_path' must be provided when SSL is enabled."
123                        .into(),
124                ));
125            }
126
127            if !Path::new(self.ssl_cert_path.as_deref().unwrap_or("")).is_file() {
128                return Err(TransportServerError::InvalidServerOptions(
129                    "'ssl_cert_path' does not point to a valid or existing file.".into(),
130                ));
131            }
132
133            if !Path::new(self.ssl_key_path.as_deref().unwrap_or("")).is_file() {
134                return Err(TransportServerError::InvalidServerOptions(
135                    "'ssl_key_path' does not point to a valid or existing file.".into(),
136                ));
137            }
138        }
139
140        Ok(())
141    }
142
143    /// Resolves the server address from host and port
144    ///
145    /// Validates the configuration and converts the host/port into a SocketAddr.
146    /// Handles scheme prefixes (http:// or https://) and logs warnings for mismatches.
147    ///
148    /// # Returns
149    /// * `TransportServerResult<SocketAddr>` - The resolved server address or an error
150    pub(crate) async fn resolve_server_address(&self) -> TransportServerResult<SocketAddr> {
151        self.validate()?;
152
153        let mut host = self.host.to_string();
154        if let Some(stripped) = self.host.strip_prefix("http://") {
155            if self.enable_ssl {
156                tracing::warn!("Warning: Ignoring http:// scheme for SSL; using hostname only");
157            }
158            host = stripped.to_string();
159        } else if let Some(stripped) = host.strip_prefix("https://") {
160            host = stripped.to_string();
161        }
162
163        let addr = {
164            let mut iter = (host, self.port)
165                .to_socket_addrs()
166                .map_err(|err| TransportServerError::ServerStartError(err.to_string()))?;
167            match iter.next() {
168                Some(addr) => addr,
169                None => format!("{}:{}", self.host, self.port).parse().map_err(
170                    |err: std::net::AddrParseError| {
171                        TransportServerError::ServerStartError(err.to_string())
172                    },
173                )?,
174            }
175        };
176        Ok(addr)
177    }
178
179    pub fn base_url(&self) -> String {
180        format!(
181            "{}://{}:{}",
182            if self.enable_ssl { "https" } else { "http" },
183            self.host,
184            self.port
185        )
186    }
187    pub fn streamable_http_url(&self) -> String {
188        format!("{}{}", self.base_url(), self.streamable_http_endpoint())
189    }
190    pub fn sse_url(&self) -> String {
191        format!("{}{}", self.base_url(), self.sse_endpoint())
192    }
193    pub fn sse_message_url(&self) -> String {
194        format!("{}{}", self.base_url(), self.sse_messages_endpoint())
195    }
196
197    pub fn sse_endpoint(&self) -> &str {
198        self.custom_sse_endpoint
199            .as_deref()
200            .unwrap_or(DEFAULT_SSE_ENDPOINT)
201    }
202
203    pub fn sse_messages_endpoint(&self) -> &str {
204        self.custom_messages_endpoint
205            .as_deref()
206            .unwrap_or(DEFAULT_MESSAGES_ENDPOINT)
207    }
208
209    pub fn streamable_http_endpoint(&self) -> &str {
210        self.custom_streamable_http_endpoint
211            .as_deref()
212            .unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT)
213    }
214
215    pub fn needs_dns_protection(&self) -> bool {
216        self.dns_rebinding_protection
217            && (self.allowed_hosts.is_some() || self.allowed_origins.is_some())
218    }
219}
220
221/// Default implementation for HyperServerOptions
222///
223/// Provides default values for the server configuration, including 127.0.0.1 address,
224/// port 8080, default Streamable HTTP endpoint, and 12-second ping interval.
225impl Default for HyperServerOptions {
226    fn default() -> Self {
227        Self {
228            host: "127.0.0.1".to_string(),
229            port: 8080,
230            custom_sse_endpoint: None,
231            custom_streamable_http_endpoint: None,
232            custom_messages_endpoint: None,
233            ping_interval: DEFAULT_CLIENT_PING_INTERVAL,
234            transport_options: Default::default(),
235            enable_ssl: false,
236            ssl_cert_path: None,
237            ssl_key_path: None,
238            session_id_generator: None,
239            enable_json_response: None,
240            sse_support: true,
241            allowed_hosts: None,
242            allowed_origins: None,
243            dns_rebinding_protection: false,
244            event_store: None,
245            #[cfg(feature = "auth")]
246            auth: None,
247        }
248    }
249}
250
251/// Hyper server struct for managing the Axum-based web server
252pub struct HyperServer {
253    app: Router,
254    state: Arc<McpAppState>,
255    pub(crate) options: HyperServerOptions,
256    handle: Handle,
257}
258
259impl HyperServer {
260    /// Creates a new HyperServer instance
261    ///
262    /// Initializes the server with the provided server details, handler, and options.
263    ///
264    /// # Arguments
265    /// * `server_details` - Initialization result from the MCP schema
266    /// * `handler` - Shared MCP server handler with static lifetime
267    /// * `server_options` - Server configuration options
268    ///
269    /// # Returns
270    /// * `Self` - A new HyperServer instance
271    pub(crate) fn new(
272        server_details: InitializeResult,
273        handler: Arc<dyn McpServerHandler + 'static>,
274        mut server_options: HyperServerOptions,
275    ) -> Self {
276        let state: Arc<McpAppState> = Arc::new(McpAppState {
277            session_store: Arc::new(InMemorySessionStore::new()),
278            id_generator: server_options
279                .session_id_generator
280                .take()
281                .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)),
282            stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))),
283            server_details: Arc::new(server_details),
284            handler,
285            ping_interval: server_options.ping_interval,
286            transport_options: Arc::clone(&server_options.transport_options),
287            enable_json_response: server_options.enable_json_response.unwrap_or(false),
288            event_store: server_options.event_store.as_ref().map(Arc::clone),
289        });
290
291        // populate middlewares
292        let mut middlewares: Vec<Arc<dyn Middleware>> = vec![];
293        if server_options.needs_dns_protection() {
294            //dns pritection middleware
295            middlewares.push(Arc::new(DnsRebindProtector::new(
296                server_options.allowed_hosts.take(),
297                server_options.allowed_origins.take(),
298            )));
299        }
300
301        let http_handler = {
302            #[cfg(feature = "auth")]
303            {
304                let auth_provider = server_options.auth.take();
305                // add auth middleware if there is a auth_provider
306                if let Some(auth_provider) = auth_provider.as_ref() {
307                    middlewares.push(Arc::new(AuthMiddleware::new(auth_provider.clone())))
308                }
309                McpHttpHandler::new(auth_provider, middlewares)
310            }
311            #[cfg(not(feature = "auth"))]
312            McpHttpHandler::new(middlewares)
313        };
314
315        let app = app_routes(Arc::clone(&state), &server_options, http_handler);
316
317        Self {
318            app,
319            state,
320            options: server_options,
321            handle: Handle::new(),
322        }
323    }
324
325    /// Returns a shared reference to the application state
326    ///
327    /// # Returns
328    /// * `Arc<McpAppState>` - Shared application state
329    pub fn state(&self) -> Arc<McpAppState> {
330        Arc::clone(&self.state)
331    }
332
333    /// Adds a new route to the server
334    ///
335    /// # Arguments
336    /// * `path` - The route path (static string)
337    /// * `route` - The Axum MethodRouter for handling the route
338    ///
339    /// # Returns
340    /// * `Self` - The modified HyperServer instance
341    pub fn with_route(mut self, path: &'static str, route: axum::routing::MethodRouter) -> Self {
342        self.app = self.app.route(path, route);
343        self
344    }
345
346    /// Generates server information string
347    ///
348    /// Constructs a string describing the server type, protocol, address, and SSE endpoint.
349    ///
350    /// # Arguments
351    /// * `addr` - Optional SocketAddr; if None, resolves from options
352    ///
353    /// # Returns
354    /// * `TransportServerResult<String>` - The server information string or an error
355    pub async fn server_info(&self, addr: Option<SocketAddr>) -> TransportServerResult<String> {
356        let addr = addr.unwrap_or(self.options.resolve_server_address().await?);
357        let server_type = if self.options.enable_ssl {
358            "SSL server"
359        } else {
360            "Server"
361        };
362        let protocol = if self.options.enable_ssl {
363            "https"
364        } else {
365            "http"
366        };
367
368        let mut server_url = format!(
369            "\n• Streamable HTTP {} is available at {}://{}{}",
370            server_type,
371            protocol,
372            addr,
373            self.options.streamable_http_endpoint()
374        );
375
376        if self.options.sse_support {
377            let sse_url = format!(
378                "\n• SSE {} is available at {}://{}{}",
379                server_type,
380                protocol,
381                addr,
382                self.options.sse_endpoint()
383            );
384            server_url.push_str(&sse_url);
385        };
386
387        Ok(server_url)
388    }
389
390    pub fn options(&self) -> &HyperServerOptions {
391        &self.options
392    }
393
394    // pub fn with_layer<L>(mut self, layer: L) -> Self
395    // where
396    //     // L: Layer<axum::body::Body> + Clone + Send + Sync + 'static,
397    //     L::Service: Send + Sync + 'static,
398    // {
399    //     self.router = self.router.layer(layer);
400    //     self
401    // }
402
403    /// Starts the server with SSL support (available when "ssl" feature is enabled)
404    ///
405    /// # Arguments
406    /// * `addr` - The server address to bind to
407    ///
408    /// # Returns
409    /// * `TransportServerResult<()>` - Ok if the server starts successfully, Err otherwise
410    #[cfg(feature = "ssl")]
411    pub(crate) async fn start_ssl(self, addr: SocketAddr) -> TransportServerResult<()> {
412        let config = RustlsConfig::from_pem_file(
413            self.options.ssl_cert_path.as_deref().unwrap_or_default(),
414            self.options.ssl_key_path.as_deref().unwrap_or_default(),
415        )
416        .await
417        .map_err(|err| TransportServerError::SslCertError(err.to_string()))?;
418
419        tracing::info!("{}", self.server_info(Some(addr)).await?);
420
421        // Spawn a task to trigger shutdown on signal
422        let handle_clone = self.handle.clone();
423        let state_clone = self.state().clone();
424        tokio::spawn(async move {
425            shutdown_signal(handle_clone, state_clone).await;
426        });
427
428        let handle_clone = self.handle.clone();
429        axum_server::bind_rustls(addr, config)
430            .handle(handle_clone)
431            .serve(self.app.into_make_service())
432            .await
433            .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
434    }
435
436    /// Returns server handle that could be used for graceful shutdown
437    pub fn server_handle(&self) -> Handle {
438        self.handle.clone()
439    }
440
441    /// Starts the server without SSL
442    ///
443    /// # Arguments
444    /// * `addr` - The server address to bind to
445    ///
446    /// # Returns
447    /// * `TransportServerResult<()>` - Ok if the server starts successfully, Err otherwise
448    pub(crate) async fn start_http(self, addr: SocketAddr) -> TransportServerResult<()> {
449        tracing::info!("{}", self.server_info(Some(addr)).await?);
450
451        // Spawn a task to trigger shutdown on signal
452        let handle_clone = self.handle.clone();
453        tokio::spawn(async move {
454            shutdown_signal(handle_clone, self.state.clone()).await;
455        });
456
457        let handle_clone = self.handle.clone();
458        axum_server::bind(addr)
459            .handle(handle_clone)
460            .serve(self.app.into_make_service())
461            .await
462            .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
463    }
464
465    /// Starts the server, choosing SSL or HTTP based on configuration
466    ///
467    /// Resolves the server address and starts the server in either SSL or HTTP mode.
468    /// Panics if SSL is requested but the "ssl" feature is not enabled.
469    ///
470    /// # Returns
471    /// * `SdkResult<()>` - Ok if the server starts successfully, Err otherwise
472    pub async fn start(self) -> SdkResult<()> {
473        let runtime = HyperRuntime::create(self).await?;
474        runtime.await_server().await
475    }
476
477    /// Similar to start() , but returns a HyperRuntime after server started
478    ///
479    /// HyperRuntime could be used to access sessions and send server initiated messages if needed
480    ///
481    /// # Returns
482    /// * `SdkResult<HyperRuntime>` - Ok if the server starts successfully, Err otherwise
483    pub async fn start_runtime(self) -> SdkResult<HyperRuntime> {
484        HyperRuntime::create(self).await
485    }
486}
487
488// Shutdown signal handler
489async fn shutdown_signal(handle: Handle, state: Arc<McpAppState>) {
490    // Wait for a Ctrl+C or SIGTERM signal
491    let ctrl_c = async {
492        signal::ctrl_c()
493            .await
494            .expect("Failed to install Ctrl+C handler");
495    };
496
497    #[cfg(unix)]
498    let terminate = async {
499        signal::unix::signal(signal::unix::SignalKind::terminate())
500            .expect("Failed to install signal handler")
501            .recv()
502            .await;
503    };
504
505    #[cfg(not(unix))]
506    let terminate = std::future::pending::<()>();
507
508    tokio::select! {
509        _ = ctrl_c => {},
510        _ = terminate => {},
511    }
512
513    tracing::info!("Signal received, starting graceful shutdown");
514    state.session_store.clear().await;
515    // Trigger graceful shutdown with a timeout
516    handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS)));
517}
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522
523    use tempfile::NamedTempFile;
524
525    #[test]
526    fn test_server_options_base_url_custom() {
527        let options = HyperServerOptions {
528            host: String::from("127.0.0.1"),
529            port: 8081,
530            enable_ssl: true,
531            ..Default::default()
532        };
533        assert_eq!(options.base_url(), "https://127.0.0.1:8081");
534    }
535
536    #[test]
537    fn test_server_options_streamable_http_custom() {
538        let options = HyperServerOptions {
539            custom_streamable_http_endpoint: Some(String::from("/abcd/mcp")),
540            host: String::from("127.0.0.1"),
541            port: 8081,
542            enable_ssl: true,
543            ..Default::default()
544        };
545        assert_eq!(
546            options.streamable_http_url(),
547            "https://127.0.0.1:8081/abcd/mcp"
548        );
549        assert_eq!(options.streamable_http_endpoint(), "/abcd/mcp");
550    }
551
552    #[test]
553    fn test_server_options_sse_custom() {
554        let options = HyperServerOptions {
555            custom_sse_endpoint: Some(String::from("/abcd/sse")),
556            host: String::from("127.0.0.1"),
557            port: 8081,
558            enable_ssl: true,
559            ..Default::default()
560        };
561        assert_eq!(options.sse_url(), "https://127.0.0.1:8081/abcd/sse");
562        assert_eq!(options.sse_endpoint(), "/abcd/sse");
563    }
564
565    #[test]
566    fn test_server_options_sse_messages_custom() {
567        let options = HyperServerOptions {
568            custom_messages_endpoint: Some(String::from("/abcd/messages")),
569            ..Default::default()
570        };
571        assert_eq!(
572            options.sse_message_url(),
573            "http://127.0.0.1:8080/abcd/messages"
574        );
575        assert_eq!(options.sse_messages_endpoint(), "/abcd/messages");
576    }
577
578    #[test]
579    fn test_server_options_needs_dns_protection() {
580        let options = HyperServerOptions::default();
581
582        // should be false by default
583        assert!(!options.needs_dns_protection());
584
585        // should still be false unless allowed_hosts or allowed_origins are also provided
586        let options = HyperServerOptions {
587            dns_rebinding_protection: true,
588            ..Default::default()
589        };
590        assert!(!options.needs_dns_protection());
591
592        // should be true when dns_rebinding_protection is true and allowed_hosts is provided
593        let options = HyperServerOptions {
594            dns_rebinding_protection: true,
595            allowed_hosts: Some(vec![String::from("127.0.0.1")]),
596            ..Default::default()
597        };
598        assert!(options.needs_dns_protection());
599
600        // should be true when dns_rebinding_protection is true and allowed_origins is provided
601        let options = HyperServerOptions {
602            dns_rebinding_protection: true,
603            allowed_origins: Some(vec![String::from("http://127.0.0.1:8080")]),
604            ..Default::default()
605        };
606        assert!(options.needs_dns_protection());
607    }
608
609    #[test]
610    fn test_server_options_validate() {
611        let options = HyperServerOptions::default();
612        assert!(options.validate().is_ok());
613
614        // with ssl enabled but no cert or key provided, validate should fail
615        let options = HyperServerOptions {
616            enable_ssl: true,
617            ..Default::default()
618        };
619        assert!(options.validate().is_err());
620
621        // with ssl enabled and invalid cert/key paths, validate should fail
622        let options = HyperServerOptions {
623            enable_ssl: true,
624            ssl_cert_path: Some(String::from("/invalid/path/to/cert.pem")),
625            ssl_key_path: Some(String::from("/invalid/path/to/key.pem")),
626            ..Default::default()
627        };
628        assert!(options.validate().is_err());
629
630        // with ssl enabled and valid cert/key paths, validate should succeed
631        let cert_file =
632            NamedTempFile::with_suffix(".pem").expect("Expected to create test cert file");
633        let ssl_cert_path = cert_file
634            .path()
635            .to_str()
636            .expect("Expected to get cert path")
637            .to_string();
638        let key_file =
639            NamedTempFile::with_suffix(".pem").expect("Expected to create test key file");
640        let ssl_key_path = key_file
641            .path()
642            .to_str()
643            .expect("Expected to get key path")
644            .to_string();
645
646        let options = HyperServerOptions {
647            enable_ssl: true,
648            ssl_cert_path: Some(ssl_cert_path),
649            ssl_key_path: Some(ssl_key_path),
650            ..Default::default()
651        };
652        assert!(options.validate().is_ok());
653    }
654
655    #[tokio::test]
656    async fn test_server_options_resolve_server_address() {
657        let options = HyperServerOptions::default();
658        assert!(options.resolve_server_address().await.is_ok());
659
660        // valid host should still work
661        let options = HyperServerOptions {
662            host: String::from("8.6.7.5"),
663            port: 309,
664            ..Default::default()
665        };
666        assert!(options.resolve_server_address().await.is_ok());
667
668        // valid host (prepended with http://) should still work
669        let options = HyperServerOptions {
670            host: String::from("http://8.6.7.5"),
671            port: 309,
672            ..Default::default()
673        };
674        assert!(options.resolve_server_address().await.is_ok());
675
676        // invalid host should raise an error
677        let options = HyperServerOptions {
678            host: String::from("invalid-host"),
679            port: 309,
680            ..Default::default()
681        };
682        assert!(options.resolve_server_address().await.is_err());
683    }
684}