1use crate::schema::{
2 schema_utils::{
3 ClientMessage, McpMessage, MessageFromServer, NotificationFromServer, RequestFromServer,
4 ResultFromClient, ServerMessage,
5 },
6 CallToolRequest, CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult,
7 GetPromptRequest, Implementation, InitializeRequestParams, InitializeResult,
8 ListPromptsRequest, ListResourceTemplatesRequest, ListResourcesRequest, ListRootsRequest,
9 ListRootsRequestParams, ListRootsResult, ListToolsRequest, LoggingMessageNotification,
10 LoggingMessageNotificationParams, PingRequest, PromptListChangedNotification,
11 PromptListChangedNotificationParams, ReadResourceRequest, RequestId,
12 ResourceListChangedNotification, ResourceListChangedNotificationParams,
13 ResourceUpdatedNotification, ResourceUpdatedNotificationParams, RpcError, ServerCapabilities,
14 SetLevelRequest, ToolListChangedNotification, ToolListChangedNotificationParams,
15};
16use async_trait::async_trait;
17use rust_mcp_transport::SessionId;
18use std::time::Duration;
19
20use crate::{error::SdkResult, utils::format_assertion_message};
21
22#[async_trait]
24pub trait McpServer: Sync + Send {
25 async fn start(&self) -> SdkResult<()>;
26 async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>;
27 fn server_info(&self) -> &InitializeResult;
28 fn client_info(&self) -> Option<InitializeRequestParams>;
29
30 async fn wait_for_initialization(&self);
31
32 async fn send(
33 &self,
34 message: MessageFromServer,
35 request_id: Option<RequestId>,
36 request_timeout: Option<Duration>,
37 ) -> SdkResult<Option<ClientMessage>>;
38
39 async fn send_batch(
40 &self,
41 messages: Vec<ServerMessage>,
42 request_timeout: Option<Duration>,
43 ) -> SdkResult<Option<Vec<ClientMessage>>>;
44
45 fn is_initialized(&self) -> bool {
47 self.client_info().is_some()
48 }
49
50 fn client_version(&self) -> Option<Implementation> {
53 self.client_info()
54 .map(|client_details| client_details.client_info)
55 }
56
57 fn capabilities(&self) -> &ServerCapabilities {
59 &self.server_info().capabilities
60 }
61
62 async fn request(
68 &self,
69 request: RequestFromServer,
70 timeout: Option<Duration>,
71 ) -> SdkResult<ResultFromClient> {
72 let response = self
74 .send(MessageFromServer::RequestFromServer(request), None, timeout)
75 .await?;
76
77 let client_message = response.ok_or_else(|| {
78 RpcError::internal_error()
79 .with_message("An empty response was received from the client.".to_string())
80 })?;
81
82 if client_message.is_error() {
83 return Err(client_message.as_error()?.error.into());
84 }
85
86 return Ok(client_message.as_response()?.result);
87 }
88
89 async fn send_notification(&self, notification: NotificationFromServer) -> SdkResult<()> {
93 self.send(
94 MessageFromServer::NotificationFromServer(notification),
95 None,
96 None,
97 )
98 .await?;
99 Ok(())
100 }
101
102 async fn list_roots(
108 &self,
109 params: Option<ListRootsRequestParams>,
110 ) -> SdkResult<ListRootsResult> {
111 let request: ListRootsRequest = ListRootsRequest::new(params);
112 let response = self.request(request.into(), None).await?;
113 ListRootsResult::try_from(response).map_err(|err| err.into())
114 }
115
116 async fn send_logging_message(
119 &self,
120 params: LoggingMessageNotificationParams,
121 ) -> SdkResult<()> {
122 let notification = LoggingMessageNotification::new(params);
123 self.send_notification(notification.into()).await
124 }
125
126 async fn send_prompt_list_changed(
130 &self,
131 params: Option<PromptListChangedNotificationParams>,
132 ) -> SdkResult<()> {
133 let notification = PromptListChangedNotification::new(params);
134 self.send_notification(notification.into()).await
135 }
136
137 async fn send_resource_list_changed(
141 &self,
142 params: Option<ResourceListChangedNotificationParams>,
143 ) -> SdkResult<()> {
144 let notification = ResourceListChangedNotification::new(params);
145 self.send_notification(notification.into()).await
146 }
147
148 async fn send_resource_updated(
152 &self,
153 params: ResourceUpdatedNotificationParams,
154 ) -> SdkResult<()> {
155 let notification = ResourceUpdatedNotification::new(params);
156 self.send_notification(notification.into()).await
157 }
158
159 async fn send_tool_list_changed(
163 &self,
164 params: Option<ToolListChangedNotificationParams>,
165 ) -> SdkResult<()> {
166 let notification = ToolListChangedNotification::new(params);
167 self.send_notification(notification.into()).await
168 }
169
170 async fn ping(&self, timeout: Option<Duration>) -> SdkResult<crate::schema::Result> {
181 let ping_request = PingRequest::new(None);
182 let response = self.request(ping_request.into(), timeout).await?;
183 Ok(response.try_into()?)
184 }
185
186 async fn create_message(
192 &self,
193 params: CreateMessageRequestParams,
194 ) -> SdkResult<CreateMessageResult> {
195 let ping_request = CreateMessageRequest::new(params);
196 let response = self.request(ping_request.into(), None).await?;
197 Ok(response.try_into()?)
198 }
199
200 fn client_supports_sampling(&self) -> Option<bool> {
212 self.client_info()
213 .map(|client_details| client_details.capabilities.sampling.is_some())
214 }
215
216 fn client_supports_root_list(&self) -> Option<bool> {
228 self.client_info()
229 .map(|client_details| client_details.capabilities.roots.is_some())
230 }
231
232 fn client_supports_experimental(&self) -> Option<bool> {
244 self.client_info()
245 .map(|client_details| client_details.capabilities.experimental.is_some())
246 }
247
248 async fn stderr_message(&self, message: String) -> SdkResult<()>;
250
251 fn assert_client_capabilities(
259 &self,
260 request_method: &String,
261 ) -> std::result::Result<(), RpcError> {
262 let entity = "Client";
263 if *request_method == CreateMessageRequest::method_name()
264 && !self.client_supports_sampling().unwrap_or(false)
265 {
266 return Err(
267 RpcError::internal_error().with_message(format_assertion_message(
268 entity,
269 "sampling",
270 request_method,
271 )),
272 );
273 }
274 if *request_method == ListRootsRequest::method_name()
275 && !self.client_supports_root_list().unwrap_or(false)
276 {
277 return Err(
278 RpcError::internal_error().with_message(format_assertion_message(
279 entity,
280 "listing roots",
281 request_method,
282 )),
283 );
284 }
285 Ok(())
286 }
287
288 fn assert_server_notification_capabilities(
289 &self,
290 notification_method: &String,
291 ) -> std::result::Result<(), RpcError> {
292 let entity = "Server";
293
294 let capabilities = &self.server_info().capabilities;
295
296 if *notification_method == LoggingMessageNotification::method_name()
297 && capabilities.logging.is_none()
298 {
299 return Err(
300 RpcError::internal_error().with_message(format_assertion_message(
301 entity,
302 "logging",
303 notification_method,
304 )),
305 );
306 }
307 if *notification_method == ResourceUpdatedNotification::method_name()
308 && capabilities.resources.is_none()
309 {
310 return Err(
311 RpcError::internal_error().with_message(format_assertion_message(
312 entity,
313 "notifying about resources",
314 notification_method,
315 )),
316 );
317 }
318 if *notification_method == ToolListChangedNotification::method_name()
319 && capabilities.tools.is_none()
320 {
321 return Err(
322 RpcError::internal_error().with_message(format_assertion_message(
323 entity,
324 "notifying of tool list changes",
325 notification_method,
326 )),
327 );
328 }
329 if *notification_method == PromptListChangedNotification::method_name()
330 && capabilities.prompts.is_none()
331 {
332 return Err(
333 RpcError::internal_error().with_message(format_assertion_message(
334 entity,
335 "notifying of prompt list changes",
336 notification_method,
337 )),
338 );
339 }
340
341 Ok(())
342 }
343
344 fn assert_server_request_capabilities(
345 &self,
346 request_method: &String,
347 ) -> std::result::Result<(), RpcError> {
348 let entity = "Server";
349 let capabilities = &self.server_info().capabilities;
350
351 if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() {
352 return Err(
353 RpcError::internal_error().with_message(format_assertion_message(
354 entity,
355 "logging",
356 request_method,
357 )),
358 );
359 }
360 if [
361 GetPromptRequest::method_name(),
362 ListPromptsRequest::method_name(),
363 ]
364 .contains(request_method)
365 && capabilities.prompts.is_none()
366 {
367 return Err(
368 RpcError::internal_error().with_message(format_assertion_message(
369 entity,
370 "prompts",
371 request_method,
372 )),
373 );
374 }
375 if [
376 ListResourcesRequest::method_name(),
377 ListResourceTemplatesRequest::method_name(),
378 ReadResourceRequest::method_name(),
379 ]
380 .contains(request_method)
381 && capabilities.resources.is_none()
382 {
383 return Err(
384 RpcError::internal_error().with_message(format_assertion_message(
385 entity,
386 "resources",
387 request_method,
388 )),
389 );
390 }
391 if [
392 CallToolRequest::method_name(),
393 ListToolsRequest::method_name(),
394 ]
395 .contains(request_method)
396 && capabilities.tools.is_none()
397 {
398 return Err(
399 RpcError::internal_error().with_message(format_assertion_message(
400 entity,
401 "tools",
402 request_method,
403 )),
404 );
405 }
406 Ok(())
407 }
408
409 #[cfg(feature = "hyper-server")]
410 fn session_id(&self) -> Option<SessionId>;
411}