1pub mod mcp_client_runtime;
2pub mod mcp_client_runtime_core;
3use crate::error::{McpSdkError, SdkResult};
4use crate::id_generator::FastIdGenerator;
5use crate::mcp_traits::mcp_client::McpClient;
6use crate::mcp_traits::mcp_handler::McpClientHandler;
7use crate::mcp_traits::IdGenerator;
8use crate::utils::ensure_server_protocole_compatibility;
9use crate::{
10 mcp_traits::{RequestIdGen, RequestIdGenNumeric},
11 schema::{
12 schema_utils::{
13 self, ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromClient,
14 ServerMessage, ServerMessages,
15 },
16 InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification,
17 RequestId, RpcError, ServerResult,
18 },
19};
20use async_trait::async_trait;
21use futures::future::{join_all, try_join_all};
22use futures::StreamExt;
23
24#[cfg(feature = "streamable-http")]
25use rust_mcp_transport::{ClientStreamableTransport, StreamableTransportOptions};
26use rust_mcp_transport::{IoStream, SessionId, StreamId, Transport, TransportDispatcher};
27use std::{collections::HashMap, sync::Arc, time::Duration};
28use tokio::io::{AsyncBufReadExt, BufReader};
29use tokio::sync::{watch, Mutex};
30
31pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM";
32
33type TransportDispatcherType = dyn TransportDispatcher<
35 ServerMessages,
36 MessageFromClient,
37 ServerMessage,
38 ClientMessages,
39 ClientMessage,
40>;
41type TransportType = Arc<TransportDispatcherType>;
42
43pub struct ClientRuntime {
44 transport_map: tokio::sync::RwLock<HashMap<String, TransportType>>,
46 handler: Box<dyn McpClientHandler>,
48 client_details: InitializeRequestParams,
50 handlers: Mutex<Vec<tokio::task::JoinHandle<Result<(), McpSdkError>>>>,
51 request_id_gen: Box<dyn RequestIdGen>,
53 stream_id_gen: FastIdGenerator,
55 #[cfg(feature = "streamable-http")]
56 transport_options: Option<StreamableTransportOptions>,
58 is_shut_down: Mutex<bool>,
60 session_id: tokio::sync::RwLock<Option<SessionId>>,
62 server_details_tx: watch::Sender<Option<InitializeResult>>,
64 server_details_rx: watch::Receiver<Option<InitializeResult>>,
65}
66
67impl ClientRuntime {
68 pub(crate) fn new(
69 client_details: InitializeRequestParams,
70 transport: TransportType,
71 handler: Box<dyn McpClientHandler>,
72 ) -> Self {
73 let mut map: HashMap<String, TransportType> = HashMap::new();
74 map.insert(DEFAULT_STREAM_ID.to_string(), transport);
75 let (server_details_tx, server_details_rx) =
76 watch::channel::<Option<InitializeResult>>(None);
77 Self {
78 transport_map: tokio::sync::RwLock::new(map),
79 handler,
80 client_details,
81 handlers: Mutex::new(vec![]),
82 request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
83 #[cfg(feature = "streamable-http")]
84 transport_options: None,
85 is_shut_down: Mutex::new(false),
86 session_id: tokio::sync::RwLock::new(None),
87 stream_id_gen: FastIdGenerator::new(Some("s_")),
88 server_details_tx,
89 server_details_rx,
90 }
91 }
92
93 #[cfg(feature = "streamable-http")]
94 pub(crate) fn new_instance(
95 client_details: InitializeRequestParams,
96 transport_options: StreamableTransportOptions,
97 handler: Box<dyn McpClientHandler>,
98 ) -> Self {
99 let map: HashMap<String, TransportType> = HashMap::new();
100 let (server_details_tx, server_details_rx) =
101 watch::channel::<Option<InitializeResult>>(None);
102 Self {
103 transport_map: tokio::sync::RwLock::new(map),
104 handler,
105 client_details,
106 handlers: Mutex::new(vec![]),
107 transport_options: Some(transport_options),
108 is_shut_down: Mutex::new(false),
109 session_id: tokio::sync::RwLock::new(None),
110 request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
111 stream_id_gen: FastIdGenerator::new(Some("s_")),
112 server_details_tx,
113 server_details_rx,
114 }
115 }
116
117 async fn initialize_request(self: Arc<Self>) -> SdkResult<()> {
118 let request = InitializeRequest::new(self.client_details.clone());
119 let result: ServerResult = self.request(request.into(), None).await?.try_into()?;
120
121 if let ServerResult::InitializeResult(initialize_result) = result {
122 ensure_server_protocole_compatibility(
123 &self.client_details.protocol_version,
124 &initialize_result.protocol_version,
125 )?;
126 self.set_server_details(initialize_result)?;
128
129 #[cfg(feature = "streamable-http")]
130 if let Err(error) = self.clone().create_sse_stream().await {
132 tracing::warn!("{error}");
133 }
134
135 self.send_notification(InitializedNotification::new(None).into())
137 .await?;
138 } else {
139 return Err(RpcError::invalid_params()
140 .with_message("Incorrect response to InitializeRequest!".into())
141 .into());
142 }
143
144 Ok(())
145 }
146
147 pub(crate) async fn handle_message(
148 &self,
149 message: ServerMessage,
150 transport: &TransportType,
151 ) -> SdkResult<Option<ClientMessage>> {
152 let response = match message {
153 ServerMessage::Request(jsonrpc_request) => {
154 let result = self
155 .handler
156 .handle_request(jsonrpc_request.request, self)
157 .await;
158
159 let response: MessageFromClient = match result {
161 Ok(success_value) => success_value.into(),
162 Err(error_value) => MessageFromClient::Error(error_value),
163 };
164
165 let mcp_message = ClientMessage::from_message(response, Some(jsonrpc_request.id))?;
166 Some(mcp_message)
167 }
168 ServerMessage::Notification(jsonrpc_notification) => {
169 self.handler
170 .handle_notification(jsonrpc_notification.notification, self)
171 .await?;
172 None
173 }
174 ServerMessage::Error(jsonrpc_error) => {
175 self.handler
176 .handle_error(&jsonrpc_error.error, self)
177 .await?;
178 if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await {
179 tx_response
180 .send(ServerMessage::Error(jsonrpc_error))
181 .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
182 } else {
183 tracing::warn!(
184 "Received an error response with no corresponding request: {:?}",
185 &jsonrpc_error.id
186 );
187 }
188 None
189 }
190 ServerMessage::Response(response) => {
191 if let Some(tx_response) = transport.pending_request_tx(&response.id).await {
192 tx_response
193 .send(ServerMessage::Response(response))
194 .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
195 } else {
196 tracing::warn!(
197 "Received a response with no corresponding request: {:?}",
198 &response.id
199 );
200 }
201 None
202 }
203 };
204 Ok(response)
205 }
206
207 async fn start_standalone(self: Arc<Self>) -> SdkResult<()> {
208 let self_clone = self.clone();
209 let transport_map = self_clone.transport_map.read().await;
210 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
211 RpcError::internal_error()
212 .with_message("transport stream does not exists or is closed!".to_string()),
213 )?;
214
215 let mut stream = transport.start().await?;
217
218 let transport_clone = transport.clone();
219 let mut error_io_stream = transport.error_stream().write().await;
220 let error_io_stream = error_io_stream.take();
221
222 let self_clone = Arc::clone(&self);
223 let self_clone_err = Arc::clone(&self);
224
225 let err_task = tokio::spawn(async move {
227 let self_ref = &*self_clone_err;
228
229 if let Some(IoStream::Readable(error_input)) = error_io_stream {
230 let mut reader = BufReader::new(error_input).lines();
231 loop {
232 tokio::select! {
233 should_break = transport_clone.is_shut_down() =>{
234 if should_break {
235 break;
236 }
237 }
238 line = reader.next_line() =>{
239 match line {
240 Ok(Some(error_message)) => {
241 self_ref
242 .handler
243 .handle_process_error(error_message, self_ref)
244 .await?;
245 }
246 Ok(None) => {
247 break;
249 }
250 Err(e) => {
251 tracing::error!("Error reading from std_err: {e}");
252 break;
253 }
254 }
255 }
256 }
257 }
258 }
259
260 Ok::<(), McpSdkError>(())
261 });
262
263 let transport = transport.clone();
264
265 let main_task = tokio::spawn(async move {
267 while let Some(mcp_messages) = stream.next().await {
268 let self_ref = &*self_clone;
269
270 match mcp_messages {
271 ServerMessages::Single(server_message) => {
272 let result = self_ref.handle_message(server_message, &transport).await;
273
274 match result {
275 Ok(result) => {
276 if let Some(result) = result {
277 transport
278 .send_message(ClientMessages::Single(result), None)
279 .await?;
280 }
281 }
282 Err(error) => {
283 tracing::error!("Error handling message : {}", error)
284 }
285 }
286 }
287 ServerMessages::Batch(server_messages) => {
288 let handling_tasks: Vec<_> = server_messages
289 .into_iter()
290 .map(|server_message| {
291 self_ref.handle_message(server_message, &transport)
292 })
293 .collect();
294 let results: Vec<_> = try_join_all(handling_tasks).await?;
295 let results: Vec<_> = results.into_iter().flatten().collect();
296
297 if !results.is_empty() {
298 transport
299 .send_message(ClientMessages::Batch(results), None)
300 .await?;
301 }
302 }
303 }
304 }
305 Ok::<(), McpSdkError>(())
306 });
307
308 self.clone().initialize_request().await?;
310
311 let mut lock = self.handlers.lock().await;
312 lock.push(main_task);
313 lock.push(err_task);
314 Ok(())
315 }
316
317 pub(crate) async fn store_transport(
318 &self,
319 stream_id: &str,
320 transport: TransportType,
321 ) -> SdkResult<()> {
322 let mut transport_map = self.transport_map.write().await;
323 tracing::trace!("save transport for stream id : {}", stream_id);
324 transport_map.insert(stream_id.to_string(), transport);
325 Ok(())
326 }
327
328 pub(crate) async fn transport_by_stream(&self, stream_id: &str) -> SdkResult<TransportType> {
329 let transport_map = self.transport_map.read().await;
330 transport_map.get(stream_id).cloned().ok_or_else(|| {
331 RpcError::internal_error()
332 .with_message(format!("Transport for key {stream_id} not found"))
333 .into()
334 })
335 }
336
337 #[cfg(feature = "streamable-http")]
338 pub(crate) async fn new_transport(
339 &self,
340 session_id: Option<SessionId>,
341 standalone: bool,
342 ) -> SdkResult<
343 impl TransportDispatcher<
344 ServerMessages,
345 MessageFromClient,
346 ServerMessage,
347 ClientMessages,
348 ClientMessage,
349 >,
350 > {
351 let options = self
352 .transport_options
353 .as_ref()
354 .ok_or(schema_utils::SdkError::connection_closed())?;
355 let transport = ClientStreamableTransport::new(options, session_id, standalone)?;
356
357 Ok(transport)
358 }
359
360 #[cfg(feature = "streamable-http")]
361 pub(crate) async fn create_sse_stream(self: Arc<Self>) -> SdkResult<()> {
362 let stream_id: StreamId = DEFAULT_STREAM_ID.into();
363 let session_id = self.session_id.read().await.clone();
364 let transport: Arc<
365 dyn TransportDispatcher<
366 ServerMessages,
367 MessageFromClient,
368 ServerMessage,
369 ClientMessages,
370 ClientMessage,
371 >,
372 > = Arc::new(self.new_transport(session_id, true).await?);
373 let mut stream = transport.start().await?;
374 self.store_transport(&stream_id, transport.clone()).await?;
375
376 let self_clone = Arc::clone(&self);
377
378 let main_task = tokio::spawn(async move {
379 loop {
380 if let Some(mcp_messages) = stream.next().await {
381 match mcp_messages {
382 ServerMessages::Single(server_message) => {
383 let result = self.handle_message(server_message, &transport).await?;
384
385 if let Some(result) = result {
386 transport
387 .send_message(ClientMessages::Single(result), None)
388 .await?;
389 }
390 }
391 ServerMessages::Batch(server_messages) => {
392 let handling_tasks: Vec<_> = server_messages
393 .into_iter()
394 .map(|server_message| {
395 self.handle_message(server_message, &transport)
396 })
397 .collect();
398
399 let results: Vec<_> = try_join_all(handling_tasks).await?;
400
401 let results: Vec<_> = results.into_iter().flatten().collect();
402
403 if !results.is_empty() {
404 transport
405 .send_message(ClientMessages::Batch(results), None)
406 .await?;
407 }
408 }
409 }
410 if !stream_id.eq(DEFAULT_STREAM_ID) {
412 return Ok::<_, McpSdkError>(());
413 }
414 } else {
415 return Ok::<_, McpSdkError>(());
417 }
418 }
419 });
420
421 let mut lock = self_clone.handlers.lock().await;
422 lock.push(main_task);
423
424 Ok(())
425 }
426
427 #[cfg(feature = "streamable-http")]
428 pub(crate) async fn start_stream(
429 &self,
430 messages: ClientMessages,
431 timeout: Option<Duration>,
432 ) -> SdkResult<Option<ServerMessages>> {
433 use futures::stream::{AbortHandle, Abortable};
434 let stream_id: StreamId = self.stream_id_gen.generate();
435 let session_id = self.session_id.read().await.clone();
436 let no_session_id = session_id.is_none();
437
438 let has_request = match &messages {
439 ClientMessages::Single(client_message) => client_message.is_request(),
440 ClientMessages::Batch(client_messages) => {
441 client_messages.iter().any(|m| m.is_request())
442 }
443 };
444
445 let transport = Arc::new(self.new_transport(session_id, false).await?);
446
447 let mut stream = transport.start().await?;
448
449 self.store_transport(&stream_id, transport).await?;
450
451 let transport = self.transport_by_stream(&stream_id).await?; let send_task = async {
454 let result = transport.send_message(messages, timeout).await?;
455
456 if no_session_id {
457 if let Some(request_id) = transport.session_id().await.clone() {
458 let mut guard = self.session_id.write().await;
459 *guard = Some(request_id)
460 }
461 }
462
463 Ok::<_, McpSdkError>(result)
464 };
465
466 if !has_request {
467 return send_task.await;
468 }
469
470 let (abort_recv_handle, abort_recv_reg) = AbortHandle::new_pair();
471
472 let receive_task = async {
473 loop {
474 tokio::select! {
475 Some(mcp_messages) = stream.next() =>{
476
477 match mcp_messages {
478 ServerMessages::Single(server_message) => {
479 let result = self.handle_message(server_message, &transport).await?;
480 if let Some(result) = result {
481 transport.send_message(ClientMessages::Single(result), None).await?;
482 }
483 }
484 ServerMessages::Batch(server_messages) => {
485
486 let handling_tasks: Vec<_> = server_messages
487 .into_iter()
488 .map(|server_message| self.handle_message(server_message, &transport))
489 .collect();
490
491 let results: Vec<_> = try_join_all(handling_tasks).await?;
492
493 let results: Vec<_> = results.into_iter().flatten().collect();
494
495 if !results.is_empty() {
496 transport.send_message(ClientMessages::Batch(results), None).await?;
497 }
498 }
499 }
500 if !stream_id.eq(DEFAULT_STREAM_ID){
502 return Ok::<_, McpSdkError>(());
503 }
504 }
505 }
506 }
507 };
508
509 let receive_task = Abortable::new(receive_task, abort_recv_reg);
510
511 tokio::pin!(send_task);
513 tokio::pin!(receive_task);
514
515 let (send_res, _) = tokio::select! {
517 res = &mut send_task => {
518 abort_recv_handle.abort();
520 (res, receive_task.await) }
522 res = &mut receive_task => {
523 (send_task.await, res)
524 }
525 };
526 send_res
527 }
528}
529
530#[async_trait]
531impl McpClient for ClientRuntime {
532 async fn send(
533 &self,
534 message: MessageFromClient,
535 request_id: Option<RequestId>,
536 request_timeout: Option<Duration>,
537 ) -> SdkResult<Option<ServerMessage>> {
538 #[cfg(feature = "streamable-http")]
539 {
540 if self.transport_options.is_some() {
541 let outgoing_request_id = self
542 .request_id_gen
543 .request_id_for_message(&message, request_id);
544 let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?;
545
546 let response = self
547 .start_stream(ClientMessages::Single(mcp_message), request_timeout)
548 .await?;
549 return response
550 .map(|r| r.as_single())
551 .transpose()
552 .map_err(|err| err.into());
553 }
554 }
555
556 let transport_map = self.transport_map.read().await;
557
558 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
559 RpcError::internal_error()
560 .with_message("transport stream does not exists or is closed!".to_string()),
561 )?;
562
563 let outgoing_request_id = self
564 .request_id_gen
565 .request_id_for_message(&message, request_id);
566
567 let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?;
568 let response = transport
569 .send_message(ClientMessages::Single(mcp_message), request_timeout)
570 .await?;
571 response
572 .map(|r| r.as_single())
573 .transpose()
574 .map_err(|err| err.into())
575 }
576
577 async fn send_batch(
578 &self,
579 messages: Vec<ClientMessage>,
580 timeout: Option<Duration>,
581 ) -> SdkResult<Option<Vec<ServerMessage>>> {
582 #[cfg(feature = "streamable-http")]
583 {
584 if self.transport_options.is_some() {
585 let result = self
586 .start_stream(ClientMessages::Batch(messages), timeout)
587 .await?;
588 return result
590 .map(|r| r.as_batch())
591 .transpose()
592 .map_err(|err| err.into());
593 }
594 }
595
596 let transport_map = self.transport_map.read().await;
597 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
598 RpcError::internal_error()
599 .with_message("transport stream does not exists or is closed!".to_string()),
600 )?;
601 transport
602 .send_batch(messages, timeout)
603 .await
604 .map_err(|err| err.into())
605 }
606
607 async fn start(self: Arc<Self>) -> SdkResult<()> {
608 #[cfg(feature = "streamable-http")]
609 {
610 if self.transport_options.is_some() {
611 self.initialize_request().await?;
612 return Ok(());
613 }
614 }
615
616 self.start_standalone().await
617 }
618
619 fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()> {
620 self.server_details_tx
621 .send(Some(server_details))
622 .map_err(|_| {
623 RpcError::internal_error()
624 .with_message("Failed to set server details".to_string())
625 .into()
626 })
627 }
628
629 fn client_info(&self) -> &InitializeRequestParams {
630 &self.client_details
631 }
632
633 fn server_info(&self) -> Option<InitializeResult> {
634 self.server_details_rx.borrow().clone()
635 }
636
637 async fn is_shut_down(&self) -> bool {
638 let result = self.is_shut_down.lock().await;
639 *result
640 }
641
642 async fn shut_down(&self) -> SdkResult<()> {
643 let mut is_shut_down_lock = self.is_shut_down.lock().await;
644 *is_shut_down_lock = true;
645
646 let mut transport_map = self.transport_map.write().await;
647 let transports: Vec<_> = transport_map.drain().map(|(_, v)| v).collect();
648 drop(transport_map);
649 for transport in transports {
650 let _ = transport.shut_down().await;
651 }
652
653 let mut tasks_lock = self.handlers.lock().await;
655 let join_handlers: Vec<_> = tasks_lock.drain(..).collect();
656 join_all(join_handlers).await;
657
658 Ok(())
659 }
660
661 async fn terminate_session(&self) {
662 #[cfg(feature = "streamable-http")]
663 {
664 if let Some(transport_options) = self.transport_options.as_ref() {
665 let session_id = self.session_id.read().await.clone();
666 transport_options
667 .terminate_session(session_id.as_ref())
668 .await;
669 let _ = self.shut_down().await;
670 }
671 }
672 let _ = self.shut_down().await;
673 }
674}