rust_mcp_sdk/mcp_runtimes/
server_runtime.rs

1pub mod mcp_server_runtime;
2pub mod mcp_server_runtime_core;
3
4use crate::schema::schema_utils::{self, MessageFromServer};
5use crate::schema::{InitializeRequestParams, InitializeResult, RpcError};
6use async_trait::async_trait;
7use futures::StreamExt;
8use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport};
9use schema_utils::ClientMessage;
10use std::pin::Pin;
11use std::sync::{Arc, RwLock};
12use tokio::io::AsyncWriteExt;
13
14use crate::error::SdkResult;
15use crate::mcp_traits::mcp_handler::McpServerHandler;
16use crate::mcp_traits::mcp_server::McpServer;
17#[cfg(feature = "hyper-server")]
18use rust_mcp_transport::SessionId;
19
20/// Struct representing the runtime core of the MCP server, handling transport and client details
21pub struct ServerRuntime {
22    // The transport interface for handling messages between client and server
23    transport: Box<dyn Transport<ClientMessage, MessageFromServer>>,
24    // The handler for processing MCP messages
25    handler: Arc<dyn McpServerHandler>,
26    // Information about the server
27    server_details: Arc<InitializeResult>,
28    // Details about the connected client
29    client_details: Arc<RwLock<Option<InitializeRequestParams>>>,
30
31    message_sender: tokio::sync::RwLock<Option<MessageDispatcher<ClientMessage>>>,
32    error_stream: tokio::sync::RwLock<Option<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>>,
33    #[cfg(feature = "hyper-server")]
34    session_id: Option<SessionId>,
35}
36
37#[async_trait]
38impl McpServer for ServerRuntime {
39    /// Set the client details, storing them in client_details
40    fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> {
41        match self.client_details.write() {
42            Ok(mut details) => {
43                *details = Some(client_details);
44                Ok(())
45            }
46            // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None.
47            Err(_) => Err(RpcError::internal_error()
48                .with_message("Internal Error: Failed to acquire write lock.".to_string())
49                .into()),
50        }
51    }
52
53    /// Returns the server's details, including server capability,
54    /// instructions, protocol_version , server_info and optional meta data
55    fn server_info(&self) -> &InitializeResult {
56        &self.server_details
57    }
58
59    /// Returns the client information if available, after successful initialization , otherwise returns None
60    fn client_info(&self) -> Option<InitializeRequestParams> {
61        if let Ok(details) = self.client_details.read() {
62            details.clone()
63        } else {
64            // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None.
65            None
66        }
67    }
68
69    async fn sender(&self) -> &tokio::sync::RwLock<Option<MessageDispatcher<ClientMessage>>>
70    where
71        MessageDispatcher<ClientMessage>: McpDispatch<ClientMessage, MessageFromServer>,
72    {
73        (&self.message_sender) as _
74    }
75
76    /// Main runtime loop, processes incoming messages and handles requests
77    async fn start(&self) -> SdkResult<()> {
78        // Start the transport layer to begin handling messages
79        // self.transport.start().await?;
80        // Open the transport stream
81        // let mut stream = self.transport.open();
82        let (mut stream, sender, error_io) = self.transport.start().await?;
83
84        self.set_message_sender(sender).await;
85
86        if let IoStream::Writable(error_stream) = error_io {
87            self.set_error_stream(error_stream).await;
88        }
89
90        let sender = self.sender().await.read().await;
91        let sender = sender
92            .as_ref()
93            .ok_or(schema_utils::SdkError::connection_closed())?;
94
95        self.handler.on_server_started(self).await;
96
97        // Process incoming messages from the client
98        while let Some(mcp_message) = stream.next().await {
99            match mcp_message {
100                // Handle a client request
101                ClientMessage::Request(client_jsonrpc_request) => {
102                    let result = self
103                        .handler
104                        .handle_request(client_jsonrpc_request.request, self)
105                        .await;
106                    // create a response to send back to the client
107                    let response: MessageFromServer = match result {
108                        Ok(success_value) => success_value.into(),
109                        Err(error_value) => {
110                            // Error occurred during initialization.
111                            // A likely cause could be an unsupported protocol version.
112                            if !self.is_initialized() {
113                                return Err(error_value.into());
114                            }
115                            MessageFromServer::Error(error_value)
116                        }
117                    };
118
119                    // send the response back with corresponding request id
120                    sender
121                        .send(response, Some(client_jsonrpc_request.id), None)
122                        .await?;
123                }
124                ClientMessage::Notification(client_jsonrpc_notification) => {
125                    self.handler
126                        .handle_notification(client_jsonrpc_notification.notification, self)
127                        .await?;
128                }
129                ClientMessage::Error(jsonrpc_error) => {
130                    self.handler.handle_error(jsonrpc_error.error, self).await?;
131                }
132                // The response is the result of a request, it is processed at the transport level.
133                ClientMessage::Response(_) => {}
134            }
135        }
136
137        return Ok(());
138    }
139
140    async fn stderr_message(&self, message: String) -> SdkResult<()> {
141        let mut lock = self.error_stream.write().await;
142        if let Some(stderr) = lock.as_mut() {
143            stderr.write_all(message.as_bytes()).await?;
144            stderr.write_all(b"\n").await?;
145            stderr.flush().await?;
146        }
147        Ok(())
148    }
149}
150
151impl ServerRuntime {
152    pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<ClientMessage>) {
153        let mut lock = self.message_sender.write().await;
154        *lock = Some(sender);
155    }
156
157    #[cfg(feature = "hyper-server")]
158    pub(crate) async fn session_id(&self) -> Option<SessionId> {
159        self.session_id.to_owned()
160    }
161
162    pub(crate) async fn set_error_stream(
163        &self,
164        error_stream: Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>,
165    ) {
166        let mut lock = self.error_stream.write().await;
167        *lock = Some(error_stream);
168    }
169
170    #[cfg(feature = "hyper-server")]
171    pub(crate) fn new_instance(
172        server_details: Arc<InitializeResult>,
173        transport: impl Transport<ClientMessage, MessageFromServer>,
174        handler: Arc<dyn McpServerHandler>,
175        session_id: SessionId,
176    ) -> Self {
177        Self {
178            server_details,
179            client_details: Arc::new(RwLock::new(None)),
180            transport: Box::new(transport),
181            handler,
182            message_sender: tokio::sync::RwLock::new(None),
183            error_stream: tokio::sync::RwLock::new(None),
184            session_id: Some(session_id),
185        }
186    }
187
188    pub(crate) fn new(
189        server_details: InitializeResult,
190        transport: impl Transport<ClientMessage, MessageFromServer>,
191        handler: Arc<dyn McpServerHandler>,
192    ) -> Self {
193        Self {
194            server_details: Arc::new(server_details),
195            client_details: Arc::new(RwLock::new(None)),
196            transport: Box::new(transport),
197            handler,
198            message_sender: tokio::sync::RwLock::new(None),
199            error_stream: tokio::sync::RwLock::new(None),
200            #[cfg(feature = "hyper-server")]
201            session_id: None,
202        }
203    }
204}