rust_mcp_sdk/mcp_runtimes/
server_runtime.rs1pub mod mcp_server_runtime;
2pub mod mcp_server_runtime_core;
3use crate::error::SdkResult;
4use crate::mcp_traits::mcp_handler::McpServerHandler;
5use crate::mcp_traits::mcp_server::McpServer;
6use crate::mcp_traits::{RequestIdGen, RequestIdGenNumeric};
7use crate::schema::{
8 schema_utils::{
9 ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage,
10 ServerMessages,
11 },
12 InitializeRequestParams, InitializeResult, RequestId, RpcError,
13};
14use crate::utils::AbortTaskOnDrop;
15use async_trait::async_trait;
16use futures::future::try_join_all;
17use futures::{StreamExt, TryFutureExt};
18#[cfg(feature = "hyper-server")]
19use rust_mcp_transport::SessionId;
20use rust_mcp_transport::{IoStream, TransportDispatcher};
21use std::collections::HashMap;
22use std::panic;
23use std::sync::Arc;
24use std::time::Duration;
25use tokio::io::AsyncWriteExt;
26
27use tokio::sync::{mpsc, oneshot, watch};
28
29pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM";
30const TASK_CHANNEL_CAPACITY: usize = 500;
31
32type TransportType = Arc<
34 dyn TransportDispatcher<
35 ClientMessages,
36 MessageFromServer,
37 ClientMessage,
38 ServerMessages,
39 ServerMessage,
40 >,
41>;
42
43pub struct ServerRuntime {
45 handler: Arc<dyn McpServerHandler>,
47 server_details: Arc<InitializeResult>,
49 #[cfg(feature = "hyper-server")]
50 session_id: Option<SessionId>,
51 transport_map: tokio::sync::RwLock<HashMap<String, TransportType>>, request_id_gen: Box<dyn RequestIdGen>,
53 client_details_tx: watch::Sender<Option<InitializeRequestParams>>,
54 client_details_rx: watch::Receiver<Option<InitializeRequestParams>>,
55}
56
57#[async_trait]
58impl McpServer for ServerRuntime {
59 async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> {
61 self.client_details_tx
62 .send(Some(client_details))
63 .map_err(|_| {
64 RpcError::internal_error()
65 .with_message("Failed to set client details".to_string())
66 .into()
67 })
68 }
69
70 async fn wait_for_initialization(&self) {
71 loop {
72 if self.client_details_rx.borrow().is_some() {
73 return;
74 }
75 let mut rx = self.client_details_rx.clone();
76 rx.changed().await.ok();
77 }
78 }
79
80 async fn send(
81 &self,
82 message: MessageFromServer,
83 request_id: Option<RequestId>,
84 request_timeout: Option<Duration>,
85 ) -> SdkResult<Option<ClientMessage>> {
86 let transport_map = self.transport_map.read().await;
87 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
88 RpcError::internal_error()
89 .with_message("transport stream does not exists or is closed!".to_string()),
90 )?;
91
92 let outgoing_request_id = self
93 .request_id_gen
94 .request_id_for_message(&message, request_id);
95
96 let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?;
97
98 let response = transport
99 .send_message(ServerMessages::Single(mcp_message), request_timeout)
100 .await?
101 .map(|res| res.as_single())
102 .transpose()?;
103
104 Ok(response)
105 }
106
107 async fn send_batch(
108 &self,
109 messages: Vec<ServerMessage>,
110 request_timeout: Option<Duration>,
111 ) -> SdkResult<Option<Vec<ClientMessage>>> {
112 let transport_map = self.transport_map.read().await;
113 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
114 RpcError::internal_error()
115 .with_message("transport stream does not exists or is closed!".to_string()),
116 )?;
117
118 transport
119 .send_batch(messages, request_timeout)
120 .map_err(|err| err.into())
121 .await
122 }
123
124 fn server_info(&self) -> &InitializeResult {
127 &self.server_details
128 }
129
130 fn client_info(&self) -> Option<InitializeRequestParams> {
132 self.client_details_rx.borrow().clone()
133 }
134
135 async fn start(self: Arc<Self>) -> SdkResult<()> {
137 let self_clone = self.clone();
138 let transport_map = self_clone.transport_map.read().await;
139
140 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
141 RpcError::internal_error()
142 .with_message("transport stream does not exists or is closed!".to_string()),
143 )?;
144
145 let mut stream = transport.start().await?;
146
147 let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY);
149
150 while let Some(mcp_messages) = stream.next().await {
152 match mcp_messages {
153 ClientMessages::Single(client_message) => {
154 let transport = transport.clone();
155 let self = self.clone();
156 let tx = tx.clone();
157
158 tokio::spawn(async move {
160 let result = self.handle_message(client_message, &transport).await;
161
162 let send_result: SdkResult<_> = match result {
163 Ok(result) => {
164 if let Some(result) = result {
165 transport
166 .send_message(ServerMessages::Single(result), None)
167 .map_err(|e| e.into())
168 .await
169 } else {
170 Ok(None)
171 }
172 }
173 Err(error) => {
174 tracing::error!("Error handling message : {}", error);
175 Ok(None)
176 }
177 };
178 if let Err(error) = tx.send(send_result).await {
180 tracing::error!("Failed to send result to channel: {}", error);
181 }
182 });
183 }
184 ClientMessages::Batch(client_messages) => {
185 let transport = transport.clone();
186 let self = self_clone.clone();
187 let tx = tx.clone();
188
189 tokio::spawn(async move {
190 let handling_tasks: Vec<_> = client_messages
191 .into_iter()
192 .map(|client_message| self.handle_message(client_message, &transport))
193 .collect();
194
195 let send_result = match try_join_all(handling_tasks).await {
196 Ok(results) => {
197 let results: Vec<_> = results.into_iter().flatten().collect();
198 if !results.is_empty() {
199 transport
200 .send_message(ServerMessages::Batch(results), None)
201 .map_err(|e| e.into())
202 .await
203 } else {
204 Ok(None)
205 }
206 }
207 Err(error) => Err(error),
208 };
209
210 if let Err(error) = tx.send(send_result).await {
211 tracing::error!("Failed to send batch result to channel: {}", error);
212 }
213 });
214 }
215 }
216
217 while let Ok(result) = rx.try_recv() {
219 result?; }
221 }
222
223 drop(tx);
225 while let Some(result) = rx.recv().await {
226 result?; }
228
229 return Ok(());
230 }
231
232 async fn stderr_message(&self, message: String) -> SdkResult<()> {
233 let transport_map = self.transport_map.read().await;
234 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
235 RpcError::internal_error()
236 .with_message("transport stream does not exists or is closed!".to_string()),
237 )?;
238 let mut lock = transport.error_stream().write().await;
239
240 if let Some(IoStream::Writable(stderr)) = lock.as_mut() {
241 stderr.write_all(message.as_bytes()).await?;
242 stderr.write_all(b"\n").await?;
243 stderr.flush().await?;
244 }
245 Ok(())
246 }
247
248 #[cfg(feature = "hyper-server")]
249 fn session_id(&self) -> Option<SessionId> {
250 self.session_id.to_owned()
251 }
252}
253
254impl ServerRuntime {
255 pub(crate) async fn consume_payload_string(
256 &self,
257 stream_id: &str,
258 payload: &str,
259 ) -> SdkResult<()> {
260 let transport_map = self.transport_map.read().await;
261
262 let transport = transport_map.get(stream_id).ok_or(
263 RpcError::internal_error()
264 .with_message("stream id does not exists or is closed!".to_string()),
265 )?;
266
267 transport.consume_string_payload(payload).await?;
268
269 Ok(())
270 }
271
272 pub(crate) async fn handle_message(
273 self: &Arc<Self>,
274 message: ClientMessage,
275 transport: &Arc<
276 dyn TransportDispatcher<
277 ClientMessages,
278 MessageFromServer,
279 ClientMessage,
280 ServerMessages,
281 ServerMessage,
282 >,
283 >,
284 ) -> SdkResult<Option<ServerMessage>> {
285 let response = match message {
286 ClientMessage::Request(client_jsonrpc_request) => {
288 let result = self
289 .handler
290 .handle_request(client_jsonrpc_request.request, self.clone())
291 .await;
292 let response: MessageFromServer = match result {
294 Ok(success_value) => success_value.into(),
295 Err(error_value) => {
296 if !self.is_initialized() {
299 return Err(error_value.into());
300 }
301 MessageFromServer::Error(error_value)
302 }
303 };
304
305 let mpc_message: ServerMessage =
306 ServerMessage::from_message(response, Some(client_jsonrpc_request.id))?;
307
308 Some(mpc_message)
309 }
310 ClientMessage::Notification(client_jsonrpc_notification) => {
311 self.handler
312 .handle_notification(client_jsonrpc_notification.notification, self.clone())
313 .await?;
314 None
315 }
316 ClientMessage::Error(jsonrpc_error) => {
317 self.handler
318 .handle_error(&jsonrpc_error.error, self.clone())
319 .await?;
320 if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await {
321 tx_response
322 .send(ClientMessage::Error(jsonrpc_error))
323 .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
324 } else {
325 tracing::warn!(
326 "Received an error response with no corresponding request {:?}",
327 &jsonrpc_error.id
328 );
329 }
330 None
331 }
332 ClientMessage::Response(response) => {
333 if let Some(tx_response) = transport.pending_request_tx(&response.id).await {
334 tx_response
335 .send(ClientMessage::Response(response))
336 .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
337 } else {
338 tracing::warn!(
339 "Received a response with no corresponding request: {:?}",
340 &response.id
341 );
342 }
343 None
344 }
345 };
346 Ok(response)
347 }
348
349 pub(crate) async fn store_transport(
350 &self,
351 stream_id: &str,
352 transport: Arc<
353 dyn TransportDispatcher<
354 ClientMessages,
355 MessageFromServer,
356 ClientMessage,
357 ServerMessages,
358 ServerMessage,
359 >,
360 >,
361 ) -> SdkResult<()> {
362 if stream_id != DEFAULT_STREAM_ID {
363 return Ok(());
364 }
365 let mut transport_map = self.transport_map.write().await;
366 tracing::trace!("save transport for stream id : {}", stream_id);
367 transport_map.insert(stream_id.to_string(), transport);
368 Ok(())
369 }
370
371 pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> {
373 if stream_id != DEFAULT_STREAM_ID {
374 return Ok(());
375 }
376 let transport_map = self.transport_map.read().await;
377 tracing::trace!("removing transport for stream id : {}", stream_id);
378 if let Some(transport) = transport_map.get(stream_id) {
379 transport.shut_down().await?;
380 }
381 Ok(())
383 }
384
385 pub(crate) async fn shutdown(&self) {
386 let mut transport_map = self.transport_map.write().await;
387 let items: Vec<_> = transport_map.drain().map(|(_, v)| v).collect();
388 drop(transport_map);
389 for item in items {
390 let _ = item.shut_down().await;
391 }
392 }
393
394 pub(crate) async fn stream_id_exists(&self, stream_id: &str) -> bool {
395 let transport_map = self.transport_map.read().await;
396 let live_transport = if let Some(t) = transport_map.get(stream_id) {
397 !t.is_shut_down().await
398 } else {
399 false
400 };
401 live_transport
402 }
403
404 pub(crate) async fn start_stream(
405 self: Arc<Self>,
406 transport: Arc<
407 dyn TransportDispatcher<
408 ClientMessages,
409 MessageFromServer,
410 ClientMessage,
411 ServerMessages,
412 ServerMessage,
413 >,
414 >,
415 stream_id: &str,
416 ping_interval: Duration,
417 payload: Option<String>,
418 ) -> SdkResult<()> {
419 let mut stream = transport.start().await?;
420
421 if stream_id == DEFAULT_STREAM_ID {
422 self.store_transport(stream_id, transport.clone()).await?;
423 }
424
425 let self_clone = self.clone();
426
427 let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>();
428 let abort_alive_task = transport
429 .keep_alive(ping_interval, disconnect_tx)
430 .await?
431 .abort_handle();
432
433 let _abort_guard = AbortTaskOnDrop {
435 handle: abort_alive_task,
436 };
437
438 if let Some(payload) = payload {
441 if let Err(err) = transport.consume_string_payload(&payload).await {
442 let _ = self.remove_transport(stream_id).await;
443 return Err(err.into());
444 }
445 }
446
447 let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY);
449
450 loop {
451 tokio::select! {
452 Some(mcp_messages) = stream.next() =>{
453
454 match mcp_messages {
455 ClientMessages::Single(client_message) => {
456 let transport = transport.clone();
457 let self_clone = self.clone();
458 let tx = tx.clone();
459 tokio::spawn(async move {
460
461 let result = self_clone.handle_message(client_message, &transport).await;
462
463 let send_result: SdkResult<_> = match result {
464 Ok(result) => {
465 if let Some(result) = result {
466 transport
467 .send_message(ServerMessages::Single(result), None)
468 .map_err(|e| e.into())
469 .await
470 } else {
471 Ok(None)
472 }
473 }
474 Err(error) => {
475 tracing::error!("Error handling message : {}", error);
476 Ok(None)
477 }
478 };
479 if let Err(error) = tx.send(send_result).await {
480 tracing::error!("Failed to send batch result to channel: {}", error);
481 }
482 });
483 }
484 ClientMessages::Batch(client_messages) => {
485
486 let transport = transport.clone();
487 let self_clone = self_clone.clone();
488 let tx = tx.clone();
489
490 tokio::spawn(async move {
491 let handling_tasks: Vec<_> = client_messages
492 .into_iter()
493 .map(|client_message| self_clone.handle_message(client_message, &transport))
494 .collect();
495
496 let send_result = match try_join_all(handling_tasks).await {
497 Ok(results) => {
498 let results: Vec<_> = results.into_iter().flatten().collect();
499 if !results.is_empty() {
500 transport.send_message(ServerMessages::Batch(results), None)
501 .map_err(|e| e.into())
502 .await
503 }else {
504 Ok(None)
505 }
506 },
507 Err(error) => Err(error),
508 };
509 if let Err(error) = tx.send(send_result).await {
510 tracing::error!("Failed to send batch result to channel: {}", error);
511 }
512 });
513 }
514 }
515
516 while let Ok(result) = rx.try_recv() {
518 result?; }
520
521 if !stream_id.eq(DEFAULT_STREAM_ID){
523 drop(tx);
525 while let Some(result) = rx.recv().await {
526 result?; }
528 return Ok(());
529 }
530 }
531 _ = &mut disconnect_rx => {
532 drop(tx);
534 while let Some(result) = rx.recv().await {
535 result?; }
537 self.remove_transport(stream_id).await?;
538 return Err(SdkError::connection_closed().into());
540
541 }
542 }
543 }
544 }
545
546 #[cfg(feature = "hyper-server")]
547 pub(crate) fn new_instance(
548 server_details: Arc<InitializeResult>,
549 handler: Arc<dyn McpServerHandler>,
550 session_id: SessionId,
551 ) -> Arc<Self> {
552 let (client_details_tx, client_details_rx) =
553 watch::channel::<Option<InitializeRequestParams>>(None);
554 Arc::new(Self {
555 server_details,
556 handler,
557 session_id: Some(session_id),
558 transport_map: tokio::sync::RwLock::new(HashMap::new()),
559 client_details_tx,
560 client_details_rx,
561 request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
562 })
563 }
564
565 pub(crate) fn new(
566 server_details: InitializeResult,
567 transport: impl TransportDispatcher<
568 ClientMessages,
569 MessageFromServer,
570 ClientMessage,
571 ServerMessages,
572 ServerMessage,
573 >,
574 handler: Arc<dyn McpServerHandler>,
575 ) -> Arc<Self> {
576 let mut map: HashMap<String, TransportType> = HashMap::new();
577 map.insert(DEFAULT_STREAM_ID.to_string(), Arc::new(transport));
578 let (client_details_tx, client_details_rx) =
579 watch::channel::<Option<InitializeRequestParams>>(None);
580 Arc::new(Self {
581 server_details: Arc::new(server_details),
582 handler,
583 #[cfg(feature = "hyper-server")]
584 session_id: None,
585 transport_map: tokio::sync::RwLock::new(map),
586 client_details_tx,
587 client_details_rx,
588 request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
589 })
590 }
591}