rust_mcp_sdk/hyper_servers/
server.rs

1use crate::{
2    error::SdkResult, mcp_server::hyper_runtime::HyperRuntime,
3    mcp_traits::mcp_handler::McpServerHandler,
4};
5#[cfg(feature = "ssl")]
6use axum_server::tls_rustls::RustlsConfig;
7use axum_server::Handle;
8use std::{
9    net::{SocketAddr, ToSocketAddrs},
10    path::Path,
11    sync::Arc,
12    time::Duration,
13};
14use tokio::signal;
15
16use super::{
17    app_state::AppState,
18    error::{TransportServerError, TransportServerResult},
19    routes::app_routes,
20    IdGenerator, InMemorySessionStore, UuidGenerator,
21};
22use crate::schema::InitializeResult;
23use axum::Router;
24use rust_mcp_transport::TransportOptions;
25
26// Default client ping interval (12 seconds)
27const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12);
28const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 5;
29// Default Server-Sent Events (SSE) endpoint path
30const DEFAULT_SSE_ENDPOINT: &str = "/sse";
31// Default MCP Messages endpoint path
32const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages";
33// Default Streamable HTTP endpoint path
34const DEFAULT_STREAMABLE_HTTP_ENDPOINT: &str = "/mcp";
35
36/// Configuration struct for the Hyper server
37/// Used to configure the HyperServer instance.
38pub struct HyperServerOptions {
39    /// Hostname or IP address the server will bind to (default: "127.0.0.1")
40    pub host: String,
41
42    /// Hostname or IP address the server will bind to (default: "8080")
43    pub port: u16,
44
45    /// Optional thread-safe session id generator to generate unique session IDs.
46    pub session_id_generator: Option<Arc<dyn IdGenerator>>,
47
48    /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`)
49    pub custom_streamable_http_endpoint: Option<String>,
50
51    /// Shared transport configuration used by the server
52    pub transport_options: Arc<TransportOptions>,
53
54    /// This setting only applies to streamable HTTP.
55    /// If true, the server will return JSON responses instead of starting an SSE stream.
56    /// This can be useful for simple request/response scenarios without streaming.
57    /// Default is false (SSE streams are preferred).
58    pub enable_json_response: Option<bool>,
59
60    /// Interval between automatic ping messages sent to clients to detect disconnects
61    pub ping_interval: Duration,
62
63    /// Enables SSL/TLS if set to `true`
64    pub enable_ssl: bool,
65
66    /// Path to the SSL/TLS certificate file (e.g., "cert.pem").
67    /// Required if `enable_ssl` is `true`.
68    pub ssl_cert_path: Option<String>,
69
70    /// Path to the SSL/TLS private key file (e.g., "key.pem").
71    /// Required if `enable_ssl` is `true`.
72    pub ssl_key_path: Option<String>,
73
74    /// List of allowed host header values for DNS rebinding protection.
75    /// If not specified, host validation is disabled.
76    pub allowed_hosts: Option<Vec<String>>,
77
78    /// List of allowed origin header values for DNS rebinding protection.
79    /// If not specified, origin validation is disabled.
80    pub allowed_origins: Option<Vec<String>>,
81
82    /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured).
83    /// Default is false for backwards compatibility.
84    pub dns_rebinding_protection: bool,
85
86    /// If set to true, the SSE transport will also be supported for backward compatibility (default: true)
87    pub sse_support: bool,
88
89    /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`)
90    /// Applicable only if sse_support is true
91    pub custom_sse_endpoint: Option<String>,
92
93    /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`)
94    /// Applicable only if sse_support is true
95    pub custom_messages_endpoint: Option<String>,
96}
97
98impl HyperServerOptions {
99    /// Validates the server configuration options
100    ///
101    /// Ensures that SSL-related paths are provided and valid when SSL is enabled.
102    ///
103    /// # Returns
104    /// * `TransportServerResult<()>` - Ok if validation passes, Err with TransportServerError if invalid
105    pub fn validate(&self) -> TransportServerResult<()> {
106        if self.enable_ssl {
107            if self.ssl_cert_path.is_none() || self.ssl_key_path.is_none() {
108                return Err(TransportServerError::InvalidServerOptions(
109                    "Both 'ssl_cert_path' and 'ssl_key_path' must be provided when SSL is enabled."
110                        .into(),
111                ));
112            }
113
114            if !Path::new(self.ssl_cert_path.as_deref().unwrap_or("")).is_file() {
115                return Err(TransportServerError::InvalidServerOptions(
116                    "'ssl_cert_path' does not point to a valid or existing file.".into(),
117                ));
118            }
119
120            if !Path::new(self.ssl_key_path.as_deref().unwrap_or("")).is_file() {
121                return Err(TransportServerError::InvalidServerOptions(
122                    "'ssl_key_path' does not point to a valid or existing file.".into(),
123                ));
124            }
125        }
126
127        Ok(())
128    }
129
130    /// Resolves the server address from host and port
131    ///
132    /// Validates the configuration and converts the host/port into a SocketAddr.
133    /// Handles scheme prefixes (http:// or https://) and logs warnings for mismatches.
134    ///
135    /// # Returns
136    /// * `TransportServerResult<SocketAddr>` - The resolved server address or an error
137    pub(crate) async fn resolve_server_address(&self) -> TransportServerResult<SocketAddr> {
138        self.validate()?;
139
140        let mut host = self.host.to_string();
141        if let Some(stripped) = self.host.strip_prefix("http://") {
142            if self.enable_ssl {
143                tracing::warn!("Warning: Ignoring http:// scheme for SSL; using hostname only");
144            }
145            host = stripped.to_string();
146        } else if let Some(stripped) = host.strip_prefix("https://") {
147            host = stripped.to_string();
148        }
149
150        let addr = {
151            let mut iter = (host, self.port)
152                .to_socket_addrs()
153                .map_err(|err| TransportServerError::ServerStartError(err.to_string()))?;
154            match iter.next() {
155                Some(addr) => addr,
156                None => format!("{}:{}", self.host, self.port).parse().map_err(
157                    |err: std::net::AddrParseError| {
158                        TransportServerError::ServerStartError(err.to_string())
159                    },
160                )?,
161            }
162        };
163        Ok(addr)
164    }
165
166    pub fn base_url(&self) -> String {
167        format!(
168            "{}://{}:{}",
169            if self.enable_ssl { "https" } else { "http" },
170            self.host,
171            self.port
172        )
173    }
174    pub fn streamable_http_url(&self) -> String {
175        format!("{}{}", self.base_url(), self.streamable_http_endpoint())
176    }
177    pub fn sse_url(&self) -> String {
178        format!("{}{}", self.base_url(), self.sse_endpoint())
179    }
180    pub fn sse_message_url(&self) -> String {
181        format!("{}{}", self.base_url(), self.sse_messages_endpoint())
182    }
183
184    pub fn sse_endpoint(&self) -> &str {
185        self.custom_sse_endpoint
186            .as_deref()
187            .unwrap_or(DEFAULT_SSE_ENDPOINT)
188    }
189
190    pub fn sse_messages_endpoint(&self) -> &str {
191        self.custom_messages_endpoint
192            .as_deref()
193            .unwrap_or(DEFAULT_MESSAGES_ENDPOINT)
194    }
195
196    pub fn streamable_http_endpoint(&self) -> &str {
197        self.custom_messages_endpoint
198            .as_deref()
199            .unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT)
200    }
201}
202
203/// Default implementation for HyperServerOptions
204///
205/// Provides default values for the server configuration, including 127.0.0.1 address,
206/// port 8080, default Streamable HTTP endpoint, and 12-second ping interval.
207impl Default for HyperServerOptions {
208    fn default() -> Self {
209        Self {
210            host: "127.0.0.1".to_string(),
211            port: 8080,
212            custom_sse_endpoint: None,
213            custom_streamable_http_endpoint: None,
214            custom_messages_endpoint: None,
215            ping_interval: DEFAULT_CLIENT_PING_INTERVAL,
216            transport_options: Default::default(),
217            enable_ssl: false,
218            ssl_cert_path: None,
219            ssl_key_path: None,
220            session_id_generator: None,
221            enable_json_response: None,
222            sse_support: true,
223            allowed_hosts: None,
224            allowed_origins: None,
225            dns_rebinding_protection: false,
226        }
227    }
228}
229
230/// Hyper server struct for managing the Axum-based web server
231pub struct HyperServer {
232    app: Router,
233    state: Arc<AppState>,
234    pub(crate) options: HyperServerOptions,
235    handle: Handle,
236}
237
238impl HyperServer {
239    /// Creates a new HyperServer instance
240    ///
241    /// Initializes the server with the provided server details, handler, and options.
242    ///
243    /// # Arguments
244    /// * `server_details` - Initialization result from the MCP schema
245    /// * `handler` - Shared MCP server handler with static lifetime
246    /// * `server_options` - Server configuration options
247    ///
248    /// # Returns
249    /// * `Self` - A new HyperServer instance
250    pub(crate) fn new(
251        server_details: InitializeResult,
252        handler: Arc<dyn McpServerHandler + 'static>,
253        mut server_options: HyperServerOptions,
254    ) -> Self {
255        let state: Arc<AppState> = Arc::new(AppState {
256            session_store: Arc::new(InMemorySessionStore::new()),
257            id_generator: server_options
258                .session_id_generator
259                .take()
260                .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)),
261            server_details: Arc::new(server_details),
262            handler,
263            ping_interval: server_options.ping_interval,
264            sse_message_endpoint: server_options.sse_messages_endpoint().to_owned(),
265            http_streamable_endpoint: server_options.streamable_http_endpoint().to_owned(),
266            transport_options: Arc::clone(&server_options.transport_options),
267            enable_json_response: server_options.enable_json_response.unwrap_or(false),
268            allowed_hosts: server_options.allowed_hosts.take(),
269            allowed_origins: server_options.allowed_origins.take(),
270            dns_rebinding_protection: server_options.dns_rebinding_protection,
271        });
272        let app = app_routes(Arc::clone(&state), &server_options);
273        Self {
274            app,
275            state,
276            options: server_options,
277            handle: Handle::new(),
278        }
279    }
280
281    /// Returns a shared reference to the application state
282    ///
283    /// # Returns
284    /// * `Arc<AppState>` - Shared application state
285    pub fn state(&self) -> Arc<AppState> {
286        Arc::clone(&self.state)
287    }
288
289    /// Adds a new route to the server
290    ///
291    /// # Arguments
292    /// * `path` - The route path (static string)
293    /// * `route` - The Axum MethodRouter for handling the route
294    ///
295    /// # Returns
296    /// * `Self` - The modified HyperServer instance
297    pub fn with_route(mut self, path: &'static str, route: axum::routing::MethodRouter) -> Self {
298        self.app = self.app.route(path, route);
299        self
300    }
301
302    /// Generates server information string
303    ///
304    /// Constructs a string describing the server type, protocol, address, and SSE endpoint.
305    ///
306    /// # Arguments
307    /// * `addr` - Optional SocketAddr; if None, resolves from options
308    ///
309    /// # Returns
310    /// * `TransportServerResult<String>` - The server information string or an error
311    pub async fn server_info(&self, addr: Option<SocketAddr>) -> TransportServerResult<String> {
312        let addr = addr.unwrap_or(self.options.resolve_server_address().await?);
313        let server_type = if self.options.enable_ssl {
314            "SSL server"
315        } else {
316            "Server"
317        };
318        let protocol = if self.options.enable_ssl {
319            "https"
320        } else {
321            "http"
322        };
323
324        let mut server_url = format!(
325            "\n• Streamable HTTP {} is available at {}://{}{}",
326            server_type,
327            protocol,
328            addr,
329            self.options.streamable_http_endpoint()
330        );
331
332        if self.options.sse_support {
333            let sse_url = format!(
334                "\n• SSE {} is available at {}://{}{}",
335                server_type,
336                protocol,
337                addr,
338                self.options.sse_endpoint()
339            );
340            server_url.push_str(&sse_url);
341        };
342
343        Ok(server_url)
344    }
345
346    pub fn options(&self) -> &HyperServerOptions {
347        &self.options
348    }
349
350    // pub fn with_layer<L>(mut self, layer: L) -> Self
351    // where
352    //     // L: Layer<axum::body::Body> + Clone + Send + Sync + 'static,
353    //     L::Service: Send + Sync + 'static,
354    // {
355    //     self.router = self.router.layer(layer);
356    //     self
357    // }
358
359    /// Starts the server with SSL support (available when "ssl" feature is enabled)
360    ///
361    /// # Arguments
362    /// * `addr` - The server address to bind to
363    ///
364    /// # Returns
365    /// * `TransportServerResult<()>` - Ok if the server starts successfully, Err otherwise
366    #[cfg(feature = "ssl")]
367    pub(crate) async fn start_ssl(self, addr: SocketAddr) -> TransportServerResult<()> {
368        let config = RustlsConfig::from_pem_file(
369            self.options.ssl_cert_path.as_deref().unwrap_or_default(),
370            self.options.ssl_key_path.as_deref().unwrap_or_default(),
371        )
372        .await
373        .map_err(|err| TransportServerError::SslCertError(err.to_string()))?;
374
375        tracing::info!("{}", self.server_info(Some(addr)).await?);
376
377        // Spawn a task to trigger shutdown on signal
378        let handle_clone = self.handle.clone();
379        let state_clone = self.state().clone();
380        tokio::spawn(async move {
381            shutdown_signal(handle_clone, state_clone).await;
382        });
383
384        let handle_clone = self.handle.clone();
385        axum_server::bind_rustls(addr, config)
386            .handle(handle_clone)
387            .serve(self.app.into_make_service())
388            .await
389            .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
390    }
391
392    /// Returns server handle that could be used for graceful shutdown
393    pub fn server_handle(&self) -> Handle {
394        self.handle.clone()
395    }
396
397    /// Starts the server without SSL
398    ///
399    /// # Arguments
400    /// * `addr` - The server address to bind to
401    ///
402    /// # Returns
403    /// * `TransportServerResult<()>` - Ok if the server starts successfully, Err otherwise
404    pub(crate) async fn start_http(self, addr: SocketAddr) -> TransportServerResult<()> {
405        tracing::info!("{}", self.server_info(Some(addr)).await?);
406
407        // Spawn a task to trigger shutdown on signal
408        let handle_clone = self.handle.clone();
409        tokio::spawn(async move {
410            shutdown_signal(handle_clone, self.state.clone()).await;
411        });
412
413        let handle_clone = self.handle.clone();
414        axum_server::bind(addr)
415            .handle(handle_clone)
416            .serve(self.app.into_make_service())
417            .await
418            .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
419    }
420
421    /// Starts the server, choosing SSL or HTTP based on configuration
422    ///
423    /// Resolves the server address and starts the server in either SSL or HTTP mode.
424    /// Panics if SSL is requested but the "ssl" feature is not enabled.
425    ///
426    /// # Returns
427    /// * `SdkResult<()>` - Ok if the server starts successfully, Err otherwise
428    pub async fn start(self) -> SdkResult<()> {
429        let runtime = HyperRuntime::create(self).await?;
430        runtime.await_server().await
431    }
432
433    /// Similar to start() , but returns a HyperRuntime after server started
434    ///
435    /// HyperRuntime could be used to access sessions and send server initiated messages if needed
436    ///
437    /// # Returns
438    /// * `SdkResult<HyperRuntime>` - Ok if the server starts successfully, Err otherwise
439    pub async fn start_runtime(self) -> SdkResult<HyperRuntime> {
440        HyperRuntime::create(self).await
441    }
442}
443
444// Shutdown signal handler
445async fn shutdown_signal(handle: Handle, state: Arc<AppState>) {
446    // Wait for a Ctrl+C or SIGTERM signal
447    let ctrl_c = async {
448        signal::ctrl_c()
449            .await
450            .expect("Failed to install Ctrl+C handler");
451    };
452
453    #[cfg(unix)]
454    let terminate = async {
455        signal::unix::signal(signal::unix::SignalKind::terminate())
456            .expect("Failed to install signal handler")
457            .recv()
458            .await;
459    };
460
461    #[cfg(not(unix))]
462    let terminate = std::future::pending::<()>();
463
464    tokio::select! {
465        _ = ctrl_c => {},
466        _ = terminate => {},
467    }
468
469    tracing::info!("Signal received, starting graceful shutdown");
470    state.session_store.clear().await;
471    // Trigger graceful shutdown with a timeout
472    handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS)));
473}