rust_mcp_sdk/mcp_runtimes/
client_runtime.rs

1pub mod mcp_client_runtime;
2pub mod mcp_client_runtime_core;
3use crate::error::{McpSdkError, SdkResult};
4use crate::id_generator::FastIdGenerator;
5use crate::mcp_traits::mcp_client::McpClient;
6use crate::mcp_traits::mcp_handler::McpClientHandler;
7use crate::mcp_traits::IdGenerator;
8use crate::utils::ensure_server_protocole_compatibility;
9use crate::{
10    mcp_traits::{RequestIdGen, RequestIdGenNumeric},
11    schema::{
12        schema_utils::{
13            self, ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromClient,
14            ServerMessage, ServerMessages,
15        },
16        InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification,
17        RequestId, RpcError, ServerResult,
18    },
19};
20use async_trait::async_trait;
21use futures::future::{join_all, try_join_all};
22use futures::StreamExt;
23
24#[cfg(feature = "streamable-http")]
25use rust_mcp_transport::{ClientStreamableTransport, StreamableTransportOptions};
26use rust_mcp_transport::{IoStream, SessionId, StreamId, Transport, TransportDispatcher};
27use std::{collections::HashMap, sync::Arc, time::Duration};
28use tokio::io::{AsyncBufReadExt, BufReader};
29use tokio::sync::{watch, Mutex};
30
31pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM";
32
33// Define a type alias for the TransportDispatcher trait object
34type TransportDispatcherType = dyn TransportDispatcher<
35    ServerMessages,
36    MessageFromClient,
37    ServerMessage,
38    ClientMessages,
39    ClientMessage,
40>;
41type TransportType = Arc<TransportDispatcherType>;
42
43pub struct ClientRuntime {
44    // A thread-safe map storing transport types
45    transport_map: tokio::sync::RwLock<HashMap<String, TransportType>>,
46    // The handler for processing MCP messages
47    handler: Box<dyn McpClientHandler>,
48    // Information about the server
49    client_details: InitializeRequestParams,
50    handlers: Mutex<Vec<tokio::task::JoinHandle<Result<(), McpSdkError>>>>,
51    // Generator for unique request IDs
52    request_id_gen: Box<dyn RequestIdGen>,
53    // Generator for stream IDs
54    stream_id_gen: FastIdGenerator,
55    #[cfg(feature = "streamable-http")]
56    // Optional configuration for streamable transport
57    transport_options: Option<StreamableTransportOptions>,
58    // Flag indicating whether the client has been shut down
59    is_shut_down: Mutex<bool>,
60    // Session ID
61    session_id: tokio::sync::RwLock<Option<SessionId>>,
62    // Details about the connected server
63    server_details_tx: watch::Sender<Option<InitializeResult>>,
64    server_details_rx: watch::Receiver<Option<InitializeResult>>,
65}
66
67impl ClientRuntime {
68    pub(crate) fn new(
69        client_details: InitializeRequestParams,
70        transport: TransportType,
71        handler: Box<dyn McpClientHandler>,
72    ) -> Self {
73        let mut map: HashMap<String, TransportType> = HashMap::new();
74        map.insert(DEFAULT_STREAM_ID.to_string(), transport);
75        let (server_details_tx, server_details_rx) =
76            watch::channel::<Option<InitializeResult>>(None);
77        Self {
78            transport_map: tokio::sync::RwLock::new(map),
79            handler,
80            client_details,
81            handlers: Mutex::new(vec![]),
82            request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
83            #[cfg(feature = "streamable-http")]
84            transport_options: None,
85            is_shut_down: Mutex::new(false),
86            session_id: tokio::sync::RwLock::new(None),
87            stream_id_gen: FastIdGenerator::new(Some("s_")),
88            server_details_tx,
89            server_details_rx,
90        }
91    }
92
93    #[cfg(feature = "streamable-http")]
94    pub(crate) fn new_instance(
95        client_details: InitializeRequestParams,
96        transport_options: StreamableTransportOptions,
97        handler: Box<dyn McpClientHandler>,
98    ) -> Self {
99        let map: HashMap<String, TransportType> = HashMap::new();
100        let (server_details_tx, server_details_rx) =
101            watch::channel::<Option<InitializeResult>>(None);
102        Self {
103            transport_map: tokio::sync::RwLock::new(map),
104            handler,
105            client_details,
106            handlers: Mutex::new(vec![]),
107            transport_options: Some(transport_options),
108            is_shut_down: Mutex::new(false),
109            session_id: tokio::sync::RwLock::new(None),
110            request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
111            stream_id_gen: FastIdGenerator::new(Some("s_")),
112            server_details_tx,
113            server_details_rx,
114        }
115    }
116
117    async fn initialize_request(self: Arc<Self>) -> SdkResult<()> {
118        let request = InitializeRequest::new(self.client_details.clone());
119        let result: ServerResult = self.request(request.into(), None).await?.try_into()?;
120
121        if let ServerResult::InitializeResult(initialize_result) = result {
122            ensure_server_protocole_compatibility(
123                &self.client_details.protocol_version,
124                &initialize_result.protocol_version,
125            )?;
126            // store server details
127            self.set_server_details(initialize_result)?;
128
129            #[cfg(feature = "streamable-http")]
130            // try to create a sse stream for server initiated messages , if supported by the server
131            if let Err(error) = self.clone().create_sse_stream().await {
132                tracing::warn!("{error}");
133            }
134
135            // send a InitializedNotification to the server
136            self.send_notification(InitializedNotification::new(None).into())
137                .await?;
138        } else {
139            return Err(RpcError::invalid_params()
140                .with_message("Incorrect response to InitializeRequest!".into())
141                .into());
142        }
143
144        Ok(())
145    }
146
147    pub(crate) async fn handle_message(
148        &self,
149        message: ServerMessage,
150        transport: &TransportType,
151    ) -> SdkResult<Option<ClientMessage>> {
152        let response = match message {
153            ServerMessage::Request(jsonrpc_request) => {
154                let result = self
155                    .handler
156                    .handle_request(jsonrpc_request.request, self)
157                    .await;
158
159                // create a response to send back to the server
160                let response: MessageFromClient = match result {
161                    Ok(success_value) => success_value.into(),
162                    Err(error_value) => MessageFromClient::Error(error_value),
163                };
164
165                let mcp_message = ClientMessage::from_message(response, Some(jsonrpc_request.id))?;
166                Some(mcp_message)
167            }
168            ServerMessage::Notification(jsonrpc_notification) => {
169                self.handler
170                    .handle_notification(jsonrpc_notification.notification, self)
171                    .await?;
172                None
173            }
174            ServerMessage::Error(jsonrpc_error) => {
175                self.handler
176                    .handle_error(&jsonrpc_error.error, self)
177                    .await?;
178                if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await {
179                    tx_response
180                        .send(ServerMessage::Error(jsonrpc_error))
181                        .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
182                } else {
183                    tracing::warn!(
184                        "Received an error response with no corresponding request: {:?}",
185                        &jsonrpc_error.id
186                    );
187                }
188                None
189            }
190            ServerMessage::Response(response) => {
191                if let Some(tx_response) = transport.pending_request_tx(&response.id).await {
192                    tx_response
193                        .send(ServerMessage::Response(response))
194                        .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
195                } else {
196                    tracing::warn!(
197                        "Received a response with no corresponding request: {:?}",
198                        &response.id
199                    );
200                }
201                None
202            }
203        };
204        Ok(response)
205    }
206
207    async fn start_standalone(self: Arc<Self>) -> SdkResult<()> {
208        let self_clone = self.clone();
209        let transport_map = self_clone.transport_map.read().await;
210        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
211            RpcError::internal_error()
212                .with_message("transport stream does not exists or is closed!".to_string()),
213        )?;
214
215        //TODO: improve the flow
216        let mut stream = transport.start().await?;
217
218        let transport_clone = transport.clone();
219        let mut error_io_stream = transport.error_stream().write().await;
220        let error_io_stream = error_io_stream.take();
221
222        let self_clone = Arc::clone(&self);
223        let self_clone_err = Arc::clone(&self);
224
225        // task reading from the error stream
226        let err_task = tokio::spawn(async move {
227            let self_ref = &*self_clone_err;
228
229            if let Some(IoStream::Readable(error_input)) = error_io_stream {
230                let mut reader = BufReader::new(error_input).lines();
231                loop {
232                    tokio::select! {
233                        should_break = transport_clone.is_shut_down() =>{
234                            if should_break {
235                                break;
236                            }
237                        }
238                        line = reader.next_line() =>{
239                            match line {
240                                Ok(Some(error_message)) => {
241                                    self_ref
242                                        .handler
243                                        .handle_process_error(error_message, self_ref)
244                                        .await?;
245                                }
246                                Ok(None) => {
247                                    // end of input
248                                    break;
249                                }
250                                Err(e) => {
251                                    tracing::error!("Error reading from std_err: {e}");
252                                    break;
253                                }
254                            }
255                        }
256                    }
257                }
258            }
259
260            Ok::<(), McpSdkError>(())
261        });
262
263        let transport = transport.clone();
264
265        // main task reading from mcp_message stream
266        let main_task = tokio::spawn(async move {
267            while let Some(mcp_messages) = stream.next().await {
268                let self_ref = &*self_clone;
269
270                match mcp_messages {
271                    ServerMessages::Single(server_message) => {
272                        let result = self_ref.handle_message(server_message, &transport).await;
273
274                        match result {
275                            Ok(result) => {
276                                if let Some(result) = result {
277                                    transport
278                                        .send_message(ClientMessages::Single(result), None)
279                                        .await?;
280                                }
281                            }
282                            Err(error) => {
283                                tracing::error!("Error handling message : {}", error)
284                            }
285                        }
286                    }
287                    ServerMessages::Batch(server_messages) => {
288                        let handling_tasks: Vec<_> = server_messages
289                            .into_iter()
290                            .map(|server_message| {
291                                self_ref.handle_message(server_message, &transport)
292                            })
293                            .collect();
294                        let results: Vec<_> = try_join_all(handling_tasks).await?;
295                        let results: Vec<_> = results.into_iter().flatten().collect();
296
297                        if !results.is_empty() {
298                            transport
299                                .send_message(ClientMessages::Batch(results), None)
300                                .await?;
301                        }
302                    }
303                }
304            }
305            Ok::<(), McpSdkError>(())
306        });
307
308        // send initialize request to the MCP server
309        self.clone().initialize_request().await?;
310
311        let mut lock = self.handlers.lock().await;
312        lock.push(main_task);
313        lock.push(err_task);
314        Ok(())
315    }
316
317    pub(crate) async fn store_transport(
318        &self,
319        stream_id: &str,
320        transport: TransportType,
321    ) -> SdkResult<()> {
322        let mut transport_map = self.transport_map.write().await;
323        tracing::trace!("save transport for stream id : {}", stream_id);
324        transport_map.insert(stream_id.to_string(), transport);
325        Ok(())
326    }
327
328    pub(crate) async fn transport_by_stream(&self, stream_id: &str) -> SdkResult<TransportType> {
329        let transport_map = self.transport_map.read().await;
330        transport_map.get(stream_id).cloned().ok_or_else(|| {
331            RpcError::internal_error()
332                .with_message(format!("Transport for key {stream_id} not found"))
333                .into()
334        })
335    }
336
337    #[cfg(feature = "streamable-http")]
338    pub(crate) async fn new_transport(
339        &self,
340        session_id: Option<SessionId>,
341        standalone: bool,
342    ) -> SdkResult<
343        impl TransportDispatcher<
344            ServerMessages,
345            MessageFromClient,
346            ServerMessage,
347            ClientMessages,
348            ClientMessage,
349        >,
350    > {
351        let options = self
352            .transport_options
353            .as_ref()
354            .ok_or(schema_utils::SdkError::connection_closed())?;
355        let transport = ClientStreamableTransport::new(options, session_id, standalone)?;
356
357        Ok(transport)
358    }
359
360    #[cfg(feature = "streamable-http")]
361    pub(crate) async fn create_sse_stream(self: Arc<Self>) -> SdkResult<()> {
362        let stream_id: StreamId = DEFAULT_STREAM_ID.into();
363        let session_id = self.session_id.read().await.clone();
364        let transport: Arc<
365            dyn TransportDispatcher<
366                ServerMessages,
367                MessageFromClient,
368                ServerMessage,
369                ClientMessages,
370                ClientMessage,
371            >,
372        > = Arc::new(self.new_transport(session_id, true).await?);
373        let mut stream = transport.start().await?;
374        self.store_transport(&stream_id, transport.clone()).await?;
375
376        let self_clone = Arc::clone(&self);
377
378        let main_task = tokio::spawn(async move {
379            loop {
380                if let Some(mcp_messages) = stream.next().await {
381                    match mcp_messages {
382                        ServerMessages::Single(server_message) => {
383                            let result = self.handle_message(server_message, &transport).await?;
384
385                            if let Some(result) = result {
386                                transport
387                                    .send_message(ClientMessages::Single(result), None)
388                                    .await?;
389                            }
390                        }
391                        ServerMessages::Batch(server_messages) => {
392                            let handling_tasks: Vec<_> = server_messages
393                                .into_iter()
394                                .map(|server_message| {
395                                    self.handle_message(server_message, &transport)
396                                })
397                                .collect();
398
399                            let results: Vec<_> = try_join_all(handling_tasks).await?;
400
401                            let results: Vec<_> = results.into_iter().flatten().collect();
402
403                            if !results.is_empty() {
404                                transport
405                                    .send_message(ClientMessages::Batch(results), None)
406                                    .await?;
407                            }
408                        }
409                    }
410                    // close the stream after all messages are sent, unless it is a standalone stream
411                    if !stream_id.eq(DEFAULT_STREAM_ID) {
412                        return Ok::<_, McpSdkError>(());
413                    }
414                } else {
415                    // end of stream
416                    return Ok::<_, McpSdkError>(());
417                }
418            }
419        });
420
421        let mut lock = self_clone.handlers.lock().await;
422        lock.push(main_task);
423
424        Ok(())
425    }
426
427    #[cfg(feature = "streamable-http")]
428    pub(crate) async fn start_stream(
429        &self,
430        messages: ClientMessages,
431        timeout: Option<Duration>,
432    ) -> SdkResult<Option<ServerMessages>> {
433        use futures::stream::{AbortHandle, Abortable};
434        let stream_id: StreamId = self.stream_id_gen.generate();
435        let session_id = self.session_id.read().await.clone();
436        let no_session_id = session_id.is_none();
437
438        let has_request = match &messages {
439            ClientMessages::Single(client_message) => client_message.is_request(),
440            ClientMessages::Batch(client_messages) => {
441                client_messages.iter().any(|m| m.is_request())
442            }
443        };
444
445        let transport = Arc::new(self.new_transport(session_id, false).await?);
446
447        let mut stream = transport.start().await?;
448
449        self.store_transport(&stream_id, transport).await?;
450
451        let transport = self.transport_by_stream(&stream_id).await?; //TODO: remove
452
453        let send_task = async {
454            let result = transport.send_message(messages, timeout).await?;
455
456            if no_session_id {
457                if let Some(request_id) = transport.session_id().await.clone() {
458                    let mut guard = self.session_id.write().await;
459                    *guard = Some(request_id)
460                }
461            }
462
463            Ok::<_, McpSdkError>(result)
464        };
465
466        if !has_request {
467            return send_task.await;
468        }
469
470        let (abort_recv_handle, abort_recv_reg) = AbortHandle::new_pair();
471
472        let receive_task = async {
473            loop {
474                tokio::select! {
475                    Some(mcp_messages) = stream.next() =>{
476
477                        match mcp_messages {
478                            ServerMessages::Single(server_message) => {
479                                let result = self.handle_message(server_message, &transport).await?;
480                                if let Some(result) = result {
481                                    transport.send_message(ClientMessages::Single(result), None).await?;
482                                }
483                            }
484                            ServerMessages::Batch(server_messages) => {
485
486                                let handling_tasks: Vec<_> = server_messages
487                                    .into_iter()
488                                    .map(|server_message| self.handle_message(server_message, &transport))
489                                    .collect();
490
491                                let results: Vec<_> = try_join_all(handling_tasks).await?;
492
493                                let results: Vec<_> = results.into_iter().flatten().collect();
494
495                                if !results.is_empty() {
496                                    transport.send_message(ClientMessages::Batch(results), None).await?;
497                                }
498                            }
499                        }
500                        // close the stream after all messages are sent, unless it is a standalone stream
501                        if !stream_id.eq(DEFAULT_STREAM_ID){
502                            return  Ok::<_, McpSdkError>(());
503                        }
504                    }
505                }
506            }
507        };
508
509        let receive_task = Abortable::new(receive_task, abort_recv_reg);
510
511        // Pin the tasks to ensure they are not moved
512        tokio::pin!(send_task);
513        tokio::pin!(receive_task);
514
515        // Run both tasks with cancellation logic
516        let (send_res, _) = tokio::select! {
517            res = &mut send_task => {
518                // cancel the receive_task task, to cover the case where send_task returns with error
519                abort_recv_handle.abort();
520                (res, receive_task.await) // Wait for receive_task to finish (it should exit due to cancellation)
521            }
522            res = &mut receive_task => {
523                (send_task.await, res)
524            }
525        };
526        send_res
527    }
528}
529
530#[async_trait]
531impl McpClient for ClientRuntime {
532    async fn send(
533        &self,
534        message: MessageFromClient,
535        request_id: Option<RequestId>,
536        request_timeout: Option<Duration>,
537    ) -> SdkResult<Option<ServerMessage>> {
538        #[cfg(feature = "streamable-http")]
539        {
540            if self.transport_options.is_some() {
541                let outgoing_request_id = self
542                    .request_id_gen
543                    .request_id_for_message(&message, request_id);
544                let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?;
545
546                let response = self
547                    .start_stream(ClientMessages::Single(mcp_message), request_timeout)
548                    .await?;
549                return response
550                    .map(|r| r.as_single())
551                    .transpose()
552                    .map_err(|err| err.into());
553            }
554        }
555
556        let transport_map = self.transport_map.read().await;
557
558        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
559            RpcError::internal_error()
560                .with_message("transport stream does not exists or is closed!".to_string()),
561        )?;
562
563        let outgoing_request_id = self
564            .request_id_gen
565            .request_id_for_message(&message, request_id);
566
567        let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?;
568        let response = transport
569            .send_message(ClientMessages::Single(mcp_message), request_timeout)
570            .await?;
571        response
572            .map(|r| r.as_single())
573            .transpose()
574            .map_err(|err| err.into())
575    }
576
577    async fn send_batch(
578        &self,
579        messages: Vec<ClientMessage>,
580        timeout: Option<Duration>,
581    ) -> SdkResult<Option<Vec<ServerMessage>>> {
582        #[cfg(feature = "streamable-http")]
583        {
584            if self.transport_options.is_some() {
585                let result = self
586                    .start_stream(ClientMessages::Batch(messages), timeout)
587                    .await?;
588                // let response = self.start_stream(&stream_id, request_id, message).await?;
589                return result
590                    .map(|r| r.as_batch())
591                    .transpose()
592                    .map_err(|err| err.into());
593            }
594        }
595
596        let transport_map = self.transport_map.read().await;
597        let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
598            RpcError::internal_error()
599                .with_message("transport stream does not exists or is closed!".to_string()),
600        )?;
601        transport
602            .send_batch(messages, timeout)
603            .await
604            .map_err(|err| err.into())
605    }
606
607    async fn start(self: Arc<Self>) -> SdkResult<()> {
608        #[cfg(feature = "streamable-http")]
609        {
610            if self.transport_options.is_some() {
611                self.initialize_request().await?;
612                return Ok(());
613            }
614        }
615
616        self.start_standalone().await
617    }
618
619    fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()> {
620        self.server_details_tx
621            .send(Some(server_details))
622            .map_err(|_| {
623                RpcError::internal_error()
624                    .with_message("Failed to set server details".to_string())
625                    .into()
626            })
627    }
628
629    fn client_info(&self) -> &InitializeRequestParams {
630        &self.client_details
631    }
632
633    fn server_info(&self) -> Option<InitializeResult> {
634        self.server_details_rx.borrow().clone()
635    }
636
637    async fn is_shut_down(&self) -> bool {
638        let result = self.is_shut_down.lock().await;
639        *result
640    }
641
642    async fn shut_down(&self) -> SdkResult<()> {
643        let mut is_shut_down_lock = self.is_shut_down.lock().await;
644        *is_shut_down_lock = true;
645
646        let mut transport_map = self.transport_map.write().await;
647        let transports: Vec<_> = transport_map.drain().map(|(_, v)| v).collect();
648        drop(transport_map);
649        for transport in transports {
650            let _ = transport.shut_down().await;
651        }
652
653        // wait for tasks
654        let mut tasks_lock = self.handlers.lock().await;
655        let join_handlers: Vec<_> = tasks_lock.drain(..).collect();
656        join_all(join_handlers).await;
657
658        Ok(())
659    }
660
661    async fn terminate_session(&self) {
662        #[cfg(feature = "streamable-http")]
663        {
664            if let Some(transport_options) = self.transport_options.as_ref() {
665                let session_id = self.session_id.read().await.clone();
666                transport_options
667                    .terminate_session(session_id.as_ref())
668                    .await;
669                let _ = self.shut_down().await;
670            }
671        }
672        let _ = self.shut_down().await;
673    }
674}