Skip to main content

rust_mcp_sdk/mcp_runtimes/
server_runtime.rs

1pub mod mcp_server_runtime;
2pub mod mcp_server_runtime_core;
3use crate::auth::AuthInfo;
4use crate::error::SdkResult;
5use crate::mcp_traits::{
6    McpObserver, McpServer, McpServerHandler, RequestIdGen, RequestIdGenNumeric,
7};
8use crate::schema::{
9    schema_utils::{
10        ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage,
11        ServerMessages,
12    },
13    InitializeRequestParams, InitializeResult, RequestId, RpcError,
14};
15use crate::task_store::{ClientTaskStore, ServerTaskStore, TaskStatusPoller, TaskStatusUpdate};
16use crate::utils::AbortTaskOnDrop;
17use async_trait::async_trait;
18use futures::future::try_join_all;
19use futures::{StreamExt, TryFutureExt};
20use rust_mcp_schema::{GetTaskParams, GetTaskPayloadParams};
21#[cfg(feature = "hyper-server")]
22use rust_mcp_transport::SessionId;
23use rust_mcp_transport::{IoStream, TaskId, TransportDispatcher};
24use std::panic;
25use std::sync::Arc;
26use std::time::Duration;
27use tokio::io::AsyncWriteExt;
28use tokio::sync::{mpsc, oneshot, watch, RwLock, RwLockReadGuard};
29
30pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM";
31const TASK_CHANNEL_CAPACITY: usize = 500;
32
33// Define a type alias for the TransportDispatcher trait object
34type TransportType = Arc<
35    dyn TransportDispatcher<
36        ClientMessages,
37        MessageFromServer,
38        ClientMessage,
39        ServerMessages,
40        ServerMessage,
41    >,
42>;
43
44/// Struct representing the runtime core of the MCP server, handling transport and client details
45pub struct ServerRuntime {
46    // The handler for processing MCP messages
47    handler: Arc<dyn McpServerHandler>,
48    // Information about the server
49    server_details: Arc<InitializeResult>,
50    #[cfg(feature = "hyper-server")]
51    session_id: Option<SessionId>,
52    transport_map: tokio::sync::RwLock<Option<TransportType>>,
53    request_id_gen: Box<dyn RequestIdGen>,
54    client_details_tx: watch::Sender<Option<InitializeRequestParams>>,
55    client_details_rx: watch::Receiver<Option<InitializeRequestParams>>,
56    auth_info: tokio::sync::RwLock<Option<AuthInfo>>,
57    task_store: Option<Arc<ServerTaskStore>>,
58    client_task_store: Option<Arc<ClientTaskStore>>,
59    message_observer: Option<Arc<dyn McpObserver<ClientMessage, ServerMessage>>>,
60}
61
62pub struct McpServerOptions<T>
63where
64    T: TransportDispatcher<
65        ClientMessages,
66        MessageFromServer,
67        ClientMessage,
68        ServerMessages,
69        ServerMessage,
70    >,
71{
72    pub server_details: InitializeResult,
73    pub transport: T,
74    pub handler: Arc<dyn McpServerHandler>,
75    pub task_store: Option<Arc<ServerTaskStore>>,
76    pub client_task_store: Option<Arc<ClientTaskStore>>,
77    pub message_observer: Option<Arc<dyn McpObserver<ClientMessage, ServerMessage>>>,
78}
79
80#[async_trait]
81impl McpServer for ServerRuntime {
82    fn task_store(&self) -> Option<Arc<ServerTaskStore>> {
83        self.task_store.clone()
84    }
85
86    fn client_task_store(&self) -> Option<Arc<ClientTaskStore>> {
87        self.client_task_store.clone()
88    }
89
90    /// Set the client details, storing them in client_details
91    async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> {
92        self.client_details_tx
93            .send(Some(client_details))
94            .map_err(|_| {
95                RpcError::internal_error()
96                    .with_message("Failed to set client details".to_string())
97                    .into()
98            })
99    }
100
101    async fn update_auth_info(&self, new_auth_info: Option<AuthInfo>) {
102        let should_update = {
103            let current = self.auth_info.read().await;
104            match (&*current, &new_auth_info) {
105                (None, Some(_)) => true,
106                (Some(old), Some(new)) => old.token_unique_id != new.token_unique_id,
107                (Some(_), None) => true,
108                (None, None) => false,
109            }
110        };
111
112        if should_update {
113            *self.auth_info.write().await = new_auth_info;
114        }
115    }
116
117    async fn auth_info(&self) -> RwLockReadGuard<'_, Option<AuthInfo>> {
118        self.auth_info.read().await
119    }
120    async fn auth_info_cloned(&self) -> Option<AuthInfo> {
121        let guard = self.auth_info.read().await;
122        guard.clone()
123    }
124
125    async fn wait_for_initialization(&self) {
126        loop {
127            if self.client_details_rx.borrow().is_some() {
128                return;
129            }
130            let mut rx = self.client_details_rx.clone();
131            rx.changed().await.ok();
132        }
133    }
134
135    async fn send(
136        &self,
137        message: MessageFromServer,
138        request_id: Option<RequestId>,
139        request_timeout: Option<Duration>,
140    ) -> SdkResult<Option<ClientMessage>> {
141        let transport_map = self.transport_map.read().await;
142        let transport = transport_map.as_ref().ok_or(
143            RpcError::internal_error()
144                .with_message("transport stream does not exists or is closed!".to_string()),
145        )?;
146
147        let outgoing_request_id = self
148            .request_id_gen
149            .request_id_for_message(&message, request_id);
150
151        let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?;
152
153        // telemetry
154        if let Some(observer) = self.message_observer.as_ref() {
155            observer.on_send(&mcp_message);
156        }
157
158        let response = transport
159            .send_message(ServerMessages::Single(mcp_message), request_timeout)
160            .await?
161            .map(|res| res.as_single())
162            .transpose()?;
163
164        Ok(response)
165    }
166
167    async fn send_batch(
168        &self,
169        messages: Vec<ServerMessage>,
170        request_timeout: Option<Duration>,
171    ) -> SdkResult<Option<Vec<ClientMessage>>> {
172        let transport_map = self.transport_map.read().await;
173        let transport = transport_map.as_ref().ok_or(
174            RpcError::internal_error()
175                .with_message("transport stream does not exists or is closed!".to_string()),
176        )?;
177
178        // telemetry
179        if let Some(observer) = self.message_observer.as_ref() {
180            messages.iter().for_each(|msg| observer.on_send(msg));
181        }
182
183        transport
184            .send_batch(messages, request_timeout)
185            .map_err(|err| err.into())
186            .await
187    }
188
189    /// Returns the server's details, including server capability,
190    /// instructions, protocol_version , server_info and optional meta data
191    fn server_info(&self) -> &InitializeResult {
192        &self.server_details
193    }
194
195    /// Returns the client information if available, after successful initialization , otherwise returns None
196    fn client_info(&self) -> Option<InitializeRequestParams> {
197        self.client_details_rx.borrow().clone()
198    }
199
200    /// Main runtime loop, processes incoming messages and handles requests
201    async fn start(self: Arc<Self>) -> SdkResult<()> {
202        let self_clone = self.clone();
203        let transport_map = self_clone.transport_map.read().await;
204
205        let transport = transport_map.as_ref().ok_or(
206            RpcError::internal_error()
207                .with_message("transport stream does not exists or is closed!".to_string()),
208        )?;
209
210        let mut stream = transport.start().await?;
211
212        // Create a channel to collect results from spawned tasks
213        let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY);
214
215        // Process incoming messages from the client
216        while let Some(mcp_messages) = stream.next().await {
217            match mcp_messages {
218                ClientMessages::Single(client_message) => {
219                    let transport = transport.clone();
220                    let self = self.clone();
221                    let tx = tx.clone();
222
223                    // Handle incoming messages in a separate task to avoid blocking the stream.
224                    tokio::spawn(async move {
225                        let result = self.handle_message(client_message, &transport).await;
226
227                        let send_result: SdkResult<_> = match result {
228                            Ok(result) => {
229                                if let Some(result) = result {
230                                    transport
231                                        .send_message(ServerMessages::Single(result), None)
232                                        .map_err(|e| e.into())
233                                        .await
234                                } else {
235                                    Ok(None)
236                                }
237                            }
238                            Err(error) => {
239                                tracing::error!("Error handling message : {}", error);
240                                Ok(None)
241                            }
242                        };
243                        // Send result to the main loop
244                        if let Err(error) = tx.send(send_result).await {
245                            tracing::error!("Failed to send result to channel: {}", error);
246                        }
247                    });
248                }
249                ClientMessages::Batch(client_messages) => {
250                    let transport = transport.clone();
251                    let self = self_clone.clone();
252                    let tx = tx.clone();
253
254                    tokio::spawn(async move {
255                        let handling_tasks: Vec<_> = client_messages
256                            .into_iter()
257                            .map(|client_message| self.handle_message(client_message, &transport))
258                            .collect();
259
260                        let send_result = match try_join_all(handling_tasks).await {
261                            Ok(results) => {
262                                let results: Vec<_> = results.into_iter().flatten().collect();
263                                if !results.is_empty() {
264                                    transport
265                                        .send_message(ServerMessages::Batch(results), None)
266                                        .map_err(|e| e.into())
267                                        .await
268                                } else {
269                                    Ok(None)
270                                }
271                            }
272                            Err(error) => Err(error),
273                        };
274
275                        if let Err(error) = tx.send(send_result).await {
276                            tracing::error!("Failed to send batch result to channel: {}", error);
277                        }
278                    });
279                }
280            }
281
282            // Check for results from spawned tasks to propagate errors
283            while let Ok(result) = rx.try_recv() {
284                result?; // Propagate errors
285            }
286        }
287
288        // Drop tx to close the channel and collect remaining results
289        drop(tx);
290        while let Some(result) = rx.recv().await {
291            result?; // Propagate errors
292        }
293
294        return Ok(());
295    }
296
297    async fn stderr_message(&self, message: String) -> SdkResult<()> {
298        let transport_map = self.transport_map.read().await;
299        let transport = transport_map.as_ref().ok_or(
300            RpcError::internal_error()
301                .with_message("transport stream does not exists or is closed!".to_string()),
302        )?;
303        let mut lock = transport.error_stream().write().await;
304
305        if let Some(IoStream::Writable(stderr)) = lock.as_mut() {
306            stderr.write_all(message.as_bytes()).await?;
307            stderr.write_all(b"\n").await?;
308            stderr.flush().await?;
309        }
310        Ok(())
311    }
312
313    #[cfg(feature = "hyper-server")]
314    fn session_id(&self) -> Option<SessionId> {
315        self.session_id.to_owned()
316    }
317}
318
319impl ServerRuntime {
320    pub(crate) async fn consume_payload_string(&self, payload: &str) -> SdkResult<()> {
321        let transport_map = self.transport_map.read().await;
322
323        let transport = transport_map.as_ref().ok_or(
324            RpcError::internal_error()
325                .with_message("stream id does not exists or is closed!".to_string()),
326        )?;
327
328        transport.consume_string_payload(payload).await?;
329
330        Ok(())
331    }
332
333    pub(crate) async fn handle_message(
334        self: &Arc<Self>,
335        message: ClientMessage,
336        transport: &Arc<
337            dyn TransportDispatcher<
338                ClientMessages,
339                MessageFromServer,
340                ClientMessage,
341                ServerMessages,
342                ServerMessage,
343            >,
344        >,
345    ) -> SdkResult<Option<ServerMessage>> {
346        // telemetry
347        if let Some(observer) = self.message_observer.as_ref() {
348            observer.on_receive(&message);
349        }
350
351        let response = match message {
352            // Handle a client request
353            ClientMessage::Request(client_jsonrpc_request) => {
354                let request_id = client_jsonrpc_request.request_id().clone();
355
356                let result = self
357                    .handler
358                    .handle_request(client_jsonrpc_request, self.clone())
359                    .await;
360
361                // create a response to send back to the client
362                let response: MessageFromServer = match result {
363                    Ok(success_value) => success_value.into(),
364                    Err(error_value) => {
365                        // Error occurred during initialization.
366                        // A likely cause could be an unsupported protocol version.
367                        if !self.is_initialized() {
368                            return Err(error_value.into());
369                        }
370                        MessageFromServer::Error(error_value)
371                    }
372                };
373
374                let mpc_message: ServerMessage =
375                    ServerMessage::from_message(response, Some(request_id))?;
376
377                Some(mpc_message)
378            }
379            ClientMessage::Notification(client_jsonrpc_notification) => {
380                self.handler
381                    .handle_notification(client_jsonrpc_notification, self.clone())
382                    .await?;
383                None
384            }
385            ClientMessage::Error(jsonrpc_error) => {
386                self.handler
387                    .handle_error(&jsonrpc_error.error, self.clone())
388                    .await?;
389
390                if let Some(request_id) = jsonrpc_error.id.as_ref() {
391                    if let Some(tx_response) = transport.pending_request_tx(request_id).await {
392                        tx_response
393                            .send(ClientMessage::Error(jsonrpc_error))
394                            .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
395                    } else {
396                        tracing::warn!(
397                            "Received an error response with no corresponding request {:?}",
398                            &jsonrpc_error.id
399                        );
400                    }
401                }
402                None
403            }
404            ClientMessage::Response(response) => {
405                if let Some(tx_response) = transport.pending_request_tx(&response.id).await {
406                    tx_response
407                        .send(ClientMessage::Response(response))
408                        .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
409                } else {
410                    tracing::warn!(
411                        "Received a response with no corresponding request: {:?}",
412                        &response.id
413                    );
414                }
415                None
416            }
417        };
418        Ok(response)
419    }
420
421    pub(crate) async fn store_transport(
422        &self,
423        stream_id: &str,
424        transport: Arc<
425            dyn TransportDispatcher<
426                ClientMessages,
427                MessageFromServer,
428                ClientMessage,
429                ServerMessages,
430                ServerMessage,
431            >,
432        >,
433    ) -> SdkResult<()> {
434        if stream_id != DEFAULT_STREAM_ID {
435            return Ok(());
436        }
437        let mut transport_map = self.transport_map.write().await;
438        tracing::trace!("save transport for stream id : {}", stream_id);
439        *transport_map = Some(transport);
440        Ok(())
441    }
442
443    //TODO: re-visit and simplify unnecessary hashmap
444    pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> {
445        if stream_id != DEFAULT_STREAM_ID {
446            return Ok(());
447        }
448        let transport_map = self.transport_map.read().await;
449        tracing::trace!("removing transport for stream id : {}", stream_id);
450        if let Some(transport) = transport_map.as_ref() {
451            transport.shut_down().await?;
452        }
453        // transport_map.remove(stream_id);
454        Ok(())
455    }
456
457    pub(crate) async fn shutdown(&self) {
458        let mut transport_map = self.transport_map.write().await;
459        let transport_option = transport_map.take();
460        drop(transport_map);
461        if let Some(transport) = transport_option {
462            let _ = transport.shut_down().await;
463        }
464    }
465
466    pub(crate) async fn default_stream_exists(&self) -> bool {
467        let transport_map = self.transport_map.read().await;
468        let live_transport = if let Some(t) = transport_map.as_ref() {
469            !t.is_shut_down().await
470        } else {
471            false
472        };
473        live_transport
474    }
475
476    pub(crate) async fn start_stream(
477        self: Arc<Self>,
478        transport: Arc<
479            dyn TransportDispatcher<
480                ClientMessages,
481                MessageFromServer,
482                ClientMessage,
483                ServerMessages,
484                ServerMessage,
485            >,
486        >,
487        stream_id: &str,
488        ping_interval: Duration,
489        payload: Option<String>,
490    ) -> SdkResult<()> {
491        let mut stream = transport.start().await?;
492
493        if stream_id == DEFAULT_STREAM_ID {
494            self.store_transport(stream_id, transport.clone()).await?;
495        }
496
497        let self_clone = self.clone();
498
499        let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>();
500        let abort_alive_task = transport
501            .keep_alive(ping_interval, disconnect_tx)
502            .await?
503            .abort_handle();
504
505        // ensure keep_alive task will be aborted
506        let _abort_guard = AbortTaskOnDrop {
507            handle: abort_alive_task,
508        };
509
510        // in case there is a payload, we consume it by transport to get processed
511        // payload would be message payload coming from the client
512        if let Some(payload) = payload {
513            if let Err(err) = transport.consume_string_payload(&payload).await {
514                let _ = self.remove_transport(stream_id).await;
515                return Err(err.into());
516            }
517        }
518
519        // Create a channel to collect results from spawned tasks
520        let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY);
521
522        loop {
523            tokio::select! {
524                Some(mcp_messages) = stream.next() =>{
525
526                    match mcp_messages {
527                        ClientMessages::Single(client_message) => {
528                            let transport = transport.clone();
529                            let self_clone = self.clone();
530                            let tx = tx.clone();
531                            tokio::spawn(async move {
532
533                                let result = self_clone.handle_message(client_message, &transport).await;
534
535                                let send_result: SdkResult<_> = match result {
536                                    Ok(result) => {
537                                        if let Some(result) = result {
538                                            transport
539                                                .send_message(ServerMessages::Single(result), None)
540                                                .map_err(|e| e.into())
541                                                .await
542                                        } else {
543                                            Ok(None)
544                                        }
545                                    }
546                                    Err(error) => {
547                                        tracing::error!("Error handling message : {}", error);
548                                        Ok(None)
549                                    }
550                                };
551                                if let Err(error) = tx.send(send_result).await {
552                                    tracing::error!("Failed to send batch result to channel: {}", error);
553                                }
554                            });
555                        }
556                        ClientMessages::Batch(client_messages) => {
557
558                            let transport = transport.clone();
559                            let self_clone = self_clone.clone();
560                            let tx = tx.clone();
561
562                            tokio::spawn(async move {
563                                let handling_tasks: Vec<_> = client_messages
564                                    .into_iter()
565                                    .map(|client_message| self_clone.handle_message(client_message, &transport))
566                                    .collect();
567
568                                    let send_result = match try_join_all(handling_tasks).await {
569                                         Ok(results) => {
570                                             let results: Vec<_> = results.into_iter().flatten().collect();
571                                             if !results.is_empty() {
572                                                 transport.send_message(ServerMessages::Batch(results), None)
573                                                 .map_err(|e| e.into())
574                                                 .await
575                                             }else {
576                                                 Ok(None)
577                                             }
578                                         },
579                                        Err(error) => Err(error),
580                                    };
581                                    if let Err(error) = tx.send(send_result).await {
582                                        tracing::error!("Failed to send batch result to channel: {}", error);
583                                    }
584                            });
585                        }
586                    }
587
588                    // Check for results from spawned tasks to propagate errors
589                    while let Ok(result) = rx.try_recv() {
590                        result?; // Propagate errors
591                    }
592
593                    // close the stream after all messages are sent, unless it is a standalone stream
594                    if !stream_id.eq(DEFAULT_STREAM_ID){
595                        // Drop tx to close the channel and collect remaining results
596                        drop(tx);
597                        while let Some(result) = rx.recv().await {
598                            result?; // Propagate errors
599                        }
600                        return  Ok(());
601                    }
602                }
603                _ = &mut disconnect_rx => {
604                    // Drop tx to close the channel and collect remaining results
605                    drop(tx);
606                    while let Some(result) = rx.recv().await {
607                        result?; // Propagate errors
608                    }
609                                self.remove_transport(stream_id).await?;
610                                // Disconnection detected by keep-alive task
611                                return Err(SdkError::connection_closed().into());
612
613                }
614            }
615        }
616    }
617
618    #[cfg(feature = "hyper-server")]
619    pub(crate) fn new_instance(
620        server_details: Arc<InitializeResult>,
621        handler: Arc<dyn McpServerHandler>,
622        session_id: SessionId,
623        auth_info: Option<AuthInfo>,
624        task_store: Option<Arc<ServerTaskStore>>,
625        client_task_store: Option<Arc<ClientTaskStore>>,
626        message_observer: Option<Arc<dyn McpObserver<ClientMessage, ServerMessage>>>,
627    ) -> Arc<Self> {
628        use tokio::sync::RwLock;
629
630        let (client_details_tx, client_details_rx) =
631            watch::channel::<Option<InitializeRequestParams>>(None);
632        Arc::new(Self {
633            server_details,
634            handler,
635            session_id: Some(session_id),
636            transport_map: tokio::sync::RwLock::new(None),
637            client_details_tx,
638            client_details_rx,
639            request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
640            auth_info: RwLock::new(auth_info),
641            task_store,
642            client_task_store,
643            message_observer,
644        })
645    }
646
647    pub(crate) async fn poll_task_status(
648        self: Arc<ServerRuntime>,
649        task_id: TaskId,
650        session_id: Option<String>,
651        task_store: Arc<ClientTaskStore>,
652    ) -> SdkResult<TaskStatusUpdate> {
653        let result = self
654            .request_get_task(GetTaskParams {
655                task_id: task_id.to_string(),
656            })
657            .await?;
658
659        if result.is_terminal() {
660            let task_payload = self
661                .request_get_task_payload(GetTaskPayloadParams {
662                    task_id: task_id.clone(),
663                })
664                .await?;
665
666            task_store
667                .store_task_result(
668                    task_id.as_str(),
669                    result.status,
670                    task_payload.into(),
671                    session_id.as_ref(),
672                )
673                .await;
674        }
675        Ok((result.status, result.poll_interval))
676    }
677
678    pub(crate) fn new<T>(options: McpServerOptions<T>) -> Arc<Self>
679    where
680        T: TransportDispatcher<
681            ClientMessages,
682            MessageFromServer,
683            ClientMessage,
684            ServerMessages,
685            ServerMessage,
686        >,
687    {
688        let (client_details_tx, client_details_rx) =
689            watch::channel::<Option<InitializeRequestParams>>(None);
690
691        let runtime = Arc::new(Self {
692            server_details: Arc::new(options.server_details),
693            handler: options.handler,
694            #[cfg(feature = "hyper-server")]
695            session_id: None,
696            transport_map: tokio::sync::RwLock::new(Some(Arc::new(options.transport))),
697            client_details_tx,
698            client_details_rx,
699            request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
700            auth_info: RwLock::new(None),
701            task_store: options.task_store,
702            client_task_store: options.client_task_store,
703            message_observer: options.message_observer,
704        });
705
706        let runtime_clone = runtime.clone();
707        if let Some(task_store) = runtime_clone.task_store() {
708            // send TaskStatusNotification  if task_store is present and supports subscribe()
709            if let Some(mut stream) = task_store.subscribe() {
710                tokio::spawn(async move {
711                    while let Some((params, _)) = stream.next().await {
712                        let _ = runtime_clone.notify_task_status(params).await;
713                    }
714                });
715            }
716        }
717
718        // Task polling for server initiated tasks
719        if let Some(client_task_store) = runtime.client_task_store.clone() {
720            let task_store_clone = client_task_store.clone();
721            let runtime_clone = runtime.clone();
722
723            let callback: TaskStatusPoller = Box::new(move |task_id, session_id| {
724                let task_store_clone = client_task_store.clone();
725                let runtime_clone = runtime_clone.clone();
726
727                Box::pin(async move {
728                    runtime_clone
729                        .poll_task_status(task_id, session_id, task_store_clone)
730                        .await
731                })
732            });
733
734            if let Err(error) = task_store_clone.start_task_polling(callback) {
735                tracing::error!("Failed to start task polling: {error}");
736            }
737        }
738
739        runtime
740    }
741}