1use std::time::Duration;
2
3use crate::schema::{
4 schema_utils::{
5 ClientMessage, McpMessage, MessageFromServer, NotificationFromServer, RequestFromServer,
6 ResultFromClient,
7 },
8 CallToolRequest, CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult,
9 GetPromptRequest, Implementation, InitializeRequestParams, InitializeResult,
10 ListPromptsRequest, ListResourceTemplatesRequest, ListResourcesRequest, ListRootsRequest,
11 ListRootsRequestParams, ListRootsResult, ListToolsRequest, LoggingMessageNotification,
12 LoggingMessageNotificationParams, PingRequest, PromptListChangedNotification,
13 PromptListChangedNotificationParams, ReadResourceRequest, ResourceListChangedNotification,
14 ResourceListChangedNotificationParams, ResourceUpdatedNotification,
15 ResourceUpdatedNotificationParams, RpcError, ServerCapabilities, SetLevelRequest,
16 ToolListChangedNotification, ToolListChangedNotificationParams,
17};
18use async_trait::async_trait;
19use rust_mcp_transport::{McpDispatch, MessageDispatcher};
20
21use crate::{error::SdkResult, utils::format_assertion_message};
22
23#[async_trait]
25pub trait McpServer: Sync + Send {
26 async fn start(&self) -> SdkResult<()>;
27 fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>;
28 fn server_info(&self) -> &InitializeResult;
29 fn client_info(&self) -> Option<InitializeRequestParams>;
30
31 #[deprecated(since = "0.2.0", note = "Use `client_info()` instead.")]
32 fn get_client_info(&self) -> Option<InitializeRequestParams> {
33 self.client_info()
34 }
35
36 #[deprecated(since = "0.2.0", note = "Use `server_info()` instead.")]
37 fn get_server_info(&self) -> &InitializeResult {
38 self.server_info()
39 }
40
41 async fn sender(&self) -> &tokio::sync::RwLock<Option<MessageDispatcher<ClientMessage>>>
42 where
43 MessageDispatcher<ClientMessage>: McpDispatch<ClientMessage, MessageFromServer>;
44
45 fn is_initialized(&self) -> bool {
47 self.client_info().is_some()
48 }
49
50 fn client_version(&self) -> Option<Implementation> {
53 self.client_info()
54 .map(|client_details| client_details.client_info)
55 }
56
57 fn capabilities(&self) -> &ServerCapabilities {
59 &self.server_info().capabilities
60 }
61
62 async fn request(
68 &self,
69 request: RequestFromServer,
70 timeout: Option<Duration>,
71 ) -> SdkResult<ResultFromClient> {
72 let sender = self.sender().await;
73 let sender = sender.read().await;
74 let sender = sender.as_ref().unwrap();
75
76 let response = sender
78 .send(MessageFromServer::RequestFromServer(request), None, timeout)
79 .await?;
80 let client_message = response.ok_or_else(|| {
81 RpcError::internal_error()
82 .with_message("An empty response was received from the client.".to_string())
83 })?;
84
85 if client_message.is_error() {
86 return Err(client_message.as_error()?.error.into());
87 }
88
89 return Ok(client_message.as_response()?.result);
90 }
91
92 async fn send_notification(&self, notification: NotificationFromServer) -> SdkResult<()> {
96 let sender = self.sender().await;
97 let sender = sender.read().await;
98 let sender = sender.as_ref().unwrap();
99
100 sender
101 .send(
102 MessageFromServer::NotificationFromServer(notification),
103 None,
104 None,
105 )
106 .await?;
107 Ok(())
108 }
109
110 async fn list_roots(
116 &self,
117 params: Option<ListRootsRequestParams>,
118 ) -> SdkResult<ListRootsResult> {
119 let request: ListRootsRequest = ListRootsRequest::new(params);
120 let response = self.request(request.into(), None).await?;
121 ListRootsResult::try_from(response).map_err(|err| err.into())
122 }
123
124 async fn send_logging_message(
127 &self,
128 params: LoggingMessageNotificationParams,
129 ) -> SdkResult<()> {
130 let notification = LoggingMessageNotification::new(params);
131 self.send_notification(notification.into()).await
132 }
133
134 async fn send_prompt_list_changed(
138 &self,
139 params: Option<PromptListChangedNotificationParams>,
140 ) -> SdkResult<()> {
141 let notification = PromptListChangedNotification::new(params);
142 self.send_notification(notification.into()).await
143 }
144
145 async fn send_resource_list_changed(
149 &self,
150 params: Option<ResourceListChangedNotificationParams>,
151 ) -> SdkResult<()> {
152 let notification = ResourceListChangedNotification::new(params);
153 self.send_notification(notification.into()).await
154 }
155
156 async fn send_resource_updated(
160 &self,
161 params: ResourceUpdatedNotificationParams,
162 ) -> SdkResult<()> {
163 let notification = ResourceUpdatedNotification::new(params);
164 self.send_notification(notification.into()).await
165 }
166
167 async fn send_tool_list_changed(
171 &self,
172 params: Option<ToolListChangedNotificationParams>,
173 ) -> SdkResult<()> {
174 let notification = ToolListChangedNotification::new(params);
175 self.send_notification(notification.into()).await
176 }
177
178 async fn ping(&self, timeout: Option<Duration>) -> SdkResult<crate::schema::Result> {
189 let ping_request = PingRequest::new(None);
190 let response = self.request(ping_request.into(), timeout).await?;
191 Ok(response.try_into()?)
192 }
193
194 async fn create_message(
200 &self,
201 params: CreateMessageRequestParams,
202 ) -> SdkResult<CreateMessageResult> {
203 let ping_request = CreateMessageRequest::new(params);
204 let response = self.request(ping_request.into(), None).await?;
205 Ok(response.try_into()?)
206 }
207
208 fn client_supports_sampling(&self) -> Option<bool> {
220 self.client_info()
221 .map(|client_details| client_details.capabilities.sampling.is_some())
222 }
223
224 fn client_supports_root_list(&self) -> Option<bool> {
236 self.client_info()
237 .map(|client_details| client_details.capabilities.roots.is_some())
238 }
239
240 fn client_supports_experimental(&self) -> Option<bool> {
252 self.client_info()
253 .map(|client_details| client_details.capabilities.experimental.is_some())
254 }
255
256 async fn stderr_message(&self, message: String) -> SdkResult<()>;
258
259 fn assert_client_capabilities(
267 &self,
268 request_method: &String,
269 ) -> std::result::Result<(), RpcError> {
270 let entity = "Client";
271 if *request_method == CreateMessageRequest::method_name()
272 && !self.client_supports_sampling().unwrap_or(false)
273 {
274 return Err(
275 RpcError::internal_error().with_message(format_assertion_message(
276 entity,
277 "sampling",
278 request_method,
279 )),
280 );
281 }
282 if *request_method == ListRootsRequest::method_name()
283 && !self.client_supports_root_list().unwrap_or(false)
284 {
285 return Err(
286 RpcError::internal_error().with_message(format_assertion_message(
287 entity,
288 "listing roots",
289 request_method,
290 )),
291 );
292 }
293 Ok(())
294 }
295
296 fn assert_server_notification_capabilities(
297 &self,
298 notification_method: &String,
299 ) -> std::result::Result<(), RpcError> {
300 let entity = "Server";
301
302 let capabilities = &self.server_info().capabilities;
303
304 if *notification_method == LoggingMessageNotification::method_name()
305 && capabilities.logging.is_none()
306 {
307 return Err(
308 RpcError::internal_error().with_message(format_assertion_message(
309 entity,
310 "logging",
311 notification_method,
312 )),
313 );
314 }
315 if *notification_method == ResourceUpdatedNotification::method_name()
316 && capabilities.resources.is_none()
317 {
318 return Err(
319 RpcError::internal_error().with_message(format_assertion_message(
320 entity,
321 "notifying about resources",
322 notification_method,
323 )),
324 );
325 }
326 if *notification_method == ToolListChangedNotification::method_name()
327 && capabilities.tools.is_none()
328 {
329 return Err(
330 RpcError::internal_error().with_message(format_assertion_message(
331 entity,
332 "notifying of tool list changes",
333 notification_method,
334 )),
335 );
336 }
337 if *notification_method == PromptListChangedNotification::method_name()
338 && capabilities.prompts.is_none()
339 {
340 return Err(
341 RpcError::internal_error().with_message(format_assertion_message(
342 entity,
343 "notifying of prompt list changes",
344 notification_method,
345 )),
346 );
347 }
348
349 Ok(())
350 }
351
352 fn assert_server_request_capabilities(
353 &self,
354 request_method: &String,
355 ) -> std::result::Result<(), RpcError> {
356 let entity = "Server";
357 let capabilities = &self.server_info().capabilities;
358
359 if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() {
360 return Err(
361 RpcError::internal_error().with_message(format_assertion_message(
362 entity,
363 "logging",
364 request_method,
365 )),
366 );
367 }
368 if [
369 GetPromptRequest::method_name(),
370 ListPromptsRequest::method_name(),
371 ]
372 .contains(request_method)
373 && capabilities.prompts.is_none()
374 {
375 return Err(
376 RpcError::internal_error().with_message(format_assertion_message(
377 entity,
378 "prompts",
379 request_method,
380 )),
381 );
382 }
383 if [
384 ListResourcesRequest::method_name(),
385 ListResourceTemplatesRequest::method_name(),
386 ReadResourceRequest::method_name(),
387 ]
388 .contains(request_method)
389 && capabilities.resources.is_none()
390 {
391 return Err(
392 RpcError::internal_error().with_message(format_assertion_message(
393 entity,
394 "resources",
395 request_method,
396 )),
397 );
398 }
399 if [
400 CallToolRequest::method_name(),
401 ListToolsRequest::method_name(),
402 ]
403 .contains(request_method)
404 && capabilities.tools.is_none()
405 {
406 return Err(
407 RpcError::internal_error().with_message(format_assertion_message(
408 entity,
409 "tools",
410 request_method,
411 )),
412 );
413 }
414 Ok(())
415 }
416}