rust_mcp_sdk/mcp_runtimes/
client_runtime.rs

1pub mod mcp_client_runtime;
2pub mod mcp_client_runtime_core;
3use crate::error::{McpSdkError, SdkResult};
4use crate::id_generator::FastIdGenerator;
5use crate::mcp_traits::{IdGenerator, McpClient, McpClientHandler};
6use crate::utils::ensure_server_protocole_compatibility;
7use crate::{
8    mcp_traits::{RequestIdGen, RequestIdGenNumeric},
9    schema::{
10        schema_utils::{
11            self, ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromClient,
12            ServerMessage, ServerMessages,
13        },
14        InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification,
15        RequestId, RpcError, ServerResult,
16    },
17};
18use async_trait::async_trait;
19use futures::future::{join_all, try_join_all};
20use futures::StreamExt;
21
22#[cfg(feature = "streamable-http")]
23use rust_mcp_transport::{ClientStreamableTransport, StreamableTransportOptions};
24use rust_mcp_transport::{IoStream, SessionId, StreamId, Transport, TransportDispatcher};
25use std::{collections::HashMap, sync::Arc, time::Duration};
26use tokio::io::{AsyncBufReadExt, BufReader};
27use tokio::sync::{watch, Mutex};
28
29pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM";
30
31// Define a type alias for the TransportDispatcher trait object
32type TransportDispatcherType = dyn TransportDispatcher<
33    ServerMessages,
34    MessageFromClient,
35    ServerMessage,
36    ClientMessages,
37    ClientMessage,
38>;
39type TransportType = Arc<TransportDispatcherType>;
40
41pub struct ClientRuntime {
42    // A thread-safe map storing transport types
43    transport_map: tokio::sync::RwLock<HashMap<String, TransportType>>,
44    // The handler for processing MCP messages
45    handler: Box<dyn McpClientHandler>,
46    // Information about the server
47    client_details: InitializeRequestParams,
48    handlers: Mutex<Vec<tokio::task::JoinHandle<Result<(), McpSdkError>>>>,
49    // Generator for unique request IDs
50    request_id_gen: Box<dyn RequestIdGen>,
51    // Generator for stream IDs
52    stream_id_gen: FastIdGenerator,
53    #[cfg(feature = "streamable-http")]
54    // Optional configuration for streamable transport
55    transport_options: Option<StreamableTransportOptions>,
56    // Flag indicating whether the client has been shut down
57    is_shut_down: Mutex<bool>,
58    // Session ID
59    session_id: tokio::sync::RwLock<Option<SessionId>>,
60    // Details about the connected server
61    server_details_tx: watch::Sender<Option<InitializeResult>>,
62    server_details_rx: watch::Receiver<Option<InitializeResult>>,
63}
64
65impl ClientRuntime {
66    pub(crate) fn new(
67        client_details: InitializeRequestParams,
68        transport: TransportType,
69        handler: Box<dyn McpClientHandler>,
70    ) -> Self {
71        let mut map: HashMap<String, TransportType> = HashMap::new();
72        map.insert(DEFAULT_STREAM_ID.to_string(), transport);
73        let (server_details_tx, server_details_rx) =
74            watch::channel::<Option<InitializeResult>>(None);
75        Self {
76            transport_map: tokio::sync::RwLock::new(map),
77            handler,
78            client_details,
79            handlers: Mutex::new(vec![]),
80            request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
81            #[cfg(feature = "streamable-http")]
82            transport_options: None,
83            is_shut_down: Mutex::new(false),
84            session_id: tokio::sync::RwLock::new(None),
85            stream_id_gen: FastIdGenerator::new(Some("s_")),
86            server_details_tx,
87            server_details_rx,
88        }
89    }
90
91    #[cfg(feature = "streamable-http")]
92    pub(crate) fn new_instance(
93        client_details: InitializeRequestParams,
94        transport_options: StreamableTransportOptions,
95        handler: Box<dyn McpClientHandler>,
96    ) -> Self {
97        let map: HashMap<String, TransportType> = HashMap::new();
98        let (server_details_tx, server_details_rx) =
99            watch::channel::<Option<InitializeResult>>(None);
100        Self {
101            transport_map: tokio::sync::RwLock::new(map),
102            handler,
103            client_details,
104            handlers: Mutex::new(vec![]),
105            transport_options: Some(transport_options),
106            is_shut_down: Mutex::new(false),
107            session_id: tokio::sync::RwLock::new(None),
108            request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
109            stream_id_gen: FastIdGenerator::new(Some("s_")),
110            server_details_tx,
111            server_details_rx,
112        }
113    }
114
115    async fn initialize_request(self: Arc<Self>) -> SdkResult<()> {
116        let request = InitializeRequest::new(self.client_details.clone());
117        let result: ServerResult = self.request(request.into(), None).await?.try_into()?;
118
119        if let ServerResult::InitializeResult(initialize_result) = result {
120            ensure_server_protocole_compatibility(
121                &self.client_details.protocol_version,
122                &initialize_result.protocol_version,
123            )?;
124            // store server details
125            self.set_server_details(initialize_result)?;
126
127            #[cfg(feature = "streamable-http")]
128            // try to create a sse stream for server initiated messages , if supported by the server
129            if let Err(error) = self.clone().create_sse_stream().await {
130                tracing::warn!("{error}");
131            }
132
133            // send a InitializedNotification to the server
134            self.send_notification(InitializedNotification::new(None).into())
135                .await?;
136        } else {
137            return Err(RpcError::invalid_params()
138                .with_message("Incorrect response to InitializeRequest!".into())
139                .into());
140        }
141
142        Ok(())
143    }
144
145    pub(crate) async fn handle_message(
146        &self,
147        message: ServerMessage,
148        transport: &TransportType,
149    ) -> SdkResult<Option<ClientMessage>> {
150        let response = match message {
151            ServerMessage::Request(jsonrpc_request) => {
152                let result = self
153                    .handler
154                    .handle_request(jsonrpc_request.request, self)
155                    .await;
156
157                // create a response to send back to the server
158                let response: MessageFromClient = match result {
159                    Ok(success_value) => success_value.into(),
160                    Err(error_value) => MessageFromClient::Error(error_value),
161                };
162
163                let mcp_message = ClientMessage::from_message(response, Some(jsonrpc_request.id))?;
164                Some(mcp_message)
165            }
166            ServerMessage::Notification(jsonrpc_notification) => {
167                self.handler
168                    .handle_notification(jsonrpc_notification.notification, self)
169                    .await?;
170                None
171            }
172            ServerMessage::Error(jsonrpc_error) => {
173                self.handler
174                    .handle_error(&jsonrpc_error.error, self)
175                    .await?;
176                if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await {
177                    tx_response
178                        .send(ServerMessage::Error(jsonrpc_error))
179                        .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
180                } else {
181                    tracing::warn!(
182                        "Received an error response with no corresponding request: {:?}",
183                        &jsonrpc_error.id
184                    );
185                }
186                None
187            }
188            ServerMessage::Response(response) => {
189                if let Some(tx_response) = transport.pending_request_tx(&response.id).await {
190                    tx_response
191                        .send(ServerMessage::Response(response))
192                        .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
193                } else {
194                    tracing::warn!(
195                        "Received a response with no corresponding request: {:?}",
196                        &response.id
197                    );
198                }
199                None
200            }
201        };
202        Ok(response)
203    }
204
205    async fn start_standalone(self: Arc<Self>) -> SdkResult<()> {
206        let self_clone = self.clone();
207        let transport_map = self_clone.transport_map.read().await;
208        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
209            RpcError::internal_error()
210                .with_message("transport stream does not exists or is closed!".to_string()),
211        )?;
212
213        //TODO: improve the flow
214        let mut stream = transport.start().await?;
215
216        let transport_clone = transport.clone();
217        let mut error_io_stream = transport.error_stream().write().await;
218        let error_io_stream = error_io_stream.take();
219
220        let self_clone = Arc::clone(&self);
221        let self_clone_err = Arc::clone(&self);
222
223        // task reading from the error stream
224        let err_task = tokio::spawn(async move {
225            let self_ref = &*self_clone_err;
226
227            if let Some(IoStream::Readable(error_input)) = error_io_stream {
228                let mut reader = BufReader::new(error_input).lines();
229                loop {
230                    tokio::select! {
231                        should_break = transport_clone.is_shut_down() =>{
232                            if should_break {
233                                break;
234                            }
235                        }
236                        line = reader.next_line() =>{
237                            match line {
238                                Ok(Some(error_message)) => {
239                                    self_ref
240                                        .handler
241                                        .handle_process_error(error_message, self_ref)
242                                        .await?;
243                                }
244                                Ok(None) => {
245                                    // end of input
246                                    break;
247                                }
248                                Err(e) => {
249                                    tracing::error!("Error reading from std_err: {e}");
250                                    break;
251                                }
252                            }
253                        }
254                    }
255                }
256            }
257
258            Ok::<(), McpSdkError>(())
259        });
260
261        let transport = transport.clone();
262
263        // main task reading from mcp_message stream
264        let main_task = tokio::spawn(async move {
265            while let Some(mcp_messages) = stream.next().await {
266                let self_ref = &*self_clone;
267
268                match mcp_messages {
269                    ServerMessages::Single(server_message) => {
270                        let result = self_ref.handle_message(server_message, &transport).await;
271
272                        match result {
273                            Ok(result) => {
274                                if let Some(result) = result {
275                                    transport
276                                        .send_message(ClientMessages::Single(result), None)
277                                        .await?;
278                                }
279                            }
280                            Err(error) => {
281                                tracing::error!("Error handling message : {}", error)
282                            }
283                        }
284                    }
285                    ServerMessages::Batch(server_messages) => {
286                        let handling_tasks: Vec<_> = server_messages
287                            .into_iter()
288                            .map(|server_message| {
289                                self_ref.handle_message(server_message, &transport)
290                            })
291                            .collect();
292                        let results: Vec<_> = try_join_all(handling_tasks).await?;
293                        let results: Vec<_> = results.into_iter().flatten().collect();
294
295                        if !results.is_empty() {
296                            transport
297                                .send_message(ClientMessages::Batch(results), None)
298                                .await?;
299                        }
300                    }
301                }
302            }
303            Ok::<(), McpSdkError>(())
304        });
305
306        // send initialize request to the MCP server
307        self.clone().initialize_request().await?;
308
309        let mut lock = self.handlers.lock().await;
310        lock.push(main_task);
311        lock.push(err_task);
312        Ok(())
313    }
314
315    pub(crate) async fn store_transport(
316        &self,
317        stream_id: &str,
318        transport: TransportType,
319    ) -> SdkResult<()> {
320        let mut transport_map = self.transport_map.write().await;
321        tracing::trace!("save transport for stream id : {}", stream_id);
322        transport_map.insert(stream_id.to_string(), transport);
323        Ok(())
324    }
325
326    pub(crate) async fn transport_by_stream(&self, stream_id: &str) -> SdkResult<TransportType> {
327        let transport_map = self.transport_map.read().await;
328        transport_map.get(stream_id).cloned().ok_or_else(|| {
329            RpcError::internal_error()
330                .with_message(format!("Transport for key {stream_id} not found"))
331                .into()
332        })
333    }
334
335    #[cfg(feature = "streamable-http")]
336    pub(crate) async fn new_transport(
337        &self,
338        session_id: Option<SessionId>,
339        standalone: bool,
340    ) -> SdkResult<
341        impl TransportDispatcher<
342            ServerMessages,
343            MessageFromClient,
344            ServerMessage,
345            ClientMessages,
346            ClientMessage,
347        >,
348    > {
349        let options = self
350            .transport_options
351            .as_ref()
352            .ok_or(schema_utils::SdkError::connection_closed())?;
353        let transport = ClientStreamableTransport::new(options, session_id, standalone)?;
354
355        Ok(transport)
356    }
357
358    #[cfg(feature = "streamable-http")]
359    pub(crate) async fn create_sse_stream(self: Arc<Self>) -> SdkResult<()> {
360        let stream_id: StreamId = DEFAULT_STREAM_ID.into();
361        let session_id = self.session_id.read().await.clone();
362        let transport: Arc<
363            dyn TransportDispatcher<
364                ServerMessages,
365                MessageFromClient,
366                ServerMessage,
367                ClientMessages,
368                ClientMessage,
369            >,
370        > = Arc::new(self.new_transport(session_id, true).await?);
371        let mut stream = transport.start().await?;
372        self.store_transport(&stream_id, transport.clone()).await?;
373
374        let self_clone = Arc::clone(&self);
375
376        let main_task = tokio::spawn(async move {
377            loop {
378                if let Some(mcp_messages) = stream.next().await {
379                    match mcp_messages {
380                        ServerMessages::Single(server_message) => {
381                            let result = self.handle_message(server_message, &transport).await?;
382
383                            if let Some(result) = result {
384                                transport
385                                    .send_message(ClientMessages::Single(result), None)
386                                    .await?;
387                            }
388                        }
389                        ServerMessages::Batch(server_messages) => {
390                            let handling_tasks: Vec<_> = server_messages
391                                .into_iter()
392                                .map(|server_message| {
393                                    self.handle_message(server_message, &transport)
394                                })
395                                .collect();
396
397                            let results: Vec<_> = try_join_all(handling_tasks).await?;
398
399                            let results: Vec<_> = results.into_iter().flatten().collect();
400
401                            if !results.is_empty() {
402                                transport
403                                    .send_message(ClientMessages::Batch(results), None)
404                                    .await?;
405                            }
406                        }
407                    }
408                    // close the stream after all messages are sent, unless it is a standalone stream
409                    if !stream_id.eq(DEFAULT_STREAM_ID) {
410                        return Ok::<_, McpSdkError>(());
411                    }
412                } else {
413                    // end of stream
414                    return Ok::<_, McpSdkError>(());
415                }
416            }
417        });
418
419        let mut lock = self_clone.handlers.lock().await;
420        lock.push(main_task);
421
422        Ok(())
423    }
424
425    #[cfg(feature = "streamable-http")]
426    pub(crate) async fn start_stream(
427        &self,
428        messages: ClientMessages,
429        timeout: Option<Duration>,
430    ) -> SdkResult<Option<ServerMessages>> {
431        use futures::stream::{AbortHandle, Abortable};
432        let stream_id: StreamId = self.stream_id_gen.generate();
433        let session_id = self.session_id.read().await.clone();
434        let no_session_id = session_id.is_none();
435
436        let has_request = match &messages {
437            ClientMessages::Single(client_message) => client_message.is_request(),
438            ClientMessages::Batch(client_messages) => {
439                client_messages.iter().any(|m| m.is_request())
440            }
441        };
442
443        let transport = Arc::new(self.new_transport(session_id, false).await?);
444
445        let mut stream = transport.start().await?;
446
447        self.store_transport(&stream_id, transport).await?;
448
449        let transport = self.transport_by_stream(&stream_id).await?; //TODO: remove
450
451        let send_task = async {
452            let result = transport.send_message(messages, timeout).await?;
453
454            if no_session_id {
455                if let Some(request_id) = transport.session_id().await.clone() {
456                    let mut guard = self.session_id.write().await;
457                    *guard = Some(request_id)
458                }
459            }
460
461            Ok::<_, McpSdkError>(result)
462        };
463
464        if !has_request {
465            return send_task.await;
466        }
467
468        let (abort_recv_handle, abort_recv_reg) = AbortHandle::new_pair();
469
470        let receive_task = async {
471            loop {
472                tokio::select! {
473                    Some(mcp_messages) = stream.next() =>{
474
475                        match mcp_messages {
476                            ServerMessages::Single(server_message) => {
477                                let result = self.handle_message(server_message, &transport).await?;
478                                if let Some(result) = result {
479                                    transport.send_message(ClientMessages::Single(result), None).await?;
480                                }
481                            }
482                            ServerMessages::Batch(server_messages) => {
483
484                                let handling_tasks: Vec<_> = server_messages
485                                    .into_iter()
486                                    .map(|server_message| self.handle_message(server_message, &transport))
487                                    .collect();
488
489                                let results: Vec<_> = try_join_all(handling_tasks).await?;
490
491                                let results: Vec<_> = results.into_iter().flatten().collect();
492
493                                if !results.is_empty() {
494                                    transport.send_message(ClientMessages::Batch(results), None).await?;
495                                }
496                            }
497                        }
498                        // close the stream after all messages are sent, unless it is a standalone stream
499                        if !stream_id.eq(DEFAULT_STREAM_ID){
500                            return  Ok::<_, McpSdkError>(());
501                        }
502                    }
503                }
504            }
505        };
506
507        let receive_task = Abortable::new(receive_task, abort_recv_reg);
508
509        // Pin the tasks to ensure they are not moved
510        tokio::pin!(send_task);
511        tokio::pin!(receive_task);
512
513        // Run both tasks with cancellation logic
514        let (send_res, _) = tokio::select! {
515            res = &mut send_task => {
516                // cancel the receive_task task, to cover the case where send_task returns with error
517                abort_recv_handle.abort();
518                (res, receive_task.await) // Wait for receive_task to finish (it should exit due to cancellation)
519            }
520            res = &mut receive_task => {
521                (send_task.await, res)
522            }
523        };
524        send_res
525    }
526}
527
528#[async_trait]
529impl McpClient for ClientRuntime {
530    async fn send(
531        &self,
532        message: MessageFromClient,
533        request_id: Option<RequestId>,
534        request_timeout: Option<Duration>,
535    ) -> SdkResult<Option<ServerMessage>> {
536        #[cfg(feature = "streamable-http")]
537        {
538            if self.transport_options.is_some() {
539                let outgoing_request_id = self
540                    .request_id_gen
541                    .request_id_for_message(&message, request_id);
542                let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?;
543
544                let response = self
545                    .start_stream(ClientMessages::Single(mcp_message), request_timeout)
546                    .await?;
547                return response
548                    .map(|r| r.as_single())
549                    .transpose()
550                    .map_err(|err| err.into());
551            }
552        }
553
554        let transport_map = self.transport_map.read().await;
555
556        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
557            RpcError::internal_error()
558                .with_message("transport stream does not exists or is closed!".to_string()),
559        )?;
560
561        let outgoing_request_id = self
562            .request_id_gen
563            .request_id_for_message(&message, request_id);
564
565        let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?;
566        let response = transport
567            .send_message(ClientMessages::Single(mcp_message), request_timeout)
568            .await?;
569        response
570            .map(|r| r.as_single())
571            .transpose()
572            .map_err(|err| err.into())
573    }
574
575    async fn send_batch(
576        &self,
577        messages: Vec<ClientMessage>,
578        timeout: Option<Duration>,
579    ) -> SdkResult<Option<Vec<ServerMessage>>> {
580        #[cfg(feature = "streamable-http")]
581        {
582            if self.transport_options.is_some() {
583                let result = self
584                    .start_stream(ClientMessages::Batch(messages), timeout)
585                    .await?;
586                // let response = self.start_stream(&stream_id, request_id, message).await?;
587                return result
588                    .map(|r| r.as_batch())
589                    .transpose()
590                    .map_err(|err| err.into());
591            }
592        }
593
594        let transport_map = self.transport_map.read().await;
595        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
596            RpcError::internal_error()
597                .with_message("transport stream does not exists or is closed!".to_string()),
598        )?;
599        transport
600            .send_batch(messages, timeout)
601            .await
602            .map_err(|err| err.into())
603    }
604
605    async fn start(self: Arc<Self>) -> SdkResult<()> {
606        #[cfg(feature = "streamable-http")]
607        {
608            if self.transport_options.is_some() {
609                self.initialize_request().await?;
610                return Ok(());
611            }
612        }
613
614        self.start_standalone().await
615    }
616
617    fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()> {
618        self.server_details_tx
619            .send(Some(server_details))
620            .map_err(|_| {
621                RpcError::internal_error()
622                    .with_message("Failed to set server details".to_string())
623                    .into()
624            })
625    }
626
627    fn client_info(&self) -> &InitializeRequestParams {
628        &self.client_details
629    }
630
631    fn server_info(&self) -> Option<InitializeResult> {
632        self.server_details_rx.borrow().clone()
633    }
634
635    async fn is_shut_down(&self) -> bool {
636        let result = self.is_shut_down.lock().await;
637        *result
638    }
639
640    async fn shut_down(&self) -> SdkResult<()> {
641        let mut is_shut_down_lock = self.is_shut_down.lock().await;
642        *is_shut_down_lock = true;
643
644        let mut transport_map = self.transport_map.write().await;
645        let transports: Vec<_> = transport_map.drain().map(|(_, v)| v).collect();
646        drop(transport_map);
647        for transport in transports {
648            let _ = transport.shut_down().await;
649        }
650
651        // wait for tasks
652        let mut tasks_lock = self.handlers.lock().await;
653        let join_handlers: Vec<_> = tasks_lock.drain(..).collect();
654        join_all(join_handlers).await;
655
656        Ok(())
657    }
658
659    async fn terminate_session(&self) {
660        #[cfg(feature = "streamable-http")]
661        {
662            if let Some(transport_options) = self.transport_options.as_ref() {
663                let session_id = self.session_id.read().await.clone();
664                transport_options
665                    .terminate_session(session_id.as_ref())
666                    .await;
667                let _ = self.shut_down().await;
668            }
669        }
670        let _ = self.shut_down().await;
671    }
672}