1use crate::schema::{
2 schema_utils::{
3 ClientMessage, McpMessage, MessageFromServer, NotificationFromServer, RequestFromServer,
4 ResultFromClient, ServerMessage,
5 },
6 CallToolRequest, CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult,
7 ElicitRequest, ElicitRequestParams, ElicitRequestedSchema, ElicitResult, GetPromptRequest,
8 Implementation, InitializeRequestParams, InitializeResult, ListPromptsRequest,
9 ListResourceTemplatesRequest, ListResourcesRequest, ListRootsRequest, ListRootsRequestParams,
10 ListRootsResult, ListToolsRequest, LoggingMessageNotification,
11 LoggingMessageNotificationParams, PingRequest, PromptListChangedNotification,
12 PromptListChangedNotificationParams, ReadResourceRequest, RequestId,
13 ResourceListChangedNotification, ResourceListChangedNotificationParams,
14 ResourceUpdatedNotification, ResourceUpdatedNotificationParams, RpcError, ServerCapabilities,
15 SetLevelRequest, ToolListChangedNotification, ToolListChangedNotificationParams,
16};
17use crate::{error::SdkResult, utils::format_assertion_message};
18use async_trait::async_trait;
19use rust_mcp_transport::SessionId;
20use std::{sync::Arc, time::Duration};
21
22#[async_trait]
24pub trait McpServer: Sync + Send {
25 async fn start(self: Arc<Self>) -> SdkResult<()>;
26 async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>;
27 fn server_info(&self) -> &InitializeResult;
28 fn client_info(&self) -> Option<InitializeRequestParams>;
29
30 async fn wait_for_initialization(&self);
31
32 async fn send(
33 &self,
34 message: MessageFromServer,
35 request_id: Option<RequestId>,
36 request_timeout: Option<Duration>,
37 ) -> SdkResult<Option<ClientMessage>>;
38
39 async fn send_batch(
40 &self,
41 messages: Vec<ServerMessage>,
42 request_timeout: Option<Duration>,
43 ) -> SdkResult<Option<Vec<ClientMessage>>>;
44
45 fn is_initialized(&self) -> bool {
47 self.client_info().is_some()
48 }
49
50 fn client_version(&self) -> Option<Implementation> {
53 self.client_info()
54 .map(|client_details| client_details.client_info)
55 }
56
57 fn capabilities(&self) -> &ServerCapabilities {
59 &self.server_info().capabilities
60 }
61
62 async fn elicit_input(
67 &self,
68 message: String,
69 requested_schema: ElicitRequestedSchema,
70 ) -> SdkResult<ElicitResult> {
71 let request: ElicitRequest = ElicitRequest::new(ElicitRequestParams {
72 message,
73 requested_schema,
74 });
75 let response = self.request(request.into(), None).await?;
76 ElicitResult::try_from(response).map_err(|err| err.into())
77 }
78
79 async fn request(
85 &self,
86 request: RequestFromServer,
87 timeout: Option<Duration>,
88 ) -> SdkResult<ResultFromClient> {
89 let response = self
91 .send(MessageFromServer::RequestFromServer(request), None, timeout)
92 .await?;
93
94 let client_message = response.ok_or_else(|| {
95 RpcError::internal_error()
96 .with_message("An empty response was received from the client.".to_string())
97 })?;
98
99 if client_message.is_error() {
100 return Err(client_message.as_error()?.error.into());
101 }
102
103 return Ok(client_message.as_response()?.result);
104 }
105
106 async fn send_notification(&self, notification: NotificationFromServer) -> SdkResult<()> {
110 self.send(
111 MessageFromServer::NotificationFromServer(notification),
112 None,
113 None,
114 )
115 .await?;
116 Ok(())
117 }
118
119 async fn list_roots(
125 &self,
126 params: Option<ListRootsRequestParams>,
127 ) -> SdkResult<ListRootsResult> {
128 let request: ListRootsRequest = ListRootsRequest::new(params);
129 let response = self.request(request.into(), None).await?;
130 ListRootsResult::try_from(response).map_err(|err| err.into())
131 }
132
133 async fn send_logging_message(
136 &self,
137 params: LoggingMessageNotificationParams,
138 ) -> SdkResult<()> {
139 let notification = LoggingMessageNotification::new(params);
140 self.send_notification(notification.into()).await
141 }
142
143 async fn send_prompt_list_changed(
147 &self,
148 params: Option<PromptListChangedNotificationParams>,
149 ) -> SdkResult<()> {
150 let notification = PromptListChangedNotification::new(params);
151 self.send_notification(notification.into()).await
152 }
153
154 async fn send_resource_list_changed(
158 &self,
159 params: Option<ResourceListChangedNotificationParams>,
160 ) -> SdkResult<()> {
161 let notification = ResourceListChangedNotification::new(params);
162 self.send_notification(notification.into()).await
163 }
164
165 async fn send_resource_updated(
169 &self,
170 params: ResourceUpdatedNotificationParams,
171 ) -> SdkResult<()> {
172 let notification = ResourceUpdatedNotification::new(params);
173 self.send_notification(notification.into()).await
174 }
175
176 async fn send_tool_list_changed(
180 &self,
181 params: Option<ToolListChangedNotificationParams>,
182 ) -> SdkResult<()> {
183 let notification = ToolListChangedNotification::new(params);
184 self.send_notification(notification.into()).await
185 }
186
187 async fn ping(&self, timeout: Option<Duration>) -> SdkResult<crate::schema::Result> {
198 let ping_request = PingRequest::new(None);
199 let response = self.request(ping_request.into(), timeout).await?;
200 Ok(response.try_into()?)
201 }
202
203 async fn create_message(
209 &self,
210 params: CreateMessageRequestParams,
211 ) -> SdkResult<CreateMessageResult> {
212 let ping_request = CreateMessageRequest::new(params);
213 let response = self.request(ping_request.into(), None).await?;
214 Ok(response.try_into()?)
215 }
216
217 fn client_supports_sampling(&self) -> Option<bool> {
229 self.client_info()
230 .map(|client_details| client_details.capabilities.sampling.is_some())
231 }
232
233 fn client_supports_root_list(&self) -> Option<bool> {
245 self.client_info()
246 .map(|client_details| client_details.capabilities.roots.is_some())
247 }
248
249 fn client_supports_experimental(&self) -> Option<bool> {
261 self.client_info()
262 .map(|client_details| client_details.capabilities.experimental.is_some())
263 }
264
265 async fn stderr_message(&self, message: String) -> SdkResult<()>;
267
268 fn assert_client_capabilities(
276 &self,
277 request_method: &String,
278 ) -> std::result::Result<(), RpcError> {
279 let entity = "Client";
280 if *request_method == CreateMessageRequest::method_name()
281 && !self.client_supports_sampling().unwrap_or(false)
282 {
283 return Err(
284 RpcError::internal_error().with_message(format_assertion_message(
285 entity,
286 "sampling",
287 request_method,
288 )),
289 );
290 }
291 if *request_method == ListRootsRequest::method_name()
292 && !self.client_supports_root_list().unwrap_or(false)
293 {
294 return Err(
295 RpcError::internal_error().with_message(format_assertion_message(
296 entity,
297 "listing roots",
298 request_method,
299 )),
300 );
301 }
302 Ok(())
303 }
304
305 fn assert_server_notification_capabilities(
306 &self,
307 notification_method: &String,
308 ) -> std::result::Result<(), RpcError> {
309 let entity = "Server";
310
311 let capabilities = &self.server_info().capabilities;
312
313 if *notification_method == LoggingMessageNotification::method_name()
314 && capabilities.logging.is_none()
315 {
316 return Err(
317 RpcError::internal_error().with_message(format_assertion_message(
318 entity,
319 "logging",
320 notification_method,
321 )),
322 );
323 }
324 if *notification_method == ResourceUpdatedNotification::method_name()
325 && capabilities.resources.is_none()
326 {
327 return Err(
328 RpcError::internal_error().with_message(format_assertion_message(
329 entity,
330 "notifying about resources",
331 notification_method,
332 )),
333 );
334 }
335 if *notification_method == ToolListChangedNotification::method_name()
336 && capabilities.tools.is_none()
337 {
338 return Err(
339 RpcError::internal_error().with_message(format_assertion_message(
340 entity,
341 "notifying of tool list changes",
342 notification_method,
343 )),
344 );
345 }
346 if *notification_method == PromptListChangedNotification::method_name()
347 && capabilities.prompts.is_none()
348 {
349 return Err(
350 RpcError::internal_error().with_message(format_assertion_message(
351 entity,
352 "notifying of prompt list changes",
353 notification_method,
354 )),
355 );
356 }
357
358 Ok(())
359 }
360
361 fn assert_server_request_capabilities(
362 &self,
363 request_method: &String,
364 ) -> std::result::Result<(), RpcError> {
365 let entity = "Server";
366 let capabilities = &self.server_info().capabilities;
367
368 if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() {
369 return Err(
370 RpcError::internal_error().with_message(format_assertion_message(
371 entity,
372 "logging",
373 request_method,
374 )),
375 );
376 }
377 if [
378 GetPromptRequest::method_name(),
379 ListPromptsRequest::method_name(),
380 ]
381 .contains(request_method)
382 && capabilities.prompts.is_none()
383 {
384 return Err(
385 RpcError::internal_error().with_message(format_assertion_message(
386 entity,
387 "prompts",
388 request_method,
389 )),
390 );
391 }
392 if [
393 ListResourcesRequest::method_name(),
394 ListResourceTemplatesRequest::method_name(),
395 ReadResourceRequest::method_name(),
396 ]
397 .contains(request_method)
398 && capabilities.resources.is_none()
399 {
400 return Err(
401 RpcError::internal_error().with_message(format_assertion_message(
402 entity,
403 "resources",
404 request_method,
405 )),
406 );
407 }
408 if [
409 CallToolRequest::method_name(),
410 ListToolsRequest::method_name(),
411 ]
412 .contains(request_method)
413 && capabilities.tools.is_none()
414 {
415 return Err(
416 RpcError::internal_error().with_message(format_assertion_message(
417 entity,
418 "tools",
419 request_method,
420 )),
421 );
422 }
423 Ok(())
424 }
425
426 #[cfg(feature = "hyper-server")]
427 fn session_id(&self) -> Option<SessionId>;
428}