rust_mcp_sdk/mcp_traits/
mcp_client.rs

1use crate::schema::{
2    schema_utils::{
3        ClientMessage, McpMessage, MessageFromClient, NotificationFromClient, RequestFromClient,
4        ResultFromServer, ServerMessage,
5    },
6    CallToolRequest, CallToolRequestParams, CallToolResult, CompleteRequest, CompleteRequestParams,
7    CreateMessageRequest, GetPromptRequest, GetPromptRequestParams, Implementation,
8    InitializeRequestParams, InitializeResult, ListPromptsRequest, ListPromptsRequestParams,
9    ListResourceTemplatesRequest, ListResourceTemplatesRequestParams, ListResourcesRequest,
10    ListResourcesRequestParams, ListRootsRequest, ListToolsRequest, ListToolsRequestParams,
11    LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams, RequestId,
12    RootsListChangedNotification, RootsListChangedNotificationParams, RpcError, ServerCapabilities,
13    SetLevelRequest, SetLevelRequestParams, SubscribeRequest, SubscribeRequestParams,
14    UnsubscribeRequest, UnsubscribeRequestParams,
15};
16use crate::{error::SdkResult, utils::format_assertion_message};
17use async_trait::async_trait;
18use std::{sync::Arc, time::Duration};
19
20#[async_trait]
21pub trait McpClient: Sync + Send {
22    async fn start(self: Arc<Self>) -> SdkResult<()>;
23    fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()>;
24
25    async fn terminate_session(&self);
26
27    async fn shut_down(&self) -> SdkResult<()>;
28    async fn is_shut_down(&self) -> bool;
29
30    fn client_info(&self) -> &InitializeRequestParams;
31    fn server_info(&self) -> Option<InitializeResult>;
32
33    /// Checks whether the server has been initialized with client
34    fn is_initialized(&self) -> bool {
35        self.server_info().is_some()
36    }
37
38    /// Returns the server's name and version information once initialization is complete.
39    /// This method retrieves the server details, if available, after successful initialization.
40    fn server_version(&self) -> Option<Implementation> {
41        self.server_info()
42            .map(|server_details| server_details.server_info)
43    }
44
45    /// Returns the server's capabilities.
46    /// After initialization has completed, this will be populated with the server's reported capabilities.
47    fn server_capabilities(&self) -> Option<ServerCapabilities> {
48        self.server_info().map(|item| item.capabilities)
49    }
50
51    /// Checks if the server has tools available.
52    ///
53    /// This function retrieves the server information and checks if the
54    /// server has tools listed in its capabilities. If the server info
55    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
56    /// `Some(true)` if tools are available, or `Some(false)` if not.
57    ///
58    /// # Returns
59    /// - `None` if server information is not yet available.
60    /// - `Some(true)` if tools are available on the server.
61    /// - `Some(false)` if no tools are available on the server.
62    /// ```rust
63    /// println!("{}",1);
64    /// ```
65    fn server_has_tools(&self) -> Option<bool> {
66        self.server_info()
67            .map(|server_details| server_details.capabilities.tools.is_some())
68    }
69
70    /// Checks if the server has prompts available.
71    ///
72    /// This function retrieves the server information and checks if the
73    /// server has prompts listed in its capabilities. If the server info
74    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
75    /// `Some(true)` if prompts are available, or `Some(false)` if not.
76    ///
77    /// # Returns
78    /// - `None` if server information is not yet available.
79    /// - `Some(true)` if prompts are available on the server.
80    /// - `Some(false)` if no prompts are available on the server.
81    fn server_has_prompts(&self) -> Option<bool> {
82        self.server_info()
83            .map(|server_details| server_details.capabilities.prompts.is_some())
84    }
85
86    /// Checks if the server has experimental capabilities available.
87    ///
88    /// This function retrieves the server information and checks if the
89    /// server has experimental listed in its capabilities. If the server info
90    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
91    /// `Some(true)` if experimental is available, or `Some(false)` if not.
92    ///
93    /// # Returns
94    /// - `None` if server information is not yet available.
95    /// - `Some(true)` if experimental capabilities are available on the server.
96    /// - `Some(false)` if no experimental capabilities are available on the server.
97    fn server_has_experimental(&self) -> Option<bool> {
98        self.server_info()
99            .map(|server_details| server_details.capabilities.experimental.is_some())
100    }
101
102    /// Checks if the server has resources available.
103    ///
104    /// This function retrieves the server information and checks if the
105    /// server has resources listed in its capabilities. If the server info
106    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
107    /// `Some(true)` if resources are available, or `Some(false)` if not.
108    ///
109    /// # Returns
110    /// - `None` if server information is not yet available.
111    /// - `Some(true)` if resources are available on the server.
112    /// - `Some(false)` if no resources are available on the server.
113    fn server_has_resources(&self) -> Option<bool> {
114        self.server_info()
115            .map(|server_details| server_details.capabilities.resources.is_some())
116    }
117
118    /// Checks if the server supports logging.
119    ///
120    /// This function retrieves the server information and checks if the
121    /// server has logging capabilities listed. If the server info has
122    /// not been retrieved yet, it returns `None`. Otherwise, it returns
123    /// `Some(true)` if logging is supported, or `Some(false)` if not.
124    ///
125    /// # Returns
126    /// - `None` if server information is not yet available.
127    /// - `Some(true)` if logging is supported by the server.
128    /// - `Some(false)` if logging is not supported by the server.
129    fn server_supports_logging(&self) -> Option<bool> {
130        self.server_info()
131            .map(|server_details| server_details.capabilities.logging.is_some())
132    }
133
134    fn instructions(&self) -> Option<String> {
135        self.server_info()?.instructions
136    }
137
138    /// Sends a request to the server and processes the response.
139    ///
140    /// This function sends a `RequestFromClient` message to the server, waits for the response,
141    /// and handles the result. If the response is empty or of an invalid type, an error is returned.
142    /// Otherwise, it returns the result from the server.
143    async fn request(
144        &self,
145        request: RequestFromClient,
146        timeout: Option<Duration>,
147    ) -> SdkResult<ResultFromServer> {
148        let response = self
149            .send(MessageFromClient::RequestFromClient(request), None, timeout)
150            .await?;
151
152        let server_message = response.ok_or_else(|| {
153            RpcError::internal_error()
154                .with_message("An empty response was received from the client.".to_string())
155        })?;
156
157        if server_message.is_error() {
158            return Err(server_message.as_error()?.error.into());
159        }
160
161        return Ok(server_message.as_response()?.result);
162    }
163
164    async fn send(
165        &self,
166        message: MessageFromClient,
167        request_id: Option<RequestId>,
168        request_timeout: Option<Duration>,
169    ) -> SdkResult<Option<ServerMessage>>;
170
171    async fn send_batch(
172        &self,
173        messages: Vec<ClientMessage>,
174        timeout: Option<Duration>,
175    ) -> SdkResult<Option<Vec<ServerMessage>>>;
176
177    /// Sends a notification. This is a one-way message that is not expected
178    /// to return any response. The method asynchronously sends the notification using
179    /// the transport layer and does not wait for any acknowledgement or result.
180    async fn send_notification(&self, notification: NotificationFromClient) -> SdkResult<()> {
181        self.send(notification.into(), None, None).await?;
182        Ok(())
183    }
184
185    /// A ping request to check that the other party is still alive.
186    /// The receiver must promptly respond, or else may be disconnected.
187    ///
188    /// This function creates a `PingRequest` with no specific parameters, sends the request and awaits the response
189    /// Once the response is received, it attempts to convert it into the expected
190    /// result type.
191    ///
192    /// # Returns
193    /// A `SdkResult` containing the `rust_mcp_schema::Result` if the request is successful.
194    /// If the request or conversion fails, an error is returned.
195    async fn ping(&self, timeout: Option<Duration>) -> SdkResult<crate::schema::Result> {
196        let ping_request = PingRequest::new(None);
197        let response = self.request(ping_request.into(), timeout).await?;
198        Ok(response.try_into()?)
199    }
200
201    async fn complete(
202        &self,
203        params: CompleteRequestParams,
204    ) -> SdkResult<crate::schema::CompleteResult> {
205        let request = CompleteRequest::new(params);
206        let response = self.request(request.into(), None).await?;
207        Ok(response.try_into()?)
208    }
209
210    async fn set_logging_level(&self, level: LoggingLevel) -> SdkResult<crate::schema::Result> {
211        let request = SetLevelRequest::new(SetLevelRequestParams { level });
212        let response = self.request(request.into(), None).await?;
213        Ok(response.try_into()?)
214    }
215
216    async fn get_prompt(
217        &self,
218        params: GetPromptRequestParams,
219    ) -> SdkResult<crate::schema::GetPromptResult> {
220        let request = GetPromptRequest::new(params);
221        let response = self.request(request.into(), None).await?;
222        Ok(response.try_into()?)
223    }
224
225    async fn list_prompts(
226        &self,
227        params: Option<ListPromptsRequestParams>,
228    ) -> SdkResult<crate::schema::ListPromptsResult> {
229        let request = ListPromptsRequest::new(params);
230        let response = self.request(request.into(), None).await?;
231        Ok(response.try_into()?)
232    }
233
234    async fn list_resources(
235        &self,
236        params: Option<ListResourcesRequestParams>,
237    ) -> SdkResult<crate::schema::ListResourcesResult> {
238        // passing ListResourcesRequestParams::default() if params is None
239        // need to investigate more but this could be a inconsistency on some MCP servers
240        // where it is not required for other requests like prompts/list or tools/list etc
241        // that excepts an empty params to be passed (like server-everything)
242        let request =
243            ListResourcesRequest::new(params.or(Some(ListResourcesRequestParams::default())));
244        let response = self.request(request.into(), None).await?;
245        Ok(response.try_into()?)
246    }
247
248    async fn list_resource_templates(
249        &self,
250        params: Option<ListResourceTemplatesRequestParams>,
251    ) -> SdkResult<crate::schema::ListResourceTemplatesResult> {
252        let request = ListResourceTemplatesRequest::new(params);
253        let response = self.request(request.into(), None).await?;
254        Ok(response.try_into()?)
255    }
256
257    async fn read_resource(
258        &self,
259        params: ReadResourceRequestParams,
260    ) -> SdkResult<crate::schema::ReadResourceResult> {
261        let request = ReadResourceRequest::new(params);
262        let response = self.request(request.into(), None).await?;
263        Ok(response.try_into()?)
264    }
265
266    async fn subscribe_resource(
267        &self,
268        params: SubscribeRequestParams,
269    ) -> SdkResult<crate::schema::Result> {
270        let request = SubscribeRequest::new(params);
271        let response = self.request(request.into(), None).await?;
272        Ok(response.try_into()?)
273    }
274
275    async fn unsubscribe_resource(
276        &self,
277        params: UnsubscribeRequestParams,
278    ) -> SdkResult<crate::schema::Result> {
279        let request = UnsubscribeRequest::new(params);
280        let response = self.request(request.into(), None).await?;
281        Ok(response.try_into()?)
282    }
283
284    async fn call_tool(&self, params: CallToolRequestParams) -> SdkResult<CallToolResult> {
285        let request = CallToolRequest::new(params);
286        let response = self.request(request.into(), None).await?;
287        Ok(response.try_into()?)
288    }
289
290    async fn list_tools(
291        &self,
292        params: Option<ListToolsRequestParams>,
293    ) -> SdkResult<crate::schema::ListToolsResult> {
294        let request = ListToolsRequest::new(params);
295        let response = self.request(request.into(), None).await?;
296        Ok(response.try_into()?)
297    }
298
299    async fn send_roots_list_changed(
300        &self,
301        params: Option<RootsListChangedNotificationParams>,
302    ) -> SdkResult<()> {
303        let notification = RootsListChangedNotification::new(params);
304        self.send_notification(notification.into()).await
305    }
306
307    /// Asserts that server capabilities support the requested method.
308    ///
309    /// Verifies that the server has the necessary capabilities to handle the given request method.
310    /// If the server is not initialized or lacks a required capability, an error is returned.
311    /// This can be utilized to avoid sending requests when the opposing party lacks support for them.
312    fn assert_server_capabilities(&self, request_method: &String) -> SdkResult<()> {
313        let entity = "Server";
314
315        let capabilities = self.server_capabilities().ok_or::<RpcError>(
316            RpcError::internal_error().with_message("Server is not initialized!".to_string()),
317        )?;
318
319        if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() {
320            return Err(RpcError::internal_error()
321                .with_message(format_assertion_message(entity, "logging", request_method))
322                .into());
323        }
324
325        if [
326            GetPromptRequest::method_name(),
327            ListPromptsRequest::method_name(),
328        ]
329        .contains(request_method)
330            && capabilities.prompts.is_none()
331        {
332            return Err(RpcError::internal_error()
333                .with_message(format_assertion_message(entity, "prompts", request_method))
334                .into());
335        }
336
337        if [
338            ListResourcesRequest::method_name(),
339            ListResourceTemplatesRequest::method_name(),
340            ReadResourceRequest::method_name(),
341            SubscribeRequest::method_name(),
342            UnsubscribeRequest::method_name(),
343        ]
344        .contains(request_method)
345            && capabilities.resources.is_none()
346        {
347            return Err(RpcError::internal_error()
348                .with_message(format_assertion_message(
349                    entity,
350                    "resources",
351                    request_method,
352                ))
353                .into());
354        }
355
356        if [
357            CallToolRequest::method_name(),
358            ListToolsRequest::method_name(),
359        ]
360        .contains(request_method)
361            && capabilities.tools.is_none()
362        {
363            return Err(RpcError::internal_error()
364                .with_message(format_assertion_message(entity, "tools", request_method))
365                .into());
366        }
367
368        Ok(())
369    }
370
371    fn assert_client_notification_capabilities(
372        &self,
373        notification_method: &String,
374    ) -> std::result::Result<(), RpcError> {
375        let entity = "Client";
376        let capabilities = &self.client_info().capabilities;
377
378        if *notification_method == RootsListChangedNotification::method_name()
379            && capabilities.roots.is_some()
380        {
381            return Err(
382                RpcError::internal_error().with_message(format_assertion_message(
383                    entity,
384                    "roots list changed notifications",
385                    notification_method,
386                )),
387            );
388        }
389
390        Ok(())
391    }
392
393    fn assert_client_request_capabilities(
394        &self,
395        request_method: &String,
396    ) -> std::result::Result<(), RpcError> {
397        let entity = "Client";
398        let capabilities = &self.client_info().capabilities;
399
400        if *request_method == CreateMessageRequest::method_name() && capabilities.sampling.is_some()
401        {
402            return Err(
403                RpcError::internal_error().with_message(format_assertion_message(
404                    entity,
405                    "sampling capability",
406                    request_method,
407                )),
408            );
409        }
410
411        if *request_method == ListRootsRequest::method_name() && capabilities.roots.is_some() {
412            return Err(
413                RpcError::internal_error().with_message(format_assertion_message(
414                    entity,
415                    "roots capability",
416                    request_method,
417                )),
418            );
419        }
420
421        Ok(())
422    }
423}