rust_mcp_sdk/hyper_servers/
server.rs

1use crate::mcp_traits::mcp_handler::McpServerHandler;
2#[cfg(feature = "ssl")]
3use axum_server::tls_rustls::RustlsConfig;
4use axum_server::Handle;
5use std::{
6    net::{SocketAddr, ToSocketAddrs},
7    path::Path,
8    sync::Arc,
9    time::Duration,
10};
11use tokio::signal;
12
13use super::{
14    app_state::AppState,
15    error::{TransportServerError, TransportServerResult},
16    routes::app_routes,
17    IdGenerator, InMemorySessionStore, UuidGenerator,
18};
19use axum::Router;
20use rust_mcp_schema::InitializeResult;
21use rust_mcp_transport::TransportOptions;
22
23// Default client ping interval (12 seconds)
24const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12);
25const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 30;
26// Default Server-Sent Events (SSE) endpoint path
27const DEFAULT_SSE_ENDPOINT: &str = "/sse";
28// Default MCP Messages endpoint path
29const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages";
30
31/// Configuration struct for the Hyper server
32/// Used to configure the HyperServer instance.
33pub struct HyperServerOptions {
34    /// Hostname or IP address the server will bind to (default: "localhost")
35    pub host: String,
36    /// Hostname or IP address the server will bind to (default: "8080")
37    pub port: u16,
38    /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`)
39    pub custom_sse_endpoint: Option<String>,
40    /// Optional custom path for the MCP messages endpoint (default: `/messages`)
41    pub custom_messages_endpoint: Option<String>,
42    /// Interval between automatic ping messages sent to clients to detect disconnects
43    pub ping_interval: Duration,
44    /// Enables SSL/TLS if set to `true`
45    pub enable_ssl: bool,
46    /// Path to the SSL/TLS certificate file (e.g., "cert.pem").
47    /// Required if `enable_ssl` is `true`.
48    pub ssl_cert_path: Option<String>,
49    /// Path to the SSL/TLS private key file (e.g., "key.pem").
50    /// Required if `enable_ssl` is `true`.
51    pub ssl_key_path: Option<String>,
52    /// Shared transport configuration used by the server
53    pub transport_options: Arc<TransportOptions>,
54    /// Optional thread-safe session id generator to generate unique session IDs.
55    pub session_id_generator: Option<Arc<dyn IdGenerator>>,
56}
57
58impl HyperServerOptions {
59    /// Validates the server configuration options
60    ///
61    /// Ensures that SSL-related paths are provided and valid when SSL is enabled.
62    ///
63    /// # Returns
64    /// * `TransportServerResult<()>` - Ok if validation passes, Err with TransportServerError if invalid
65    pub fn validate(&self) -> TransportServerResult<()> {
66        if self.enable_ssl {
67            if self.ssl_cert_path.is_none() || self.ssl_key_path.is_none() {
68                return Err(TransportServerError::InvalidServerOptions(
69                    "Both 'ssl_cert_path' and 'ssl_key_path' must be provided when SSL is enabled."
70                        .into(),
71                ));
72            }
73
74            if !Path::new(self.ssl_cert_path.as_deref().unwrap_or("")).is_file() {
75                return Err(TransportServerError::InvalidServerOptions(
76                    "'ssl_cert_path' does not point to a valid or existing file.".into(),
77                ));
78            }
79
80            if !Path::new(self.ssl_key_path.as_deref().unwrap_or("")).is_file() {
81                return Err(TransportServerError::InvalidServerOptions(
82                    "'ssl_key_path' does not point to a valid or existing file.".into(),
83                ));
84            }
85        }
86
87        Ok(())
88    }
89
90    /// Resolves the server address from host and port
91    ///
92    /// Validates the configuration and converts the host/port into a SocketAddr.
93    /// Handles scheme prefixes (http:// or https://) and logs warnings for mismatches.
94    ///
95    /// # Returns
96    /// * `TransportServerResult<SocketAddr>` - The resolved server address or an error
97    async fn resolve_server_address(&self) -> TransportServerResult<SocketAddr> {
98        self.validate()?;
99
100        let mut host = self.host.to_string();
101        if let Some(stripped) = self.host.strip_prefix("http://") {
102            if self.enable_ssl {
103                tracing::warn!("Warning: Ignoring http:// scheme for SSL; using hostname only");
104            }
105            host = stripped.to_string();
106        } else if let Some(stripped) = host.strip_prefix("https://") {
107            host = stripped.to_string();
108        }
109
110        let addr = {
111            let mut iter = (host, self.port)
112                .to_socket_addrs()
113                .map_err(|err| TransportServerError::ServerStartError(err.to_string()))?;
114            match iter.next() {
115                Some(addr) => addr,
116                None => format!("{}:{}", self.host, self.port).parse().map_err(
117                    |err: std::net::AddrParseError| {
118                        TransportServerError::ServerStartError(err.to_string())
119                    },
120                )?,
121            }
122        };
123        Ok(addr)
124    }
125
126    pub fn sse_endpoint(&self) -> &str {
127        self.custom_sse_endpoint
128            .as_deref()
129            .unwrap_or(DEFAULT_SSE_ENDPOINT)
130    }
131
132    pub fn sse_messages_endpoint(&self) -> &str {
133        self.custom_messages_endpoint
134            .as_deref()
135            .unwrap_or(DEFAULT_MESSAGES_ENDPOINT)
136    }
137}
138
139/// Default implementation for HyperServerOptions
140///
141/// Provides default values for the server configuration, including localhost address,
142/// port 8080, default SSE endpoint, and 12-second ping interval.
143impl Default for HyperServerOptions {
144    fn default() -> Self {
145        Self {
146            host: "127.0.0.1".to_string(),
147            port: 8080,
148            custom_sse_endpoint: None,
149            custom_messages_endpoint: None,
150            ping_interval: DEFAULT_CLIENT_PING_INTERVAL,
151            transport_options: Default::default(),
152            enable_ssl: false,
153            ssl_cert_path: None,
154            ssl_key_path: None,
155            session_id_generator: None,
156        }
157    }
158}
159
160/// Hyper server struct for managing the Axum-based web server
161pub struct HyperServer {
162    app: Router,
163    state: Arc<AppState>,
164    options: HyperServerOptions,
165    handle: Handle,
166}
167
168impl HyperServer {
169    /// Creates a new HyperServer instance
170    ///
171    /// Initializes the server with the provided server details, handler, and options.
172    ///
173    /// # Arguments
174    /// * `server_details` - Initialization result from the MCP schema
175    /// * `handler` - Shared MCP server handler with static lifetime
176    /// * `server_options` - Server configuration options
177    ///
178    /// # Returns
179    /// * `Self` - A new HyperServer instance
180    pub(crate) fn new(
181        server_details: InitializeResult,
182        handler: Arc<dyn McpServerHandler + 'static>,
183        mut server_options: HyperServerOptions,
184    ) -> Self {
185        let state: Arc<AppState> = Arc::new(AppState {
186            session_store: Arc::new(InMemorySessionStore::new()),
187            id_generator: server_options
188                .session_id_generator
189                .take()
190                .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)),
191            server_details: Arc::new(server_details),
192            handler,
193            ping_interval: server_options.ping_interval,
194            sse_message_endpoint: server_options.sse_messages_endpoint().to_owned(),
195            transport_options: Arc::clone(&server_options.transport_options),
196        });
197        let app = app_routes(Arc::clone(&state), &server_options);
198        Self {
199            app,
200            state,
201            options: server_options,
202            handle: Handle::new(),
203        }
204    }
205
206    /// Returns a shared reference to the application state
207    ///
208    /// # Returns
209    /// * `Arc<AppState>` - Shared application state
210    pub fn state(&self) -> Arc<AppState> {
211        Arc::clone(&self.state)
212    }
213
214    /// Adds a new route to the server
215    ///
216    /// # Arguments
217    /// * `path` - The route path (static string)
218    /// * `route` - The Axum MethodRouter for handling the route
219    ///
220    /// # Returns
221    /// * `Self` - The modified HyperServer instance
222    pub fn with_route(mut self, path: &'static str, route: axum::routing::MethodRouter) -> Self {
223        self.app = self.app.route(path, route);
224        self
225    }
226
227    /// Generates server information string
228    ///
229    /// Constructs a string describing the server type, protocol, address, and SSE endpoint.
230    ///
231    /// # Arguments
232    /// * `addr` - Optional SocketAddr; if None, resolves from options
233    ///
234    /// # Returns
235    /// * `TransportServerResult<String>` - The server information string or an error
236    pub async fn server_info(&self, addr: Option<SocketAddr>) -> TransportServerResult<String> {
237        let addr = addr.unwrap_or(self.options.resolve_server_address().await?);
238        let server_type = if self.options.enable_ssl {
239            "SSL server"
240        } else {
241            "Server"
242        };
243        let protocol = if self.options.enable_ssl {
244            "https"
245        } else {
246            "http"
247        };
248
249        let server_url = format!(
250            "{} is available at {}://{}{}",
251            server_type,
252            protocol,
253            addr,
254            self.options.sse_endpoint()
255        );
256
257        Ok(server_url)
258    }
259
260    // pub fn with_layer<L>(mut self, layer: L) -> Self
261    // where
262    //     // L: Layer<axum::body::Body> + Clone + Send + Sync + 'static,
263    //     L::Service: Send + Sync + 'static,
264    // {
265    //     self.router = self.router.layer(layer);
266    //     self
267    // }
268
269    /// Starts the server with SSL support (available when "ssl" feature is enabled)
270    ///
271    /// # Arguments
272    /// * `addr` - The server address to bind to
273    ///
274    /// # Returns
275    /// * `TransportServerResult<()>` - Ok if the server starts successfully, Err otherwise
276    #[cfg(feature = "ssl")]
277    async fn start_ssl(self, addr: SocketAddr) -> TransportServerResult<()> {
278        let config = RustlsConfig::from_pem_file(
279            self.options.ssl_cert_path.as_deref().unwrap_or_default(),
280            self.options.ssl_key_path.as_deref().unwrap_or_default(),
281        )
282        .await
283        .map_err(|err| TransportServerError::SslCertError(err.to_string()))?;
284
285        tracing::info!("{}", self.server_info(Some(addr)).await?);
286
287        // Spawn a task to trigger shutdown on signal
288        let handle_clone = self.handle.clone();
289        tokio::spawn(async move {
290            shutdown_signal(handle_clone).await;
291        });
292
293        let handle_clone = self.handle.clone();
294        axum_server::bind_rustls(addr, config)
295            .handle(handle_clone)
296            .serve(self.app.into_make_service())
297            .await
298            .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
299    }
300
301    /// Returns server handle that could be used for graceful shutdown
302    pub fn server_handle(&self) -> Handle {
303        self.handle.clone()
304    }
305
306    /// Starts the server without SSL
307    ///
308    /// # Arguments
309    /// * `addr` - The server address to bind to
310    ///
311    /// # Returns
312    /// * `TransportServerResult<()>` - Ok if the server starts successfully, Err otherwise
313    async fn start_http(self, addr: SocketAddr) -> TransportServerResult<()> {
314        tracing::info!("{}", self.server_info(Some(addr)).await?);
315
316        // Spawn a task to trigger shutdown on signal
317        let handle_clone = self.handle.clone();
318        tokio::spawn(async move {
319            shutdown_signal(handle_clone).await;
320        });
321
322        let handle_clone = self.handle.clone();
323        axum_server::bind(addr)
324            .handle(handle_clone)
325            .serve(self.app.into_make_service())
326            .await
327            .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
328    }
329
330    /// Starts the server, choosing SSL or HTTP based on configuration
331    ///
332    /// Resolves the server address and starts the server in either SSL or HTTP mode.
333    /// Panics if SSL is requested but the "ssl" feature is not enabled.
334    ///
335    /// # Returns
336    /// * `TransportServerResult<()>` - Ok if the server starts successfully, Err otherwise
337    pub async fn start(self) -> TransportServerResult<()> {
338        let addr = self.options.resolve_server_address().await?;
339
340        #[cfg(feature = "ssl")]
341        if self.options.enable_ssl {
342            self.start_ssl(addr).await
343        } else {
344            self.start_http(addr).await
345        }
346
347        #[cfg(not(feature = "ssl"))]
348        if self.options.enable_ssl {
349            panic!("SSL requested but the 'ssl' feature is not enabled");
350        } else {
351            self.start_http(addr).await
352        }
353    }
354}
355
356// Shutdown signal handler
357async fn shutdown_signal(handle: Handle) {
358    // Wait for a Ctrl+C or SIGTERM signal
359    let ctrl_c = async {
360        signal::ctrl_c()
361            .await
362            .expect("Failed to install Ctrl+C handler");
363    };
364
365    #[cfg(unix)]
366    let terminate = async {
367        signal::unix::signal(signal::unix::SignalKind::terminate())
368            .expect("Failed to install signal handler")
369            .recv()
370            .await;
371    };
372
373    #[cfg(not(unix))]
374    let terminate = std::future::pending::<()>();
375
376    tokio::select! {
377        _ = ctrl_c => {},
378        _ = terminate => {},
379    }
380
381    tracing::info!("Signal received, starting graceful shutdown");
382    // Trigger graceful shutdown with a timeout
383    handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS)));
384}