rust_mcp_sdk/mcp_http/
mcp_http_handler.rs

1#[cfg(feature = "sse")]
2use super::utils::handle_sse_connection;
3use crate::mcp_http::mcp_http_middleware::MiddlewareChain;
4use crate::mcp_http::utils::{
5    accepts_event_stream, empty_response, error_response, query_param,
6    validate_mcp_protocol_version_header,
7};
8use crate::mcp_http::Middleware;
9use crate::mcp_runtimes::server_runtime::DEFAULT_STREAM_ID;
10use crate::mcp_server::error::TransportServerError;
11use crate::schema::schema_utils::SdkError;
12use crate::{
13    error::McpSdkError,
14    mcp_http::{
15        utils::{
16            acceptable_content_type, create_standalone_stream, delete_session,
17            process_incoming_message, process_incoming_message_return, protect_dns_rebinding,
18            start_new_session, valid_streaming_http_accept_header, GenericBody,
19        },
20        McpAppState,
21    },
22    mcp_server::error::TransportServerResult,
23    utils::valid_initialize_method,
24};
25use http::{self, HeaderMap, Method, StatusCode, Uri};
26use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER};
27use std::sync::Arc;
28
29#[derive(Clone)]
30pub struct McpHttpHandler {
31    middleware_chain: MiddlewareChain,
32}
33
34impl Default for McpHttpHandler {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl McpHttpHandler {
41    pub fn new() -> Self {
42        McpHttpHandler {
43            middleware_chain: MiddlewareChain::new(),
44        }
45    }
46
47    pub fn add_middleware<M: Middleware>(&mut self, middleware: M) {
48        self.middleware_chain.add_middleware(middleware);
49    }
50
51    /// An `http::Request<&str>` initialized with the specified method, URI, headers, and body.
52    /// If the `body` is `None`, an empty string is used as the default.
53    ///
54    pub fn create_request(
55        method: Method,
56        uri: Uri,
57        headers: HeaderMap,
58        body: Option<&str>,
59    ) -> http::Request<&str> {
60        let mut request = http::Request::default();
61        *request.method_mut() = method;
62        *request.uri_mut() = uri;
63        *request.body_mut() = body.unwrap_or_default();
64        let req_headers = request.headers_mut();
65        for (key, value) in headers {
66            if let Some(k) = key {
67                req_headers.insert(k, value);
68            }
69        }
70        request
71    }
72}
73
74impl McpHttpHandler {
75    /// Handles an MCP connection using the SSE (Server-Sent Events) transport.
76    ///
77    /// This function serves as the entry point for initializing and managing a client connection
78    /// over SSE when the `sse` feature is enabled.
79    ///
80    /// # Arguments
81    /// * `state` - Shared application state required to manage the MCP session.
82    /// * `sse_message_endpoint` - Optional message endpoint to override the default SSE route (default: `/messages` ).
83    ///
84    ///
85    /// # Features
86    /// This function is only available when the `sse` feature is enabled.
87    #[cfg(feature = "sse")]
88    pub async fn handle_sse_connection(
89        &self,
90        state: Arc<McpAppState>,
91        sse_message_endpoint: Option<&str>,
92    ) -> TransportServerResult<http::Response<GenericBody>> {
93        handle_sse_connection(state, sse_message_endpoint).await
94    }
95
96    /// Handles incoming MCP messages from the client after an SSE connection is established.
97    ///
98    /// This function processes a message sent by the client as part of an active SSE session. It:
99    /// - Extracts the `sessionId` from the request query parameters.
100    /// - Locates the corresponding session's transmit channel.
101    /// - Forwards the incoming message payload to the MCP transport stream for consumption.
102    /// # Arguments
103    /// * `request` - The HTTP request containing the message body and query parameters (including `sessionId`).
104    /// * `state` - Shared application state, including access to the session store.
105    ///
106    /// # Returns
107    /// * `TransportServerResult<http::Response<GenericBody>>`:
108    ///   - Returns a `202 Accepted` HTTP response if the message is successfully forwarded.
109    ///   - Returns an error if the session ID is missing, invalid, or if any I/O issues occur while processing the message.
110    ///
111    /// # Errors
112    /// - `SessionIdMissing`: if the `sessionId` query parameter is not present.
113    /// - `SessionIdInvalid`: if the session ID does not map to a valid session in the session store.
114    /// - `StreamIoError`: if an error occurs while writing to the stream.
115    /// - `HttpError`: if constructing the HTTP response fails.
116    pub async fn handle_sse_message(
117        &self,
118        request: http::Request<&str>,
119        state: Arc<McpAppState>,
120    ) -> TransportServerResult<http::Response<GenericBody>> {
121        let session_id =
122            query_param(&request, "sessionId").ok_or(TransportServerError::SessionIdMissing)?;
123
124        // transmit to the readable stream, that transport is reading from
125        let transmit = state.session_store.get(&session_id).await.ok_or(
126            TransportServerError::SessionIdInvalid(session_id.to_string()),
127        )?;
128
129        let message = *request.body();
130        transmit
131            .consume_payload_string(DEFAULT_STREAM_ID, message)
132            .await
133            .map_err(|err| {
134                tracing::trace!("{}", err);
135                TransportServerError::StreamIoError(err.to_string())
136            })?;
137
138        http::Response::builder()
139            .status(StatusCode::ACCEPTED)
140            .body(empty_response())
141            .map_err(|err| TransportServerError::HttpError(err.to_string()))
142    }
143
144    /// Handles incoming MCP messages over the StreamableHTTP transport.
145    ///
146    /// It supports `GET`, `POST`, and `DELETE` methods for handling streaming operations, and performs optional
147    /// DNS rebinding protection if it is configured.
148    ///
149    /// # Arguments
150    /// * `request` - The HTTP request from the client, including method, headers, and optional body.
151    /// * `state` - Shared application state, including configuration and session management.
152    ///
153    /// # Behavior
154    /// - If DNS rebinding protection is enabled via the app state, the function checks the request headers.
155    ///   If dns protection fails, a `403 Forbidden` response is returned.
156    /// - Dispatches the request to method-specific handlers based on the HTTP method:
157    ///     - `GET` → `handle_http_get`
158    ///     - `POST` → `handle_http_post`
159    ///     - `DELETE` → `handle_http_delete`
160    /// - Returns `405 Method Not Allowed` for unsupported methods.
161    ///
162    /// # Returns
163    /// * A `TransportServerResult` wrapping an HTTP response indicating success or failure of the operation.
164    ///
165    pub async fn handle_streamable_http(
166        &self,
167        request: http::Request<&str>,
168        state: Arc<McpAppState>,
169    ) -> TransportServerResult<http::Response<GenericBody>> {
170        let request = self
171            .middleware_chain
172            .process_request(request)
173            .await
174            .map_err(|e| TransportServerError::HttpError(e.to_string()))?;
175
176        // Enforces DNS rebinding protection if required by state.
177        // If protection fails, respond with HTTP 403 Forbidden.
178        if state.needs_dns_protection() {
179            if let Err(error) = protect_dns_rebinding(request.headers(), state.clone()).await {
180                return error_response(StatusCode::FORBIDDEN, error);
181            }
182        }
183
184        let method = request.method();
185        let response = match method {
186            &http::Method::GET => return self.handle_http_get(request, state).await,
187            &http::Method::POST => return self.handle_http_post(request, state).await,
188            &http::Method::DELETE => return self.handle_http_delete(request, state).await,
189            other => {
190                let error = SdkError::bad_request().with_message(&format!(
191                    "'{other}' is not a valid HTTP method for StreamableHTTP transport."
192                ));
193                error_response(StatusCode::METHOD_NOT_ALLOWED, error)
194            }
195        };
196
197        self.middleware_chain
198            .process_response(response?)
199            .await
200            .map_err(|e| TransportServerError::HttpError(e.to_string()))
201    }
202
203    /// Processes POST requests for the Streamable HTTP Protocol
204    async fn handle_http_post(
205        &self,
206        request: http::Request<&str>,
207        state: Arc<McpAppState>,
208    ) -> TransportServerResult<http::Response<GenericBody>> {
209        let request = self
210            .middleware_chain
211            .process_request(request)
212            .await
213            .map_err(|e| TransportServerError::HttpError(e.to_string()))?;
214
215        let headers = request.headers();
216
217        if !valid_streaming_http_accept_header(headers) {
218            let error = SdkError::bad_request()
219                .with_message(r#"Client must accept both application/json and text/event-stream"#);
220            return error_response(StatusCode::NOT_ACCEPTABLE, error);
221        }
222
223        if !acceptable_content_type(headers) {
224            let error = SdkError::bad_request()
225                .with_message(r#"Unsupported Media Type: Content-Type must be application/json"#);
226            return error_response(StatusCode::UNSUPPORTED_MEDIA_TYPE, error);
227        }
228
229        if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
230            let error = SdkError::bad_request()
231                .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
232            return error_response(StatusCode::BAD_REQUEST, error);
233        }
234
235        let session_id: Option<SessionId> = headers
236            .get(MCP_SESSION_ID_HEADER)
237            .and_then(|value| value.to_str().ok())
238            .map(|s| s.to_string());
239
240        let payload = *request.body();
241
242        let response = match session_id {
243            // has session-id => write to the existing stream
244            Some(id) => {
245                if state.enable_json_response {
246                    process_incoming_message_return(id, state, payload).await
247                } else {
248                    process_incoming_message(id, state, payload).await
249                }
250            }
251            None => match valid_initialize_method(payload) {
252                Ok(_) => {
253                    return start_new_session(state, payload).await;
254                }
255                Err(McpSdkError::SdkError(error)) => error_response(StatusCode::BAD_REQUEST, error),
256                Err(error) => {
257                    let error = SdkError::bad_request().with_message(&error.to_string());
258                    error_response(StatusCode::BAD_REQUEST, error)
259                }
260            },
261        };
262
263        self.middleware_chain
264            .process_response(response?)
265            .await
266            .map_err(|e| TransportServerError::HttpError(e.to_string()))
267    }
268
269    /// Processes GET requests for the Streamable HTTP Protocol
270    async fn handle_http_get(
271        &self,
272        request: http::Request<&str>,
273        state: Arc<McpAppState>,
274    ) -> TransportServerResult<http::Response<GenericBody>> {
275        let request = self
276            .middleware_chain
277            .process_request(request)
278            .await
279            .map_err(|e| TransportServerError::HttpError(e.to_string()))?;
280
281        let headers = request.headers();
282
283        if !accepts_event_stream(headers) {
284            let error =
285                SdkError::bad_request().with_message(r#"Client must accept text/event-stream"#);
286            return error_response(StatusCode::NOT_ACCEPTABLE, error);
287        }
288
289        if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
290            let error = SdkError::bad_request()
291                .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
292            return error_response(StatusCode::BAD_REQUEST, error);
293        }
294
295        let session_id: Option<SessionId> = headers
296            .get(MCP_SESSION_ID_HEADER)
297            .and_then(|value| value.to_str().ok())
298            .map(|s| s.to_string());
299
300        let last_event_id: Option<SessionId> = headers
301            .get(MCP_LAST_EVENT_ID_HEADER)
302            .and_then(|value| value.to_str().ok())
303            .map(|s| s.to_string());
304
305        let response = match session_id {
306            Some(session_id) => {
307                let res = create_standalone_stream(session_id, last_event_id, state).await;
308                res
309            }
310            None => {
311                let error = SdkError::bad_request().with_message("Bad request: session not found");
312                error_response(StatusCode::BAD_REQUEST, error)
313            }
314        };
315
316        self.middleware_chain
317            .process_response(response?)
318            .await
319            .map_err(|e| TransportServerError::HttpError(e.to_string()))
320    }
321
322    /// Processes DELETE requests for the Streamable HTTP Protocol
323    async fn handle_http_delete(
324        &self,
325        request: http::Request<&str>,
326        state: Arc<McpAppState>,
327    ) -> TransportServerResult<http::Response<GenericBody>> {
328        let request = self
329            .middleware_chain
330            .process_request(request)
331            .await
332            .map_err(|e| TransportServerError::HttpError(e.to_string()))?;
333
334        let headers = request.headers();
335
336        if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
337            let error = SdkError::bad_request()
338                .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
339            return error_response(StatusCode::BAD_REQUEST, error);
340        }
341
342        let session_id: Option<SessionId> = headers
343            .get(MCP_SESSION_ID_HEADER)
344            .and_then(|value| value.to_str().ok())
345            .map(|s| s.to_string());
346
347        let response = match session_id {
348            Some(id) => delete_session(id, state).await,
349            None => {
350                let error = SdkError::bad_request().with_message("Bad Request: Session not found");
351                error_response(StatusCode::BAD_REQUEST, error)
352            }
353        };
354
355        self.middleware_chain
356            .process_response(response?)
357            .await
358            .map_err(|e| TransportServerError::HttpError(e.to_string()))
359    }
360}