1use crate::auth::AuthInfo;
2use crate::error::SdkResult;
3use crate::schema::{
4 schema_utils::{
5 ClientMessage, McpMessage, MessageFromServer, NotificationFromServer, RequestFromServer,
6 ResultFromClient, ServerMessage,
7 },
8 CreateMessageRequestParams, CreateMessageResult, ElicitRequestParams, ElicitResult,
9 Implementation, InitializeRequestParams, InitializeResult, ListRootsResult,
10 LoggingMessageNotificationParams, NotificationParams, RequestId, RequestParams,
11 ResourceUpdatedNotificationParams, RpcError, ServerCapabilities,
12};
13use crate::task_store::{ClientTaskStore, CreateTaskOptions, ServerTaskStore};
14use async_trait::async_trait;
15use rust_mcp_schema::schema_utils::{
16 ClientTaskResult, CustomNotification, CustomRequest, ServerJsonrpcRequest,
17};
18use rust_mcp_schema::{
19 CancelTaskParams, CancelTaskResult, CancelledNotificationParams, CreateTaskResult,
20 ElicitCompleteParams, GenericResult, GetTaskParams, GetTaskPayloadParams, GetTaskResult,
21 ListTasksResult, PaginatedRequestParams, ProgressNotificationParams,
22 TaskStatusNotificationParams,
23};
24use rust_mcp_transport::SessionId;
25use std::{sync::Arc, time::Duration};
26use tokio::sync::RwLockReadGuard;
27
28#[async_trait]
29pub trait McpServer: Sync + Send {
30 async fn start(self: Arc<Self>) -> SdkResult<()>;
31 async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>;
32 fn server_info(&self) -> &InitializeResult;
33 fn client_info(&self) -> Option<InitializeRequestParams>;
34
35 async fn auth_info(&self) -> RwLockReadGuard<'_, Option<AuthInfo>>;
36 async fn auth_info_cloned(&self) -> Option<AuthInfo>;
37 async fn update_auth_info(&self, auth_info: Option<AuthInfo>);
38
39 async fn wait_for_initialization(&self);
40
41 fn task_store(&self) -> Option<Arc<ServerTaskStore>>;
45
46 fn client_task_store(&self) -> Option<Arc<ClientTaskStore>>;
51
52 fn client_supports_sampling(&self) -> Option<bool> {
64 self.client_info()
65 .map(|client_details| client_details.capabilities.sampling.is_some())
66 }
67
68 fn client_supports_root_list(&self) -> Option<bool> {
80 self.client_info()
81 .map(|client_details| client_details.capabilities.roots.is_some())
82 }
83
84 fn client_supports_experimental(&self) -> Option<bool> {
96 self.client_info()
97 .map(|client_details| client_details.capabilities.experimental.is_some())
98 }
99
100 async fn stderr_message(&self, message: String) -> SdkResult<()>;
102
103 #[cfg(feature = "hyper-server")]
104 fn session_id(&self) -> Option<SessionId>;
105
106 async fn send(
107 &self,
108 message: MessageFromServer,
109 request_id: Option<RequestId>,
110 request_timeout: Option<Duration>,
111 ) -> SdkResult<Option<ClientMessage>>;
112
113 async fn send_batch(
114 &self,
115 messages: Vec<ServerMessage>,
116 request_timeout: Option<Duration>,
117 ) -> SdkResult<Option<Vec<ClientMessage>>>;
118
119 fn is_initialized(&self) -> bool {
121 self.client_info().is_some()
122 }
123
124 fn client_version(&self) -> Option<Implementation> {
127 self.client_info()
128 .map(|client_details| client_details.client_info)
129 }
130
131 fn capabilities(&self) -> &ServerCapabilities {
133 &self.server_info().capabilities
134 }
135
136 async fn request(
146 &self,
147 request: RequestFromServer,
148 timeout: Option<Duration>,
149 ) -> SdkResult<ResultFromClient> {
150 let request_clone = if request.is_task_augmented() {
152 Some(request.clone())
153 } else {
154 None
155 };
156 let response = self
158 .send(MessageFromServer::RequestFromServer(request), None, timeout)
159 .await?;
160
161 let client_message = response.ok_or_else(|| {
162 RpcError::internal_error()
163 .with_message("An empty response was received from the client.".to_string())
164 })?;
165
166 if client_message.is_error() {
167 return Err(client_message.as_error()?.error.into());
168 }
169
170 let client_response = client_message.as_response()?;
171
172 if let ResultFromClient::CreateTaskResult(create_task_result) = &client_response.result {
176 if let Some(request_to_store) = request_clone {
177 if let Some(client_task_store) = self.client_task_store() {
178 let session_id = {
179 #[cfg(feature = "hyper-server")]
180 {
181 self.session_id()
182 }
183 #[cfg(not(feature = "hyper-server"))]
184 None
185 };
186 client_task_store
187 .create_task(
188 CreateTaskOptions {
189 ttl: create_task_result.task.ttl,
190 poll_interval: create_task_result.task.poll_interval,
191 meta: create_task_result.meta.clone(),
192 },
193 client_response.id.clone(),
194 ServerJsonrpcRequest::new(client_response.id, request_to_store),
195 session_id,
196 )
197 .await;
198 }
199 } else {
200 return Err(RpcError::internal_error()
201 .with_message("No eligible request found for task storage.".to_string())
202 .into());
203 }
204 }
205
206 return Ok(client_response.result);
207 }
208
209 async fn request_elicitation(&self, params: ElicitRequestParams) -> SdkResult<ElicitResult> {
214 let response = self
215 .request(RequestFromServer::ElicitRequest(params), None)
216 .await?;
217 ElicitResult::try_from(response).map_err(|err| err.into())
218 }
219
220 async fn request_elicitation_task(
221 &self,
222 params: ElicitRequestParams,
223 ) -> SdkResult<CreateTaskResult> {
224 if !params.is_task_augmented() {
225 return Err(RpcError::invalid_params()
226 .with_message(
227 "Invalid parameters: the request is not identified as task-augmented."
228 .to_string(),
229 )
230 .into());
231 }
232 let response = self
233 .request(RequestFromServer::ElicitRequest(params), None)
234 .await?;
235
236 let response = CreateTaskResult::try_from(response)?;
237
238 Ok(response)
239 }
240
241 async fn request_root_list(&self, params: Option<RequestParams>) -> SdkResult<ListRootsResult> {
247 let response = self
248 .request(RequestFromServer::ListRootsRequest(params), None)
249 .await?;
250 ListRootsResult::try_from(response).map_err(|err| err.into())
251 }
252
253 async fn ping(
264 &self,
265 params: Option<RequestParams>,
266 timeout: Option<Duration>,
267 ) -> SdkResult<crate::schema::Result> {
268 let response = self
269 .request(RequestFromServer::PingRequest(params), timeout)
270 .await?;
271 Ok(response.try_into()?)
272 }
273
274 async fn request_message_creation(
280 &self,
281 params: CreateMessageRequestParams,
282 ) -> SdkResult<CreateMessageResult> {
283 let response = self
284 .request(RequestFromServer::CreateMessageRequest(params), None)
285 .await?;
286 Ok(response.try_into()?)
287 }
288
289 async fn request_get_task(&self, params: GetTaskParams) -> SdkResult<GetTaskResult> {
291 let response = self
292 .request(RequestFromServer::GetTaskRequest(params), None)
293 .await?;
294 Ok(response.try_into()?)
295 }
296
297 async fn request_get_task_payload(
299 &self,
300 params: GetTaskPayloadParams,
301 ) -> SdkResult<ClientTaskResult> {
302 let response = self
303 .request(RequestFromServer::GetTaskPayloadRequest(params), None)
304 .await?;
305 Ok(response.try_into()?)
306 }
307
308 async fn request_task_cancellation(
310 &self,
311 params: CancelTaskParams,
312 ) -> SdkResult<CancelTaskResult> {
313 let response = self
314 .request(RequestFromServer::CancelTaskRequest(params), None)
315 .await?;
316 Ok(response.try_into()?)
317 }
318
319 async fn request_task_list(
321 &self,
322 params: Option<PaginatedRequestParams>,
323 ) -> SdkResult<ListTasksResult> {
324 let response = self
325 .request(RequestFromServer::ListTasksRequest(params), None)
326 .await?;
327 Ok(response.try_into()?)
328 }
329
330 async fn request_custom(&self, params: CustomRequest) -> SdkResult<GenericResult> {
332 let response = self
333 .request(RequestFromServer::CustomRequest(params), None)
334 .await?;
335 Ok(response.try_into()?)
336 }
337
338 async fn send_notification(&self, notification: NotificationFromServer) -> SdkResult<()> {
346 self.send(
347 MessageFromServer::NotificationFromServer(notification),
348 None,
349 None,
350 )
351 .await?;
352 Ok(())
353 }
354
355 async fn notify_log_message(&self, params: LoggingMessageNotificationParams) -> SdkResult<()> {
358 self.send_notification(NotificationFromServer::LoggingMessageNotification(params))
359 .await
360 }
361
362 async fn notify_prompt_list_changed(
366 &self,
367 params: Option<NotificationParams>,
368 ) -> SdkResult<()> {
369 self.send_notification(NotificationFromServer::PromptListChangedNotification(
370 params,
371 ))
372 .await
373 }
374
375 async fn notify_resource_list_changed(
379 &self,
380 params: Option<NotificationParams>,
381 ) -> SdkResult<()> {
382 self.send_notification(NotificationFromServer::ResourceListChangedNotification(
383 params,
384 ))
385 .await
386 }
387
388 async fn notify_resource_updated(
392 &self,
393 params: ResourceUpdatedNotificationParams,
394 ) -> SdkResult<()> {
395 self.send_notification(NotificationFromServer::ResourceUpdatedNotification(params))
396 .await
397 }
398
399 async fn notify_tool_list_changed(&self, params: Option<NotificationParams>) -> SdkResult<()> {
403 self.send_notification(NotificationFromServer::ToolListChangedNotification(params))
404 .await
405 }
406
407 async fn notify_cancellation(&self, params: CancelledNotificationParams) -> SdkResult<()> {
413 self.send_notification(NotificationFromServer::CancelledNotification(params))
414 .await
415 }
416
417 async fn notify_progress(&self, params: ProgressNotificationParams) -> SdkResult<()> {
419 self.send_notification(NotificationFromServer::ProgressNotification(params))
420 .await
421 }
422
423 async fn notify_task_status(&self, params: TaskStatusNotificationParams) -> SdkResult<()> {
426 self.send_notification(NotificationFromServer::TaskStatusNotification(params))
427 .await
428 }
429
430 async fn notify_elicitation_completed(&self, params: ElicitCompleteParams) -> SdkResult<()> {
432 self.send_notification(NotificationFromServer::ElicitationCompleteNotification(
433 params,
434 ))
435 .await
436 }
437
438 async fn notify_custom(&self, params: CustomNotification) -> SdkResult<()> {
440 self.send_notification(NotificationFromServer::CustomNotification(params))
441 .await
442 }
443
444 #[deprecated(since = "0.8.0", note = "Use `request_root_list()` instead.")]
445 async fn list_roots(&self, params: Option<RequestParams>) -> SdkResult<ListRootsResult> {
446 let response = self
447 .request(RequestFromServer::ListRootsRequest(params), None)
448 .await?;
449 ListRootsResult::try_from(response).map_err(|err| err.into())
450 }
451
452 #[deprecated(since = "0.8.0", note = "Use `request_elicitation()` instead.")]
453 async fn elicit_input(&self, params: ElicitRequestParams) -> SdkResult<ElicitResult> {
454 let response = self
455 .request(RequestFromServer::ElicitRequest(params), None)
456 .await?;
457 ElicitResult::try_from(response).map_err(|err| err.into())
458 }
459
460 #[deprecated(since = "0.8.0", note = "Use `request_message_creation()` instead.")]
461 async fn create_message(
462 &self,
463 params: CreateMessageRequestParams,
464 ) -> SdkResult<CreateMessageResult> {
465 let response = self
466 .request(RequestFromServer::CreateMessageRequest(params), None)
467 .await?;
468 Ok(response.try_into()?)
469 }
470
471 #[deprecated(since = "0.8.0", note = "Use `notify_tool_list_changed()` instead.")]
472 async fn send_tool_list_changed(&self, params: Option<NotificationParams>) -> SdkResult<()> {
473 self.send_notification(NotificationFromServer::ToolListChangedNotification(params))
474 .await
475 }
476
477 #[deprecated(since = "0.8.0", note = "Use `notify_resource_updated()` instead.")]
478 async fn send_resource_updated(
479 &self,
480 params: ResourceUpdatedNotificationParams,
481 ) -> SdkResult<()> {
482 self.send_notification(NotificationFromServer::ResourceUpdatedNotification(params))
483 .await
484 }
485
486 #[deprecated(
487 since = "0.8.0",
488 note = "Use `notify_resource_list_changed()` instead."
489 )]
490 async fn send_resource_list_changed(
491 &self,
492 params: Option<NotificationParams>,
493 ) -> SdkResult<()> {
494 self.send_notification(NotificationFromServer::ResourceListChangedNotification(
495 params,
496 ))
497 .await
498 }
499
500 #[deprecated(since = "0.8.0", note = "Use `notify_prompt_list_changed()` instead.")]
501 async fn send_prompt_list_changed(&self, params: Option<NotificationParams>) -> SdkResult<()> {
502 self.send_notification(NotificationFromServer::PromptListChangedNotification(
503 params,
504 ))
505 .await
506 }
507
508 #[deprecated(since = "0.8.0", note = "Use `notify_log_message()` instead.")]
509 async fn send_logging_message(
510 &self,
511 params: LoggingMessageNotificationParams,
512 ) -> SdkResult<()> {
513 self.send_notification(NotificationFromServer::LoggingMessageNotification(params))
514 .await
515 }
516}