rust_mcp_sdk/mcp_runtimes/
server_runtime.rs1pub mod mcp_server_runtime;
2pub mod mcp_server_runtime_core;
3
4use crate::schema::schema_utils::{self, MessageFromServer};
5use crate::schema::{InitializeRequestParams, InitializeResult, RpcError};
6use async_trait::async_trait;
7use futures::StreamExt;
8use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport};
9use schema_utils::ClientMessage;
10use std::pin::Pin;
11use std::sync::{Arc, RwLock};
12use tokio::io::AsyncWriteExt;
13
14use crate::error::SdkResult;
15use crate::mcp_traits::mcp_handler::McpServerHandler;
16use crate::mcp_traits::mcp_server::McpServer;
17#[cfg(feature = "hyper-server")]
18use rust_mcp_transport::SessionId;
19
20pub struct ServerRuntime {
22 transport: Box<dyn Transport<ClientMessage, MessageFromServer>>,
24 handler: Arc<dyn McpServerHandler>,
26 server_details: Arc<InitializeResult>,
28 client_details: Arc<RwLock<Option<InitializeRequestParams>>>,
30
31 message_sender: tokio::sync::RwLock<Option<MessageDispatcher<ClientMessage>>>,
32 error_stream: tokio::sync::RwLock<Option<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>>,
33 #[cfg(feature = "hyper-server")]
34 session_id: Option<SessionId>,
35}
36
37#[async_trait]
38impl McpServer for ServerRuntime {
39 fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> {
41 match self.client_details.write() {
42 Ok(mut details) => {
43 *details = Some(client_details);
44 Ok(())
45 }
46 Err(_) => Err(RpcError::internal_error()
48 .with_message("Internal Error: Failed to acquire write lock.".to_string())
49 .into()),
50 }
51 }
52
53 fn server_info(&self) -> &InitializeResult {
56 &self.server_details
57 }
58
59 fn client_info(&self) -> Option<InitializeRequestParams> {
61 if let Ok(details) = self.client_details.read() {
62 details.clone()
63 } else {
64 None
66 }
67 }
68
69 async fn sender(&self) -> &tokio::sync::RwLock<Option<MessageDispatcher<ClientMessage>>>
70 where
71 MessageDispatcher<ClientMessage>: McpDispatch<ClientMessage, MessageFromServer>,
72 {
73 (&self.message_sender) as _
74 }
75
76 async fn start(&self) -> SdkResult<()> {
78 let (mut stream, sender, error_io) = self.transport.start().await?;
83
84 self.set_message_sender(sender).await;
85
86 if let IoStream::Writable(error_stream) = error_io {
87 self.set_error_stream(error_stream).await;
88 }
89
90 let sender = self.sender().await.read().await;
91 let sender = sender
92 .as_ref()
93 .ok_or(schema_utils::SdkError::connection_closed())?;
94
95 self.handler.on_server_started(self).await;
96
97 while let Some(mcp_message) = stream.next().await {
99 match mcp_message {
100 ClientMessage::Request(client_jsonrpc_request) => {
102 let result = self
103 .handler
104 .handle_request(client_jsonrpc_request.request, self)
105 .await;
106 let response: MessageFromServer = match result {
108 Ok(success_value) => success_value.into(),
109 Err(error_value) => {
110 if !self.is_initialized() {
113 return Err(error_value.into());
114 }
115 MessageFromServer::Error(error_value)
116 }
117 };
118
119 sender
121 .send(response, Some(client_jsonrpc_request.id), None)
122 .await?;
123 }
124 ClientMessage::Notification(client_jsonrpc_notification) => {
125 self.handler
126 .handle_notification(client_jsonrpc_notification.notification, self)
127 .await?;
128 }
129 ClientMessage::Error(jsonrpc_error) => {
130 self.handler.handle_error(jsonrpc_error.error, self).await?;
131 }
132 ClientMessage::Response(_) => {}
134 }
135 }
136
137 return Ok(());
138 }
139
140 async fn stderr_message(&self, message: String) -> SdkResult<()> {
141 let mut lock = self.error_stream.write().await;
142 if let Some(stderr) = lock.as_mut() {
143 stderr.write_all(message.as_bytes()).await?;
144 stderr.write_all(b"\n").await?;
145 stderr.flush().await?;
146 }
147 Ok(())
148 }
149}
150
151impl ServerRuntime {
152 pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<ClientMessage>) {
153 let mut lock = self.message_sender.write().await;
154 *lock = Some(sender);
155 }
156
157 #[cfg(feature = "hyper-server")]
158 pub(crate) async fn session_id(&self) -> Option<SessionId> {
159 self.session_id.to_owned()
160 }
161
162 pub(crate) async fn set_error_stream(
163 &self,
164 error_stream: Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>,
165 ) {
166 let mut lock = self.error_stream.write().await;
167 *lock = Some(error_stream);
168 }
169
170 #[cfg(feature = "hyper-server")]
171 pub(crate) fn new_instance(
172 server_details: Arc<InitializeResult>,
173 transport: impl Transport<ClientMessage, MessageFromServer>,
174 handler: Arc<dyn McpServerHandler>,
175 session_id: SessionId,
176 ) -> Self {
177 Self {
178 server_details,
179 client_details: Arc::new(RwLock::new(None)),
180 transport: Box::new(transport),
181 handler,
182 message_sender: tokio::sync::RwLock::new(None),
183 error_stream: tokio::sync::RwLock::new(None),
184 session_id: Some(session_id),
185 }
186 }
187
188 pub(crate) fn new(
189 server_details: InitializeResult,
190 transport: impl Transport<ClientMessage, MessageFromServer>,
191 handler: Arc<dyn McpServerHandler>,
192 ) -> Self {
193 Self {
194 server_details: Arc::new(server_details),
195 client_details: Arc::new(RwLock::new(None)),
196 transport: Box::new(transport),
197 handler,
198 message_sender: tokio::sync::RwLock::new(None),
199 error_stream: tokio::sync::RwLock::new(None),
200 #[cfg(feature = "hyper-server")]
201 session_id: None,
202 }
203 }
204}