rust_mcp_sdk/hyper_servers/
server.rs1use crate::mcp_traits::mcp_handler::McpServerHandler;
2#[cfg(feature = "ssl")]
3use axum_server::tls_rustls::RustlsConfig;
4use axum_server::Handle;
5use std::{
6 net::{SocketAddr, ToSocketAddrs},
7 path::Path,
8 sync::Arc,
9 time::Duration,
10};
11use tokio::signal;
12
13use super::{
14 app_state::AppState,
15 error::{TransportServerError, TransportServerResult},
16 routes::app_routes,
17 IdGenerator, InMemorySessionStore, UuidGenerator,
18};
19use axum::Router;
20use rust_mcp_schema::InitializeResult;
21use rust_mcp_transport::TransportOptions;
22
23const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12);
25const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 30;
26const DEFAULT_SSE_ENDPOINT: &str = "/sse";
28const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages";
30
31pub struct HyperServerOptions {
34 pub host: String,
36 pub port: u16,
38 pub custom_sse_endpoint: Option<String>,
40 pub custom_messages_endpoint: Option<String>,
42 pub ping_interval: Duration,
44 pub enable_ssl: bool,
46 pub ssl_cert_path: Option<String>,
49 pub ssl_key_path: Option<String>,
52 pub transport_options: Arc<TransportOptions>,
54 pub session_id_generator: Option<Arc<dyn IdGenerator>>,
56}
57
58impl HyperServerOptions {
59 pub fn validate(&self) -> TransportServerResult<()> {
66 if self.enable_ssl {
67 if self.ssl_cert_path.is_none() || self.ssl_key_path.is_none() {
68 return Err(TransportServerError::InvalidServerOptions(
69 "Both 'ssl_cert_path' and 'ssl_key_path' must be provided when SSL is enabled."
70 .into(),
71 ));
72 }
73
74 if !Path::new(self.ssl_cert_path.as_deref().unwrap_or("")).is_file() {
75 return Err(TransportServerError::InvalidServerOptions(
76 "'ssl_cert_path' does not point to a valid or existing file.".into(),
77 ));
78 }
79
80 if !Path::new(self.ssl_key_path.as_deref().unwrap_or("")).is_file() {
81 return Err(TransportServerError::InvalidServerOptions(
82 "'ssl_key_path' does not point to a valid or existing file.".into(),
83 ));
84 }
85 }
86
87 Ok(())
88 }
89
90 async fn resolve_server_address(&self) -> TransportServerResult<SocketAddr> {
98 self.validate()?;
99
100 let mut host = self.host.to_string();
101 if let Some(stripped) = self.host.strip_prefix("http://") {
102 if self.enable_ssl {
103 tracing::warn!("Warning: Ignoring http:// scheme for SSL; using hostname only");
104 }
105 host = stripped.to_string();
106 } else if let Some(stripped) = host.strip_prefix("https://") {
107 host = stripped.to_string();
108 }
109
110 let addr = {
111 let mut iter = (host, self.port)
112 .to_socket_addrs()
113 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))?;
114 match iter.next() {
115 Some(addr) => addr,
116 None => format!("{}:{}", self.host, self.port).parse().map_err(
117 |err: std::net::AddrParseError| {
118 TransportServerError::ServerStartError(err.to_string())
119 },
120 )?,
121 }
122 };
123 Ok(addr)
124 }
125
126 pub fn sse_endpoint(&self) -> &str {
127 self.custom_sse_endpoint
128 .as_deref()
129 .unwrap_or(DEFAULT_SSE_ENDPOINT)
130 }
131
132 pub fn sse_messages_endpoint(&self) -> &str {
133 self.custom_messages_endpoint
134 .as_deref()
135 .unwrap_or(DEFAULT_MESSAGES_ENDPOINT)
136 }
137}
138
139impl Default for HyperServerOptions {
144 fn default() -> Self {
145 Self {
146 host: "127.0.0.1".to_string(),
147 port: 8080,
148 custom_sse_endpoint: None,
149 custom_messages_endpoint: None,
150 ping_interval: DEFAULT_CLIENT_PING_INTERVAL,
151 transport_options: Default::default(),
152 enable_ssl: false,
153 ssl_cert_path: None,
154 ssl_key_path: None,
155 session_id_generator: None,
156 }
157 }
158}
159
160pub struct HyperServer {
162 app: Router,
163 state: Arc<AppState>,
164 options: HyperServerOptions,
165 handle: Handle,
166}
167
168impl HyperServer {
169 pub(crate) fn new(
181 server_details: InitializeResult,
182 handler: Arc<dyn McpServerHandler + 'static>,
183 mut server_options: HyperServerOptions,
184 ) -> Self {
185 let state: Arc<AppState> = Arc::new(AppState {
186 session_store: Arc::new(InMemorySessionStore::new()),
187 id_generator: server_options
188 .session_id_generator
189 .take()
190 .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)),
191 server_details: Arc::new(server_details),
192 handler,
193 ping_interval: server_options.ping_interval,
194 sse_message_endpoint: server_options.sse_messages_endpoint().to_owned(),
195 transport_options: Arc::clone(&server_options.transport_options),
196 });
197 let app = app_routes(Arc::clone(&state), &server_options);
198 Self {
199 app,
200 state,
201 options: server_options,
202 handle: Handle::new(),
203 }
204 }
205
206 pub fn state(&self) -> Arc<AppState> {
211 Arc::clone(&self.state)
212 }
213
214 pub fn with_route(mut self, path: &'static str, route: axum::routing::MethodRouter) -> Self {
223 self.app = self.app.route(path, route);
224 self
225 }
226
227 pub async fn server_info(&self, addr: Option<SocketAddr>) -> TransportServerResult<String> {
237 let addr = addr.unwrap_or(self.options.resolve_server_address().await?);
238 let server_type = if self.options.enable_ssl {
239 "SSL server"
240 } else {
241 "Server"
242 };
243 let protocol = if self.options.enable_ssl {
244 "https"
245 } else {
246 "http"
247 };
248
249 let server_url = format!(
250 "{} is available at {}://{}{}",
251 server_type,
252 protocol,
253 addr,
254 self.options.sse_endpoint()
255 );
256
257 Ok(server_url)
258 }
259
260 #[cfg(feature = "ssl")]
277 async fn start_ssl(self, addr: SocketAddr) -> TransportServerResult<()> {
278 let config = RustlsConfig::from_pem_file(
279 self.options.ssl_cert_path.as_deref().unwrap_or_default(),
280 self.options.ssl_key_path.as_deref().unwrap_or_default(),
281 )
282 .await
283 .map_err(|err| TransportServerError::SslCertError(err.to_string()))?;
284
285 tracing::info!("{}", self.server_info(Some(addr)).await?);
286
287 let handle_clone = self.handle.clone();
289 tokio::spawn(async move {
290 shutdown_signal(handle_clone).await;
291 });
292
293 let handle_clone = self.handle.clone();
294 axum_server::bind_rustls(addr, config)
295 .handle(handle_clone)
296 .serve(self.app.into_make_service())
297 .await
298 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
299 }
300
301 pub fn server_handle(&self) -> Handle {
303 self.handle.clone()
304 }
305
306 async fn start_http(self, addr: SocketAddr) -> TransportServerResult<()> {
314 tracing::info!("{}", self.server_info(Some(addr)).await?);
315
316 let handle_clone = self.handle.clone();
318 tokio::spawn(async move {
319 shutdown_signal(handle_clone).await;
320 });
321
322 let handle_clone = self.handle.clone();
323 axum_server::bind(addr)
324 .handle(handle_clone)
325 .serve(self.app.into_make_service())
326 .await
327 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
328 }
329
330 pub async fn start(self) -> TransportServerResult<()> {
338 let addr = self.options.resolve_server_address().await?;
339
340 #[cfg(feature = "ssl")]
341 if self.options.enable_ssl {
342 self.start_ssl(addr).await
343 } else {
344 self.start_http(addr).await
345 }
346
347 #[cfg(not(feature = "ssl"))]
348 if self.options.enable_ssl {
349 panic!("SSL requested but the 'ssl' feature is not enabled");
350 } else {
351 self.start_http(addr).await
352 }
353 }
354}
355
356async fn shutdown_signal(handle: Handle) {
358 let ctrl_c = async {
360 signal::ctrl_c()
361 .await
362 .expect("Failed to install Ctrl+C handler");
363 };
364
365 #[cfg(unix)]
366 let terminate = async {
367 signal::unix::signal(signal::unix::SignalKind::terminate())
368 .expect("Failed to install signal handler")
369 .recv()
370 .await;
371 };
372
373 #[cfg(not(unix))]
374 let terminate = std::future::pending::<()>();
375
376 tokio::select! {
377 _ = ctrl_c => {},
378 _ = terminate => {},
379 }
380
381 tracing::info!("Signal received, starting graceful shutdown");
382 handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS)));
384}