rust_mcp_sdk/mcp_runtimes/
client_runtime.rs

1pub mod mcp_client_runtime;
2pub mod mcp_client_runtime_core;
3
4use crate::schema::schema_utils::{self, MessageFromClient, ServerMessage};
5use crate::schema::{
6    InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification,
7    RpcError, ServerResult,
8};
9use async_trait::async_trait;
10use futures::future::join_all;
11use futures::StreamExt;
12use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport};
13use std::sync::{Arc, RwLock};
14use tokio::io::{AsyncBufReadExt, BufReader};
15use tokio::sync::Mutex;
16
17use crate::error::{McpSdkError, SdkResult};
18use crate::mcp_traits::mcp_client::McpClient;
19use crate::mcp_traits::mcp_handler::McpClientHandler;
20use crate::utils::ensure_server_protocole_compatibility;
21
22pub struct ClientRuntime {
23    // The transport interface for handling messages between client and server
24    transport: Box<dyn Transport<ServerMessage, MessageFromClient>>,
25    // The handler for processing MCP messages
26    handler: Box<dyn McpClientHandler>,
27    // // Information about the server
28    client_details: InitializeRequestParams,
29    // Details about the connected server
30    server_details: Arc<RwLock<Option<InitializeResult>>>,
31    message_sender: tokio::sync::RwLock<Option<MessageDispatcher<ServerMessage>>>,
32    handlers: Mutex<Vec<tokio::task::JoinHandle<Result<(), McpSdkError>>>>,
33}
34
35impl ClientRuntime {
36    pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<ServerMessage>) {
37        let mut lock = self.message_sender.write().await;
38        *lock = Some(sender);
39    }
40
41    pub(crate) fn new(
42        client_details: InitializeRequestParams,
43        transport: impl Transport<ServerMessage, MessageFromClient>,
44        handler: Box<dyn McpClientHandler>,
45    ) -> Self {
46        Self {
47            transport: Box::new(transport),
48            handler,
49            client_details,
50            server_details: Arc::new(RwLock::new(None)),
51            message_sender: tokio::sync::RwLock::new(None),
52            handlers: Mutex::new(vec![]),
53        }
54    }
55
56    async fn initialize_request(&self) -> SdkResult<()> {
57        let request = InitializeRequest::new(self.client_details.clone());
58        let result: ServerResult = self.request(request.into(), None).await?.try_into()?;
59
60        if let ServerResult::InitializeResult(initialize_result) = result {
61            ensure_server_protocole_compatibility(
62                &self.client_details.protocol_version,
63                &initialize_result.protocol_version,
64            )?;
65
66            // store server details
67            self.set_server_details(initialize_result)?;
68            // send a InitializedNotification to the server
69            self.send_notification(InitializedNotification::new(None).into())
70                .await?;
71        } else {
72            return Err(RpcError::invalid_params()
73                .with_message("Incorrect response to InitializeRequest!".into())
74                .into());
75        }
76        Ok(())
77    }
78}
79
80#[async_trait]
81impl McpClient for ClientRuntime {
82    async fn sender(&self) -> &tokio::sync::RwLock<Option<MessageDispatcher<ServerMessage>>>
83    where
84        MessageDispatcher<ServerMessage>: McpDispatch<ServerMessage, MessageFromClient>,
85    {
86        (&self.message_sender) as _
87    }
88
89    async fn start(self: Arc<Self>) -> SdkResult<()> {
90        let (mut stream, sender, error_io) = self.transport.start().await?;
91        self.set_message_sender(sender).await;
92
93        let self_clone = Arc::clone(&self);
94        let self_clone_err = Arc::clone(&self);
95
96        let err_task = tokio::spawn(async move {
97            let self_ref = &*self_clone_err;
98
99            if let IoStream::Readable(error_input) = error_io {
100                let mut reader = BufReader::new(error_input).lines();
101                loop {
102                    tokio::select! {
103                        should_break = self_ref.transport.is_shut_down() =>{
104                            if should_break {
105                                break;
106                            }
107                        }
108                        line = reader.next_line() =>{
109                            match line {
110                                Ok(Some(error_message)) => {
111                                    self_ref
112                                        .handler
113                                        .handle_process_error(error_message, self_ref)
114                                        .await?;
115                                }
116                                Ok(None) => {
117                                    // end of input
118                                    break;
119                                }
120                                Err(e) => {
121                                    eprintln!("Error reading from std_err: {}", e);
122                                    break;
123                                }
124                            }
125                        }
126                    }
127                }
128            }
129            Ok::<(), McpSdkError>(())
130        });
131
132        // send initialize request to the MCP server
133        self_clone.initialize_request().await?;
134
135        let main_task = tokio::spawn(async move {
136            let sender = self_clone.sender().await.read().await;
137            let sender = sender
138                .as_ref()
139                .ok_or(schema_utils::SdkError::connection_closed())?;
140            while let Some(mcp_message) = stream.next().await {
141                let self_ref = &*self_clone;
142
143                match mcp_message {
144                    ServerMessage::Request(jsonrpc_request) => {
145                        let result = self_ref
146                            .handler
147                            .handle_request(jsonrpc_request.request, self_ref)
148                            .await;
149
150                        // create a response to send back to the server
151                        let response: MessageFromClient = match result {
152                            Ok(success_value) => success_value.into(),
153                            Err(error_value) => MessageFromClient::Error(error_value),
154                        };
155                        // send the response back with corresponding request id
156                        sender
157                            .send(response, Some(jsonrpc_request.id), None)
158                            .await?;
159                    }
160                    ServerMessage::Notification(jsonrpc_notification) => {
161                        self_ref
162                            .handler
163                            .handle_notification(jsonrpc_notification.notification, self_ref)
164                            .await?;
165                    }
166                    ServerMessage::Error(jsonrpc_error) => {
167                        self_ref
168                            .handler
169                            .handle_error(jsonrpc_error.error, self_ref)
170                            .await?;
171                    }
172                    // The response is the result of a request, it is processed at the transport level.
173                    ServerMessage::Response(_) => {}
174                }
175            }
176            Ok::<(), McpSdkError>(())
177        });
178
179        let mut lock = self.handlers.lock().await;
180        lock.push(main_task);
181        lock.push(err_task);
182
183        Ok(())
184    }
185
186    fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()> {
187        match self.server_details.write() {
188            Ok(mut details) => {
189                *details = Some(server_details);
190                Ok(())
191            }
192            // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None.
193            Err(_) => Err(RpcError::internal_error()
194                .with_message("Internal Error: Failed to acquire write lock.".to_string())
195                .into()),
196        }
197    }
198    fn client_info(&self) -> &InitializeRequestParams {
199        &self.client_details
200    }
201    fn server_info(&self) -> Option<InitializeResult> {
202        if let Ok(details) = self.server_details.read() {
203            details.clone()
204        } else {
205            // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None.
206            None
207        }
208    }
209
210    async fn is_shut_down(&self) -> bool {
211        self.transport.is_shut_down().await
212    }
213    async fn shut_down(&self) -> SdkResult<()> {
214        self.transport.shut_down().await?;
215
216        // wait for tasks
217        let mut tasks_lock = self.handlers.lock().await;
218        let join_handlers: Vec<_> = tasks_lock.drain(..).collect();
219        join_all(join_handlers).await;
220
221        Ok(())
222    }
223}