rust_mcp_sdk/hyper_servers/
server.rs1use crate::{
2 error::SdkResult,
3 id_generator::{FastIdGenerator, UuidGenerator},
4 mcp_http::{
5 utils::{
6 DEFAULT_MESSAGES_ENDPOINT, DEFAULT_SSE_ENDPOINT, DEFAULT_STREAMABLE_HTTP_ENDPOINT,
7 },
8 McpAppState,
9 },
10 mcp_server::hyper_runtime::HyperRuntime,
11 mcp_traits::{mcp_handler::McpServerHandler, IdGenerator},
12 session_store::InMemorySessionStore,
13};
14#[cfg(feature = "ssl")]
15use axum_server::tls_rustls::RustlsConfig;
16use axum_server::Handle;
17use std::{
18 net::{SocketAddr, ToSocketAddrs},
19 path::Path,
20 sync::Arc,
21 time::Duration,
22};
23use tokio::signal;
24
25use super::{
26 error::{TransportServerError, TransportServerResult},
27 routes::app_routes,
28};
29use crate::schema::InitializeResult;
30use axum::Router;
31use rust_mcp_transport::{event_store::EventStore, SessionId, TransportOptions};
32
33const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12);
35const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 5;
36
37pub struct HyperServerOptions {
40 pub host: String,
42
43 pub port: u16,
45
46 pub session_id_generator: Option<Arc<dyn IdGenerator<SessionId>>>,
48
49 pub custom_streamable_http_endpoint: Option<String>,
51
52 pub transport_options: Arc<TransportOptions>,
54
55 pub event_store: Option<Arc<dyn EventStore>>,
58
59 pub enable_json_response: Option<bool>,
64
65 pub ping_interval: Duration,
67
68 pub enable_ssl: bool,
70
71 pub ssl_cert_path: Option<String>,
74
75 pub ssl_key_path: Option<String>,
78
79 pub allowed_hosts: Option<Vec<String>>,
82
83 pub allowed_origins: Option<Vec<String>>,
86
87 pub dns_rebinding_protection: bool,
90
91 pub sse_support: bool,
93
94 pub custom_sse_endpoint: Option<String>,
97
98 pub custom_messages_endpoint: Option<String>,
101}
102
103impl HyperServerOptions {
104 pub fn validate(&self) -> TransportServerResult<()> {
111 if self.enable_ssl {
112 if self.ssl_cert_path.is_none() || self.ssl_key_path.is_none() {
113 return Err(TransportServerError::InvalidServerOptions(
114 "Both 'ssl_cert_path' and 'ssl_key_path' must be provided when SSL is enabled."
115 .into(),
116 ));
117 }
118
119 if !Path::new(self.ssl_cert_path.as_deref().unwrap_or("")).is_file() {
120 return Err(TransportServerError::InvalidServerOptions(
121 "'ssl_cert_path' does not point to a valid or existing file.".into(),
122 ));
123 }
124
125 if !Path::new(self.ssl_key_path.as_deref().unwrap_or("")).is_file() {
126 return Err(TransportServerError::InvalidServerOptions(
127 "'ssl_key_path' does not point to a valid or existing file.".into(),
128 ));
129 }
130 }
131
132 Ok(())
133 }
134
135 pub(crate) async fn resolve_server_address(&self) -> TransportServerResult<SocketAddr> {
143 self.validate()?;
144
145 let mut host = self.host.to_string();
146 if let Some(stripped) = self.host.strip_prefix("http://") {
147 if self.enable_ssl {
148 tracing::warn!("Warning: Ignoring http:// scheme for SSL; using hostname only");
149 }
150 host = stripped.to_string();
151 } else if let Some(stripped) = host.strip_prefix("https://") {
152 host = stripped.to_string();
153 }
154
155 let addr = {
156 let mut iter = (host, self.port)
157 .to_socket_addrs()
158 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))?;
159 match iter.next() {
160 Some(addr) => addr,
161 None => format!("{}:{}", self.host, self.port).parse().map_err(
162 |err: std::net::AddrParseError| {
163 TransportServerError::ServerStartError(err.to_string())
164 },
165 )?,
166 }
167 };
168 Ok(addr)
169 }
170
171 pub fn base_url(&self) -> String {
172 format!(
173 "{}://{}:{}",
174 if self.enable_ssl { "https" } else { "http" },
175 self.host,
176 self.port
177 )
178 }
179 pub fn streamable_http_url(&self) -> String {
180 format!("{}{}", self.base_url(), self.streamable_http_endpoint())
181 }
182 pub fn sse_url(&self) -> String {
183 format!("{}{}", self.base_url(), self.sse_endpoint())
184 }
185 pub fn sse_message_url(&self) -> String {
186 format!("{}{}", self.base_url(), self.sse_messages_endpoint())
187 }
188
189 pub fn sse_endpoint(&self) -> &str {
190 self.custom_sse_endpoint
191 .as_deref()
192 .unwrap_or(DEFAULT_SSE_ENDPOINT)
193 }
194
195 pub fn sse_messages_endpoint(&self) -> &str {
196 self.custom_messages_endpoint
197 .as_deref()
198 .unwrap_or(DEFAULT_MESSAGES_ENDPOINT)
199 }
200
201 pub fn streamable_http_endpoint(&self) -> &str {
202 self.custom_messages_endpoint
203 .as_deref()
204 .unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT)
205 }
206}
207
208impl Default for HyperServerOptions {
213 fn default() -> Self {
214 Self {
215 host: "127.0.0.1".to_string(),
216 port: 8080,
217 custom_sse_endpoint: None,
218 custom_streamable_http_endpoint: None,
219 custom_messages_endpoint: None,
220 ping_interval: DEFAULT_CLIENT_PING_INTERVAL,
221 transport_options: Default::default(),
222 enable_ssl: false,
223 ssl_cert_path: None,
224 ssl_key_path: None,
225 session_id_generator: None,
226 enable_json_response: None,
227 sse_support: true,
228 allowed_hosts: None,
229 allowed_origins: None,
230 dns_rebinding_protection: false,
231 event_store: None,
232 }
233 }
234}
235
236pub struct HyperServer {
238 app: Router,
239 state: Arc<McpAppState>,
240 pub(crate) options: HyperServerOptions,
241 handle: Handle,
242}
243
244impl HyperServer {
245 pub(crate) fn new(
257 server_details: InitializeResult,
258 handler: Arc<dyn McpServerHandler + 'static>,
259 mut server_options: HyperServerOptions,
260 ) -> Self {
261 let state: Arc<McpAppState> = Arc::new(McpAppState {
262 session_store: Arc::new(InMemorySessionStore::new()),
263 id_generator: server_options
264 .session_id_generator
265 .take()
266 .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)),
267 stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))),
268 server_details: Arc::new(server_details),
269 handler,
270 ping_interval: server_options.ping_interval,
271 transport_options: Arc::clone(&server_options.transport_options),
272 enable_json_response: server_options.enable_json_response.unwrap_or(false),
273 allowed_hosts: server_options.allowed_hosts.take(),
274 allowed_origins: server_options.allowed_origins.take(),
275 dns_rebinding_protection: server_options.dns_rebinding_protection,
276 event_store: server_options.event_store.as_ref().map(Arc::clone),
277 });
278 let app = app_routes(Arc::clone(&state), &server_options);
279 Self {
280 app,
281 state,
282 options: server_options,
283 handle: Handle::new(),
284 }
285 }
286
287 pub fn state(&self) -> Arc<McpAppState> {
292 Arc::clone(&self.state)
293 }
294
295 pub fn with_route(mut self, path: &'static str, route: axum::routing::MethodRouter) -> Self {
304 self.app = self.app.route(path, route);
305 self
306 }
307
308 pub async fn server_info(&self, addr: Option<SocketAddr>) -> TransportServerResult<String> {
318 let addr = addr.unwrap_or(self.options.resolve_server_address().await?);
319 let server_type = if self.options.enable_ssl {
320 "SSL server"
321 } else {
322 "Server"
323 };
324 let protocol = if self.options.enable_ssl {
325 "https"
326 } else {
327 "http"
328 };
329
330 let mut server_url = format!(
331 "\n• Streamable HTTP {} is available at {}://{}{}",
332 server_type,
333 protocol,
334 addr,
335 self.options.streamable_http_endpoint()
336 );
337
338 if self.options.sse_support {
339 let sse_url = format!(
340 "\n• SSE {} is available at {}://{}{}",
341 server_type,
342 protocol,
343 addr,
344 self.options.sse_endpoint()
345 );
346 server_url.push_str(&sse_url);
347 };
348
349 Ok(server_url)
350 }
351
352 pub fn options(&self) -> &HyperServerOptions {
353 &self.options
354 }
355
356 #[cfg(feature = "ssl")]
373 pub(crate) async fn start_ssl(self, addr: SocketAddr) -> TransportServerResult<()> {
374 let config = RustlsConfig::from_pem_file(
375 self.options.ssl_cert_path.as_deref().unwrap_or_default(),
376 self.options.ssl_key_path.as_deref().unwrap_or_default(),
377 )
378 .await
379 .map_err(|err| TransportServerError::SslCertError(err.to_string()))?;
380
381 tracing::info!("{}", self.server_info(Some(addr)).await?);
382
383 let handle_clone = self.handle.clone();
385 let state_clone = self.state().clone();
386 tokio::spawn(async move {
387 shutdown_signal(handle_clone, state_clone).await;
388 });
389
390 let handle_clone = self.handle.clone();
391 axum_server::bind_rustls(addr, config)
392 .handle(handle_clone)
393 .serve(self.app.into_make_service())
394 .await
395 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
396 }
397
398 pub fn server_handle(&self) -> Handle {
400 self.handle.clone()
401 }
402
403 pub(crate) async fn start_http(self, addr: SocketAddr) -> TransportServerResult<()> {
411 tracing::info!("{}", self.server_info(Some(addr)).await?);
412
413 let handle_clone = self.handle.clone();
415 tokio::spawn(async move {
416 shutdown_signal(handle_clone, self.state.clone()).await;
417 });
418
419 let handle_clone = self.handle.clone();
420 axum_server::bind(addr)
421 .handle(handle_clone)
422 .serve(self.app.into_make_service())
423 .await
424 .map_err(|err| TransportServerError::ServerStartError(err.to_string()))
425 }
426
427 pub async fn start(self) -> SdkResult<()> {
435 let runtime = HyperRuntime::create(self).await?;
436 runtime.await_server().await
437 }
438
439 pub async fn start_runtime(self) -> SdkResult<HyperRuntime> {
446 HyperRuntime::create(self).await
447 }
448}
449
450async fn shutdown_signal(handle: Handle, state: Arc<McpAppState>) {
452 let ctrl_c = async {
454 signal::ctrl_c()
455 .await
456 .expect("Failed to install Ctrl+C handler");
457 };
458
459 #[cfg(unix)]
460 let terminate = async {
461 signal::unix::signal(signal::unix::SignalKind::terminate())
462 .expect("Failed to install signal handler")
463 .recv()
464 .await;
465 };
466
467 #[cfg(not(unix))]
468 let terminate = std::future::pending::<()>();
469
470 tokio::select! {
471 _ = ctrl_c => {},
472 _ = terminate => {},
473 }
474
475 tracing::info!("Signal received, starting graceful shutdown");
476 state.session_store.clear().await;
477 handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS)));
479}