rust_mcp_sdk/hyper_servers/
server.rs

1use crate::{
2    error::SdkResult,
3    id_generator::{FastIdGenerator, UuidGenerator},
4    mcp_http::{
5        http_utils::{
6            DEFAULT_MESSAGES_ENDPOINT, DEFAULT_SSE_ENDPOINT, DEFAULT_STREAMABLE_HTTP_ENDPOINT,
7        },
8        middleware::dns_rebind_protector::DnsRebindProtector,
9        McpAppState, McpHttpHandler,
10    },
11    mcp_server::hyper_runtime::HyperRuntime,
12    mcp_traits::{mcp_handler::McpServerHandler, IdGenerator},
13    session_store::InMemorySessionStore,
14};
15#[cfg(feature = "ssl")]
16use axum_server::tls_rustls::RustlsConfig;
17use axum_server::Handle;
18use std::{
19    net::{SocketAddr, ToSocketAddrs},
20    path::Path,
21    sync::Arc,
22    time::Duration,
23};
24use tokio::signal;
25
26use super::{
27    error::{TransportServerError, TransportServerResult},
28    routes::app_routes,
29};
30use crate::schema::InitializeResult;
31use axum::Router;
32use rust_mcp_transport::{event_store::EventStore, SessionId, TransportOptions};
33
34// Default client ping interval (12 seconds)
35const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12);
36const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 5;
37
38/// Configuration struct for the Hyper server
39/// Used to configure the HyperServer instance.
40pub struct HyperServerOptions {
41    /// Hostname or IP address the server will bind to (default: "127.0.0.1")
42    pub host: String,
43
44    /// Hostname or IP address the server will bind to (default: "8080")
45    pub port: u16,
46
47    /// Optional thread-safe session id generator to generate unique session IDs.
48    pub session_id_generator: Option<Arc<dyn IdGenerator<SessionId>>>,
49
50    /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`)
51    pub custom_streamable_http_endpoint: Option<String>,
52
53    /// Shared transport configuration used by the server
54    pub transport_options: Arc<TransportOptions>,
55
56    /// Event store for resumability support
57    /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages
58    pub event_store: Option<Arc<dyn EventStore>>,
59
60    /// This setting only applies to streamable HTTP.
61    /// If true, the server will return JSON responses instead of starting an SSE stream.
62    /// This can be useful for simple request/response scenarios without streaming.
63    /// Default is false (SSE streams are preferred).
64    pub enable_json_response: Option<bool>,
65
66    /// Interval between automatic ping messages sent to clients to detect disconnects
67    pub ping_interval: Duration,
68
69    /// Enables SSL/TLS if set to `true`
70    pub enable_ssl: bool,
71
72    /// Path to the SSL/TLS certificate file (e.g., "cert.pem").
73    /// Required if `enable_ssl` is `true`.
74    pub ssl_cert_path: Option<String>,
75
76    /// Path to the SSL/TLS private key file (e.g., "key.pem").
77    /// Required if `enable_ssl` is `true`.
78    pub ssl_key_path: Option<String>,
79
80    /// List of allowed host header values for DNS rebinding protection.
81    /// If not specified, host validation is disabled.
82    pub allowed_hosts: Option<Vec<String>>,
83
84    /// List of allowed origin header values for DNS rebinding protection.
85    /// If not specified, origin validation is disabled.
86    pub allowed_origins: Option<Vec<String>>,
87
88    /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured).
89    /// Default is false for backwards compatibility.
90    pub dns_rebinding_protection: bool,
91
92    /// If set to true, the SSE transport will also be supported for backward compatibility (default: true)
93    pub sse_support: bool,
94
95    /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`)
96    /// Applicable only if sse_support is true
97    pub custom_sse_endpoint: Option<String>,
98
99    /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`)
100    /// Applicable only if sse_support is true
101    pub custom_messages_endpoint: Option<String>,
102}
103
104impl HyperServerOptions {
105    /// Validates the server configuration options
106    ///
107    /// Ensures that SSL-related paths are provided and valid when SSL is enabled.
108    ///
109    /// # Returns
110    /// * `TransportServerResult<()>` - Ok if validation passes, Err with TransportServerError if invalid
111    pub fn validate(&self) -> TransportServerResult<()> {
112        if self.enable_ssl {
113            if self.ssl_cert_path.is_none() || self.ssl_key_path.is_none() {
114                return Err(TransportServerError::InvalidServerOptions(
115                    "Both 'ssl_cert_path' and 'ssl_key_path' must be provided when SSL is enabled."
116                        .into(),
117                ));
118            }
119
120            if !Path::new(self.ssl_cert_path.as_deref().unwrap_or("")).is_file() {
121                return Err(TransportServerError::InvalidServerOptions(
122                    "'ssl_cert_path' does not point to a valid or existing file.".into(),
123                ));
124            }
125
126            if !Path::new(self.ssl_key_path.as_deref().unwrap_or("")).is_file() {
127                return Err(TransportServerError::InvalidServerOptions(
128                    "'ssl_key_path' does not point to a valid or existing file.".into(),
129                ));
130            }
131        }
132
133        Ok(())
134    }
135
136    /// Resolves the server address from host and port
137    ///
138    /// Validates the configuration and converts the host/port into a SocketAddr.
139    /// Handles scheme prefixes (http:// or https://) and logs warnings for mismatches.
140    ///
141    /// # Returns
142    /// * `TransportServerResult<SocketAddr>` - The resolved server address or an error
143    pub(crate) async fn resolve_server_address(&self) -> TransportServerResult<SocketAddr> {
144        self.validate()?;
145
146        let mut host = self.host.to_string();
147        if let Some(stripped) = self.host.strip_prefix("http://") {
148            if self.enable_ssl {
149                tracing::warn!("Warning: Ignoring http:// scheme for SSL; using hostname only");
150            }
151            host = stripped.to_string();
152        } else if let Some(stripped) = host.strip_prefix("https://") {
153            host = stripped.to_string();
154        }
155
156        let addr = {
157            let mut iter = (host, self.port)
158                .to_socket_addrs()
159                .map_err(|err| TransportServerError::ServerStartError(err.to_string()))?;
160            match iter.next() {
161                Some(addr) => addr,
162                None => format!("{}:{}", self.host, self.port).parse().map_err(
163                    |err: std::net::AddrParseError| {
164                        TransportServerError::ServerStartError(err.to_string())
165                    },
166                )?,
167            }
168        };
169        Ok(addr)
170    }
171
172    pub fn base_url(&self) -> String {
173        format!(
174            "{}://{}:{}",
175            if self.enable_ssl { "https" } else { "http" },
176            self.host,
177            self.port
178        )
179    }
180    pub fn streamable_http_url(&self) -> String {
181        format!("{}{}", self.base_url(), self.streamable_http_endpoint())
182    }
183    pub fn sse_url(&self) -> String {
184        format!("{}{}", self.base_url(), self.sse_endpoint())
185    }
186    pub fn sse_message_url(&self) -> String {
187        format!("{}{}", self.base_url(), self.sse_messages_endpoint())
188    }
189
190    pub fn sse_endpoint(&self) -> &str {
191        self.custom_sse_endpoint
192            .as_deref()
193            .unwrap_or(DEFAULT_SSE_ENDPOINT)
194    }
195
196    pub fn sse_messages_endpoint(&self) -> &str {
197        self.custom_messages_endpoint
198            .as_deref()
199            .unwrap_or(DEFAULT_MESSAGES_ENDPOINT)
200    }
201
202    pub fn streamable_http_endpoint(&self) -> &str {
203        self.custom_messages_endpoint
204            .as_deref()
205            .unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT)
206    }
207
208    pub fn needs_dns_protection(&self) -> bool {
209        self.dns_rebinding_protection
210            && (self.allowed_hosts.is_some() || self.allowed_origins.is_some())
211    }
212}
213
214/// Default implementation for HyperServerOptions
215///
216/// Provides default values for the server configuration, including 127.0.0.1 address,
217/// port 8080, default Streamable HTTP endpoint, and 12-second ping interval.
218impl Default for HyperServerOptions {
219    fn default() -> Self {
220        Self {
221            host: "127.0.0.1".to_string(),
222            port: 8080,
223            custom_sse_endpoint: None,
224            custom_streamable_http_endpoint: None,
225            custom_messages_endpoint: None,
226            ping_interval: DEFAULT_CLIENT_PING_INTERVAL,
227            transport_options: Default::default(),
228            enable_ssl: false,
229            ssl_cert_path: None,
230            ssl_key_path: None,
231            session_id_generator: None,
232            enable_json_response: None,
233            sse_support: true,
234            allowed_hosts: None,
235            allowed_origins: None,
236            dns_rebinding_protection: false,
237            event_store: None,
238        }
239    }
240}
241
242/// Hyper server struct for managing the Axum-based web server
243pub struct HyperServer {
244    app: Router,
245    state: Arc<McpAppState>,
246    pub(crate) options: HyperServerOptions,
247    handle: Handle,
248}
249
250impl HyperServer {
251    /// Creates a new HyperServer instance
252    ///
253    /// Initializes the server with the provided server details, handler, and options.
254    ///
255    /// # Arguments
256    /// * `server_details` - Initialization result from the MCP schema
257    /// * `handler` - Shared MCP server handler with static lifetime
258    /// * `server_options` - Server configuration options
259    ///
260    /// # Returns
261    /// * `Self` - A new HyperServer instance
262    pub(crate) fn new(
263        server_details: InitializeResult,
264        handler: Arc<dyn McpServerHandler + 'static>,
265        mut server_options: HyperServerOptions,
266    ) -> Self {
267        let state: Arc<McpAppState> = Arc::new(McpAppState {
268            session_store: Arc::new(InMemorySessionStore::new()),
269            id_generator: server_options
270                .session_id_generator
271                .take()
272                .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)),
273            stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))),
274            server_details: Arc::new(server_details),
275            handler,
276            ping_interval: server_options.ping_interval,
277            transport_options: Arc::clone(&server_options.transport_options),
278            enable_json_response: server_options.enable_json_response.unwrap_or(false),
279            event_store: server_options.event_store.as_ref().map(Arc::clone),
280        });
281
282        let mut http_handler = McpHttpHandler::new();
283
284        if server_options.needs_dns_protection() {
285            http_handler.add_middleware(DnsRebindProtector::new(
286                server_options.allowed_hosts.take(),
287                server_options.allowed_origins.take(),
288            ));
289        }
290
291        let app = app_routes(Arc::clone(&state), &server_options, http_handler);
292        Self {
293            app,
294            state,
295            options: server_options,
296            handle: Handle::new(),
297        }
298    }
299
300    /// Returns a shared reference to the application state
301    ///
302    /// # Returns
303    /// * `Arc<McpAppState>` - Shared application state
304    pub fn state(&self) -> Arc<McpAppState> {
305        Arc::clone(&self.state)
306    }
307
308    /// Adds a new route to the server
309    ///
310    /// # Arguments
311    /// * `path` - The route path (static string)
312    /// * `route` - The Axum MethodRouter for handling the route
313    ///
314    /// # Returns
315    /// * `Self` - The modified HyperServer instance
316    pub fn with_route(mut self, path: &'static str, route: axum::routing::MethodRouter) -> Self {
317        self.app = self.app.route(path, route);
318        self
319    }
320
321    /// Generates server information string
322    ///
323    /// Constructs a string describing the server type, protocol, address, and SSE endpoint.
324    ///
325    /// # Arguments
326    /// * `addr` - Optional SocketAddr; if None, resolves from options
327    ///
328    /// # Returns
329    /// * `TransportServerResult<String>` - The server information string or an error
330    pub async fn server_info(&self, addr: Option<SocketAddr>) -> TransportServerResult<String> {
331        let addr = addr.unwrap_or(self.options.resolve_server_address().await?);
332        let server_type = if self.options.enable_ssl {
333            "SSL server"
334        } else {
335            "Server"
336        };
337        let protocol = if self.options.enable_ssl {
338            "https"
339        } else {
340            "http"
341        };
342
343        let mut server_url = format!(
344            "\n• Streamable HTTP {} is available at {}://{}{}",
345            server_type,
346            protocol,
347            addr,
348            self.options.streamable_http_endpoint()
349        );
350
351        if self.options.sse_support {
352            let sse_url = format!(
353                "\n• SSE {} is available at {}://{}{}",
354                server_type,
355                protocol,
356                addr,
357                self.options.sse_endpoint()
358            );
359            server_url.push_str(&sse_url);
360        };
361
362        Ok(server_url)
363    }
364
365    pub fn options(&self) -> &HyperServerOptions {
366        &self.options
367    }
368
369    // pub fn with_layer<L>(mut self, layer: L) -> Self
370    // where
371    //     // L: Layer<axum::body::Body> + Clone + Send + Sync + 'static,
372    //     L::Service: Send + Sync + 'static,
373    // {
374    //     self.router = self.router.layer(layer);
375    //     self
376    // }
377
378    /// Starts the server with SSL support (available when "ssl" feature is enabled)
379    ///
380    /// # Arguments
381    /// * `addr` - The server address to bind to
382    ///
383    /// # Returns
384    /// * `TransportServerResult<()>` - Ok if the server starts successfully, Err otherwise
385    #[cfg(feature = "ssl")]
386    pub(crate) async fn start_ssl(self, addr: SocketAddr) -> TransportServerResult<()> {
387        let config = RustlsConfig::from_pem_file(
388            self.options.ssl_cert_path.as_deref().unwrap_or_default(),
389            self.options.ssl_key_path.as_deref().unwrap_or_default(),
390        )
391        .await
392        .map_err(|err| TransportServerError::SslCertError(err.to_string()))?;
393
394        tracing::info!("{}", self.server_info(Some(addr)).await?);
395
396        // Spawn a task to trigger shutdown on signal
397        let handle_clone = self.handle.clone();
398        let state_clone = self.state().clone();
399        tokio::spawn(async move {
400            shutdown_signal(handle_clone, state_clone).await;
401        });
402
403        let handle_clone = self.handle.clone();
404        axum_server::bind_rustls(addr, config)
405            .handle(handle_clone)
406            .serve(self.app.into_make_service())
407            .await
408            .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
409    }
410
411    /// Returns server handle that could be used for graceful shutdown
412    pub fn server_handle(&self) -> Handle {
413        self.handle.clone()
414    }
415
416    /// Starts the server without SSL
417    ///
418    /// # Arguments
419    /// * `addr` - The server address to bind to
420    ///
421    /// # Returns
422    /// * `TransportServerResult<()>` - Ok if the server starts successfully, Err otherwise
423    pub(crate) async fn start_http(self, addr: SocketAddr) -> TransportServerResult<()> {
424        tracing::info!("{}", self.server_info(Some(addr)).await?);
425
426        // Spawn a task to trigger shutdown on signal
427        let handle_clone = self.handle.clone();
428        tokio::spawn(async move {
429            shutdown_signal(handle_clone, self.state.clone()).await;
430        });
431
432        let handle_clone = self.handle.clone();
433        axum_server::bind(addr)
434            .handle(handle_clone)
435            .serve(self.app.into_make_service())
436            .await
437            .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
438    }
439
440    /// Starts the server, choosing SSL or HTTP based on configuration
441    ///
442    /// Resolves the server address and starts the server in either SSL or HTTP mode.
443    /// Panics if SSL is requested but the "ssl" feature is not enabled.
444    ///
445    /// # Returns
446    /// * `SdkResult<()>` - Ok if the server starts successfully, Err otherwise
447    pub async fn start(self) -> SdkResult<()> {
448        let runtime = HyperRuntime::create(self).await?;
449        runtime.await_server().await
450    }
451
452    /// Similar to start() , but returns a HyperRuntime after server started
453    ///
454    /// HyperRuntime could be used to access sessions and send server initiated messages if needed
455    ///
456    /// # Returns
457    /// * `SdkResult<HyperRuntime>` - Ok if the server starts successfully, Err otherwise
458    pub async fn start_runtime(self) -> SdkResult<HyperRuntime> {
459        HyperRuntime::create(self).await
460    }
461}
462
463// Shutdown signal handler
464async fn shutdown_signal(handle: Handle, state: Arc<McpAppState>) {
465    // Wait for a Ctrl+C or SIGTERM signal
466    let ctrl_c = async {
467        signal::ctrl_c()
468            .await
469            .expect("Failed to install Ctrl+C handler");
470    };
471
472    #[cfg(unix)]
473    let terminate = async {
474        signal::unix::signal(signal::unix::SignalKind::terminate())
475            .expect("Failed to install signal handler")
476            .recv()
477            .await;
478    };
479
480    #[cfg(not(unix))]
481    let terminate = std::future::pending::<()>();
482
483    tokio::select! {
484        _ = ctrl_c => {},
485        _ = terminate => {},
486    }
487
488    tracing::info!("Signal received, starting graceful shutdown");
489    state.session_store.clear().await;
490    // Trigger graceful shutdown with a timeout
491    handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS)));
492}