rust_mcp_sdk/mcp_traits/
mcp_client.rs

1use std::{sync::Arc, time::Duration};
2
3use crate::schema::{
4    schema_utils::{
5        self, McpMessage, MessageFromClient, NotificationFromClient, RequestFromClient,
6        ResultFromServer, ServerMessage,
7    },
8    CallToolRequest, CallToolRequestParams, CallToolResult, CompleteRequest, CompleteRequestParams,
9    CreateMessageRequest, GetPromptRequest, GetPromptRequestParams, Implementation,
10    InitializeRequestParams, InitializeResult, ListPromptsRequest, ListPromptsRequestParams,
11    ListResourceTemplatesRequest, ListResourceTemplatesRequestParams, ListResourcesRequest,
12    ListResourcesRequestParams, ListRootsRequest, ListToolsRequest, ListToolsRequestParams,
13    LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams,
14    RootsListChangedNotification, RootsListChangedNotificationParams, RpcError, ServerCapabilities,
15    SetLevelRequest, SetLevelRequestParams, SubscribeRequest, SubscribeRequestParams,
16    UnsubscribeRequest, UnsubscribeRequestParams,
17};
18use async_trait::async_trait;
19use rust_mcp_transport::{McpDispatch, MessageDispatcher};
20
21use crate::{error::SdkResult, utils::format_assertion_message};
22
23#[async_trait]
24pub trait McpClient: Sync + Send {
25    async fn start(self: Arc<Self>) -> SdkResult<()>;
26    fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()>;
27
28    async fn shut_down(&self) -> SdkResult<()>;
29    async fn is_shut_down(&self) -> bool;
30
31    async fn sender(&self) -> &tokio::sync::RwLock<Option<MessageDispatcher<ServerMessage>>>
32    where
33        MessageDispatcher<ServerMessage>: McpDispatch<ServerMessage, MessageFromClient>;
34
35    fn client_info(&self) -> &InitializeRequestParams;
36    fn server_info(&self) -> Option<InitializeResult>;
37
38    #[deprecated(since = "0.2.0", note = "Use `client_info()` instead.")]
39    fn get_client_info(&self) -> &InitializeRequestParams {
40        self.client_info()
41    }
42
43    #[deprecated(since = "0.2.0", note = "Use `server_info()` instead.")]
44    fn get_server_info(&self) -> Option<InitializeResult> {
45        self.server_info()
46    }
47
48    /// Checks whether the server has been initialized with client
49    fn is_initialized(&self) -> bool {
50        self.server_info().is_some()
51    }
52
53    /// Returns the server's name and version information once initialization is complete.
54    /// This method retrieves the server details, if available, after successful initialization.
55    fn server_version(&self) -> Option<Implementation> {
56        self.server_info()
57            .map(|server_details| server_details.server_info)
58    }
59
60    #[deprecated(since = "0.2.0", note = "Use `server_version()` instead.")]
61    fn get_server_version(&self) -> Option<Implementation> {
62        self.server_info()
63            .map(|server_details| server_details.server_info)
64    }
65
66    /// Returns the server's capabilities.
67    /// After initialization has completed, this will be populated with the server's reported capabilities.
68    fn server_capabilities(&self) -> Option<ServerCapabilities> {
69        self.server_info().map(|item| item.capabilities)
70    }
71
72    #[deprecated(since = "0.2.0", note = "Use `server_capabilities()` instead.")]
73    fn get_server_capabilities(&self) -> Option<ServerCapabilities> {
74        self.server_info().map(|item| item.capabilities)
75    }
76
77    /// Checks if the server has tools available.
78    ///
79    /// This function retrieves the server information and checks if the
80    /// server has tools listed in its capabilities. If the server info
81    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
82    /// `Some(true)` if tools are available, or `Some(false)` if not.
83    ///
84    /// # Returns
85    /// - `None` if server information is not yet available.
86    /// - `Some(true)` if tools are available on the server.
87    /// - `Some(false)` if no tools are available on the server.
88    /// ```rust
89    /// println!("{}",1);
90    /// ```
91    fn server_has_tools(&self) -> Option<bool> {
92        self.server_info()
93            .map(|server_details| server_details.capabilities.tools.is_some())
94    }
95
96    /// Checks if the server has prompts available.
97    ///
98    /// This function retrieves the server information and checks if the
99    /// server has prompts listed in its capabilities. If the server info
100    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
101    /// `Some(true)` if prompts are available, or `Some(false)` if not.
102    ///
103    /// # Returns
104    /// - `None` if server information is not yet available.
105    /// - `Some(true)` if prompts are available on the server.
106    /// - `Some(false)` if no prompts are available on the server.
107    fn server_has_prompts(&self) -> Option<bool> {
108        self.server_info()
109            .map(|server_details| server_details.capabilities.prompts.is_some())
110    }
111
112    /// Checks if the server has experimental capabilities available.
113    ///
114    /// This function retrieves the server information and checks if the
115    /// server has experimental listed in its capabilities. If the server info
116    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
117    /// `Some(true)` if experimental is available, or `Some(false)` if not.
118    ///
119    /// # Returns
120    /// - `None` if server information is not yet available.
121    /// - `Some(true)` if experimental capabilities are available on the server.
122    /// - `Some(false)` if no experimental capabilities are available on the server.
123    fn server_has_experimental(&self) -> Option<bool> {
124        self.server_info()
125            .map(|server_details| server_details.capabilities.experimental.is_some())
126    }
127
128    /// Checks if the server has resources available.
129    ///
130    /// This function retrieves the server information and checks if the
131    /// server has resources listed in its capabilities. If the server info
132    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
133    /// `Some(true)` if resources are available, or `Some(false)` if not.
134    ///
135    /// # Returns
136    /// - `None` if server information is not yet available.
137    /// - `Some(true)` if resources are available on the server.
138    /// - `Some(false)` if no resources are available on the server.
139    fn server_has_resources(&self) -> Option<bool> {
140        self.server_info()
141            .map(|server_details| server_details.capabilities.resources.is_some())
142    }
143
144    /// Checks if the server supports logging.
145    ///
146    /// This function retrieves the server information and checks if the
147    /// server has logging capabilities listed. If the server info has
148    /// not been retrieved yet, it returns `None`. Otherwise, it returns
149    /// `Some(true)` if logging is supported, or `Some(false)` if not.
150    ///
151    /// # Returns
152    /// - `None` if server information is not yet available.
153    /// - `Some(true)` if logging is supported by the server.
154    /// - `Some(false)` if logging is not supported by the server.
155    fn server_supports_logging(&self) -> Option<bool> {
156        self.server_info()
157            .map(|server_details| server_details.capabilities.logging.is_some())
158    }
159    #[deprecated(since = "0.2.0", note = "Use `instructions()` instead.")]
160    fn get_instructions(&self) -> Option<String> {
161        self.server_info()?.instructions
162    }
163
164    fn instructions(&self) -> Option<String> {
165        self.server_info()?.instructions
166    }
167
168    /// Sends a request to the server and processes the response.
169    ///
170    /// This function sends a `RequestFromClient` message to the server, waits for the response,
171    /// and handles the result. If the response is empty or of an invalid type, an error is returned.
172    /// Otherwise, it returns the result from the server.
173    async fn request(
174        &self,
175        request: RequestFromClient,
176        timeout: Option<Duration>,
177    ) -> SdkResult<ResultFromServer> {
178        let sender = self.sender().await.read().await;
179        let sender = sender
180            .as_ref()
181            .ok_or(schema_utils::SdkError::connection_closed())?;
182
183        // Send the request and receive the response.
184        let response = sender
185            .send(MessageFromClient::RequestFromClient(request), None, timeout)
186            .await?;
187
188        let server_message = response.ok_or_else(|| {
189            RpcError::internal_error()
190                .with_message("An empty response was received from the server.".to_string())
191        })?;
192
193        if server_message.is_error() {
194            return Err(server_message.as_error()?.error.into());
195        }
196
197        return Ok(server_message.as_response()?.result);
198    }
199
200    /// Sends a notification. This is a one-way message that is not expected
201    /// to return any response. The method asynchronously sends the notification using
202    /// the transport layer and does not wait for any acknowledgement or result.
203    async fn send_notification(&self, notification: NotificationFromClient) -> SdkResult<()> {
204        let sender = self.sender().await.read().await;
205        let sender = sender
206            .as_ref()
207            .ok_or(schema_utils::SdkError::connection_closed())?;
208        sender
209            .send(
210                MessageFromClient::NotificationFromClient(notification),
211                None,
212                None,
213            )
214            .await?;
215        Ok(())
216    }
217
218    /// A ping request to check that the other party is still alive.
219    /// The receiver must promptly respond, or else may be disconnected.
220    ///
221    /// This function creates a `PingRequest` with no specific parameters, sends the request and awaits the response
222    /// Once the response is received, it attempts to convert it into the expected
223    /// result type.
224    ///
225    /// # Returns
226    /// A `SdkResult` containing the `rust_mcp_schema::Result` if the request is successful.
227    /// If the request or conversion fails, an error is returned.
228    async fn ping(&self, timeout: Option<Duration>) -> SdkResult<crate::schema::Result> {
229        let ping_request = PingRequest::new(None);
230        let response = self.request(ping_request.into(), timeout).await?;
231        Ok(response.try_into()?)
232    }
233
234    async fn complete(
235        &self,
236        params: CompleteRequestParams,
237    ) -> SdkResult<crate::schema::CompleteResult> {
238        let request = CompleteRequest::new(params);
239        let response = self.request(request.into(), None).await?;
240        Ok(response.try_into()?)
241    }
242
243    async fn set_logging_level(&self, level: LoggingLevel) -> SdkResult<crate::schema::Result> {
244        let request = SetLevelRequest::new(SetLevelRequestParams { level });
245        let response = self.request(request.into(), None).await?;
246        Ok(response.try_into()?)
247    }
248
249    async fn get_prompt(
250        &self,
251        params: GetPromptRequestParams,
252    ) -> SdkResult<crate::schema::GetPromptResult> {
253        let request = GetPromptRequest::new(params);
254        let response = self.request(request.into(), None).await?;
255        Ok(response.try_into()?)
256    }
257
258    async fn list_prompts(
259        &self,
260        params: Option<ListPromptsRequestParams>,
261    ) -> SdkResult<crate::schema::ListPromptsResult> {
262        let request = ListPromptsRequest::new(params);
263        let response = self.request(request.into(), None).await?;
264        Ok(response.try_into()?)
265    }
266
267    async fn list_resources(
268        &self,
269        params: Option<ListResourcesRequestParams>,
270    ) -> SdkResult<crate::schema::ListResourcesResult> {
271        // passing ListResourcesRequestParams::default() if params is None
272        // need to investigate more but this could be a inconsistency on some MCP servers
273        // where it is not required for other requests like prompts/list or tools/list etc
274        // that excepts an empty params to be passed (like server-everything)
275        let request =
276            ListResourcesRequest::new(params.or(Some(ListResourcesRequestParams::default())));
277        let response = self.request(request.into(), None).await?;
278        Ok(response.try_into()?)
279    }
280
281    async fn list_resource_templates(
282        &self,
283        params: Option<ListResourceTemplatesRequestParams>,
284    ) -> SdkResult<crate::schema::ListResourceTemplatesResult> {
285        let request = ListResourceTemplatesRequest::new(params);
286        let response = self.request(request.into(), None).await?;
287        Ok(response.try_into()?)
288    }
289
290    async fn read_resource(
291        &self,
292        params: ReadResourceRequestParams,
293    ) -> SdkResult<crate::schema::ReadResourceResult> {
294        let request = ReadResourceRequest::new(params);
295        let response = self.request(request.into(), None).await?;
296        Ok(response.try_into()?)
297    }
298
299    async fn subscribe_resource(
300        &self,
301        params: SubscribeRequestParams,
302    ) -> SdkResult<crate::schema::Result> {
303        let request = SubscribeRequest::new(params);
304        let response = self.request(request.into(), None).await?;
305        Ok(response.try_into()?)
306    }
307
308    async fn unsubscribe_resource(
309        &self,
310        params: UnsubscribeRequestParams,
311    ) -> SdkResult<crate::schema::Result> {
312        let request = UnsubscribeRequest::new(params);
313        let response = self.request(request.into(), None).await?;
314        Ok(response.try_into()?)
315    }
316
317    async fn call_tool(&self, params: CallToolRequestParams) -> SdkResult<CallToolResult> {
318        let request = CallToolRequest::new(params);
319        let response = self.request(request.into(), None).await?;
320        Ok(response.try_into()?)
321    }
322
323    async fn list_tools(
324        &self,
325        params: Option<ListToolsRequestParams>,
326    ) -> SdkResult<crate::schema::ListToolsResult> {
327        let request = ListToolsRequest::new(params);
328        let response = self.request(request.into(), None).await?;
329        Ok(response.try_into()?)
330    }
331
332    async fn send_roots_list_changed(
333        &self,
334        params: Option<RootsListChangedNotificationParams>,
335    ) -> SdkResult<()> {
336        let notification = RootsListChangedNotification::new(params);
337        self.send_notification(notification.into()).await
338    }
339
340    /// Asserts that server capabilities support the requested method.
341    ///
342    /// Verifies that the server has the necessary capabilities to handle the given request method.
343    /// If the server is not initialized or lacks a required capability, an error is returned.
344    /// This can be utilized to avoid sending requests when the opposing party lacks support for them.
345    fn assert_server_capabilities(&self, request_method: &String) -> SdkResult<()> {
346        let entity = "Server";
347
348        let capabilities = self.server_capabilities().ok_or::<RpcError>(
349            RpcError::internal_error().with_message("Server is not initialized!".to_string()),
350        )?;
351
352        if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() {
353            return Err(RpcError::internal_error()
354                .with_message(format_assertion_message(entity, "logging", request_method))
355                .into());
356        }
357
358        if [
359            GetPromptRequest::method_name(),
360            ListPromptsRequest::method_name(),
361        ]
362        .contains(request_method)
363            && capabilities.prompts.is_none()
364        {
365            return Err(RpcError::internal_error()
366                .with_message(format_assertion_message(entity, "prompts", request_method))
367                .into());
368        }
369
370        if [
371            ListResourcesRequest::method_name(),
372            ListResourceTemplatesRequest::method_name(),
373            ReadResourceRequest::method_name(),
374            SubscribeRequest::method_name(),
375            UnsubscribeRequest::method_name(),
376        ]
377        .contains(request_method)
378            && capabilities.resources.is_none()
379        {
380            return Err(RpcError::internal_error()
381                .with_message(format_assertion_message(
382                    entity,
383                    "resources",
384                    request_method,
385                ))
386                .into());
387        }
388
389        if [
390            CallToolRequest::method_name(),
391            ListToolsRequest::method_name(),
392        ]
393        .contains(request_method)
394            && capabilities.tools.is_none()
395        {
396            return Err(RpcError::internal_error()
397                .with_message(format_assertion_message(entity, "tools", request_method))
398                .into());
399        }
400
401        Ok(())
402    }
403
404    fn assert_client_notification_capabilities(
405        &self,
406        notification_method: &String,
407    ) -> std::result::Result<(), RpcError> {
408        let entity = "Client";
409        let capabilities = &self.client_info().capabilities;
410
411        if *notification_method == RootsListChangedNotification::method_name()
412            && capabilities.roots.is_some()
413        {
414            return Err(
415                RpcError::internal_error().with_message(format_assertion_message(
416                    entity,
417                    "roots list changed notifications",
418                    notification_method,
419                )),
420            );
421        }
422
423        Ok(())
424    }
425
426    fn assert_client_request_capabilities(
427        &self,
428        request_method: &String,
429    ) -> std::result::Result<(), RpcError> {
430        let entity = "Client";
431        let capabilities = &self.client_info().capabilities;
432
433        if *request_method == CreateMessageRequest::method_name() && capabilities.sampling.is_some()
434        {
435            return Err(
436                RpcError::internal_error().with_message(format_assertion_message(
437                    entity,
438                    "sampling capability",
439                    request_method,
440                )),
441            );
442        }
443
444        if *request_method == ListRootsRequest::method_name() && capabilities.roots.is_some() {
445            return Err(
446                RpcError::internal_error().with_message(format_assertion_message(
447                    entity,
448                    "roots capability",
449                    request_method,
450                )),
451            );
452        }
453
454        Ok(())
455    }
456}