rust_mcp_sdk/mcp_http/
mcp_http_handler.rs

1#[cfg(feature = "sse")]
2use super::utils::handle_sse_connection;
3use crate::mcp_http::utils::{
4    accepts_event_stream, error_response, query_param, validate_mcp_protocol_version_header,
5};
6use crate::mcp_runtimes::server_runtime::DEFAULT_STREAM_ID;
7use crate::mcp_server::error::TransportServerError;
8use crate::schema::schema_utils::SdkError;
9use crate::{
10    error::McpSdkError,
11    mcp_http::{
12        utils::{
13            acceptable_content_type, create_standalone_stream, delete_session,
14            process_incoming_message, process_incoming_message_return, protect_dns_rebinding,
15            start_new_session, valid_streaming_http_accept_header, GenericBody,
16        },
17        McpAppState,
18    },
19    mcp_server::error::TransportServerResult,
20    utils::valid_initialize_method,
21};
22use bytes::Bytes;
23use http::{self, HeaderMap, Method, StatusCode, Uri};
24use http_body_util::{BodyExt, Full};
25use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER};
26use std::sync::Arc;
27
28pub struct McpHttpHandler {}
29
30impl McpHttpHandler {
31    /// Creates a new HTTP request with the given method, URI, headers, and optional body.
32    ///
33    /// # Arguments
34    ///
35    /// * `method` - The HTTP method to use (e.g., GET, POST).
36    /// * `uri` - The target URI for the request.
37    /// * `headers` - A map of optional header keys and their corresponding values.
38    /// * `body` - An optional string slice representing the request body.
39    ///
40    /// # Returns
41    ///
42    /// An `http::Request<&str>` initialized with the specified method, URI, headers, and body.
43    /// If the `body` is `None`, an empty string is used as the default.
44    ///
45    pub fn create_request(
46        method: Method,
47        uri: Uri,
48        headers: HeaderMap,
49        body: Option<&str>,
50    ) -> http::Request<&str> {
51        let mut request = http::Request::default();
52        *request.method_mut() = method;
53        *request.uri_mut() = uri;
54        *request.body_mut() = body.unwrap_or_default();
55        let req_headers = request.headers_mut();
56        for (key, value) in headers {
57            if let Some(k) = key {
58                req_headers.insert(k, value);
59            }
60        }
61        request
62    }
63}
64
65impl McpHttpHandler {
66    /// Handles an MCP connection using the SSE (Server-Sent Events) transport.
67    ///
68    /// This function serves as the entry point for initializing and managing a client connection
69    /// over SSE when the `sse` feature is enabled.
70    ///
71    /// # Arguments
72    /// * `state` - Shared application state required to manage the MCP session.
73    /// * `sse_message_endpoint` - Optional message endpoint to override the default SSE route (default: `/messages` ).
74    ///
75    ///
76    /// # Features
77    /// This function is only available when the `sse` feature is enabled.
78    #[cfg(feature = "sse")]
79    pub async fn handle_sse_connection(
80        state: Arc<McpAppState>,
81        sse_message_endpoint: Option<&str>,
82    ) -> TransportServerResult<http::Response<GenericBody>> {
83        handle_sse_connection(state, sse_message_endpoint).await
84    }
85
86    /// Handles incoming MCP messages from the client after an SSE connection is established.
87    ///
88    /// This function processes a message sent by the client as part of an active SSE session. It:
89    /// - Extracts the `sessionId` from the request query parameters.
90    /// - Locates the corresponding session's transmit channel.
91    /// - Forwards the incoming message payload to the MCP transport stream for consumption.
92    /// # Arguments
93    /// * `request` - The HTTP request containing the message body and query parameters (including `sessionId`).
94    /// * `state` - Shared application state, including access to the session store.
95    ///
96    /// # Returns
97    /// * `TransportServerResult<http::Response<GenericBody>>`:
98    ///   - Returns a `202 Accepted` HTTP response if the message is successfully forwarded.
99    ///   - Returns an error if the session ID is missing, invalid, or if any I/O issues occur while processing the message.
100    ///
101    /// # Errors
102    /// - `SessionIdMissing`: if the `sessionId` query parameter is not present.
103    /// - `SessionIdInvalid`: if the session ID does not map to a valid session in the session store.
104    /// - `StreamIoError`: if an error occurs while writing to the stream.
105    /// - `HttpError`: if constructing the HTTP response fails.
106    pub async fn handle_sse_message(
107        request: http::Request<&str>,
108        state: Arc<McpAppState>,
109    ) -> TransportServerResult<http::Response<GenericBody>> {
110        let session_id =
111            query_param(&request, "sessionId").ok_or(TransportServerError::SessionIdMissing)?;
112
113        // transmit to the readable stream, that transport is reading from
114        let transmit = state.session_store.get(&session_id).await.ok_or(
115            TransportServerError::SessionIdInvalid(session_id.to_string()),
116        )?;
117
118        let message = *request.body();
119        transmit
120            .consume_payload_string(DEFAULT_STREAM_ID, message)
121            .await
122            .map_err(|err| {
123                tracing::trace!("{}", err);
124                TransportServerError::StreamIoError(err.to_string())
125            })?;
126
127        let body = Full::new(Bytes::new())
128            .map_err(|err| TransportServerError::HttpError(err.to_string()))
129            .boxed();
130
131        http::Response::builder()
132            .status(StatusCode::ACCEPTED)
133            .body(body)
134            .map_err(|err| TransportServerError::HttpError(err.to_string()))
135    }
136
137    /// Handles incoming MCP messages over the StreamableHTTP transport.
138    ///
139    /// It supports `GET`, `POST`, and `DELETE` methods for handling streaming operations, and performs optional
140    /// DNS rebinding protection if it is configured.
141    ///
142    /// # Arguments
143    /// * `request` - The HTTP request from the client, including method, headers, and optional body.
144    /// * `state` - Shared application state, including configuration and session management.
145    ///
146    /// # Behavior
147    /// - If DNS rebinding protection is enabled via the app state, the function checks the request headers.
148    ///   If dns protection fails, a `403 Forbidden` response is returned.
149    /// - Dispatches the request to method-specific handlers based on the HTTP method:
150    ///     - `GET` → `handle_http_get`
151    ///     - `POST` → `handle_http_post`
152    ///     - `DELETE` → `handle_http_delete`
153    /// - Returns `405 Method Not Allowed` for unsupported methods.
154    ///
155    /// # Returns
156    /// * A `TransportServerResult` wrapping an HTTP response indicating success or failure of the operation.
157    ///
158    pub async fn handle_streamable_http(
159        request: http::Request<&str>,
160        state: Arc<McpAppState>,
161    ) -> TransportServerResult<http::Response<GenericBody>> {
162        // Enforces DNS rebinding protection if required by state.
163        // If protection fails, respond with HTTP 403 Forbidden.
164        if state.needs_dns_protection() {
165            if let Err(error) = protect_dns_rebinding(request.headers(), state.clone()).await {
166                return error_response(StatusCode::FORBIDDEN, error);
167            }
168        }
169
170        let method = request.method();
171        match method {
172            &http::Method::GET => return Self::handle_http_get(request, state).await,
173            &http::Method::POST => return Self::handle_http_post(request, state).await,
174            &http::Method::DELETE => return Self::handle_http_delete(request, state).await,
175            other => {
176                let error = SdkError::bad_request().with_message(&format!(
177                    "'{other}' is not a valid HTTP method for StreamableHTTP transport."
178                ));
179                error_response(StatusCode::METHOD_NOT_ALLOWED, error)
180            }
181        }
182    }
183
184    /// Processes POST requests for the Streamable HTTP Protocol
185    async fn handle_http_post(
186        request: http::Request<&str>,
187        state: Arc<McpAppState>,
188    ) -> TransportServerResult<http::Response<GenericBody>> {
189        let headers = request.headers();
190
191        if !valid_streaming_http_accept_header(headers) {
192            let error = SdkError::bad_request()
193                .with_message(r#"Client must accept both application/json and text/event-stream"#);
194            return error_response(StatusCode::NOT_ACCEPTABLE, error);
195        }
196
197        if !acceptable_content_type(headers) {
198            let error = SdkError::bad_request()
199                .with_message(r#"Unsupported Media Type: Content-Type must be application/json"#);
200            return error_response(StatusCode::UNSUPPORTED_MEDIA_TYPE, error);
201        }
202
203        if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
204            let error = SdkError::bad_request()
205                .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
206            return error_response(StatusCode::BAD_REQUEST, error);
207        }
208
209        let session_id: Option<SessionId> = headers
210            .get(MCP_SESSION_ID_HEADER)
211            .and_then(|value| value.to_str().ok())
212            .map(|s| s.to_string());
213
214        let payload = *request.body();
215
216        match session_id {
217            // has session-id => write to the existing stream
218            Some(id) => {
219                if state.enable_json_response {
220                    process_incoming_message_return(id, state, payload).await
221                } else {
222                    process_incoming_message(id, state, payload).await
223                }
224            }
225            None => match valid_initialize_method(payload) {
226                Ok(_) => {
227                    return start_new_session(state, payload).await;
228                }
229                Err(McpSdkError::SdkError(error)) => error_response(StatusCode::BAD_REQUEST, error),
230                Err(error) => {
231                    let error = SdkError::bad_request().with_message(&error.to_string());
232                    error_response(StatusCode::BAD_REQUEST, error)
233                }
234            },
235        }
236    }
237
238    /// Processes GET requests for the Streamable HTTP Protocol
239    async fn handle_http_get(
240        request: http::Request<&str>,
241        state: Arc<McpAppState>,
242    ) -> TransportServerResult<http::Response<GenericBody>> {
243        let headers = request.headers();
244
245        if !accepts_event_stream(headers) {
246            let error =
247                SdkError::bad_request().with_message(r#"Client must accept text/event-stream"#);
248            return error_response(StatusCode::NOT_ACCEPTABLE, error);
249        }
250
251        if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
252            let error = SdkError::bad_request()
253                .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
254            return error_response(StatusCode::BAD_REQUEST, error);
255        }
256
257        let session_id: Option<SessionId> = headers
258            .get(MCP_SESSION_ID_HEADER)
259            .and_then(|value| value.to_str().ok())
260            .map(|s| s.to_string());
261
262        let last_event_id: Option<SessionId> = headers
263            .get(MCP_LAST_EVENT_ID_HEADER)
264            .and_then(|value| value.to_str().ok())
265            .map(|s| s.to_string());
266
267        match session_id {
268            Some(session_id) => {
269                let res = create_standalone_stream(session_id, last_event_id, state).await;
270                res
271            }
272            None => {
273                let error = SdkError::bad_request().with_message("Bad request: session not found");
274                error_response(StatusCode::BAD_REQUEST, error)
275            }
276        }
277    }
278
279    /// Processes DELETE requests for the Streamable HTTP Protocol
280    async fn handle_http_delete(
281        request: http::Request<&str>,
282        state: Arc<McpAppState>,
283    ) -> TransportServerResult<http::Response<GenericBody>> {
284        let headers = request.headers();
285
286        if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
287            let error = SdkError::bad_request()
288                .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
289            return error_response(StatusCode::BAD_REQUEST, error);
290        }
291
292        let session_id: Option<SessionId> = headers
293            .get(MCP_SESSION_ID_HEADER)
294            .and_then(|value| value.to_str().ok())
295            .map(|s| s.to_string());
296
297        match session_id {
298            Some(id) => delete_session(id, state).await,
299            None => {
300                let error = SdkError::bad_request().with_message("Bad Request: Session not found");
301                error_response(StatusCode::BAD_REQUEST, error)
302            }
303        }
304    }
305}