rust_mcp_sdk/mcp_traits/
mcp_server.rs

1use std::time::Duration;
2
3use crate::schema::{
4    schema_utils::{
5        ClientMessage, McpMessage, MessageFromServer, NotificationFromServer, RequestFromServer,
6        ResultFromClient,
7    },
8    CallToolRequest, CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult,
9    GetPromptRequest, Implementation, InitializeRequestParams, InitializeResult,
10    ListPromptsRequest, ListResourceTemplatesRequest, ListResourcesRequest, ListRootsRequest,
11    ListRootsRequestParams, ListRootsResult, ListToolsRequest, LoggingMessageNotification,
12    LoggingMessageNotificationParams, PingRequest, PromptListChangedNotification,
13    PromptListChangedNotificationParams, ReadResourceRequest, ResourceListChangedNotification,
14    ResourceListChangedNotificationParams, ResourceUpdatedNotification,
15    ResourceUpdatedNotificationParams, RpcError, ServerCapabilities, SetLevelRequest,
16    ToolListChangedNotification, ToolListChangedNotificationParams,
17};
18use async_trait::async_trait;
19use rust_mcp_transport::{McpDispatch, MessageDispatcher};
20
21use crate::{error::SdkResult, utils::format_assertion_message};
22
23//TODO: support options , such as enforceStrictCapabilities
24#[async_trait]
25pub trait McpServer: Sync + Send {
26    async fn start(&self) -> SdkResult<()>;
27    fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>;
28    fn server_info(&self) -> &InitializeResult;
29    fn client_info(&self) -> Option<InitializeRequestParams>;
30
31    #[deprecated(since = "0.2.0", note = "Use `client_info()` instead.")]
32    fn get_client_info(&self) -> Option<InitializeRequestParams> {
33        self.client_info()
34    }
35
36    #[deprecated(since = "0.2.0", note = "Use `server_info()` instead.")]
37    fn get_server_info(&self) -> &InitializeResult {
38        self.server_info()
39    }
40
41    async fn sender(&self) -> &tokio::sync::RwLock<Option<MessageDispatcher<ClientMessage>>>
42    where
43        MessageDispatcher<ClientMessage>: McpDispatch<ClientMessage, MessageFromServer>;
44
45    /// Checks whether the server has been initialized with client
46    fn is_initialized(&self) -> bool {
47        self.client_info().is_some()
48    }
49
50    /// Returns the client's name and version information once initialization is complete.
51    /// This method retrieves the client details, if available, after successful initialization.
52    fn client_version(&self) -> Option<Implementation> {
53        self.client_info()
54            .map(|client_details| client_details.client_info)
55    }
56
57    /// Returns the server's capabilities.
58    fn capabilities(&self) -> &ServerCapabilities {
59        &self.server_info().capabilities
60    }
61
62    /// Sends a request to the client and processes the response.
63    ///
64    /// This function sends a `RequestFromServer` message to the client, waits for the response,
65    /// and handles the result. If the response is empty or of an invalid type, an error is returned.
66    /// Otherwise, it returns the result from the client.
67    async fn request(
68        &self,
69        request: RequestFromServer,
70        timeout: Option<Duration>,
71    ) -> SdkResult<ResultFromClient> {
72        let sender = self.sender().await;
73        let sender = sender.read().await;
74        let sender = sender.as_ref().unwrap();
75
76        // Send the request and receive the response.
77        let response = sender
78            .send(MessageFromServer::RequestFromServer(request), None, timeout)
79            .await?;
80        let client_message = response.ok_or_else(|| {
81            RpcError::internal_error()
82                .with_message("An empty response was received from the client.".to_string())
83        })?;
84
85        if client_message.is_error() {
86            return Err(client_message.as_error()?.error.into());
87        }
88
89        return Ok(client_message.as_response()?.result);
90    }
91
92    /// Sends a notification. This is a one-way message that is not expected
93    /// to return any response. The method asynchronously sends the notification using
94    /// the transport layer and does not wait for any acknowledgement or result.
95    async fn send_notification(&self, notification: NotificationFromServer) -> SdkResult<()> {
96        let sender = self.sender().await;
97        let sender = sender.read().await;
98        let sender = sender.as_ref().unwrap();
99
100        sender
101            .send(
102                MessageFromServer::NotificationFromServer(notification),
103                None,
104                None,
105            )
106            .await?;
107        Ok(())
108    }
109
110    /// Request a list of root URIs from the client. Roots allow
111    /// servers to ask for specific directories or files to operate on. A common example
112    /// for roots is providing a set of repositories or directories a server should operate on.
113    /// This request is typically used when the server needs to understand the file system
114    /// structure or access specific locations that the client has permission to read from
115    async fn list_roots(
116        &self,
117        params: Option<ListRootsRequestParams>,
118    ) -> SdkResult<ListRootsResult> {
119        let request: ListRootsRequest = ListRootsRequest::new(params);
120        let response = self.request(request.into(), None).await?;
121        ListRootsResult::try_from(response).map_err(|err| err.into())
122    }
123
124    /// Send log message notification from server to client.
125    /// If no logging/setLevel request has been sent from the client, the server MAY decide which messages to send automatically.
126    async fn send_logging_message(
127        &self,
128        params: LoggingMessageNotificationParams,
129    ) -> SdkResult<()> {
130        let notification = LoggingMessageNotification::new(params);
131        self.send_notification(notification.into()).await
132    }
133
134    /// An optional notification from the server to the client, informing it that
135    /// the list of prompts it offers has changed.
136    /// This may be issued by servers without any previous subscription from the client.
137    async fn send_prompt_list_changed(
138        &self,
139        params: Option<PromptListChangedNotificationParams>,
140    ) -> SdkResult<()> {
141        let notification = PromptListChangedNotification::new(params);
142        self.send_notification(notification.into()).await
143    }
144
145    /// An optional notification from the server to the client,
146    /// informing it that the list of resources it can read from has changed.
147    /// This may be issued by servers without any previous subscription from the client.
148    async fn send_resource_list_changed(
149        &self,
150        params: Option<ResourceListChangedNotificationParams>,
151    ) -> SdkResult<()> {
152        let notification = ResourceListChangedNotification::new(params);
153        self.send_notification(notification.into()).await
154    }
155
156    /// A notification from the server to the client, informing it that
157    /// a resource has changed and may need to be read again.
158    ///  This should only be sent if the client previously sent a resources/subscribe request.
159    async fn send_resource_updated(
160        &self,
161        params: ResourceUpdatedNotificationParams,
162    ) -> SdkResult<()> {
163        let notification = ResourceUpdatedNotification::new(params);
164        self.send_notification(notification.into()).await
165    }
166
167    /// An optional notification from the server to the client, informing it that
168    /// the list of tools it offers has changed.
169    /// This may be issued by servers without any previous subscription from the client.
170    async fn send_tool_list_changed(
171        &self,
172        params: Option<ToolListChangedNotificationParams>,
173    ) -> SdkResult<()> {
174        let notification = ToolListChangedNotification::new(params);
175        self.send_notification(notification.into()).await
176    }
177
178    /// A ping request to check that the other party is still alive.
179    /// The receiver must promptly respond, or else may be disconnected.
180    ///
181    /// This function creates a `PingRequest` with no specific parameters, sends the request and awaits the response
182    /// Once the response is received, it attempts to convert it into the expected
183    /// result type.
184    ///
185    /// # Returns
186    /// A `SdkResult` containing the `rust_mcp_schema::Result` if the request is successful.
187    /// If the request or conversion fails, an error is returned.
188    async fn ping(&self, timeout: Option<Duration>) -> SdkResult<crate::schema::Result> {
189        let ping_request = PingRequest::new(None);
190        let response = self.request(ping_request.into(), timeout).await?;
191        Ok(response.try_into()?)
192    }
193
194    /// A request from the server to sample an LLM via the client.
195    /// The client has full discretion over which model to select.
196    /// The client should also inform the user before beginning sampling,
197    /// to allow them to inspect the request (human in the loop)
198    /// and decide whether to approve it.
199    async fn create_message(
200        &self,
201        params: CreateMessageRequestParams,
202    ) -> SdkResult<CreateMessageResult> {
203        let ping_request = CreateMessageRequest::new(params);
204        let response = self.request(ping_request.into(), None).await?;
205        Ok(response.try_into()?)
206    }
207
208    /// Checks if the client supports sampling.
209    ///
210    /// This function retrieves the client information and checks if the
211    /// client has sampling capabilities listed. If the client info has
212    /// not been retrieved yet, it returns `None`. Otherwise, it returns
213    /// `Some(true)` if sampling is supported, or `Some(false)` if not.
214    ///
215    /// # Returns
216    /// - `None` if client information is not yet available.
217    /// - `Some(true)` if sampling is supported by the client.
218    /// - `Some(false)` if sampling is not supported by the client.
219    fn client_supports_sampling(&self) -> Option<bool> {
220        self.client_info()
221            .map(|client_details| client_details.capabilities.sampling.is_some())
222    }
223
224    /// Checks if the client supports listing roots.
225    ///
226    /// This function retrieves the client information and checks if the
227    /// client has listing roots capabilities listed. If the client info has
228    /// not been retrieved yet, it returns `None`. Otherwise, it returns
229    /// `Some(true)` if listing roots is supported, or `Some(false)` if not.
230    ///
231    /// # Returns
232    /// - `None` if client information is not yet available.
233    /// - `Some(true)` if listing roots is supported by the client.
234    /// - `Some(false)` if listing roots is not supported by the client.
235    fn client_supports_root_list(&self) -> Option<bool> {
236        self.client_info()
237            .map(|client_details| client_details.capabilities.roots.is_some())
238    }
239
240    /// Checks if the client has experimental capabilities available.
241    ///
242    /// This function retrieves the client information and checks if the
243    /// client has experimental listed in its capabilities. If the client info
244    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
245    /// `Some(true)` if experimental is available, or `Some(false)` if not.
246    ///
247    /// # Returns
248    /// - `None` if client information is not yet available.
249    /// - `Some(true)` if experimental capabilities are available on the client.
250    /// - `Some(false)` if no experimental capabilities are available on the client.
251    fn client_supports_experimental(&self) -> Option<bool> {
252        self.client_info()
253            .map(|client_details| client_details.capabilities.experimental.is_some())
254    }
255
256    /// Sends a message to the standard error output (stderr) asynchronously.
257    async fn stderr_message(&self, message: String) -> SdkResult<()>;
258
259    /// Asserts that client capabilities are available for a given server request.
260    ///
261    /// This method verifies that the client capabilities required to process the specified
262    /// server request have been retrieved and are accessible. It returns an error if the
263    /// capabilities are not available.
264    ///
265    /// This can be utilized to avoid sending requests when the opposing party lacks support for them.
266    fn assert_client_capabilities(
267        &self,
268        request_method: &String,
269    ) -> std::result::Result<(), RpcError> {
270        let entity = "Client";
271        if *request_method == CreateMessageRequest::method_name()
272            && !self.client_supports_sampling().unwrap_or(false)
273        {
274            return Err(
275                RpcError::internal_error().with_message(format_assertion_message(
276                    entity,
277                    "sampling",
278                    request_method,
279                )),
280            );
281        }
282        if *request_method == ListRootsRequest::method_name()
283            && !self.client_supports_root_list().unwrap_or(false)
284        {
285            return Err(
286                RpcError::internal_error().with_message(format_assertion_message(
287                    entity,
288                    "listing roots",
289                    request_method,
290                )),
291            );
292        }
293        Ok(())
294    }
295
296    fn assert_server_notification_capabilities(
297        &self,
298        notification_method: &String,
299    ) -> std::result::Result<(), RpcError> {
300        let entity = "Server";
301
302        let capabilities = &self.server_info().capabilities;
303
304        if *notification_method == LoggingMessageNotification::method_name()
305            && capabilities.logging.is_none()
306        {
307            return Err(
308                RpcError::internal_error().with_message(format_assertion_message(
309                    entity,
310                    "logging",
311                    notification_method,
312                )),
313            );
314        }
315        if *notification_method == ResourceUpdatedNotification::method_name()
316            && capabilities.resources.is_none()
317        {
318            return Err(
319                RpcError::internal_error().with_message(format_assertion_message(
320                    entity,
321                    "notifying about resources",
322                    notification_method,
323                )),
324            );
325        }
326        if *notification_method == ToolListChangedNotification::method_name()
327            && capabilities.tools.is_none()
328        {
329            return Err(
330                RpcError::internal_error().with_message(format_assertion_message(
331                    entity,
332                    "notifying of tool list changes",
333                    notification_method,
334                )),
335            );
336        }
337        if *notification_method == PromptListChangedNotification::method_name()
338            && capabilities.prompts.is_none()
339        {
340            return Err(
341                RpcError::internal_error().with_message(format_assertion_message(
342                    entity,
343                    "notifying of prompt list changes",
344                    notification_method,
345                )),
346            );
347        }
348
349        Ok(())
350    }
351
352    fn assert_server_request_capabilities(
353        &self,
354        request_method: &String,
355    ) -> std::result::Result<(), RpcError> {
356        let entity = "Server";
357        let capabilities = &self.server_info().capabilities;
358
359        if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() {
360            return Err(
361                RpcError::internal_error().with_message(format_assertion_message(
362                    entity,
363                    "logging",
364                    request_method,
365                )),
366            );
367        }
368        if [
369            GetPromptRequest::method_name(),
370            ListPromptsRequest::method_name(),
371        ]
372        .contains(request_method)
373            && capabilities.prompts.is_none()
374        {
375            return Err(
376                RpcError::internal_error().with_message(format_assertion_message(
377                    entity,
378                    "prompts",
379                    request_method,
380                )),
381            );
382        }
383        if [
384            ListResourcesRequest::method_name(),
385            ListResourceTemplatesRequest::method_name(),
386            ReadResourceRequest::method_name(),
387        ]
388        .contains(request_method)
389            && capabilities.resources.is_none()
390        {
391            return Err(
392                RpcError::internal_error().with_message(format_assertion_message(
393                    entity,
394                    "resources",
395                    request_method,
396                )),
397            );
398        }
399        if [
400            CallToolRequest::method_name(),
401            ListToolsRequest::method_name(),
402        ]
403        .contains(request_method)
404            && capabilities.tools.is_none()
405        {
406            return Err(
407                RpcError::internal_error().with_message(format_assertion_message(
408                    entity,
409                    "tools",
410                    request_method,
411                )),
412            );
413        }
414        Ok(())
415    }
416}