1#[cfg(feature = "sse")]
2use super::http_utils::{
3 accepts_event_stream, empty_response, error_response, handle_sse_connection, query_param,
4 validate_mcp_protocol_version_header,
5};
6use super::types::GenericBody;
7use crate::mcp_http::{middleware::compose, BoxFutureResponse, Middleware, RequestHandler};
8use crate::mcp_runtimes::server_runtime::DEFAULT_STREAM_ID;
9use crate::mcp_server::error::TransportServerError;
10use crate::schema::schema_utils::SdkError;
11use crate::{
12 error::McpSdkError,
13 mcp_http::{
14 http_utils::{
15 acceptable_content_type, create_standalone_stream, delete_session,
16 process_incoming_message, process_incoming_message_return, start_new_session,
17 valid_streaming_http_accept_header,
18 },
19 McpAppState,
20 },
21 mcp_server::error::TransportServerResult,
22 utils::valid_initialize_method,
23};
24use http::{self, HeaderMap, Method, StatusCode, Uri};
25use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER};
26use std::sync::Arc;
27
28#[macro_export]
37macro_rules! with_middlewares {
38 ($self:ident, $handler:path) => {{
39 let final_handler: RequestHandler = std::sync::Arc::new(
40 move |req: http::Request<&str>,
41 state: std::sync::Arc<McpAppState>|
42 -> BoxFutureResponse<'_> {
43 Box::pin(async move { $handler(req, state).await })
44 },
45 );
46 $crate::mcp_http::middleware::compose(&$self.middlewares, final_handler)
47 }};
48}
49
50#[derive(Clone)]
51pub struct McpHttpHandler {
52 middlewares: Vec<Arc<dyn Middleware>>,
53}
54
55impl Default for McpHttpHandler {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl McpHttpHandler {
62 pub fn new() -> Self {
63 McpHttpHandler {
64 middlewares: vec![],
65 }
66 }
67 pub fn add_middleware<M: Middleware + 'static>(&mut self, middleware: M) {
68 let m: Arc<dyn Middleware> = Arc::new(middleware);
69 self.middlewares.push(m);
70 }
71
72 pub fn create_request(
76 method: Method,
77 uri: Uri,
78 headers: HeaderMap,
79 body: Option<&str>,
80 ) -> http::Request<&str> {
81 let mut request = http::Request::default();
82 *request.method_mut() = method;
83 *request.uri_mut() = uri;
84 *request.body_mut() = body.unwrap_or_default();
85 let req_headers = request.headers_mut();
86 for (key, value) in headers {
87 if let Some(k) = key {
88 req_headers.insert(k, value);
89 }
90 }
91 request
92 }
93}
94
95impl McpHttpHandler {
96 #[cfg(feature = "sse")]
109 pub async fn handle_sse_connection(
110 &self,
111 request: http::Request<&str>,
112 state: Arc<McpAppState>,
113 sse_message_endpoint: Option<&str>,
114 ) -> TransportServerResult<http::Response<GenericBody>> {
115 let sse_endpoint = Arc::from(sse_message_endpoint.map(|s| s.to_string()));
116 let final_handler: RequestHandler = Arc::new(move |_req, state| {
117 let sse_endpoint = sse_endpoint.clone();
118 Box::pin(async move { handle_sse_connection(state, sse_endpoint.as_deref()).await })
119 });
120 let handle = compose(&self.middlewares, final_handler);
121 handle(request, state).await
122 }
123
124 #[cfg(feature = "sse")]
145 pub async fn handle_sse_message(
146 &self,
147 request: http::Request<&str>,
148 state: Arc<McpAppState>,
149 ) -> TransportServerResult<http::Response<GenericBody>> {
150 let handle = with_middlewares!(self, Self::internal_handle_sse_message);
151 handle(request, state).await
152 }
153
154 pub async fn handle_streamable_http(
176 &self,
177 request: http::Request<&str>,
178 state: Arc<McpAppState>,
179 ) -> TransportServerResult<http::Response<GenericBody>> {
180 let handle = with_middlewares!(self, Self::internal_handle_streamable_http);
181 handle(request, state).await
182 }
183
184 async fn internal_handle_sse_message(
185 request: http::Request<&str>,
186 state: Arc<McpAppState>,
187 ) -> TransportServerResult<http::Response<GenericBody>> {
188 let session_id =
189 query_param(&request, "sessionId").ok_or(TransportServerError::SessionIdMissing)?;
190
191 let transmit = state.session_store.get(&session_id).await.ok_or(
193 TransportServerError::SessionIdInvalid(session_id.to_string()),
194 )?;
195
196 let message = request.body();
197
198 transmit
199 .consume_payload_string(DEFAULT_STREAM_ID, message.as_ref())
200 .await
201 .map_err(|err| {
202 tracing::trace!("{}", err);
203 TransportServerError::StreamIoError(err.to_string())
204 })?;
205
206 http::Response::builder()
207 .status(StatusCode::ACCEPTED)
208 .body(empty_response())
209 .map_err(|err| TransportServerError::HttpError(err.to_string()))
210 }
211
212 async fn internal_handle_streamable_http(
213 request: http::Request<&str>,
214 state: Arc<McpAppState>,
215 ) -> TransportServerResult<http::Response<GenericBody>> {
216 let method = request.method();
217 let response = match method {
218 &http::Method::GET => return Self::handle_http_get(request, state).await,
219 &http::Method::POST => return Self::handle_http_post(request, state).await,
220 &http::Method::DELETE => return Self::handle_http_delete(request, state).await,
221 other => {
222 let error = SdkError::bad_request().with_message(&format!(
223 "'{other}' is not a valid HTTP method for StreamableHTTP transport."
224 ));
225 error_response(StatusCode::METHOD_NOT_ALLOWED, error)
226 }
227 };
228
229 response
230 }
231
232 async fn handle_http_post(
234 request: http::Request<&str>,
235 state: Arc<McpAppState>,
236 ) -> TransportServerResult<http::Response<GenericBody>> {
237 let headers = request.headers();
238
239 if !valid_streaming_http_accept_header(headers) {
240 let error = SdkError::bad_request()
241 .with_message(r#"Client must accept both application/json and text/event-stream"#);
242 return error_response(StatusCode::NOT_ACCEPTABLE, error);
243 }
244
245 if !acceptable_content_type(headers) {
246 let error = SdkError::bad_request()
247 .with_message(r#"Unsupported Media Type: Content-Type must be application/json"#);
248 return error_response(StatusCode::UNSUPPORTED_MEDIA_TYPE, error);
249 }
250
251 if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
252 let error = SdkError::bad_request()
253 .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
254 return error_response(StatusCode::BAD_REQUEST, error);
255 }
256
257 let session_id: Option<SessionId> = headers
258 .get(MCP_SESSION_ID_HEADER)
259 .and_then(|value| value.to_str().ok())
260 .map(|s| s.to_string());
261
262 let payload = request.body();
263
264 let response = match session_id {
265 Some(id) => {
267 if state.enable_json_response {
268 process_incoming_message_return(id, state, payload).await
269 } else {
270 process_incoming_message(id, state, payload).await
271 }
272 }
273 None => match valid_initialize_method(payload) {
274 Ok(_) => {
275 return start_new_session(state, payload).await;
276 }
277 Err(McpSdkError::SdkError(error)) => error_response(StatusCode::BAD_REQUEST, error),
278 Err(error) => {
279 let error = SdkError::bad_request().with_message(&error.to_string());
280 error_response(StatusCode::BAD_REQUEST, error)
281 }
282 },
283 };
284
285 response
286 }
287
288 async fn handle_http_get(
290 request: http::Request<&str>,
291 state: Arc<McpAppState>,
292 ) -> TransportServerResult<http::Response<GenericBody>> {
293 let headers = request.headers();
294
295 if !accepts_event_stream(headers) {
296 let error =
297 SdkError::bad_request().with_message(r#"Client must accept text/event-stream"#);
298 return error_response(StatusCode::NOT_ACCEPTABLE, error);
299 }
300
301 if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
302 let error = SdkError::bad_request()
303 .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
304 return error_response(StatusCode::BAD_REQUEST, error);
305 }
306
307 let session_id: Option<SessionId> = headers
308 .get(MCP_SESSION_ID_HEADER)
309 .and_then(|value| value.to_str().ok())
310 .map(|s| s.to_string());
311
312 let last_event_id: Option<SessionId> = headers
313 .get(MCP_LAST_EVENT_ID_HEADER)
314 .and_then(|value| value.to_str().ok())
315 .map(|s| s.to_string());
316
317 let response = match session_id {
318 Some(session_id) => {
319 let res = create_standalone_stream(session_id, last_event_id, state).await;
320 res
321 }
322 None => {
323 let error = SdkError::bad_request().with_message("Bad request: session not found");
324 error_response(StatusCode::BAD_REQUEST, error)
325 }
326 };
327
328 response
329 }
330
331 async fn handle_http_delete(
333 request: http::Request<&str>,
334 state: Arc<McpAppState>,
335 ) -> TransportServerResult<http::Response<GenericBody>> {
336 let headers = request.headers();
337
338 if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
339 let error = SdkError::bad_request()
340 .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
341 return error_response(StatusCode::BAD_REQUEST, error);
342 }
343
344 let session_id: Option<SessionId> = headers
345 .get(MCP_SESSION_ID_HEADER)
346 .and_then(|value| value.to_str().ok())
347 .map(|s| s.to_string());
348
349 let response = match session_id {
350 Some(id) => delete_session(id, state).await,
351 None => {
352 let error = SdkError::bad_request().with_message("Bad Request: Session not found");
353 error_response(StatusCode::BAD_REQUEST, error)
354 }
355 };
356
357 response
358 }
359}