rust_mcp_sdk/mcp_traits/
mcp_server.rs

1use crate::schema::{
2    schema_utils::{
3        ClientMessage, McpMessage, MessageFromServer, NotificationFromServer, RequestFromServer,
4        ResultFromClient, ServerMessage,
5    },
6    CallToolRequest, CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult,
7    ElicitRequest, ElicitRequestParams, ElicitRequestedSchema, ElicitResult, GetPromptRequest,
8    Implementation, InitializeRequestParams, InitializeResult, ListPromptsRequest,
9    ListResourceTemplatesRequest, ListResourcesRequest, ListRootsRequest, ListRootsRequestParams,
10    ListRootsResult, ListToolsRequest, LoggingMessageNotification,
11    LoggingMessageNotificationParams, PingRequest, PromptListChangedNotification,
12    PromptListChangedNotificationParams, ReadResourceRequest, RequestId,
13    ResourceListChangedNotification, ResourceListChangedNotificationParams,
14    ResourceUpdatedNotification, ResourceUpdatedNotificationParams, RpcError, ServerCapabilities,
15    SetLevelRequest, ToolListChangedNotification, ToolListChangedNotificationParams,
16};
17use crate::{error::SdkResult, utils::format_assertion_message};
18use async_trait::async_trait;
19use rust_mcp_transport::SessionId;
20use std::{sync::Arc, time::Duration};
21
22//TODO: support options , such as enforceStrictCapabilities
23#[async_trait]
24pub trait McpServer: Sync + Send {
25    async fn start(self: Arc<Self>) -> SdkResult<()>;
26    async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>;
27    fn server_info(&self) -> &InitializeResult;
28    fn client_info(&self) -> Option<InitializeRequestParams>;
29
30    async fn wait_for_initialization(&self);
31
32    async fn send(
33        &self,
34        message: MessageFromServer,
35        request_id: Option<RequestId>,
36        request_timeout: Option<Duration>,
37    ) -> SdkResult<Option<ClientMessage>>;
38
39    async fn send_batch(
40        &self,
41        messages: Vec<ServerMessage>,
42        request_timeout: Option<Duration>,
43    ) -> SdkResult<Option<Vec<ClientMessage>>>;
44
45    /// Checks whether the server has been initialized with client
46    fn is_initialized(&self) -> bool {
47        self.client_info().is_some()
48    }
49
50    /// Returns the client's name and version information once initialization is complete.
51    /// This method retrieves the client details, if available, after successful initialization.
52    fn client_version(&self) -> Option<Implementation> {
53        self.client_info()
54            .map(|client_details| client_details.client_info)
55    }
56
57    /// Returns the server's capabilities.
58    fn capabilities(&self) -> &ServerCapabilities {
59        &self.server_info().capabilities
60    }
61
62    /// Sends an elicitation request to the client to prompt user input and returns the received response.
63    ///
64    /// The requested_schema argument allows servers to define the structure of the expected response using a restricted subset of JSON Schema.
65    /// To simplify client user experience, elicitation schemas are limited to flat objects with primitive properties only
66    async fn elicit_input(
67        &self,
68        message: String,
69        requested_schema: ElicitRequestedSchema,
70    ) -> SdkResult<ElicitResult> {
71        let request: ElicitRequest = ElicitRequest::new(ElicitRequestParams {
72            message,
73            requested_schema,
74        });
75        let response = self.request(request.into(), None).await?;
76        ElicitResult::try_from(response).map_err(|err| err.into())
77    }
78
79    /// Sends a request to the client and processes the response.
80    ///
81    /// This function sends a `RequestFromServer` message to the client, waits for the response,
82    /// and handles the result. If the response is empty or of an invalid type, an error is returned.
83    /// Otherwise, it returns the result from the client.
84    async fn request(
85        &self,
86        request: RequestFromServer,
87        timeout: Option<Duration>,
88    ) -> SdkResult<ResultFromClient> {
89        // Send the request and receive the response.
90        let response = self
91            .send(MessageFromServer::RequestFromServer(request), None, timeout)
92            .await?;
93
94        let client_message = response.ok_or_else(|| {
95            RpcError::internal_error()
96                .with_message("An empty response was received from the client.".to_string())
97        })?;
98
99        if client_message.is_error() {
100            return Err(client_message.as_error()?.error.into());
101        }
102
103        return Ok(client_message.as_response()?.result);
104    }
105
106    /// Sends a notification. This is a one-way message that is not expected
107    /// to return any response. The method asynchronously sends the notification using
108    /// the transport layer and does not wait for any acknowledgement or result.
109    async fn send_notification(&self, notification: NotificationFromServer) -> SdkResult<()> {
110        self.send(
111            MessageFromServer::NotificationFromServer(notification),
112            None,
113            None,
114        )
115        .await?;
116        Ok(())
117    }
118
119    /// Request a list of root URIs from the client. Roots allow
120    /// servers to ask for specific directories or files to operate on. A common example
121    /// for roots is providing a set of repositories or directories a server should operate on.
122    /// This request is typically used when the server needs to understand the file system
123    /// structure or access specific locations that the client has permission to read from
124    async fn list_roots(
125        &self,
126        params: Option<ListRootsRequestParams>,
127    ) -> SdkResult<ListRootsResult> {
128        let request: ListRootsRequest = ListRootsRequest::new(params);
129        let response = self.request(request.into(), None).await?;
130        ListRootsResult::try_from(response).map_err(|err| err.into())
131    }
132
133    /// Send log message notification from server to client.
134    /// If no logging/setLevel request has been sent from the client, the server MAY decide which messages to send automatically.
135    async fn send_logging_message(
136        &self,
137        params: LoggingMessageNotificationParams,
138    ) -> SdkResult<()> {
139        let notification = LoggingMessageNotification::new(params);
140        self.send_notification(notification.into()).await
141    }
142
143    /// An optional notification from the server to the client, informing it that
144    /// the list of prompts it offers has changed.
145    /// This may be issued by servers without any previous subscription from the client.
146    async fn send_prompt_list_changed(
147        &self,
148        params: Option<PromptListChangedNotificationParams>,
149    ) -> SdkResult<()> {
150        let notification = PromptListChangedNotification::new(params);
151        self.send_notification(notification.into()).await
152    }
153
154    /// An optional notification from the server to the client,
155    /// informing it that the list of resources it can read from has changed.
156    /// This may be issued by servers without any previous subscription from the client.
157    async fn send_resource_list_changed(
158        &self,
159        params: Option<ResourceListChangedNotificationParams>,
160    ) -> SdkResult<()> {
161        let notification = ResourceListChangedNotification::new(params);
162        self.send_notification(notification.into()).await
163    }
164
165    /// A notification from the server to the client, informing it that
166    /// a resource has changed and may need to be read again.
167    ///  This should only be sent if the client previously sent a resources/subscribe request.
168    async fn send_resource_updated(
169        &self,
170        params: ResourceUpdatedNotificationParams,
171    ) -> SdkResult<()> {
172        let notification = ResourceUpdatedNotification::new(params);
173        self.send_notification(notification.into()).await
174    }
175
176    /// An optional notification from the server to the client, informing it that
177    /// the list of tools it offers has changed.
178    /// This may be issued by servers without any previous subscription from the client.
179    async fn send_tool_list_changed(
180        &self,
181        params: Option<ToolListChangedNotificationParams>,
182    ) -> SdkResult<()> {
183        let notification = ToolListChangedNotification::new(params);
184        self.send_notification(notification.into()).await
185    }
186
187    /// A ping request to check that the other party is still alive.
188    /// The receiver must promptly respond, or else may be disconnected.
189    ///
190    /// This function creates a `PingRequest` with no specific parameters, sends the request and awaits the response
191    /// Once the response is received, it attempts to convert it into the expected
192    /// result type.
193    ///
194    /// # Returns
195    /// A `SdkResult` containing the `rust_mcp_schema::Result` if the request is successful.
196    /// If the request or conversion fails, an error is returned.
197    async fn ping(&self, timeout: Option<Duration>) -> SdkResult<crate::schema::Result> {
198        let ping_request = PingRequest::new(None);
199        let response = self.request(ping_request.into(), timeout).await?;
200        Ok(response.try_into()?)
201    }
202
203    /// A request from the server to sample an LLM via the client.
204    /// The client has full discretion over which model to select.
205    /// The client should also inform the user before beginning sampling,
206    /// to allow them to inspect the request (human in the loop)
207    /// and decide whether to approve it.
208    async fn create_message(
209        &self,
210        params: CreateMessageRequestParams,
211    ) -> SdkResult<CreateMessageResult> {
212        let ping_request = CreateMessageRequest::new(params);
213        let response = self.request(ping_request.into(), None).await?;
214        Ok(response.try_into()?)
215    }
216
217    /// Checks if the client supports sampling.
218    ///
219    /// This function retrieves the client information and checks if the
220    /// client has sampling capabilities listed. If the client info has
221    /// not been retrieved yet, it returns `None`. Otherwise, it returns
222    /// `Some(true)` if sampling is supported, or `Some(false)` if not.
223    ///
224    /// # Returns
225    /// - `None` if client information is not yet available.
226    /// - `Some(true)` if sampling is supported by the client.
227    /// - `Some(false)` if sampling is not supported by the client.
228    fn client_supports_sampling(&self) -> Option<bool> {
229        self.client_info()
230            .map(|client_details| client_details.capabilities.sampling.is_some())
231    }
232
233    /// Checks if the client supports listing roots.
234    ///
235    /// This function retrieves the client information and checks if the
236    /// client has listing roots capabilities listed. If the client info has
237    /// not been retrieved yet, it returns `None`. Otherwise, it returns
238    /// `Some(true)` if listing roots is supported, or `Some(false)` if not.
239    ///
240    /// # Returns
241    /// - `None` if client information is not yet available.
242    /// - `Some(true)` if listing roots is supported by the client.
243    /// - `Some(false)` if listing roots is not supported by the client.
244    fn client_supports_root_list(&self) -> Option<bool> {
245        self.client_info()
246            .map(|client_details| client_details.capabilities.roots.is_some())
247    }
248
249    /// Checks if the client has experimental capabilities available.
250    ///
251    /// This function retrieves the client information and checks if the
252    /// client has experimental listed in its capabilities. If the client info
253    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
254    /// `Some(true)` if experimental is available, or `Some(false)` if not.
255    ///
256    /// # Returns
257    /// - `None` if client information is not yet available.
258    /// - `Some(true)` if experimental capabilities are available on the client.
259    /// - `Some(false)` if no experimental capabilities are available on the client.
260    fn client_supports_experimental(&self) -> Option<bool> {
261        self.client_info()
262            .map(|client_details| client_details.capabilities.experimental.is_some())
263    }
264
265    /// Sends a message to the standard error output (stderr) asynchronously.
266    async fn stderr_message(&self, message: String) -> SdkResult<()>;
267
268    /// Asserts that client capabilities are available for a given server request.
269    ///
270    /// This method verifies that the client capabilities required to process the specified
271    /// server request have been retrieved and are accessible. It returns an error if the
272    /// capabilities are not available.
273    ///
274    /// This can be utilized to avoid sending requests when the opposing party lacks support for them.
275    fn assert_client_capabilities(
276        &self,
277        request_method: &String,
278    ) -> std::result::Result<(), RpcError> {
279        let entity = "Client";
280        if *request_method == CreateMessageRequest::method_name()
281            && !self.client_supports_sampling().unwrap_or(false)
282        {
283            return Err(
284                RpcError::internal_error().with_message(format_assertion_message(
285                    entity,
286                    "sampling",
287                    request_method,
288                )),
289            );
290        }
291        if *request_method == ListRootsRequest::method_name()
292            && !self.client_supports_root_list().unwrap_or(false)
293        {
294            return Err(
295                RpcError::internal_error().with_message(format_assertion_message(
296                    entity,
297                    "listing roots",
298                    request_method,
299                )),
300            );
301        }
302        Ok(())
303    }
304
305    fn assert_server_notification_capabilities(
306        &self,
307        notification_method: &String,
308    ) -> std::result::Result<(), RpcError> {
309        let entity = "Server";
310
311        let capabilities = &self.server_info().capabilities;
312
313        if *notification_method == LoggingMessageNotification::method_name()
314            && capabilities.logging.is_none()
315        {
316            return Err(
317                RpcError::internal_error().with_message(format_assertion_message(
318                    entity,
319                    "logging",
320                    notification_method,
321                )),
322            );
323        }
324        if *notification_method == ResourceUpdatedNotification::method_name()
325            && capabilities.resources.is_none()
326        {
327            return Err(
328                RpcError::internal_error().with_message(format_assertion_message(
329                    entity,
330                    "notifying about resources",
331                    notification_method,
332                )),
333            );
334        }
335        if *notification_method == ToolListChangedNotification::method_name()
336            && capabilities.tools.is_none()
337        {
338            return Err(
339                RpcError::internal_error().with_message(format_assertion_message(
340                    entity,
341                    "notifying of tool list changes",
342                    notification_method,
343                )),
344            );
345        }
346        if *notification_method == PromptListChangedNotification::method_name()
347            && capabilities.prompts.is_none()
348        {
349            return Err(
350                RpcError::internal_error().with_message(format_assertion_message(
351                    entity,
352                    "notifying of prompt list changes",
353                    notification_method,
354                )),
355            );
356        }
357
358        Ok(())
359    }
360
361    fn assert_server_request_capabilities(
362        &self,
363        request_method: &String,
364    ) -> std::result::Result<(), RpcError> {
365        let entity = "Server";
366        let capabilities = &self.server_info().capabilities;
367
368        if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() {
369            return Err(
370                RpcError::internal_error().with_message(format_assertion_message(
371                    entity,
372                    "logging",
373                    request_method,
374                )),
375            );
376        }
377        if [
378            GetPromptRequest::method_name(),
379            ListPromptsRequest::method_name(),
380        ]
381        .contains(request_method)
382            && capabilities.prompts.is_none()
383        {
384            return Err(
385                RpcError::internal_error().with_message(format_assertion_message(
386                    entity,
387                    "prompts",
388                    request_method,
389                )),
390            );
391        }
392        if [
393            ListResourcesRequest::method_name(),
394            ListResourceTemplatesRequest::method_name(),
395            ReadResourceRequest::method_name(),
396        ]
397        .contains(request_method)
398            && capabilities.resources.is_none()
399        {
400            return Err(
401                RpcError::internal_error().with_message(format_assertion_message(
402                    entity,
403                    "resources",
404                    request_method,
405                )),
406            );
407        }
408        if [
409            CallToolRequest::method_name(),
410            ListToolsRequest::method_name(),
411        ]
412        .contains(request_method)
413            && capabilities.tools.is_none()
414        {
415            return Err(
416                RpcError::internal_error().with_message(format_assertion_message(
417                    entity,
418                    "tools",
419                    request_method,
420                )),
421            );
422        }
423        Ok(())
424    }
425
426    #[cfg(feature = "hyper-server")]
427    fn session_id(&self) -> Option<SessionId>;
428}