rust_mcp_sdk/mcp_runtimes/
server_runtime.rs1pub mod mcp_server_runtime;
2pub mod mcp_server_runtime_core;
3
4use crate::schema::schema_utils::{self, MessageFromServer};
5use crate::schema::{InitializeRequestParams, InitializeResult, RpcError};
6use async_trait::async_trait;
7use futures::StreamExt;
8use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport};
9use schema_utils::ClientMessage;
10use std::pin::Pin;
11use std::sync::{Arc, RwLock};
12use tokio::io::AsyncWriteExt;
13
14use crate::error::SdkResult;
15use crate::mcp_traits::mcp_handler::McpServerHandler;
16use crate::mcp_traits::mcp_server::McpServer;
17#[cfg(feature = "hyper-server")]
18use rust_mcp_transport::SessionId;
19
20pub struct ServerRuntime {
22 transport: Box<dyn Transport<ClientMessage, MessageFromServer>>,
24 handler: Arc<dyn McpServerHandler>,
26 server_details: Arc<InitializeResult>,
28 client_details: Arc<RwLock<Option<InitializeRequestParams>>>,
30
31 message_sender: tokio::sync::RwLock<Option<MessageDispatcher<ClientMessage>>>,
32 error_stream: tokio::sync::RwLock<Option<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>>,
33 #[cfg(feature = "hyper-server")]
34 session_id: Option<SessionId>,
35}
36
37#[async_trait]
38impl McpServer for ServerRuntime {
39 fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> {
41 match self.client_details.write() {
42 Ok(mut details) => {
43 *details = Some(client_details);
44 Ok(())
45 }
46 Err(_) => Err(RpcError::internal_error()
48 .with_message("Internal Error: Failed to acquire write lock.".to_string())
49 .into()),
50 }
51 }
52
53 fn server_info(&self) -> &InitializeResult {
56 &self.server_details
57 }
58
59 fn client_info(&self) -> Option<InitializeRequestParams> {
61 if let Ok(details) = self.client_details.read() {
62 details.clone()
63 } else {
64 None
66 }
67 }
68
69 async fn sender(&self) -> &tokio::sync::RwLock<Option<MessageDispatcher<ClientMessage>>>
70 where
71 MessageDispatcher<ClientMessage>: McpDispatch<ClientMessage, MessageFromServer>,
72 {
73 (&self.message_sender) as _
74 }
75
76 async fn start(&self) -> SdkResult<()> {
78 let (mut stream, sender, error_io) = self.transport.start().await?;
83
84 self.set_message_sender(sender).await;
85
86 if let IoStream::Writable(error_stream) = error_io {
87 self.set_error_stream(error_stream).await;
88 }
89
90 let sender = self.sender().await.read().await;
91 let sender = sender
92 .as_ref()
93 .ok_or(schema_utils::SdkError::connection_closed())?;
94
95 self.handler.on_server_started(self).await;
96
97 while let Some(mcp_message) = stream.next().await {
99 match mcp_message {
100 ClientMessage::Request(client_jsonrpc_request) => {
102 let result = self
103 .handler
104 .handle_request(client_jsonrpc_request.request, self)
105 .await;
106 let response: MessageFromServer = match result {
108 Ok(success_value) => success_value.into(),
109 Err(error_value) => MessageFromServer::Error(error_value),
110 };
111
112 sender
114 .send(response, Some(client_jsonrpc_request.id), None)
115 .await?;
116 }
117 ClientMessage::Notification(client_jsonrpc_notification) => {
118 self.handler
119 .handle_notification(client_jsonrpc_notification.notification, self)
120 .await?;
121 }
122 ClientMessage::Error(jsonrpc_error) => {
123 self.handler.handle_error(jsonrpc_error.error, self).await?;
124 }
125 ClientMessage::Response(_) => {}
127 }
128 }
129
130 return Ok(());
131 }
132
133 async fn stderr_message(&self, message: String) -> SdkResult<()> {
134 let mut lock = self.error_stream.write().await;
135 if let Some(stderr) = lock.as_mut() {
136 stderr.write_all(message.as_bytes()).await?;
137 stderr.write_all(b"\n").await?;
138 stderr.flush().await?;
139 }
140 Ok(())
141 }
142}
143
144impl ServerRuntime {
145 pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<ClientMessage>) {
146 let mut lock = self.message_sender.write().await;
147 *lock = Some(sender);
148 }
149
150 #[cfg(feature = "hyper-server")]
151 pub(crate) async fn session_id(&self) -> Option<SessionId> {
152 self.session_id.to_owned()
153 }
154
155 pub(crate) async fn set_error_stream(
156 &self,
157 error_stream: Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>,
158 ) {
159 let mut lock = self.error_stream.write().await;
160 *lock = Some(error_stream);
161 }
162
163 #[cfg(feature = "hyper-server")]
164 pub(crate) fn new_instance(
165 server_details: Arc<InitializeResult>,
166 transport: impl Transport<ClientMessage, MessageFromServer>,
167 handler: Arc<dyn McpServerHandler>,
168 session_id: SessionId,
169 ) -> Self {
170 Self {
171 server_details,
172 client_details: Arc::new(RwLock::new(None)),
173 transport: Box::new(transport),
174 handler,
175 message_sender: tokio::sync::RwLock::new(None),
176 error_stream: tokio::sync::RwLock::new(None),
177 session_id: Some(session_id),
178 }
179 }
180
181 pub(crate) fn new(
182 server_details: InitializeResult,
183 transport: impl Transport<ClientMessage, MessageFromServer>,
184 handler: Arc<dyn McpServerHandler>,
185 ) -> Self {
186 Self {
187 server_details: Arc::new(server_details),
188 client_details: Arc::new(RwLock::new(None)),
189 transport: Box::new(transport),
190 handler,
191 message_sender: tokio::sync::RwLock::new(None),
192 error_stream: tokio::sync::RwLock::new(None),
193 #[cfg(feature = "hyper-server")]
194 session_id: None,
195 }
196 }
197}