rust_mcp_sdk/hyper_servers/
server.rs1use crate::{
2 error::SdkResult, mcp_server::hyper_runtime::HyperRuntime,
3 mcp_traits::mcp_handler::McpServerHandler,
4};
5#[cfg(feature = "ssl")]
6use axum_server::tls_rustls::RustlsConfig;
7use axum_server::Handle;
8use std::{
9 net::{SocketAddr, ToSocketAddrs},
10 path::Path,
11 sync::Arc,
12 time::Duration,
13};
14use tokio::signal;
15
16use super::{
17 app_state::AppState,
18 error::{TransportServerError, TransportServerResult},
19 routes::app_routes,
20 IdGenerator, InMemorySessionStore, UuidGenerator,
21};
22use crate::schema::InitializeResult;
23use axum::Router;
24use rust_mcp_transport::TransportOptions;
25
26const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12);
28const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 5;
29const DEFAULT_SSE_ENDPOINT: &str = "/sse";
31const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages";
33const DEFAULT_STREAMABLE_HTTP_ENDPOINT: &str = "/mcp";
35
36pub struct HyperServerOptions {
39 pub host: String,
41
42 pub port: u16,
44
45 pub session_id_generator: Option<Arc<dyn IdGenerator>>,
47
48 pub custom_streamable_http_endpoint: Option<String>,
50
51 pub transport_options: Arc<TransportOptions>,
53
54 pub enable_json_response: Option<bool>,
59
60 pub ping_interval: Duration,
62
63 pub enable_ssl: bool,
65
66 pub ssl_cert_path: Option<String>,
69
70 pub ssl_key_path: Option<String>,
73
74 pub allowed_hosts: Option<Vec<String>>,
77
78 pub allowed_origins: Option<Vec<String>>,
81
82 pub dns_rebinding_protection: bool,
85
86 pub sse_support: bool,
88
89 pub custom_sse_endpoint: Option<String>,
92
93 pub custom_messages_endpoint: Option<String>,
96}
97
98impl HyperServerOptions {
99 pub fn validate(&self) -> TransportServerResult<()> {
106 if self.enable_ssl {
107 if self.ssl_cert_path.is_none() || self.ssl_key_path.is_none() {
108 return Err(TransportServerError::InvalidServerOptions(
109 "Both 'ssl_cert_path' and 'ssl_key_path' must be provided when SSL is enabled."
110 .into(),
111 ));
112 }
113
114 if !Path::new(self.ssl_cert_path.as_deref().unwrap_or("")).is_file() {
115 return Err(TransportServerError::InvalidServerOptions(
116 "'ssl_cert_path' does not point to a valid or existing file.".into(),
117 ));
118 }
119
120 if !Path::new(self.ssl_key_path.as_deref().unwrap_or("")).is_file() {
121 return Err(TransportServerError::InvalidServerOptions(
122 "'ssl_key_path' does not point to a valid or existing file.".into(),
123 ));
124 }
125 }
126
127 Ok(())
128 }
129
130 pub(crate) async fn resolve_server_address(&self) -> TransportServerResult<SocketAddr> {
138 self.validate()?;
139
140 let mut host = self.host.to_string();
141 if let Some(stripped) = self.host.strip_prefix("http://") {
142 if self.enable_ssl {
143 tracing::warn!("Warning: Ignoring http:// scheme for SSL; using hostname only");
144 }
145 host = stripped.to_string();
146 } else if let Some(stripped) = host.strip_prefix("https://") {
147 host = stripped.to_string();
148 }
149
150 let addr = {
151 let mut iter = (host, self.port)
152 .to_socket_addrs()
153 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))?;
154 match iter.next() {
155 Some(addr) => addr,
156 None => format!("{}:{}", self.host, self.port).parse().map_err(
157 |err: std::net::AddrParseError| {
158 TransportServerError::ServerStartError(err.to_string())
159 },
160 )?,
161 }
162 };
163 Ok(addr)
164 }
165
166 pub fn base_url(&self) -> String {
167 format!(
168 "{}://{}:{}",
169 if self.enable_ssl { "https" } else { "http" },
170 self.host,
171 self.port
172 )
173 }
174 pub fn streamable_http_url(&self) -> String {
175 format!("{}{}", self.base_url(), self.streamable_http_endpoint())
176 }
177 pub fn sse_url(&self) -> String {
178 format!("{}{}", self.base_url(), self.sse_endpoint())
179 }
180 pub fn sse_message_url(&self) -> String {
181 format!("{}{}", self.base_url(), self.sse_messages_endpoint())
182 }
183
184 pub fn sse_endpoint(&self) -> &str {
185 self.custom_sse_endpoint
186 .as_deref()
187 .unwrap_or(DEFAULT_SSE_ENDPOINT)
188 }
189
190 pub fn sse_messages_endpoint(&self) -> &str {
191 self.custom_messages_endpoint
192 .as_deref()
193 .unwrap_or(DEFAULT_MESSAGES_ENDPOINT)
194 }
195
196 pub fn streamable_http_endpoint(&self) -> &str {
197 self.custom_messages_endpoint
198 .as_deref()
199 .unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT)
200 }
201}
202
203impl Default for HyperServerOptions {
208 fn default() -> Self {
209 Self {
210 host: "127.0.0.1".to_string(),
211 port: 8080,
212 custom_sse_endpoint: None,
213 custom_streamable_http_endpoint: None,
214 custom_messages_endpoint: None,
215 ping_interval: DEFAULT_CLIENT_PING_INTERVAL,
216 transport_options: Default::default(),
217 enable_ssl: false,
218 ssl_cert_path: None,
219 ssl_key_path: None,
220 session_id_generator: None,
221 enable_json_response: None,
222 sse_support: true,
223 allowed_hosts: None,
224 allowed_origins: None,
225 dns_rebinding_protection: false,
226 }
227 }
228}
229
230pub struct HyperServer {
232 app: Router,
233 state: Arc<AppState>,
234 pub(crate) options: HyperServerOptions,
235 handle: Handle,
236}
237
238impl HyperServer {
239 pub(crate) fn new(
251 server_details: InitializeResult,
252 handler: Arc<dyn McpServerHandler + 'static>,
253 mut server_options: HyperServerOptions,
254 ) -> Self {
255 let state: Arc<AppState> = Arc::new(AppState {
256 session_store: Arc::new(InMemorySessionStore::new()),
257 id_generator: server_options
258 .session_id_generator
259 .take()
260 .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)),
261 server_details: Arc::new(server_details),
262 handler,
263 ping_interval: server_options.ping_interval,
264 sse_message_endpoint: server_options.sse_messages_endpoint().to_owned(),
265 http_streamable_endpoint: server_options.streamable_http_endpoint().to_owned(),
266 transport_options: Arc::clone(&server_options.transport_options),
267 enable_json_response: server_options.enable_json_response.unwrap_or(false),
268 allowed_hosts: server_options.allowed_hosts.take(),
269 allowed_origins: server_options.allowed_origins.take(),
270 dns_rebinding_protection: server_options.dns_rebinding_protection,
271 });
272 let app = app_routes(Arc::clone(&state), &server_options);
273 Self {
274 app,
275 state,
276 options: server_options,
277 handle: Handle::new(),
278 }
279 }
280
281 pub fn state(&self) -> Arc<AppState> {
286 Arc::clone(&self.state)
287 }
288
289 pub fn with_route(mut self, path: &'static str, route: axum::routing::MethodRouter) -> Self {
298 self.app = self.app.route(path, route);
299 self
300 }
301
302 pub async fn server_info(&self, addr: Option<SocketAddr>) -> TransportServerResult<String> {
312 let addr = addr.unwrap_or(self.options.resolve_server_address().await?);
313 let server_type = if self.options.enable_ssl {
314 "SSL server"
315 } else {
316 "Server"
317 };
318 let protocol = if self.options.enable_ssl {
319 "https"
320 } else {
321 "http"
322 };
323
324 let mut server_url = format!(
325 "\n• Streamable HTTP {} is available at {}://{}{}",
326 server_type,
327 protocol,
328 addr,
329 self.options.streamable_http_endpoint()
330 );
331
332 if self.options.sse_support {
333 let sse_url = format!(
334 "\n• SSE {} is available at {}://{}{}",
335 server_type,
336 protocol,
337 addr,
338 self.options.sse_endpoint()
339 );
340 server_url.push_str(&sse_url);
341 };
342
343 Ok(server_url)
344 }
345
346 pub fn options(&self) -> &HyperServerOptions {
347 &self.options
348 }
349
350 #[cfg(feature = "ssl")]
367 pub(crate) async fn start_ssl(self, addr: SocketAddr) -> TransportServerResult<()> {
368 let config = RustlsConfig::from_pem_file(
369 self.options.ssl_cert_path.as_deref().unwrap_or_default(),
370 self.options.ssl_key_path.as_deref().unwrap_or_default(),
371 )
372 .await
373 .map_err(|err| TransportServerError::SslCertError(err.to_string()))?;
374
375 tracing::info!("{}", self.server_info(Some(addr)).await?);
376
377 let handle_clone = self.handle.clone();
379 let state_clone = self.state().clone();
380 tokio::spawn(async move {
381 shutdown_signal(handle_clone, state_clone).await;
382 });
383
384 let handle_clone = self.handle.clone();
385 axum_server::bind_rustls(addr, config)
386 .handle(handle_clone)
387 .serve(self.app.into_make_service())
388 .await
389 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
390 }
391
392 pub fn server_handle(&self) -> Handle {
394 self.handle.clone()
395 }
396
397 pub(crate) async fn start_http(self, addr: SocketAddr) -> TransportServerResult<()> {
405 tracing::info!("{}", self.server_info(Some(addr)).await?);
406
407 let handle_clone = self.handle.clone();
409 tokio::spawn(async move {
410 shutdown_signal(handle_clone, self.state.clone()).await;
411 });
412
413 let handle_clone = self.handle.clone();
414 axum_server::bind(addr)
415 .handle(handle_clone)
416 .serve(self.app.into_make_service())
417 .await
418 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
419 }
420
421 pub async fn start(self) -> SdkResult<()> {
429 let runtime = HyperRuntime::create(self).await?;
430 runtime.await_server().await
431 }
432
433 pub async fn start_runtime(self) -> SdkResult<HyperRuntime> {
440 HyperRuntime::create(self).await
441 }
442}
443
444async fn shutdown_signal(handle: Handle, state: Arc<AppState>) {
446 let ctrl_c = async {
448 signal::ctrl_c()
449 .await
450 .expect("Failed to install Ctrl+C handler");
451 };
452
453 #[cfg(unix)]
454 let terminate = async {
455 signal::unix::signal(signal::unix::SignalKind::terminate())
456 .expect("Failed to install signal handler")
457 .recv()
458 .await;
459 };
460
461 #[cfg(not(unix))]
462 let terminate = std::future::pending::<()>();
463
464 tokio::select! {
465 _ = ctrl_c => {},
466 _ = terminate => {},
467 }
468
469 tracing::info!("Signal received, starting graceful shutdown");
470 state.session_store.clear().await;
471 handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS)));
473}