1use crate::auth::AuthInfo;
2use crate::{error::SdkResult, utils::format_assertion_message};
3
4use crate::schema::{
5 schema_utils::{
6 ClientMessage, McpMessage, MessageFromServer, NotificationFromServer, RequestFromServer,
7 ResultFromClient, ServerMessage,
8 },
9 CallToolRequest, CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult,
10 ElicitRequest, ElicitRequestParams, ElicitRequestedSchema, ElicitResult, GetPromptRequest,
11 Implementation, InitializeRequestParams, InitializeResult, ListPromptsRequest,
12 ListResourceTemplatesRequest, ListResourcesRequest, ListRootsRequest, ListRootsRequestParams,
13 ListRootsResult, ListToolsRequest, LoggingMessageNotification,
14 LoggingMessageNotificationParams, PingRequest, PromptListChangedNotification,
15 PromptListChangedNotificationParams, ReadResourceRequest, RequestId,
16 ResourceListChangedNotification, ResourceListChangedNotificationParams,
17 ResourceUpdatedNotification, ResourceUpdatedNotificationParams, RpcError, ServerCapabilities,
18 SetLevelRequest, ToolListChangedNotification, ToolListChangedNotificationParams,
19};
20use async_trait::async_trait;
21use rust_mcp_transport::SessionId;
22use std::{sync::Arc, time::Duration};
23use tokio::sync::RwLockReadGuard;
24
25#[async_trait]
27pub trait McpServer: Sync + Send {
28 async fn start(self: Arc<Self>) -> SdkResult<()>;
29 async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>;
30 fn server_info(&self) -> &InitializeResult;
31 fn client_info(&self) -> Option<InitializeRequestParams>;
32
33 async fn auth_info(&self) -> RwLockReadGuard<'_, Option<AuthInfo>>;
34 async fn auth_info_cloned(&self) -> Option<AuthInfo>;
35 async fn update_auth_info(&self, auth_info: Option<AuthInfo>);
36
37 async fn wait_for_initialization(&self);
38
39 async fn send(
40 &self,
41 message: MessageFromServer,
42 request_id: Option<RequestId>,
43 request_timeout: Option<Duration>,
44 ) -> SdkResult<Option<ClientMessage>>;
45
46 async fn send_batch(
47 &self,
48 messages: Vec<ServerMessage>,
49 request_timeout: Option<Duration>,
50 ) -> SdkResult<Option<Vec<ClientMessage>>>;
51
52 fn is_initialized(&self) -> bool {
54 self.client_info().is_some()
55 }
56
57 fn client_version(&self) -> Option<Implementation> {
60 self.client_info()
61 .map(|client_details| client_details.client_info)
62 }
63
64 fn capabilities(&self) -> &ServerCapabilities {
66 &self.server_info().capabilities
67 }
68
69 async fn elicit_input(
74 &self,
75 message: String,
76 requested_schema: ElicitRequestedSchema,
77 ) -> SdkResult<ElicitResult> {
78 let request: ElicitRequest = ElicitRequest::new(ElicitRequestParams {
79 message,
80 requested_schema,
81 });
82 let response = self.request(request.into(), None).await?;
83 ElicitResult::try_from(response).map_err(|err| err.into())
84 }
85
86 async fn request(
92 &self,
93 request: RequestFromServer,
94 timeout: Option<Duration>,
95 ) -> SdkResult<ResultFromClient> {
96 let response = self
98 .send(MessageFromServer::RequestFromServer(request), None, timeout)
99 .await?;
100
101 let client_message = response.ok_or_else(|| {
102 RpcError::internal_error()
103 .with_message("An empty response was received from the client.".to_string())
104 })?;
105
106 if client_message.is_error() {
107 return Err(client_message.as_error()?.error.into());
108 }
109
110 return Ok(client_message.as_response()?.result);
111 }
112
113 async fn send_notification(&self, notification: NotificationFromServer) -> SdkResult<()> {
117 self.send(
118 MessageFromServer::NotificationFromServer(notification),
119 None,
120 None,
121 )
122 .await?;
123 Ok(())
124 }
125
126 async fn list_roots(
132 &self,
133 params: Option<ListRootsRequestParams>,
134 ) -> SdkResult<ListRootsResult> {
135 let request: ListRootsRequest = ListRootsRequest::new(params);
136 let response = self.request(request.into(), None).await?;
137 ListRootsResult::try_from(response).map_err(|err| err.into())
138 }
139
140 async fn send_logging_message(
143 &self,
144 params: LoggingMessageNotificationParams,
145 ) -> SdkResult<()> {
146 let notification = LoggingMessageNotification::new(params);
147 self.send_notification(notification.into()).await
148 }
149
150 async fn send_prompt_list_changed(
154 &self,
155 params: Option<PromptListChangedNotificationParams>,
156 ) -> SdkResult<()> {
157 let notification = PromptListChangedNotification::new(params);
158 self.send_notification(notification.into()).await
159 }
160
161 async fn send_resource_list_changed(
165 &self,
166 params: Option<ResourceListChangedNotificationParams>,
167 ) -> SdkResult<()> {
168 let notification = ResourceListChangedNotification::new(params);
169 self.send_notification(notification.into()).await
170 }
171
172 async fn send_resource_updated(
176 &self,
177 params: ResourceUpdatedNotificationParams,
178 ) -> SdkResult<()> {
179 let notification = ResourceUpdatedNotification::new(params);
180 self.send_notification(notification.into()).await
181 }
182
183 async fn send_tool_list_changed(
187 &self,
188 params: Option<ToolListChangedNotificationParams>,
189 ) -> SdkResult<()> {
190 let notification = ToolListChangedNotification::new(params);
191 self.send_notification(notification.into()).await
192 }
193
194 async fn ping(&self, timeout: Option<Duration>) -> SdkResult<crate::schema::Result> {
205 let ping_request = PingRequest::new(None);
206 let response = self.request(ping_request.into(), timeout).await?;
207 Ok(response.try_into()?)
208 }
209
210 async fn create_message(
216 &self,
217 params: CreateMessageRequestParams,
218 ) -> SdkResult<CreateMessageResult> {
219 let ping_request = CreateMessageRequest::new(params);
220 let response = self.request(ping_request.into(), None).await?;
221 Ok(response.try_into()?)
222 }
223
224 fn client_supports_sampling(&self) -> Option<bool> {
236 self.client_info()
237 .map(|client_details| client_details.capabilities.sampling.is_some())
238 }
239
240 fn client_supports_root_list(&self) -> Option<bool> {
252 self.client_info()
253 .map(|client_details| client_details.capabilities.roots.is_some())
254 }
255
256 fn client_supports_experimental(&self) -> Option<bool> {
268 self.client_info()
269 .map(|client_details| client_details.capabilities.experimental.is_some())
270 }
271
272 async fn stderr_message(&self, message: String) -> SdkResult<()>;
274
275 fn assert_client_capabilities(
283 &self,
284 request_method: &String,
285 ) -> std::result::Result<(), RpcError> {
286 let entity = "Client";
287 if *request_method == CreateMessageRequest::method_name()
288 && !self.client_supports_sampling().unwrap_or(false)
289 {
290 return Err(
291 RpcError::internal_error().with_message(format_assertion_message(
292 entity,
293 "sampling",
294 request_method,
295 )),
296 );
297 }
298 if *request_method == ListRootsRequest::method_name()
299 && !self.client_supports_root_list().unwrap_or(false)
300 {
301 return Err(
302 RpcError::internal_error().with_message(format_assertion_message(
303 entity,
304 "listing roots",
305 request_method,
306 )),
307 );
308 }
309 Ok(())
310 }
311
312 fn assert_server_notification_capabilities(
313 &self,
314 notification_method: &String,
315 ) -> std::result::Result<(), RpcError> {
316 let entity = "Server";
317
318 let capabilities = &self.server_info().capabilities;
319
320 if *notification_method == LoggingMessageNotification::method_name()
321 && capabilities.logging.is_none()
322 {
323 return Err(
324 RpcError::internal_error().with_message(format_assertion_message(
325 entity,
326 "logging",
327 notification_method,
328 )),
329 );
330 }
331 if *notification_method == ResourceUpdatedNotification::method_name()
332 && capabilities.resources.is_none()
333 {
334 return Err(
335 RpcError::internal_error().with_message(format_assertion_message(
336 entity,
337 "notifying about resources",
338 notification_method,
339 )),
340 );
341 }
342 if *notification_method == ToolListChangedNotification::method_name()
343 && capabilities.tools.is_none()
344 {
345 return Err(
346 RpcError::internal_error().with_message(format_assertion_message(
347 entity,
348 "notifying of tool list changes",
349 notification_method,
350 )),
351 );
352 }
353 if *notification_method == PromptListChangedNotification::method_name()
354 && capabilities.prompts.is_none()
355 {
356 return Err(
357 RpcError::internal_error().with_message(format_assertion_message(
358 entity,
359 "notifying of prompt list changes",
360 notification_method,
361 )),
362 );
363 }
364
365 Ok(())
366 }
367
368 fn assert_server_request_capabilities(
369 &self,
370 request_method: &String,
371 ) -> std::result::Result<(), RpcError> {
372 let entity = "Server";
373 let capabilities = &self.server_info().capabilities;
374
375 if *request_method == SetLevelRequest::method_name() && capabilities.logging.is_none() {
376 return Err(
377 RpcError::internal_error().with_message(format_assertion_message(
378 entity,
379 "logging",
380 request_method,
381 )),
382 );
383 }
384 if [
385 GetPromptRequest::method_name(),
386 ListPromptsRequest::method_name(),
387 ]
388 .contains(request_method)
389 && capabilities.prompts.is_none()
390 {
391 return Err(
392 RpcError::internal_error().with_message(format_assertion_message(
393 entity,
394 "prompts",
395 request_method,
396 )),
397 );
398 }
399 if [
400 ListResourcesRequest::method_name(),
401 ListResourceTemplatesRequest::method_name(),
402 ReadResourceRequest::method_name(),
403 ]
404 .contains(request_method)
405 && capabilities.resources.is_none()
406 {
407 return Err(
408 RpcError::internal_error().with_message(format_assertion_message(
409 entity,
410 "resources",
411 request_method,
412 )),
413 );
414 }
415 if [
416 CallToolRequest::method_name(),
417 ListToolsRequest::method_name(),
418 ]
419 .contains(request_method)
420 && capabilities.tools.is_none()
421 {
422 return Err(
423 RpcError::internal_error().with_message(format_assertion_message(
424 entity,
425 "tools",
426 request_method,
427 )),
428 );
429 }
430 Ok(())
431 }
432
433 #[cfg(feature = "hyper-server")]
434 fn session_id(&self) -> Option<SessionId>;
435}