1#[cfg(feature = "sse")]
2use super::utils::handle_sse_connection;
3use crate::mcp_http::mcp_http_middleware::MiddlewareChain;
4use crate::mcp_http::utils::{
5 accepts_event_stream, empty_response, error_response, query_param,
6 validate_mcp_protocol_version_header,
7};
8use crate::mcp_http::Middleware;
9use crate::mcp_runtimes::server_runtime::DEFAULT_STREAM_ID;
10use crate::mcp_server::error::TransportServerError;
11use crate::schema::schema_utils::SdkError;
12use crate::{
13 error::McpSdkError,
14 mcp_http::{
15 utils::{
16 acceptable_content_type, create_standalone_stream, delete_session,
17 process_incoming_message, process_incoming_message_return, protect_dns_rebinding,
18 start_new_session, valid_streaming_http_accept_header, GenericBody,
19 },
20 McpAppState,
21 },
22 mcp_server::error::TransportServerResult,
23 utils::valid_initialize_method,
24};
25use http::{self, HeaderMap, Method, StatusCode, Uri};
26use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER};
27use std::sync::Arc;
28
29#[derive(Clone)]
30pub struct McpHttpHandler {
31 middleware_chain: MiddlewareChain,
32}
33
34impl Default for McpHttpHandler {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl McpHttpHandler {
41 pub fn new() -> Self {
42 McpHttpHandler {
43 middleware_chain: MiddlewareChain::new(),
44 }
45 }
46
47 pub fn add_middleware<M: Middleware>(&mut self, middleware: M) {
48 self.middleware_chain.add_middleware(middleware);
49 }
50
51 pub fn create_request(
55 method: Method,
56 uri: Uri,
57 headers: HeaderMap,
58 body: Option<&str>,
59 ) -> http::Request<&str> {
60 let mut request = http::Request::default();
61 *request.method_mut() = method;
62 *request.uri_mut() = uri;
63 *request.body_mut() = body.unwrap_or_default();
64 let req_headers = request.headers_mut();
65 for (key, value) in headers {
66 if let Some(k) = key {
67 req_headers.insert(k, value);
68 }
69 }
70 request
71 }
72}
73
74impl McpHttpHandler {
75 #[cfg(feature = "sse")]
88 pub async fn handle_sse_connection(
89 &self,
90 state: Arc<McpAppState>,
91 sse_message_endpoint: Option<&str>,
92 ) -> TransportServerResult<http::Response<GenericBody>> {
93 handle_sse_connection(state, sse_message_endpoint).await
94 }
95
96 pub async fn handle_sse_message(
117 &self,
118 request: http::Request<&str>,
119 state: Arc<McpAppState>,
120 ) -> TransportServerResult<http::Response<GenericBody>> {
121 let session_id =
122 query_param(&request, "sessionId").ok_or(TransportServerError::SessionIdMissing)?;
123
124 let transmit = state.session_store.get(&session_id).await.ok_or(
126 TransportServerError::SessionIdInvalid(session_id.to_string()),
127 )?;
128
129 let message = *request.body();
130 transmit
131 .consume_payload_string(DEFAULT_STREAM_ID, message)
132 .await
133 .map_err(|err| {
134 tracing::trace!("{}", err);
135 TransportServerError::StreamIoError(err.to_string())
136 })?;
137
138 http::Response::builder()
139 .status(StatusCode::ACCEPTED)
140 .body(empty_response())
141 .map_err(|err| TransportServerError::HttpError(err.to_string()))
142 }
143
144 pub async fn handle_streamable_http(
166 &self,
167 request: http::Request<&str>,
168 state: Arc<McpAppState>,
169 ) -> TransportServerResult<http::Response<GenericBody>> {
170 let request = self
171 .middleware_chain
172 .process_request(request)
173 .await
174 .map_err(|e| TransportServerError::HttpError(e.to_string()))?;
175
176 if state.needs_dns_protection() {
179 if let Err(error) = protect_dns_rebinding(request.headers(), state.clone()).await {
180 return error_response(StatusCode::FORBIDDEN, error);
181 }
182 }
183
184 let method = request.method();
185 let response = match method {
186 &http::Method::GET => return self.handle_http_get(request, state).await,
187 &http::Method::POST => return self.handle_http_post(request, state).await,
188 &http::Method::DELETE => return self.handle_http_delete(request, state).await,
189 other => {
190 let error = SdkError::bad_request().with_message(&format!(
191 "'{other}' is not a valid HTTP method for StreamableHTTP transport."
192 ));
193 error_response(StatusCode::METHOD_NOT_ALLOWED, error)
194 }
195 };
196
197 self.middleware_chain
198 .process_response(response?)
199 .await
200 .map_err(|e| TransportServerError::HttpError(e.to_string()))
201 }
202
203 async fn handle_http_post(
205 &self,
206 request: http::Request<&str>,
207 state: Arc<McpAppState>,
208 ) -> TransportServerResult<http::Response<GenericBody>> {
209 let request = self
210 .middleware_chain
211 .process_request(request)
212 .await
213 .map_err(|e| TransportServerError::HttpError(e.to_string()))?;
214
215 let headers = request.headers();
216
217 if !valid_streaming_http_accept_header(headers) {
218 let error = SdkError::bad_request()
219 .with_message(r#"Client must accept both application/json and text/event-stream"#);
220 return error_response(StatusCode::NOT_ACCEPTABLE, error);
221 }
222
223 if !acceptable_content_type(headers) {
224 let error = SdkError::bad_request()
225 .with_message(r#"Unsupported Media Type: Content-Type must be application/json"#);
226 return error_response(StatusCode::UNSUPPORTED_MEDIA_TYPE, error);
227 }
228
229 if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
230 let error = SdkError::bad_request()
231 .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
232 return error_response(StatusCode::BAD_REQUEST, error);
233 }
234
235 let session_id: Option<SessionId> = headers
236 .get(MCP_SESSION_ID_HEADER)
237 .and_then(|value| value.to_str().ok())
238 .map(|s| s.to_string());
239
240 let payload = *request.body();
241
242 let response = match session_id {
243 Some(id) => {
245 if state.enable_json_response {
246 process_incoming_message_return(id, state, payload).await
247 } else {
248 process_incoming_message(id, state, payload).await
249 }
250 }
251 None => match valid_initialize_method(payload) {
252 Ok(_) => {
253 return start_new_session(state, payload).await;
254 }
255 Err(McpSdkError::SdkError(error)) => error_response(StatusCode::BAD_REQUEST, error),
256 Err(error) => {
257 let error = SdkError::bad_request().with_message(&error.to_string());
258 error_response(StatusCode::BAD_REQUEST, error)
259 }
260 },
261 };
262
263 self.middleware_chain
264 .process_response(response?)
265 .await
266 .map_err(|e| TransportServerError::HttpError(e.to_string()))
267 }
268
269 async fn handle_http_get(
271 &self,
272 request: http::Request<&str>,
273 state: Arc<McpAppState>,
274 ) -> TransportServerResult<http::Response<GenericBody>> {
275 let request = self
276 .middleware_chain
277 .process_request(request)
278 .await
279 .map_err(|e| TransportServerError::HttpError(e.to_string()))?;
280
281 let headers = request.headers();
282
283 if !accepts_event_stream(headers) {
284 let error =
285 SdkError::bad_request().with_message(r#"Client must accept text/event-stream"#);
286 return error_response(StatusCode::NOT_ACCEPTABLE, error);
287 }
288
289 if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
290 let error = SdkError::bad_request()
291 .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
292 return error_response(StatusCode::BAD_REQUEST, error);
293 }
294
295 let session_id: Option<SessionId> = headers
296 .get(MCP_SESSION_ID_HEADER)
297 .and_then(|value| value.to_str().ok())
298 .map(|s| s.to_string());
299
300 let last_event_id: Option<SessionId> = headers
301 .get(MCP_LAST_EVENT_ID_HEADER)
302 .and_then(|value| value.to_str().ok())
303 .map(|s| s.to_string());
304
305 let response = match session_id {
306 Some(session_id) => {
307 let res = create_standalone_stream(session_id, last_event_id, state).await;
308 res
309 }
310 None => {
311 let error = SdkError::bad_request().with_message("Bad request: session not found");
312 error_response(StatusCode::BAD_REQUEST, error)
313 }
314 };
315
316 self.middleware_chain
317 .process_response(response?)
318 .await
319 .map_err(|e| TransportServerError::HttpError(e.to_string()))
320 }
321
322 async fn handle_http_delete(
324 &self,
325 request: http::Request<&str>,
326 state: Arc<McpAppState>,
327 ) -> TransportServerResult<http::Response<GenericBody>> {
328 let request = self
329 .middleware_chain
330 .process_request(request)
331 .await
332 .map_err(|e| TransportServerError::HttpError(e.to_string()))?;
333
334 let headers = request.headers();
335
336 if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
337 let error = SdkError::bad_request()
338 .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
339 return error_response(StatusCode::BAD_REQUEST, error);
340 }
341
342 let session_id: Option<SessionId> = headers
343 .get(MCP_SESSION_ID_HEADER)
344 .and_then(|value| value.to_str().ok())
345 .map(|s| s.to_string());
346
347 let response = match session_id {
348 Some(id) => delete_session(id, state).await,
349 None => {
350 let error = SdkError::bad_request().with_message("Bad Request: Session not found");
351 error_response(StatusCode::BAD_REQUEST, error)
352 }
353 };
354
355 self.middleware_chain
356 .process_response(response?)
357 .await
358 .map_err(|e| TransportServerError::HttpError(e.to_string()))
359 }
360}