rust_mcp_sdk/hyper_servers/
server.rs

1use crate::{
2    error::SdkResult,
3    id_generator::{FastIdGenerator, UuidGenerator},
4    mcp_server::hyper_runtime::HyperRuntime,
5    mcp_traits::{mcp_handler::McpServerHandler, IdGenerator},
6};
7#[cfg(feature = "ssl")]
8use axum_server::tls_rustls::RustlsConfig;
9use axum_server::Handle;
10use std::{
11    net::{SocketAddr, ToSocketAddrs},
12    path::Path,
13    sync::Arc,
14    time::Duration,
15};
16use tokio::signal;
17
18use super::{
19    app_state::AppState,
20    error::{TransportServerError, TransportServerResult},
21    routes::app_routes,
22    InMemorySessionStore,
23};
24use crate::schema::InitializeResult;
25use axum::Router;
26use rust_mcp_transport::{event_store::EventStore, SessionId, TransportOptions};
27
28// Default client ping interval (12 seconds)
29const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12);
30const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 5;
31// Default Server-Sent Events (SSE) endpoint path
32const DEFAULT_SSE_ENDPOINT: &str = "/sse";
33// Default MCP Messages endpoint path
34const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages";
35// Default Streamable HTTP endpoint path
36const DEFAULT_STREAMABLE_HTTP_ENDPOINT: &str = "/mcp";
37
38/// Configuration struct for the Hyper server
39/// Used to configure the HyperServer instance.
40pub struct HyperServerOptions {
41    /// Hostname or IP address the server will bind to (default: "127.0.0.1")
42    pub host: String,
43
44    /// Hostname or IP address the server will bind to (default: "8080")
45    pub port: u16,
46
47    /// Optional thread-safe session id generator to generate unique session IDs.
48    pub session_id_generator: Option<Arc<dyn IdGenerator<SessionId>>>,
49
50    /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`)
51    pub custom_streamable_http_endpoint: Option<String>,
52
53    /// Shared transport configuration used by the server
54    pub transport_options: Arc<TransportOptions>,
55
56    /// Event store for resumability support
57    /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages
58    pub event_store: Option<Arc<dyn EventStore>>,
59
60    /// This setting only applies to streamable HTTP.
61    /// If true, the server will return JSON responses instead of starting an SSE stream.
62    /// This can be useful for simple request/response scenarios without streaming.
63    /// Default is false (SSE streams are preferred).
64    pub enable_json_response: Option<bool>,
65
66    /// Interval between automatic ping messages sent to clients to detect disconnects
67    pub ping_interval: Duration,
68
69    /// Enables SSL/TLS if set to `true`
70    pub enable_ssl: bool,
71
72    /// Path to the SSL/TLS certificate file (e.g., "cert.pem").
73    /// Required if `enable_ssl` is `true`.
74    pub ssl_cert_path: Option<String>,
75
76    /// Path to the SSL/TLS private key file (e.g., "key.pem").
77    /// Required if `enable_ssl` is `true`.
78    pub ssl_key_path: Option<String>,
79
80    /// List of allowed host header values for DNS rebinding protection.
81    /// If not specified, host validation is disabled.
82    pub allowed_hosts: Option<Vec<String>>,
83
84    /// List of allowed origin header values for DNS rebinding protection.
85    /// If not specified, origin validation is disabled.
86    pub allowed_origins: Option<Vec<String>>,
87
88    /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured).
89    /// Default is false for backwards compatibility.
90    pub dns_rebinding_protection: bool,
91
92    /// If set to true, the SSE transport will also be supported for backward compatibility (default: true)
93    pub sse_support: bool,
94
95    /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`)
96    /// Applicable only if sse_support is true
97    pub custom_sse_endpoint: Option<String>,
98
99    /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`)
100    /// Applicable only if sse_support is true
101    pub custom_messages_endpoint: Option<String>,
102}
103
104impl HyperServerOptions {
105    /// Validates the server configuration options
106    ///
107    /// Ensures that SSL-related paths are provided and valid when SSL is enabled.
108    ///
109    /// # Returns
110    /// * `TransportServerResult<()>` - Ok if validation passes, Err with TransportServerError if invalid
111    pub fn validate(&self) -> TransportServerResult<()> {
112        if self.enable_ssl {
113            if self.ssl_cert_path.is_none() || self.ssl_key_path.is_none() {
114                return Err(TransportServerError::InvalidServerOptions(
115                    "Both 'ssl_cert_path' and 'ssl_key_path' must be provided when SSL is enabled."
116                        .into(),
117                ));
118            }
119
120            if !Path::new(self.ssl_cert_path.as_deref().unwrap_or("")).is_file() {
121                return Err(TransportServerError::InvalidServerOptions(
122                    "'ssl_cert_path' does not point to a valid or existing file.".into(),
123                ));
124            }
125
126            if !Path::new(self.ssl_key_path.as_deref().unwrap_or("")).is_file() {
127                return Err(TransportServerError::InvalidServerOptions(
128                    "'ssl_key_path' does not point to a valid or existing file.".into(),
129                ));
130            }
131        }
132
133        Ok(())
134    }
135
136    /// Resolves the server address from host and port
137    ///
138    /// Validates the configuration and converts the host/port into a SocketAddr.
139    /// Handles scheme prefixes (http:// or https://) and logs warnings for mismatches.
140    ///
141    /// # Returns
142    /// * `TransportServerResult<SocketAddr>` - The resolved server address or an error
143    pub(crate) async fn resolve_server_address(&self) -> TransportServerResult<SocketAddr> {
144        self.validate()?;
145
146        let mut host = self.host.to_string();
147        if let Some(stripped) = self.host.strip_prefix("http://") {
148            if self.enable_ssl {
149                tracing::warn!("Warning: Ignoring http:// scheme for SSL; using hostname only");
150            }
151            host = stripped.to_string();
152        } else if let Some(stripped) = host.strip_prefix("https://") {
153            host = stripped.to_string();
154        }
155
156        let addr = {
157            let mut iter = (host, self.port)
158                .to_socket_addrs()
159                .map_err(|err| TransportServerError::ServerStartError(err.to_string()))?;
160            match iter.next() {
161                Some(addr) => addr,
162                None => format!("{}:{}", self.host, self.port).parse().map_err(
163                    |err: std::net::AddrParseError| {
164                        TransportServerError::ServerStartError(err.to_string())
165                    },
166                )?,
167            }
168        };
169        Ok(addr)
170    }
171
172    pub fn base_url(&self) -> String {
173        format!(
174            "{}://{}:{}",
175            if self.enable_ssl { "https" } else { "http" },
176            self.host,
177            self.port
178        )
179    }
180    pub fn streamable_http_url(&self) -> String {
181        format!("{}{}", self.base_url(), self.streamable_http_endpoint())
182    }
183    pub fn sse_url(&self) -> String {
184        format!("{}{}", self.base_url(), self.sse_endpoint())
185    }
186    pub fn sse_message_url(&self) -> String {
187        format!("{}{}", self.base_url(), self.sse_messages_endpoint())
188    }
189
190    pub fn sse_endpoint(&self) -> &str {
191        self.custom_sse_endpoint
192            .as_deref()
193            .unwrap_or(DEFAULT_SSE_ENDPOINT)
194    }
195
196    pub fn sse_messages_endpoint(&self) -> &str {
197        self.custom_messages_endpoint
198            .as_deref()
199            .unwrap_or(DEFAULT_MESSAGES_ENDPOINT)
200    }
201
202    pub fn streamable_http_endpoint(&self) -> &str {
203        self.custom_messages_endpoint
204            .as_deref()
205            .unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT)
206    }
207}
208
209/// Default implementation for HyperServerOptions
210///
211/// Provides default values for the server configuration, including 127.0.0.1 address,
212/// port 8080, default Streamable HTTP endpoint, and 12-second ping interval.
213impl Default for HyperServerOptions {
214    fn default() -> Self {
215        Self {
216            host: "127.0.0.1".to_string(),
217            port: 8080,
218            custom_sse_endpoint: None,
219            custom_streamable_http_endpoint: None,
220            custom_messages_endpoint: None,
221            ping_interval: DEFAULT_CLIENT_PING_INTERVAL,
222            transport_options: Default::default(),
223            enable_ssl: false,
224            ssl_cert_path: None,
225            ssl_key_path: None,
226            session_id_generator: None,
227            enable_json_response: None,
228            sse_support: true,
229            allowed_hosts: None,
230            allowed_origins: None,
231            dns_rebinding_protection: false,
232            event_store: None,
233        }
234    }
235}
236
237/// Hyper server struct for managing the Axum-based web server
238pub struct HyperServer {
239    app: Router,
240    state: Arc<AppState>,
241    pub(crate) options: HyperServerOptions,
242    handle: Handle,
243}
244
245impl HyperServer {
246    /// Creates a new HyperServer instance
247    ///
248    /// Initializes the server with the provided server details, handler, and options.
249    ///
250    /// # Arguments
251    /// * `server_details` - Initialization result from the MCP schema
252    /// * `handler` - Shared MCP server handler with static lifetime
253    /// * `server_options` - Server configuration options
254    ///
255    /// # Returns
256    /// * `Self` - A new HyperServer instance
257    pub(crate) fn new(
258        server_details: InitializeResult,
259        handler: Arc<dyn McpServerHandler + 'static>,
260        mut server_options: HyperServerOptions,
261    ) -> Self {
262        let state: Arc<AppState> = Arc::new(AppState {
263            session_store: Arc::new(InMemorySessionStore::new()),
264            id_generator: server_options
265                .session_id_generator
266                .take()
267                .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)),
268            stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))),
269            server_details: Arc::new(server_details),
270            handler,
271            ping_interval: server_options.ping_interval,
272            sse_message_endpoint: server_options.sse_messages_endpoint().to_owned(),
273            http_streamable_endpoint: server_options.streamable_http_endpoint().to_owned(),
274            transport_options: Arc::clone(&server_options.transport_options),
275            enable_json_response: server_options.enable_json_response.unwrap_or(false),
276            allowed_hosts: server_options.allowed_hosts.take(),
277            allowed_origins: server_options.allowed_origins.take(),
278            dns_rebinding_protection: server_options.dns_rebinding_protection,
279            event_store: server_options.event_store.as_ref().map(Arc::clone),
280        });
281        let app = app_routes(Arc::clone(&state), &server_options);
282        Self {
283            app,
284            state,
285            options: server_options,
286            handle: Handle::new(),
287        }
288    }
289
290    /// Returns a shared reference to the application state
291    ///
292    /// # Returns
293    /// * `Arc<AppState>` - Shared application state
294    pub fn state(&self) -> Arc<AppState> {
295        Arc::clone(&self.state)
296    }
297
298    /// Adds a new route to the server
299    ///
300    /// # Arguments
301    /// * `path` - The route path (static string)
302    /// * `route` - The Axum MethodRouter for handling the route
303    ///
304    /// # Returns
305    /// * `Self` - The modified HyperServer instance
306    pub fn with_route(mut self, path: &'static str, route: axum::routing::MethodRouter) -> Self {
307        self.app = self.app.route(path, route);
308        self
309    }
310
311    /// Generates server information string
312    ///
313    /// Constructs a string describing the server type, protocol, address, and SSE endpoint.
314    ///
315    /// # Arguments
316    /// * `addr` - Optional SocketAddr; if None, resolves from options
317    ///
318    /// # Returns
319    /// * `TransportServerResult<String>` - The server information string or an error
320    pub async fn server_info(&self, addr: Option<SocketAddr>) -> TransportServerResult<String> {
321        let addr = addr.unwrap_or(self.options.resolve_server_address().await?);
322        let server_type = if self.options.enable_ssl {
323            "SSL server"
324        } else {
325            "Server"
326        };
327        let protocol = if self.options.enable_ssl {
328            "https"
329        } else {
330            "http"
331        };
332
333        let mut server_url = format!(
334            "\n• Streamable HTTP {} is available at {}://{}{}",
335            server_type,
336            protocol,
337            addr,
338            self.options.streamable_http_endpoint()
339        );
340
341        if self.options.sse_support {
342            let sse_url = format!(
343                "\n• SSE {} is available at {}://{}{}",
344                server_type,
345                protocol,
346                addr,
347                self.options.sse_endpoint()
348            );
349            server_url.push_str(&sse_url);
350        };
351
352        Ok(server_url)
353    }
354
355    pub fn options(&self) -> &HyperServerOptions {
356        &self.options
357    }
358
359    // pub fn with_layer<L>(mut self, layer: L) -> Self
360    // where
361    //     // L: Layer<axum::body::Body> + Clone + Send + Sync + 'static,
362    //     L::Service: Send + Sync + 'static,
363    // {
364    //     self.router = self.router.layer(layer);
365    //     self
366    // }
367
368    /// Starts the server with SSL support (available when "ssl" feature is enabled)
369    ///
370    /// # Arguments
371    /// * `addr` - The server address to bind to
372    ///
373    /// # Returns
374    /// * `TransportServerResult<()>` - Ok if the server starts successfully, Err otherwise
375    #[cfg(feature = "ssl")]
376    pub(crate) async fn start_ssl(self, addr: SocketAddr) -> TransportServerResult<()> {
377        let config = RustlsConfig::from_pem_file(
378            self.options.ssl_cert_path.as_deref().unwrap_or_default(),
379            self.options.ssl_key_path.as_deref().unwrap_or_default(),
380        )
381        .await
382        .map_err(|err| TransportServerError::SslCertError(err.to_string()))?;
383
384        tracing::info!("{}", self.server_info(Some(addr)).await?);
385
386        // Spawn a task to trigger shutdown on signal
387        let handle_clone = self.handle.clone();
388        let state_clone = self.state().clone();
389        tokio::spawn(async move {
390            shutdown_signal(handle_clone, state_clone).await;
391        });
392
393        let handle_clone = self.handle.clone();
394        axum_server::bind_rustls(addr, config)
395            .handle(handle_clone)
396            .serve(self.app.into_make_service())
397            .await
398            .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
399    }
400
401    /// Returns server handle that could be used for graceful shutdown
402    pub fn server_handle(&self) -> Handle {
403        self.handle.clone()
404    }
405
406    /// Starts the server without SSL
407    ///
408    /// # Arguments
409    /// * `addr` - The server address to bind to
410    ///
411    /// # Returns
412    /// * `TransportServerResult<()>` - Ok if the server starts successfully, Err otherwise
413    pub(crate) async fn start_http(self, addr: SocketAddr) -> TransportServerResult<()> {
414        tracing::info!("{}", self.server_info(Some(addr)).await?);
415
416        // Spawn a task to trigger shutdown on signal
417        let handle_clone = self.handle.clone();
418        tokio::spawn(async move {
419            shutdown_signal(handle_clone, self.state.clone()).await;
420        });
421
422        let handle_clone = self.handle.clone();
423        axum_server::bind(addr)
424            .handle(handle_clone)
425            .serve(self.app.into_make_service())
426            .await
427            .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
428    }
429
430    /// Starts the server, choosing SSL or HTTP based on configuration
431    ///
432    /// Resolves the server address and starts the server in either SSL or HTTP mode.
433    /// Panics if SSL is requested but the "ssl" feature is not enabled.
434    ///
435    /// # Returns
436    /// * `SdkResult<()>` - Ok if the server starts successfully, Err otherwise
437    pub async fn start(self) -> SdkResult<()> {
438        let runtime = HyperRuntime::create(self).await?;
439        runtime.await_server().await
440    }
441
442    /// Similar to start() , but returns a HyperRuntime after server started
443    ///
444    /// HyperRuntime could be used to access sessions and send server initiated messages if needed
445    ///
446    /// # Returns
447    /// * `SdkResult<HyperRuntime>` - Ok if the server starts successfully, Err otherwise
448    pub async fn start_runtime(self) -> SdkResult<HyperRuntime> {
449        HyperRuntime::create(self).await
450    }
451}
452
453// Shutdown signal handler
454async fn shutdown_signal(handle: Handle, state: Arc<AppState>) {
455    // Wait for a Ctrl+C or SIGTERM signal
456    let ctrl_c = async {
457        signal::ctrl_c()
458            .await
459            .expect("Failed to install Ctrl+C handler");
460    };
461
462    #[cfg(unix)]
463    let terminate = async {
464        signal::unix::signal(signal::unix::SignalKind::terminate())
465            .expect("Failed to install signal handler")
466            .recv()
467            .await;
468    };
469
470    #[cfg(not(unix))]
471    let terminate = std::future::pending::<()>();
472
473    tokio::select! {
474        _ = ctrl_c => {},
475        _ = terminate => {},
476    }
477
478    tracing::info!("Signal received, starting graceful shutdown");
479    state.session_store.clear().await;
480    // Trigger graceful shutdown with a timeout
481    handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS)));
482}