1pub mod mcp_server_runtime;
2pub mod mcp_server_runtime_core;
3use crate::auth::AuthInfo;
4use crate::error::SdkResult;
5use crate::mcp_traits::{McpServer, McpServerHandler, RequestIdGen, RequestIdGenNumeric};
6use crate::schema::{
7 schema_utils::{
8 ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage,
9 ServerMessages,
10 },
11 InitializeRequestParams, InitializeResult, RequestId, RpcError,
12};
13use crate::utils::AbortTaskOnDrop;
14use async_trait::async_trait;
15use futures::future::try_join_all;
16use futures::{StreamExt, TryFutureExt};
17#[cfg(feature = "hyper-server")]
18use rust_mcp_transport::SessionId;
19use rust_mcp_transport::{IoStream, TransportDispatcher};
20use std::collections::HashMap;
21use std::panic;
22use std::sync::Arc;
23use std::time::Duration;
24use tokio::io::AsyncWriteExt;
25use tokio::sync::{mpsc, oneshot, watch, RwLock, RwLockReadGuard};
26
27pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM";
28const TASK_CHANNEL_CAPACITY: usize = 500;
29
30type TransportType = Arc<
32 dyn TransportDispatcher<
33 ClientMessages,
34 MessageFromServer,
35 ClientMessage,
36 ServerMessages,
37 ServerMessage,
38 >,
39>;
40
41pub struct ServerRuntime {
43 handler: Arc<dyn McpServerHandler>,
45 server_details: Arc<InitializeResult>,
47 #[cfg(feature = "hyper-server")]
48 session_id: Option<SessionId>,
49 transport_map: tokio::sync::RwLock<HashMap<String, TransportType>>, request_id_gen: Box<dyn RequestIdGen>,
51 client_details_tx: watch::Sender<Option<InitializeRequestParams>>,
52 client_details_rx: watch::Receiver<Option<InitializeRequestParams>>,
53 auth_info: tokio::sync::RwLock<Option<AuthInfo>>,
54}
55
56#[async_trait]
57impl McpServer for ServerRuntime {
58 async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> {
60 self.client_details_tx
61 .send(Some(client_details))
62 .map_err(|_| {
63 RpcError::internal_error()
64 .with_message("Failed to set client details".to_string())
65 .into()
66 })
67 }
68
69 async fn update_auth_info(&self, new_auth_info: Option<AuthInfo>) {
70 let should_update = {
71 let current = self.auth_info.read().await;
72 match (&*current, &new_auth_info) {
73 (None, Some(_)) => true,
74 (Some(old), Some(new)) => old.token_unique_id != new.token_unique_id,
75 (Some(_), None) => true,
76 (None, None) => false,
77 }
78 };
79
80 if should_update {
81 *self.auth_info.write().await = new_auth_info;
82 }
83 }
84
85 async fn auth_info(&self) -> RwLockReadGuard<'_, Option<AuthInfo>> {
86 self.auth_info.read().await
87 }
88 async fn auth_info_cloned(&self) -> Option<AuthInfo> {
89 let guard = self.auth_info.read().await;
90 guard.clone()
91 }
92
93 async fn wait_for_initialization(&self) {
94 loop {
95 if self.client_details_rx.borrow().is_some() {
96 return;
97 }
98 let mut rx = self.client_details_rx.clone();
99 rx.changed().await.ok();
100 }
101 }
102
103 async fn send(
104 &self,
105 message: MessageFromServer,
106 request_id: Option<RequestId>,
107 request_timeout: Option<Duration>,
108 ) -> SdkResult<Option<ClientMessage>> {
109 let transport_map = self.transport_map.read().await;
110 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
111 RpcError::internal_error()
112 .with_message("transport stream does not exists or is closed!".to_string()),
113 )?;
114
115 let outgoing_request_id = self
116 .request_id_gen
117 .request_id_for_message(&message, request_id);
118
119 let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?;
120
121 let response = transport
122 .send_message(ServerMessages::Single(mcp_message), request_timeout)
123 .await?
124 .map(|res| res.as_single())
125 .transpose()?;
126
127 Ok(response)
128 }
129
130 async fn send_batch(
131 &self,
132 messages: Vec<ServerMessage>,
133 request_timeout: Option<Duration>,
134 ) -> SdkResult<Option<Vec<ClientMessage>>> {
135 let transport_map = self.transport_map.read().await;
136 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
137 RpcError::internal_error()
138 .with_message("transport stream does not exists or is closed!".to_string()),
139 )?;
140
141 transport
142 .send_batch(messages, request_timeout)
143 .map_err(|err| err.into())
144 .await
145 }
146
147 fn server_info(&self) -> &InitializeResult {
150 &self.server_details
151 }
152
153 fn client_info(&self) -> Option<InitializeRequestParams> {
155 self.client_details_rx.borrow().clone()
156 }
157
158 async fn start(self: Arc<Self>) -> SdkResult<()> {
160 let self_clone = self.clone();
161 let transport_map = self_clone.transport_map.read().await;
162
163 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
164 RpcError::internal_error()
165 .with_message("transport stream does not exists or is closed!".to_string()),
166 )?;
167
168 let mut stream = transport.start().await?;
169
170 let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY);
172
173 while let Some(mcp_messages) = stream.next().await {
175 match mcp_messages {
176 ClientMessages::Single(client_message) => {
177 let transport = transport.clone();
178 let self = self.clone();
179 let tx = tx.clone();
180
181 tokio::spawn(async move {
183 let result = self.handle_message(client_message, &transport).await;
184
185 let send_result: SdkResult<_> = match result {
186 Ok(result) => {
187 if let Some(result) = result {
188 transport
189 .send_message(ServerMessages::Single(result), None)
190 .map_err(|e| e.into())
191 .await
192 } else {
193 Ok(None)
194 }
195 }
196 Err(error) => {
197 tracing::error!("Error handling message : {}", error);
198 Ok(None)
199 }
200 };
201 if let Err(error) = tx.send(send_result).await {
203 tracing::error!("Failed to send result to channel: {}", error);
204 }
205 });
206 }
207 ClientMessages::Batch(client_messages) => {
208 let transport = transport.clone();
209 let self = self_clone.clone();
210 let tx = tx.clone();
211
212 tokio::spawn(async move {
213 let handling_tasks: Vec<_> = client_messages
214 .into_iter()
215 .map(|client_message| self.handle_message(client_message, &transport))
216 .collect();
217
218 let send_result = match try_join_all(handling_tasks).await {
219 Ok(results) => {
220 let results: Vec<_> = results.into_iter().flatten().collect();
221 if !results.is_empty() {
222 transport
223 .send_message(ServerMessages::Batch(results), None)
224 .map_err(|e| e.into())
225 .await
226 } else {
227 Ok(None)
228 }
229 }
230 Err(error) => Err(error),
231 };
232
233 if let Err(error) = tx.send(send_result).await {
234 tracing::error!("Failed to send batch result to channel: {}", error);
235 }
236 });
237 }
238 }
239
240 while let Ok(result) = rx.try_recv() {
242 result?; }
244 }
245
246 drop(tx);
248 while let Some(result) = rx.recv().await {
249 result?; }
251
252 return Ok(());
253 }
254
255 async fn stderr_message(&self, message: String) -> SdkResult<()> {
256 let transport_map = self.transport_map.read().await;
257 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
258 RpcError::internal_error()
259 .with_message("transport stream does not exists or is closed!".to_string()),
260 )?;
261 let mut lock = transport.error_stream().write().await;
262
263 if let Some(IoStream::Writable(stderr)) = lock.as_mut() {
264 stderr.write_all(message.as_bytes()).await?;
265 stderr.write_all(b"\n").await?;
266 stderr.flush().await?;
267 }
268 Ok(())
269 }
270
271 #[cfg(feature = "hyper-server")]
272 fn session_id(&self) -> Option<SessionId> {
273 self.session_id.to_owned()
274 }
275}
276
277impl ServerRuntime {
278 pub(crate) async fn consume_payload_string(
279 &self,
280 stream_id: &str,
281 payload: &str,
282 ) -> SdkResult<()> {
283 let transport_map = self.transport_map.read().await;
284
285 let transport = transport_map.get(stream_id).ok_or(
286 RpcError::internal_error()
287 .with_message("stream id does not exists or is closed!".to_string()),
288 )?;
289
290 transport.consume_string_payload(payload).await?;
291
292 Ok(())
293 }
294
295 pub(crate) async fn handle_message(
296 self: &Arc<Self>,
297 message: ClientMessage,
298 transport: &Arc<
299 dyn TransportDispatcher<
300 ClientMessages,
301 MessageFromServer,
302 ClientMessage,
303 ServerMessages,
304 ServerMessage,
305 >,
306 >,
307 ) -> SdkResult<Option<ServerMessage>> {
308 let response = match message {
309 ClientMessage::Request(client_jsonrpc_request) => {
311 let result = self
312 .handler
313 .handle_request(client_jsonrpc_request.request, self.clone())
314 .await;
315 let response: MessageFromServer = match result {
317 Ok(success_value) => success_value.into(),
318 Err(error_value) => {
319 if !self.is_initialized() {
322 return Err(error_value.into());
323 }
324 MessageFromServer::Error(error_value)
325 }
326 };
327
328 let mpc_message: ServerMessage =
329 ServerMessage::from_message(response, Some(client_jsonrpc_request.id))?;
330
331 Some(mpc_message)
332 }
333 ClientMessage::Notification(client_jsonrpc_notification) => {
334 self.handler
335 .handle_notification(client_jsonrpc_notification.notification, self.clone())
336 .await?;
337 None
338 }
339 ClientMessage::Error(jsonrpc_error) => {
340 self.handler
341 .handle_error(&jsonrpc_error.error, self.clone())
342 .await?;
343 if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await {
344 tx_response
345 .send(ClientMessage::Error(jsonrpc_error))
346 .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
347 } else {
348 tracing::warn!(
349 "Received an error response with no corresponding request {:?}",
350 &jsonrpc_error.id
351 );
352 }
353 None
354 }
355 ClientMessage::Response(response) => {
356 if let Some(tx_response) = transport.pending_request_tx(&response.id).await {
357 tx_response
358 .send(ClientMessage::Response(response))
359 .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
360 } else {
361 tracing::warn!(
362 "Received a response with no corresponding request: {:?}",
363 &response.id
364 );
365 }
366 None
367 }
368 };
369 Ok(response)
370 }
371
372 pub(crate) async fn store_transport(
373 &self,
374 stream_id: &str,
375 transport: Arc<
376 dyn TransportDispatcher<
377 ClientMessages,
378 MessageFromServer,
379 ClientMessage,
380 ServerMessages,
381 ServerMessage,
382 >,
383 >,
384 ) -> SdkResult<()> {
385 if stream_id != DEFAULT_STREAM_ID {
386 return Ok(());
387 }
388 let mut transport_map = self.transport_map.write().await;
389 tracing::trace!("save transport for stream id : {}", stream_id);
390 transport_map.insert(stream_id.to_string(), transport);
391 Ok(())
392 }
393
394 pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> {
396 if stream_id != DEFAULT_STREAM_ID {
397 return Ok(());
398 }
399 let transport_map = self.transport_map.read().await;
400 tracing::trace!("removing transport for stream id : {}", stream_id);
401 if let Some(transport) = transport_map.get(stream_id) {
402 transport.shut_down().await?;
403 }
404 Ok(())
406 }
407
408 pub(crate) async fn shutdown(&self) {
409 let mut transport_map = self.transport_map.write().await;
410 let items: Vec<_> = transport_map.drain().map(|(_, v)| v).collect();
411 drop(transport_map);
412 for item in items {
413 let _ = item.shut_down().await;
414 }
415 }
416
417 pub(crate) async fn stream_id_exists(&self, stream_id: &str) -> bool {
418 let transport_map = self.transport_map.read().await;
419 let live_transport = if let Some(t) = transport_map.get(stream_id) {
420 !t.is_shut_down().await
421 } else {
422 false
423 };
424 live_transport
425 }
426
427 pub(crate) async fn start_stream(
428 self: Arc<Self>,
429 transport: Arc<
430 dyn TransportDispatcher<
431 ClientMessages,
432 MessageFromServer,
433 ClientMessage,
434 ServerMessages,
435 ServerMessage,
436 >,
437 >,
438 stream_id: &str,
439 ping_interval: Duration,
440 payload: Option<String>,
441 ) -> SdkResult<()> {
442 let mut stream = transport.start().await?;
443
444 if stream_id == DEFAULT_STREAM_ID {
445 self.store_transport(stream_id, transport.clone()).await?;
446 }
447
448 let self_clone = self.clone();
449
450 let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>();
451 let abort_alive_task = transport
452 .keep_alive(ping_interval, disconnect_tx)
453 .await?
454 .abort_handle();
455
456 let _abort_guard = AbortTaskOnDrop {
458 handle: abort_alive_task,
459 };
460
461 if let Some(payload) = payload {
464 if let Err(err) = transport.consume_string_payload(&payload).await {
465 let _ = self.remove_transport(stream_id).await;
466 return Err(err.into());
467 }
468 }
469
470 let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY);
472
473 loop {
474 tokio::select! {
475 Some(mcp_messages) = stream.next() =>{
476
477 match mcp_messages {
478 ClientMessages::Single(client_message) => {
479 let transport = transport.clone();
480 let self_clone = self.clone();
481 let tx = tx.clone();
482 tokio::spawn(async move {
483
484 let result = self_clone.handle_message(client_message, &transport).await;
485
486 let send_result: SdkResult<_> = match result {
487 Ok(result) => {
488 if let Some(result) = result {
489 transport
490 .send_message(ServerMessages::Single(result), None)
491 .map_err(|e| e.into())
492 .await
493 } else {
494 Ok(None)
495 }
496 }
497 Err(error) => {
498 tracing::error!("Error handling message : {}", error);
499 Ok(None)
500 }
501 };
502 if let Err(error) = tx.send(send_result).await {
503 tracing::error!("Failed to send batch result to channel: {}", error);
504 }
505 });
506 }
507 ClientMessages::Batch(client_messages) => {
508
509 let transport = transport.clone();
510 let self_clone = self_clone.clone();
511 let tx = tx.clone();
512
513 tokio::spawn(async move {
514 let handling_tasks: Vec<_> = client_messages
515 .into_iter()
516 .map(|client_message| self_clone.handle_message(client_message, &transport))
517 .collect();
518
519 let send_result = match try_join_all(handling_tasks).await {
520 Ok(results) => {
521 let results: Vec<_> = results.into_iter().flatten().collect();
522 if !results.is_empty() {
523 transport.send_message(ServerMessages::Batch(results), None)
524 .map_err(|e| e.into())
525 .await
526 }else {
527 Ok(None)
528 }
529 },
530 Err(error) => Err(error),
531 };
532 if let Err(error) = tx.send(send_result).await {
533 tracing::error!("Failed to send batch result to channel: {}", error);
534 }
535 });
536 }
537 }
538
539 while let Ok(result) = rx.try_recv() {
541 result?; }
543
544 if !stream_id.eq(DEFAULT_STREAM_ID){
546 drop(tx);
548 while let Some(result) = rx.recv().await {
549 result?; }
551 return Ok(());
552 }
553 }
554 _ = &mut disconnect_rx => {
555 drop(tx);
557 while let Some(result) = rx.recv().await {
558 result?; }
560 self.remove_transport(stream_id).await?;
561 return Err(SdkError::connection_closed().into());
563
564 }
565 }
566 }
567 }
568
569 #[cfg(feature = "hyper-server")]
570 pub(crate) fn new_instance(
571 server_details: Arc<InitializeResult>,
572 handler: Arc<dyn McpServerHandler>,
573 session_id: SessionId,
574 auth_info: Option<AuthInfo>,
575 ) -> Arc<Self> {
576 use tokio::sync::RwLock;
577
578 let (client_details_tx, client_details_rx) =
579 watch::channel::<Option<InitializeRequestParams>>(None);
580 Arc::new(Self {
581 server_details,
582 handler,
583 session_id: Some(session_id),
584 transport_map: tokio::sync::RwLock::new(HashMap::new()),
585 client_details_tx,
586 client_details_rx,
587 request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
588 auth_info: RwLock::new(auth_info),
589 })
590 }
591
592 pub(crate) fn new(
593 server_details: InitializeResult,
594 transport: impl TransportDispatcher<
595 ClientMessages,
596 MessageFromServer,
597 ClientMessage,
598 ServerMessages,
599 ServerMessage,
600 >,
601 handler: Arc<dyn McpServerHandler>,
602 ) -> Arc<Self> {
603 let mut map: HashMap<String, TransportType> = HashMap::new();
604 map.insert(DEFAULT_STREAM_ID.to_string(), Arc::new(transport));
605 let (client_details_tx, client_details_rx) =
606 watch::channel::<Option<InitializeRequestParams>>(None);
607 Arc::new(Self {
608 server_details: Arc::new(server_details),
609 handler,
610 #[cfg(feature = "hyper-server")]
611 session_id: None,
612 transport_map: tokio::sync::RwLock::new(map),
613 client_details_tx,
614 client_details_rx,
615 request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
616 auth_info: RwLock::new(None),
617 })
618 }
619}