rust_mcp_sdk/mcp_runtimes/
client_runtime.rs1pub mod mcp_client_runtime;
2pub mod mcp_client_runtime_core;
3
4use crate::schema::schema_utils::{self, MessageFromClient, ServerMessage};
5use crate::schema::{
6 InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification,
7 RpcError, ServerResult,
8};
9use async_trait::async_trait;
10use futures::future::join_all;
11use futures::StreamExt;
12use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport};
13use std::sync::{Arc, RwLock};
14use tokio::io::{AsyncBufReadExt, BufReader};
15use tokio::sync::Mutex;
16
17use crate::error::{McpSdkError, SdkResult};
18use crate::mcp_traits::mcp_client::McpClient;
19use crate::mcp_traits::mcp_handler::McpClientHandler;
20use crate::utils::ensure_server_protocole_compatibility;
21
22pub struct ClientRuntime {
23 transport: Box<dyn Transport<ServerMessage, MessageFromClient>>,
25 handler: Box<dyn McpClientHandler>,
27 client_details: InitializeRequestParams,
29 server_details: Arc<RwLock<Option<InitializeResult>>>,
31 message_sender: tokio::sync::RwLock<Option<MessageDispatcher<ServerMessage>>>,
32 handlers: Mutex<Vec<tokio::task::JoinHandle<Result<(), McpSdkError>>>>,
33}
34
35impl ClientRuntime {
36 pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<ServerMessage>) {
37 let mut lock = self.message_sender.write().await;
38 *lock = Some(sender);
39 }
40
41 pub(crate) fn new(
42 client_details: InitializeRequestParams,
43 transport: impl Transport<ServerMessage, MessageFromClient>,
44 handler: Box<dyn McpClientHandler>,
45 ) -> Self {
46 Self {
47 transport: Box::new(transport),
48 handler,
49 client_details,
50 server_details: Arc::new(RwLock::new(None)),
51 message_sender: tokio::sync::RwLock::new(None),
52 handlers: Mutex::new(vec![]),
53 }
54 }
55
56 async fn initialize_request(&self) -> SdkResult<()> {
57 let request = InitializeRequest::new(self.client_details.clone());
58 let result: ServerResult = self.request(request.into(), None).await?.try_into()?;
59
60 if let ServerResult::InitializeResult(initialize_result) = result {
61 ensure_server_protocole_compatibility(
62 &self.client_details.protocol_version,
63 &initialize_result.protocol_version,
64 )?;
65
66 self.set_server_details(initialize_result)?;
68 self.send_notification(InitializedNotification::new(None).into())
70 .await?;
71 } else {
72 return Err(RpcError::invalid_params()
73 .with_message("Incorrect response to InitializeRequest!".into())
74 .into());
75 }
76 Ok(())
77 }
78}
79
80#[async_trait]
81impl McpClient for ClientRuntime {
82 async fn sender(&self) -> &tokio::sync::RwLock<Option<MessageDispatcher<ServerMessage>>>
83 where
84 MessageDispatcher<ServerMessage>: McpDispatch<ServerMessage, MessageFromClient>,
85 {
86 (&self.message_sender) as _
87 }
88
89 async fn start(self: Arc<Self>) -> SdkResult<()> {
90 let (mut stream, sender, error_io) = self.transport.start().await?;
91 self.set_message_sender(sender).await;
92
93 let self_clone = Arc::clone(&self);
94 let self_clone_err = Arc::clone(&self);
95
96 let err_task = tokio::spawn(async move {
97 let self_ref = &*self_clone_err;
98
99 if let IoStream::Readable(error_input) = error_io {
100 let mut reader = BufReader::new(error_input).lines();
101 loop {
102 tokio::select! {
103 should_break = self_ref.transport.is_shut_down() =>{
104 if should_break {
105 break;
106 }
107 }
108 line = reader.next_line() =>{
109 match line {
110 Ok(Some(error_message)) => {
111 self_ref
112 .handler
113 .handle_process_error(error_message, self_ref)
114 .await?;
115 }
116 Ok(None) => {
117 break;
119 }
120 Err(e) => {
121 eprintln!("Error reading from std_err: {}", e);
122 break;
123 }
124 }
125 }
126 }
127 }
128 }
129 Ok::<(), McpSdkError>(())
130 });
131
132 self_clone.initialize_request().await?;
134
135 let main_task = tokio::spawn(async move {
136 let sender = self_clone.sender().await.read().await;
137 let sender = sender
138 .as_ref()
139 .ok_or(schema_utils::SdkError::connection_closed())?;
140 while let Some(mcp_message) = stream.next().await {
141 let self_ref = &*self_clone;
142
143 match mcp_message {
144 ServerMessage::Request(jsonrpc_request) => {
145 let result = self_ref
146 .handler
147 .handle_request(jsonrpc_request.request, self_ref)
148 .await;
149
150 let response: MessageFromClient = match result {
152 Ok(success_value) => success_value.into(),
153 Err(error_value) => MessageFromClient::Error(error_value),
154 };
155 sender
157 .send(response, Some(jsonrpc_request.id), None)
158 .await?;
159 }
160 ServerMessage::Notification(jsonrpc_notification) => {
161 self_ref
162 .handler
163 .handle_notification(jsonrpc_notification.notification, self_ref)
164 .await?;
165 }
166 ServerMessage::Error(jsonrpc_error) => {
167 self_ref
168 .handler
169 .handle_error(jsonrpc_error.error, self_ref)
170 .await?;
171 }
172 ServerMessage::Response(_) => {}
174 }
175 }
176 Ok::<(), McpSdkError>(())
177 });
178
179 let mut lock = self.handlers.lock().await;
180 lock.push(main_task);
181 lock.push(err_task);
182
183 Ok(())
184 }
185
186 fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()> {
187 match self.server_details.write() {
188 Ok(mut details) => {
189 *details = Some(server_details);
190 Ok(())
191 }
192 Err(_) => Err(RpcError::internal_error()
194 .with_message("Internal Error: Failed to acquire write lock.".to_string())
195 .into()),
196 }
197 }
198 fn client_info(&self) -> &InitializeRequestParams {
199 &self.client_details
200 }
201 fn server_info(&self) -> Option<InitializeResult> {
202 if let Ok(details) = self.server_details.read() {
203 details.clone()
204 } else {
205 None
207 }
208 }
209
210 async fn is_shut_down(&self) -> bool {
211 self.transport.is_shut_down().await
212 }
213 async fn shut_down(&self) -> SdkResult<()> {
214 self.transport.shut_down().await?;
215
216 let mut tasks_lock = self.handlers.lock().await;
218 let join_handlers: Vec<_> = tasks_lock.drain(..).collect();
219 join_all(join_handlers).await;
220
221 Ok(())
222 }
223}