rust_mcp_sdk/mcp_runtimes/
server_runtime.rs

1pub mod mcp_server_runtime;
2pub mod mcp_server_runtime_core;
3use crate::error::SdkResult;
4use crate::mcp_traits::mcp_handler::McpServerHandler;
5use crate::mcp_traits::mcp_server::McpServer;
6use crate::mcp_traits::{RequestIdGen, RequestIdGenNumeric};
7use crate::schema::{
8    schema_utils::{
9        ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage,
10        ServerMessages,
11    },
12    InitializeRequestParams, InitializeResult, RequestId, RpcError,
13};
14use crate::utils::AbortTaskOnDrop;
15use async_trait::async_trait;
16use futures::future::try_join_all;
17use futures::{StreamExt, TryFutureExt};
18#[cfg(feature = "hyper-server")]
19use rust_mcp_transport::SessionId;
20use rust_mcp_transport::{IoStream, TransportDispatcher};
21use std::collections::HashMap;
22use std::panic;
23use std::sync::Arc;
24use std::time::Duration;
25use tokio::io::AsyncWriteExt;
26
27use tokio::sync::{mpsc, oneshot, watch};
28
29pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM";
30const TASK_CHANNEL_CAPACITY: usize = 500;
31
32// Define a type alias for the TransportDispatcher trait object
33type TransportType = Arc<
34    dyn TransportDispatcher<
35        ClientMessages,
36        MessageFromServer,
37        ClientMessage,
38        ServerMessages,
39        ServerMessage,
40    >,
41>;
42
43/// Struct representing the runtime core of the MCP server, handling transport and client details
44pub struct ServerRuntime {
45    // The handler for processing MCP messages
46    handler: Arc<dyn McpServerHandler>,
47    // Information about the server
48    server_details: Arc<InitializeResult>,
49    #[cfg(feature = "hyper-server")]
50    session_id: Option<SessionId>,
51    transport_map: tokio::sync::RwLock<HashMap<String, TransportType>>, //TODO: remove the transport_map, we do not need a hashmap for it
52    request_id_gen: Box<dyn RequestIdGen>,
53    client_details_tx: watch::Sender<Option<InitializeRequestParams>>,
54    client_details_rx: watch::Receiver<Option<InitializeRequestParams>>,
55}
56
57#[async_trait]
58impl McpServer for ServerRuntime {
59    /// Set the client details, storing them in client_details
60    async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> {
61        self.client_details_tx
62            .send(Some(client_details))
63            .map_err(|_| {
64                RpcError::internal_error()
65                    .with_message("Failed to set client details".to_string())
66                    .into()
67            })
68    }
69
70    async fn wait_for_initialization(&self) {
71        loop {
72            if self.client_details_rx.borrow().is_some() {
73                return;
74            }
75            let mut rx = self.client_details_rx.clone();
76            rx.changed().await.ok();
77        }
78    }
79
80    async fn send(
81        &self,
82        message: MessageFromServer,
83        request_id: Option<RequestId>,
84        request_timeout: Option<Duration>,
85    ) -> SdkResult<Option<ClientMessage>> {
86        let transport_map = self.transport_map.read().await;
87        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
88            RpcError::internal_error()
89                .with_message("transport stream does not exists or is closed!".to_string()),
90        )?;
91
92        let outgoing_request_id = self
93            .request_id_gen
94            .request_id_for_message(&message, request_id);
95
96        let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?;
97
98        let response = transport
99            .send_message(ServerMessages::Single(mcp_message), request_timeout)
100            .await?
101            .map(|res| res.as_single())
102            .transpose()?;
103
104        Ok(response)
105    }
106
107    async fn send_batch(
108        &self,
109        messages: Vec<ServerMessage>,
110        request_timeout: Option<Duration>,
111    ) -> SdkResult<Option<Vec<ClientMessage>>> {
112        let transport_map = self.transport_map.read().await;
113        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
114            RpcError::internal_error()
115                .with_message("transport stream does not exists or is closed!".to_string()),
116        )?;
117
118        transport
119            .send_batch(messages, request_timeout)
120            .map_err(|err| err.into())
121            .await
122    }
123
124    /// Returns the server's details, including server capability,
125    /// instructions, protocol_version , server_info and optional meta data
126    fn server_info(&self) -> &InitializeResult {
127        &self.server_details
128    }
129
130    /// Returns the client information if available, after successful initialization , otherwise returns None
131    fn client_info(&self) -> Option<InitializeRequestParams> {
132        self.client_details_rx.borrow().clone()
133    }
134
135    /// Main runtime loop, processes incoming messages and handles requests
136    async fn start(self: Arc<Self>) -> SdkResult<()> {
137        let self_clone = self.clone();
138        let transport_map = self_clone.transport_map.read().await;
139
140        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
141            RpcError::internal_error()
142                .with_message("transport stream does not exists or is closed!".to_string()),
143        )?;
144
145        let mut stream = transport.start().await?;
146
147        // Create a channel to collect results from spawned tasks
148        let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY);
149
150        // Process incoming messages from the client
151        while let Some(mcp_messages) = stream.next().await {
152            match mcp_messages {
153                ClientMessages::Single(client_message) => {
154                    let transport = transport.clone();
155                    let self = self.clone();
156                    let tx = tx.clone();
157
158                    // Handle incoming messages in a separate task to avoid blocking the stream.
159                    tokio::spawn(async move {
160                        let result = self.handle_message(client_message, &transport).await;
161
162                        let send_result: SdkResult<_> = match result {
163                            Ok(result) => {
164                                if let Some(result) = result {
165                                    transport
166                                        .send_message(ServerMessages::Single(result), None)
167                                        .map_err(|e| e.into())
168                                        .await
169                                } else {
170                                    Ok(None)
171                                }
172                            }
173                            Err(error) => {
174                                tracing::error!("Error handling message : {}", error);
175                                Ok(None)
176                            }
177                        };
178                        // Send result to the main loop
179                        if let Err(error) = tx.send(send_result).await {
180                            tracing::error!("Failed to send result to channel: {}", error);
181                        }
182                    });
183                }
184                ClientMessages::Batch(client_messages) => {
185                    let transport = transport.clone();
186                    let self = self_clone.clone();
187                    let tx = tx.clone();
188
189                    tokio::spawn(async move {
190                        let handling_tasks: Vec<_> = client_messages
191                            .into_iter()
192                            .map(|client_message| self.handle_message(client_message, &transport))
193                            .collect();
194
195                        let send_result = match try_join_all(handling_tasks).await {
196                            Ok(results) => {
197                                let results: Vec<_> = results.into_iter().flatten().collect();
198                                if !results.is_empty() {
199                                    transport
200                                        .send_message(ServerMessages::Batch(results), None)
201                                        .map_err(|e| e.into())
202                                        .await
203                                } else {
204                                    Ok(None)
205                                }
206                            }
207                            Err(error) => Err(error),
208                        };
209
210                        if let Err(error) = tx.send(send_result).await {
211                            tracing::error!("Failed to send batch result to channel: {}", error);
212                        }
213                    });
214                }
215            }
216
217            // Check for results from spawned tasks to propagate errors
218            while let Ok(result) = rx.try_recv() {
219                result?; // Propagate errors
220            }
221        }
222
223        // Drop tx to close the channel and collect remaining results
224        drop(tx);
225        while let Some(result) = rx.recv().await {
226            result?; // Propagate errors
227        }
228
229        return Ok(());
230    }
231
232    async fn stderr_message(&self, message: String) -> SdkResult<()> {
233        let transport_map = self.transport_map.read().await;
234        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
235            RpcError::internal_error()
236                .with_message("transport stream does not exists or is closed!".to_string()),
237        )?;
238        let mut lock = transport.error_stream().write().await;
239
240        if let Some(IoStream::Writable(stderr)) = lock.as_mut() {
241            stderr.write_all(message.as_bytes()).await?;
242            stderr.write_all(b"\n").await?;
243            stderr.flush().await?;
244        }
245        Ok(())
246    }
247
248    #[cfg(feature = "hyper-server")]
249    fn session_id(&self) -> Option<SessionId> {
250        self.session_id.to_owned()
251    }
252}
253
254impl ServerRuntime {
255    pub(crate) async fn consume_payload_string(
256        &self,
257        stream_id: &str,
258        payload: &str,
259    ) -> SdkResult<()> {
260        let transport_map = self.transport_map.read().await;
261
262        let transport = transport_map.get(stream_id).ok_or(
263            RpcError::internal_error()
264                .with_message("stream id does not exists or is closed!".to_string()),
265        )?;
266
267        transport.consume_string_payload(payload).await?;
268
269        Ok(())
270    }
271
272    pub(crate) async fn handle_message(
273        self: &Arc<Self>,
274        message: ClientMessage,
275        transport: &Arc<
276            dyn TransportDispatcher<
277                ClientMessages,
278                MessageFromServer,
279                ClientMessage,
280                ServerMessages,
281                ServerMessage,
282            >,
283        >,
284    ) -> SdkResult<Option<ServerMessage>> {
285        let response = match message {
286            // Handle a client request
287            ClientMessage::Request(client_jsonrpc_request) => {
288                let result = self
289                    .handler
290                    .handle_request(client_jsonrpc_request.request, self.clone())
291                    .await;
292                // create a response to send back to the client
293                let response: MessageFromServer = match result {
294                    Ok(success_value) => success_value.into(),
295                    Err(error_value) => {
296                        // Error occurred during initialization.
297                        // A likely cause could be an unsupported protocol version.
298                        if !self.is_initialized() {
299                            return Err(error_value.into());
300                        }
301                        MessageFromServer::Error(error_value)
302                    }
303                };
304
305                let mpc_message: ServerMessage =
306                    ServerMessage::from_message(response, Some(client_jsonrpc_request.id))?;
307
308                Some(mpc_message)
309            }
310            ClientMessage::Notification(client_jsonrpc_notification) => {
311                self.handler
312                    .handle_notification(client_jsonrpc_notification.notification, self.clone())
313                    .await?;
314                None
315            }
316            ClientMessage::Error(jsonrpc_error) => {
317                self.handler
318                    .handle_error(&jsonrpc_error.error, self.clone())
319                    .await?;
320                if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await {
321                    tx_response
322                        .send(ClientMessage::Error(jsonrpc_error))
323                        .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
324                } else {
325                    tracing::warn!(
326                        "Received an error response with no corresponding request {:?}",
327                        &jsonrpc_error.id
328                    );
329                }
330                None
331            }
332            ClientMessage::Response(response) => {
333                if let Some(tx_response) = transport.pending_request_tx(&response.id).await {
334                    tx_response
335                        .send(ClientMessage::Response(response))
336                        .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
337                } else {
338                    tracing::warn!(
339                        "Received a response with no corresponding request: {:?}",
340                        &response.id
341                    );
342                }
343                None
344            }
345        };
346        Ok(response)
347    }
348
349    pub(crate) async fn store_transport(
350        &self,
351        stream_id: &str,
352        transport: Arc<
353            dyn TransportDispatcher<
354                ClientMessages,
355                MessageFromServer,
356                ClientMessage,
357                ServerMessages,
358                ServerMessage,
359            >,
360        >,
361    ) -> SdkResult<()> {
362        if stream_id != DEFAULT_STREAM_ID {
363            return Ok(());
364        }
365        let mut transport_map = self.transport_map.write().await;
366        tracing::trace!("save transport for stream id : {}", stream_id);
367        transport_map.insert(stream_id.to_string(), transport);
368        Ok(())
369    }
370
371    //TODO: re-visit and simplify unnecessary hashmap
372    pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> {
373        if stream_id != DEFAULT_STREAM_ID {
374            return Ok(());
375        }
376        let transport_map = self.transport_map.read().await;
377        tracing::trace!("removing transport for stream id : {}", stream_id);
378        if let Some(transport) = transport_map.get(stream_id) {
379            transport.shut_down().await?;
380        }
381        // transport_map.remove(stream_id);
382        Ok(())
383    }
384
385    pub(crate) async fn shutdown(&self) {
386        let mut transport_map = self.transport_map.write().await;
387        let items: Vec<_> = transport_map.drain().map(|(_, v)| v).collect();
388        drop(transport_map);
389        for item in items {
390            let _ = item.shut_down().await;
391        }
392    }
393
394    pub(crate) async fn stream_id_exists(&self, stream_id: &str) -> bool {
395        let transport_map = self.transport_map.read().await;
396        let live_transport = if let Some(t) = transport_map.get(stream_id) {
397            !t.is_shut_down().await
398        } else {
399            false
400        };
401        live_transport
402    }
403
404    pub(crate) async fn start_stream(
405        self: Arc<Self>,
406        transport: Arc<
407            dyn TransportDispatcher<
408                ClientMessages,
409                MessageFromServer,
410                ClientMessage,
411                ServerMessages,
412                ServerMessage,
413            >,
414        >,
415        stream_id: &str,
416        ping_interval: Duration,
417        payload: Option<String>,
418    ) -> SdkResult<()> {
419        let mut stream = transport.start().await?;
420
421        if stream_id == DEFAULT_STREAM_ID {
422            self.store_transport(stream_id, transport.clone()).await?;
423        }
424
425        let self_clone = self.clone();
426
427        let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>();
428        let abort_alive_task = transport
429            .keep_alive(ping_interval, disconnect_tx)
430            .await?
431            .abort_handle();
432
433        // ensure keep_alive task will be aborted
434        let _abort_guard = AbortTaskOnDrop {
435            handle: abort_alive_task,
436        };
437
438        // in case there is a payload, we consume it by transport to get processed
439        // payload would be message payload coming from the client
440        if let Some(payload) = payload {
441            if let Err(err) = transport.consume_string_payload(&payload).await {
442                let _ = self.remove_transport(stream_id).await;
443                return Err(err.into());
444            }
445        }
446
447        // Create a channel to collect results from spawned tasks
448        let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY);
449
450        loop {
451            tokio::select! {
452                Some(mcp_messages) = stream.next() =>{
453
454                    match mcp_messages {
455                        ClientMessages::Single(client_message) => {
456                            let transport = transport.clone();
457                            let self_clone = self.clone();
458                            let tx = tx.clone();
459                            tokio::spawn(async move {
460
461                                let result = self_clone.handle_message(client_message, &transport).await;
462
463                                let send_result: SdkResult<_> = match result {
464                                    Ok(result) => {
465                                        if let Some(result) = result {
466                                            transport
467                                                .send_message(ServerMessages::Single(result), None)
468                                                .map_err(|e| e.into())
469                                                .await
470                                        } else {
471                                            Ok(None)
472                                        }
473                                    }
474                                    Err(error) => {
475                                        tracing::error!("Error handling message : {}", error);
476                                        Ok(None)
477                                    }
478                                };
479                                if let Err(error) = tx.send(send_result).await {
480                                    tracing::error!("Failed to send batch result to channel: {}", error);
481                                }
482                            });
483                        }
484                        ClientMessages::Batch(client_messages) => {
485
486                            let transport = transport.clone();
487                            let self_clone = self_clone.clone();
488                            let tx = tx.clone();
489
490                            tokio::spawn(async move {
491                                let handling_tasks: Vec<_> = client_messages
492                                    .into_iter()
493                                    .map(|client_message| self_clone.handle_message(client_message, &transport))
494                                    .collect();
495
496                                    let send_result = match try_join_all(handling_tasks).await {
497                                         Ok(results) => {
498                                             let results: Vec<_> = results.into_iter().flatten().collect();
499                                             if !results.is_empty() {
500                                                 transport.send_message(ServerMessages::Batch(results), None)
501                                                 .map_err(|e| e.into())
502                                                 .await
503                                             }else {
504                                                 Ok(None)
505                                             }
506                                         },
507                                        Err(error) => Err(error),
508                                    };
509                                    if let Err(error) = tx.send(send_result).await {
510                                        tracing::error!("Failed to send batch result to channel: {}", error);
511                                    }
512                            });
513                        }
514                    }
515
516                    // Check for results from spawned tasks to propagate errors
517                    while let Ok(result) = rx.try_recv() {
518                        result?; // Propagate errors
519                    }
520
521                    // close the stream after all messages are sent, unless it is a standalone stream
522                    if !stream_id.eq(DEFAULT_STREAM_ID){
523                        // Drop tx to close the channel and collect remaining results
524                        drop(tx);
525                        while let Some(result) = rx.recv().await {
526                            result?; // Propagate errors
527                        }
528                        return  Ok(());
529                    }
530                }
531                _ = &mut disconnect_rx => {
532                    // Drop tx to close the channel and collect remaining results
533                    drop(tx);
534                    while let Some(result) = rx.recv().await {
535                        result?; // Propagate errors
536                    }
537                                self.remove_transport(stream_id).await?;
538                                // Disconnection detected by keep-alive task
539                                return Err(SdkError::connection_closed().into());
540
541                }
542            }
543        }
544    }
545
546    #[cfg(feature = "hyper-server")]
547    pub(crate) fn new_instance(
548        server_details: Arc<InitializeResult>,
549        handler: Arc<dyn McpServerHandler>,
550        session_id: SessionId,
551    ) -> Arc<Self> {
552        let (client_details_tx, client_details_rx) =
553            watch::channel::<Option<InitializeRequestParams>>(None);
554        Arc::new(Self {
555            server_details,
556            handler,
557            session_id: Some(session_id),
558            transport_map: tokio::sync::RwLock::new(HashMap::new()),
559            client_details_tx,
560            client_details_rx,
561            request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
562        })
563    }
564
565    pub(crate) fn new(
566        server_details: InitializeResult,
567        transport: impl TransportDispatcher<
568            ClientMessages,
569            MessageFromServer,
570            ClientMessage,
571            ServerMessages,
572            ServerMessage,
573        >,
574        handler: Arc<dyn McpServerHandler>,
575    ) -> Arc<Self> {
576        let mut map: HashMap<String, TransportType> = HashMap::new();
577        map.insert(DEFAULT_STREAM_ID.to_string(), Arc::new(transport));
578        let (client_details_tx, client_details_rx) =
579            watch::channel::<Option<InitializeRequestParams>>(None);
580        Arc::new(Self {
581            server_details: Arc::new(server_details),
582            handler,
583            #[cfg(feature = "hyper-server")]
584            session_id: None,
585            transport_map: tokio::sync::RwLock::new(map),
586            client_details_tx,
587            client_details_rx,
588            request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
589        })
590    }
591}