rust_mcp_sdk/mcp_http/
mcp_http_handler.rs

1#[cfg(feature = "sse")]
2use super::http_utils::{
3    accepts_event_stream, empty_response, error_response, handle_sse_connection, query_param,
4    validate_mcp_protocol_version_header,
5};
6use super::types::GenericBody;
7use crate::mcp_http::{middleware::compose, BoxFutureResponse, Middleware, RequestHandler};
8use crate::mcp_runtimes::server_runtime::DEFAULT_STREAM_ID;
9use crate::mcp_server::error::TransportServerError;
10use crate::schema::schema_utils::SdkError;
11use crate::{
12    error::McpSdkError,
13    mcp_http::{
14        http_utils::{
15            acceptable_content_type, create_standalone_stream, delete_session,
16            process_incoming_message, process_incoming_message_return, start_new_session,
17            valid_streaming_http_accept_header,
18        },
19        McpAppState,
20    },
21    mcp_server::error::TransportServerResult,
22    utils::valid_initialize_method,
23};
24use http::{self, HeaderMap, Method, StatusCode, Uri};
25use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER};
26use std::sync::Arc;
27
28/// A helper macro to wrap an async handler method into a `RequestHandler`
29/// and compose it with middlewares.
30///
31/// # Example
32/// ```ignore
33/// let handle = with_middlewares!(self, Self::internal_handle_sse_message);
34/// handle
35/// ```
36#[macro_export]
37macro_rules! with_middlewares {
38    ($self:ident, $handler:path) => {{
39        let final_handler: RequestHandler = std::sync::Arc::new(
40            move |req: http::Request<&str>,
41                  state: std::sync::Arc<McpAppState>|
42                  -> BoxFutureResponse<'_> {
43                Box::pin(async move { $handler(req, state).await })
44            },
45        );
46        $crate::mcp_http::middleware::compose(&$self.middlewares, final_handler)
47    }};
48}
49
50#[derive(Clone)]
51pub struct McpHttpHandler {
52    middlewares: Vec<Arc<dyn Middleware>>,
53}
54
55impl Default for McpHttpHandler {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl McpHttpHandler {
62    pub fn new() -> Self {
63        McpHttpHandler {
64            middlewares: vec![],
65        }
66    }
67    pub fn add_middleware<M: Middleware + 'static>(&mut self, middleware: M) {
68        let m: Arc<dyn Middleware> = Arc::new(middleware);
69        self.middlewares.push(m);
70    }
71
72    /// An `http::Request<&str>` initialized with the specified method, URI, headers, and body.
73    /// If the `body` is `None`, an empty string is used as the default.
74    ///
75    pub fn create_request(
76        method: Method,
77        uri: Uri,
78        headers: HeaderMap,
79        body: Option<&str>,
80    ) -> http::Request<&str> {
81        let mut request = http::Request::default();
82        *request.method_mut() = method;
83        *request.uri_mut() = uri;
84        *request.body_mut() = body.unwrap_or_default();
85        let req_headers = request.headers_mut();
86        for (key, value) in headers {
87            if let Some(k) = key {
88                req_headers.insert(k, value);
89            }
90        }
91        request
92    }
93}
94
95impl McpHttpHandler {
96    /// Handles an MCP connection using the SSE (Server-Sent Events) transport.
97    ///
98    /// This function serves as the entry point for initializing and managing a client connection
99    /// over SSE when the `sse` feature is enabled.
100    ///
101    /// # Arguments
102    /// * `state` - Shared application state required to manage the MCP session.
103    /// * `sse_message_endpoint` - Optional message endpoint to override the default SSE route (default: `/messages` ).
104    ///
105    ///
106    /// # Features
107    /// This function is only available when the `sse` feature is enabled.
108    #[cfg(feature = "sse")]
109    pub async fn handle_sse_connection(
110        &self,
111        request: http::Request<&str>,
112        state: Arc<McpAppState>,
113        sse_message_endpoint: Option<&str>,
114    ) -> TransportServerResult<http::Response<GenericBody>> {
115        let sse_endpoint = Arc::from(sse_message_endpoint.map(|s| s.to_string()));
116        let final_handler: RequestHandler = Arc::new(move |_req, state| {
117            let sse_endpoint = sse_endpoint.clone();
118            Box::pin(async move { handle_sse_connection(state, sse_endpoint.as_deref()).await })
119        });
120        let handle = compose(&self.middlewares, final_handler);
121        handle(request, state).await
122    }
123
124    /// Handles incoming MCP messages from the client after an SSE connection is established.
125    ///
126    /// This function processes a message sent by the client as part of an active SSE session. It:
127    /// - Extracts the `sessionId` from the request query parameters.
128    /// - Locates the corresponding session's transmit channel.
129    /// - Forwards the incoming message payload to the MCP transport stream for consumption.
130    /// # Arguments
131    /// * `request` - The HTTP request containing the message body and query parameters (including `sessionId`).
132    /// * `state` - Shared application state, including access to the session store.
133    ///
134    /// # Returns
135    /// * `TransportServerResult<http::Response<GenericBody>>`:
136    ///   - Returns a `202 Accepted` HTTP response if the message is successfully forwarded.
137    ///   - Returns an error if the session ID is missing, invalid, or if any I/O issues occur while processing the message.
138    ///
139    /// # Errors
140    /// - `SessionIdMissing`: if the `sessionId` query parameter is not present.
141    /// - `SessionIdInvalid`: if the session ID does not map to a valid session in the session store.
142    /// - `StreamIoError`: if an error occurs while writing to the stream.
143    /// - `HttpError`: if constructing the HTTP response fails.
144    #[cfg(feature = "sse")]
145    pub async fn handle_sse_message(
146        &self,
147        request: http::Request<&str>,
148        state: Arc<McpAppState>,
149    ) -> TransportServerResult<http::Response<GenericBody>> {
150        let handle = with_middlewares!(self, Self::internal_handle_sse_message);
151        handle(request, state).await
152    }
153
154    /// Handles incoming MCP messages over the StreamableHTTP transport.
155    ///
156    /// It supports `GET`, `POST`, and `DELETE` methods for handling streaming operations, and performs optional
157    /// DNS rebinding protection if it is configured.
158    ///
159    /// # Arguments
160    /// * `request` - The HTTP request from the client, including method, headers, and optional body.
161    /// * `state` - Shared application state, including configuration and session management.
162    ///
163    /// # Behavior
164    /// - If DNS rebinding protection is enabled via the app state, the function checks the request headers.
165    ///   If dns protection fails, a `403 Forbidden` response is returned.
166    /// - Dispatches the request to method-specific handlers based on the HTTP method:
167    ///     - `GET` → `handle_http_get`
168    ///     - `POST` → `handle_http_post`
169    ///     - `DELETE` → `handle_http_delete`
170    /// - Returns `405 Method Not Allowed` for unsupported methods.
171    ///
172    /// # Returns
173    /// * A `TransportServerResult` wrapping an HTTP response indicating success or failure of the operation.
174    ///
175    pub async fn handle_streamable_http(
176        &self,
177        request: http::Request<&str>,
178        state: Arc<McpAppState>,
179    ) -> TransportServerResult<http::Response<GenericBody>> {
180        let handle = with_middlewares!(self, Self::internal_handle_streamable_http);
181        handle(request, state).await
182    }
183
184    async fn internal_handle_sse_message(
185        request: http::Request<&str>,
186        state: Arc<McpAppState>,
187    ) -> TransportServerResult<http::Response<GenericBody>> {
188        let session_id =
189            query_param(&request, "sessionId").ok_or(TransportServerError::SessionIdMissing)?;
190
191        // transmit to the readable stream, that transport is reading from
192        let transmit = state.session_store.get(&session_id).await.ok_or(
193            TransportServerError::SessionIdInvalid(session_id.to_string()),
194        )?;
195
196        let message = request.body();
197
198        transmit
199            .consume_payload_string(DEFAULT_STREAM_ID, message.as_ref())
200            .await
201            .map_err(|err| {
202                tracing::trace!("{}", err);
203                TransportServerError::StreamIoError(err.to_string())
204            })?;
205
206        http::Response::builder()
207            .status(StatusCode::ACCEPTED)
208            .body(empty_response())
209            .map_err(|err| TransportServerError::HttpError(err.to_string()))
210    }
211
212    async fn internal_handle_streamable_http(
213        request: http::Request<&str>,
214        state: Arc<McpAppState>,
215    ) -> TransportServerResult<http::Response<GenericBody>> {
216        let method = request.method();
217        let response = match method {
218            &http::Method::GET => return Self::handle_http_get(request, state).await,
219            &http::Method::POST => return Self::handle_http_post(request, state).await,
220            &http::Method::DELETE => return Self::handle_http_delete(request, state).await,
221            other => {
222                let error = SdkError::bad_request().with_message(&format!(
223                    "'{other}' is not a valid HTTP method for StreamableHTTP transport."
224                ));
225                error_response(StatusCode::METHOD_NOT_ALLOWED, error)
226            }
227        };
228
229        response
230    }
231
232    /// Processes POST requests for the Streamable HTTP Protocol
233    async fn handle_http_post(
234        request: http::Request<&str>,
235        state: Arc<McpAppState>,
236    ) -> TransportServerResult<http::Response<GenericBody>> {
237        let headers = request.headers();
238
239        if !valid_streaming_http_accept_header(headers) {
240            let error = SdkError::bad_request()
241                .with_message(r#"Client must accept both application/json and text/event-stream"#);
242            return error_response(StatusCode::NOT_ACCEPTABLE, error);
243        }
244
245        if !acceptable_content_type(headers) {
246            let error = SdkError::bad_request()
247                .with_message(r#"Unsupported Media Type: Content-Type must be application/json"#);
248            return error_response(StatusCode::UNSUPPORTED_MEDIA_TYPE, error);
249        }
250
251        if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
252            let error = SdkError::bad_request()
253                .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
254            return error_response(StatusCode::BAD_REQUEST, error);
255        }
256
257        let session_id: Option<SessionId> = headers
258            .get(MCP_SESSION_ID_HEADER)
259            .and_then(|value| value.to_str().ok())
260            .map(|s| s.to_string());
261
262        let payload = request.body();
263
264        let response = match session_id {
265            // has session-id => write to the existing stream
266            Some(id) => {
267                if state.enable_json_response {
268                    process_incoming_message_return(id, state, payload).await
269                } else {
270                    process_incoming_message(id, state, payload).await
271                }
272            }
273            None => match valid_initialize_method(payload) {
274                Ok(_) => {
275                    return start_new_session(state, payload).await;
276                }
277                Err(McpSdkError::SdkError(error)) => error_response(StatusCode::BAD_REQUEST, error),
278                Err(error) => {
279                    let error = SdkError::bad_request().with_message(&error.to_string());
280                    error_response(StatusCode::BAD_REQUEST, error)
281                }
282            },
283        };
284
285        response
286    }
287
288    /// Processes GET requests for the Streamable HTTP Protocol
289    async fn handle_http_get(
290        request: http::Request<&str>,
291        state: Arc<McpAppState>,
292    ) -> TransportServerResult<http::Response<GenericBody>> {
293        let headers = request.headers();
294
295        if !accepts_event_stream(headers) {
296            let error =
297                SdkError::bad_request().with_message(r#"Client must accept text/event-stream"#);
298            return error_response(StatusCode::NOT_ACCEPTABLE, error);
299        }
300
301        if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
302            let error = SdkError::bad_request()
303                .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
304            return error_response(StatusCode::BAD_REQUEST, error);
305        }
306
307        let session_id: Option<SessionId> = headers
308            .get(MCP_SESSION_ID_HEADER)
309            .and_then(|value| value.to_str().ok())
310            .map(|s| s.to_string());
311
312        let last_event_id: Option<SessionId> = headers
313            .get(MCP_LAST_EVENT_ID_HEADER)
314            .and_then(|value| value.to_str().ok())
315            .map(|s| s.to_string());
316
317        let response = match session_id {
318            Some(session_id) => {
319                let res = create_standalone_stream(session_id, last_event_id, state).await;
320                res
321            }
322            None => {
323                let error = SdkError::bad_request().with_message("Bad request: session not found");
324                error_response(StatusCode::BAD_REQUEST, error)
325            }
326        };
327
328        response
329    }
330
331    /// Processes DELETE requests for the Streamable HTTP Protocol
332    async fn handle_http_delete(
333        request: http::Request<&str>,
334        state: Arc<McpAppState>,
335    ) -> TransportServerResult<http::Response<GenericBody>> {
336        let headers = request.headers();
337
338        if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
339            let error = SdkError::bad_request()
340                .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
341            return error_response(StatusCode::BAD_REQUEST, error);
342        }
343
344        let session_id: Option<SessionId> = headers
345            .get(MCP_SESSION_ID_HEADER)
346            .and_then(|value| value.to_str().ok())
347            .map(|s| s.to_string());
348
349        let response = match session_id {
350            Some(id) => delete_session(id, state).await,
351            None => {
352                let error = SdkError::bad_request().with_message("Bad Request: Session not found");
353                error_response(StatusCode::BAD_REQUEST, error)
354            }
355        };
356
357        response
358    }
359}