rust_mcp_sdk/hyper_servers/
server.rs

1use crate::{
2    error::SdkResult,
3    id_generator::{FastIdGenerator, UuidGenerator},
4    mcp_http::{
5        utils::{
6            DEFAULT_MESSAGES_ENDPOINT, DEFAULT_SSE_ENDPOINT, DEFAULT_STREAMABLE_HTTP_ENDPOINT,
7        },
8        McpAppState, McpHttpHandler,
9    },
10    mcp_server::hyper_runtime::HyperRuntime,
11    mcp_traits::{mcp_handler::McpServerHandler, IdGenerator},
12    session_store::InMemorySessionStore,
13};
14#[cfg(feature = "ssl")]
15use axum_server::tls_rustls::RustlsConfig;
16use axum_server::Handle;
17use std::{
18    net::{SocketAddr, ToSocketAddrs},
19    path::Path,
20    sync::Arc,
21    time::Duration,
22};
23use tokio::signal;
24
25use super::{
26    error::{TransportServerError, TransportServerResult},
27    routes::app_routes,
28};
29use crate::schema::InitializeResult;
30use axum::Router;
31use rust_mcp_transport::{event_store::EventStore, SessionId, TransportOptions};
32
33// Default client ping interval (12 seconds)
34const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12);
35const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 5;
36
37/// Configuration struct for the Hyper server
38/// Used to configure the HyperServer instance.
39pub struct HyperServerOptions {
40    /// Hostname or IP address the server will bind to (default: "127.0.0.1")
41    pub host: String,
42
43    /// Hostname or IP address the server will bind to (default: "8080")
44    pub port: u16,
45
46    /// Optional thread-safe session id generator to generate unique session IDs.
47    pub session_id_generator: Option<Arc<dyn IdGenerator<SessionId>>>,
48
49    /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`)
50    pub custom_streamable_http_endpoint: Option<String>,
51
52    /// Shared transport configuration used by the server
53    pub transport_options: Arc<TransportOptions>,
54
55    /// Event store for resumability support
56    /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages
57    pub event_store: Option<Arc<dyn EventStore>>,
58
59    /// This setting only applies to streamable HTTP.
60    /// If true, the server will return JSON responses instead of starting an SSE stream.
61    /// This can be useful for simple request/response scenarios without streaming.
62    /// Default is false (SSE streams are preferred).
63    pub enable_json_response: Option<bool>,
64
65    /// Interval between automatic ping messages sent to clients to detect disconnects
66    pub ping_interval: Duration,
67
68    /// Enables SSL/TLS if set to `true`
69    pub enable_ssl: bool,
70
71    /// Path to the SSL/TLS certificate file (e.g., "cert.pem").
72    /// Required if `enable_ssl` is `true`.
73    pub ssl_cert_path: Option<String>,
74
75    /// Path to the SSL/TLS private key file (e.g., "key.pem").
76    /// Required if `enable_ssl` is `true`.
77    pub ssl_key_path: Option<String>,
78
79    /// List of allowed host header values for DNS rebinding protection.
80    /// If not specified, host validation is disabled.
81    pub allowed_hosts: Option<Vec<String>>,
82
83    /// List of allowed origin header values for DNS rebinding protection.
84    /// If not specified, origin validation is disabled.
85    pub allowed_origins: Option<Vec<String>>,
86
87    /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured).
88    /// Default is false for backwards compatibility.
89    pub dns_rebinding_protection: bool,
90
91    /// If set to true, the SSE transport will also be supported for backward compatibility (default: true)
92    pub sse_support: bool,
93
94    /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`)
95    /// Applicable only if sse_support is true
96    pub custom_sse_endpoint: Option<String>,
97
98    /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`)
99    /// Applicable only if sse_support is true
100    pub custom_messages_endpoint: Option<String>,
101}
102
103impl HyperServerOptions {
104    /// Validates the server configuration options
105    ///
106    /// Ensures that SSL-related paths are provided and valid when SSL is enabled.
107    ///
108    /// # Returns
109    /// * `TransportServerResult<()>` - Ok if validation passes, Err with TransportServerError if invalid
110    pub fn validate(&self) -> TransportServerResult<()> {
111        if self.enable_ssl {
112            if self.ssl_cert_path.is_none() || self.ssl_key_path.is_none() {
113                return Err(TransportServerError::InvalidServerOptions(
114                    "Both 'ssl_cert_path' and 'ssl_key_path' must be provided when SSL is enabled."
115                        .into(),
116                ));
117            }
118
119            if !Path::new(self.ssl_cert_path.as_deref().unwrap_or("")).is_file() {
120                return Err(TransportServerError::InvalidServerOptions(
121                    "'ssl_cert_path' does not point to a valid or existing file.".into(),
122                ));
123            }
124
125            if !Path::new(self.ssl_key_path.as_deref().unwrap_or("")).is_file() {
126                return Err(TransportServerError::InvalidServerOptions(
127                    "'ssl_key_path' does not point to a valid or existing file.".into(),
128                ));
129            }
130        }
131
132        Ok(())
133    }
134
135    /// Resolves the server address from host and port
136    ///
137    /// Validates the configuration and converts the host/port into a SocketAddr.
138    /// Handles scheme prefixes (http:// or https://) and logs warnings for mismatches.
139    ///
140    /// # Returns
141    /// * `TransportServerResult<SocketAddr>` - The resolved server address or an error
142    pub(crate) async fn resolve_server_address(&self) -> TransportServerResult<SocketAddr> {
143        self.validate()?;
144
145        let mut host = self.host.to_string();
146        if let Some(stripped) = self.host.strip_prefix("http://") {
147            if self.enable_ssl {
148                tracing::warn!("Warning: Ignoring http:// scheme for SSL; using hostname only");
149            }
150            host = stripped.to_string();
151        } else if let Some(stripped) = host.strip_prefix("https://") {
152            host = stripped.to_string();
153        }
154
155        let addr = {
156            let mut iter = (host, self.port)
157                .to_socket_addrs()
158                .map_err(|err| TransportServerError::ServerStartError(err.to_string()))?;
159            match iter.next() {
160                Some(addr) => addr,
161                None => format!("{}:{}", self.host, self.port).parse().map_err(
162                    |err: std::net::AddrParseError| {
163                        TransportServerError::ServerStartError(err.to_string())
164                    },
165                )?,
166            }
167        };
168        Ok(addr)
169    }
170
171    pub fn base_url(&self) -> String {
172        format!(
173            "{}://{}:{}",
174            if self.enable_ssl { "https" } else { "http" },
175            self.host,
176            self.port
177        )
178    }
179    pub fn streamable_http_url(&self) -> String {
180        format!("{}{}", self.base_url(), self.streamable_http_endpoint())
181    }
182    pub fn sse_url(&self) -> String {
183        format!("{}{}", self.base_url(), self.sse_endpoint())
184    }
185    pub fn sse_message_url(&self) -> String {
186        format!("{}{}", self.base_url(), self.sse_messages_endpoint())
187    }
188
189    pub fn sse_endpoint(&self) -> &str {
190        self.custom_sse_endpoint
191            .as_deref()
192            .unwrap_or(DEFAULT_SSE_ENDPOINT)
193    }
194
195    pub fn sse_messages_endpoint(&self) -> &str {
196        self.custom_messages_endpoint
197            .as_deref()
198            .unwrap_or(DEFAULT_MESSAGES_ENDPOINT)
199    }
200
201    pub fn streamable_http_endpoint(&self) -> &str {
202        self.custom_messages_endpoint
203            .as_deref()
204            .unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT)
205    }
206}
207
208/// Default implementation for HyperServerOptions
209///
210/// Provides default values for the server configuration, including 127.0.0.1 address,
211/// port 8080, default Streamable HTTP endpoint, and 12-second ping interval.
212impl Default for HyperServerOptions {
213    fn default() -> Self {
214        Self {
215            host: "127.0.0.1".to_string(),
216            port: 8080,
217            custom_sse_endpoint: None,
218            custom_streamable_http_endpoint: None,
219            custom_messages_endpoint: None,
220            ping_interval: DEFAULT_CLIENT_PING_INTERVAL,
221            transport_options: Default::default(),
222            enable_ssl: false,
223            ssl_cert_path: None,
224            ssl_key_path: None,
225            session_id_generator: None,
226            enable_json_response: None,
227            sse_support: true,
228            allowed_hosts: None,
229            allowed_origins: None,
230            dns_rebinding_protection: false,
231            event_store: None,
232        }
233    }
234}
235
236/// Hyper server struct for managing the Axum-based web server
237pub struct HyperServer {
238    app: Router,
239    state: Arc<McpAppState>,
240    pub(crate) options: HyperServerOptions,
241    handle: Handle,
242}
243
244impl HyperServer {
245    /// Creates a new HyperServer instance
246    ///
247    /// Initializes the server with the provided server details, handler, and options.
248    ///
249    /// # Arguments
250    /// * `server_details` - Initialization result from the MCP schema
251    /// * `handler` - Shared MCP server handler with static lifetime
252    /// * `server_options` - Server configuration options
253    ///
254    /// # Returns
255    /// * `Self` - A new HyperServer instance
256    pub(crate) fn new(
257        server_details: InitializeResult,
258        handler: Arc<dyn McpServerHandler + 'static>,
259        mut server_options: HyperServerOptions,
260    ) -> Self {
261        let state: Arc<McpAppState> = Arc::new(McpAppState {
262            session_store: Arc::new(InMemorySessionStore::new()),
263            id_generator: server_options
264                .session_id_generator
265                .take()
266                .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)),
267            stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))),
268            server_details: Arc::new(server_details),
269            handler,
270            ping_interval: server_options.ping_interval,
271            transport_options: Arc::clone(&server_options.transport_options),
272            enable_json_response: server_options.enable_json_response.unwrap_or(false),
273            allowed_hosts: server_options.allowed_hosts.take(),
274            allowed_origins: server_options.allowed_origins.take(),
275            dns_rebinding_protection: server_options.dns_rebinding_protection,
276            event_store: server_options.event_store.as_ref().map(Arc::clone),
277        });
278
279        let http_handler = McpHttpHandler::new(); //TODO: add auth handlers
280        let app = app_routes(Arc::clone(&state), &server_options, http_handler);
281        Self {
282            app,
283            state,
284            options: server_options,
285            handle: Handle::new(),
286        }
287    }
288
289    /// Returns a shared reference to the application state
290    ///
291    /// # Returns
292    /// * `Arc<McpAppState>` - Shared application state
293    pub fn state(&self) -> Arc<McpAppState> {
294        Arc::clone(&self.state)
295    }
296
297    /// Adds a new route to the server
298    ///
299    /// # Arguments
300    /// * `path` - The route path (static string)
301    /// * `route` - The Axum MethodRouter for handling the route
302    ///
303    /// # Returns
304    /// * `Self` - The modified HyperServer instance
305    pub fn with_route(mut self, path: &'static str, route: axum::routing::MethodRouter) -> Self {
306        self.app = self.app.route(path, route);
307        self
308    }
309
310    /// Generates server information string
311    ///
312    /// Constructs a string describing the server type, protocol, address, and SSE endpoint.
313    ///
314    /// # Arguments
315    /// * `addr` - Optional SocketAddr; if None, resolves from options
316    ///
317    /// # Returns
318    /// * `TransportServerResult<String>` - The server information string or an error
319    pub async fn server_info(&self, addr: Option<SocketAddr>) -> TransportServerResult<String> {
320        let addr = addr.unwrap_or(self.options.resolve_server_address().await?);
321        let server_type = if self.options.enable_ssl {
322            "SSL server"
323        } else {
324            "Server"
325        };
326        let protocol = if self.options.enable_ssl {
327            "https"
328        } else {
329            "http"
330        };
331
332        let mut server_url = format!(
333            "\n• Streamable HTTP {} is available at {}://{}{}",
334            server_type,
335            protocol,
336            addr,
337            self.options.streamable_http_endpoint()
338        );
339
340        if self.options.sse_support {
341            let sse_url = format!(
342                "\n• SSE {} is available at {}://{}{}",
343                server_type,
344                protocol,
345                addr,
346                self.options.sse_endpoint()
347            );
348            server_url.push_str(&sse_url);
349        };
350
351        Ok(server_url)
352    }
353
354    pub fn options(&self) -> &HyperServerOptions {
355        &self.options
356    }
357
358    // pub fn with_layer<L>(mut self, layer: L) -> Self
359    // where
360    //     // L: Layer<axum::body::Body> + Clone + Send + Sync + 'static,
361    //     L::Service: Send + Sync + 'static,
362    // {
363    //     self.router = self.router.layer(layer);
364    //     self
365    // }
366
367    /// Starts the server with SSL support (available when "ssl" feature is enabled)
368    ///
369    /// # Arguments
370    /// * `addr` - The server address to bind to
371    ///
372    /// # Returns
373    /// * `TransportServerResult<()>` - Ok if the server starts successfully, Err otherwise
374    #[cfg(feature = "ssl")]
375    pub(crate) async fn start_ssl(self, addr: SocketAddr) -> TransportServerResult<()> {
376        let config = RustlsConfig::from_pem_file(
377            self.options.ssl_cert_path.as_deref().unwrap_or_default(),
378            self.options.ssl_key_path.as_deref().unwrap_or_default(),
379        )
380        .await
381        .map_err(|err| TransportServerError::SslCertError(err.to_string()))?;
382
383        tracing::info!("{}", self.server_info(Some(addr)).await?);
384
385        // Spawn a task to trigger shutdown on signal
386        let handle_clone = self.handle.clone();
387        let state_clone = self.state().clone();
388        tokio::spawn(async move {
389            shutdown_signal(handle_clone, state_clone).await;
390        });
391
392        let handle_clone = self.handle.clone();
393        axum_server::bind_rustls(addr, config)
394            .handle(handle_clone)
395            .serve(self.app.into_make_service())
396            .await
397            .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
398    }
399
400    /// Returns server handle that could be used for graceful shutdown
401    pub fn server_handle(&self) -> Handle {
402        self.handle.clone()
403    }
404
405    /// Starts the server without SSL
406    ///
407    /// # Arguments
408    /// * `addr` - The server address to bind to
409    ///
410    /// # Returns
411    /// * `TransportServerResult<()>` - Ok if the server starts successfully, Err otherwise
412    pub(crate) async fn start_http(self, addr: SocketAddr) -> TransportServerResult<()> {
413        tracing::info!("{}", self.server_info(Some(addr)).await?);
414
415        // Spawn a task to trigger shutdown on signal
416        let handle_clone = self.handle.clone();
417        tokio::spawn(async move {
418            shutdown_signal(handle_clone, self.state.clone()).await;
419        });
420
421        let handle_clone = self.handle.clone();
422        axum_server::bind(addr)
423            .handle(handle_clone)
424            .serve(self.app.into_make_service())
425            .await
426            .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
427    }
428
429    /// Starts the server, choosing SSL or HTTP based on configuration
430    ///
431    /// Resolves the server address and starts the server in either SSL or HTTP mode.
432    /// Panics if SSL is requested but the "ssl" feature is not enabled.
433    ///
434    /// # Returns
435    /// * `SdkResult<()>` - Ok if the server starts successfully, Err otherwise
436    pub async fn start(self) -> SdkResult<()> {
437        let runtime = HyperRuntime::create(self).await?;
438        runtime.await_server().await
439    }
440
441    /// Similar to start() , but returns a HyperRuntime after server started
442    ///
443    /// HyperRuntime could be used to access sessions and send server initiated messages if needed
444    ///
445    /// # Returns
446    /// * `SdkResult<HyperRuntime>` - Ok if the server starts successfully, Err otherwise
447    pub async fn start_runtime(self) -> SdkResult<HyperRuntime> {
448        HyperRuntime::create(self).await
449    }
450}
451
452// Shutdown signal handler
453async fn shutdown_signal(handle: Handle, state: Arc<McpAppState>) {
454    // Wait for a Ctrl+C or SIGTERM signal
455    let ctrl_c = async {
456        signal::ctrl_c()
457            .await
458            .expect("Failed to install Ctrl+C handler");
459    };
460
461    #[cfg(unix)]
462    let terminate = async {
463        signal::unix::signal(signal::unix::SignalKind::terminate())
464            .expect("Failed to install signal handler")
465            .recv()
466            .await;
467    };
468
469    #[cfg(not(unix))]
470    let terminate = std::future::pending::<()>();
471
472    tokio::select! {
473        _ = ctrl_c => {},
474        _ = terminate => {},
475    }
476
477    tracing::info!("Signal received, starting graceful shutdown");
478    state.session_store.clear().await;
479    // Trigger graceful shutdown with a timeout
480    handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS)));
481}