rust_mcp_sdk/mcp_runtimes/
client_runtime.rs

1pub mod mcp_client_runtime;
2pub mod mcp_client_runtime_core;
3
4use crate::{
5    mcp_traits::{RequestIdGen, RequestIdGenNumeric},
6    schema::{
7        schema_utils::{
8            self, ClientMessage, ClientMessages, FromMessage, MessageFromClient, ServerMessage,
9            ServerMessages,
10        },
11        InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification,
12        RequestId, RpcError, ServerResult,
13    },
14};
15use async_trait::async_trait;
16use futures::future::{join_all, try_join_all};
17use futures::StreamExt;
18
19use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport};
20use std::{
21    sync::{Arc, RwLock},
22    time::Duration,
23};
24use tokio::io::{AsyncBufReadExt, BufReader};
25use tokio::sync::Mutex;
26
27use crate::error::{McpSdkError, SdkResult};
28use crate::mcp_traits::mcp_client::McpClient;
29use crate::mcp_traits::mcp_handler::McpClientHandler;
30use crate::utils::ensure_server_protocole_compatibility;
31
32pub struct ClientRuntime {
33    // The transport interface for handling messages between client and server
34    transport: Arc<
35        dyn Transport<
36            ServerMessages,
37            MessageFromClient,
38            ServerMessage,
39            ClientMessages,
40            ClientMessage,
41        >,
42    >,
43    // The handler for processing MCP messages
44    handler: Box<dyn McpClientHandler>,
45    // // Information about the server
46    client_details: InitializeRequestParams,
47    // Details about the connected server
48    server_details: Arc<RwLock<Option<InitializeResult>>>,
49    handlers: Mutex<Vec<tokio::task::JoinHandle<Result<(), McpSdkError>>>>,
50    request_id_gen: Box<dyn RequestIdGen>,
51}
52
53impl ClientRuntime {
54    pub(crate) fn new(
55        client_details: InitializeRequestParams,
56        transport: impl Transport<
57            ServerMessages,
58            MessageFromClient,
59            ServerMessage,
60            ClientMessages,
61            ClientMessage,
62        >,
63        handler: Box<dyn McpClientHandler>,
64    ) -> Self {
65        Self {
66            transport: Arc::new(transport),
67            handler,
68            client_details,
69            server_details: Arc::new(RwLock::new(None)),
70            handlers: Mutex::new(vec![]),
71            request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
72        }
73    }
74
75    async fn initialize_request(&self) -> SdkResult<()> {
76        let request = InitializeRequest::new(self.client_details.clone());
77        let result: ServerResult = self.request(request.into(), None).await?.try_into()?;
78
79        if let ServerResult::InitializeResult(initialize_result) = result {
80            ensure_server_protocole_compatibility(
81                &self.client_details.protocol_version,
82                &initialize_result.protocol_version,
83            )?;
84
85            // store server details
86            self.set_server_details(initialize_result)?;
87            // send a InitializedNotification to the server
88            self.send_notification(InitializedNotification::new(None).into())
89                .await?;
90        } else {
91            return Err(RpcError::invalid_params()
92                .with_message("Incorrect response to InitializeRequest!".into())
93                .into());
94        }
95        Ok(())
96    }
97
98    pub(crate) async fn handle_message(
99        &self,
100        message: ServerMessage,
101        transport: &Arc<
102            dyn Transport<
103                ServerMessages,
104                MessageFromClient,
105                ServerMessage,
106                ClientMessages,
107                ClientMessage,
108            >,
109        >,
110    ) -> SdkResult<Option<ClientMessage>> {
111        let response = match message {
112            ServerMessage::Request(jsonrpc_request) => {
113                let result = self
114                    .handler
115                    .handle_request(jsonrpc_request.request, self)
116                    .await;
117
118                // create a response to send back to the server
119                let response: MessageFromClient = match result {
120                    Ok(success_value) => success_value.into(),
121                    Err(error_value) => MessageFromClient::Error(error_value),
122                };
123
124                let mcp_message = ClientMessage::from_message(response, Some(jsonrpc_request.id))?;
125                Some(mcp_message)
126            }
127            ServerMessage::Notification(jsonrpc_notification) => {
128                self.handler
129                    .handle_notification(jsonrpc_notification.notification, self)
130                    .await?;
131                None
132            }
133            ServerMessage::Error(jsonrpc_error) => {
134                self.handler
135                    .handle_error(&jsonrpc_error.error, self)
136                    .await?;
137                if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await {
138                    tx_response
139                        .send(ServerMessage::Error(jsonrpc_error))
140                        .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
141                } else {
142                    tracing::warn!(
143                        "Received an error response with no corresponding request: {:?}",
144                        &jsonrpc_error.id
145                    );
146                }
147                None
148            }
149            ServerMessage::Response(response) => {
150                if let Some(tx_response) = transport.pending_request_tx(&response.id).await {
151                    tx_response
152                        .send(ServerMessage::Response(response))
153                        .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
154                } else {
155                    tracing::warn!(
156                        "Received a response with no corresponding request: {:?}",
157                        &response.id
158                    );
159                }
160                None
161            }
162        };
163        Ok(response)
164    }
165}
166
167#[async_trait]
168impl McpClient for ClientRuntime {
169    fn sender(&self) -> Arc<tokio::sync::RwLock<Option<MessageDispatcher<ServerMessage>>>>
170    where
171        MessageDispatcher<ServerMessage>:
172            McpDispatch<ServerMessages, ClientMessages, ServerMessage, ClientMessage>,
173    {
174        (self.transport.message_sender().clone()) as _
175    }
176
177    async fn start(self: Arc<Self>) -> SdkResult<()> {
178        //TODO: improve the flow
179        let mut stream = self.transport.start().await?;
180        let transport = self.transport.clone();
181        let mut error_io_stream = transport.error_stream().write().await;
182        let error_io_stream = error_io_stream.take();
183
184        let self_clone = Arc::clone(&self);
185        let self_clone_err = Arc::clone(&self);
186
187        let err_task = tokio::spawn(async move {
188            let self_ref = &*self_clone_err;
189
190            if let Some(IoStream::Readable(error_input)) = error_io_stream {
191                let mut reader = BufReader::new(error_input).lines();
192                loop {
193                    tokio::select! {
194                        should_break = self_ref.transport.is_shut_down() =>{
195                            if should_break {
196                                break;
197                            }
198                        }
199                        line = reader.next_line() =>{
200                            match line {
201                                Ok(Some(error_message)) => {
202                                    self_ref
203                                        .handler
204                                        .handle_process_error(error_message, self_ref)
205                                        .await?;
206                                }
207                                Ok(None) => {
208                                    // end of input
209                                    break;
210                                }
211                                Err(e) => {
212                                    tracing::error!("Error reading from std_err: {e}");
213                                    break;
214                                }
215                            }
216                        }
217                    }
218                }
219            }
220
221            Ok::<(), McpSdkError>(())
222        });
223
224        let transport = self.transport.clone();
225
226        let main_task = tokio::spawn(async move {
227            let sender = self_clone.sender();
228            let sender = sender.read().await;
229            let sender = sender
230                .as_ref()
231                .ok_or(schema_utils::SdkError::connection_closed())?;
232            while let Some(mcp_messages) = stream.next().await {
233                let self_ref = &*self_clone;
234
235                match mcp_messages {
236                    ServerMessages::Single(server_message) => {
237                        let result = self_ref.handle_message(server_message, &transport).await;
238
239                        match result {
240                            Ok(result) => {
241                                if let Some(result) = result {
242                                    sender
243                                        .send_message(ClientMessages::Single(result), None)
244                                        .await?;
245                                }
246                            }
247                            Err(error) => {
248                                tracing::error!("Error handling message : {}", error)
249                            }
250                        }
251                    }
252                    ServerMessages::Batch(server_messages) => {
253                        let handling_tasks: Vec<_> = server_messages
254                            .into_iter()
255                            .map(|server_message| {
256                                self_ref.handle_message(server_message, &transport)
257                            })
258                            .collect();
259                        let results: Vec<_> = try_join_all(handling_tasks).await?;
260                        let results: Vec<_> = results.into_iter().flatten().collect();
261
262                        if !results.is_empty() {
263                            sender
264                                .send_message(ClientMessages::Batch(results), None)
265                                .await?;
266                        }
267                    }
268                }
269            }
270            Ok::<(), McpSdkError>(())
271        });
272
273        // send initialize request to the MCP server
274        self.initialize_request().await?;
275
276        let mut lock = self.handlers.lock().await;
277        lock.push(main_task);
278        lock.push(err_task);
279
280        Ok(())
281    }
282
283    fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()> {
284        match self.server_details.write() {
285            Ok(mut details) => {
286                *details = Some(server_details);
287                Ok(())
288            }
289            // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None.
290            Err(_) => Err(RpcError::internal_error()
291                .with_message("Internal Error: Failed to acquire write lock.".to_string())
292                .into()),
293        }
294    }
295    fn client_info(&self) -> &InitializeRequestParams {
296        &self.client_details
297    }
298    fn server_info(&self) -> Option<InitializeResult> {
299        if let Ok(details) = self.server_details.read() {
300            details.clone()
301        } else {
302            // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None.
303            None
304        }
305    }
306
307    async fn send(
308        &self,
309        message: MessageFromClient,
310        request_id: Option<RequestId>,
311        timeout: Option<Duration>,
312    ) -> SdkResult<Option<ServerMessage>> {
313        let sender = self.sender();
314        let sender = sender.read().await;
315        let sender = sender
316            .as_ref()
317            .ok_or(schema_utils::SdkError::connection_closed())?;
318
319        let outgoing_request_id = self
320            .request_id_gen
321            .request_id_for_message(&message, request_id);
322
323        let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?;
324
325        let response = sender
326            .send_message(ClientMessages::Single(mcp_message), timeout)
327            .await?
328            .map(|res| res.as_single())
329            .transpose()?;
330
331        Ok(response)
332    }
333
334    async fn is_shut_down(&self) -> bool {
335        self.transport.is_shut_down().await
336    }
337    async fn shut_down(&self) -> SdkResult<()> {
338        self.transport.shut_down().await?;
339
340        // wait for tasks
341        let mut tasks_lock = self.handlers.lock().await;
342        let join_handlers: Vec<_> = tasks_lock.drain(..).collect();
343        join_all(join_handlers).await;
344
345        Ok(())
346    }
347}