rust_mcp_sdk/mcp_runtimes/
client_runtime.rs

1pub mod mcp_client_runtime;
2pub mod mcp_client_runtime_core;
3
4use async_trait::async_trait;
5use futures::future::join_all;
6use futures::StreamExt;
7use rust_mcp_schema::schema_utils::{self, MessageFromClient, ServerMessage};
8use rust_mcp_schema::{
9    InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification,
10    RpcError, ServerResult,
11};
12use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport};
13use std::sync::{Arc, RwLock};
14use tokio::io::{AsyncBufReadExt, BufReader};
15use tokio::sync::Mutex;
16
17use crate::error::{McpSdkError, SdkResult};
18use crate::mcp_traits::mcp_client::McpClient;
19use crate::mcp_traits::mcp_handler::McpClientHandler;
20
21pub struct ClientRuntime {
22    // The transport interface for handling messages between client and server
23    transport: Box<dyn Transport<ServerMessage, MessageFromClient>>,
24    // The handler for processing MCP messages
25    handler: Box<dyn McpClientHandler>,
26    // // Information about the server
27    client_details: InitializeRequestParams,
28    // Details about the connected server
29    server_details: Arc<RwLock<Option<InitializeResult>>>,
30    message_sender: tokio::sync::RwLock<Option<MessageDispatcher<ServerMessage>>>,
31    handlers: Mutex<Vec<tokio::task::JoinHandle<Result<(), McpSdkError>>>>,
32}
33
34impl ClientRuntime {
35    pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<ServerMessage>) {
36        let mut lock = self.message_sender.write().await;
37        *lock = Some(sender);
38    }
39
40    pub(crate) fn new(
41        client_details: InitializeRequestParams,
42        transport: impl Transport<ServerMessage, MessageFromClient>,
43        handler: Box<dyn McpClientHandler>,
44    ) -> Self {
45        Self {
46            transport: Box::new(transport),
47            handler,
48            client_details,
49            server_details: Arc::new(RwLock::new(None)),
50            message_sender: tokio::sync::RwLock::new(None),
51            handlers: Mutex::new(vec![]),
52        }
53    }
54
55    async fn initialize_request(&self) -> SdkResult<()> {
56        let request = InitializeRequest::new(self.client_details.clone());
57        let result: ServerResult = self.request(request.into(), None).await?.try_into()?;
58
59        if let ServerResult::InitializeResult(initialize_result) = result {
60            // store server details
61            self.set_server_details(initialize_result)?;
62            // send a InitializedNotification to the server
63            self.send_notification(InitializedNotification::new(None).into())
64                .await?;
65        } else {
66            return Err(RpcError::invalid_params()
67                .with_message("Incorrect response to InitializeRequest!".into())
68                .into());
69        }
70        Ok(())
71    }
72}
73
74#[async_trait]
75impl McpClient for ClientRuntime {
76    async fn sender(&self) -> &tokio::sync::RwLock<Option<MessageDispatcher<ServerMessage>>>
77    where
78        MessageDispatcher<ServerMessage>: McpDispatch<ServerMessage, MessageFromClient>,
79    {
80        (&self.message_sender) as _
81    }
82
83    async fn start(self: Arc<Self>) -> SdkResult<()> {
84        let (mut stream, sender, error_io) = self.transport.start().await?;
85        self.set_message_sender(sender).await;
86
87        let self_clone = Arc::clone(&self);
88        let self_clone_err = Arc::clone(&self);
89
90        let err_task = tokio::spawn(async move {
91            let self_ref = &*self_clone_err;
92
93            if let IoStream::Readable(error_input) = error_io {
94                let mut reader = BufReader::new(error_input).lines();
95                loop {
96                    tokio::select! {
97                        should_break = self_ref.transport.is_shut_down() =>{
98                            if should_break {
99                                break;
100                            }
101                        }
102                        line = reader.next_line() =>{
103                            match line {
104                                Ok(Some(error_message)) => {
105                                    self_ref
106                                        .handler
107                                        .handle_process_error(error_message, self_ref)
108                                        .await?;
109                                }
110                                Ok(None) => {
111                                    // end of input
112                                    break;
113                                }
114                                Err(e) => {
115                                    eprintln!("Error reading from std_err: {}", e);
116                                    break;
117                                }
118                            }
119                        }
120                    }
121                }
122            }
123            Ok::<(), McpSdkError>(())
124        });
125
126        // send initialize request to the MCP server
127        self_clone.initialize_request().await?;
128
129        let main_task = tokio::spawn(async move {
130            let sender = self_clone.sender().await.read().await;
131            let sender = sender
132                .as_ref()
133                .ok_or(schema_utils::SdkError::connection_closed())?;
134            while let Some(mcp_message) = stream.next().await {
135                let self_ref = &*self_clone;
136
137                match mcp_message {
138                    ServerMessage::Request(jsonrpc_request) => {
139                        let result = self_ref
140                            .handler
141                            .handle_request(jsonrpc_request.request, self_ref)
142                            .await;
143
144                        // create a response to send back to the server
145                        let response: MessageFromClient = match result {
146                            Ok(success_value) => success_value.into(),
147                            Err(error_value) => MessageFromClient::Error(error_value),
148                        };
149                        // send the response back with corresponding request id
150                        sender
151                            .send(response, Some(jsonrpc_request.id), None)
152                            .await?;
153                    }
154                    ServerMessage::Notification(jsonrpc_notification) => {
155                        self_ref
156                            .handler
157                            .handle_notification(jsonrpc_notification.notification, self_ref)
158                            .await?;
159                    }
160                    ServerMessage::Error(jsonrpc_error) => {
161                        self_ref
162                            .handler
163                            .handle_error(jsonrpc_error.error, self_ref)
164                            .await?;
165                    }
166                    // The response is the result of a request, it is processed at the transport level.
167                    ServerMessage::Response(_) => {}
168                }
169            }
170            Ok::<(), McpSdkError>(())
171        });
172
173        let mut lock = self.handlers.lock().await;
174        lock.push(main_task);
175        lock.push(err_task);
176
177        Ok(())
178    }
179
180    fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()> {
181        match self.server_details.write() {
182            Ok(mut details) => {
183                *details = Some(server_details);
184                Ok(())
185            }
186            // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None.
187            Err(_) => Err(RpcError::internal_error()
188                .with_message("Internal Error: Failed to acquire write lock.".to_string())
189                .into()),
190        }
191    }
192    fn client_info(&self) -> &InitializeRequestParams {
193        &self.client_details
194    }
195    fn server_info(&self) -> Option<InitializeResult> {
196        if let Ok(details) = self.server_details.read() {
197            details.clone()
198        } else {
199            // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None.
200            None
201        }
202    }
203
204    async fn is_shut_down(&self) -> bool {
205        self.transport.is_shut_down().await
206    }
207    async fn shut_down(&self) -> SdkResult<()> {
208        self.transport.shut_down().await?;
209
210        // wait for tasks
211        let mut tasks_lock = self.handlers.lock().await;
212        let join_handlers: Vec<_> = tasks_lock.drain(..).collect();
213        join_all(join_handlers).await;
214
215        Ok(())
216    }
217}