1pub mod mcp_client_runtime;
2pub mod mcp_client_runtime_core;
3use crate::error::{McpSdkError, SdkResult};
4use crate::id_generator::FastIdGenerator;
5use crate::mcp_traits::{IdGenerator, McpClient, McpClientHandler};
6use crate::utils::ensure_server_protocole_compatibility;
7use crate::{
8 mcp_traits::{RequestIdGen, RequestIdGenNumeric},
9 schema::{
10 schema_utils::{
11 self, ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromClient,
12 ServerMessage, ServerMessages,
13 },
14 InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification,
15 RequestId, RpcError, ServerResult,
16 },
17};
18use async_trait::async_trait;
19use futures::future::{join_all, try_join_all};
20use futures::StreamExt;
21
22#[cfg(feature = "streamable-http")]
23use rust_mcp_transport::{ClientStreamableTransport, StreamableTransportOptions};
24use rust_mcp_transport::{IoStream, SessionId, StreamId, Transport, TransportDispatcher};
25use std::{collections::HashMap, sync::Arc, time::Duration};
26use tokio::io::{AsyncBufReadExt, BufReader};
27use tokio::sync::{watch, Mutex};
28
29pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM";
30
31type TransportDispatcherType = dyn TransportDispatcher<
33 ServerMessages,
34 MessageFromClient,
35 ServerMessage,
36 ClientMessages,
37 ClientMessage,
38>;
39type TransportType = Arc<TransportDispatcherType>;
40
41pub struct ClientRuntime {
42 transport_map: tokio::sync::RwLock<HashMap<String, TransportType>>,
44 handler: Box<dyn McpClientHandler>,
46 client_details: InitializeRequestParams,
48 handlers: Mutex<Vec<tokio::task::JoinHandle<Result<(), McpSdkError>>>>,
49 request_id_gen: Box<dyn RequestIdGen>,
51 stream_id_gen: FastIdGenerator,
53 #[cfg(feature = "streamable-http")]
54 transport_options: Option<StreamableTransportOptions>,
56 is_shut_down: Mutex<bool>,
58 session_id: tokio::sync::RwLock<Option<SessionId>>,
60 server_details_tx: watch::Sender<Option<InitializeResult>>,
62 server_details_rx: watch::Receiver<Option<InitializeResult>>,
63}
64
65impl ClientRuntime {
66 pub(crate) fn new(
67 client_details: InitializeRequestParams,
68 transport: TransportType,
69 handler: Box<dyn McpClientHandler>,
70 ) -> Self {
71 let mut map: HashMap<String, TransportType> = HashMap::new();
72 map.insert(DEFAULT_STREAM_ID.to_string(), transport);
73 let (server_details_tx, server_details_rx) =
74 watch::channel::<Option<InitializeResult>>(None);
75 Self {
76 transport_map: tokio::sync::RwLock::new(map),
77 handler,
78 client_details,
79 handlers: Mutex::new(vec![]),
80 request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
81 #[cfg(feature = "streamable-http")]
82 transport_options: None,
83 is_shut_down: Mutex::new(false),
84 session_id: tokio::sync::RwLock::new(None),
85 stream_id_gen: FastIdGenerator::new(Some("s_")),
86 server_details_tx,
87 server_details_rx,
88 }
89 }
90
91 #[cfg(feature = "streamable-http")]
92 pub(crate) fn new_instance(
93 client_details: InitializeRequestParams,
94 transport_options: StreamableTransportOptions,
95 handler: Box<dyn McpClientHandler>,
96 ) -> Self {
97 let map: HashMap<String, TransportType> = HashMap::new();
98 let (server_details_tx, server_details_rx) =
99 watch::channel::<Option<InitializeResult>>(None);
100 Self {
101 transport_map: tokio::sync::RwLock::new(map),
102 handler,
103 client_details,
104 handlers: Mutex::new(vec![]),
105 transport_options: Some(transport_options),
106 is_shut_down: Mutex::new(false),
107 session_id: tokio::sync::RwLock::new(None),
108 request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
109 stream_id_gen: FastIdGenerator::new(Some("s_")),
110 server_details_tx,
111 server_details_rx,
112 }
113 }
114
115 async fn initialize_request(self: Arc<Self>) -> SdkResult<()> {
116 let request = InitializeRequest::new(self.client_details.clone());
117 let result: ServerResult = self.request(request.into(), None).await?.try_into()?;
118
119 if let ServerResult::InitializeResult(initialize_result) = result {
120 ensure_server_protocole_compatibility(
121 &self.client_details.protocol_version,
122 &initialize_result.protocol_version,
123 )?;
124 self.set_server_details(initialize_result)?;
126
127 #[cfg(feature = "streamable-http")]
128 if let Err(error) = self.clone().create_sse_stream().await {
130 tracing::warn!("{error}");
131 }
132
133 self.send_notification(InitializedNotification::new(None).into())
135 .await?;
136 } else {
137 return Err(RpcError::invalid_params()
138 .with_message("Incorrect response to InitializeRequest!".into())
139 .into());
140 }
141
142 Ok(())
143 }
144
145 pub(crate) async fn handle_message(
146 &self,
147 message: ServerMessage,
148 transport: &TransportType,
149 ) -> SdkResult<Option<ClientMessage>> {
150 let response = match message {
151 ServerMessage::Request(jsonrpc_request) => {
152 let result = self
153 .handler
154 .handle_request(jsonrpc_request.request, self)
155 .await;
156
157 let response: MessageFromClient = match result {
159 Ok(success_value) => success_value.into(),
160 Err(error_value) => MessageFromClient::Error(error_value),
161 };
162
163 let mcp_message = ClientMessage::from_message(response, Some(jsonrpc_request.id))?;
164 Some(mcp_message)
165 }
166 ServerMessage::Notification(jsonrpc_notification) => {
167 self.handler
168 .handle_notification(jsonrpc_notification.notification, self)
169 .await?;
170 None
171 }
172 ServerMessage::Error(jsonrpc_error) => {
173 self.handler
174 .handle_error(&jsonrpc_error.error, self)
175 .await?;
176 if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await {
177 tx_response
178 .send(ServerMessage::Error(jsonrpc_error))
179 .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
180 } else {
181 tracing::warn!(
182 "Received an error response with no corresponding request: {:?}",
183 &jsonrpc_error.id
184 );
185 }
186 None
187 }
188 ServerMessage::Response(response) => {
189 if let Some(tx_response) = transport.pending_request_tx(&response.id).await {
190 tx_response
191 .send(ServerMessage::Response(response))
192 .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
193 } else {
194 tracing::warn!(
195 "Received a response with no corresponding request: {:?}",
196 &response.id
197 );
198 }
199 None
200 }
201 };
202 Ok(response)
203 }
204
205 async fn start_standalone(self: Arc<Self>) -> SdkResult<()> {
206 let self_clone = self.clone();
207 let transport_map = self_clone.transport_map.read().await;
208 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
209 RpcError::internal_error()
210 .with_message("transport stream does not exists or is closed!".to_string()),
211 )?;
212
213 let mut stream = transport.start().await?;
215
216 let transport_clone = transport.clone();
217 let mut error_io_stream = transport.error_stream().write().await;
218 let error_io_stream = error_io_stream.take();
219
220 let self_clone = Arc::clone(&self);
221 let self_clone_err = Arc::clone(&self);
222
223 let err_task = tokio::spawn(async move {
225 let self_ref = &*self_clone_err;
226
227 if let Some(IoStream::Readable(error_input)) = error_io_stream {
228 let mut reader = BufReader::new(error_input).lines();
229 loop {
230 tokio::select! {
231 should_break = transport_clone.is_shut_down() =>{
232 if should_break {
233 break;
234 }
235 }
236 line = reader.next_line() =>{
237 match line {
238 Ok(Some(error_message)) => {
239 self_ref
240 .handler
241 .handle_process_error(error_message, self_ref)
242 .await?;
243 }
244 Ok(None) => {
245 break;
247 }
248 Err(e) => {
249 tracing::error!("Error reading from std_err: {e}");
250 break;
251 }
252 }
253 }
254 }
255 }
256 }
257
258 Ok::<(), McpSdkError>(())
259 });
260
261 let transport = transport.clone();
262
263 let main_task = tokio::spawn(async move {
265 while let Some(mcp_messages) = stream.next().await {
266 let self_ref = &*self_clone;
267
268 match mcp_messages {
269 ServerMessages::Single(server_message) => {
270 let result = self_ref.handle_message(server_message, &transport).await;
271
272 match result {
273 Ok(result) => {
274 if let Some(result) = result {
275 transport
276 .send_message(ClientMessages::Single(result), None)
277 .await?;
278 }
279 }
280 Err(error) => {
281 tracing::error!("Error handling message : {}", error)
282 }
283 }
284 }
285 ServerMessages::Batch(server_messages) => {
286 let handling_tasks: Vec<_> = server_messages
287 .into_iter()
288 .map(|server_message| {
289 self_ref.handle_message(server_message, &transport)
290 })
291 .collect();
292 let results: Vec<_> = try_join_all(handling_tasks).await?;
293 let results: Vec<_> = results.into_iter().flatten().collect();
294
295 if !results.is_empty() {
296 transport
297 .send_message(ClientMessages::Batch(results), None)
298 .await?;
299 }
300 }
301 }
302 }
303 Ok::<(), McpSdkError>(())
304 });
305
306 self.clone().initialize_request().await?;
308
309 let mut lock = self.handlers.lock().await;
310 lock.push(main_task);
311 lock.push(err_task);
312 Ok(())
313 }
314
315 pub(crate) async fn store_transport(
316 &self,
317 stream_id: &str,
318 transport: TransportType,
319 ) -> SdkResult<()> {
320 let mut transport_map = self.transport_map.write().await;
321 tracing::trace!("save transport for stream id : {}", stream_id);
322 transport_map.insert(stream_id.to_string(), transport);
323 Ok(())
324 }
325
326 pub(crate) async fn transport_by_stream(&self, stream_id: &str) -> SdkResult<TransportType> {
327 let transport_map = self.transport_map.read().await;
328 transport_map.get(stream_id).cloned().ok_or_else(|| {
329 RpcError::internal_error()
330 .with_message(format!("Transport for key {stream_id} not found"))
331 .into()
332 })
333 }
334
335 #[cfg(feature = "streamable-http")]
336 pub(crate) async fn new_transport(
337 &self,
338 session_id: Option<SessionId>,
339 standalone: bool,
340 ) -> SdkResult<
341 impl TransportDispatcher<
342 ServerMessages,
343 MessageFromClient,
344 ServerMessage,
345 ClientMessages,
346 ClientMessage,
347 >,
348 > {
349 let options = self
350 .transport_options
351 .as_ref()
352 .ok_or(schema_utils::SdkError::connection_closed())?;
353 let transport = ClientStreamableTransport::new(options, session_id, standalone)?;
354
355 Ok(transport)
356 }
357
358 #[cfg(feature = "streamable-http")]
359 pub(crate) async fn create_sse_stream(self: Arc<Self>) -> SdkResult<()> {
360 let stream_id: StreamId = DEFAULT_STREAM_ID.into();
361 let session_id = self.session_id.read().await.clone();
362 let transport: Arc<
363 dyn TransportDispatcher<
364 ServerMessages,
365 MessageFromClient,
366 ServerMessage,
367 ClientMessages,
368 ClientMessage,
369 >,
370 > = Arc::new(self.new_transport(session_id, true).await?);
371 let mut stream = transport.start().await?;
372 self.store_transport(&stream_id, transport.clone()).await?;
373
374 let self_clone = Arc::clone(&self);
375
376 let main_task = tokio::spawn(async move {
377 loop {
378 if let Some(mcp_messages) = stream.next().await {
379 match mcp_messages {
380 ServerMessages::Single(server_message) => {
381 let result = self.handle_message(server_message, &transport).await?;
382
383 if let Some(result) = result {
384 transport
385 .send_message(ClientMessages::Single(result), None)
386 .await?;
387 }
388 }
389 ServerMessages::Batch(server_messages) => {
390 let handling_tasks: Vec<_> = server_messages
391 .into_iter()
392 .map(|server_message| {
393 self.handle_message(server_message, &transport)
394 })
395 .collect();
396
397 let results: Vec<_> = try_join_all(handling_tasks).await?;
398
399 let results: Vec<_> = results.into_iter().flatten().collect();
400
401 if !results.is_empty() {
402 transport
403 .send_message(ClientMessages::Batch(results), None)
404 .await?;
405 }
406 }
407 }
408 if !stream_id.eq(DEFAULT_STREAM_ID) {
410 return Ok::<_, McpSdkError>(());
411 }
412 } else {
413 return Ok::<_, McpSdkError>(());
415 }
416 }
417 });
418
419 let mut lock = self_clone.handlers.lock().await;
420 lock.push(main_task);
421
422 Ok(())
423 }
424
425 #[cfg(feature = "streamable-http")]
426 pub(crate) async fn start_stream(
427 &self,
428 messages: ClientMessages,
429 timeout: Option<Duration>,
430 ) -> SdkResult<Option<ServerMessages>> {
431 use futures::stream::{AbortHandle, Abortable};
432 let stream_id: StreamId = self.stream_id_gen.generate();
433 let session_id = self.session_id.read().await.clone();
434 let no_session_id = session_id.is_none();
435
436 let has_request = match &messages {
437 ClientMessages::Single(client_message) => client_message.is_request(),
438 ClientMessages::Batch(client_messages) => {
439 client_messages.iter().any(|m| m.is_request())
440 }
441 };
442
443 let transport = Arc::new(self.new_transport(session_id, false).await?);
444
445 let mut stream = transport.start().await?;
446
447 self.store_transport(&stream_id, transport).await?;
448
449 let transport = self.transport_by_stream(&stream_id).await?; let send_task = async {
452 let result = transport.send_message(messages, timeout).await?;
453
454 if no_session_id {
455 if let Some(request_id) = transport.session_id().await.clone() {
456 let mut guard = self.session_id.write().await;
457 *guard = Some(request_id)
458 }
459 }
460
461 Ok::<_, McpSdkError>(result)
462 };
463
464 if !has_request {
465 return send_task.await;
466 }
467
468 let (abort_recv_handle, abort_recv_reg) = AbortHandle::new_pair();
469
470 let receive_task = async {
471 loop {
472 tokio::select! {
473 Some(mcp_messages) = stream.next() =>{
474
475 match mcp_messages {
476 ServerMessages::Single(server_message) => {
477 let result = self.handle_message(server_message, &transport).await?;
478 if let Some(result) = result {
479 transport.send_message(ClientMessages::Single(result), None).await?;
480 }
481 }
482 ServerMessages::Batch(server_messages) => {
483
484 let handling_tasks: Vec<_> = server_messages
485 .into_iter()
486 .map(|server_message| self.handle_message(server_message, &transport))
487 .collect();
488
489 let results: Vec<_> = try_join_all(handling_tasks).await?;
490
491 let results: Vec<_> = results.into_iter().flatten().collect();
492
493 if !results.is_empty() {
494 transport.send_message(ClientMessages::Batch(results), None).await?;
495 }
496 }
497 }
498 if !stream_id.eq(DEFAULT_STREAM_ID){
500 return Ok::<_, McpSdkError>(());
501 }
502 }
503 }
504 }
505 };
506
507 let receive_task = Abortable::new(receive_task, abort_recv_reg);
508
509 tokio::pin!(send_task);
511 tokio::pin!(receive_task);
512
513 let (send_res, _) = tokio::select! {
515 res = &mut send_task => {
516 abort_recv_handle.abort();
518 (res, receive_task.await) }
520 res = &mut receive_task => {
521 (send_task.await, res)
522 }
523 };
524 send_res
525 }
526}
527
528#[async_trait]
529impl McpClient for ClientRuntime {
530 async fn send(
531 &self,
532 message: MessageFromClient,
533 request_id: Option<RequestId>,
534 request_timeout: Option<Duration>,
535 ) -> SdkResult<Option<ServerMessage>> {
536 #[cfg(feature = "streamable-http")]
537 {
538 if self.transport_options.is_some() {
539 let outgoing_request_id = self
540 .request_id_gen
541 .request_id_for_message(&message, request_id);
542 let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?;
543
544 let response = self
545 .start_stream(ClientMessages::Single(mcp_message), request_timeout)
546 .await?;
547 return response
548 .map(|r| r.as_single())
549 .transpose()
550 .map_err(|err| err.into());
551 }
552 }
553
554 let transport_map = self.transport_map.read().await;
555
556 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
557 RpcError::internal_error()
558 .with_message("transport stream does not exists or is closed!".to_string()),
559 )?;
560
561 let outgoing_request_id = self
562 .request_id_gen
563 .request_id_for_message(&message, request_id);
564
565 let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?;
566 let response = transport
567 .send_message(ClientMessages::Single(mcp_message), request_timeout)
568 .await?;
569 response
570 .map(|r| r.as_single())
571 .transpose()
572 .map_err(|err| err.into())
573 }
574
575 async fn send_batch(
576 &self,
577 messages: Vec<ClientMessage>,
578 timeout: Option<Duration>,
579 ) -> SdkResult<Option<Vec<ServerMessage>>> {
580 #[cfg(feature = "streamable-http")]
581 {
582 if self.transport_options.is_some() {
583 let result = self
584 .start_stream(ClientMessages::Batch(messages), timeout)
585 .await?;
586 return result
588 .map(|r| r.as_batch())
589 .transpose()
590 .map_err(|err| err.into());
591 }
592 }
593
594 let transport_map = self.transport_map.read().await;
595 let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
596 RpcError::internal_error()
597 .with_message("transport stream does not exists or is closed!".to_string()),
598 )?;
599 transport
600 .send_batch(messages, timeout)
601 .await
602 .map_err(|err| err.into())
603 }
604
605 async fn start(self: Arc<Self>) -> SdkResult<()> {
606 #[cfg(feature = "streamable-http")]
607 {
608 if self.transport_options.is_some() {
609 self.initialize_request().await?;
610 return Ok(());
611 }
612 }
613
614 self.start_standalone().await
615 }
616
617 fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()> {
618 self.server_details_tx
619 .send(Some(server_details))
620 .map_err(|_| {
621 RpcError::internal_error()
622 .with_message("Failed to set server details".to_string())
623 .into()
624 })
625 }
626
627 fn client_info(&self) -> &InitializeRequestParams {
628 &self.client_details
629 }
630
631 fn server_info(&self) -> Option<InitializeResult> {
632 self.server_details_rx.borrow().clone()
633 }
634
635 async fn is_shut_down(&self) -> bool {
636 let result = self.is_shut_down.lock().await;
637 *result
638 }
639
640 async fn shut_down(&self) -> SdkResult<()> {
641 let mut is_shut_down_lock = self.is_shut_down.lock().await;
642 *is_shut_down_lock = true;
643
644 let mut transport_map = self.transport_map.write().await;
645 let transports: Vec<_> = transport_map.drain().map(|(_, v)| v).collect();
646 drop(transport_map);
647 for transport in transports {
648 let _ = transport.shut_down().await;
649 }
650
651 let mut tasks_lock = self.handlers.lock().await;
653 let join_handlers: Vec<_> = tasks_lock.drain(..).collect();
654 join_all(join_handlers).await;
655
656 Ok(())
657 }
658
659 async fn terminate_session(&self) {
660 #[cfg(feature = "streamable-http")]
661 {
662 if let Some(transport_options) = self.transport_options.as_ref() {
663 let session_id = self.session_id.read().await.clone();
664 transport_options
665 .terminate_session(session_id.as_ref())
666 .await;
667 let _ = self.shut_down().await;
668 }
669 }
670 let _ = self.shut_down().await;
671 }
672}