rust_mcp_sdk/mcp_runtimes/
client_runtime.rs1pub mod mcp_client_runtime;
2pub mod mcp_client_runtime_core;
3
4use crate::{
5 mcp_traits::{RequestIdGen, RequestIdGenNumeric},
6 schema::{
7 schema_utils::{
8 self, ClientMessage, ClientMessages, FromMessage, MessageFromClient, ServerMessage,
9 ServerMessages,
10 },
11 InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification,
12 RequestId, RpcError, ServerResult,
13 },
14};
15use async_trait::async_trait;
16use futures::future::{join_all, try_join_all};
17use futures::StreamExt;
18
19use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport};
20use std::{
21 sync::{Arc, RwLock},
22 time::Duration,
23};
24use tokio::io::{AsyncBufReadExt, BufReader};
25use tokio::sync::Mutex;
26
27use crate::error::{McpSdkError, SdkResult};
28use crate::mcp_traits::mcp_client::McpClient;
29use crate::mcp_traits::mcp_handler::McpClientHandler;
30use crate::utils::ensure_server_protocole_compatibility;
31
32pub struct ClientRuntime {
33 transport: Arc<
35 dyn Transport<
36 ServerMessages,
37 MessageFromClient,
38 ServerMessage,
39 ClientMessages,
40 ClientMessage,
41 >,
42 >,
43 handler: Box<dyn McpClientHandler>,
45 client_details: InitializeRequestParams,
47 server_details: Arc<RwLock<Option<InitializeResult>>>,
49 handlers: Mutex<Vec<tokio::task::JoinHandle<Result<(), McpSdkError>>>>,
50 request_id_gen: Box<dyn RequestIdGen>,
51}
52
53impl ClientRuntime {
54 pub(crate) fn new(
55 client_details: InitializeRequestParams,
56 transport: impl Transport<
57 ServerMessages,
58 MessageFromClient,
59 ServerMessage,
60 ClientMessages,
61 ClientMessage,
62 >,
63 handler: Box<dyn McpClientHandler>,
64 ) -> Self {
65 Self {
66 transport: Arc::new(transport),
67 handler,
68 client_details,
69 server_details: Arc::new(RwLock::new(None)),
70 handlers: Mutex::new(vec![]),
71 request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
72 }
73 }
74
75 async fn initialize_request(&self) -> SdkResult<()> {
76 let request = InitializeRequest::new(self.client_details.clone());
77 let result: ServerResult = self.request(request.into(), None).await?.try_into()?;
78
79 if let ServerResult::InitializeResult(initialize_result) = result {
80 ensure_server_protocole_compatibility(
81 &self.client_details.protocol_version,
82 &initialize_result.protocol_version,
83 )?;
84
85 self.set_server_details(initialize_result)?;
87 self.send_notification(InitializedNotification::new(None).into())
89 .await?;
90 } else {
91 return Err(RpcError::invalid_params()
92 .with_message("Incorrect response to InitializeRequest!".into())
93 .into());
94 }
95 Ok(())
96 }
97
98 pub(crate) async fn handle_message(
99 &self,
100 message: ServerMessage,
101 transport: &Arc<
102 dyn Transport<
103 ServerMessages,
104 MessageFromClient,
105 ServerMessage,
106 ClientMessages,
107 ClientMessage,
108 >,
109 >,
110 ) -> SdkResult<Option<ClientMessage>> {
111 let response = match message {
112 ServerMessage::Request(jsonrpc_request) => {
113 let result = self
114 .handler
115 .handle_request(jsonrpc_request.request, self)
116 .await;
117
118 let response: MessageFromClient = match result {
120 Ok(success_value) => success_value.into(),
121 Err(error_value) => MessageFromClient::Error(error_value),
122 };
123
124 let mcp_message = ClientMessage::from_message(response, Some(jsonrpc_request.id))?;
125 Some(mcp_message)
126 }
127 ServerMessage::Notification(jsonrpc_notification) => {
128 self.handler
129 .handle_notification(jsonrpc_notification.notification, self)
130 .await?;
131 None
132 }
133 ServerMessage::Error(jsonrpc_error) => {
134 self.handler
135 .handle_error(&jsonrpc_error.error, self)
136 .await?;
137 if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await {
138 tx_response
139 .send(ServerMessage::Error(jsonrpc_error))
140 .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
141 } else {
142 tracing::warn!(
143 "Received an error response with no corresponding request: {:?}",
144 &jsonrpc_error.id
145 );
146 }
147 None
148 }
149 ServerMessage::Response(response) => {
150 if let Some(tx_response) = transport.pending_request_tx(&response.id).await {
151 tx_response
152 .send(ServerMessage::Response(response))
153 .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
154 } else {
155 tracing::warn!(
156 "Received a response with no corresponding request: {:?}",
157 &response.id
158 );
159 }
160 None
161 }
162 };
163 Ok(response)
164 }
165}
166
167#[async_trait]
168impl McpClient for ClientRuntime {
169 fn sender(&self) -> Arc<tokio::sync::RwLock<Option<MessageDispatcher<ServerMessage>>>>
170 where
171 MessageDispatcher<ServerMessage>:
172 McpDispatch<ServerMessages, ClientMessages, ServerMessage, ClientMessage>,
173 {
174 (self.transport.message_sender().clone()) as _
175 }
176
177 async fn start(self: Arc<Self>) -> SdkResult<()> {
178 let mut stream = self.transport.start().await?;
180 let transport = self.transport.clone();
181 let mut error_io_stream = transport.error_stream().write().await;
182 let error_io_stream = error_io_stream.take();
183
184 let self_clone = Arc::clone(&self);
185 let self_clone_err = Arc::clone(&self);
186
187 let err_task = tokio::spawn(async move {
188 let self_ref = &*self_clone_err;
189
190 if let Some(IoStream::Readable(error_input)) = error_io_stream {
191 let mut reader = BufReader::new(error_input).lines();
192 loop {
193 tokio::select! {
194 should_break = self_ref.transport.is_shut_down() =>{
195 if should_break {
196 break;
197 }
198 }
199 line = reader.next_line() =>{
200 match line {
201 Ok(Some(error_message)) => {
202 self_ref
203 .handler
204 .handle_process_error(error_message, self_ref)
205 .await?;
206 }
207 Ok(None) => {
208 break;
210 }
211 Err(e) => {
212 tracing::error!("Error reading from std_err: {e}");
213 break;
214 }
215 }
216 }
217 }
218 }
219 }
220
221 Ok::<(), McpSdkError>(())
222 });
223
224 let transport = self.transport.clone();
225
226 let main_task = tokio::spawn(async move {
227 let sender = self_clone.sender();
228 let sender = sender.read().await;
229 let sender = sender
230 .as_ref()
231 .ok_or(schema_utils::SdkError::connection_closed())?;
232 while let Some(mcp_messages) = stream.next().await {
233 let self_ref = &*self_clone;
234
235 match mcp_messages {
236 ServerMessages::Single(server_message) => {
237 let result = self_ref.handle_message(server_message, &transport).await;
238
239 match result {
240 Ok(result) => {
241 if let Some(result) = result {
242 sender
243 .send_message(ClientMessages::Single(result), None)
244 .await?;
245 }
246 }
247 Err(error) => {
248 tracing::error!("Error handling message : {}", error)
249 }
250 }
251 }
252 ServerMessages::Batch(server_messages) => {
253 let handling_tasks: Vec<_> = server_messages
254 .into_iter()
255 .map(|server_message| {
256 self_ref.handle_message(server_message, &transport)
257 })
258 .collect();
259 let results: Vec<_> = try_join_all(handling_tasks).await?;
260 let results: Vec<_> = results.into_iter().flatten().collect();
261
262 if !results.is_empty() {
263 sender
264 .send_message(ClientMessages::Batch(results), None)
265 .await?;
266 }
267 }
268 }
269 }
270 Ok::<(), McpSdkError>(())
271 });
272
273 self.initialize_request().await?;
275
276 let mut lock = self.handlers.lock().await;
277 lock.push(main_task);
278 lock.push(err_task);
279
280 Ok(())
281 }
282
283 fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()> {
284 match self.server_details.write() {
285 Ok(mut details) => {
286 *details = Some(server_details);
287 Ok(())
288 }
289 Err(_) => Err(RpcError::internal_error()
291 .with_message("Internal Error: Failed to acquire write lock.".to_string())
292 .into()),
293 }
294 }
295 fn client_info(&self) -> &InitializeRequestParams {
296 &self.client_details
297 }
298 fn server_info(&self) -> Option<InitializeResult> {
299 if let Ok(details) = self.server_details.read() {
300 details.clone()
301 } else {
302 None
304 }
305 }
306
307 async fn send(
308 &self,
309 message: MessageFromClient,
310 request_id: Option<RequestId>,
311 timeout: Option<Duration>,
312 ) -> SdkResult<Option<ServerMessage>> {
313 let sender = self.sender();
314 let sender = sender.read().await;
315 let sender = sender
316 .as_ref()
317 .ok_or(schema_utils::SdkError::connection_closed())?;
318
319 let outgoing_request_id = self
320 .request_id_gen
321 .request_id_for_message(&message, request_id);
322
323 let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?;
324
325 let response = sender
326 .send_message(ClientMessages::Single(mcp_message), timeout)
327 .await?
328 .map(|res| res.as_single())
329 .transpose()?;
330
331 Ok(response)
332 }
333
334 async fn is_shut_down(&self) -> bool {
335 self.transport.is_shut_down().await
336 }
337 async fn shut_down(&self) -> SdkResult<()> {
338 self.transport.shut_down().await?;
339
340 let mut tasks_lock = self.handlers.lock().await;
342 let join_handlers: Vec<_> = tasks_lock.drain(..).collect();
343 join_all(join_handlers).await;
344
345 Ok(())
346 }
347}