1use std::time::Duration;
2
3use crate::schema::{
4 schema_utils::{
5 ClientMessage, ClientMessages, McpMessage, MessageFromServer, NotificationFromServer,
6 RequestFromServer, ResultFromClient, ServerMessage,
7 },
8 CallToolRequest, CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult,
9 GetPromptRequest, Implementation, InitializeRequestParams, InitializeResult,
10 ListPromptsRequest, ListResourceTemplatesRequest, ListResourcesRequest, ListRootsRequest,
11 ListRootsRequestParams, ListRootsResult, ListToolsRequest, LoggingMessageNotification,
12 LoggingMessageNotificationParams, PingRequest, PromptListChangedNotification,
13 PromptListChangedNotificationParams, ReadResourceRequest, RequestId,
14 ResourceListChangedNotification, ResourceListChangedNotificationParams,
15 ResourceUpdatedNotification, ResourceUpdatedNotificationParams, RpcError, ServerCapabilities,
16 SetLevelRequest, ToolListChangedNotification, ToolListChangedNotificationParams,
17};
18use async_trait::async_trait;
19
20use crate::{error::SdkResult, utils::format_assertion_message};
21
22#[async_trait]
24pub trait McpServer: Sync + Send {
25 async fn start(&self) -> SdkResult<()>;
26 async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>;
27 fn server_info(&self) -> &InitializeResult;
28 fn client_info(&self) -> Option<InitializeRequestParams>;
29
30 async fn wait_for_initialization(&self);
31
32 #[deprecated(since = "0.2.0", note = "Use `client_info()` instead.")]
33 fn get_client_info(&self) -> Option<InitializeRequestParams> {
34 self.client_info()
35 }
36
37 #[deprecated(since = "0.2.0", note = "Use `server_info()` instead.")]
38 fn get_server_info(&self) -> &InitializeResult {
39 self.server_info()
40 }
41
42 async fn send(
43 &self,
44 message: MessageFromServer,
45 request_id: Option<RequestId>,
46 request_timeout: Option<Duration>,
47 ) -> SdkResult<Option<ClientMessages>>;
48
49 async fn send_batch(
50 &self,
51 messages: Vec<ServerMessage>,
52 request_timeout: Option<Duration>,
53 ) -> SdkResult<Option<Vec<ClientMessage>>>;
54
55 fn is_initialized(&self) -> bool {
57 self.client_info().is_some()
58 }
59
60 fn client_version(&self) -> Option<Implementation> {
63 self.client_info()
64 .map(|client_details| client_details.client_info)
65 }
66
67 fn capabilities(&self) -> &ServerCapabilities {
69 &self.server_info().capabilities
70 }
71
72 async fn request(
78 &self,
79 request: RequestFromServer,
80 timeout: Option<Duration>,
81 ) -> SdkResult<ResultFromClient> {
82 let response = self
84 .send(MessageFromServer::RequestFromServer(request), None, timeout)
85 .await?;
86
87 let client_messages = response.ok_or_else(|| {
88 RpcError::internal_error()
89 .with_message("An empty response was received from the client.".to_string())
90 })?;
91
92 let client_message = client_messages.as_single()?;
93
94 if client_message.is_error() {
95 return Err(client_message.as_error()?.error.into());
96 }
97
98 return Ok(client_message.as_response()?.result);
99 }
100
101 async fn send_notification(&self, notification: NotificationFromServer) -> SdkResult<()> {
105 self.send(
106 MessageFromServer::NotificationFromServer(notification),
107 None,
108 None,
109 )
110 .await?;
111 Ok(())
112 }
113
114 async fn list_roots(
120 &self,
121 params: Option<ListRootsRequestParams>,
122 ) -> SdkResult<ListRootsResult> {
123 let request: ListRootsRequest = ListRootsRequest::new(params);
124 let response = self.request(request.into(), None).await?;
125 ListRootsResult::try_from(response).map_err(|err| err.into())
126 }
127
128 async fn send_logging_message(
131 &self,
132 params: LoggingMessageNotificationParams,
133 ) -> SdkResult<()> {
134 let notification = LoggingMessageNotification::new(params);
135 self.send_notification(notification.into()).await
136 }
137
138 async fn send_prompt_list_changed(
142 &self,
143 params: Option<PromptListChangedNotificationParams>,
144 ) -> SdkResult<()> {
145 let notification = PromptListChangedNotification::new(params);
146 self.send_notification(notification.into()).await
147 }
148
149 async fn send_resource_list_changed(
153 &self,
154 params: Option<ResourceListChangedNotificationParams>,
155 ) -> SdkResult<()> {
156 let notification = ResourceListChangedNotification::new(params);
157 self.send_notification(notification.into()).await
158 }
159
160 async fn send_resource_updated(
164 &self,
165 params: ResourceUpdatedNotificationParams,
166 ) -> SdkResult<()> {
167 let notification = ResourceUpdatedNotification::new(params);
168 self.send_notification(notification.into()).await
169 }
170
171 async fn send_tool_list_changed(
175 &self,
176 params: Option<ToolListChangedNotificationParams>,
177 ) -> SdkResult<()> {
178 let notification = ToolListChangedNotification::new(params);
179 self.send_notification(notification.into()).await
180 }
181
182 async fn ping(&self, timeout: Option<Duration>) -> SdkResult<crate::schema::Result> {
193 let ping_request = PingRequest::new(None);
194 let response = self.request(ping_request.into(), timeout).await?;
195 Ok(response.try_into()?)
196 }
197
198 async fn create_message(
204 &self,
205 params: CreateMessageRequestParams,
206 ) -> SdkResult<CreateMessageResult> {
207 let ping_request = CreateMessageRequest::new(params);
208 let response = self.request(ping_request.into(), None).await?;
209 Ok(response.try_into()?)
210 }
211
212 fn client_supports_sampling(&self) -> Option<bool> {
224 self.client_info()
225 .map(|client_details| client_details.capabilities.sampling.is_some())
226 }
227
228 fn client_supports_root_list(&self) -> Option<bool> {
240 self.client_info()
241 .map(|client_details| client_details.capabilities.roots.is_some())
242 }
243
244 fn client_supports_experimental(&self) -> Option<bool> {
256 self.client_info()
257 .map(|client_details| client_details.capabilities.experimental.is_some())
258 }
259
260 async fn stderr_message(&self, message: String) -> SdkResult<()>;
262
263 fn assert_client_capabilities(
271 &self,
272 request_method: &String,
273 ) -> std::result::Result<(), RpcError> {
274 let entity = "Client";
275 if *request_method == CreateMessageRequest::method_name()
276 && !self.client_supports_sampling().unwrap_or(false)
277 {
278 return Err(
279 RpcError::internal_error().with_message(format_assertion_message(
280 entity,
281 "sampling",
282 request_method,
283 )),
284 );
285 }
286 if *request_method == ListRootsRequest::method_name()
287 && !self.client_supports_root_list().unwrap_or(false)
288 {
289 return Err(
290 RpcError::internal_error().with_message(format_assertion_message(
291 entity,
292 "listing roots",
293 request_method,
294 )),
295 );
296 }
297 Ok(())
298 }
299
300 fn assert_server_notification_capabilities(
301 &self,
302 notification_method: &String,
303 ) -> std::result::Result<(), RpcError> {
304 let entity = "Server";
305
306 let capabilities = &self.server_info().capabilities;
307
308 if *notification_method == LoggingMessageNotification::method_name()
309 && capabilities.logging.is_none()
310 {
311 return Err(
312 RpcError::internal_error().with_message(format_assertion_message(
313 entity,
314 "logging",
315 notification_method,
316 )),
317 );
318 }
319 if *notification_method == ResourceUpdatedNotification::method_name()
320 && capabilities.resources.is_none()
321 {
322 return Err(
323 RpcError::internal_error().with_message(format_assertion_message(
324 entity,
325 "notifying about resources",
326 notification_method,
327 )),
328 );
329 }
330 if *notification_method == ToolListChangedNotification::method_name()
331 && capabilities.tools.is_none()
332 {
333 return Err(
334 RpcError::internal_error().with_message(format_assertion_message(
335 entity,
336 "notifying of tool list changes",
337 notification_method,
338 )),
339 );
340 }
341 if *notification_method == PromptListChangedNotification::method_name()
342 && capabilities.prompts.is_none()
343 {
344 return Err(
345 RpcError::internal_error().with_message(format_assertion_message(
346 entity,
347 "notifying of prompt list changes",
348 notification_method,
349 )),
350 );
351 }
352
353 Ok(())
354 }
355
356 fn assert_server_request_capabilities(
357 &self,
358 request_method: &String,
359 ) -> std::result::Result<(), RpcError> {
360 let entity = "Server";
361 let capabilities = &self.server_info().capabilities;
362
363 if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() {
364 return Err(
365 RpcError::internal_error().with_message(format_assertion_message(
366 entity,
367 "logging",
368 request_method,
369 )),
370 );
371 }
372 if [
373 GetPromptRequest::method_name(),
374 ListPromptsRequest::method_name(),
375 ]
376 .contains(request_method)
377 && capabilities.prompts.is_none()
378 {
379 return Err(
380 RpcError::internal_error().with_message(format_assertion_message(
381 entity,
382 "prompts",
383 request_method,
384 )),
385 );
386 }
387 if [
388 ListResourcesRequest::method_name(),
389 ListResourceTemplatesRequest::method_name(),
390 ReadResourceRequest::method_name(),
391 ]
392 .contains(request_method)
393 && capabilities.resources.is_none()
394 {
395 return Err(
396 RpcError::internal_error().with_message(format_assertion_message(
397 entity,
398 "resources",
399 request_method,
400 )),
401 );
402 }
403 if [
404 CallToolRequest::method_name(),
405 ListToolsRequest::method_name(),
406 ]
407 .contains(request_method)
408 && capabilities.tools.is_none()
409 {
410 return Err(
411 RpcError::internal_error().with_message(format_assertion_message(
412 entity,
413 "tools",
414 request_method,
415 )),
416 );
417 }
418 Ok(())
419 }
420}