rust_mcp_sdk/mcp_runtimes/
server_runtime.rs1pub mod mcp_server_runtime;
2pub mod mcp_server_runtime_core;
3use crate::error::SdkResult;
4use crate::mcp_traits::mcp_handler::McpServerHandler;
5use crate::mcp_traits::mcp_server::McpServer;
6use crate::mcp_traits::{RequestIdGen, RequestIdGenNumeric};
7use crate::schema::{
8 schema_utils::{
9 ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage,
10 ServerMessages,
11 },
12 InitializeRequestParams, InitializeResult, RequestId, RpcError,
13};
14use crate::utils::AbortTaskOnDrop;
15use async_trait::async_trait;
16use futures::future::try_join_all;
17use futures::{StreamExt, TryFutureExt};
18#[cfg(feature = "hyper-server")]
19use rust_mcp_transport::SessionId;
20use rust_mcp_transport::{IoStream, TransportDispatcher};
21use std::collections::HashMap;
22use std::sync::Arc;
23use std::time::Duration;
24use tokio::io::AsyncWriteExt;
25use tokio::sync::{oneshot, watch};
26
27pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM";
28
29type TransportType = Arc<
31 dyn TransportDispatcher<
32 ClientMessages,
33 MessageFromServer,
34 ClientMessage,
35 ServerMessages,
36 ServerMessage,
37 >,
38>;
39
40pub struct ServerRuntime {
42 handler: Arc<dyn McpServerHandler>,
44 server_details: Arc<InitializeResult>,
46 #[cfg(feature = "hyper-server")]
47 session_id: Option<SessionId>,
48 transport_map: tokio::sync::RwLock<HashMap<String, TransportType>>,
49 request_id_gen: Box<dyn RequestIdGen>,
50 client_details_tx: watch::Sender<Option<InitializeRequestParams>>,
51 client_details_rx: watch::Receiver<Option<InitializeRequestParams>>,
52}
53
54#[async_trait]
55impl McpServer for ServerRuntime {
56 async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> {
58 self.handler.on_server_started(self).await;
59
60 self.client_details_tx
61 .send(Some(client_details))
62 .map_err(|_| {
63 RpcError::internal_error()
64 .with_message("Failed to set client details".to_string())
65 .into()
66 })
67 }
68
69 async fn wait_for_initialization(&self) {
70 loop {
71 if self.client_details_rx.borrow().is_some() {
72 return;
73 }
74 let mut rx = self.client_details_rx.clone();
75 rx.changed().await.ok();
76 }
77 }
78
79 async fn send(
80 &self,
81 message: MessageFromServer,
82 request_id: Option<RequestId>,
83 request_timeout: Option<Duration>,
84 ) -> SdkResult<Option<ClientMessage>> {
85 let transport_map = self.transport_map.read().await;
86 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
87 RpcError::internal_error()
88 .with_message("transport stream does not exists or is closed!".to_string()),
89 )?;
90
91 let outgoing_request_id = self
92 .request_id_gen
93 .request_id_for_message(&message, request_id);
94
95 let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?;
96
97 let response = transport
98 .send_message(ServerMessages::Single(mcp_message), request_timeout)
99 .await?
100 .map(|res| res.as_single())
101 .transpose()?;
102
103 Ok(response)
104 }
105
106 async fn send_batch(
107 &self,
108 messages: Vec<ServerMessage>,
109 request_timeout: Option<Duration>,
110 ) -> SdkResult<Option<Vec<ClientMessage>>> {
111 let transport_map = self.transport_map.read().await;
112 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
113 RpcError::internal_error()
114 .with_message("transport stream does not exists or is closed!".to_string()),
115 )?;
116
117 transport
118 .send_batch(messages, request_timeout)
119 .map_err(|err| err.into())
120 .await
121 }
122
123 fn server_info(&self) -> &InitializeResult {
126 &self.server_details
127 }
128
129 fn client_info(&self) -> Option<InitializeRequestParams> {
131 self.client_details_rx.borrow().clone()
132 }
133
134 async fn start(&self) -> SdkResult<()> {
136 let transport_map = self.transport_map.read().await;
137
138 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
139 RpcError::internal_error()
140 .with_message("transport stream does not exists or is closed!".to_string()),
141 )?;
142
143 let mut stream = transport.start().await?;
144
145 while let Some(mcp_messages) = stream.next().await {
147 match mcp_messages {
148 ClientMessages::Single(client_message) => {
149 let result = self.handle_message(client_message, transport).await;
150
151 match result {
152 Ok(result) => {
153 if let Some(result) = result {
154 transport
155 .send_message(ServerMessages::Single(result), None)
156 .await?;
157 }
158 }
159 Err(error) => {
160 tracing::error!("Error handling message : {}", error)
161 }
162 }
163 }
164 ClientMessages::Batch(client_messages) => {
165 let handling_tasks: Vec<_> = client_messages
166 .into_iter()
167 .map(|client_message| self.handle_message(client_message, transport))
168 .collect();
169
170 let results: Vec<_> = try_join_all(handling_tasks).await?;
171
172 let results: Vec<_> = results.into_iter().flatten().collect();
173
174 if !results.is_empty() {
175 transport
176 .send_message(ServerMessages::Batch(results), None)
177 .await?;
178 }
179 }
180 }
181 }
182 return Ok(());
183 }
184
185 async fn stderr_message(&self, message: String) -> SdkResult<()> {
186 let transport_map = self.transport_map.read().await;
187 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
188 RpcError::internal_error()
189 .with_message("transport stream does not exists or is closed!".to_string()),
190 )?;
191 let mut lock = transport.error_stream().write().await;
192
193 if let Some(IoStream::Writable(stderr)) = lock.as_mut() {
194 stderr.write_all(message.as_bytes()).await?;
195 stderr.write_all(b"\n").await?;
196 stderr.flush().await?;
197 }
198 Ok(())
199 }
200}
201
202impl ServerRuntime {
203 pub(crate) async fn consume_payload_string(
204 &self,
205 stream_id: &str,
206 payload: &str,
207 ) -> SdkResult<()> {
208 let transport_map = self.transport_map.read().await;
209
210 let transport = transport_map.get(stream_id).ok_or(
211 RpcError::internal_error()
212 .with_message("stream id does not exists or is closed!".to_string()),
213 )?;
214
215 transport.consume_string_payload(payload).await?;
216
217 Ok(())
218 }
219
220 pub(crate) async fn handle_message(
221 &self,
222 message: ClientMessage,
223 transport: &Arc<
224 dyn TransportDispatcher<
225 ClientMessages,
226 MessageFromServer,
227 ClientMessage,
228 ServerMessages,
229 ServerMessage,
230 >,
231 >,
232 ) -> SdkResult<Option<ServerMessage>> {
233 let response = match message {
234 ClientMessage::Request(client_jsonrpc_request) => {
236 let result = self
237 .handler
238 .handle_request(client_jsonrpc_request.request, self)
239 .await;
240 let response: MessageFromServer = match result {
242 Ok(success_value) => success_value.into(),
243 Err(error_value) => {
244 if !self.is_initialized() {
247 return Err(error_value.into());
248 }
249 MessageFromServer::Error(error_value)
250 }
251 };
252
253 let mpc_message: ServerMessage =
254 ServerMessage::from_message(response, Some(client_jsonrpc_request.id))?;
255
256 Some(mpc_message)
257 }
258 ClientMessage::Notification(client_jsonrpc_notification) => {
259 self.handler
260 .handle_notification(client_jsonrpc_notification.notification, self)
261 .await?;
262 None
263 }
264 ClientMessage::Error(jsonrpc_error) => {
265 self.handler
266 .handle_error(&jsonrpc_error.error, self)
267 .await?;
268 if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await {
269 tx_response
270 .send(ClientMessage::Error(jsonrpc_error))
271 .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
272 } else {
273 tracing::warn!(
274 "Received an error response with no corresponding request {:?}",
275 &jsonrpc_error.id
276 );
277 }
278 None
279 }
280 ClientMessage::Response(response) => {
282 if let Some(tx_response) = transport.pending_request_tx(&response.id).await {
283 tx_response
284 .send(ClientMessage::Response(response))
285 .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
286 } else {
287 tracing::warn!(
288 "Received a response with no corresponding request: {:?}",
289 &response.id
290 );
291 }
292 None
293 }
294 };
295 Ok(response)
296 }
297
298 pub(crate) async fn store_transport(
299 &self,
300 stream_id: &str,
301 transport: Arc<
302 dyn TransportDispatcher<
303 ClientMessages,
304 MessageFromServer,
305 ClientMessage,
306 ServerMessages,
307 ServerMessage,
308 >,
309 >,
310 ) -> SdkResult<()> {
311 let mut transport_map = self.transport_map.write().await;
312 tracing::trace!("save transport for stream id : {}", stream_id);
313 transport_map.insert(stream_id.to_string(), transport);
314 Ok(())
315 }
316
317 pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> {
318 let mut transport_map = self.transport_map.write().await;
319 tracing::trace!("removing transport for stream id : {}", stream_id);
320 transport_map.remove(stream_id);
321 Ok(())
322 }
323
324 pub(crate) async fn transport_by_stream(
325 &self,
326 stream_id: &str,
327 ) -> SdkResult<
328 Arc<
329 dyn TransportDispatcher<
330 ClientMessages,
331 MessageFromServer,
332 ClientMessage,
333 ServerMessages,
334 ServerMessage,
335 >,
336 >,
337 > {
338 let transport_map = self.transport_map.read().await;
339 transport_map.get(stream_id).cloned().ok_or_else(|| {
340 RpcError::internal_error()
341 .with_message(format!("Transport for key {stream_id} not found"))
342 .into()
343 })
344 }
345
346 pub(crate) async fn shutdown(&self) {
347 let mut transport_map = self.transport_map.write().await;
348 let items: Vec<_> = transport_map.drain().map(|(_, v)| v).collect();
349 drop(transport_map);
350 for item in items {
351 let _ = item.shut_down().await;
352 }
353 }
354
355 pub(crate) async fn stream_id_exists(&self, stream_id: &str) -> bool {
356 let transport_map = self.transport_map.read().await;
357 transport_map.contains_key(stream_id)
358 }
359
360 pub(crate) async fn start_stream(
361 self: Arc<Self>,
362 transport: impl TransportDispatcher<
363 ClientMessages,
364 MessageFromServer,
365 ClientMessage,
366 ServerMessages,
367 ServerMessage,
368 >,
369 stream_id: &str,
370 ping_interval: Duration,
371 payload: Option<String>,
372 ) -> SdkResult<()> {
373 let mut stream = transport.start().await?;
374
375 self.store_transport(stream_id, Arc::new(transport)).await?;
376
377 let transport = self.transport_by_stream(stream_id).await?;
378
379 let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>();
380 let abort_alive_task = transport
381 .keep_alive(ping_interval, disconnect_tx)
382 .await?
383 .abort_handle();
384
385 let _abort_guard = AbortTaskOnDrop {
387 handle: abort_alive_task,
388 };
389
390 if let Some(payload) = payload {
392 transport.consume_string_payload(&payload).await?;
393 }
394
395 loop {
396 tokio::select! {
397 Some(mcp_messages) = stream.next() =>{
398
399 match mcp_messages {
400 ClientMessages::Single(client_message) => {
401 let result = self.handle_message(client_message, &transport).await?;
402 if let Some(result) = result {
403 transport.send_message(ServerMessages::Single(result), None).await?;
404 }
405 }
406 ClientMessages::Batch(client_messages) => {
407
408 let handling_tasks: Vec<_> = client_messages
409 .into_iter()
410 .map(|client_message| self.handle_message(client_message, &transport))
411 .collect();
412
413 let results: Vec<_> = try_join_all(handling_tasks).await?;
414
415 let results: Vec<_> = results.into_iter().flatten().collect();
416
417
418 if !results.is_empty() {
419 transport.send_message(ServerMessages::Batch(results), None).await?;
420 }
421 }
422 }
423 if !stream_id.eq(DEFAULT_STREAM_ID){
425 return Ok(());
426 }
427 }
428 _ = &mut disconnect_rx => {
429 self.remove_transport(stream_id).await?;
430 return Err(SdkError::connection_closed().into());
432
433 }
434 }
435 }
436 }
437
438 #[cfg(feature = "hyper-server")]
439 pub(crate) async fn session_id(&self) -> Option<SessionId> {
440 self.session_id.to_owned()
441 }
442
443 #[cfg(feature = "hyper-server")]
444 pub(crate) fn new_instance(
445 server_details: Arc<InitializeResult>,
446 handler: Arc<dyn McpServerHandler>,
447 session_id: SessionId,
448 ) -> Self {
449 let (client_details_tx, client_details_rx) =
450 watch::channel::<Option<InitializeRequestParams>>(None);
451 Self {
452 server_details,
453 handler,
454 session_id: Some(session_id),
455 transport_map: tokio::sync::RwLock::new(HashMap::new()),
456 client_details_tx,
457 client_details_rx,
458 request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
459 }
460 }
461
462 pub(crate) fn new(
463 server_details: InitializeResult,
464 transport: impl TransportDispatcher<
465 ClientMessages,
466 MessageFromServer,
467 ClientMessage,
468 ServerMessages,
469 ServerMessage,
470 >,
471 handler: Arc<dyn McpServerHandler>,
472 ) -> Self {
473 let mut map: HashMap<String, TransportType> = HashMap::new();
474 map.insert(DEFAULT_STREAM_ID.to_string(), Arc::new(transport));
475 let (client_details_tx, client_details_rx) =
476 watch::channel::<Option<InitializeRequestParams>>(None);
477 Self {
478 server_details: Arc::new(server_details),
479 handler,
480 #[cfg(feature = "hyper-server")]
481 session_id: None,
482 transport_map: tokio::sync::RwLock::new(map),
483 client_details_tx,
484 client_details_rx,
485 request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
486 }
487 }
488}