rust_mcp_sdk/mcp_runtimes/
server_runtime.rs1pub mod mcp_server_runtime;
2pub mod mcp_server_runtime_core;
3use crate::error::SdkResult;
4use crate::mcp_traits::mcp_handler::McpServerHandler;
5use crate::mcp_traits::mcp_server::McpServer;
6use crate::mcp_traits::{RequestIdGen, RequestIdGenNumeric};
7use crate::schema::{
8 schema_utils::{
9 ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage,
10 ServerMessages,
11 },
12 InitializeRequestParams, InitializeResult, RequestId, RpcError,
13};
14use crate::utils::AbortTaskOnDrop;
15use async_trait::async_trait;
16use futures::future::try_join_all;
17use futures::{StreamExt, TryFutureExt};
18#[cfg(feature = "hyper-server")]
19use rust_mcp_transport::SessionId;
20use rust_mcp_transport::{IoStream, TransportDispatcher};
21use std::collections::HashMap;
22use std::sync::Arc;
23use std::time::Duration;
24use tokio::io::AsyncWriteExt;
25use tokio::sync::{oneshot, watch};
26
27pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM";
28
29type TransportType = Arc<
31 dyn TransportDispatcher<
32 ClientMessages,
33 MessageFromServer,
34 ClientMessage,
35 ServerMessages,
36 ServerMessage,
37 >,
38>;
39
40pub struct ServerRuntime {
42 handler: Arc<dyn McpServerHandler>,
44 server_details: Arc<InitializeResult>,
46 #[cfg(feature = "hyper-server")]
47 session_id: Option<SessionId>,
48 transport_map: tokio::sync::RwLock<HashMap<String, TransportType>>,
49 request_id_gen: Box<dyn RequestIdGen>,
50 client_details_tx: watch::Sender<Option<InitializeRequestParams>>,
51 client_details_rx: watch::Receiver<Option<InitializeRequestParams>>,
52}
53
54#[async_trait]
55impl McpServer for ServerRuntime {
56 async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> {
58 self.handler.on_server_started(self).await;
59
60 self.client_details_tx
61 .send(Some(client_details))
62 .map_err(|_| {
63 RpcError::internal_error()
64 .with_message("Failed to set client details".to_string())
65 .into()
66 })
67 }
68
69 async fn wait_for_initialization(&self) {
70 loop {
71 if self.client_details_rx.borrow().is_some() {
72 return;
73 }
74 let mut rx = self.client_details_rx.clone();
75 rx.changed().await.ok();
76 }
77 }
78
79 async fn send(
80 &self,
81 message: MessageFromServer,
82 request_id: Option<RequestId>,
83 request_timeout: Option<Duration>,
84 ) -> SdkResult<Option<ClientMessage>> {
85 let transport_map = self.transport_map.read().await;
86 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
87 RpcError::internal_error()
88 .with_message("transport stream does not exists or is closed!".to_string()),
89 )?;
90
91 let outgoing_request_id = self
92 .request_id_gen
93 .request_id_for_message(&message, request_id);
94
95 let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?;
96
97 let response = transport
98 .send_message(ServerMessages::Single(mcp_message), request_timeout)
99 .await?
100 .map(|res| res.as_single())
101 .transpose()?;
102
103 Ok(response)
104 }
105
106 async fn send_batch(
107 &self,
108 messages: Vec<ServerMessage>,
109 request_timeout: Option<Duration>,
110 ) -> SdkResult<Option<Vec<ClientMessage>>> {
111 let transport_map = self.transport_map.read().await;
112 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
113 RpcError::internal_error()
114 .with_message("transport stream does not exists or is closed!".to_string()),
115 )?;
116
117 transport
118 .send_batch(messages, request_timeout)
119 .map_err(|err| err.into())
120 .await
121 }
122
123 fn server_info(&self) -> &InitializeResult {
126 &self.server_details
127 }
128
129 fn client_info(&self) -> Option<InitializeRequestParams> {
131 self.client_details_rx.borrow().clone()
132 }
133
134 async fn start(&self) -> SdkResult<()> {
136 let transport_map = self.transport_map.read().await;
137
138 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
139 RpcError::internal_error()
140 .with_message("transport stream does not exists or is closed!".to_string()),
141 )?;
142
143 let mut stream = transport.start().await?;
144
145 while let Some(mcp_messages) = stream.next().await {
147 match mcp_messages {
148 ClientMessages::Single(client_message) => {
149 let result = self.handle_message(client_message, transport).await;
150
151 match result {
152 Ok(result) => {
153 if let Some(result) = result {
154 transport
155 .send_message(ServerMessages::Single(result), None)
156 .await?;
157 }
158 }
159 Err(error) => {
160 tracing::error!("Error handling message : {}", error)
161 }
162 }
163 }
164 ClientMessages::Batch(client_messages) => {
165 let handling_tasks: Vec<_> = client_messages
166 .into_iter()
167 .map(|client_message| self.handle_message(client_message, transport))
168 .collect();
169
170 let results: Vec<_> = try_join_all(handling_tasks).await?;
171
172 let results: Vec<_> = results.into_iter().flatten().collect();
173
174 if !results.is_empty() {
175 transport
176 .send_message(ServerMessages::Batch(results), None)
177 .await?;
178 }
179 }
180 }
181 }
182 return Ok(());
183 }
184
185 async fn stderr_message(&self, message: String) -> SdkResult<()> {
186 let transport_map = self.transport_map.read().await;
187 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
188 RpcError::internal_error()
189 .with_message("transport stream does not exists or is closed!".to_string()),
190 )?;
191 let mut lock = transport.error_stream().write().await;
192
193 if let Some(IoStream::Writable(stderr)) = lock.as_mut() {
194 stderr.write_all(message.as_bytes()).await?;
195 stderr.write_all(b"\n").await?;
196 stderr.flush().await?;
197 }
198 Ok(())
199 }
200
201 #[cfg(feature = "hyper-server")]
202 fn session_id(&self) -> Option<SessionId> {
203 self.session_id.to_owned()
204 }
205}
206
207impl ServerRuntime {
208 pub(crate) async fn consume_payload_string(
209 &self,
210 stream_id: &str,
211 payload: &str,
212 ) -> SdkResult<()> {
213 let transport_map = self.transport_map.read().await;
214
215 let transport = transport_map.get(stream_id).ok_or(
216 RpcError::internal_error()
217 .with_message("stream id does not exists or is closed!".to_string()),
218 )?;
219
220 transport.consume_string_payload(payload).await?;
221
222 Ok(())
223 }
224
225 pub(crate) async fn handle_message(
226 &self,
227 message: ClientMessage,
228 transport: &Arc<
229 dyn TransportDispatcher<
230 ClientMessages,
231 MessageFromServer,
232 ClientMessage,
233 ServerMessages,
234 ServerMessage,
235 >,
236 >,
237 ) -> SdkResult<Option<ServerMessage>> {
238 let response = match message {
239 ClientMessage::Request(client_jsonrpc_request) => {
241 let result = self
242 .handler
243 .handle_request(client_jsonrpc_request.request, self)
244 .await;
245 let response: MessageFromServer = match result {
247 Ok(success_value) => success_value.into(),
248 Err(error_value) => {
249 if !self.is_initialized() {
252 return Err(error_value.into());
253 }
254 MessageFromServer::Error(error_value)
255 }
256 };
257
258 let mpc_message: ServerMessage =
259 ServerMessage::from_message(response, Some(client_jsonrpc_request.id))?;
260
261 Some(mpc_message)
262 }
263 ClientMessage::Notification(client_jsonrpc_notification) => {
264 self.handler
265 .handle_notification(client_jsonrpc_notification.notification, self)
266 .await?;
267 None
268 }
269 ClientMessage::Error(jsonrpc_error) => {
270 self.handler
271 .handle_error(&jsonrpc_error.error, self)
272 .await?;
273 if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await {
274 tx_response
275 .send(ClientMessage::Error(jsonrpc_error))
276 .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
277 } else {
278 tracing::warn!(
279 "Received an error response with no corresponding request {:?}",
280 &jsonrpc_error.id
281 );
282 }
283 None
284 }
285 ClientMessage::Response(response) => {
287 if let Some(tx_response) = transport.pending_request_tx(&response.id).await {
288 tx_response
289 .send(ClientMessage::Response(response))
290 .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
291 } else {
292 tracing::warn!(
293 "Received a response with no corresponding request: {:?}",
294 &response.id
295 );
296 }
297 None
298 }
299 };
300 Ok(response)
301 }
302
303 pub(crate) async fn store_transport(
304 &self,
305 stream_id: &str,
306 transport: Arc<
307 dyn TransportDispatcher<
308 ClientMessages,
309 MessageFromServer,
310 ClientMessage,
311 ServerMessages,
312 ServerMessage,
313 >,
314 >,
315 ) -> SdkResult<()> {
316 let mut transport_map = self.transport_map.write().await;
317 tracing::trace!("save transport for stream id : {}", stream_id);
318 transport_map.insert(stream_id.to_string(), transport);
319 Ok(())
320 }
321
322 pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> {
323 let mut transport_map = self.transport_map.write().await;
324 tracing::trace!("removing transport for stream id : {}", stream_id);
325 transport_map.remove(stream_id);
326 Ok(())
327 }
328
329 pub(crate) async fn transport_by_stream(
330 &self,
331 stream_id: &str,
332 ) -> SdkResult<
333 Arc<
334 dyn TransportDispatcher<
335 ClientMessages,
336 MessageFromServer,
337 ClientMessage,
338 ServerMessages,
339 ServerMessage,
340 >,
341 >,
342 > {
343 let transport_map = self.transport_map.read().await;
344 transport_map.get(stream_id).cloned().ok_or_else(|| {
345 RpcError::internal_error()
346 .with_message(format!("Transport for key {stream_id} not found"))
347 .into()
348 })
349 }
350
351 pub(crate) async fn shutdown(&self) {
352 let mut transport_map = self.transport_map.write().await;
353 let items: Vec<_> = transport_map.drain().map(|(_, v)| v).collect();
354 drop(transport_map);
355 for item in items {
356 let _ = item.shut_down().await;
357 }
358 }
359
360 pub(crate) async fn stream_id_exists(&self, stream_id: &str) -> bool {
361 let transport_map = self.transport_map.read().await;
362 transport_map.contains_key(stream_id)
363 }
364
365 pub(crate) async fn start_stream(
366 self: Arc<Self>,
367 transport: impl TransportDispatcher<
368 ClientMessages,
369 MessageFromServer,
370 ClientMessage,
371 ServerMessages,
372 ServerMessage,
373 >,
374 stream_id: &str,
375 ping_interval: Duration,
376 payload: Option<String>,
377 ) -> SdkResult<()> {
378 let mut stream = transport.start().await?;
379
380 self.store_transport(stream_id, Arc::new(transport)).await?;
381
382 let transport = self.transport_by_stream(stream_id).await?;
383
384 let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>();
385 let abort_alive_task = transport
386 .keep_alive(ping_interval, disconnect_tx)
387 .await?
388 .abort_handle();
389
390 let _abort_guard = AbortTaskOnDrop {
392 handle: abort_alive_task,
393 };
394
395 if let Some(payload) = payload {
397 transport.consume_string_payload(&payload).await?;
398 }
399
400 loop {
401 tokio::select! {
402 Some(mcp_messages) = stream.next() =>{
403
404 match mcp_messages {
405 ClientMessages::Single(client_message) => {
406 let result = self.handle_message(client_message, &transport).await?;
407 if let Some(result) = result {
408 transport.send_message(ServerMessages::Single(result), None).await?;
409 }
410 }
411 ClientMessages::Batch(client_messages) => {
412
413 let handling_tasks: Vec<_> = client_messages
414 .into_iter()
415 .map(|client_message| self.handle_message(client_message, &transport))
416 .collect();
417
418 let results: Vec<_> = try_join_all(handling_tasks).await?;
419
420 let results: Vec<_> = results.into_iter().flatten().collect();
421
422
423 if !results.is_empty() {
424 transport.send_message(ServerMessages::Batch(results), None).await?;
425 }
426 }
427 }
428 if !stream_id.eq(DEFAULT_STREAM_ID){
430 return Ok(());
431 }
432 }
433 _ = &mut disconnect_rx => {
434 self.remove_transport(stream_id).await?;
435 return Err(SdkError::connection_closed().into());
437
438 }
439 }
440 }
441 }
442
443 #[cfg(feature = "hyper-server")]
444 pub(crate) fn new_instance(
445 server_details: Arc<InitializeResult>,
446 handler: Arc<dyn McpServerHandler>,
447 session_id: SessionId,
448 ) -> Self {
449 let (client_details_tx, client_details_rx) =
450 watch::channel::<Option<InitializeRequestParams>>(None);
451 Self {
452 server_details,
453 handler,
454 session_id: Some(session_id),
455 transport_map: tokio::sync::RwLock::new(HashMap::new()),
456 client_details_tx,
457 client_details_rx,
458 request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
459 }
460 }
461
462 pub(crate) fn new(
463 server_details: InitializeResult,
464 transport: impl TransportDispatcher<
465 ClientMessages,
466 MessageFromServer,
467 ClientMessage,
468 ServerMessages,
469 ServerMessage,
470 >,
471 handler: Arc<dyn McpServerHandler>,
472 ) -> Self {
473 let mut map: HashMap<String, TransportType> = HashMap::new();
474 map.insert(DEFAULT_STREAM_ID.to_string(), Arc::new(transport));
475 let (client_details_tx, client_details_rx) =
476 watch::channel::<Option<InitializeRequestParams>>(None);
477 Self {
478 server_details: Arc::new(server_details),
479 handler,
480 #[cfg(feature = "hyper-server")]
481 session_id: None,
482 transport_map: tokio::sync::RwLock::new(map),
483 client_details_tx,
484 client_details_rx,
485 request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
486 }
487 }
488}