1#[cfg(feature = "sse")]
2use super::http_utils::handle_sse_connection;
3use super::http_utils::{
4 accepts_event_stream, error_response, query_param, validate_mcp_protocol_version_header,
5};
6use super::types::GenericBody;
7use crate::auth::AuthInfo;
8#[cfg(feature = "auth")]
9use crate::auth::AuthProvider;
10use crate::mcp_http::{middleware::compose, BoxFutureResponse, Middleware, RequestHandler};
11use crate::mcp_http::{GenericBodyExt, RequestExt};
12use crate::mcp_runtimes::server_runtime::DEFAULT_STREAM_ID;
13use crate::mcp_server::error::TransportServerError;
14use crate::schema::schema_utils::SdkError;
15use crate::{
16 error::McpSdkError,
17 mcp_http::{
18 http_utils::{
19 acceptable_content_type, create_standalone_stream, delete_session,
20 process_incoming_message, process_incoming_message_return, start_new_session,
21 valid_streaming_http_accept_header,
22 },
23 McpAppState,
24 },
25 mcp_server::error::TransportServerResult,
26 utils::valid_initialize_method,
27};
28use http::{self, HeaderMap, Method, StatusCode, Uri};
29use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER};
30use std::sync::Arc;
31
32#[macro_export]
44macro_rules! with_middlewares {
45 ($self:ident, $handler:path) => {{
46 let final_handler: RequestHandler = Box::new(
47 move |req: http::Request<&str>,
48 state: std::sync::Arc<McpAppState>|
49 -> BoxFutureResponse<'_> {
50 Box::pin(async move { $handler(req, state).await })
51 },
52 );
53 $crate::mcp_http::middleware::compose(&$self.middlewares, final_handler)
54 }};
55
56 ($self:ident, $handler:path, $($extra:expr),+ $(,)?) => {{
58 let final_handler: RequestHandler = Box::new(
59 move |req: http::Request<&str>,
60 state: std::sync::Arc<McpAppState>|
61 -> BoxFutureResponse<'_> {
62 Box::pin(async move { $handler(req, state).await })
63 },
64 );
65
66 let all = $self.middlewares.iter()
68 $(.chain($extra.iter()))+;
69
70 $crate::mcp_http::middleware::compose(all, final_handler)
71 }};
72}
73
74#[derive(Clone)]
75pub struct McpHttpHandler {
76 #[cfg(feature = "auth")]
77 auth: Option<Arc<dyn AuthProvider>>,
78 middlewares: Vec<Arc<dyn Middleware>>,
79}
80
81impl McpHttpHandler {
82 #[cfg(feature = "auth")]
83 pub fn new(auth: Option<Arc<dyn AuthProvider>>, middlewares: Vec<Arc<dyn Middleware>>) -> Self {
84 McpHttpHandler { auth, middlewares }
85 }
86
87 #[cfg(not(feature = "auth"))]
88 pub fn new(middlewares: Vec<Arc<dyn Middleware>>) -> Self {
89 McpHttpHandler { middlewares }
90 }
91
92 pub fn add_middleware<M: Middleware + 'static>(&mut self, middleware: M) {
93 let m: Arc<dyn Middleware> = Arc::new(middleware);
94 self.middlewares.push(m);
95 }
96
97 pub fn create_request(
101 method: Method,
102 uri: Uri,
103 headers: HeaderMap,
104 body: Option<&str>,
105 ) -> http::Request<&str> {
106 let mut request = http::Request::default();
107 *request.method_mut() = method;
108 *request.uri_mut() = uri;
109 *request.body_mut() = body.unwrap_or_default();
110 let req_headers = request.headers_mut();
111 for (key, value) in headers {
112 if let Some(k) = key {
113 req_headers.insert(k, value);
114 }
115 }
116 request
117 }
118}
119
120#[cfg(feature = "auth")]
122impl McpHttpHandler {
123 pub fn oauth_endppoints(&self) -> Option<Vec<&String>> {
124 self.auth
125 .as_ref()
126 .and_then(|a| a.auth_endpoints().map(|e| e.keys().collect::<Vec<_>>()))
127 }
128
129 pub async fn handle_auth_requests(
130 &self,
131 request: http::Request<&str>,
132 state: Arc<McpAppState>,
133 ) -> TransportServerResult<http::Response<GenericBody>> {
134 let Some(auth_provider) = self.auth.as_ref() else {
135 return Err(TransportServerError::HttpError(
136 "Authentication is not supported by this server.".to_string(),
137 ));
138 };
139
140 let auth_provider = auth_provider.clone();
141 let final_handler: RequestHandler = Box::new(move |req, state| {
142 Box::pin(async move {
143 use futures::TryFutureExt;
144 auth_provider
145 .handle_request(req, state)
146 .map_err(|e| e)
147 .await
148 })
149 });
150
151 let handle = compose(&[], final_handler);
152 handle(request, state).await
153 }
154}
155
156impl McpHttpHandler {
157 #[cfg(feature = "sse")]
170 pub async fn handle_sse_connection(
171 &self,
172 request: http::Request<&str>,
173 state: Arc<McpAppState>,
174 sse_message_endpoint: Option<&str>,
175 ) -> TransportServerResult<http::Response<GenericBody>> {
176 use crate::auth::AuthInfo;
177 use crate::mcp_http::RequestExt;
178
179 let (request, auth_info) = request.take::<AuthInfo>();
180
181 let sse_endpoint = sse_message_endpoint.map(|s| s.to_string());
182 let final_handler: RequestHandler = Box::new(move |_req, state| {
183 Box::pin(async move {
184 handle_sse_connection(state, sse_endpoint.as_deref(), auth_info).await
185 })
186 });
187 let handle = compose(&self.middlewares, final_handler);
188 handle(request, state).await
189 }
190
191 #[cfg(feature = "sse")]
212 pub async fn handle_sse_message(
213 &self,
214 request: http::Request<&str>,
215 state: Arc<McpAppState>,
216 ) -> TransportServerResult<http::Response<GenericBody>> {
217 let handle = with_middlewares!(self, Self::internal_handle_sse_message);
218 handle(request, state).await
219 }
220
221 pub async fn handle_streamable_http(
243 &self,
244 request: http::Request<&str>,
245 state: Arc<McpAppState>,
246 ) -> TransportServerResult<http::Response<GenericBody>> {
247 let handle = with_middlewares!(self, Self::internal_handle_streamable_http);
248 handle(request, state).await
249 }
250
251 async fn internal_handle_sse_message(
252 request: http::Request<&str>,
253 state: Arc<McpAppState>,
254 ) -> TransportServerResult<http::Response<GenericBody>> {
255 let session_id =
256 query_param(&request, "sessionId").ok_or(TransportServerError::SessionIdMissing)?;
257
258 let transmit = state.session_store.get(&session_id).await.ok_or(
260 TransportServerError::SessionIdInvalid(session_id.to_string()),
261 )?;
262
263 let message = request.body();
264
265 transmit
266 .consume_payload_string(DEFAULT_STREAM_ID, message.as_ref())
267 .await
268 .map_err(|err| {
269 tracing::trace!("{}", err);
270 TransportServerError::StreamIoError(err.to_string())
271 })?;
272
273 http::Response::builder()
274 .status(StatusCode::ACCEPTED)
275 .body(GenericBody::empty())
276 .map_err(|err| TransportServerError::HttpError(err.to_string()))
277 }
278
279 async fn internal_handle_streamable_http(
280 request: http::Request<&str>,
281 state: Arc<McpAppState>,
282 ) -> TransportServerResult<http::Response<GenericBody>> {
283 let (request, auth_info) = request.take::<AuthInfo>();
284
285 let method = request.method();
286
287 let response = match method {
288 &http::Method::GET => return Self::handle_http_get(request, state, auth_info).await,
289 &http::Method::POST => return Self::handle_http_post(request, state, auth_info).await,
290 &http::Method::DELETE => return Self::handle_http_delete(request, state).await,
291 other => {
292 let error = SdkError::bad_request().with_message(&format!(
293 "'{other}' is not a valid HTTP method for StreamableHTTP transport."
294 ));
295 error_response(StatusCode::METHOD_NOT_ALLOWED, error)
296 }
297 };
298
299 response
300 }
301
302 async fn handle_http_post(
304 request: http::Request<&str>,
305 state: Arc<McpAppState>,
306 auth_info: Option<AuthInfo>,
307 ) -> TransportServerResult<http::Response<GenericBody>> {
308 let headers = request.headers();
309
310 if !valid_streaming_http_accept_header(headers) {
311 let error = SdkError::bad_request()
312 .with_message(r#"Client must accept both application/json and text/event-stream"#);
313 return error_response(StatusCode::NOT_ACCEPTABLE, error);
314 }
315
316 if !acceptable_content_type(headers) {
317 let error = SdkError::bad_request()
318 .with_message(r#"Unsupported Media Type: Content-Type must be application/json"#);
319 return error_response(StatusCode::UNSUPPORTED_MEDIA_TYPE, error);
320 }
321
322 if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
323 let error = SdkError::bad_request()
324 .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
325 return error_response(StatusCode::BAD_REQUEST, error);
326 }
327
328 let session_id: Option<SessionId> = headers
329 .get(MCP_SESSION_ID_HEADER)
330 .and_then(|value| value.to_str().ok())
331 .map(|s| s.to_string());
332
333 let payload = request.body();
334
335 let response = match session_id {
336 Some(id) => {
338 if state.enable_json_response {
339 process_incoming_message_return(id, state, payload, auth_info).await
340 } else {
341 process_incoming_message(id, state, payload, auth_info).await
342 }
343 }
344 None => match valid_initialize_method(payload) {
345 Ok(_) => {
346 return start_new_session(state, payload, auth_info).await;
347 }
348 Err(McpSdkError::SdkError(error)) => error_response(StatusCode::BAD_REQUEST, error),
349 Err(error) => {
350 let error = SdkError::bad_request().with_message(&error.to_string());
351 error_response(StatusCode::BAD_REQUEST, error)
352 }
353 },
354 };
355
356 response
357 }
358
359 async fn handle_http_get(
361 request: http::Request<&str>,
362 state: Arc<McpAppState>,
363 auth_info: Option<AuthInfo>,
364 ) -> TransportServerResult<http::Response<GenericBody>> {
365 let headers = request.headers();
366
367 if !accepts_event_stream(headers) {
368 let error =
369 SdkError::bad_request().with_message(r#"Client must accept text/event-stream"#);
370 return error_response(StatusCode::NOT_ACCEPTABLE, error);
371 }
372
373 if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
374 let error = SdkError::bad_request()
375 .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
376 return error_response(StatusCode::BAD_REQUEST, error);
377 }
378
379 let session_id: Option<SessionId> = headers
380 .get(MCP_SESSION_ID_HEADER)
381 .and_then(|value| value.to_str().ok())
382 .map(|s| s.to_string());
383
384 let last_event_id: Option<SessionId> = headers
385 .get(MCP_LAST_EVENT_ID_HEADER)
386 .and_then(|value| value.to_str().ok())
387 .map(|s| s.to_string());
388
389 let response = match session_id {
390 Some(session_id) => {
391 let res =
392 create_standalone_stream(session_id, last_event_id, state, auth_info).await;
393 res
394 }
395 None => {
396 let error = SdkError::bad_request().with_message("Bad request: session not found");
397 error_response(StatusCode::BAD_REQUEST, error)
398 }
399 };
400
401 response
402 }
403
404 async fn handle_http_delete(
406 request: http::Request<&str>,
407 state: Arc<McpAppState>,
408 ) -> TransportServerResult<http::Response<GenericBody>> {
409 let headers = request.headers();
410
411 if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
412 let error = SdkError::bad_request()
413 .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
414 return error_response(StatusCode::BAD_REQUEST, error);
415 }
416
417 let session_id: Option<SessionId> = headers
418 .get(MCP_SESSION_ID_HEADER)
419 .and_then(|value| value.to_str().ok())
420 .map(|s| s.to_string());
421
422 let response = match session_id {
423 Some(id) => delete_session(id, state).await,
424 None => {
425 let error = SdkError::bad_request().with_message("Bad Request: Session not found");
426 error_response(StatusCode::BAD_REQUEST, error)
427 }
428 };
429
430 response
431 }
432}