rust_mcp_sdk/mcp_traits/
mcp_client.rs

1use crate::schema::{
2    schema_utils::{
3        ClientMessage, McpMessage, MessageFromClient, NotificationFromClient, RequestFromClient,
4        ResultFromServer, ServerMessage,
5    },
6    CallToolRequest, CallToolRequestParams, CallToolResult, CompleteRequest, CompleteRequestParams,
7    CreateMessageRequest, GetPromptRequest, GetPromptRequestParams, Implementation,
8    InitializeRequestParams, InitializeResult, ListPromptsRequest, ListPromptsRequestParams,
9    ListResourceTemplatesRequest, ListResourceTemplatesRequestParams, ListResourcesRequest,
10    ListResourcesRequestParams, ListRootsRequest, ListToolsRequest, ListToolsRequestParams,
11    LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams, RequestId,
12    RootsListChangedNotification, RootsListChangedNotificationParams, RpcError, ServerCapabilities,
13    SetLevelRequest, SetLevelRequestParams, SubscribeRequest, SubscribeRequestParams,
14    UnsubscribeRequest, UnsubscribeRequestParams,
15};
16use crate::{error::SdkResult, utils::format_assertion_message};
17use async_trait::async_trait;
18use std::{sync::Arc, time::Duration};
19
20#[async_trait]
21pub trait McpClient: Sync + Send {
22    async fn start(self: Arc<Self>) -> SdkResult<()>;
23    fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()>;
24
25    async fn terminate_session(&self);
26
27    async fn shut_down(&self) -> SdkResult<()>;
28    async fn is_shut_down(&self) -> bool;
29
30    fn client_info(&self) -> &InitializeRequestParams;
31    fn server_info(&self) -> Option<InitializeResult>;
32
33    /// Checks whether the server has been initialized with client
34    fn is_initialized(&self) -> bool {
35        self.server_info().is_some()
36    }
37
38    /// Returns the server's name and version information once initialization is complete.
39    /// This method retrieves the server details, if available, after successful initialization.
40    fn server_version(&self) -> Option<Implementation> {
41        self.server_info()
42            .map(|server_details| server_details.server_info)
43    }
44
45    /// Returns the server's capabilities.
46    /// After initialization has completed, this will be populated with the server's reported capabilities.
47    fn server_capabilities(&self) -> Option<ServerCapabilities> {
48        self.server_info().map(|item| item.capabilities)
49    }
50
51    /// Checks if the server has tools available.
52    ///
53    /// This function retrieves the server information and checks if the
54    /// server has tools listed in its capabilities. If the server info
55    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
56    /// `Some(true)` if tools are available, or `Some(false)` if not.
57    ///
58    /// # Returns
59    /// - `None` if server information is not yet available.
60    /// - `Some(true)` if tools are available on the server.
61    /// - `Some(false)` if no tools are available on the server.
62    /// ```rust
63    /// println!("{}",1);
64    /// ```
65    fn server_has_tools(&self) -> Option<bool> {
66        self.server_info()
67            .map(|server_details| server_details.capabilities.tools.is_some())
68    }
69
70    /// Checks if the server has prompts available.
71    ///
72    /// This function retrieves the server information and checks if the
73    /// server has prompts listed in its capabilities. If the server info
74    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
75    /// `Some(true)` if prompts are available, or `Some(false)` if not.
76    ///
77    /// # Returns
78    /// - `None` if server information is not yet available.
79    /// - `Some(true)` if prompts are available on the server.
80    /// - `Some(false)` if no prompts are available on the server.
81    fn server_has_prompts(&self) -> Option<bool> {
82        self.server_info()
83            .map(|server_details| server_details.capabilities.prompts.is_some())
84    }
85
86    /// Checks if the server has experimental capabilities available.
87    ///
88    /// This function retrieves the server information and checks if the
89    /// server has experimental listed in its capabilities. If the server info
90    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
91    /// `Some(true)` if experimental is available, or `Some(false)` if not.
92    ///
93    /// # Returns
94    /// - `None` if server information is not yet available.
95    /// - `Some(true)` if experimental capabilities are available on the server.
96    /// - `Some(false)` if no experimental capabilities are available on the server.
97    fn server_has_experimental(&self) -> Option<bool> {
98        self.server_info()
99            .map(|server_details| server_details.capabilities.experimental.is_some())
100    }
101
102    /// Checks if the server has resources available.
103    ///
104    /// This function retrieves the server information and checks if the
105    /// server has resources listed in its capabilities. If the server info
106    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
107    /// `Some(true)` if resources are available, or `Some(false)` if not.
108    ///
109    /// # Returns
110    /// - `None` if server information is not yet available.
111    /// - `Some(true)` if resources are available on the server.
112    /// - `Some(false)` if no resources are available on the server.
113    fn server_has_resources(&self) -> Option<bool> {
114        self.server_info()
115            .map(|server_details| server_details.capabilities.resources.is_some())
116    }
117
118    /// Checks if the server supports logging.
119    ///
120    /// This function retrieves the server information and checks if the
121    /// server has logging capabilities listed. If the server info has
122    /// not been retrieved yet, it returns `None`. Otherwise, it returns
123    /// `Some(true)` if logging is supported, or `Some(false)` if not.
124    ///
125    /// # Returns
126    /// - `None` if server information is not yet available.
127    /// - `Some(true)` if logging is supported by the server.
128    /// - `Some(false)` if logging is not supported by the server.
129    fn server_supports_logging(&self) -> Option<bool> {
130        self.server_info()
131            .map(|server_details| server_details.capabilities.logging.is_some())
132    }
133
134    /// Checks if the server supports argument autocompletion suggestions.
135    ///
136    /// This function retrieves the server information and checks if the
137    /// server has completions capabilities listed. If the server info has
138    /// not been retrieved yet, it returns `None`. Otherwise, it returns
139    /// `Some(true)` if completions is supported, or `Some(false)` if not.
140    ///
141    /// # Returns
142    /// - `None` if server information is not yet available.
143    /// - `Some(true)` if completions is supported by the server.
144    /// - `Some(false)` if completions is not supported by the server.
145    fn server_supports_completion(&self) -> Option<bool> {
146        self.server_info()
147            .map(|server_details| server_details.capabilities.completions.is_some())
148    }
149
150    fn instructions(&self) -> Option<String> {
151        self.server_info()?.instructions
152    }
153
154    /// Sends a request to the server and processes the response.
155    ///
156    /// This function sends a `RequestFromClient` message to the server, waits for the response,
157    /// and handles the result. If the response is empty or of an invalid type, an error is returned.
158    /// Otherwise, it returns the result from the server.
159    async fn request(
160        &self,
161        request: RequestFromClient,
162        timeout: Option<Duration>,
163    ) -> SdkResult<ResultFromServer> {
164        let response = self
165            .send(MessageFromClient::RequestFromClient(request), None, timeout)
166            .await?;
167
168        let server_message = response.ok_or_else(|| {
169            RpcError::internal_error()
170                .with_message("An empty response was received from the client.".to_string())
171        })?;
172
173        if server_message.is_error() {
174            return Err(server_message.as_error()?.error.into());
175        }
176
177        return Ok(server_message.as_response()?.result);
178    }
179
180    async fn send(
181        &self,
182        message: MessageFromClient,
183        request_id: Option<RequestId>,
184        request_timeout: Option<Duration>,
185    ) -> SdkResult<Option<ServerMessage>>;
186
187    async fn send_batch(
188        &self,
189        messages: Vec<ClientMessage>,
190        timeout: Option<Duration>,
191    ) -> SdkResult<Option<Vec<ServerMessage>>>;
192
193    /// Sends a notification. This is a one-way message that is not expected
194    /// to return any response. The method asynchronously sends the notification using
195    /// the transport layer and does not wait for any acknowledgement or result.
196    async fn send_notification(&self, notification: NotificationFromClient) -> SdkResult<()> {
197        self.send(notification.into(), None, None).await?;
198        Ok(())
199    }
200
201    /// A ping request to check that the other party is still alive.
202    /// The receiver must promptly respond, or else may be disconnected.
203    ///
204    /// This function creates a `PingRequest` with no specific parameters, sends the request and awaits the response
205    /// Once the response is received, it attempts to convert it into the expected
206    /// result type.
207    ///
208    /// # Returns
209    /// A `SdkResult` containing the `rust_mcp_schema::Result` if the request is successful.
210    /// If the request or conversion fails, an error is returned.
211    async fn ping(&self, timeout: Option<Duration>) -> SdkResult<crate::schema::Result> {
212        let ping_request = PingRequest::new(None);
213        let response = self.request(ping_request.into(), timeout).await?;
214        Ok(response.try_into()?)
215    }
216
217    async fn complete(
218        &self,
219        params: CompleteRequestParams,
220    ) -> SdkResult<crate::schema::CompleteResult> {
221        let request = CompleteRequest::new(params);
222        let response = self.request(request.into(), None).await?;
223        Ok(response.try_into()?)
224    }
225
226    async fn set_logging_level(&self, level: LoggingLevel) -> SdkResult<crate::schema::Result> {
227        let request = SetLevelRequest::new(SetLevelRequestParams { level });
228        let response = self.request(request.into(), None).await?;
229        Ok(response.try_into()?)
230    }
231
232    async fn get_prompt(
233        &self,
234        params: GetPromptRequestParams,
235    ) -> SdkResult<crate::schema::GetPromptResult> {
236        let request = GetPromptRequest::new(params);
237        let response = self.request(request.into(), None).await?;
238        Ok(response.try_into()?)
239    }
240
241    async fn list_prompts(
242        &self,
243        params: Option<ListPromptsRequestParams>,
244    ) -> SdkResult<crate::schema::ListPromptsResult> {
245        let request = ListPromptsRequest::new(params);
246        let response = self.request(request.into(), None).await?;
247        Ok(response.try_into()?)
248    }
249
250    async fn list_resources(
251        &self,
252        params: Option<ListResourcesRequestParams>,
253    ) -> SdkResult<crate::schema::ListResourcesResult> {
254        // passing ListResourcesRequestParams::default() if params is None
255        // need to investigate more but this could be a inconsistency on some MCP servers
256        // where it is not required for other requests like prompts/list or tools/list etc
257        // that excepts an empty params to be passed (like server-everything)
258        let request =
259            ListResourcesRequest::new(params.or(Some(ListResourcesRequestParams::default())));
260        let response = self.request(request.into(), None).await?;
261        Ok(response.try_into()?)
262    }
263
264    async fn list_resource_templates(
265        &self,
266        params: Option<ListResourceTemplatesRequestParams>,
267    ) -> SdkResult<crate::schema::ListResourceTemplatesResult> {
268        let request = ListResourceTemplatesRequest::new(params);
269        let response = self.request(request.into(), None).await?;
270        Ok(response.try_into()?)
271    }
272
273    async fn read_resource(
274        &self,
275        params: ReadResourceRequestParams,
276    ) -> SdkResult<crate::schema::ReadResourceResult> {
277        let request = ReadResourceRequest::new(params);
278        let response = self.request(request.into(), None).await?;
279        Ok(response.try_into()?)
280    }
281
282    async fn subscribe_resource(
283        &self,
284        params: SubscribeRequestParams,
285    ) -> SdkResult<crate::schema::Result> {
286        let request = SubscribeRequest::new(params);
287        let response = self.request(request.into(), None).await?;
288        Ok(response.try_into()?)
289    }
290
291    async fn unsubscribe_resource(
292        &self,
293        params: UnsubscribeRequestParams,
294    ) -> SdkResult<crate::schema::Result> {
295        let request = UnsubscribeRequest::new(params);
296        let response = self.request(request.into(), None).await?;
297        Ok(response.try_into()?)
298    }
299
300    async fn call_tool(&self, params: CallToolRequestParams) -> SdkResult<CallToolResult> {
301        let request = CallToolRequest::new(params);
302        let response = self.request(request.into(), None).await?;
303        Ok(response.try_into()?)
304    }
305
306    async fn list_tools(
307        &self,
308        params: Option<ListToolsRequestParams>,
309    ) -> SdkResult<crate::schema::ListToolsResult> {
310        let request = ListToolsRequest::new(params);
311        let response = self.request(request.into(), None).await?;
312        Ok(response.try_into()?)
313    }
314
315    async fn send_roots_list_changed(
316        &self,
317        params: Option<RootsListChangedNotificationParams>,
318    ) -> SdkResult<()> {
319        let notification = RootsListChangedNotification::new(params);
320        self.send_notification(notification.into()).await
321    }
322
323    /// Asserts that server capabilities support the requested method.
324    ///
325    /// Verifies that the server has the necessary capabilities to handle the given request method.
326    /// If the server is not initialized or lacks a required capability, an error is returned.
327    /// This can be utilized to avoid sending requests when the opposing party lacks support for them.
328    fn assert_server_capabilities(&self, request_method: &String) -> SdkResult<()> {
329        let entity = "Server";
330
331        let capabilities = self.server_capabilities().ok_or::<RpcError>(
332            RpcError::internal_error().with_message("Server is not initialized!".to_string()),
333        )?;
334
335        if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() {
336            return Err(RpcError::internal_error()
337                .with_message(format_assertion_message(entity, "logging", request_method))
338                .into());
339        }
340
341        if [
342            GetPromptRequest::method_name(),
343            ListPromptsRequest::method_name(),
344        ]
345        .contains(request_method)
346            && capabilities.prompts.is_none()
347        {
348            return Err(RpcError::internal_error()
349                .with_message(format_assertion_message(entity, "prompts", request_method))
350                .into());
351        }
352
353        if [
354            ListResourcesRequest::method_name(),
355            ListResourceTemplatesRequest::method_name(),
356            ReadResourceRequest::method_name(),
357            SubscribeRequest::method_name(),
358            UnsubscribeRequest::method_name(),
359        ]
360        .contains(request_method)
361            && capabilities.resources.is_none()
362        {
363            return Err(RpcError::internal_error()
364                .with_message(format_assertion_message(
365                    entity,
366                    "resources",
367                    request_method,
368                ))
369                .into());
370        }
371
372        if [
373            CallToolRequest::method_name(),
374            ListToolsRequest::method_name(),
375        ]
376        .contains(request_method)
377            && capabilities.tools.is_none()
378        {
379            return Err(RpcError::internal_error()
380                .with_message(format_assertion_message(entity, "tools", request_method))
381                .into());
382        }
383
384        Ok(())
385    }
386
387    fn assert_client_notification_capabilities(
388        &self,
389        notification_method: &String,
390    ) -> std::result::Result<(), RpcError> {
391        let entity = "Client";
392        let capabilities = &self.client_info().capabilities;
393
394        if *notification_method == RootsListChangedNotification::method_name()
395            && capabilities.roots.is_some()
396        {
397            return Err(
398                RpcError::internal_error().with_message(format_assertion_message(
399                    entity,
400                    "roots list changed notifications",
401                    notification_method,
402                )),
403            );
404        }
405
406        Ok(())
407    }
408
409    fn assert_client_request_capabilities(
410        &self,
411        request_method: &String,
412    ) -> std::result::Result<(), RpcError> {
413        let entity = "Client";
414        let capabilities = &self.client_info().capabilities;
415
416        if *request_method == CreateMessageRequest::method_name() && capabilities.sampling.is_some()
417        {
418            return Err(
419                RpcError::internal_error().with_message(format_assertion_message(
420                    entity,
421                    "sampling capability",
422                    request_method,
423                )),
424            );
425        }
426
427        if *request_method == ListRootsRequest::method_name() && capabilities.roots.is_some() {
428            return Err(
429                RpcError::internal_error().with_message(format_assertion_message(
430                    entity,
431                    "roots capability",
432                    request_method,
433                )),
434            );
435        }
436
437        Ok(())
438    }
439}