rust_mcp_sdk/mcp_http/
mcp_http_handler.rs1#[cfg(feature = "sse")]
2use super::utils::handle_sse_connection;
3use crate::mcp_http::utils::{
4 accepts_event_stream, error_response, query_param, validate_mcp_protocol_version_header,
5};
6use crate::mcp_runtimes::server_runtime::DEFAULT_STREAM_ID;
7use crate::mcp_server::error::TransportServerError;
8use crate::schema::schema_utils::SdkError;
9use crate::{
10 error::McpSdkError,
11 mcp_http::{
12 utils::{
13 acceptable_content_type, create_standalone_stream, delete_session,
14 process_incoming_message, process_incoming_message_return, protect_dns_rebinding,
15 start_new_session, valid_streaming_http_accept_header, GenericBody,
16 },
17 McpAppState,
18 },
19 mcp_server::error::TransportServerResult,
20 utils::valid_initialize_method,
21};
22use bytes::Bytes;
23use http::{self, HeaderMap, Method, StatusCode, Uri};
24use http_body_util::{BodyExt, Full};
25use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER};
26use std::sync::Arc;
27
28pub struct McpHttpHandler {}
29
30impl McpHttpHandler {
31 pub fn create_request(
46 method: Method,
47 uri: Uri,
48 headers: HeaderMap,
49 body: Option<&str>,
50 ) -> http::Request<&str> {
51 let mut request = http::Request::default();
52 *request.method_mut() = method;
53 *request.uri_mut() = uri;
54 *request.body_mut() = body.unwrap_or_default();
55 let req_headers = request.headers_mut();
56 for (key, value) in headers {
57 if let Some(k) = key {
58 req_headers.insert(k, value);
59 }
60 }
61 request
62 }
63}
64
65impl McpHttpHandler {
66 #[cfg(feature = "sse")]
79 pub async fn handle_sse_connection(
80 state: Arc<McpAppState>,
81 sse_message_endpoint: Option<&str>,
82 ) -> TransportServerResult<http::Response<GenericBody>> {
83 handle_sse_connection(state, sse_message_endpoint).await
84 }
85
86 pub async fn handle_sse_message(
107 request: http::Request<&str>,
108 state: Arc<McpAppState>,
109 ) -> TransportServerResult<http::Response<GenericBody>> {
110 let session_id =
111 query_param(&request, "sessionId").ok_or(TransportServerError::SessionIdMissing)?;
112
113 let transmit = state.session_store.get(&session_id).await.ok_or(
115 TransportServerError::SessionIdInvalid(session_id.to_string()),
116 )?;
117
118 let message = *request.body();
119 transmit
120 .consume_payload_string(DEFAULT_STREAM_ID, message)
121 .await
122 .map_err(|err| {
123 tracing::trace!("{}", err);
124 TransportServerError::StreamIoError(err.to_string())
125 })?;
126
127 let body = Full::new(Bytes::new())
128 .map_err(|err| TransportServerError::HttpError(err.to_string()))
129 .boxed();
130
131 http::Response::builder()
132 .status(StatusCode::ACCEPTED)
133 .body(body)
134 .map_err(|err| TransportServerError::HttpError(err.to_string()))
135 }
136
137 pub async fn handle_streamable_http(
159 request: http::Request<&str>,
160 state: Arc<McpAppState>,
161 ) -> TransportServerResult<http::Response<GenericBody>> {
162 if state.needs_dns_protection() {
165 if let Err(error) = protect_dns_rebinding(request.headers(), state.clone()).await {
166 return error_response(StatusCode::FORBIDDEN, error);
167 }
168 }
169
170 let method = request.method();
171 match method {
172 &http::Method::GET => return Self::handle_http_get(request, state).await,
173 &http::Method::POST => return Self::handle_http_post(request, state).await,
174 &http::Method::DELETE => return Self::handle_http_delete(request, state).await,
175 other => {
176 let error = SdkError::bad_request().with_message(&format!(
177 "'{other}' is not a valid HTTP method for StreamableHTTP transport."
178 ));
179 error_response(StatusCode::METHOD_NOT_ALLOWED, error)
180 }
181 }
182 }
183
184 async fn handle_http_post(
186 request: http::Request<&str>,
187 state: Arc<McpAppState>,
188 ) -> TransportServerResult<http::Response<GenericBody>> {
189 let headers = request.headers();
190
191 if !valid_streaming_http_accept_header(headers) {
192 let error = SdkError::bad_request()
193 .with_message(r#"Client must accept both application/json and text/event-stream"#);
194 return error_response(StatusCode::NOT_ACCEPTABLE, error);
195 }
196
197 if !acceptable_content_type(headers) {
198 let error = SdkError::bad_request()
199 .with_message(r#"Unsupported Media Type: Content-Type must be application/json"#);
200 return error_response(StatusCode::UNSUPPORTED_MEDIA_TYPE, error);
201 }
202
203 if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
204 let error = SdkError::bad_request()
205 .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
206 return error_response(StatusCode::BAD_REQUEST, error);
207 }
208
209 let session_id: Option<SessionId> = headers
210 .get(MCP_SESSION_ID_HEADER)
211 .and_then(|value| value.to_str().ok())
212 .map(|s| s.to_string());
213
214 let payload = *request.body();
215
216 match session_id {
217 Some(id) => {
219 if state.enable_json_response {
220 process_incoming_message_return(id, state, payload).await
221 } else {
222 process_incoming_message(id, state, payload).await
223 }
224 }
225 None => match valid_initialize_method(payload) {
226 Ok(_) => {
227 return start_new_session(state, payload).await;
228 }
229 Err(McpSdkError::SdkError(error)) => error_response(StatusCode::BAD_REQUEST, error),
230 Err(error) => {
231 let error = SdkError::bad_request().with_message(&error.to_string());
232 error_response(StatusCode::BAD_REQUEST, error)
233 }
234 },
235 }
236 }
237
238 async fn handle_http_get(
240 request: http::Request<&str>,
241 state: Arc<McpAppState>,
242 ) -> TransportServerResult<http::Response<GenericBody>> {
243 let headers = request.headers();
244
245 if !accepts_event_stream(headers) {
246 let error =
247 SdkError::bad_request().with_message(r#"Client must accept text/event-stream"#);
248 return error_response(StatusCode::NOT_ACCEPTABLE, error);
249 }
250
251 if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
252 let error = SdkError::bad_request()
253 .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
254 return error_response(StatusCode::BAD_REQUEST, error);
255 }
256
257 let session_id: Option<SessionId> = headers
258 .get(MCP_SESSION_ID_HEADER)
259 .and_then(|value| value.to_str().ok())
260 .map(|s| s.to_string());
261
262 let last_event_id: Option<SessionId> = headers
263 .get(MCP_LAST_EVENT_ID_HEADER)
264 .and_then(|value| value.to_str().ok())
265 .map(|s| s.to_string());
266
267 match session_id {
268 Some(session_id) => {
269 let res = create_standalone_stream(session_id, last_event_id, state).await;
270 res
271 }
272 None => {
273 let error = SdkError::bad_request().with_message("Bad request: session not found");
274 error_response(StatusCode::BAD_REQUEST, error)
275 }
276 }
277 }
278
279 async fn handle_http_delete(
281 request: http::Request<&str>,
282 state: Arc<McpAppState>,
283 ) -> TransportServerResult<http::Response<GenericBody>> {
284 let headers = request.headers();
285
286 if let Err(parse_error) = validate_mcp_protocol_version_header(headers) {
287 let error = SdkError::bad_request()
288 .with_message(format!(r#"Bad Request: {parse_error}"#).as_str());
289 return error_response(StatusCode::BAD_REQUEST, error);
290 }
291
292 let session_id: Option<SessionId> = headers
293 .get(MCP_SESSION_ID_HEADER)
294 .and_then(|value| value.to_str().ok())
295 .map(|s| s.to_string());
296
297 match session_id {
298 Some(id) => delete_session(id, state).await,
299 None => {
300 let error = SdkError::bad_request().with_message("Bad Request: Session not found");
301 error_response(StatusCode::BAD_REQUEST, error)
302 }
303 }
304 }
305}