rust_mcp_sdk/mcp_traits/
mcp_client.rs

1use std::{sync::Arc, time::Duration};
2
3use crate::schema::{
4    schema_utils::{
5        self, ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromClient,
6        NotificationFromClient, RequestFromClient, ResultFromServer, ServerMessage, ServerMessages,
7    },
8    CallToolRequest, CallToolRequestParams, CallToolResult, CompleteRequest, CompleteRequestParams,
9    CreateMessageRequest, GetPromptRequest, GetPromptRequestParams, Implementation,
10    InitializeRequestParams, InitializeResult, ListPromptsRequest, ListPromptsRequestParams,
11    ListResourceTemplatesRequest, ListResourceTemplatesRequestParams, ListResourcesRequest,
12    ListResourcesRequestParams, ListRootsRequest, ListToolsRequest, ListToolsRequestParams,
13    LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams, RequestId,
14    RootsListChangedNotification, RootsListChangedNotificationParams, RpcError, ServerCapabilities,
15    SetLevelRequest, SetLevelRequestParams, SubscribeRequest, SubscribeRequestParams,
16    UnsubscribeRequest, UnsubscribeRequestParams,
17};
18use crate::{error::SdkResult, utils::format_assertion_message};
19use async_trait::async_trait;
20use rust_mcp_transport::{McpDispatch, MessageDispatcher};
21
22#[async_trait]
23pub trait McpClient: Sync + Send {
24    async fn start(self: Arc<Self>) -> SdkResult<()>;
25    fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()>;
26
27    async fn shut_down(&self) -> SdkResult<()>;
28    async fn is_shut_down(&self) -> bool;
29
30    fn sender(&self) -> Arc<tokio::sync::RwLock<Option<MessageDispatcher<ServerMessage>>>>
31    where
32        MessageDispatcher<ServerMessage>:
33            McpDispatch<ServerMessages, ClientMessages, ServerMessage, ClientMessage>;
34
35    fn client_info(&self) -> &InitializeRequestParams;
36    fn server_info(&self) -> Option<InitializeResult>;
37
38    /// Checks whether the server has been initialized with client
39    fn is_initialized(&self) -> bool {
40        self.server_info().is_some()
41    }
42
43    /// Returns the server's name and version information once initialization is complete.
44    /// This method retrieves the server details, if available, after successful initialization.
45    fn server_version(&self) -> Option<Implementation> {
46        self.server_info()
47            .map(|server_details| server_details.server_info)
48    }
49
50    /// Returns the server's capabilities.
51    /// After initialization has completed, this will be populated with the server's reported capabilities.
52    fn server_capabilities(&self) -> Option<ServerCapabilities> {
53        self.server_info().map(|item| item.capabilities)
54    }
55
56    /// Checks if the server has tools available.
57    ///
58    /// This function retrieves the server information and checks if the
59    /// server has tools listed in its capabilities. If the server info
60    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
61    /// `Some(true)` if tools are available, or `Some(false)` if not.
62    ///
63    /// # Returns
64    /// - `None` if server information is not yet available.
65    /// - `Some(true)` if tools are available on the server.
66    /// - `Some(false)` if no tools are available on the server.
67    /// ```rust
68    /// println!("{}",1);
69    /// ```
70    fn server_has_tools(&self) -> Option<bool> {
71        self.server_info()
72            .map(|server_details| server_details.capabilities.tools.is_some())
73    }
74
75    /// Checks if the server has prompts available.
76    ///
77    /// This function retrieves the server information and checks if the
78    /// server has prompts listed in its capabilities. If the server info
79    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
80    /// `Some(true)` if prompts are available, or `Some(false)` if not.
81    ///
82    /// # Returns
83    /// - `None` if server information is not yet available.
84    /// - `Some(true)` if prompts are available on the server.
85    /// - `Some(false)` if no prompts are available on the server.
86    fn server_has_prompts(&self) -> Option<bool> {
87        self.server_info()
88            .map(|server_details| server_details.capabilities.prompts.is_some())
89    }
90
91    /// Checks if the server has experimental capabilities available.
92    ///
93    /// This function retrieves the server information and checks if the
94    /// server has experimental listed in its capabilities. If the server info
95    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
96    /// `Some(true)` if experimental is available, or `Some(false)` if not.
97    ///
98    /// # Returns
99    /// - `None` if server information is not yet available.
100    /// - `Some(true)` if experimental capabilities are available on the server.
101    /// - `Some(false)` if no experimental capabilities are available on the server.
102    fn server_has_experimental(&self) -> Option<bool> {
103        self.server_info()
104            .map(|server_details| server_details.capabilities.experimental.is_some())
105    }
106
107    /// Checks if the server has resources available.
108    ///
109    /// This function retrieves the server information and checks if the
110    /// server has resources listed in its capabilities. If the server info
111    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
112    /// `Some(true)` if resources are available, or `Some(false)` if not.
113    ///
114    /// # Returns
115    /// - `None` if server information is not yet available.
116    /// - `Some(true)` if resources are available on the server.
117    /// - `Some(false)` if no resources are available on the server.
118    fn server_has_resources(&self) -> Option<bool> {
119        self.server_info()
120            .map(|server_details| server_details.capabilities.resources.is_some())
121    }
122
123    /// Checks if the server supports logging.
124    ///
125    /// This function retrieves the server information and checks if the
126    /// server has logging capabilities listed. If the server info has
127    /// not been retrieved yet, it returns `None`. Otherwise, it returns
128    /// `Some(true)` if logging is supported, or `Some(false)` if not.
129    ///
130    /// # Returns
131    /// - `None` if server information is not yet available.
132    /// - `Some(true)` if logging is supported by the server.
133    /// - `Some(false)` if logging is not supported by the server.
134    fn server_supports_logging(&self) -> Option<bool> {
135        self.server_info()
136            .map(|server_details| server_details.capabilities.logging.is_some())
137    }
138
139    fn instructions(&self) -> Option<String> {
140        self.server_info()?.instructions
141    }
142
143    /// Sends a request to the server and processes the response.
144    ///
145    /// This function sends a `RequestFromClient` message to the server, waits for the response,
146    /// and handles the result. If the response is empty or of an invalid type, an error is returned.
147    /// Otherwise, it returns the result from the server.
148    async fn request(
149        &self,
150        request: RequestFromClient,
151        timeout: Option<Duration>,
152    ) -> SdkResult<ResultFromServer> {
153        let response = self
154            .send(MessageFromClient::RequestFromClient(request), None, timeout)
155            .await?;
156
157        let server_message = response.ok_or_else(|| {
158            RpcError::internal_error()
159                .with_message("An empty response was received from the client.".to_string())
160        })?;
161
162        if server_message.is_error() {
163            return Err(server_message.as_error()?.error.into());
164        }
165
166        return Ok(server_message.as_response()?.result);
167    }
168
169    async fn send(
170        &self,
171        message: MessageFromClient,
172        request_id: Option<RequestId>,
173        timeout: Option<Duration>,
174    ) -> SdkResult<Option<ServerMessage>>;
175
176    async fn send_batch(
177        &self,
178        messages: Vec<ClientMessage>,
179        timeout: Option<Duration>,
180    ) -> SdkResult<Option<Vec<ServerMessage>>> {
181        let sender = self.sender();
182        let sender = sender.read().await;
183        let sender = sender
184            .as_ref()
185            .ok_or(schema_utils::SdkError::connection_closed())?;
186
187        let response = sender
188            .send_message(ClientMessages::Batch(messages), timeout)
189            .await?;
190
191        match response {
192            Some(res) => {
193                let server_results = res.as_batch()?;
194                Ok(Some(server_results))
195            }
196            None => Ok(None),
197        }
198    }
199
200    /// Sends a notification. This is a one-way message that is not expected
201    /// to return any response. The method asynchronously sends the notification using
202    /// the transport layer and does not wait for any acknowledgement or result.
203    async fn send_notification(&self, notification: NotificationFromClient) -> SdkResult<()> {
204        let sender = self.sender();
205        let sender = sender.read().await;
206        let sender = sender
207            .as_ref()
208            .ok_or(schema_utils::SdkError::connection_closed())?;
209
210        let mcp_message = ClientMessage::from_message(MessageFromClient::from(notification), None)?;
211
212        sender
213            .send_message(ClientMessages::Single(mcp_message), None)
214            .await?;
215        Ok(())
216    }
217
218    /// A ping request to check that the other party is still alive.
219    /// The receiver must promptly respond, or else may be disconnected.
220    ///
221    /// This function creates a `PingRequest` with no specific parameters, sends the request and awaits the response
222    /// Once the response is received, it attempts to convert it into the expected
223    /// result type.
224    ///
225    /// # Returns
226    /// A `SdkResult` containing the `rust_mcp_schema::Result` if the request is successful.
227    /// If the request or conversion fails, an error is returned.
228    async fn ping(&self, timeout: Option<Duration>) -> SdkResult<crate::schema::Result> {
229        let ping_request = PingRequest::new(None);
230        let response = self.request(ping_request.into(), timeout).await?;
231        Ok(response.try_into()?)
232    }
233
234    async fn complete(
235        &self,
236        params: CompleteRequestParams,
237    ) -> SdkResult<crate::schema::CompleteResult> {
238        let request = CompleteRequest::new(params);
239        let response = self.request(request.into(), None).await?;
240        Ok(response.try_into()?)
241    }
242
243    async fn set_logging_level(&self, level: LoggingLevel) -> SdkResult<crate::schema::Result> {
244        let request = SetLevelRequest::new(SetLevelRequestParams { level });
245        let response = self.request(request.into(), None).await?;
246        Ok(response.try_into()?)
247    }
248
249    async fn get_prompt(
250        &self,
251        params: GetPromptRequestParams,
252    ) -> SdkResult<crate::schema::GetPromptResult> {
253        let request = GetPromptRequest::new(params);
254        let response = self.request(request.into(), None).await?;
255        Ok(response.try_into()?)
256    }
257
258    async fn list_prompts(
259        &self,
260        params: Option<ListPromptsRequestParams>,
261    ) -> SdkResult<crate::schema::ListPromptsResult> {
262        let request = ListPromptsRequest::new(params);
263        let response = self.request(request.into(), None).await?;
264        Ok(response.try_into()?)
265    }
266
267    async fn list_resources(
268        &self,
269        params: Option<ListResourcesRequestParams>,
270    ) -> SdkResult<crate::schema::ListResourcesResult> {
271        // passing ListResourcesRequestParams::default() if params is None
272        // need to investigate more but this could be a inconsistency on some MCP servers
273        // where it is not required for other requests like prompts/list or tools/list etc
274        // that excepts an empty params to be passed (like server-everything)
275        let request =
276            ListResourcesRequest::new(params.or(Some(ListResourcesRequestParams::default())));
277        let response = self.request(request.into(), None).await?;
278        Ok(response.try_into()?)
279    }
280
281    async fn list_resource_templates(
282        &self,
283        params: Option<ListResourceTemplatesRequestParams>,
284    ) -> SdkResult<crate::schema::ListResourceTemplatesResult> {
285        let request = ListResourceTemplatesRequest::new(params);
286        let response = self.request(request.into(), None).await?;
287        Ok(response.try_into()?)
288    }
289
290    async fn read_resource(
291        &self,
292        params: ReadResourceRequestParams,
293    ) -> SdkResult<crate::schema::ReadResourceResult> {
294        let request = ReadResourceRequest::new(params);
295        let response = self.request(request.into(), None).await?;
296        Ok(response.try_into()?)
297    }
298
299    async fn subscribe_resource(
300        &self,
301        params: SubscribeRequestParams,
302    ) -> SdkResult<crate::schema::Result> {
303        let request = SubscribeRequest::new(params);
304        let response = self.request(request.into(), None).await?;
305        Ok(response.try_into()?)
306    }
307
308    async fn unsubscribe_resource(
309        &self,
310        params: UnsubscribeRequestParams,
311    ) -> SdkResult<crate::schema::Result> {
312        let request = UnsubscribeRequest::new(params);
313        let response = self.request(request.into(), None).await?;
314        Ok(response.try_into()?)
315    }
316
317    async fn call_tool(&self, params: CallToolRequestParams) -> SdkResult<CallToolResult> {
318        let request = CallToolRequest::new(params);
319        let response = self.request(request.into(), None).await?;
320        Ok(response.try_into()?)
321    }
322
323    async fn list_tools(
324        &self,
325        params: Option<ListToolsRequestParams>,
326    ) -> SdkResult<crate::schema::ListToolsResult> {
327        let request = ListToolsRequest::new(params);
328        let response = self.request(request.into(), None).await?;
329        Ok(response.try_into()?)
330    }
331
332    async fn send_roots_list_changed(
333        &self,
334        params: Option<RootsListChangedNotificationParams>,
335    ) -> SdkResult<()> {
336        let notification = RootsListChangedNotification::new(params);
337        self.send_notification(notification.into()).await
338    }
339
340    /// Asserts that server capabilities support the requested method.
341    ///
342    /// Verifies that the server has the necessary capabilities to handle the given request method.
343    /// If the server is not initialized or lacks a required capability, an error is returned.
344    /// This can be utilized to avoid sending requests when the opposing party lacks support for them.
345    fn assert_server_capabilities(&self, request_method: &String) -> SdkResult<()> {
346        let entity = "Server";
347
348        let capabilities = self.server_capabilities().ok_or::<RpcError>(
349            RpcError::internal_error().with_message("Server is not initialized!".to_string()),
350        )?;
351
352        if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() {
353            return Err(RpcError::internal_error()
354                .with_message(format_assertion_message(entity, "logging", request_method))
355                .into());
356        }
357
358        if [
359            GetPromptRequest::method_name(),
360            ListPromptsRequest::method_name(),
361        ]
362        .contains(request_method)
363            && capabilities.prompts.is_none()
364        {
365            return Err(RpcError::internal_error()
366                .with_message(format_assertion_message(entity, "prompts", request_method))
367                .into());
368        }
369
370        if [
371            ListResourcesRequest::method_name(),
372            ListResourceTemplatesRequest::method_name(),
373            ReadResourceRequest::method_name(),
374            SubscribeRequest::method_name(),
375            UnsubscribeRequest::method_name(),
376        ]
377        .contains(request_method)
378            && capabilities.resources.is_none()
379        {
380            return Err(RpcError::internal_error()
381                .with_message(format_assertion_message(
382                    entity,
383                    "resources",
384                    request_method,
385                ))
386                .into());
387        }
388
389        if [
390            CallToolRequest::method_name(),
391            ListToolsRequest::method_name(),
392        ]
393        .contains(request_method)
394            && capabilities.tools.is_none()
395        {
396            return Err(RpcError::internal_error()
397                .with_message(format_assertion_message(entity, "tools", request_method))
398                .into());
399        }
400
401        Ok(())
402    }
403
404    fn assert_client_notification_capabilities(
405        &self,
406        notification_method: &String,
407    ) -> std::result::Result<(), RpcError> {
408        let entity = "Client";
409        let capabilities = &self.client_info().capabilities;
410
411        if *notification_method == RootsListChangedNotification::method_name()
412            && capabilities.roots.is_some()
413        {
414            return Err(
415                RpcError::internal_error().with_message(format_assertion_message(
416                    entity,
417                    "roots list changed notifications",
418                    notification_method,
419                )),
420            );
421        }
422
423        Ok(())
424    }
425
426    fn assert_client_request_capabilities(
427        &self,
428        request_method: &String,
429    ) -> std::result::Result<(), RpcError> {
430        let entity = "Client";
431        let capabilities = &self.client_info().capabilities;
432
433        if *request_method == CreateMessageRequest::method_name() && capabilities.sampling.is_some()
434        {
435            return Err(
436                RpcError::internal_error().with_message(format_assertion_message(
437                    entity,
438                    "sampling capability",
439                    request_method,
440                )),
441            );
442        }
443
444        if *request_method == ListRootsRequest::method_name() && capabilities.roots.is_some() {
445            return Err(
446                RpcError::internal_error().with_message(format_assertion_message(
447                    entity,
448                    "roots capability",
449                    request_method,
450                )),
451            );
452        }
453
454        Ok(())
455    }
456}