rust_mcp_sdk/mcp_traits/
mcp_server.rs

1use std::time::Duration;
2
3use crate::schema::{
4    schema_utils::{
5        ClientMessage, ClientMessages, McpMessage, MessageFromServer, NotificationFromServer,
6        RequestFromServer, ResultFromClient, ServerMessage,
7    },
8    CallToolRequest, CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult,
9    GetPromptRequest, Implementation, InitializeRequestParams, InitializeResult,
10    ListPromptsRequest, ListResourceTemplatesRequest, ListResourcesRequest, ListRootsRequest,
11    ListRootsRequestParams, ListRootsResult, ListToolsRequest, LoggingMessageNotification,
12    LoggingMessageNotificationParams, PingRequest, PromptListChangedNotification,
13    PromptListChangedNotificationParams, ReadResourceRequest, RequestId,
14    ResourceListChangedNotification, ResourceListChangedNotificationParams,
15    ResourceUpdatedNotification, ResourceUpdatedNotificationParams, RpcError, ServerCapabilities,
16    SetLevelRequest, ToolListChangedNotification, ToolListChangedNotificationParams,
17};
18use async_trait::async_trait;
19
20use crate::{error::SdkResult, utils::format_assertion_message};
21
22//TODO: support options , such as enforceStrictCapabilities
23#[async_trait]
24pub trait McpServer: Sync + Send {
25    async fn start(&self) -> SdkResult<()>;
26    async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>;
27    fn server_info(&self) -> &InitializeResult;
28    fn client_info(&self) -> Option<InitializeRequestParams>;
29
30    async fn wait_for_initialization(&self);
31
32    #[deprecated(since = "0.2.0", note = "Use `client_info()` instead.")]
33    fn get_client_info(&self) -> Option<InitializeRequestParams> {
34        self.client_info()
35    }
36
37    #[deprecated(since = "0.2.0", note = "Use `server_info()` instead.")]
38    fn get_server_info(&self) -> &InitializeResult {
39        self.server_info()
40    }
41
42    async fn send(
43        &self,
44        message: MessageFromServer,
45        request_id: Option<RequestId>,
46        request_timeout: Option<Duration>,
47    ) -> SdkResult<Option<ClientMessages>>;
48
49    async fn send_batch(
50        &self,
51        messages: Vec<ServerMessage>,
52        request_timeout: Option<Duration>,
53    ) -> SdkResult<Option<Vec<ClientMessage>>>;
54
55    /// Checks whether the server has been initialized with client
56    fn is_initialized(&self) -> bool {
57        self.client_info().is_some()
58    }
59
60    /// Returns the client's name and version information once initialization is complete.
61    /// This method retrieves the client details, if available, after successful initialization.
62    fn client_version(&self) -> Option<Implementation> {
63        self.client_info()
64            .map(|client_details| client_details.client_info)
65    }
66
67    /// Returns the server's capabilities.
68    fn capabilities(&self) -> &ServerCapabilities {
69        &self.server_info().capabilities
70    }
71
72    /// Sends a request to the client and processes the response.
73    ///
74    /// This function sends a `RequestFromServer` message to the client, waits for the response,
75    /// and handles the result. If the response is empty or of an invalid type, an error is returned.
76    /// Otherwise, it returns the result from the client.
77    async fn request(
78        &self,
79        request: RequestFromServer,
80        timeout: Option<Duration>,
81    ) -> SdkResult<ResultFromClient> {
82        // Send the request and receive the response.
83        let response = self
84            .send(MessageFromServer::RequestFromServer(request), None, timeout)
85            .await?;
86
87        let client_messages = response.ok_or_else(|| {
88            RpcError::internal_error()
89                .with_message("An empty response was received from the client.".to_string())
90        })?;
91
92        let client_message = client_messages.as_single()?;
93
94        if client_message.is_error() {
95            return Err(client_message.as_error()?.error.into());
96        }
97
98        return Ok(client_message.as_response()?.result);
99    }
100
101    /// Sends a notification. This is a one-way message that is not expected
102    /// to return any response. The method asynchronously sends the notification using
103    /// the transport layer and does not wait for any acknowledgement or result.
104    async fn send_notification(&self, notification: NotificationFromServer) -> SdkResult<()> {
105        self.send(
106            MessageFromServer::NotificationFromServer(notification),
107            None,
108            None,
109        )
110        .await?;
111        Ok(())
112    }
113
114    /// Request a list of root URIs from the client. Roots allow
115    /// servers to ask for specific directories or files to operate on. A common example
116    /// for roots is providing a set of repositories or directories a server should operate on.
117    /// This request is typically used when the server needs to understand the file system
118    /// structure or access specific locations that the client has permission to read from
119    async fn list_roots(
120        &self,
121        params: Option<ListRootsRequestParams>,
122    ) -> SdkResult<ListRootsResult> {
123        let request: ListRootsRequest = ListRootsRequest::new(params);
124        let response = self.request(request.into(), None).await?;
125        ListRootsResult::try_from(response).map_err(|err| err.into())
126    }
127
128    /// Send log message notification from server to client.
129    /// If no logging/setLevel request has been sent from the client, the server MAY decide which messages to send automatically.
130    async fn send_logging_message(
131        &self,
132        params: LoggingMessageNotificationParams,
133    ) -> SdkResult<()> {
134        let notification = LoggingMessageNotification::new(params);
135        self.send_notification(notification.into()).await
136    }
137
138    /// An optional notification from the server to the client, informing it that
139    /// the list of prompts it offers has changed.
140    /// This may be issued by servers without any previous subscription from the client.
141    async fn send_prompt_list_changed(
142        &self,
143        params: Option<PromptListChangedNotificationParams>,
144    ) -> SdkResult<()> {
145        let notification = PromptListChangedNotification::new(params);
146        self.send_notification(notification.into()).await
147    }
148
149    /// An optional notification from the server to the client,
150    /// informing it that the list of resources it can read from has changed.
151    /// This may be issued by servers without any previous subscription from the client.
152    async fn send_resource_list_changed(
153        &self,
154        params: Option<ResourceListChangedNotificationParams>,
155    ) -> SdkResult<()> {
156        let notification = ResourceListChangedNotification::new(params);
157        self.send_notification(notification.into()).await
158    }
159
160    /// A notification from the server to the client, informing it that
161    /// a resource has changed and may need to be read again.
162    ///  This should only be sent if the client previously sent a resources/subscribe request.
163    async fn send_resource_updated(
164        &self,
165        params: ResourceUpdatedNotificationParams,
166    ) -> SdkResult<()> {
167        let notification = ResourceUpdatedNotification::new(params);
168        self.send_notification(notification.into()).await
169    }
170
171    /// An optional notification from the server to the client, informing it that
172    /// the list of tools it offers has changed.
173    /// This may be issued by servers without any previous subscription from the client.
174    async fn send_tool_list_changed(
175        &self,
176        params: Option<ToolListChangedNotificationParams>,
177    ) -> SdkResult<()> {
178        let notification = ToolListChangedNotification::new(params);
179        self.send_notification(notification.into()).await
180    }
181
182    /// A ping request to check that the other party is still alive.
183    /// The receiver must promptly respond, or else may be disconnected.
184    ///
185    /// This function creates a `PingRequest` with no specific parameters, sends the request and awaits the response
186    /// Once the response is received, it attempts to convert it into the expected
187    /// result type.
188    ///
189    /// # Returns
190    /// A `SdkResult` containing the `rust_mcp_schema::Result` if the request is successful.
191    /// If the request or conversion fails, an error is returned.
192    async fn ping(&self, timeout: Option<Duration>) -> SdkResult<crate::schema::Result> {
193        let ping_request = PingRequest::new(None);
194        let response = self.request(ping_request.into(), timeout).await?;
195        Ok(response.try_into()?)
196    }
197
198    /// A request from the server to sample an LLM via the client.
199    /// The client has full discretion over which model to select.
200    /// The client should also inform the user before beginning sampling,
201    /// to allow them to inspect the request (human in the loop)
202    /// and decide whether to approve it.
203    async fn create_message(
204        &self,
205        params: CreateMessageRequestParams,
206    ) -> SdkResult<CreateMessageResult> {
207        let ping_request = CreateMessageRequest::new(params);
208        let response = self.request(ping_request.into(), None).await?;
209        Ok(response.try_into()?)
210    }
211
212    /// Checks if the client supports sampling.
213    ///
214    /// This function retrieves the client information and checks if the
215    /// client has sampling capabilities listed. If the client info has
216    /// not been retrieved yet, it returns `None`. Otherwise, it returns
217    /// `Some(true)` if sampling is supported, or `Some(false)` if not.
218    ///
219    /// # Returns
220    /// - `None` if client information is not yet available.
221    /// - `Some(true)` if sampling is supported by the client.
222    /// - `Some(false)` if sampling is not supported by the client.
223    fn client_supports_sampling(&self) -> Option<bool> {
224        self.client_info()
225            .map(|client_details| client_details.capabilities.sampling.is_some())
226    }
227
228    /// Checks if the client supports listing roots.
229    ///
230    /// This function retrieves the client information and checks if the
231    /// client has listing roots capabilities listed. If the client info has
232    /// not been retrieved yet, it returns `None`. Otherwise, it returns
233    /// `Some(true)` if listing roots is supported, or `Some(false)` if not.
234    ///
235    /// # Returns
236    /// - `None` if client information is not yet available.
237    /// - `Some(true)` if listing roots is supported by the client.
238    /// - `Some(false)` if listing roots is not supported by the client.
239    fn client_supports_root_list(&self) -> Option<bool> {
240        self.client_info()
241            .map(|client_details| client_details.capabilities.roots.is_some())
242    }
243
244    /// Checks if the client has experimental capabilities available.
245    ///
246    /// This function retrieves the client information and checks if the
247    /// client has experimental listed in its capabilities. If the client info
248    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
249    /// `Some(true)` if experimental is available, or `Some(false)` if not.
250    ///
251    /// # Returns
252    /// - `None` if client information is not yet available.
253    /// - `Some(true)` if experimental capabilities are available on the client.
254    /// - `Some(false)` if no experimental capabilities are available on the client.
255    fn client_supports_experimental(&self) -> Option<bool> {
256        self.client_info()
257            .map(|client_details| client_details.capabilities.experimental.is_some())
258    }
259
260    /// Sends a message to the standard error output (stderr) asynchronously.
261    async fn stderr_message(&self, message: String) -> SdkResult<()>;
262
263    /// Asserts that client capabilities are available for a given server request.
264    ///
265    /// This method verifies that the client capabilities required to process the specified
266    /// server request have been retrieved and are accessible. It returns an error if the
267    /// capabilities are not available.
268    ///
269    /// This can be utilized to avoid sending requests when the opposing party lacks support for them.
270    fn assert_client_capabilities(
271        &self,
272        request_method: &String,
273    ) -> std::result::Result<(), RpcError> {
274        let entity = "Client";
275        if *request_method == CreateMessageRequest::method_name()
276            && !self.client_supports_sampling().unwrap_or(false)
277        {
278            return Err(
279                RpcError::internal_error().with_message(format_assertion_message(
280                    entity,
281                    "sampling",
282                    request_method,
283                )),
284            );
285        }
286        if *request_method == ListRootsRequest::method_name()
287            && !self.client_supports_root_list().unwrap_or(false)
288        {
289            return Err(
290                RpcError::internal_error().with_message(format_assertion_message(
291                    entity,
292                    "listing roots",
293                    request_method,
294                )),
295            );
296        }
297        Ok(())
298    }
299
300    fn assert_server_notification_capabilities(
301        &self,
302        notification_method: &String,
303    ) -> std::result::Result<(), RpcError> {
304        let entity = "Server";
305
306        let capabilities = &self.server_info().capabilities;
307
308        if *notification_method == LoggingMessageNotification::method_name()
309            && capabilities.logging.is_none()
310        {
311            return Err(
312                RpcError::internal_error().with_message(format_assertion_message(
313                    entity,
314                    "logging",
315                    notification_method,
316                )),
317            );
318        }
319        if *notification_method == ResourceUpdatedNotification::method_name()
320            && capabilities.resources.is_none()
321        {
322            return Err(
323                RpcError::internal_error().with_message(format_assertion_message(
324                    entity,
325                    "notifying about resources",
326                    notification_method,
327                )),
328            );
329        }
330        if *notification_method == ToolListChangedNotification::method_name()
331            && capabilities.tools.is_none()
332        {
333            return Err(
334                RpcError::internal_error().with_message(format_assertion_message(
335                    entity,
336                    "notifying of tool list changes",
337                    notification_method,
338                )),
339            );
340        }
341        if *notification_method == PromptListChangedNotification::method_name()
342            && capabilities.prompts.is_none()
343        {
344            return Err(
345                RpcError::internal_error().with_message(format_assertion_message(
346                    entity,
347                    "notifying of prompt list changes",
348                    notification_method,
349                )),
350            );
351        }
352
353        Ok(())
354    }
355
356    fn assert_server_request_capabilities(
357        &self,
358        request_method: &String,
359    ) -> std::result::Result<(), RpcError> {
360        let entity = "Server";
361        let capabilities = &self.server_info().capabilities;
362
363        if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() {
364            return Err(
365                RpcError::internal_error().with_message(format_assertion_message(
366                    entity,
367                    "logging",
368                    request_method,
369                )),
370            );
371        }
372        if [
373            GetPromptRequest::method_name(),
374            ListPromptsRequest::method_name(),
375        ]
376        .contains(request_method)
377            && capabilities.prompts.is_none()
378        {
379            return Err(
380                RpcError::internal_error().with_message(format_assertion_message(
381                    entity,
382                    "prompts",
383                    request_method,
384                )),
385            );
386        }
387        if [
388            ListResourcesRequest::method_name(),
389            ListResourceTemplatesRequest::method_name(),
390            ReadResourceRequest::method_name(),
391        ]
392        .contains(request_method)
393            && capabilities.resources.is_none()
394        {
395            return Err(
396                RpcError::internal_error().with_message(format_assertion_message(
397                    entity,
398                    "resources",
399                    request_method,
400                )),
401            );
402        }
403        if [
404            CallToolRequest::method_name(),
405            ListToolsRequest::method_name(),
406        ]
407        .contains(request_method)
408            && capabilities.tools.is_none()
409        {
410            return Err(
411                RpcError::internal_error().with_message(format_assertion_message(
412                    entity,
413                    "tools",
414                    request_method,
415                )),
416            );
417        }
418        Ok(())
419    }
420}