rust_mcp_sdk/hyper_servers/
server.rs

1use crate::{
2    error::SdkResult,
3    id_generator::{FastIdGenerator, UuidGenerator},
4    mcp_http::{
5        utils::{
6            DEFAULT_MESSAGES_ENDPOINT, DEFAULT_SSE_ENDPOINT, DEFAULT_STREAMABLE_HTTP_ENDPOINT,
7        },
8        McpAppState,
9    },
10    mcp_server::hyper_runtime::HyperRuntime,
11    mcp_traits::{mcp_handler::McpServerHandler, IdGenerator},
12    session_store::InMemorySessionStore,
13};
14#[cfg(feature = "ssl")]
15use axum_server::tls_rustls::RustlsConfig;
16use axum_server::Handle;
17use std::{
18    net::{SocketAddr, ToSocketAddrs},
19    path::Path,
20    sync::Arc,
21    time::Duration,
22};
23use tokio::signal;
24
25use super::{
26    error::{TransportServerError, TransportServerResult},
27    routes::app_routes,
28};
29use crate::schema::InitializeResult;
30use axum::Router;
31use rust_mcp_transport::{event_store::EventStore, SessionId, TransportOptions};
32
33// Default client ping interval (12 seconds)
34const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12);
35const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 5;
36
37/// Configuration struct for the Hyper server
38/// Used to configure the HyperServer instance.
39pub struct HyperServerOptions {
40    /// Hostname or IP address the server will bind to (default: "127.0.0.1")
41    pub host: String,
42
43    /// Hostname or IP address the server will bind to (default: "8080")
44    pub port: u16,
45
46    /// Optional thread-safe session id generator to generate unique session IDs.
47    pub session_id_generator: Option<Arc<dyn IdGenerator<SessionId>>>,
48
49    /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`)
50    pub custom_streamable_http_endpoint: Option<String>,
51
52    /// Shared transport configuration used by the server
53    pub transport_options: Arc<TransportOptions>,
54
55    /// Event store for resumability support
56    /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages
57    pub event_store: Option<Arc<dyn EventStore>>,
58
59    /// This setting only applies to streamable HTTP.
60    /// If true, the server will return JSON responses instead of starting an SSE stream.
61    /// This can be useful for simple request/response scenarios without streaming.
62    /// Default is false (SSE streams are preferred).
63    pub enable_json_response: Option<bool>,
64
65    /// Interval between automatic ping messages sent to clients to detect disconnects
66    pub ping_interval: Duration,
67
68    /// Enables SSL/TLS if set to `true`
69    pub enable_ssl: bool,
70
71    /// Path to the SSL/TLS certificate file (e.g., "cert.pem").
72    /// Required if `enable_ssl` is `true`.
73    pub ssl_cert_path: Option<String>,
74
75    /// Path to the SSL/TLS private key file (e.g., "key.pem").
76    /// Required if `enable_ssl` is `true`.
77    pub ssl_key_path: Option<String>,
78
79    /// List of allowed host header values for DNS rebinding protection.
80    /// If not specified, host validation is disabled.
81    pub allowed_hosts: Option<Vec<String>>,
82
83    /// List of allowed origin header values for DNS rebinding protection.
84    /// If not specified, origin validation is disabled.
85    pub allowed_origins: Option<Vec<String>>,
86
87    /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured).
88    /// Default is false for backwards compatibility.
89    pub dns_rebinding_protection: bool,
90
91    /// If set to true, the SSE transport will also be supported for backward compatibility (default: true)
92    pub sse_support: bool,
93
94    /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`)
95    /// Applicable only if sse_support is true
96    pub custom_sse_endpoint: Option<String>,
97
98    /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`)
99    /// Applicable only if sse_support is true
100    pub custom_messages_endpoint: Option<String>,
101}
102
103impl HyperServerOptions {
104    /// Validates the server configuration options
105    ///
106    /// Ensures that SSL-related paths are provided and valid when SSL is enabled.
107    ///
108    /// # Returns
109    /// * `TransportServerResult<()>` - Ok if validation passes, Err with TransportServerError if invalid
110    pub fn validate(&self) -> TransportServerResult<()> {
111        if self.enable_ssl {
112            if self.ssl_cert_path.is_none() || self.ssl_key_path.is_none() {
113                return Err(TransportServerError::InvalidServerOptions(
114                    "Both 'ssl_cert_path' and 'ssl_key_path' must be provided when SSL is enabled."
115                        .into(),
116                ));
117            }
118
119            if !Path::new(self.ssl_cert_path.as_deref().unwrap_or("")).is_file() {
120                return Err(TransportServerError::InvalidServerOptions(
121                    "'ssl_cert_path' does not point to a valid or existing file.".into(),
122                ));
123            }
124
125            if !Path::new(self.ssl_key_path.as_deref().unwrap_or("")).is_file() {
126                return Err(TransportServerError::InvalidServerOptions(
127                    "'ssl_key_path' does not point to a valid or existing file.".into(),
128                ));
129            }
130        }
131
132        Ok(())
133    }
134
135    /// Resolves the server address from host and port
136    ///
137    /// Validates the configuration and converts the host/port into a SocketAddr.
138    /// Handles scheme prefixes (http:// or https://) and logs warnings for mismatches.
139    ///
140    /// # Returns
141    /// * `TransportServerResult<SocketAddr>` - The resolved server address or an error
142    pub(crate) async fn resolve_server_address(&self) -> TransportServerResult<SocketAddr> {
143        self.validate()?;
144
145        let mut host = self.host.to_string();
146        if let Some(stripped) = self.host.strip_prefix("http://") {
147            if self.enable_ssl {
148                tracing::warn!("Warning: Ignoring http:// scheme for SSL; using hostname only");
149            }
150            host = stripped.to_string();
151        } else if let Some(stripped) = host.strip_prefix("https://") {
152            host = stripped.to_string();
153        }
154
155        let addr = {
156            let mut iter = (host, self.port)
157                .to_socket_addrs()
158                .map_err(|err| TransportServerError::ServerStartError(err.to_string()))?;
159            match iter.next() {
160                Some(addr) => addr,
161                None => format!("{}:{}", self.host, self.port).parse().map_err(
162                    |err: std::net::AddrParseError| {
163                        TransportServerError::ServerStartError(err.to_string())
164                    },
165                )?,
166            }
167        };
168        Ok(addr)
169    }
170
171    pub fn base_url(&self) -> String {
172        format!(
173            "{}://{}:{}",
174            if self.enable_ssl { "https" } else { "http" },
175            self.host,
176            self.port
177        )
178    }
179    pub fn streamable_http_url(&self) -> String {
180        format!("{}{}", self.base_url(), self.streamable_http_endpoint())
181    }
182    pub fn sse_url(&self) -> String {
183        format!("{}{}", self.base_url(), self.sse_endpoint())
184    }
185    pub fn sse_message_url(&self) -> String {
186        format!("{}{}", self.base_url(), self.sse_messages_endpoint())
187    }
188
189    pub fn sse_endpoint(&self) -> &str {
190        self.custom_sse_endpoint
191            .as_deref()
192            .unwrap_or(DEFAULT_SSE_ENDPOINT)
193    }
194
195    pub fn sse_messages_endpoint(&self) -> &str {
196        self.custom_messages_endpoint
197            .as_deref()
198            .unwrap_or(DEFAULT_MESSAGES_ENDPOINT)
199    }
200
201    pub fn streamable_http_endpoint(&self) -> &str {
202        self.custom_messages_endpoint
203            .as_deref()
204            .unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT)
205    }
206}
207
208/// Default implementation for HyperServerOptions
209///
210/// Provides default values for the server configuration, including 127.0.0.1 address,
211/// port 8080, default Streamable HTTP endpoint, and 12-second ping interval.
212impl Default for HyperServerOptions {
213    fn default() -> Self {
214        Self {
215            host: "127.0.0.1".to_string(),
216            port: 8080,
217            custom_sse_endpoint: None,
218            custom_streamable_http_endpoint: None,
219            custom_messages_endpoint: None,
220            ping_interval: DEFAULT_CLIENT_PING_INTERVAL,
221            transport_options: Default::default(),
222            enable_ssl: false,
223            ssl_cert_path: None,
224            ssl_key_path: None,
225            session_id_generator: None,
226            enable_json_response: None,
227            sse_support: true,
228            allowed_hosts: None,
229            allowed_origins: None,
230            dns_rebinding_protection: false,
231            event_store: None,
232        }
233    }
234}
235
236/// Hyper server struct for managing the Axum-based web server
237pub struct HyperServer {
238    app: Router,
239    state: Arc<McpAppState>,
240    pub(crate) options: HyperServerOptions,
241    handle: Handle,
242}
243
244impl HyperServer {
245    /// Creates a new HyperServer instance
246    ///
247    /// Initializes the server with the provided server details, handler, and options.
248    ///
249    /// # Arguments
250    /// * `server_details` - Initialization result from the MCP schema
251    /// * `handler` - Shared MCP server handler with static lifetime
252    /// * `server_options` - Server configuration options
253    ///
254    /// # Returns
255    /// * `Self` - A new HyperServer instance
256    pub(crate) fn new(
257        server_details: InitializeResult,
258        handler: Arc<dyn McpServerHandler + 'static>,
259        mut server_options: HyperServerOptions,
260    ) -> Self {
261        let state: Arc<McpAppState> = Arc::new(McpAppState {
262            session_store: Arc::new(InMemorySessionStore::new()),
263            id_generator: server_options
264                .session_id_generator
265                .take()
266                .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)),
267            stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))),
268            server_details: Arc::new(server_details),
269            handler,
270            ping_interval: server_options.ping_interval,
271            transport_options: Arc::clone(&server_options.transport_options),
272            enable_json_response: server_options.enable_json_response.unwrap_or(false),
273            allowed_hosts: server_options.allowed_hosts.take(),
274            allowed_origins: server_options.allowed_origins.take(),
275            dns_rebinding_protection: server_options.dns_rebinding_protection,
276            event_store: server_options.event_store.as_ref().map(Arc::clone),
277        });
278        let app = app_routes(Arc::clone(&state), &server_options);
279        Self {
280            app,
281            state,
282            options: server_options,
283            handle: Handle::new(),
284        }
285    }
286
287    /// Returns a shared reference to the application state
288    ///
289    /// # Returns
290    /// * `Arc<McpAppState>` - Shared application state
291    pub fn state(&self) -> Arc<McpAppState> {
292        Arc::clone(&self.state)
293    }
294
295    /// Adds a new route to the server
296    ///
297    /// # Arguments
298    /// * `path` - The route path (static string)
299    /// * `route` - The Axum MethodRouter for handling the route
300    ///
301    /// # Returns
302    /// * `Self` - The modified HyperServer instance
303    pub fn with_route(mut self, path: &'static str, route: axum::routing::MethodRouter) -> Self {
304        self.app = self.app.route(path, route);
305        self
306    }
307
308    /// Generates server information string
309    ///
310    /// Constructs a string describing the server type, protocol, address, and SSE endpoint.
311    ///
312    /// # Arguments
313    /// * `addr` - Optional SocketAddr; if None, resolves from options
314    ///
315    /// # Returns
316    /// * `TransportServerResult<String>` - The server information string or an error
317    pub async fn server_info(&self, addr: Option<SocketAddr>) -> TransportServerResult<String> {
318        let addr = addr.unwrap_or(self.options.resolve_server_address().await?);
319        let server_type = if self.options.enable_ssl {
320            "SSL server"
321        } else {
322            "Server"
323        };
324        let protocol = if self.options.enable_ssl {
325            "https"
326        } else {
327            "http"
328        };
329
330        let mut server_url = format!(
331            "\n• Streamable HTTP {} is available at {}://{}{}",
332            server_type,
333            protocol,
334            addr,
335            self.options.streamable_http_endpoint()
336        );
337
338        if self.options.sse_support {
339            let sse_url = format!(
340                "\n• SSE {} is available at {}://{}{}",
341                server_type,
342                protocol,
343                addr,
344                self.options.sse_endpoint()
345            );
346            server_url.push_str(&sse_url);
347        };
348
349        Ok(server_url)
350    }
351
352    pub fn options(&self) -> &HyperServerOptions {
353        &self.options
354    }
355
356    // pub fn with_layer<L>(mut self, layer: L) -> Self
357    // where
358    //     // L: Layer<axum::body::Body> + Clone + Send + Sync + 'static,
359    //     L::Service: Send + Sync + 'static,
360    // {
361    //     self.router = self.router.layer(layer);
362    //     self
363    // }
364
365    /// Starts the server with SSL support (available when "ssl" feature is enabled)
366    ///
367    /// # Arguments
368    /// * `addr` - The server address to bind to
369    ///
370    /// # Returns
371    /// * `TransportServerResult<()>` - Ok if the server starts successfully, Err otherwise
372    #[cfg(feature = "ssl")]
373    pub(crate) async fn start_ssl(self, addr: SocketAddr) -> TransportServerResult<()> {
374        let config = RustlsConfig::from_pem_file(
375            self.options.ssl_cert_path.as_deref().unwrap_or_default(),
376            self.options.ssl_key_path.as_deref().unwrap_or_default(),
377        )
378        .await
379        .map_err(|err| TransportServerError::SslCertError(err.to_string()))?;
380
381        tracing::info!("{}", self.server_info(Some(addr)).await?);
382
383        // Spawn a task to trigger shutdown on signal
384        let handle_clone = self.handle.clone();
385        let state_clone = self.state().clone();
386        tokio::spawn(async move {
387            shutdown_signal(handle_clone, state_clone).await;
388        });
389
390        let handle_clone = self.handle.clone();
391        axum_server::bind_rustls(addr, config)
392            .handle(handle_clone)
393            .serve(self.app.into_make_service())
394            .await
395            .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
396    }
397
398    /// Returns server handle that could be used for graceful shutdown
399    pub fn server_handle(&self) -> Handle {
400        self.handle.clone()
401    }
402
403    /// Starts the server without SSL
404    ///
405    /// # Arguments
406    /// * `addr` - The server address to bind to
407    ///
408    /// # Returns
409    /// * `TransportServerResult<()>` - Ok if the server starts successfully, Err otherwise
410    pub(crate) async fn start_http(self, addr: SocketAddr) -> TransportServerResult<()> {
411        tracing::info!("{}", self.server_info(Some(addr)).await?);
412
413        // Spawn a task to trigger shutdown on signal
414        let handle_clone = self.handle.clone();
415        tokio::spawn(async move {
416            shutdown_signal(handle_clone, self.state.clone()).await;
417        });
418
419        let handle_clone = self.handle.clone();
420        axum_server::bind(addr)
421            .handle(handle_clone)
422            .serve(self.app.into_make_service())
423            .await
424            .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
425    }
426
427    /// Starts the server, choosing SSL or HTTP based on configuration
428    ///
429    /// Resolves the server address and starts the server in either SSL or HTTP mode.
430    /// Panics if SSL is requested but the "ssl" feature is not enabled.
431    ///
432    /// # Returns
433    /// * `SdkResult<()>` - Ok if the server starts successfully, Err otherwise
434    pub async fn start(self) -> SdkResult<()> {
435        let runtime = HyperRuntime::create(self).await?;
436        runtime.await_server().await
437    }
438
439    /// Similar to start() , but returns a HyperRuntime after server started
440    ///
441    /// HyperRuntime could be used to access sessions and send server initiated messages if needed
442    ///
443    /// # Returns
444    /// * `SdkResult<HyperRuntime>` - Ok if the server starts successfully, Err otherwise
445    pub async fn start_runtime(self) -> SdkResult<HyperRuntime> {
446        HyperRuntime::create(self).await
447    }
448}
449
450// Shutdown signal handler
451async fn shutdown_signal(handle: Handle, state: Arc<McpAppState>) {
452    // Wait for a Ctrl+C or SIGTERM signal
453    let ctrl_c = async {
454        signal::ctrl_c()
455            .await
456            .expect("Failed to install Ctrl+C handler");
457    };
458
459    #[cfg(unix)]
460    let terminate = async {
461        signal::unix::signal(signal::unix::SignalKind::terminate())
462            .expect("Failed to install signal handler")
463            .recv()
464            .await;
465    };
466
467    #[cfg(not(unix))]
468    let terminate = std::future::pending::<()>();
469
470    tokio::select! {
471        _ = ctrl_c => {},
472        _ = terminate => {},
473    }
474
475    tracing::info!("Signal received, starting graceful shutdown");
476    state.session_store.clear().await;
477    // Trigger graceful shutdown with a timeout
478    handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS)));
479}