rust_mcp_sdk/mcp_traits/
mcp_server.rs

1use crate::auth::AuthInfo;
2use crate::{error::SdkResult, utils::format_assertion_message};
3
4use crate::schema::{
5    schema_utils::{
6        ClientMessage, McpMessage, MessageFromServer, NotificationFromServer, RequestFromServer,
7        ResultFromClient, ServerMessage,
8    },
9    CallToolRequest, CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult,
10    ElicitRequest, ElicitRequestParams, ElicitRequestedSchema, ElicitResult, GetPromptRequest,
11    Implementation, InitializeRequestParams, InitializeResult, ListPromptsRequest,
12    ListResourceTemplatesRequest, ListResourcesRequest, ListRootsRequest, ListRootsRequestParams,
13    ListRootsResult, ListToolsRequest, LoggingMessageNotification,
14    LoggingMessageNotificationParams, PingRequest, PromptListChangedNotification,
15    PromptListChangedNotificationParams, ReadResourceRequest, RequestId,
16    ResourceListChangedNotification, ResourceListChangedNotificationParams,
17    ResourceUpdatedNotification, ResourceUpdatedNotificationParams, RpcError, ServerCapabilities,
18    SetLevelRequest, ToolListChangedNotification, ToolListChangedNotificationParams,
19};
20use async_trait::async_trait;
21use rust_mcp_transport::SessionId;
22use std::{sync::Arc, time::Duration};
23use tokio::sync::RwLockReadGuard;
24
25//TODO: support options , such as enforceStrictCapabilities
26#[async_trait]
27pub trait McpServer: Sync + Send {
28    async fn start(self: Arc<Self>) -> SdkResult<()>;
29    async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>;
30    fn server_info(&self) -> &InitializeResult;
31    fn client_info(&self) -> Option<InitializeRequestParams>;
32
33    async fn auth_info(&self) -> RwLockReadGuard<'_, Option<AuthInfo>>;
34    async fn auth_info_cloned(&self) -> Option<AuthInfo>;
35    async fn update_auth_info(&self, auth_info: Option<AuthInfo>);
36
37    async fn wait_for_initialization(&self);
38
39    async fn send(
40        &self,
41        message: MessageFromServer,
42        request_id: Option<RequestId>,
43        request_timeout: Option<Duration>,
44    ) -> SdkResult<Option<ClientMessage>>;
45
46    async fn send_batch(
47        &self,
48        messages: Vec<ServerMessage>,
49        request_timeout: Option<Duration>,
50    ) -> SdkResult<Option<Vec<ClientMessage>>>;
51
52    /// Checks whether the server has been initialized with client
53    fn is_initialized(&self) -> bool {
54        self.client_info().is_some()
55    }
56
57    /// Returns the client's name and version information once initialization is complete.
58    /// This method retrieves the client details, if available, after successful initialization.
59    fn client_version(&self) -> Option<Implementation> {
60        self.client_info()
61            .map(|client_details| client_details.client_info)
62    }
63
64    /// Returns the server's capabilities.
65    fn capabilities(&self) -> &ServerCapabilities {
66        &self.server_info().capabilities
67    }
68
69    /// Sends an elicitation request to the client to prompt user input and returns the received response.
70    ///
71    /// The requested_schema argument allows servers to define the structure of the expected response using a restricted subset of JSON Schema.
72    /// To simplify client user experience, elicitation schemas are limited to flat objects with primitive properties only
73    async fn elicit_input(
74        &self,
75        message: String,
76        requested_schema: ElicitRequestedSchema,
77    ) -> SdkResult<ElicitResult> {
78        let request: ElicitRequest = ElicitRequest::new(ElicitRequestParams {
79            message,
80            requested_schema,
81        });
82        let response = self.request(request.into(), None).await?;
83        ElicitResult::try_from(response).map_err(|err| err.into())
84    }
85
86    /// Sends a request to the client and processes the response.
87    ///
88    /// This function sends a `RequestFromServer` message to the client, waits for the response,
89    /// and handles the result. If the response is empty or of an invalid type, an error is returned.
90    /// Otherwise, it returns the result from the client.
91    async fn request(
92        &self,
93        request: RequestFromServer,
94        timeout: Option<Duration>,
95    ) -> SdkResult<ResultFromClient> {
96        // Send the request and receive the response.
97        let response = self
98            .send(MessageFromServer::RequestFromServer(request), None, timeout)
99            .await?;
100
101        let client_message = response.ok_or_else(|| {
102            RpcError::internal_error()
103                .with_message("An empty response was received from the client.".to_string())
104        })?;
105
106        if client_message.is_error() {
107            return Err(client_message.as_error()?.error.into());
108        }
109
110        return Ok(client_message.as_response()?.result);
111    }
112
113    /// Sends a notification. This is a one-way message that is not expected
114    /// to return any response. The method asynchronously sends the notification using
115    /// the transport layer and does not wait for any acknowledgement or result.
116    async fn send_notification(&self, notification: NotificationFromServer) -> SdkResult<()> {
117        self.send(
118            MessageFromServer::NotificationFromServer(notification),
119            None,
120            None,
121        )
122        .await?;
123        Ok(())
124    }
125
126    /// Request a list of root URIs from the client. Roots allow
127    /// servers to ask for specific directories or files to operate on. A common example
128    /// for roots is providing a set of repositories or directories a server should operate on.
129    /// This request is typically used when the server needs to understand the file system
130    /// structure or access specific locations that the client has permission to read from
131    async fn list_roots(
132        &self,
133        params: Option<ListRootsRequestParams>,
134    ) -> SdkResult<ListRootsResult> {
135        let request: ListRootsRequest = ListRootsRequest::new(params);
136        let response = self.request(request.into(), None).await?;
137        ListRootsResult::try_from(response).map_err(|err| err.into())
138    }
139
140    /// Send log message notification from server to client.
141    /// If no logging/setLevel request has been sent from the client, the server MAY decide which messages to send automatically.
142    async fn send_logging_message(
143        &self,
144        params: LoggingMessageNotificationParams,
145    ) -> SdkResult<()> {
146        let notification = LoggingMessageNotification::new(params);
147        self.send_notification(notification.into()).await
148    }
149
150    /// An optional notification from the server to the client, informing it that
151    /// the list of prompts it offers has changed.
152    /// This may be issued by servers without any previous subscription from the client.
153    async fn send_prompt_list_changed(
154        &self,
155        params: Option<PromptListChangedNotificationParams>,
156    ) -> SdkResult<()> {
157        let notification = PromptListChangedNotification::new(params);
158        self.send_notification(notification.into()).await
159    }
160
161    /// An optional notification from the server to the client,
162    /// informing it that the list of resources it can read from has changed.
163    /// This may be issued by servers without any previous subscription from the client.
164    async fn send_resource_list_changed(
165        &self,
166        params: Option<ResourceListChangedNotificationParams>,
167    ) -> SdkResult<()> {
168        let notification = ResourceListChangedNotification::new(params);
169        self.send_notification(notification.into()).await
170    }
171
172    /// A notification from the server to the client, informing it that
173    /// a resource has changed and may need to be read again.
174    ///  This should only be sent if the client previously sent a resources/subscribe request.
175    async fn send_resource_updated(
176        &self,
177        params: ResourceUpdatedNotificationParams,
178    ) -> SdkResult<()> {
179        let notification = ResourceUpdatedNotification::new(params);
180        self.send_notification(notification.into()).await
181    }
182
183    /// An optional notification from the server to the client, informing it that
184    /// the list of tools it offers has changed.
185    /// This may be issued by servers without any previous subscription from the client.
186    async fn send_tool_list_changed(
187        &self,
188        params: Option<ToolListChangedNotificationParams>,
189    ) -> SdkResult<()> {
190        let notification = ToolListChangedNotification::new(params);
191        self.send_notification(notification.into()).await
192    }
193
194    /// A ping request to check that the other party is still alive.
195    /// The receiver must promptly respond, or else may be disconnected.
196    ///
197    /// This function creates a `PingRequest` with no specific parameters, sends the request and awaits the response
198    /// Once the response is received, it attempts to convert it into the expected
199    /// result type.
200    ///
201    /// # Returns
202    /// A `SdkResult` containing the `rust_mcp_schema::Result` if the request is successful.
203    /// If the request or conversion fails, an error is returned.
204    async fn ping(&self, timeout: Option<Duration>) -> SdkResult<crate::schema::Result> {
205        let ping_request = PingRequest::new(None);
206        let response = self.request(ping_request.into(), timeout).await?;
207        Ok(response.try_into()?)
208    }
209
210    /// A request from the server to sample an LLM via the client.
211    /// The client has full discretion over which model to select.
212    /// The client should also inform the user before beginning sampling,
213    /// to allow them to inspect the request (human in the loop)
214    /// and decide whether to approve it.
215    async fn create_message(
216        &self,
217        params: CreateMessageRequestParams,
218    ) -> SdkResult<CreateMessageResult> {
219        let ping_request = CreateMessageRequest::new(params);
220        let response = self.request(ping_request.into(), None).await?;
221        Ok(response.try_into()?)
222    }
223
224    /// Checks if the client supports sampling.
225    ///
226    /// This function retrieves the client information and checks if the
227    /// client has sampling capabilities listed. If the client info has
228    /// not been retrieved yet, it returns `None`. Otherwise, it returns
229    /// `Some(true)` if sampling is supported, or `Some(false)` if not.
230    ///
231    /// # Returns
232    /// - `None` if client information is not yet available.
233    /// - `Some(true)` if sampling is supported by the client.
234    /// - `Some(false)` if sampling is not supported by the client.
235    fn client_supports_sampling(&self) -> Option<bool> {
236        self.client_info()
237            .map(|client_details| client_details.capabilities.sampling.is_some())
238    }
239
240    /// Checks if the client supports listing roots.
241    ///
242    /// This function retrieves the client information and checks if the
243    /// client has listing roots capabilities listed. If the client info has
244    /// not been retrieved yet, it returns `None`. Otherwise, it returns
245    /// `Some(true)` if listing roots is supported, or `Some(false)` if not.
246    ///
247    /// # Returns
248    /// - `None` if client information is not yet available.
249    /// - `Some(true)` if listing roots is supported by the client.
250    /// - `Some(false)` if listing roots is not supported by the client.
251    fn client_supports_root_list(&self) -> Option<bool> {
252        self.client_info()
253            .map(|client_details| client_details.capabilities.roots.is_some())
254    }
255
256    /// Checks if the client has experimental capabilities available.
257    ///
258    /// This function retrieves the client information and checks if the
259    /// client has experimental listed in its capabilities. If the client info
260    /// has not been retrieved yet, it returns `None`. Otherwise, it returns
261    /// `Some(true)` if experimental is available, or `Some(false)` if not.
262    ///
263    /// # Returns
264    /// - `None` if client information is not yet available.
265    /// - `Some(true)` if experimental capabilities are available on the client.
266    /// - `Some(false)` if no experimental capabilities are available on the client.
267    fn client_supports_experimental(&self) -> Option<bool> {
268        self.client_info()
269            .map(|client_details| client_details.capabilities.experimental.is_some())
270    }
271
272    /// Sends a message to the standard error output (stderr) asynchronously.
273    async fn stderr_message(&self, message: String) -> SdkResult<()>;
274
275    /// Asserts that client capabilities are available for a given server request.
276    ///
277    /// This method verifies that the client capabilities required to process the specified
278    /// server request have been retrieved and are accessible. It returns an error if the
279    /// capabilities are not available.
280    ///
281    /// This can be utilized to avoid sending requests when the opposing party lacks support for them.
282    fn assert_client_capabilities(
283        &self,
284        request_method: &String,
285    ) -> std::result::Result<(), RpcError> {
286        let entity = "Client";
287        if *request_method == CreateMessageRequest::method_name()
288            && !self.client_supports_sampling().unwrap_or(false)
289        {
290            return Err(
291                RpcError::internal_error().with_message(format_assertion_message(
292                    entity,
293                    "sampling",
294                    request_method,
295                )),
296            );
297        }
298        if *request_method == ListRootsRequest::method_name()
299            && !self.client_supports_root_list().unwrap_or(false)
300        {
301            return Err(
302                RpcError::internal_error().with_message(format_assertion_message(
303                    entity,
304                    "listing roots",
305                    request_method,
306                )),
307            );
308        }
309        Ok(())
310    }
311
312    fn assert_server_notification_capabilities(
313        &self,
314        notification_method: &String,
315    ) -> std::result::Result<(), RpcError> {
316        let entity = "Server";
317
318        let capabilities = &self.server_info().capabilities;
319
320        if *notification_method == LoggingMessageNotification::method_name()
321            && capabilities.logging.is_none()
322        {
323            return Err(
324                RpcError::internal_error().with_message(format_assertion_message(
325                    entity,
326                    "logging",
327                    notification_method,
328                )),
329            );
330        }
331        if *notification_method == ResourceUpdatedNotification::method_name()
332            && capabilities.resources.is_none()
333        {
334            return Err(
335                RpcError::internal_error().with_message(format_assertion_message(
336                    entity,
337                    "notifying about resources",
338                    notification_method,
339                )),
340            );
341        }
342        if *notification_method == ToolListChangedNotification::method_name()
343            && capabilities.tools.is_none()
344        {
345            return Err(
346                RpcError::internal_error().with_message(format_assertion_message(
347                    entity,
348                    "notifying of tool list changes",
349                    notification_method,
350                )),
351            );
352        }
353        if *notification_method == PromptListChangedNotification::method_name()
354            && capabilities.prompts.is_none()
355        {
356            return Err(
357                RpcError::internal_error().with_message(format_assertion_message(
358                    entity,
359                    "notifying of prompt list changes",
360                    notification_method,
361                )),
362            );
363        }
364
365        Ok(())
366    }
367
368    fn assert_server_request_capabilities(
369        &self,
370        request_method: &String,
371    ) -> std::result::Result<(), RpcError> {
372        let entity = "Server";
373        let capabilities = &self.server_info().capabilities;
374
375        if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() {
376            return Err(
377                RpcError::internal_error().with_message(format_assertion_message(
378                    entity,
379                    "logging",
380                    request_method,
381                )),
382            );
383        }
384        if [
385            GetPromptRequest::method_name(),
386            ListPromptsRequest::method_name(),
387        ]
388        .contains(request_method)
389            && capabilities.prompts.is_none()
390        {
391            return Err(
392                RpcError::internal_error().with_message(format_assertion_message(
393                    entity,
394                    "prompts",
395                    request_method,
396                )),
397            );
398        }
399        if [
400            ListResourcesRequest::method_name(),
401            ListResourceTemplatesRequest::method_name(),
402            ReadResourceRequest::method_name(),
403        ]
404        .contains(request_method)
405            && capabilities.resources.is_none()
406        {
407            return Err(
408                RpcError::internal_error().with_message(format_assertion_message(
409                    entity,
410                    "resources",
411                    request_method,
412                )),
413            );
414        }
415        if [
416            CallToolRequest::method_name(),
417            ListToolsRequest::method_name(),
418        ]
419        .contains(request_method)
420            && capabilities.tools.is_none()
421        {
422            return Err(
423                RpcError::internal_error().with_message(format_assertion_message(
424                    entity,
425                    "tools",
426                    request_method,
427                )),
428            );
429        }
430        Ok(())
431    }
432
433    #[cfg(feature = "hyper-server")]
434    fn session_id(&self) -> Option<SessionId>;
435}