1pub mod mcp_server_runtime;
2pub mod mcp_server_runtime_core;
3use crate::auth::AuthInfo;
4use crate::error::SdkResult;
5use crate::mcp_traits::{
6 McpObserver, McpServer, McpServerHandler, RequestIdGen, RequestIdGenNumeric,
7};
8use crate::schema::{
9 schema_utils::{
10 ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage,
11 ServerMessages,
12 },
13 InitializeRequestParams, InitializeResult, RequestId, RpcError,
14};
15use crate::task_store::{ClientTaskStore, ServerTaskStore, TaskStatusPoller, TaskStatusUpdate};
16use crate::utils::AbortTaskOnDrop;
17use async_trait::async_trait;
18use futures::future::try_join_all;
19use futures::{StreamExt, TryFutureExt};
20use rust_mcp_schema::{GetTaskParams, GetTaskPayloadParams};
21#[cfg(feature = "hyper-server")]
22use rust_mcp_transport::SessionId;
23use rust_mcp_transport::{IoStream, TaskId, TransportDispatcher};
24use std::panic;
25use std::sync::Arc;
26use std::time::Duration;
27use tokio::io::AsyncWriteExt;
28use tokio::sync::{mpsc, oneshot, watch, RwLock, RwLockReadGuard};
29
30pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM";
31const TASK_CHANNEL_CAPACITY: usize = 500;
32
33type TransportType = Arc<
35 dyn TransportDispatcher<
36 ClientMessages,
37 MessageFromServer,
38 ClientMessage,
39 ServerMessages,
40 ServerMessage,
41 >,
42>;
43
44pub struct ServerRuntime {
46 handler: Arc<dyn McpServerHandler>,
48 server_details: Arc<InitializeResult>,
50 #[cfg(feature = "hyper-server")]
51 session_id: Option<SessionId>,
52 transport_map: tokio::sync::RwLock<Option<TransportType>>,
53 request_id_gen: Box<dyn RequestIdGen>,
54 client_details_tx: watch::Sender<Option<InitializeRequestParams>>,
55 client_details_rx: watch::Receiver<Option<InitializeRequestParams>>,
56 auth_info: tokio::sync::RwLock<Option<AuthInfo>>,
57 task_store: Option<Arc<ServerTaskStore>>,
58 client_task_store: Option<Arc<ClientTaskStore>>,
59 message_observer: Option<Arc<dyn McpObserver<ClientMessage, ServerMessage>>>,
60}
61
62pub struct McpServerOptions<T>
63where
64 T: TransportDispatcher<
65 ClientMessages,
66 MessageFromServer,
67 ClientMessage,
68 ServerMessages,
69 ServerMessage,
70 >,
71{
72 pub server_details: InitializeResult,
73 pub transport: T,
74 pub handler: Arc<dyn McpServerHandler>,
75 pub task_store: Option<Arc<ServerTaskStore>>,
76 pub client_task_store: Option<Arc<ClientTaskStore>>,
77 pub message_observer: Option<Arc<dyn McpObserver<ClientMessage, ServerMessage>>>,
78}
79
80#[async_trait]
81impl McpServer for ServerRuntime {
82 fn task_store(&self) -> Option<Arc<ServerTaskStore>> {
83 self.task_store.clone()
84 }
85
86 fn client_task_store(&self) -> Option<Arc<ClientTaskStore>> {
87 self.client_task_store.clone()
88 }
89
90 async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> {
92 self.client_details_tx
93 .send(Some(client_details))
94 .map_err(|_| {
95 RpcError::internal_error()
96 .with_message("Failed to set client details".to_string())
97 .into()
98 })
99 }
100
101 async fn update_auth_info(&self, new_auth_info: Option<AuthInfo>) {
102 let should_update = {
103 let current = self.auth_info.read().await;
104 match (&*current, &new_auth_info) {
105 (None, Some(_)) => true,
106 (Some(old), Some(new)) => old.token_unique_id != new.token_unique_id,
107 (Some(_), None) => true,
108 (None, None) => false,
109 }
110 };
111
112 if should_update {
113 *self.auth_info.write().await = new_auth_info;
114 }
115 }
116
117 async fn auth_info(&self) -> RwLockReadGuard<'_, Option<AuthInfo>> {
118 self.auth_info.read().await
119 }
120 async fn auth_info_cloned(&self) -> Option<AuthInfo> {
121 let guard = self.auth_info.read().await;
122 guard.clone()
123 }
124
125 async fn wait_for_initialization(&self) {
126 loop {
127 if self.client_details_rx.borrow().is_some() {
128 return;
129 }
130 let mut rx = self.client_details_rx.clone();
131 rx.changed().await.ok();
132 }
133 }
134
135 async fn send(
136 &self,
137 message: MessageFromServer,
138 request_id: Option<RequestId>,
139 request_timeout: Option<Duration>,
140 ) -> SdkResult<Option<ClientMessage>> {
141 let transport_map = self.transport_map.read().await;
142 let transport = transport_map.as_ref().ok_or(
143 RpcError::internal_error()
144 .with_message("transport stream does not exists or is closed!".to_string()),
145 )?;
146
147 let outgoing_request_id = self
148 .request_id_gen
149 .request_id_for_message(&message, request_id);
150
151 let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?;
152
153 if let Some(observer) = self.message_observer.as_ref() {
155 observer.on_send(&mcp_message);
156 }
157
158 let response = transport
159 .send_message(ServerMessages::Single(mcp_message), request_timeout)
160 .await?
161 .map(|res| res.as_single())
162 .transpose()?;
163
164 Ok(response)
165 }
166
167 async fn send_batch(
168 &self,
169 messages: Vec<ServerMessage>,
170 request_timeout: Option<Duration>,
171 ) -> SdkResult<Option<Vec<ClientMessage>>> {
172 let transport_map = self.transport_map.read().await;
173 let transport = transport_map.as_ref().ok_or(
174 RpcError::internal_error()
175 .with_message("transport stream does not exists or is closed!".to_string()),
176 )?;
177
178 if let Some(observer) = self.message_observer.as_ref() {
180 messages.iter().for_each(|msg| observer.on_send(msg));
181 }
182
183 transport
184 .send_batch(messages, request_timeout)
185 .map_err(|err| err.into())
186 .await
187 }
188
189 fn server_info(&self) -> &InitializeResult {
192 &self.server_details
193 }
194
195 fn client_info(&self) -> Option<InitializeRequestParams> {
197 self.client_details_rx.borrow().clone()
198 }
199
200 async fn start(self: Arc<Self>) -> SdkResult<()> {
202 let self_clone = self.clone();
203 let transport_map = self_clone.transport_map.read().await;
204
205 let transport = transport_map.as_ref().ok_or(
206 RpcError::internal_error()
207 .with_message("transport stream does not exists or is closed!".to_string()),
208 )?;
209
210 let mut stream = transport.start().await?;
211
212 let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY);
214
215 while let Some(mcp_messages) = stream.next().await {
217 match mcp_messages {
218 ClientMessages::Single(client_message) => {
219 let transport = transport.clone();
220 let self = self.clone();
221 let tx = tx.clone();
222
223 tokio::spawn(async move {
225 let result = self.handle_message(client_message, &transport).await;
226
227 let send_result: SdkResult<_> = match result {
228 Ok(result) => {
229 if let Some(result) = result {
230 transport
231 .send_message(ServerMessages::Single(result), None)
232 .map_err(|e| e.into())
233 .await
234 } else {
235 Ok(None)
236 }
237 }
238 Err(error) => {
239 tracing::error!("Error handling message : {}", error);
240 Ok(None)
241 }
242 };
243 if let Err(error) = tx.send(send_result).await {
245 tracing::error!("Failed to send result to channel: {}", error);
246 }
247 });
248 }
249 ClientMessages::Batch(client_messages) => {
250 let transport = transport.clone();
251 let self = self_clone.clone();
252 let tx = tx.clone();
253
254 tokio::spawn(async move {
255 let handling_tasks: Vec<_> = client_messages
256 .into_iter()
257 .map(|client_message| self.handle_message(client_message, &transport))
258 .collect();
259
260 let send_result = match try_join_all(handling_tasks).await {
261 Ok(results) => {
262 let results: Vec<_> = results.into_iter().flatten().collect();
263 if !results.is_empty() {
264 transport
265 .send_message(ServerMessages::Batch(results), None)
266 .map_err(|e| e.into())
267 .await
268 } else {
269 Ok(None)
270 }
271 }
272 Err(error) => Err(error),
273 };
274
275 if let Err(error) = tx.send(send_result).await {
276 tracing::error!("Failed to send batch result to channel: {}", error);
277 }
278 });
279 }
280 }
281
282 while let Ok(result) = rx.try_recv() {
284 result?; }
286 }
287
288 drop(tx);
290 while let Some(result) = rx.recv().await {
291 result?; }
293
294 return Ok(());
295 }
296
297 async fn stderr_message(&self, message: String) -> SdkResult<()> {
298 let transport_map = self.transport_map.read().await;
299 let transport = transport_map.as_ref().ok_or(
300 RpcError::internal_error()
301 .with_message("transport stream does not exists or is closed!".to_string()),
302 )?;
303 let mut lock = transport.error_stream().write().await;
304
305 if let Some(IoStream::Writable(stderr)) = lock.as_mut() {
306 stderr.write_all(message.as_bytes()).await?;
307 stderr.write_all(b"\n").await?;
308 stderr.flush().await?;
309 }
310 Ok(())
311 }
312
313 #[cfg(feature = "hyper-server")]
314 fn session_id(&self) -> Option<SessionId> {
315 self.session_id.to_owned()
316 }
317}
318
319impl ServerRuntime {
320 pub(crate) async fn consume_payload_string(&self, payload: &str) -> SdkResult<()> {
321 let transport_map = self.transport_map.read().await;
322
323 let transport = transport_map.as_ref().ok_or(
324 RpcError::internal_error()
325 .with_message("stream id does not exists or is closed!".to_string()),
326 )?;
327
328 transport.consume_string_payload(payload).await?;
329
330 Ok(())
331 }
332
333 pub(crate) async fn handle_message(
334 self: &Arc<Self>,
335 message: ClientMessage,
336 transport: &Arc<
337 dyn TransportDispatcher<
338 ClientMessages,
339 MessageFromServer,
340 ClientMessage,
341 ServerMessages,
342 ServerMessage,
343 >,
344 >,
345 ) -> SdkResult<Option<ServerMessage>> {
346 if let Some(observer) = self.message_observer.as_ref() {
348 observer.on_receive(&message);
349 }
350
351 let response = match message {
352 ClientMessage::Request(client_jsonrpc_request) => {
354 let request_id = client_jsonrpc_request.request_id().clone();
355
356 let result = self
357 .handler
358 .handle_request(client_jsonrpc_request, self.clone())
359 .await;
360
361 let response: MessageFromServer = match result {
363 Ok(success_value) => success_value.into(),
364 Err(error_value) => {
365 if !self.is_initialized() {
368 return Err(error_value.into());
369 }
370 MessageFromServer::Error(error_value)
371 }
372 };
373
374 let mpc_message: ServerMessage =
375 ServerMessage::from_message(response, Some(request_id))?;
376
377 Some(mpc_message)
378 }
379 ClientMessage::Notification(client_jsonrpc_notification) => {
380 self.handler
381 .handle_notification(client_jsonrpc_notification, self.clone())
382 .await?;
383 None
384 }
385 ClientMessage::Error(jsonrpc_error) => {
386 self.handler
387 .handle_error(&jsonrpc_error.error, self.clone())
388 .await?;
389
390 if let Some(request_id) = jsonrpc_error.id.as_ref() {
391 if let Some(tx_response) = transport.pending_request_tx(request_id).await {
392 tx_response
393 .send(ClientMessage::Error(jsonrpc_error))
394 .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
395 } else {
396 tracing::warn!(
397 "Received an error response with no corresponding request {:?}",
398 &jsonrpc_error.id
399 );
400 }
401 }
402 None
403 }
404 ClientMessage::Response(response) => {
405 if let Some(tx_response) = transport.pending_request_tx(&response.id).await {
406 tx_response
407 .send(ClientMessage::Response(response))
408 .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
409 } else {
410 tracing::warn!(
411 "Received a response with no corresponding request: {:?}",
412 &response.id
413 );
414 }
415 None
416 }
417 };
418 Ok(response)
419 }
420
421 pub(crate) async fn store_transport(
422 &self,
423 stream_id: &str,
424 transport: Arc<
425 dyn TransportDispatcher<
426 ClientMessages,
427 MessageFromServer,
428 ClientMessage,
429 ServerMessages,
430 ServerMessage,
431 >,
432 >,
433 ) -> SdkResult<()> {
434 if stream_id != DEFAULT_STREAM_ID {
435 return Ok(());
436 }
437 let mut transport_map = self.transport_map.write().await;
438 tracing::trace!("save transport for stream id : {}", stream_id);
439 *transport_map = Some(transport);
440 Ok(())
441 }
442
443 pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> {
445 if stream_id != DEFAULT_STREAM_ID {
446 return Ok(());
447 }
448 let transport_map = self.transport_map.read().await;
449 tracing::trace!("removing transport for stream id : {}", stream_id);
450 if let Some(transport) = transport_map.as_ref() {
451 transport.shut_down().await?;
452 }
453 Ok(())
455 }
456
457 pub(crate) async fn shutdown(&self) {
458 let mut transport_map = self.transport_map.write().await;
459 let transport_option = transport_map.take();
460 drop(transport_map);
461 if let Some(transport) = transport_option {
462 let _ = transport.shut_down().await;
463 }
464 }
465
466 pub(crate) async fn default_stream_exists(&self) -> bool {
467 let transport_map = self.transport_map.read().await;
468 let live_transport = if let Some(t) = transport_map.as_ref() {
469 !t.is_shut_down().await
470 } else {
471 false
472 };
473 live_transport
474 }
475
476 pub(crate) async fn start_stream(
477 self: Arc<Self>,
478 transport: Arc<
479 dyn TransportDispatcher<
480 ClientMessages,
481 MessageFromServer,
482 ClientMessage,
483 ServerMessages,
484 ServerMessage,
485 >,
486 >,
487 stream_id: &str,
488 ping_interval: Duration,
489 payload: Option<String>,
490 ) -> SdkResult<()> {
491 let mut stream = transport.start().await?;
492
493 if stream_id == DEFAULT_STREAM_ID {
494 self.store_transport(stream_id, transport.clone()).await?;
495 }
496
497 let self_clone = self.clone();
498
499 let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>();
500 let abort_alive_task = transport
501 .keep_alive(ping_interval, disconnect_tx)
502 .await?
503 .abort_handle();
504
505 let _abort_guard = AbortTaskOnDrop {
507 handle: abort_alive_task,
508 };
509
510 if let Some(payload) = payload {
513 if let Err(err) = transport.consume_string_payload(&payload).await {
514 let _ = self.remove_transport(stream_id).await;
515 return Err(err.into());
516 }
517 }
518
519 let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY);
521
522 loop {
523 tokio::select! {
524 Some(mcp_messages) = stream.next() =>{
525
526 match mcp_messages {
527 ClientMessages::Single(client_message) => {
528 let transport = transport.clone();
529 let self_clone = self.clone();
530 let tx = tx.clone();
531 tokio::spawn(async move {
532
533 let result = self_clone.handle_message(client_message, &transport).await;
534
535 let send_result: SdkResult<_> = match result {
536 Ok(result) => {
537 if let Some(result) = result {
538 transport
539 .send_message(ServerMessages::Single(result), None)
540 .map_err(|e| e.into())
541 .await
542 } else {
543 Ok(None)
544 }
545 }
546 Err(error) => {
547 tracing::error!("Error handling message : {}", error);
548 Ok(None)
549 }
550 };
551 if let Err(error) = tx.send(send_result).await {
552 tracing::error!("Failed to send batch result to channel: {}", error);
553 }
554 });
555 }
556 ClientMessages::Batch(client_messages) => {
557
558 let transport = transport.clone();
559 let self_clone = self_clone.clone();
560 let tx = tx.clone();
561
562 tokio::spawn(async move {
563 let handling_tasks: Vec<_> = client_messages
564 .into_iter()
565 .map(|client_message| self_clone.handle_message(client_message, &transport))
566 .collect();
567
568 let send_result = match try_join_all(handling_tasks).await {
569 Ok(results) => {
570 let results: Vec<_> = results.into_iter().flatten().collect();
571 if !results.is_empty() {
572 transport.send_message(ServerMessages::Batch(results), None)
573 .map_err(|e| e.into())
574 .await
575 }else {
576 Ok(None)
577 }
578 },
579 Err(error) => Err(error),
580 };
581 if let Err(error) = tx.send(send_result).await {
582 tracing::error!("Failed to send batch result to channel: {}", error);
583 }
584 });
585 }
586 }
587
588 while let Ok(result) = rx.try_recv() {
590 result?; }
592
593 if !stream_id.eq(DEFAULT_STREAM_ID){
595 drop(tx);
597 while let Some(result) = rx.recv().await {
598 result?; }
600 return Ok(());
601 }
602 }
603 _ = &mut disconnect_rx => {
604 drop(tx);
606 while let Some(result) = rx.recv().await {
607 result?; }
609 self.remove_transport(stream_id).await?;
610 return Err(SdkError::connection_closed().into());
612
613 }
614 }
615 }
616 }
617
618 #[cfg(feature = "hyper-server")]
619 pub(crate) fn new_instance(
620 server_details: Arc<InitializeResult>,
621 handler: Arc<dyn McpServerHandler>,
622 session_id: SessionId,
623 auth_info: Option<AuthInfo>,
624 task_store: Option<Arc<ServerTaskStore>>,
625 client_task_store: Option<Arc<ClientTaskStore>>,
626 message_observer: Option<Arc<dyn McpObserver<ClientMessage, ServerMessage>>>,
627 ) -> Arc<Self> {
628 use tokio::sync::RwLock;
629
630 let (client_details_tx, client_details_rx) =
631 watch::channel::<Option<InitializeRequestParams>>(None);
632 Arc::new(Self {
633 server_details,
634 handler,
635 session_id: Some(session_id),
636 transport_map: tokio::sync::RwLock::new(None),
637 client_details_tx,
638 client_details_rx,
639 request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
640 auth_info: RwLock::new(auth_info),
641 task_store,
642 client_task_store,
643 message_observer,
644 })
645 }
646
647 pub(crate) async fn poll_task_status(
648 self: Arc<ServerRuntime>,
649 task_id: TaskId,
650 session_id: Option<String>,
651 task_store: Arc<ClientTaskStore>,
652 ) -> SdkResult<TaskStatusUpdate> {
653 let result = self
654 .request_get_task(GetTaskParams {
655 task_id: task_id.to_string(),
656 })
657 .await?;
658
659 if result.is_terminal() {
660 let task_payload = self
661 .request_get_task_payload(GetTaskPayloadParams {
662 task_id: task_id.clone(),
663 })
664 .await?;
665
666 task_store
667 .store_task_result(
668 task_id.as_str(),
669 result.status,
670 task_payload.into(),
671 session_id.as_ref(),
672 )
673 .await;
674 }
675 Ok((result.status, result.poll_interval))
676 }
677
678 pub(crate) fn new<T>(options: McpServerOptions<T>) -> Arc<Self>
679 where
680 T: TransportDispatcher<
681 ClientMessages,
682 MessageFromServer,
683 ClientMessage,
684 ServerMessages,
685 ServerMessage,
686 >,
687 {
688 let (client_details_tx, client_details_rx) =
689 watch::channel::<Option<InitializeRequestParams>>(None);
690
691 let runtime = Arc::new(Self {
692 server_details: Arc::new(options.server_details),
693 handler: options.handler,
694 #[cfg(feature = "hyper-server")]
695 session_id: None,
696 transport_map: tokio::sync::RwLock::new(Some(Arc::new(options.transport))),
697 client_details_tx,
698 client_details_rx,
699 request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
700 auth_info: RwLock::new(None),
701 task_store: options.task_store,
702 client_task_store: options.client_task_store,
703 message_observer: options.message_observer,
704 });
705
706 let runtime_clone = runtime.clone();
707 if let Some(task_store) = runtime_clone.task_store() {
708 if let Some(mut stream) = task_store.subscribe() {
710 tokio::spawn(async move {
711 while let Some((params, _)) = stream.next().await {
712 let _ = runtime_clone.notify_task_status(params).await;
713 }
714 });
715 }
716 }
717
718 if let Some(client_task_store) = runtime.client_task_store.clone() {
720 let task_store_clone = client_task_store.clone();
721 let runtime_clone = runtime.clone();
722
723 let callback: TaskStatusPoller = Box::new(move |task_id, session_id| {
724 let task_store_clone = client_task_store.clone();
725 let runtime_clone = runtime_clone.clone();
726
727 Box::pin(async move {
728 runtime_clone
729 .poll_task_status(task_id, session_id, task_store_clone)
730 .await
731 })
732 });
733
734 if let Err(error) = task_store_clone.start_task_polling(callback) {
735 tracing::error!("Failed to start task polling: {error}");
736 }
737 }
738
739 runtime
740 }
741}