rust_mcp_sdk/mcp_runtimes/
server_runtime.rs

1pub mod mcp_server_runtime;
2pub mod mcp_server_runtime_core;
3use crate::auth::AuthInfo;
4use crate::error::SdkResult;
5use crate::mcp_traits::{McpServer, McpServerHandler, RequestIdGen, RequestIdGenNumeric};
6use crate::schema::{
7    schema_utils::{
8        ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage,
9        ServerMessages,
10    },
11    InitializeRequestParams, InitializeResult, RequestId, RpcError,
12};
13use crate::utils::AbortTaskOnDrop;
14use async_trait::async_trait;
15use futures::future::try_join_all;
16use futures::{StreamExt, TryFutureExt};
17#[cfg(feature = "hyper-server")]
18use rust_mcp_transport::SessionId;
19use rust_mcp_transport::{IoStream, TransportDispatcher};
20use std::collections::HashMap;
21use std::panic;
22use std::sync::Arc;
23use std::time::Duration;
24use tokio::io::AsyncWriteExt;
25use tokio::sync::{mpsc, oneshot, watch, RwLock, RwLockReadGuard};
26
27pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM";
28const TASK_CHANNEL_CAPACITY: usize = 500;
29
30// Define a type alias for the TransportDispatcher trait object
31type TransportType = Arc<
32    dyn TransportDispatcher<
33        ClientMessages,
34        MessageFromServer,
35        ClientMessage,
36        ServerMessages,
37        ServerMessage,
38    >,
39>;
40
41/// Struct representing the runtime core of the MCP server, handling transport and client details
42pub struct ServerRuntime {
43    // The handler for processing MCP messages
44    handler: Arc<dyn McpServerHandler>,
45    // Information about the server
46    server_details: Arc<InitializeResult>,
47    #[cfg(feature = "hyper-server")]
48    session_id: Option<SessionId>,
49    transport_map: tokio::sync::RwLock<HashMap<String, TransportType>>, //TODO: remove the transport_map, we do not need a hashmap for it
50    request_id_gen: Box<dyn RequestIdGen>,
51    client_details_tx: watch::Sender<Option<InitializeRequestParams>>,
52    client_details_rx: watch::Receiver<Option<InitializeRequestParams>>,
53    auth_info: tokio::sync::RwLock<Option<AuthInfo>>,
54}
55
56#[async_trait]
57impl McpServer for ServerRuntime {
58    /// Set the client details, storing them in client_details
59    async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> {
60        self.client_details_tx
61            .send(Some(client_details))
62            .map_err(|_| {
63                RpcError::internal_error()
64                    .with_message("Failed to set client details".to_string())
65                    .into()
66            })
67    }
68
69    async fn update_auth_info(&self, new_auth_info: Option<AuthInfo>) {
70        let should_update = {
71            let current = self.auth_info.read().await;
72            match (&*current, &new_auth_info) {
73                (None, Some(_)) => true,
74                (Some(old), Some(new)) => old.token_unique_id != new.token_unique_id,
75                (Some(_), None) => true,
76                (None, None) => false,
77            }
78        };
79
80        if should_update {
81            *self.auth_info.write().await = new_auth_info;
82        }
83    }
84
85    async fn auth_info(&self) -> RwLockReadGuard<'_, Option<AuthInfo>> {
86        self.auth_info.read().await
87    }
88    async fn auth_info_cloned(&self) -> Option<AuthInfo> {
89        let guard = self.auth_info.read().await;
90        guard.clone()
91    }
92
93    async fn wait_for_initialization(&self) {
94        loop {
95            if self.client_details_rx.borrow().is_some() {
96                return;
97            }
98            let mut rx = self.client_details_rx.clone();
99            rx.changed().await.ok();
100        }
101    }
102
103    async fn send(
104        &self,
105        message: MessageFromServer,
106        request_id: Option<RequestId>,
107        request_timeout: Option<Duration>,
108    ) -> SdkResult<Option<ClientMessage>> {
109        let transport_map = self.transport_map.read().await;
110        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
111            RpcError::internal_error()
112                .with_message("transport stream does not exists or is closed!".to_string()),
113        )?;
114
115        let outgoing_request_id = self
116            .request_id_gen
117            .request_id_for_message(&message, request_id);
118
119        let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?;
120
121        let response = transport
122            .send_message(ServerMessages::Single(mcp_message), request_timeout)
123            .await?
124            .map(|res| res.as_single())
125            .transpose()?;
126
127        Ok(response)
128    }
129
130    async fn send_batch(
131        &self,
132        messages: Vec<ServerMessage>,
133        request_timeout: Option<Duration>,
134    ) -> SdkResult<Option<Vec<ClientMessage>>> {
135        let transport_map = self.transport_map.read().await;
136        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
137            RpcError::internal_error()
138                .with_message("transport stream does not exists or is closed!".to_string()),
139        )?;
140
141        transport
142            .send_batch(messages, request_timeout)
143            .map_err(|err| err.into())
144            .await
145    }
146
147    /// Returns the server's details, including server capability,
148    /// instructions, protocol_version , server_info and optional meta data
149    fn server_info(&self) -> &InitializeResult {
150        &self.server_details
151    }
152
153    /// Returns the client information if available, after successful initialization , otherwise returns None
154    fn client_info(&self) -> Option<InitializeRequestParams> {
155        self.client_details_rx.borrow().clone()
156    }
157
158    /// Main runtime loop, processes incoming messages and handles requests
159    async fn start(self: Arc<Self>) -> SdkResult<()> {
160        let self_clone = self.clone();
161        let transport_map = self_clone.transport_map.read().await;
162
163        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
164            RpcError::internal_error()
165                .with_message("transport stream does not exists or is closed!".to_string()),
166        )?;
167
168        let mut stream = transport.start().await?;
169
170        // Create a channel to collect results from spawned tasks
171        let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY);
172
173        // Process incoming messages from the client
174        while let Some(mcp_messages) = stream.next().await {
175            match mcp_messages {
176                ClientMessages::Single(client_message) => {
177                    let transport = transport.clone();
178                    let self = self.clone();
179                    let tx = tx.clone();
180
181                    // Handle incoming messages in a separate task to avoid blocking the stream.
182                    tokio::spawn(async move {
183                        let result = self.handle_message(client_message, &transport).await;
184
185                        let send_result: SdkResult<_> = match result {
186                            Ok(result) => {
187                                if let Some(result) = result {
188                                    transport
189                                        .send_message(ServerMessages::Single(result), None)
190                                        .map_err(|e| e.into())
191                                        .await
192                                } else {
193                                    Ok(None)
194                                }
195                            }
196                            Err(error) => {
197                                tracing::error!("Error handling message : {}", error);
198                                Ok(None)
199                            }
200                        };
201                        // Send result to the main loop
202                        if let Err(error) = tx.send(send_result).await {
203                            tracing::error!("Failed to send result to channel: {}", error);
204                        }
205                    });
206                }
207                ClientMessages::Batch(client_messages) => {
208                    let transport = transport.clone();
209                    let self = self_clone.clone();
210                    let tx = tx.clone();
211
212                    tokio::spawn(async move {
213                        let handling_tasks: Vec<_> = client_messages
214                            .into_iter()
215                            .map(|client_message| self.handle_message(client_message, &transport))
216                            .collect();
217
218                        let send_result = match try_join_all(handling_tasks).await {
219                            Ok(results) => {
220                                let results: Vec<_> = results.into_iter().flatten().collect();
221                                if !results.is_empty() {
222                                    transport
223                                        .send_message(ServerMessages::Batch(results), None)
224                                        .map_err(|e| e.into())
225                                        .await
226                                } else {
227                                    Ok(None)
228                                }
229                            }
230                            Err(error) => Err(error),
231                        };
232
233                        if let Err(error) = tx.send(send_result).await {
234                            tracing::error!("Failed to send batch result to channel: {}", error);
235                        }
236                    });
237                }
238            }
239
240            // Check for results from spawned tasks to propagate errors
241            while let Ok(result) = rx.try_recv() {
242                result?; // Propagate errors
243            }
244        }
245
246        // Drop tx to close the channel and collect remaining results
247        drop(tx);
248        while let Some(result) = rx.recv().await {
249            result?; // Propagate errors
250        }
251
252        return Ok(());
253    }
254
255    async fn stderr_message(&self, message: String) -> SdkResult<()> {
256        let transport_map = self.transport_map.read().await;
257        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
258            RpcError::internal_error()
259                .with_message("transport stream does not exists or is closed!".to_string()),
260        )?;
261        let mut lock = transport.error_stream().write().await;
262
263        if let Some(IoStream::Writable(stderr)) = lock.as_mut() {
264            stderr.write_all(message.as_bytes()).await?;
265            stderr.write_all(b"\n").await?;
266            stderr.flush().await?;
267        }
268        Ok(())
269    }
270
271    #[cfg(feature = "hyper-server")]
272    fn session_id(&self) -> Option<SessionId> {
273        self.session_id.to_owned()
274    }
275}
276
277impl ServerRuntime {
278    pub(crate) async fn consume_payload_string(
279        &self,
280        stream_id: &str,
281        payload: &str,
282    ) -> SdkResult<()> {
283        let transport_map = self.transport_map.read().await;
284
285        let transport = transport_map.get(stream_id).ok_or(
286            RpcError::internal_error()
287                .with_message("stream id does not exists or is closed!".to_string()),
288        )?;
289
290        transport.consume_string_payload(payload).await?;
291
292        Ok(())
293    }
294
295    pub(crate) async fn handle_message(
296        self: &Arc<Self>,
297        message: ClientMessage,
298        transport: &Arc<
299            dyn TransportDispatcher<
300                ClientMessages,
301                MessageFromServer,
302                ClientMessage,
303                ServerMessages,
304                ServerMessage,
305            >,
306        >,
307    ) -> SdkResult<Option<ServerMessage>> {
308        let response = match message {
309            // Handle a client request
310            ClientMessage::Request(client_jsonrpc_request) => {
311                let result = self
312                    .handler
313                    .handle_request(client_jsonrpc_request.request, self.clone())
314                    .await;
315                // create a response to send back to the client
316                let response: MessageFromServer = match result {
317                    Ok(success_value) => success_value.into(),
318                    Err(error_value) => {
319                        // Error occurred during initialization.
320                        // A likely cause could be an unsupported protocol version.
321                        if !self.is_initialized() {
322                            return Err(error_value.into());
323                        }
324                        MessageFromServer::Error(error_value)
325                    }
326                };
327
328                let mpc_message: ServerMessage =
329                    ServerMessage::from_message(response, Some(client_jsonrpc_request.id))?;
330
331                Some(mpc_message)
332            }
333            ClientMessage::Notification(client_jsonrpc_notification) => {
334                self.handler
335                    .handle_notification(client_jsonrpc_notification.notification, self.clone())
336                    .await?;
337                None
338            }
339            ClientMessage::Error(jsonrpc_error) => {
340                self.handler
341                    .handle_error(&jsonrpc_error.error, self.clone())
342                    .await?;
343                if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await {
344                    tx_response
345                        .send(ClientMessage::Error(jsonrpc_error))
346                        .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
347                } else {
348                    tracing::warn!(
349                        "Received an error response with no corresponding request {:?}",
350                        &jsonrpc_error.id
351                    );
352                }
353                None
354            }
355            ClientMessage::Response(response) => {
356                if let Some(tx_response) = transport.pending_request_tx(&response.id).await {
357                    tx_response
358                        .send(ClientMessage::Response(response))
359                        .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
360                } else {
361                    tracing::warn!(
362                        "Received a response with no corresponding request: {:?}",
363                        &response.id
364                    );
365                }
366                None
367            }
368        };
369        Ok(response)
370    }
371
372    pub(crate) async fn store_transport(
373        &self,
374        stream_id: &str,
375        transport: Arc<
376            dyn TransportDispatcher<
377                ClientMessages,
378                MessageFromServer,
379                ClientMessage,
380                ServerMessages,
381                ServerMessage,
382            >,
383        >,
384    ) -> SdkResult<()> {
385        if stream_id != DEFAULT_STREAM_ID {
386            return Ok(());
387        }
388        let mut transport_map = self.transport_map.write().await;
389        tracing::trace!("save transport for stream id : {}", stream_id);
390        transport_map.insert(stream_id.to_string(), transport);
391        Ok(())
392    }
393
394    //TODO: re-visit and simplify unnecessary hashmap
395    pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> {
396        if stream_id != DEFAULT_STREAM_ID {
397            return Ok(());
398        }
399        let transport_map = self.transport_map.read().await;
400        tracing::trace!("removing transport for stream id : {}", stream_id);
401        if let Some(transport) = transport_map.get(stream_id) {
402            transport.shut_down().await?;
403        }
404        // transport_map.remove(stream_id);
405        Ok(())
406    }
407
408    pub(crate) async fn shutdown(&self) {
409        let mut transport_map = self.transport_map.write().await;
410        let items: Vec<_> = transport_map.drain().map(|(_, v)| v).collect();
411        drop(transport_map);
412        for item in items {
413            let _ = item.shut_down().await;
414        }
415    }
416
417    pub(crate) async fn stream_id_exists(&self, stream_id: &str) -> bool {
418        let transport_map = self.transport_map.read().await;
419        let live_transport = if let Some(t) = transport_map.get(stream_id) {
420            !t.is_shut_down().await
421        } else {
422            false
423        };
424        live_transport
425    }
426
427    pub(crate) async fn start_stream(
428        self: Arc<Self>,
429        transport: Arc<
430            dyn TransportDispatcher<
431                ClientMessages,
432                MessageFromServer,
433                ClientMessage,
434                ServerMessages,
435                ServerMessage,
436            >,
437        >,
438        stream_id: &str,
439        ping_interval: Duration,
440        payload: Option<String>,
441    ) -> SdkResult<()> {
442        let mut stream = transport.start().await?;
443
444        if stream_id == DEFAULT_STREAM_ID {
445            self.store_transport(stream_id, transport.clone()).await?;
446        }
447
448        let self_clone = self.clone();
449
450        let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>();
451        let abort_alive_task = transport
452            .keep_alive(ping_interval, disconnect_tx)
453            .await?
454            .abort_handle();
455
456        // ensure keep_alive task will be aborted
457        let _abort_guard = AbortTaskOnDrop {
458            handle: abort_alive_task,
459        };
460
461        // in case there is a payload, we consume it by transport to get processed
462        // payload would be message payload coming from the client
463        if let Some(payload) = payload {
464            if let Err(err) = transport.consume_string_payload(&payload).await {
465                let _ = self.remove_transport(stream_id).await;
466                return Err(err.into());
467            }
468        }
469
470        // Create a channel to collect results from spawned tasks
471        let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY);
472
473        loop {
474            tokio::select! {
475                Some(mcp_messages) = stream.next() =>{
476
477                    match mcp_messages {
478                        ClientMessages::Single(client_message) => {
479                            let transport = transport.clone();
480                            let self_clone = self.clone();
481                            let tx = tx.clone();
482                            tokio::spawn(async move {
483
484                                let result = self_clone.handle_message(client_message, &transport).await;
485
486                                let send_result: SdkResult<_> = match result {
487                                    Ok(result) => {
488                                        if let Some(result) = result {
489                                            transport
490                                                .send_message(ServerMessages::Single(result), None)
491                                                .map_err(|e| e.into())
492                                                .await
493                                        } else {
494                                            Ok(None)
495                                        }
496                                    }
497                                    Err(error) => {
498                                        tracing::error!("Error handling message : {}", error);
499                                        Ok(None)
500                                    }
501                                };
502                                if let Err(error) = tx.send(send_result).await {
503                                    tracing::error!("Failed to send batch result to channel: {}", error);
504                                }
505                            });
506                        }
507                        ClientMessages::Batch(client_messages) => {
508
509                            let transport = transport.clone();
510                            let self_clone = self_clone.clone();
511                            let tx = tx.clone();
512
513                            tokio::spawn(async move {
514                                let handling_tasks: Vec<_> = client_messages
515                                    .into_iter()
516                                    .map(|client_message| self_clone.handle_message(client_message, &transport))
517                                    .collect();
518
519                                    let send_result = match try_join_all(handling_tasks).await {
520                                         Ok(results) => {
521                                             let results: Vec<_> = results.into_iter().flatten().collect();
522                                             if !results.is_empty() {
523                                                 transport.send_message(ServerMessages::Batch(results), None)
524                                                 .map_err(|e| e.into())
525                                                 .await
526                                             }else {
527                                                 Ok(None)
528                                             }
529                                         },
530                                        Err(error) => Err(error),
531                                    };
532                                    if let Err(error) = tx.send(send_result).await {
533                                        tracing::error!("Failed to send batch result to channel: {}", error);
534                                    }
535                            });
536                        }
537                    }
538
539                    // Check for results from spawned tasks to propagate errors
540                    while let Ok(result) = rx.try_recv() {
541                        result?; // Propagate errors
542                    }
543
544                    // close the stream after all messages are sent, unless it is a standalone stream
545                    if !stream_id.eq(DEFAULT_STREAM_ID){
546                        // Drop tx to close the channel and collect remaining results
547                        drop(tx);
548                        while let Some(result) = rx.recv().await {
549                            result?; // Propagate errors
550                        }
551                        return  Ok(());
552                    }
553                }
554                _ = &mut disconnect_rx => {
555                    // Drop tx to close the channel and collect remaining results
556                    drop(tx);
557                    while let Some(result) = rx.recv().await {
558                        result?; // Propagate errors
559                    }
560                                self.remove_transport(stream_id).await?;
561                                // Disconnection detected by keep-alive task
562                                return Err(SdkError::connection_closed().into());
563
564                }
565            }
566        }
567    }
568
569    #[cfg(feature = "hyper-server")]
570    pub(crate) fn new_instance(
571        server_details: Arc<InitializeResult>,
572        handler: Arc<dyn McpServerHandler>,
573        session_id: SessionId,
574        auth_info: Option<AuthInfo>,
575    ) -> Arc<Self> {
576        use tokio::sync::RwLock;
577
578        let (client_details_tx, client_details_rx) =
579            watch::channel::<Option<InitializeRequestParams>>(None);
580        Arc::new(Self {
581            server_details,
582            handler,
583            session_id: Some(session_id),
584            transport_map: tokio::sync::RwLock::new(HashMap::new()),
585            client_details_tx,
586            client_details_rx,
587            request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
588            auth_info: RwLock::new(auth_info),
589        })
590    }
591
592    pub(crate) fn new(
593        server_details: InitializeResult,
594        transport: impl TransportDispatcher<
595            ClientMessages,
596            MessageFromServer,
597            ClientMessage,
598            ServerMessages,
599            ServerMessage,
600        >,
601        handler: Arc<dyn McpServerHandler>,
602    ) -> Arc<Self> {
603        let mut map: HashMap<String, TransportType> = HashMap::new();
604        map.insert(DEFAULT_STREAM_ID.to_string(), Arc::new(transport));
605        let (client_details_tx, client_details_rx) =
606            watch::channel::<Option<InitializeRequestParams>>(None);
607        Arc::new(Self {
608            server_details: Arc::new(server_details),
609            handler,
610            #[cfg(feature = "hyper-server")]
611            session_id: None,
612            transport_map: tokio::sync::RwLock::new(map),
613            client_details_tx,
614            client_details_rx,
615            request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
616            auth_info: RwLock::new(None),
617        })
618    }
619}