rust_mcp_sdk/mcp_runtimes/server_runtime/
mcp_server_runtime.rs1use std::sync::Arc;
2
3use crate::schema::{
4    schema_utils::{
5        self, CallToolError, ClientMessage, ClientMessages, MessageFromServer,
6        NotificationFromClient, RequestFromClient, ResultFromServer, ServerMessage, ServerMessages,
7    },
8    CallToolResult, ClientNotification, ClientRequest, InitializeResult, RpcError,
9};
10use async_trait::async_trait;
11
12use rust_mcp_transport::TransportDispatcher;
13
14use super::ServerRuntime;
15#[cfg(feature = "hyper-server")]
16use rust_mcp_transport::SessionId;
17
18use crate::{
19    error::SdkResult,
20    mcp_handlers::mcp_server_handler::ServerHandler,
21    mcp_traits::{mcp_handler::McpServerHandler, mcp_server::McpServer},
22};
23
24pub fn create_server(
43    server_details: InitializeResult,
44    transport: impl TransportDispatcher<
45        ClientMessages,
46        MessageFromServer,
47        ClientMessage,
48        ServerMessages,
49        ServerMessage,
50    >,
51    handler: impl ServerHandler,
52) -> Arc<ServerRuntime> {
53    ServerRuntime::new(
54        server_details,
55        transport,
56        Arc::new(ServerRuntimeInternalHandler::new(Box::new(handler))),
57    )
58}
59
60#[cfg(feature = "hyper-server")]
61pub(crate) fn create_server_instance(
62    server_details: Arc<InitializeResult>,
63    handler: Arc<dyn McpServerHandler>,
64    session_id: SessionId,
65) -> Arc<ServerRuntime> {
66    ServerRuntime::new_instance(server_details, handler, session_id)
67}
68
69pub(crate) struct ServerRuntimeInternalHandler<H> {
70    handler: H,
71}
72impl ServerRuntimeInternalHandler<Box<dyn ServerHandler>> {
73    pub fn new(handler: Box<dyn ServerHandler>) -> Self {
74        Self { handler }
75    }
76}
77
78#[async_trait]
79impl McpServerHandler for ServerRuntimeInternalHandler<Box<dyn ServerHandler>> {
80    async fn handle_request(
81        &self,
82        client_jsonrpc_request: RequestFromClient,
83        runtime: Arc<dyn McpServer>,
84    ) -> std::result::Result<ResultFromServer, RpcError> {
85        match client_jsonrpc_request {
86            schema_utils::RequestFromClient::ClientRequest(client_request) => {
87                match client_request {
88                    ClientRequest::InitializeRequest(initialize_request) => self
89                        .handler
90                        .handle_initialize_request(initialize_request, runtime)
91                        .await
92                        .map(|value| value.into()),
93                    ClientRequest::PingRequest(ping_request) => self
94                        .handler
95                        .handle_ping_request(ping_request, runtime)
96                        .await
97                        .map(|value| value.into()),
98                    ClientRequest::ListResourcesRequest(list_resources_request) => self
99                        .handler
100                        .handle_list_resources_request(list_resources_request, runtime)
101                        .await
102                        .map(|value| value.into()),
103                    ClientRequest::ListResourceTemplatesRequest(
104                        list_resource_templates_request,
105                    ) => self
106                        .handler
107                        .handle_list_resource_templates_request(
108                            list_resource_templates_request,
109                            runtime,
110                        )
111                        .await
112                        .map(|value| value.into()),
113                    ClientRequest::ReadResourceRequest(read_resource_request) => self
114                        .handler
115                        .handle_read_resource_request(read_resource_request, runtime)
116                        .await
117                        .map(|value| value.into()),
118                    ClientRequest::SubscribeRequest(subscribe_request) => self
119                        .handler
120                        .handle_subscribe_request(subscribe_request, runtime)
121                        .await
122                        .map(|value| value.into()),
123                    ClientRequest::UnsubscribeRequest(unsubscribe_request) => self
124                        .handler
125                        .handle_unsubscribe_request(unsubscribe_request, runtime)
126                        .await
127                        .map(|value| value.into()),
128                    ClientRequest::ListPromptsRequest(list_prompts_request) => self
129                        .handler
130                        .handle_list_prompts_request(list_prompts_request, runtime)
131                        .await
132                        .map(|value| value.into()),
133
134                    ClientRequest::GetPromptRequest(prompt_request) => self
135                        .handler
136                        .handle_get_prompt_request(prompt_request, runtime)
137                        .await
138                        .map(|value| value.into()),
139                    ClientRequest::ListToolsRequest(list_tools_request) => self
140                        .handler
141                        .handle_list_tools_request(list_tools_request, runtime)
142                        .await
143                        .map(|value| value.into()),
144                    ClientRequest::CallToolRequest(call_tool_request) => {
145                        let result = self
146                            .handler
147                            .handle_call_tool_request(call_tool_request, runtime)
148                            .await;
149
150                        Ok(result.map_or_else(
151                            |err| {
152                                let result: CallToolResult = CallToolError::new(err).into();
153                                result.into()
154                            },
155                            Into::into,
156                        ))
157                    }
158                    ClientRequest::SetLevelRequest(set_level_request) => self
159                        .handler
160                        .handle_set_level_request(set_level_request, runtime)
161                        .await
162                        .map(|value| value.into()),
163                    ClientRequest::CompleteRequest(complete_request) => self
164                        .handler
165                        .handle_complete_request(complete_request, runtime)
166                        .await
167                        .map(|value| value.into()),
168                }
169            }
170            schema_utils::RequestFromClient::CustomRequest(value) => self
171                .handler
172                .handle_custom_request(value, runtime)
173                .await
174                .map(|value| value.into()),
175        }
176    }
177
178    async fn handle_error(
179        &self,
180        jsonrpc_error: &RpcError,
181        runtime: Arc<dyn McpServer>,
182    ) -> SdkResult<()> {
183        self.handler.handle_error(jsonrpc_error, runtime).await?;
184        Ok(())
185    }
186
187    async fn handle_notification(
188        &self,
189        client_jsonrpc_notification: NotificationFromClient,
190        runtime: Arc<dyn McpServer>,
191    ) -> SdkResult<()> {
192        match client_jsonrpc_notification {
193            schema_utils::NotificationFromClient::ClientNotification(client_notification) => {
194                match client_notification {
195                    ClientNotification::CancelledNotification(cancelled_notification) => {
196                        self.handler
197                            .handle_cancelled_notification(cancelled_notification, runtime)
198                            .await?;
199                    }
200                    ClientNotification::InitializedNotification(initialized_notification) => {
201                        self.handler
202                            .handle_initialized_notification(
203                                initialized_notification,
204                                runtime.clone(),
205                            )
206                            .await?;
207                        self.handler.on_initialized(runtime).await;
208                    }
209                    ClientNotification::ProgressNotification(progress_notification) => {
210                        self.handler
211                            .handle_progress_notification(progress_notification, runtime)
212                            .await?;
213                    }
214                    ClientNotification::RootsListChangedNotification(
215                        roots_list_changed_notification,
216                    ) => {
217                        self.handler
218                            .handle_roots_list_changed_notification(
219                                roots_list_changed_notification,
220                                runtime,
221                            )
222                            .await?;
223                    }
224                }
225            }
226            schema_utils::NotificationFromClient::CustomNotification(value) => {
227                self.handler.handle_custom_notification(value).await?;
228            }
229        }
230        Ok(())
231    }
232}