1use crate::auth::AuthInfo;
2use crate::mcp_http::types::GenericBody;
3use crate::schema::schema_utils::{ClientMessage, SdkError};
4use crate::McpServer;
5use crate::{
6 error::SdkResult,
7 hyper_servers::error::{TransportServerError, TransportServerResult},
8 mcp_http::McpAppState,
9 mcp_runtimes::server_runtime::DEFAULT_STREAM_ID,
10 mcp_server::{server_runtime, ServerRuntime},
11 mcp_traits::{IdGenerator, McpServerHandler},
12 utils::validate_mcp_protocol_version,
13};
14use axum::http::HeaderValue;
15use bytes::Bytes;
16use futures::stream;
17use http::header::{ACCEPT, CONNECTION, CONTENT_TYPE};
18use http_body::Frame;
19use http_body_util::{BodyExt, Full, StreamBody};
20use hyper::{HeaderMap, StatusCode};
21use rust_mcp_transport::{
22 EventId, McpDispatch, SessionId, SseEvent, SseTransport, StreamId, ID_SEPARATOR,
23 MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER,
24};
25use serde_json::{Map, Value};
26use std::sync::Arc;
27use tokio::io::{duplex, AsyncBufReadExt, BufReader};
28use tokio_stream::StreamExt;
29
30pub(crate) const DEFAULT_SSE_ENDPOINT: &str = "/sse";
32pub(crate) const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages";
34pub(crate) const DEFAULT_STREAMABLE_HTTP_ENDPOINT: &str = "/mcp";
36const DUPLEX_BUFFER_SIZE: usize = 8192;
37
38fn initial_sse_event(endpoint: &str) -> Result<Bytes, TransportServerError> {
48 Ok(SseEvent::default()
49 .with_event("endpoint")
50 .with_data(endpoint.to_string())
51 .as_bytes())
52}
53
54#[cfg(feature = "auth")]
55pub fn url_base(url: &url::Url) -> String {
56 format!("{}://{}", url.scheme(), url.host_str().unwrap_or_default())
57}
58
59fn strip_bearer_prefix(header: &str) -> &str {
65 let lower = header.to_lowercase();
66 if lower.starts_with("bearer ") {
67 header[7..].trim()
68 } else if lower == "bearer" {
69 ""
70 } else {
71 header.trim()
72 }
73}
74
75#[cfg(feature = "auth")]
78pub fn parse_www_authenticate(header: &str) -> Option<Map<String, Value>> {
79 let params_str = strip_bearer_prefix(header);
80
81 let mut result: Option<Map<String, Value>> = None;
82
83 for part in params_str.split(',') {
84 let part = part.trim();
85
86 if let Some((key, value)) = part.split_once('=') {
87 let cleaned = value.trim().trim_matches('"');
88
89 let map = result.get_or_insert_with(Map::new);
91 map.insert(key.to_string(), Value::String(cleaned.to_string()));
92 }
93 }
94
95 result
96}
97
98#[cfg(feature = "auth")]
110pub async fn error_message_from_response(
111 response: reqwest::Response,
112 default_message: &str,
113) -> String {
114 if let Some(www_authenticate) = response
115 .headers()
116 .get(http::header::WWW_AUTHENTICATE)
117 .and_then(|v| v.to_str().ok())
118 {
119 if let Some(map) = parse_www_authenticate(www_authenticate) {
120 if let Some(Value::String(s)) = map.get("error_description") {
121 return s.clone();
122 }
123 if let Some(Value::String(s)) = map.get("error") {
124 return s.clone();
125 }
126
127 let values: Vec<&str> = map
129 .values()
130 .filter_map(|v| match v {
131 Value::String(s) => Some(s.as_str()),
132 _ => None,
133 })
134 .collect();
135 if !values.is_empty() {
136 return values.join(", ");
137 }
138 }
139 }
140
141 response.text().await.unwrap_or(default_message.to_owned())
142}
143
144async fn create_sse_stream(
145 runtime: Arc<ServerRuntime>,
146 session_id: SessionId,
147 state: Arc<McpAppState>,
148 payload: Option<&str>,
149 standalone: bool,
150 last_event_id: Option<EventId>,
151) -> TransportServerResult<http::Response<GenericBody>> {
152 let payload_string = payload.map(|p| p.to_string());
153
154 let payload_contains_request = payload_string
156 .as_ref()
157 .map(|json_str| contains_request(json_str))
158 .unwrap_or(Ok(false));
159 let Ok(payload_contains_request) = payload_contains_request else {
160 return error_response(StatusCode::BAD_REQUEST, SdkError::parse_error());
161 };
162
163 let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE);
165 let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE);
167
168 let session_id = Arc::new(session_id);
169 let stream_id: Arc<StreamId> = if standalone {
170 Arc::new(DEFAULT_STREAM_ID.to_string())
171 } else {
172 Arc::new(state.stream_id_gen.generate())
173 };
174
175 let event_store = state.event_store.as_ref().map(Arc::clone);
176 let resumability_enabled = event_store.is_some();
177
178 let mut transport = SseTransport::<ClientMessage>::new(
179 read_rx,
180 write_tx,
181 read_tx,
182 Arc::clone(&state.transport_options),
183 )
184 .map_err(|err| TransportServerError::TransportError(err.to_string()))?;
185 if let Some(event_store) = event_store.clone() {
186 transport.make_resumable((*session_id).clone(), (*stream_id).clone(), event_store);
187 }
188 let transport = Arc::new(transport);
189
190 let ping_interval = state.ping_interval;
191 let runtime_clone = Arc::clone(&runtime);
192 let stream_id_clone = stream_id.clone();
193 let transport_clone = transport.clone();
194
195 tokio::spawn(async move {
197 match runtime_clone
198 .start_stream(
199 transport_clone,
200 &stream_id_clone,
201 ping_interval,
202 payload_string,
203 )
204 .await
205 {
206 Ok(_) => tracing::trace!("stream {} exited gracefully.", &stream_id_clone),
207 Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id_clone, err),
208 }
209 let _ = runtime.remove_transport(&stream_id_clone).await;
210 });
211
212 let reader = BufReader::new(write_rx);
214
215 let message_stream = stream::unfold(reader, move |mut reader| {
217 async move {
218 let mut line = String::new();
219
220 match reader.read_line(&mut line).await {
221 Ok(0) => None, Ok(_) => {
223 let trimmed_line = line.trim_end_matches('\n').to_owned();
224
225 if is_empty_sse_message(&trimmed_line) {
227 return Some((Ok(SseEvent::default().as_bytes()), reader));
228 }
229
230 let (event_id, message) = match (
231 resumability_enabled,
232 trimmed_line.split_once(char::from(ID_SEPARATOR)),
233 ) {
234 (true, Some((id, msg))) => (Some(id.to_string()), msg.to_string()),
235 _ => (None, trimmed_line),
236 };
237
238 let event = match event_id {
239 Some(id) => SseEvent::default()
240 .with_data(message)
241 .with_id(id)
242 .as_bytes(),
243 None => SseEvent::default().with_data(message).as_bytes(),
244 };
245
246 Some((Ok(event), reader))
247 }
248 Err(e) => Some((Err(e), reader)),
249 }
250 }
251 });
252
253 let streaming_body: GenericBody =
255 http_body_util::BodyExt::boxed(StreamBody::new(message_stream.map(|res| {
256 res.map(Frame::data)
257 .map_err(|err: std::io::Error| TransportServerError::HttpError(err.to_string()))
258 })));
259
260 let session_id_value = HeaderValue::from_str(&session_id)
261 .map_err(|err| TransportServerError::HttpError(err.to_string()))?;
262
263 let status_code = if !payload_contains_request {
264 StatusCode::ACCEPTED
265 } else {
266 StatusCode::OK
267 };
268
269 let response = http::Response::builder()
270 .status(status_code)
271 .header(CONTENT_TYPE, "text/event-stream")
272 .header(MCP_SESSION_ID_HEADER, session_id_value)
273 .header(CONNECTION, "keep-alive")
274 .body(streaming_body)
275 .map_err(|err| TransportServerError::HttpError(err.to_string()))?;
276
277 tokio::spawn(async move {
279 if let Some(last_event_id) = last_event_id {
280 if let Some(event_store) = state.event_store.as_ref() {
281 let events = event_store
282 .events_after(last_event_id)
283 .await
284 .unwrap_or_else(|err| {
285 tracing::error!("{err}");
286 None
287 });
288
289 if let Some(events) = events {
290 for message_payload in events.messages {
291 let error = transport.write_str(&message_payload, true).await;
293 if let Err(error) = error {
294 tracing::trace!("Error replaying message: {error}")
295 }
296 }
297 }
298 }
299 }
300 });
301
302 Ok(response)
303}
304
305fn contains_request(json_str: &str) -> Result<bool, serde_json::Error> {
309 let value: serde_json::Value = serde_json::from_str(json_str)?;
310 match value {
311 serde_json::Value::Object(obj) => Ok(obj.contains_key("id") && obj.contains_key("method")),
312 serde_json::Value::Array(arr) => Ok(arr.iter().any(|item| {
313 item.as_object()
314 .map(|obj| obj.contains_key("id") && obj.contains_key("method"))
315 .unwrap_or(false)
316 })),
317 _ => Ok(false),
318 }
319}
320
321fn is_result(json_str: &str) -> Result<bool, serde_json::Error> {
322 let value: serde_json::Value = serde_json::from_str(json_str)?;
323 match value {
324 serde_json::Value::Object(obj) => Ok(obj.contains_key("result")),
325 serde_json::Value::Array(arr) => Ok(arr.iter().all(|item| {
326 item.as_object()
327 .map(|obj| obj.contains_key("result"))
328 .unwrap_or(false)
329 })),
330 _ => Ok(false),
331 }
332}
333
334pub(crate) async fn create_standalone_stream(
335 session_id: SessionId,
336 last_event_id: Option<EventId>,
337 state: Arc<McpAppState>,
338 auth_info: Option<AuthInfo>,
339) -> TransportServerResult<http::Response<GenericBody>> {
340 let runtime = state.session_store.get(&session_id).await.ok_or(
341 TransportServerError::SessionIdInvalid(session_id.to_string()),
342 )?;
343
344 runtime.update_auth_info(auth_info).await;
345
346 if runtime.stream_id_exists(DEFAULT_STREAM_ID).await {
347 let error =
348 SdkError::bad_request().with_message("Only one SSE stream is allowed per session");
349 return error_response(StatusCode::CONFLICT, error)
350 .map_err(|err| TransportServerError::HttpError(err.to_string()));
351 }
352
353 if let Some(last_event_id) = last_event_id.as_ref() {
354 tracing::trace!(
355 "SSE stream re-connected with last-event-id: {}",
356 last_event_id
357 );
358 }
359
360 let mut response = create_sse_stream(
361 runtime.clone(),
362 session_id.clone(),
363 state.clone(),
364 None,
365 true,
366 last_event_id,
367 )
368 .await?;
369 *response.status_mut() = StatusCode::OK;
370 Ok(response)
371}
372
373pub(crate) async fn start_new_session(
374 state: Arc<McpAppState>,
375 payload: &str,
376 auth_info: Option<AuthInfo>,
377) -> TransportServerResult<http::Response<GenericBody>> {
378 let session_id: SessionId = state.id_generator.generate();
379
380 let h: Arc<dyn McpServerHandler> = state.handler.clone();
381 let runtime: Arc<ServerRuntime> = server_runtime::create_server_instance(
383 Arc::clone(&state.server_details),
384 h,
385 session_id.to_owned(),
386 auth_info,
387 );
388
389 tracing::info!("a new client joined : {}", &session_id);
390
391 let response = create_sse_stream(
392 runtime.clone(),
393 session_id.clone(),
394 state.clone(),
395 Some(payload),
396 false,
397 None,
398 )
399 .await;
400
401 if response.is_ok() {
402 state
403 .session_store
404 .set(session_id.to_owned(), runtime.clone())
405 .await;
406 }
407 response
408}
409async fn single_shot_stream(
410 runtime: Arc<ServerRuntime>,
411 session_id: SessionId,
412 state: Arc<McpAppState>,
413 payload: Option<&str>,
414 standalone: bool,
415) -> TransportServerResult<http::Response<GenericBody>> {
416 let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE);
418 let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE);
420
421 let transport = SseTransport::<ClientMessage>::new(
422 read_rx,
423 write_tx,
424 read_tx,
425 Arc::clone(&state.transport_options),
426 )
427 .map_err(|err| TransportServerError::TransportError(err.to_string()))?;
428
429 let stream_id = if standalone {
430 DEFAULT_STREAM_ID.to_string()
431 } else {
432 state.id_generator.generate()
433 };
434 let ping_interval = state.ping_interval;
435 let runtime_clone = Arc::clone(&runtime);
436
437 let payload_string = payload.map(|p| p.to_string());
438
439 tokio::spawn(async move {
440 match runtime_clone
441 .start_stream(
442 Arc::new(transport),
443 &stream_id,
444 ping_interval,
445 payload_string,
446 )
447 .await
448 {
449 Ok(_) => tracing::info!("stream {} exited gracefully.", &stream_id),
450 Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id, err),
451 }
452 let _ = runtime.remove_transport(&stream_id).await;
453 });
454
455 let mut reader = BufReader::new(write_rx);
456 let mut line = String::new();
457 let response = match reader.read_line(&mut line).await {
458 Ok(0) => None, Ok(_) => {
460 let trimmed_line = line.trim_end_matches('\n').to_owned();
461 Some(Ok(trimmed_line))
462 }
463 Err(e) => Some(Err(e)),
464 };
465
466 let session_id_value = HeaderValue::from_str(&session_id)
467 .map_err(|err| TransportServerError::HttpError(err.to_string()))?;
468
469 match response {
470 Some(response_result) => match response_result {
471 Ok(response_str) => {
472 let body = Full::new(Bytes::from(response_str))
473 .map_err(|err| TransportServerError::HttpError(err.to_string()))
474 .boxed();
475
476 http::Response::builder()
477 .status(StatusCode::OK)
478 .header(CONTENT_TYPE, "application/json")
479 .header(MCP_SESSION_ID_HEADER, session_id_value)
480 .body(body)
481 .map_err(|err| TransportServerError::HttpError(err.to_string()))
482 }
483 Err(err) => {
484 let body = Full::new(Bytes::from(err.to_string()))
485 .map_err(|err| TransportServerError::HttpError(err.to_string()))
486 .boxed();
487 http::Response::builder()
488 .status(StatusCode::INTERNAL_SERVER_ERROR)
489 .header(CONTENT_TYPE, "application/json")
490 .body(body)
491 .map_err(|err| TransportServerError::HttpError(err.to_string()))
492 }
493 },
494 None => {
495 let body = Full::new(Bytes::from(
496 "End of the transport stream reached.".to_string(),
497 ))
498 .map_err(|err| TransportServerError::HttpError(err.to_string()))
499 .boxed();
500 http::Response::builder()
501 .status(StatusCode::UNPROCESSABLE_ENTITY)
502 .header(CONTENT_TYPE, "application/json")
503 .body(body)
504 .map_err(|err| TransportServerError::HttpError(err.to_string()))
505 }
506 }
507}
508
509pub(crate) async fn process_incoming_message_return(
510 session_id: SessionId,
511 state: Arc<McpAppState>,
512 payload: &str,
513 auth_info: Option<AuthInfo>,
514) -> TransportServerResult<http::Response<GenericBody>> {
515 match state.session_store.get(&session_id).await {
516 Some(runtime) => {
517 runtime.update_auth_info(auth_info).await;
518 single_shot_stream(
519 runtime.clone(),
520 session_id,
521 state.clone(),
522 Some(payload),
523 false,
524 )
525 .await
526 }
528 None => {
529 let error = SdkError::session_not_found();
530 error_response(StatusCode::NOT_FOUND, error)
531 .map_err(|err| TransportServerError::HttpError(err.to_string()))
532 }
533 }
534}
535
536pub(crate) async fn process_incoming_message(
537 session_id: SessionId,
538 state: Arc<McpAppState>,
539 payload: &str,
540 auth_info: Option<AuthInfo>,
541) -> TransportServerResult<http::Response<GenericBody>> {
542 match state.session_store.get(&session_id).await {
543 Some(runtime) => {
544 runtime.update_auth_info(auth_info).await;
545 let Ok(is_result) = is_result(payload) else {
548 return error_response(StatusCode::BAD_REQUEST, SdkError::parse_error());
549 };
550
551 if is_result {
552 match runtime
553 .consume_payload_string(DEFAULT_STREAM_ID, payload)
554 .await
555 {
556 Ok(()) => {
557 let body = Full::new(Bytes::new())
558 .map_err(|err| TransportServerError::HttpError(err.to_string()))
559 .boxed();
560 http::Response::builder()
561 .status(200)
562 .header("Content-Type", "application/json")
563 .body(body)
564 .map_err(|err| TransportServerError::HttpError(err.to_string()))
565 }
566 Err(err) => {
567 let error =
568 SdkError::internal_error().with_message(err.to_string().as_ref());
569 error_response(StatusCode::BAD_REQUEST, error)
570 }
571 }
572 } else {
573 create_sse_stream(
574 runtime.clone(),
575 session_id.clone(),
576 state.clone(),
577 Some(payload),
578 false,
579 None,
580 )
581 .await
582 }
583 }
584 None => {
585 let error = SdkError::session_not_found();
586 error_response(StatusCode::NOT_FOUND, error)
587 }
588 }
589}
590
591pub(crate) fn is_empty_sse_message(sse_payload: &str) -> bool {
592 sse_payload.is_empty() || sse_payload.trim() == ":"
593}
594
595pub(crate) async fn delete_session(
596 session_id: SessionId,
597 state: Arc<McpAppState>,
598) -> TransportServerResult<http::Response<GenericBody>> {
599 match state.session_store.get(&session_id).await {
600 Some(runtime) => {
601 runtime.shutdown().await;
602 state.session_store.delete(&session_id).await;
603 tracing::info!("client disconnected : {}", &session_id);
604
605 let body = Full::new(Bytes::from("ok"))
606 .map_err(|err| TransportServerError::HttpError(err.to_string()))
607 .boxed();
608 http::Response::builder()
609 .status(200)
610 .header("Content-Type", "application/json")
611 .body(body)
612 .map_err(|err| TransportServerError::HttpError(err.to_string()))
613 }
614 None => {
615 let error = SdkError::session_not_found();
616 error_response(StatusCode::NOT_FOUND, error)
617 }
618 }
619}
620
621pub(crate) fn acceptable_content_type(headers: &HeaderMap) -> bool {
622 let accept_header = headers
623 .get("content-type")
624 .and_then(|val| val.to_str().ok())
625 .unwrap_or("");
626 accept_header
627 .split(',')
628 .any(|val| val.trim().starts_with("application/json"))
629}
630
631pub(crate) fn validate_mcp_protocol_version_header(headers: &HeaderMap) -> SdkResult<()> {
632 let protocol_version_header = headers
633 .get(MCP_PROTOCOL_VERSION_HEADER)
634 .and_then(|val| val.to_str().ok())
635 .unwrap_or("");
636
637 if protocol_version_header.is_empty() {
639 return Ok(());
640 }
641
642 validate_mcp_protocol_version(protocol_version_header)
643}
644
645pub(crate) fn accepts_event_stream(headers: &HeaderMap) -> bool {
646 let accept_header = headers
647 .get(ACCEPT)
648 .and_then(|val| val.to_str().ok())
649 .unwrap_or("");
650
651 accept_header
652 .split(',')
653 .any(|val| val.trim().starts_with("text/event-stream"))
654}
655
656pub(crate) fn valid_streaming_http_accept_header(headers: &HeaderMap) -> bool {
657 let accept_header = headers
658 .get(ACCEPT)
659 .and_then(|val| val.to_str().ok())
660 .unwrap_or("");
661
662 let types: Vec<_> = accept_header.split(',').map(|v| v.trim()).collect();
663
664 let has_event_stream = types.iter().any(|v| v.starts_with("text/event-stream"));
665 let has_json = types.iter().any(|v| v.starts_with("application/json"));
666 has_event_stream && has_json
667}
668
669pub fn error_response(
670 status_code: StatusCode,
671 error: SdkError,
672) -> TransportServerResult<http::Response<GenericBody>> {
673 let error_string = serde_json::to_string(&error).unwrap_or_default();
674 let body = Full::new(Bytes::from(error_string))
675 .map_err(|err| TransportServerError::HttpError(err.to_string()))
676 .boxed();
677
678 http::Response::builder()
679 .status(status_code)
680 .header(CONTENT_TYPE, "application/json")
681 .body(body)
682 .map_err(|err| TransportServerError::HttpError(err.to_string()))
683}
684
685pub(crate) fn query_param(request: &http::Request<&str>, key: &str) -> Option<String> {
699 request.uri().query().and_then(|query| {
700 for pair in query.split('&') {
701 let mut split = pair.splitn(2, '=');
702 let k = split.next()?;
703 let v = split.next().unwrap_or("");
704 if k == key {
705 return Some(v.to_string());
706 }
707 }
708 None
709 })
710}
711
712#[cfg(feature = "sse")]
713pub(crate) async fn handle_sse_connection(
714 state: Arc<McpAppState>,
715 sse_message_endpoint: Option<&str>,
716 auth_info: Option<AuthInfo>,
717) -> TransportServerResult<http::Response<GenericBody>> {
718 let session_id: SessionId = state.id_generator.generate();
719
720 let sse_message_endpoint = sse_message_endpoint.unwrap_or(DEFAULT_MESSAGES_ENDPOINT);
721 let messages_endpoint =
722 SseTransport::<ClientMessage>::message_endpoint(sse_message_endpoint, &session_id);
723
724 let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE);
727
728 let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE);
730
731 let Ok(transport) = SseTransport::new(
733 read_rx,
734 write_tx,
735 read_tx,
736 Arc::clone(&state.transport_options),
737 ) else {
738 return Err(TransportServerError::TransportError(
739 "Failed to create SSE transport".to_string(),
740 ));
741 };
742
743 let h: Arc<dyn McpServerHandler> = state.handler.clone();
744 let server: Arc<ServerRuntime> = server_runtime::create_server_instance(
746 Arc::clone(&state.server_details),
747 h,
748 session_id.to_owned(),
749 auth_info,
750 );
751
752 state
753 .session_store
754 .set(session_id.to_owned(), server.clone())
755 .await;
756
757 tracing::info!("A new client joined : {}", session_id.to_owned());
758
759 tokio::spawn(async move {
761 match server
762 .start_stream(
763 Arc::new(transport),
764 DEFAULT_STREAM_ID,
765 state.ping_interval,
766 None,
767 )
768 .await
769 {
770 Ok(_) => tracing::info!("server {} exited gracefully.", session_id.to_owned()),
771 Err(err) => tracing::info!(
772 "server {} exited with error : {}",
773 session_id.to_owned(),
774 err
775 ),
776 };
777
778 state.session_store.delete(&session_id).await;
779 });
780
781 let initial_sse_event = stream::once(async move { initial_sse_event(&messages_endpoint) });
783
784 let reader = BufReader::new(write_rx);
786
787 let message_stream = stream::unfold(reader, |mut reader| async move {
788 let mut line = String::new();
789
790 match reader.read_line(&mut line).await {
791 Ok(0) => None, Ok(_) => {
793 let trimmed_line = line.trim_end_matches('\n').to_owned();
794 Some((
795 Ok(SseEvent::default().with_data(trimmed_line).as_bytes()),
796 reader,
797 ))
798 }
799 Err(_) => None, }
801 });
802
803 let stream = initial_sse_event.chain(message_stream);
804
805 let streaming_body: GenericBody =
807 http_body_util::BodyExt::boxed(StreamBody::new(stream.map(|res| res.map(Frame::data))));
808
809 let response = http::Response::builder()
810 .status(StatusCode::OK)
811 .header(CONTENT_TYPE, "text/event-stream")
812 .header(CONNECTION, "keep-alive")
813 .body(streaming_body)
814 .map_err(|err| TransportServerError::HttpError(err.to_string()))?;
815
816 Ok(response)
817}