rust_mcp_sdk/mcp_runtimes/
server_runtime.rs

1pub mod mcp_server_runtime;
2pub mod mcp_server_runtime_core;
3use crate::error::SdkResult;
4use crate::mcp_traits::mcp_handler::McpServerHandler;
5use crate::mcp_traits::mcp_server::McpServer;
6use crate::mcp_traits::{RequestIdGen, RequestIdGenNumeric};
7use crate::schema::{
8    schema_utils::{
9        ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage,
10        ServerMessages,
11    },
12    InitializeRequestParams, InitializeResult, RequestId, RpcError,
13};
14use crate::utils::AbortTaskOnDrop;
15use async_trait::async_trait;
16use futures::future::try_join_all;
17use futures::{StreamExt, TryFutureExt};
18#[cfg(feature = "hyper-server")]
19use rust_mcp_transport::SessionId;
20use rust_mcp_transport::{IoStream, TransportDispatcher};
21use std::collections::HashMap;
22use std::sync::Arc;
23use std::time::Duration;
24use tokio::io::AsyncWriteExt;
25use tokio::sync::{oneshot, watch};
26
27pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM";
28
29// Define a type alias for the TransportDispatcher trait object
30type TransportType = Arc<
31    dyn TransportDispatcher<
32        ClientMessages,
33        MessageFromServer,
34        ClientMessage,
35        ServerMessages,
36        ServerMessage,
37    >,
38>;
39
40/// Struct representing the runtime core of the MCP server, handling transport and client details
41pub struct ServerRuntime {
42    // The handler for processing MCP messages
43    handler: Arc<dyn McpServerHandler>,
44    // Information about the server
45    server_details: Arc<InitializeResult>,
46    #[cfg(feature = "hyper-server")]
47    session_id: Option<SessionId>,
48    transport_map: tokio::sync::RwLock<HashMap<String, TransportType>>,
49    request_id_gen: Box<dyn RequestIdGen>,
50    client_details_tx: watch::Sender<Option<InitializeRequestParams>>,
51    client_details_rx: watch::Receiver<Option<InitializeRequestParams>>,
52}
53
54#[async_trait]
55impl McpServer for ServerRuntime {
56    /// Set the client details, storing them in client_details
57    async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> {
58        self.handler.on_server_started(self).await;
59
60        self.client_details_tx
61            .send(Some(client_details))
62            .map_err(|_| {
63                RpcError::internal_error()
64                    .with_message("Failed to set client details".to_string())
65                    .into()
66            })
67    }
68
69    async fn wait_for_initialization(&self) {
70        loop {
71            if self.client_details_rx.borrow().is_some() {
72                return;
73            }
74            let mut rx = self.client_details_rx.clone();
75            rx.changed().await.ok();
76        }
77    }
78
79    async fn send(
80        &self,
81        message: MessageFromServer,
82        request_id: Option<RequestId>,
83        request_timeout: Option<Duration>,
84    ) -> SdkResult<Option<ClientMessage>> {
85        let transport_map = self.transport_map.read().await;
86        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
87            RpcError::internal_error()
88                .with_message("transport stream does not exists or is closed!".to_string()),
89        )?;
90
91        let outgoing_request_id = self
92            .request_id_gen
93            .request_id_for_message(&message, request_id);
94
95        let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?;
96
97        let response = transport
98            .send_message(ServerMessages::Single(mcp_message), request_timeout)
99            .await?
100            .map(|res| res.as_single())
101            .transpose()?;
102
103        Ok(response)
104    }
105
106    async fn send_batch(
107        &self,
108        messages: Vec<ServerMessage>,
109        request_timeout: Option<Duration>,
110    ) -> SdkResult<Option<Vec<ClientMessage>>> {
111        let transport_map = self.transport_map.read().await;
112        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
113            RpcError::internal_error()
114                .with_message("transport stream does not exists or is closed!".to_string()),
115        )?;
116
117        transport
118            .send_batch(messages, request_timeout)
119            .map_err(|err| err.into())
120            .await
121    }
122
123    /// Returns the server's details, including server capability,
124    /// instructions, protocol_version , server_info and optional meta data
125    fn server_info(&self) -> &InitializeResult {
126        &self.server_details
127    }
128
129    /// Returns the client information if available, after successful initialization , otherwise returns None
130    fn client_info(&self) -> Option<InitializeRequestParams> {
131        self.client_details_rx.borrow().clone()
132    }
133
134    /// Main runtime loop, processes incoming messages and handles requests
135    async fn start(&self) -> SdkResult<()> {
136        let transport_map = self.transport_map.read().await;
137
138        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
139            RpcError::internal_error()
140                .with_message("transport stream does not exists or is closed!".to_string()),
141        )?;
142
143        let mut stream = transport.start().await?;
144
145        // Process incoming messages from the client
146        while let Some(mcp_messages) = stream.next().await {
147            match mcp_messages {
148                ClientMessages::Single(client_message) => {
149                    let result = self.handle_message(client_message, transport).await;
150
151                    match result {
152                        Ok(result) => {
153                            if let Some(result) = result {
154                                transport
155                                    .send_message(ServerMessages::Single(result), None)
156                                    .await?;
157                            }
158                        }
159                        Err(error) => {
160                            tracing::error!("Error handling message : {}", error)
161                        }
162                    }
163                }
164                ClientMessages::Batch(client_messages) => {
165                    let handling_tasks: Vec<_> = client_messages
166                        .into_iter()
167                        .map(|client_message| self.handle_message(client_message, transport))
168                        .collect();
169
170                    let results: Vec<_> = try_join_all(handling_tasks).await?;
171
172                    let results: Vec<_> = results.into_iter().flatten().collect();
173
174                    if !results.is_empty() {
175                        transport
176                            .send_message(ServerMessages::Batch(results), None)
177                            .await?;
178                    }
179                }
180            }
181        }
182        return Ok(());
183    }
184
185    async fn stderr_message(&self, message: String) -> SdkResult<()> {
186        let transport_map = self.transport_map.read().await;
187        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
188            RpcError::internal_error()
189                .with_message("transport stream does not exists or is closed!".to_string()),
190        )?;
191        let mut lock = transport.error_stream().write().await;
192
193        if let Some(IoStream::Writable(stderr)) = lock.as_mut() {
194            stderr.write_all(message.as_bytes()).await?;
195            stderr.write_all(b"\n").await?;
196            stderr.flush().await?;
197        }
198        Ok(())
199    }
200}
201
202impl ServerRuntime {
203    pub(crate) async fn consume_payload_string(
204        &self,
205        stream_id: &str,
206        payload: &str,
207    ) -> SdkResult<()> {
208        let transport_map = self.transport_map.read().await;
209
210        let transport = transport_map.get(stream_id).ok_or(
211            RpcError::internal_error()
212                .with_message("stream id does not exists or is closed!".to_string()),
213        )?;
214
215        transport.consume_string_payload(payload).await?;
216
217        Ok(())
218    }
219
220    pub(crate) async fn handle_message(
221        &self,
222        message: ClientMessage,
223        transport: &Arc<
224            dyn TransportDispatcher<
225                ClientMessages,
226                MessageFromServer,
227                ClientMessage,
228                ServerMessages,
229                ServerMessage,
230            >,
231        >,
232    ) -> SdkResult<Option<ServerMessage>> {
233        let response = match message {
234            // Handle a client request
235            ClientMessage::Request(client_jsonrpc_request) => {
236                let result = self
237                    .handler
238                    .handle_request(client_jsonrpc_request.request, self)
239                    .await;
240                // create a response to send back to the client
241                let response: MessageFromServer = match result {
242                    Ok(success_value) => success_value.into(),
243                    Err(error_value) => {
244                        // Error occurred during initialization.
245                        // A likely cause could be an unsupported protocol version.
246                        if !self.is_initialized() {
247                            return Err(error_value.into());
248                        }
249                        MessageFromServer::Error(error_value)
250                    }
251                };
252
253                let mpc_message: ServerMessage =
254                    ServerMessage::from_message(response, Some(client_jsonrpc_request.id))?;
255
256                Some(mpc_message)
257            }
258            ClientMessage::Notification(client_jsonrpc_notification) => {
259                self.handler
260                    .handle_notification(client_jsonrpc_notification.notification, self)
261                    .await?;
262                None
263            }
264            ClientMessage::Error(jsonrpc_error) => {
265                self.handler
266                    .handle_error(&jsonrpc_error.error, self)
267                    .await?;
268                if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await {
269                    tx_response
270                        .send(ClientMessage::Error(jsonrpc_error))
271                        .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
272                } else {
273                    tracing::warn!(
274                        "Received an error response with no corresponding request {:?}",
275                        &jsonrpc_error.id
276                    );
277                }
278                None
279            }
280            // The response is the result of a request, it is processed at the transport level.
281            ClientMessage::Response(response) => {
282                if let Some(tx_response) = transport.pending_request_tx(&response.id).await {
283                    tx_response
284                        .send(ClientMessage::Response(response))
285                        .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
286                } else {
287                    tracing::warn!(
288                        "Received a response with no corresponding request: {:?}",
289                        &response.id
290                    );
291                }
292                None
293            }
294        };
295        Ok(response)
296    }
297
298    pub(crate) async fn store_transport(
299        &self,
300        stream_id: &str,
301        transport: Arc<
302            dyn TransportDispatcher<
303                ClientMessages,
304                MessageFromServer,
305                ClientMessage,
306                ServerMessages,
307                ServerMessage,
308            >,
309        >,
310    ) -> SdkResult<()> {
311        let mut transport_map = self.transport_map.write().await;
312        tracing::trace!("save transport for stream id : {}", stream_id);
313        transport_map.insert(stream_id.to_string(), transport);
314        Ok(())
315    }
316
317    pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> {
318        let mut transport_map = self.transport_map.write().await;
319        tracing::trace!("removing transport for stream id : {}", stream_id);
320        transport_map.remove(stream_id);
321        Ok(())
322    }
323
324    pub(crate) async fn transport_by_stream(
325        &self,
326        stream_id: &str,
327    ) -> SdkResult<
328        Arc<
329            dyn TransportDispatcher<
330                ClientMessages,
331                MessageFromServer,
332                ClientMessage,
333                ServerMessages,
334                ServerMessage,
335            >,
336        >,
337    > {
338        let transport_map = self.transport_map.read().await;
339        transport_map.get(stream_id).cloned().ok_or_else(|| {
340            RpcError::internal_error()
341                .with_message(format!("Transport for key {stream_id} not found"))
342                .into()
343        })
344    }
345
346    pub(crate) async fn shutdown(&self) {
347        let mut transport_map = self.transport_map.write().await;
348        let items: Vec<_> = transport_map.drain().map(|(_, v)| v).collect();
349        drop(transport_map);
350        for item in items {
351            let _ = item.shut_down().await;
352        }
353    }
354
355    pub(crate) async fn stream_id_exists(&self, stream_id: &str) -> bool {
356        let transport_map = self.transport_map.read().await;
357        transport_map.contains_key(stream_id)
358    }
359
360    pub(crate) async fn start_stream(
361        self: Arc<Self>,
362        transport: impl TransportDispatcher<
363            ClientMessages,
364            MessageFromServer,
365            ClientMessage,
366            ServerMessages,
367            ServerMessage,
368        >,
369        stream_id: &str,
370        ping_interval: Duration,
371        payload: Option<String>,
372    ) -> SdkResult<()> {
373        let mut stream = transport.start().await?;
374
375        self.store_transport(stream_id, Arc::new(transport)).await?;
376
377        let transport = self.transport_by_stream(stream_id).await?;
378
379        let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>();
380        let abort_alive_task = transport
381            .keep_alive(ping_interval, disconnect_tx)
382            .await?
383            .abort_handle();
384
385        // ensure keep_alive task will be aborted
386        let _abort_guard = AbortTaskOnDrop {
387            handle: abort_alive_task,
388        };
389
390        // in case there is a payload, we consume it by transport to get processed
391        if let Some(payload) = payload {
392            transport.consume_string_payload(&payload).await?;
393        }
394
395        loop {
396            tokio::select! {
397                Some(mcp_messages) = stream.next() =>{
398
399                    match mcp_messages {
400                        ClientMessages::Single(client_message) => {
401                            let result = self.handle_message(client_message, &transport).await?;
402                            if let Some(result) = result {
403                                transport.send_message(ServerMessages::Single(result), None).await?;
404                            }
405                        }
406                        ClientMessages::Batch(client_messages) => {
407
408                            let handling_tasks: Vec<_> = client_messages
409                                .into_iter()
410                                .map(|client_message| self.handle_message(client_message, &transport))
411                                .collect();
412
413                            let results: Vec<_> = try_join_all(handling_tasks).await?;
414
415                            let results: Vec<_> = results.into_iter().flatten().collect();
416
417
418                            if !results.is_empty() {
419                                transport.send_message(ServerMessages::Batch(results), None).await?;
420                            }
421                        }
422                    }
423                    // close the stream after all messages are sent, unless it is a standalone stream
424                    if !stream_id.eq(DEFAULT_STREAM_ID){
425                        return  Ok(());
426                    }
427                }
428                _ = &mut disconnect_rx => {
429                                self.remove_transport(stream_id).await?;
430                                // Disconnection detected by keep-alive task
431                                return Err(SdkError::connection_closed().into());
432
433                }
434            }
435        }
436    }
437
438    #[cfg(feature = "hyper-server")]
439    pub(crate) async fn session_id(&self) -> Option<SessionId> {
440        self.session_id.to_owned()
441    }
442
443    #[cfg(feature = "hyper-server")]
444    pub(crate) fn new_instance(
445        server_details: Arc<InitializeResult>,
446        handler: Arc<dyn McpServerHandler>,
447        session_id: SessionId,
448    ) -> Self {
449        let (client_details_tx, client_details_rx) =
450            watch::channel::<Option<InitializeRequestParams>>(None);
451        Self {
452            server_details,
453            handler,
454            session_id: Some(session_id),
455            transport_map: tokio::sync::RwLock::new(HashMap::new()),
456            client_details_tx,
457            client_details_rx,
458            request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
459        }
460    }
461
462    pub(crate) fn new(
463        server_details: InitializeResult,
464        transport: impl TransportDispatcher<
465            ClientMessages,
466            MessageFromServer,
467            ClientMessage,
468            ServerMessages,
469            ServerMessage,
470        >,
471        handler: Arc<dyn McpServerHandler>,
472    ) -> Self {
473        let mut map: HashMap<String, TransportType> = HashMap::new();
474        map.insert(DEFAULT_STREAM_ID.to_string(), Arc::new(transport));
475        let (client_details_tx, client_details_rx) =
476            watch::channel::<Option<InitializeRequestParams>>(None);
477        Self {
478            server_details: Arc::new(server_details),
479            handler,
480            #[cfg(feature = "hyper-server")]
481            session_id: None,
482            transport_map: tokio::sync::RwLock::new(map),
483            client_details_tx,
484            client_details_rx,
485            request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
486        }
487    }
488}