rust_mcp_sdk/mcp_http/
mcp_http_handler.rs

1#[cfg(feature = "sse")]
2use super::http_utils::handle_sse_connection;
3use super::http_utils::{
4    accepts_event_stream, error_response, query_param, validate_mcp_protocol_version_header,
5};
6use super::types::GenericBody;
7use crate::auth::AuthInfo;
8#[cfg(feature = "auth")]
9use crate::auth::AuthProvider;
10use crate::mcp_http::{middleware::compose, BoxFutureResponse, Middleware, RequestHandler};
11use crate::mcp_http::{GenericBodyExt, RequestExt};
12use crate::mcp_runtimes::server_runtime::DEFAULT_STREAM_ID;
13use crate::mcp_server::error::TransportServerError;
14use crate::schema::schema_utils::SdkError;
15use crate::{
16    error::McpSdkError,
17    mcp_http::{
18        http_utils::{
19            acceptable_content_type, create_standalone_stream, delete_session,
20            process_incoming_message, process_incoming_message_return, start_new_session,
21            valid_streaming_http_accept_header,
22        },
23        McpAppState,
24    },
25    mcp_server::error::TransportServerResult,
26    utils::valid_initialize_method,
27};
28use http::{self, HeaderMap, Method, StatusCode, Uri};
29use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER};
30use std::sync::Arc;
31
32/// A helper macro to wrap an async handler method into a `RequestHandler`
33/// and compose it with middlewares.
34///
35/// # Example
36/// ```ignore
37/// let handle = with_middlewares!(self, Self::internal_handle_sse_message);
38/// handle
39///
40/// // OR
41/// let handler = with_middlewares!(self, Self::internal_handle_sse_message, extra_middlewares1, extra_middlewares2);
42/// ```
43#[macro_export]
44macro_rules! with_middlewares {
45    ($self:ident, $handler:path) => {{
46        let final_handler: RequestHandler = Box::new(
47            move |req: http::Request<&str>,
48                  state: std::sync::Arc<McpAppState>|
49                  -> BoxFutureResponse<'_> {
50                Box::pin(async move { $handler(req, state).await })
51            },
52        );
53        $crate::mcp_http::middleware::compose(&$self.middlewares, final_handler)
54    }};
55
56    // Handler + extra middleware(s)
57    ($self:ident, $handler:path, $($extra:expr),+ $(,)?) => {{
58        let final_handler: RequestHandler = Box::new(
59            move |req: http::Request<&str>,
60                  state: std::sync::Arc<McpAppState>|
61                  -> BoxFutureResponse<'_> {
62                Box::pin(async move { $handler(req, state).await })
63            },
64        );
65
66        // Chain $self.middlewares with any extra middleware iterators
67        let all = $self.middlewares.iter()
68            $(.chain($extra.iter()))+;
69
70        $crate::mcp_http::middleware::compose(all, final_handler)
71    }};
72}
73
74#[derive(Clone)]
75pub struct McpHttpHandler {
76    #[cfg(feature = "auth")]
77    auth: Option<Arc<dyn AuthProvider>>,
78    middlewares: Vec<Arc<dyn Middleware>>,
79}
80
81impl McpHttpHandler {
82    #[cfg(feature = "auth")]
83    pub fn new(auth: Option<Arc<dyn AuthProvider>>, middlewares: Vec<Arc<dyn Middleware>>) -> Self {
84        McpHttpHandler { auth, middlewares }
85    }
86
87    #[cfg(not(feature = "auth"))]
88    pub fn new(middlewares: Vec<Arc<dyn Middleware>>) -> Self {
89        McpHttpHandler { middlewares }
90    }
91
92    pub fn add_middleware<M: Middleware + 'static>(&mut self, middleware: M) {
93        let m: Arc<dyn Middleware> = Arc::new(middleware);
94        self.middlewares.push(m);
95    }
96
97    /// An `http::Request<&str>` initialized with the specified method, URI, headers, and body.
98    /// If the `body` is `None`, an empty string is used as the default.
99    ///
100    pub fn create_request(
101        method: Method,
102        uri: Uri,
103        headers: HeaderMap,
104        body: Option<&str>,
105    ) -> http::Request<&str> {
106        let mut request = http::Request::default();
107        *request.method_mut() = method;
108        *request.uri_mut() = uri;
109        *request.body_mut() = body.unwrap_or_default();
110        let req_headers = request.headers_mut();
111        for (key, value) in headers {
112            if let Some(k) = key {
113                req_headers.insert(k, value);
114            }
115        }
116        request
117    }
118}
119
120// auth related methods
121#[cfg(feature = "auth")]
122impl McpHttpHandler {
123    pub fn oauth_endppoints(&self) -> Option<Vec<&String>> {
124        self.auth
125            .as_ref()
126            .and_then(|a| a.auth_endpoints().map(|e| e.keys().collect::<Vec<_>>()))
127    }
128
129    pub async fn handle_auth_requests(
130        &self,
131        request: http::Request<&str>,
132        state: Arc<McpAppState>,
133    ) -> TransportServerResult<http::Response<GenericBody>> {
134        let Some(auth_provider) = self.auth.as_ref() else {
135            return Err(TransportServerError::HttpError(
136                "Authentication is not supported by this server.".to_string(),
137            ));
138        };
139
140        let auth_provider = auth_provider.clone();
141        let final_handler: RequestHandler = Box::new(move |req, state| {
142            Box::pin(async move {
143                use futures::TryFutureExt;
144                auth_provider
145                    .handle_request(req, state)
146                    .map_err(|e| e)
147                    .await
148            })
149        });
150
151        let handle = compose(&[], final_handler);
152        handle(request, state).await
153    }
154}
155
156impl McpHttpHandler {
157    /// Handles an MCP connection using the SSE (Server-Sent Events) transport.
158    ///
159    /// This function serves as the entry point for initializing and managing a client connection
160    /// over SSE when the `sse` feature is enabled.
161    ///
162    /// # Arguments
163    /// * `state` - Shared application state required to manage the MCP session.
164    /// * `sse_message_endpoint` - Optional message endpoint to override the default SSE route (default: `/messages` ).
165    ///
166    ///
167    /// # Features
168    /// This function is only available when the `sse` feature is enabled.
169    #[cfg(feature = "sse")]
170    pub async fn handle_sse_connection(
171        &self,
172        request: http::Request<&str>,
173        state: Arc<McpAppState>,
174        sse_message_endpoint: Option<&str>,
175    ) -> TransportServerResult<http::Response<GenericBody>> {
176        use crate::auth::AuthInfo;
177        use crate::mcp_http::RequestExt;
178
179        let (request, auth_info) = request.take::<AuthInfo>();
180
181        let sse_endpoint = sse_message_endpoint.map(|s| s.to_string());
182        let final_handler: RequestHandler = Box::new(move |_req, state| {
183            Box::pin(async move {
184                handle_sse_connection(state, sse_endpoint.as_deref(), auth_info).await
185            })
186        });
187        let handle = compose(&self.middlewares, final_handler);
188        handle(request, state).await
189    }
190
191    /// Handles incoming MCP messages from the client after an SSE connection is established.
192    ///
193    /// This function processes a message sent by the client as part of an active SSE session. It:
194    /// - Extracts the `sessionId` from the request query parameters.
195    /// - Locates the corresponding session's transmit channel.
196    /// - Forwards the incoming message payload to the MCP transport stream for consumption.
197    /// # Arguments
198    /// * `request` - The HTTP request containing the message body and query parameters (including `sessionId`).
199    /// * `state` - Shared application state, including access to the session store.
200    ///
201    /// # Returns
202    /// * `TransportServerResult<http::Response<GenericBody>>`:
203    ///   - Returns a `202 Accepted` HTTP response if the message is successfully forwarded.
204    ///   - Returns an error if the session ID is missing, invalid, or if any I/O issues occur while processing the message.
205    ///
206    /// # Errors
207    /// - `SessionIdMissing`: if the `sessionId` query parameter is not present.
208    /// - `SessionIdInvalid`: if the session ID does not map to a valid session in the session store.
209    /// - `StreamIoError`: if an error occurs while writing to the stream.
210    /// - `HttpError`: if constructing the HTTP response fails.
211    #[cfg(feature = "sse")]
212    pub async fn handle_sse_message(
213        &self,
214        request: http::Request<&str>,
215        state: Arc<McpAppState>,
216    ) -> TransportServerResult<http::Response<GenericBody>> {
217        let handle = with_middlewares!(self, Self::internal_handle_sse_message);
218        handle(request, state).await
219    }
220
221    /// Handles incoming MCP messages over the StreamableHTTP transport.
222    ///
223    /// It supports `GET`, `POST`, and `DELETE` methods for handling streaming operations, and performs optional
224    /// DNS rebinding protection if it is configured.
225    ///
226    /// # Arguments
227    /// * `request` - The HTTP request from the client, including method, headers, and optional body.
228    /// * `state` - Shared application state, including configuration and session management.
229    ///
230    /// # Behavior
231    /// - If DNS rebinding protection is enabled via the app state, the function checks the request headers.
232    ///   If dns protection fails, a `403 Forbidden` response is returned.
233    /// - Dispatches the request to method-specific handlers based on the HTTP method:
234    ///     - `GET` → `handle_http_get`
235    ///     - `POST` → `handle_http_post`
236    ///     - `DELETE` → `handle_http_delete`
237    /// - Returns `405 Method Not Allowed` for unsupported methods.
238    ///
239    /// # Returns
240    /// * A `TransportServerResult` wrapping an HTTP response indicating success or failure of the operation.
241    ///
242    pub async fn handle_streamable_http(
243        &self,
244        request: http::Request<&str>,
245        state: Arc<McpAppState>,
246    ) -> TransportServerResult<http::Response<GenericBody>> {
247        let handle = with_middlewares!(self, Self::internal_handle_streamable_http);
248        handle(request, state).await
249    }
250
251    async fn internal_handle_sse_message(
252        request: http::Request<&str>,
253        state: Arc<McpAppState>,
254    ) -> TransportServerResult<http::Response<GenericBody>> {
255        let session_id =
256            query_param(&request, "sessionId").ok_or(TransportServerError::SessionIdMissing)?;
257
258        // transmit to the readable stream, that transport is reading from
259        let transmit = state.session_store.get(&session_id).await.ok_or(
260            TransportServerError::SessionIdInvalid(session_id.to_string()),
261        )?;
262
263        let message = request.body();
264
265        transmit
266            .consume_payload_string(DEFAULT_STREAM_ID, message.as_ref())
267            .await
268            .map_err(|err| {
269                tracing::trace!("{}", err);
270                TransportServerError::StreamIoError(err.to_string())
271            })?;
272
273        http::Response::builder()
274            .status(StatusCode::ACCEPTED)
275            .body(GenericBody::empty())
276            .map_err(|err| TransportServerError::HttpError(err.to_string()))
277    }
278
279    async fn internal_handle_streamable_http(
280        request: http::Request<&str>,
281        state: Arc<McpAppState>,
282    ) -> TransportServerResult<http::Response<GenericBody>> {
283        let (request, auth_info) = request.take::<AuthInfo>();
284
285        let method = request.method();
286
287        let response = match method {
288            &http::Method::GET => return Self::handle_http_get(request, state, auth_info).await,
289            &http::Method::POST => return Self::handle_http_post(request, state, auth_info).await,
290            &http::Method::DELETE => return Self::handle_http_delete(request, state).await,
291            other => {
292                let error = SdkError::bad_request().with_message(&format!(
293                    "'{other}' is not a valid HTTP method for StreamableHTTP transport."
294                ));
295                error_response(StatusCode::METHOD_NOT_ALLOWED, error)
296            }
297        };
298
299        response
300    }
301
302    /// Processes POST requests for the Streamable HTTP Protocol
303    async fn handle_http_post(
304        request: http::Request<&str>,
305        state: Arc<McpAppState>,
306        auth_info: Option<AuthInfo>,
307    ) -> TransportServerResult<http::Response<GenericBody>> {
308        let headers = request.headers();
309
310        if !valid_streaming_http_accept_header(headers) {
311            let error = SdkError::bad_request()
312                .with_message(r#"Client must accept both application/json and text/event-stream"#);
313            return error_response(StatusCode::NOT_ACCEPTABLE, error);
314        }
315
316        if !acceptable_content_type(headers) {
317            let error = SdkError::bad_request()
318                .with_message(r#"Unsupported Media Type: Content-Type must be application/json"#);
319            return error_response(StatusCode::UNSUPPORTED_MEDIA_TYPE, error);
320        }
321
322        if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
323            let error = SdkError::bad_request()
324                .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
325            return error_response(StatusCode::BAD_REQUEST, error);
326        }
327
328        let session_id: Option<SessionId> = headers
329            .get(MCP_SESSION_ID_HEADER)
330            .and_then(|value| value.to_str().ok())
331            .map(|s| s.to_string());
332
333        let payload = request.body();
334
335        let response = match session_id {
336            // has session-id => write to the existing stream
337            Some(id) => {
338                if state.enable_json_response {
339                    process_incoming_message_return(id, state, payload, auth_info).await
340                } else {
341                    process_incoming_message(id, state, payload, auth_info).await
342                }
343            }
344            None => match valid_initialize_method(payload) {
345                Ok(_) => {
346                    return start_new_session(state, payload, auth_info).await;
347                }
348                Err(McpSdkError::SdkError(error)) => error_response(StatusCode::BAD_REQUEST, error),
349                Err(error) => {
350                    let error = SdkError::bad_request().with_message(&error.to_string());
351                    error_response(StatusCode::BAD_REQUEST, error)
352                }
353            },
354        };
355
356        response
357    }
358
359    /// Processes GET requests for the Streamable HTTP Protocol
360    async fn handle_http_get(
361        request: http::Request<&str>,
362        state: Arc<McpAppState>,
363        auth_info: Option<AuthInfo>,
364    ) -> TransportServerResult<http::Response<GenericBody>> {
365        let headers = request.headers();
366
367        if !accepts_event_stream(headers) {
368            let error =
369                SdkError::bad_request().with_message(r#"Client must accept text/event-stream"#);
370            return error_response(StatusCode::NOT_ACCEPTABLE, error);
371        }
372
373        if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
374            let error = SdkError::bad_request()
375                .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
376            return error_response(StatusCode::BAD_REQUEST, error);
377        }
378
379        let session_id: Option<SessionId> = headers
380            .get(MCP_SESSION_ID_HEADER)
381            .and_then(|value| value.to_str().ok())
382            .map(|s| s.to_string());
383
384        let last_event_id: Option<SessionId> = headers
385            .get(MCP_LAST_EVENT_ID_HEADER)
386            .and_then(|value| value.to_str().ok())
387            .map(|s| s.to_string());
388
389        let response = match session_id {
390            Some(session_id) => {
391                let res =
392                    create_standalone_stream(session_id, last_event_id, state, auth_info).await;
393                res
394            }
395            None => {
396                let error = SdkError::bad_request().with_message("Bad request: session not found");
397                error_response(StatusCode::BAD_REQUEST, error)
398            }
399        };
400
401        response
402    }
403
404    /// Processes DELETE requests for the Streamable HTTP Protocol
405    async fn handle_http_delete(
406        request: http::Request<&str>,
407        state: Arc<McpAppState>,
408    ) -> TransportServerResult<http::Response<GenericBody>> {
409        let headers = request.headers();
410
411        if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
412            let error = SdkError::bad_request()
413                .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
414            return error_response(StatusCode::BAD_REQUEST, error);
415        }
416
417        let session_id: Option<SessionId> = headers
418            .get(MCP_SESSION_ID_HEADER)
419            .and_then(|value| value.to_str().ok())
420            .map(|s| s.to_string());
421
422        let response = match session_id {
423            Some(id) => delete_session(id, state).await,
424            None => {
425                let error = SdkError::bad_request().with_message("Bad Request: Session not found");
426                error_response(StatusCode::BAD_REQUEST, error)
427            }
428        };
429
430        response
431    }
432}