rust_mcp_sdk/mcp_runtimes/
server_runtime.rs

1pub mod mcp_server_runtime;
2pub mod mcp_server_runtime_core;
3use crate::error::SdkResult;
4use crate::mcp_traits::mcp_handler::McpServerHandler;
5use crate::mcp_traits::mcp_server::McpServer;
6use crate::mcp_traits::{RequestIdGen, RequestIdGenNumeric};
7use crate::schema::{
8    schema_utils::{
9        ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage,
10        ServerMessages,
11    },
12    InitializeRequestParams, InitializeResult, RequestId, RpcError,
13};
14use crate::utils::AbortTaskOnDrop;
15use async_trait::async_trait;
16use futures::future::try_join_all;
17use futures::{StreamExt, TryFutureExt};
18#[cfg(feature = "hyper-server")]
19use rust_mcp_transport::SessionId;
20use rust_mcp_transport::{IoStream, TransportDispatcher};
21use std::collections::HashMap;
22use std::sync::Arc;
23use std::time::Duration;
24use tokio::io::AsyncWriteExt;
25use tokio::sync::{oneshot, watch};
26
27pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM";
28
29// Define a type alias for the TransportDispatcher trait object
30type TransportType = Arc<
31    dyn TransportDispatcher<
32        ClientMessages,
33        MessageFromServer,
34        ClientMessage,
35        ServerMessages,
36        ServerMessage,
37    >,
38>;
39
40/// Struct representing the runtime core of the MCP server, handling transport and client details
41pub struct ServerRuntime {
42    // The handler for processing MCP messages
43    handler: Arc<dyn McpServerHandler>,
44    // Information about the server
45    server_details: Arc<InitializeResult>,
46    #[cfg(feature = "hyper-server")]
47    session_id: Option<SessionId>,
48    transport_map: tokio::sync::RwLock<HashMap<String, TransportType>>,
49    request_id_gen: Box<dyn RequestIdGen>,
50    client_details_tx: watch::Sender<Option<InitializeRequestParams>>,
51    client_details_rx: watch::Receiver<Option<InitializeRequestParams>>,
52}
53
54#[async_trait]
55impl McpServer for ServerRuntime {
56    /// Set the client details, storing them in client_details
57    async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> {
58        self.handler.on_server_started(self).await;
59
60        self.client_details_tx
61            .send(Some(client_details))
62            .map_err(|_| {
63                RpcError::internal_error()
64                    .with_message("Failed to set client details".to_string())
65                    .into()
66            })
67    }
68
69    async fn wait_for_initialization(&self) {
70        loop {
71            if self.client_details_rx.borrow().is_some() {
72                return;
73            }
74            let mut rx = self.client_details_rx.clone();
75            rx.changed().await.ok();
76        }
77    }
78
79    async fn send(
80        &self,
81        message: MessageFromServer,
82        request_id: Option<RequestId>,
83        request_timeout: Option<Duration>,
84    ) -> SdkResult<Option<ClientMessage>> {
85        let transport_map = self.transport_map.read().await;
86        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
87            RpcError::internal_error()
88                .with_message("transport stream does not exists or is closed!".to_string()),
89        )?;
90
91        let outgoing_request_id = self
92            .request_id_gen
93            .request_id_for_message(&message, request_id);
94
95        let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?;
96
97        let response = transport
98            .send_message(ServerMessages::Single(mcp_message), request_timeout)
99            .await?
100            .map(|res| res.as_single())
101            .transpose()?;
102
103        Ok(response)
104    }
105
106    async fn send_batch(
107        &self,
108        messages: Vec<ServerMessage>,
109        request_timeout: Option<Duration>,
110    ) -> SdkResult<Option<Vec<ClientMessage>>> {
111        let transport_map = self.transport_map.read().await;
112        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
113            RpcError::internal_error()
114                .with_message("transport stream does not exists or is closed!".to_string()),
115        )?;
116
117        transport
118            .send_batch(messages, request_timeout)
119            .map_err(|err| err.into())
120            .await
121    }
122
123    /// Returns the server's details, including server capability,
124    /// instructions, protocol_version , server_info and optional meta data
125    fn server_info(&self) -> &InitializeResult {
126        &self.server_details
127    }
128
129    /// Returns the client information if available, after successful initialization , otherwise returns None
130    fn client_info(&self) -> Option<InitializeRequestParams> {
131        self.client_details_rx.borrow().clone()
132    }
133
134    /// Main runtime loop, processes incoming messages and handles requests
135    async fn start(&self) -> SdkResult<()> {
136        let transport_map = self.transport_map.read().await;
137
138        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
139            RpcError::internal_error()
140                .with_message("transport stream does not exists or is closed!".to_string()),
141        )?;
142
143        let mut stream = transport.start().await?;
144
145        // Process incoming messages from the client
146        while let Some(mcp_messages) = stream.next().await {
147            match mcp_messages {
148                ClientMessages::Single(client_message) => {
149                    let result = self.handle_message(client_message, transport).await;
150
151                    match result {
152                        Ok(result) => {
153                            if let Some(result) = result {
154                                transport
155                                    .send_message(ServerMessages::Single(result), None)
156                                    .await?;
157                            }
158                        }
159                        Err(error) => {
160                            tracing::error!("Error handling message : {}", error)
161                        }
162                    }
163                }
164                ClientMessages::Batch(client_messages) => {
165                    let handling_tasks: Vec<_> = client_messages
166                        .into_iter()
167                        .map(|client_message| self.handle_message(client_message, transport))
168                        .collect();
169
170                    let results: Vec<_> = try_join_all(handling_tasks).await?;
171
172                    let results: Vec<_> = results.into_iter().flatten().collect();
173
174                    if !results.is_empty() {
175                        transport
176                            .send_message(ServerMessages::Batch(results), None)
177                            .await?;
178                    }
179                }
180            }
181        }
182        return Ok(());
183    }
184
185    async fn stderr_message(&self, message: String) -> SdkResult<()> {
186        let transport_map = self.transport_map.read().await;
187        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
188            RpcError::internal_error()
189                .with_message("transport stream does not exists or is closed!".to_string()),
190        )?;
191        let mut lock = transport.error_stream().write().await;
192
193        if let Some(IoStream::Writable(stderr)) = lock.as_mut() {
194            stderr.write_all(message.as_bytes()).await?;
195            stderr.write_all(b"\n").await?;
196            stderr.flush().await?;
197        }
198        Ok(())
199    }
200
201    #[cfg(feature = "hyper-server")]
202    fn session_id(&self) -> Option<SessionId> {
203        self.session_id.to_owned()
204    }
205}
206
207impl ServerRuntime {
208    pub(crate) async fn consume_payload_string(
209        &self,
210        stream_id: &str,
211        payload: &str,
212    ) -> SdkResult<()> {
213        let transport_map = self.transport_map.read().await;
214
215        let transport = transport_map.get(stream_id).ok_or(
216            RpcError::internal_error()
217                .with_message("stream id does not exists or is closed!".to_string()),
218        )?;
219
220        transport.consume_string_payload(payload).await?;
221
222        Ok(())
223    }
224
225    pub(crate) async fn handle_message(
226        &self,
227        message: ClientMessage,
228        transport: &Arc<
229            dyn TransportDispatcher<
230                ClientMessages,
231                MessageFromServer,
232                ClientMessage,
233                ServerMessages,
234                ServerMessage,
235            >,
236        >,
237    ) -> SdkResult<Option<ServerMessage>> {
238        let response = match message {
239            // Handle a client request
240            ClientMessage::Request(client_jsonrpc_request) => {
241                let result = self
242                    .handler
243                    .handle_request(client_jsonrpc_request.request, self)
244                    .await;
245                // create a response to send back to the client
246                let response: MessageFromServer = match result {
247                    Ok(success_value) => success_value.into(),
248                    Err(error_value) => {
249                        // Error occurred during initialization.
250                        // A likely cause could be an unsupported protocol version.
251                        if !self.is_initialized() {
252                            return Err(error_value.into());
253                        }
254                        MessageFromServer::Error(error_value)
255                    }
256                };
257
258                let mpc_message: ServerMessage =
259                    ServerMessage::from_message(response, Some(client_jsonrpc_request.id))?;
260
261                Some(mpc_message)
262            }
263            ClientMessage::Notification(client_jsonrpc_notification) => {
264                self.handler
265                    .handle_notification(client_jsonrpc_notification.notification, self)
266                    .await?;
267                None
268            }
269            ClientMessage::Error(jsonrpc_error) => {
270                self.handler
271                    .handle_error(&jsonrpc_error.error, self)
272                    .await?;
273                if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await {
274                    tx_response
275                        .send(ClientMessage::Error(jsonrpc_error))
276                        .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
277                } else {
278                    tracing::warn!(
279                        "Received an error response with no corresponding request {:?}",
280                        &jsonrpc_error.id
281                    );
282                }
283                None
284            }
285            // The response is the result of a request, it is processed at the transport level.
286            ClientMessage::Response(response) => {
287                if let Some(tx_response) = transport.pending_request_tx(&response.id).await {
288                    tx_response
289                        .send(ClientMessage::Response(response))
290                        .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
291                } else {
292                    tracing::warn!(
293                        "Received a response with no corresponding request: {:?}",
294                        &response.id
295                    );
296                }
297                None
298            }
299        };
300        Ok(response)
301    }
302
303    pub(crate) async fn store_transport(
304        &self,
305        stream_id: &str,
306        transport: Arc<
307            dyn TransportDispatcher<
308                ClientMessages,
309                MessageFromServer,
310                ClientMessage,
311                ServerMessages,
312                ServerMessage,
313            >,
314        >,
315    ) -> SdkResult<()> {
316        let mut transport_map = self.transport_map.write().await;
317        tracing::trace!("save transport for stream id : {}", stream_id);
318        transport_map.insert(stream_id.to_string(), transport);
319        Ok(())
320    }
321
322    pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> {
323        let mut transport_map = self.transport_map.write().await;
324        tracing::trace!("removing transport for stream id : {}", stream_id);
325        transport_map.remove(stream_id);
326        Ok(())
327    }
328
329    pub(crate) async fn transport_by_stream(
330        &self,
331        stream_id: &str,
332    ) -> SdkResult<
333        Arc<
334            dyn TransportDispatcher<
335                ClientMessages,
336                MessageFromServer,
337                ClientMessage,
338                ServerMessages,
339                ServerMessage,
340            >,
341        >,
342    > {
343        let transport_map = self.transport_map.read().await;
344        transport_map.get(stream_id).cloned().ok_or_else(|| {
345            RpcError::internal_error()
346                .with_message(format!("Transport for key {stream_id} not found"))
347                .into()
348        })
349    }
350
351    pub(crate) async fn shutdown(&self) {
352        let mut transport_map = self.transport_map.write().await;
353        let items: Vec<_> = transport_map.drain().map(|(_, v)| v).collect();
354        drop(transport_map);
355        for item in items {
356            let _ = item.shut_down().await;
357        }
358    }
359
360    pub(crate) async fn stream_id_exists(&self, stream_id: &str) -> bool {
361        let transport_map = self.transport_map.read().await;
362        transport_map.contains_key(stream_id)
363    }
364
365    pub(crate) async fn start_stream(
366        self: Arc<Self>,
367        transport: impl TransportDispatcher<
368            ClientMessages,
369            MessageFromServer,
370            ClientMessage,
371            ServerMessages,
372            ServerMessage,
373        >,
374        stream_id: &str,
375        ping_interval: Duration,
376        payload: Option<String>,
377    ) -> SdkResult<()> {
378        let mut stream = transport.start().await?;
379
380        self.store_transport(stream_id, Arc::new(transport)).await?;
381
382        let transport = self.transport_by_stream(stream_id).await?;
383
384        let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>();
385        let abort_alive_task = transport
386            .keep_alive(ping_interval, disconnect_tx)
387            .await?
388            .abort_handle();
389
390        // ensure keep_alive task will be aborted
391        let _abort_guard = AbortTaskOnDrop {
392            handle: abort_alive_task,
393        };
394
395        // in case there is a payload, we consume it by transport to get processed
396        if let Some(payload) = payload {
397            transport.consume_string_payload(&payload).await?;
398        }
399
400        loop {
401            tokio::select! {
402                Some(mcp_messages) = stream.next() =>{
403
404                    match mcp_messages {
405                        ClientMessages::Single(client_message) => {
406                            let result = self.handle_message(client_message, &transport).await?;
407                            if let Some(result) = result {
408                                transport.send_message(ServerMessages::Single(result), None).await?;
409                            }
410                        }
411                        ClientMessages::Batch(client_messages) => {
412
413                            let handling_tasks: Vec<_> = client_messages
414                                .into_iter()
415                                .map(|client_message| self.handle_message(client_message, &transport))
416                                .collect();
417
418                            let results: Vec<_> = try_join_all(handling_tasks).await?;
419
420                            let results: Vec<_> = results.into_iter().flatten().collect();
421
422
423                            if !results.is_empty() {
424                                transport.send_message(ServerMessages::Batch(results), None).await?;
425                            }
426                        }
427                    }
428                    // close the stream after all messages are sent, unless it is a standalone stream
429                    if !stream_id.eq(DEFAULT_STREAM_ID){
430                        return  Ok(());
431                    }
432                }
433                _ = &mut disconnect_rx => {
434                                self.remove_transport(stream_id).await?;
435                                // Disconnection detected by keep-alive task
436                                return Err(SdkError::connection_closed().into());
437
438                }
439            }
440        }
441    }
442
443    #[cfg(feature = "hyper-server")]
444    pub(crate) fn new_instance(
445        server_details: Arc<InitializeResult>,
446        handler: Arc<dyn McpServerHandler>,
447        session_id: SessionId,
448    ) -> Self {
449        let (client_details_tx, client_details_rx) =
450            watch::channel::<Option<InitializeRequestParams>>(None);
451        Self {
452            server_details,
453            handler,
454            session_id: Some(session_id),
455            transport_map: tokio::sync::RwLock::new(HashMap::new()),
456            client_details_tx,
457            client_details_rx,
458            request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
459        }
460    }
461
462    pub(crate) fn new(
463        server_details: InitializeResult,
464        transport: impl TransportDispatcher<
465            ClientMessages,
466            MessageFromServer,
467            ClientMessage,
468            ServerMessages,
469            ServerMessage,
470        >,
471        handler: Arc<dyn McpServerHandler>,
472    ) -> Self {
473        let mut map: HashMap<String, TransportType> = HashMap::new();
474        map.insert(DEFAULT_STREAM_ID.to_string(), Arc::new(transport));
475        let (client_details_tx, client_details_rx) =
476            watch::channel::<Option<InitializeRequestParams>>(None);
477        Self {
478            server_details: Arc::new(server_details),
479            handler,
480            #[cfg(feature = "hyper-server")]
481            session_id: None,
482            transport_map: tokio::sync::RwLock::new(map),
483            client_details_tx,
484            client_details_rx,
485            request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
486        }
487    }
488}