1use crate::mcp_http::types::GenericBody;
2use crate::schema::schema_utils::{ClientMessage, SdkError};
3use crate::{
4 error::SdkResult,
5 hyper_servers::error::{TransportServerError, TransportServerResult},
6 mcp_http::McpAppState,
7 mcp_runtimes::server_runtime::DEFAULT_STREAM_ID,
8 mcp_server::{server_runtime, ServerRuntime},
9 mcp_traits::{mcp_handler::McpServerHandler, IdGenerator},
10 utils::validate_mcp_protocol_version,
11};
12use axum::http::HeaderValue;
13use bytes::Bytes;
14use futures::stream;
15use http::header::{ACCEPT, CONNECTION, CONTENT_TYPE};
16use http_body::Frame;
17use http_body_util::StreamBody;
18use http_body_util::{BodyExt, Full};
19use hyper::{HeaderMap, StatusCode};
20use rust_mcp_transport::{
21 EventId, McpDispatch, SessionId, SseEvent, SseTransport, StreamId, ID_SEPARATOR,
22 MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER,
23};
24use std::sync::Arc;
25use tokio::io::{duplex, AsyncBufReadExt, BufReader};
26use tokio_stream::StreamExt;
27
28pub(crate) const DEFAULT_SSE_ENDPOINT: &str = "/sse";
30pub(crate) const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages";
32pub(crate) const DEFAULT_STREAMABLE_HTTP_ENDPOINT: &str = "/mcp";
34const DUPLEX_BUFFER_SIZE: usize = 8192;
35
36pub fn empty_response() -> GenericBody {
42 Full::new(Bytes::new())
43 .map_err(|err| TransportServerError::HttpError(err.to_string()))
44 .boxed()
45}
46
47pub fn build_response(
48 status_code: StatusCode,
49 payload: String,
50) -> Result<http::Response<GenericBody>, TransportServerError> {
51 let body = Full::new(Bytes::from(payload))
52 .map_err(|err| TransportServerError::HttpError(err.to_string()))
53 .boxed();
54
55 http::Response::builder()
56 .status(status_code)
57 .body(body)
58 .map_err(|err| TransportServerError::HttpError(err.to_string()))
59}
60
61fn initial_sse_event(endpoint: &str) -> Result<Bytes, TransportServerError> {
71 Ok(SseEvent::default()
72 .with_event("endpoint")
73 .with_data(endpoint.to_string())
74 .as_bytes())
75}
76
77async fn create_sse_stream(
78 runtime: Arc<ServerRuntime>,
79 session_id: SessionId,
80 state: Arc<McpAppState>,
81 payload: Option<&str>,
82 standalone: bool,
83 last_event_id: Option<EventId>,
84) -> TransportServerResult<http::Response<GenericBody>> {
85 let payload_string = payload.map(|p| p.to_string());
86
87 let payload_contains_request = payload_string
89 .as_ref()
90 .map(|json_str| contains_request(json_str))
91 .unwrap_or(Ok(false));
92 let Ok(payload_contains_request) = payload_contains_request else {
93 return error_response(StatusCode::BAD_REQUEST, SdkError::parse_error());
94 };
95
96 let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE);
98 let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE);
100
101 let session_id = Arc::new(session_id);
102 let stream_id: Arc<StreamId> = if standalone {
103 Arc::new(DEFAULT_STREAM_ID.to_string())
104 } else {
105 Arc::new(state.stream_id_gen.generate())
106 };
107
108 let event_store = state.event_store.as_ref().map(Arc::clone);
109 let resumability_enabled = event_store.is_some();
110
111 let mut transport = SseTransport::<ClientMessage>::new(
112 read_rx,
113 write_tx,
114 read_tx,
115 Arc::clone(&state.transport_options),
116 )
117 .map_err(|err| TransportServerError::TransportError(err.to_string()))?;
118 if let Some(event_store) = event_store.clone() {
119 transport.make_resumable((*session_id).clone(), (*stream_id).clone(), event_store);
120 }
121 let transport = Arc::new(transport);
122
123 let ping_interval = state.ping_interval;
124 let runtime_clone = Arc::clone(&runtime);
125 let stream_id_clone = stream_id.clone();
126 let transport_clone = transport.clone();
127
128 tokio::spawn(async move {
130 match runtime_clone
131 .start_stream(
132 transport_clone,
133 &stream_id_clone,
134 ping_interval,
135 payload_string,
136 )
137 .await
138 {
139 Ok(_) => tracing::trace!("stream {} exited gracefully.", &stream_id_clone),
140 Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id_clone, err),
141 }
142 let _ = runtime.remove_transport(&stream_id_clone).await;
143 });
144
145 let reader = BufReader::new(write_rx);
147
148 let message_stream = stream::unfold(reader, move |mut reader| {
150 async move {
151 let mut line = String::new();
152
153 match reader.read_line(&mut line).await {
154 Ok(0) => None, Ok(_) => {
156 let trimmed_line = line.trim_end_matches('\n').to_owned();
157
158 if is_empty_sse_message(&trimmed_line) {
160 return Some((Ok(SseEvent::default().as_bytes()), reader));
161 }
162
163 let (event_id, message) = match (
164 resumability_enabled,
165 trimmed_line.split_once(char::from(ID_SEPARATOR)),
166 ) {
167 (true, Some((id, msg))) => (Some(id.to_string()), msg.to_string()),
168 _ => (None, trimmed_line),
169 };
170
171 let event = match event_id {
172 Some(id) => SseEvent::default()
173 .with_data(message)
174 .with_id(id)
175 .as_bytes(),
176 None => SseEvent::default().with_data(message).as_bytes(),
177 };
178
179 Some((Ok(event), reader))
180 }
181 Err(e) => Some((Err(e), reader)),
182 }
183 }
184 });
185
186 let streaming_body: GenericBody =
188 http_body_util::BodyExt::boxed(StreamBody::new(message_stream.map(|res| {
189 res.map(Frame::data)
190 .map_err(|err: std::io::Error| TransportServerError::HttpError(err.to_string()))
191 })));
192
193 let session_id_value = HeaderValue::from_str(&session_id)
194 .map_err(|err| TransportServerError::HttpError(err.to_string()))?;
195
196 let status_code = if !payload_contains_request {
197 StatusCode::ACCEPTED
198 } else {
199 StatusCode::OK
200 };
201
202 let response = http::Response::builder()
203 .status(status_code)
204 .header(CONTENT_TYPE, "text/event-stream")
205 .header(MCP_SESSION_ID_HEADER, session_id_value)
206 .header(CONNECTION, "keep-alive")
207 .body(streaming_body)
208 .map_err(|err| TransportServerError::HttpError(err.to_string()))?;
209
210 tokio::spawn(async move {
212 if let Some(last_event_id) = last_event_id {
213 if let Some(event_store) = state.event_store.as_ref() {
214 let events = event_store
215 .events_after(last_event_id)
216 .await
217 .unwrap_or_else(|err| {
218 tracing::error!("{err}");
219 None
220 });
221
222 if let Some(events) = events {
223 for message_payload in events.messages {
224 let error = transport.write_str(&message_payload, true).await;
226 if let Err(error) = error {
227 tracing::trace!("Error replaying message: {error}")
228 }
229 }
230 }
231 }
232 }
233 });
234
235 Ok(response)
236}
237
238fn contains_request(json_str: &str) -> Result<bool, serde_json::Error> {
242 let value: serde_json::Value = serde_json::from_str(json_str)?;
243 match value {
244 serde_json::Value::Object(obj) => Ok(obj.contains_key("id") && obj.contains_key("method")),
245 serde_json::Value::Array(arr) => Ok(arr.iter().any(|item| {
246 item.as_object()
247 .map(|obj| obj.contains_key("id") && obj.contains_key("method"))
248 .unwrap_or(false)
249 })),
250 _ => Ok(false),
251 }
252}
253
254fn is_result(json_str: &str) -> Result<bool, serde_json::Error> {
255 let value: serde_json::Value = serde_json::from_str(json_str)?;
256 match value {
257 serde_json::Value::Object(obj) => Ok(obj.contains_key("result")),
258 serde_json::Value::Array(arr) => Ok(arr.iter().all(|item| {
259 item.as_object()
260 .map(|obj| obj.contains_key("result"))
261 .unwrap_or(false)
262 })),
263 _ => Ok(false),
264 }
265}
266
267pub(crate) async fn create_standalone_stream(
268 session_id: SessionId,
269 last_event_id: Option<EventId>,
270 state: Arc<McpAppState>,
271) -> TransportServerResult<http::Response<GenericBody>> {
272 let runtime = state.session_store.get(&session_id).await.ok_or(
273 TransportServerError::SessionIdInvalid(session_id.to_string()),
274 )?;
275
276 if runtime.stream_id_exists(DEFAULT_STREAM_ID).await {
277 let error =
278 SdkError::bad_request().with_message("Only one SSE stream is allowed per session");
279 return error_response(StatusCode::CONFLICT, error)
280 .map_err(|err| TransportServerError::HttpError(err.to_string()));
281 }
282
283 if let Some(last_event_id) = last_event_id.as_ref() {
284 tracing::trace!(
285 "SSE stream re-connected with last-event-id: {}",
286 last_event_id
287 );
288 }
289
290 let mut response = create_sse_stream(
291 runtime.clone(),
292 session_id.clone(),
293 state.clone(),
294 None,
295 true,
296 last_event_id,
297 )
298 .await?;
299 *response.status_mut() = StatusCode::OK;
300 Ok(response)
301}
302
303pub(crate) async fn start_new_session(
304 state: Arc<McpAppState>,
305 payload: &str,
306) -> TransportServerResult<http::Response<GenericBody>> {
307 let session_id: SessionId = state.id_generator.generate();
308
309 let h: Arc<dyn McpServerHandler> = state.handler.clone();
310 let runtime: Arc<ServerRuntime> = server_runtime::create_server_instance(
312 Arc::clone(&state.server_details),
313 h,
314 session_id.to_owned(),
315 );
316
317 tracing::info!("a new client joined : {}", &session_id);
318
319 let response = create_sse_stream(
320 runtime.clone(),
321 session_id.clone(),
322 state.clone(),
323 Some(payload),
324 false,
325 None,
326 )
327 .await;
328
329 if response.is_ok() {
330 state
331 .session_store
332 .set(session_id.to_owned(), runtime.clone())
333 .await;
334 }
335 response
336}
337async fn single_shot_stream(
338 runtime: Arc<ServerRuntime>,
339 session_id: SessionId,
340 state: Arc<McpAppState>,
341 payload: Option<&str>,
342 standalone: bool,
343) -> TransportServerResult<http::Response<GenericBody>> {
344 let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE);
346 let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE);
348
349 let transport = SseTransport::<ClientMessage>::new(
350 read_rx,
351 write_tx,
352 read_tx,
353 Arc::clone(&state.transport_options),
354 )
355 .map_err(|err| TransportServerError::TransportError(err.to_string()))?;
356
357 let stream_id = if standalone {
358 DEFAULT_STREAM_ID.to_string()
359 } else {
360 state.id_generator.generate()
361 };
362 let ping_interval = state.ping_interval;
363 let runtime_clone = Arc::clone(&runtime);
364
365 let payload_string = payload.map(|p| p.to_string());
366
367 tokio::spawn(async move {
368 match runtime_clone
369 .start_stream(
370 Arc::new(transport),
371 &stream_id,
372 ping_interval,
373 payload_string,
374 )
375 .await
376 {
377 Ok(_) => tracing::info!("stream {} exited gracefully.", &stream_id),
378 Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id, err),
379 }
380 let _ = runtime.remove_transport(&stream_id).await;
381 });
382
383 let mut reader = BufReader::new(write_rx);
384 let mut line = String::new();
385 let response = match reader.read_line(&mut line).await {
386 Ok(0) => None, Ok(_) => {
388 let trimmed_line = line.trim_end_matches('\n').to_owned();
389 Some(Ok(trimmed_line))
390 }
391 Err(e) => Some(Err(e)),
392 };
393
394 let session_id_value = HeaderValue::from_str(&session_id)
395 .map_err(|err| TransportServerError::HttpError(err.to_string()))?;
396
397 match response {
398 Some(response_result) => match response_result {
399 Ok(response_str) => {
400 let body = Full::new(Bytes::from(response_str))
401 .map_err(|err| TransportServerError::HttpError(err.to_string()))
402 .boxed();
403
404 http::Response::builder()
405 .status(StatusCode::OK)
406 .header(CONTENT_TYPE, "application/json")
407 .header(MCP_SESSION_ID_HEADER, session_id_value)
408 .body(body)
409 .map_err(|err| TransportServerError::HttpError(err.to_string()))
410 }
411 Err(err) => {
412 let body = Full::new(Bytes::from(err.to_string()))
413 .map_err(|err| TransportServerError::HttpError(err.to_string()))
414 .boxed();
415 http::Response::builder()
416 .status(StatusCode::INTERNAL_SERVER_ERROR)
417 .header(CONTENT_TYPE, "application/json")
418 .body(body)
419 .map_err(|err| TransportServerError::HttpError(err.to_string()))
420 }
421 },
422 None => {
423 let body = Full::new(Bytes::from(
424 "End of the transport stream reached.".to_string(),
425 ))
426 .map_err(|err| TransportServerError::HttpError(err.to_string()))
427 .boxed();
428 http::Response::builder()
429 .status(StatusCode::UNPROCESSABLE_ENTITY)
430 .header(CONTENT_TYPE, "application/json")
431 .body(body)
432 .map_err(|err| TransportServerError::HttpError(err.to_string()))
433 }
434 }
435}
436
437pub(crate) async fn process_incoming_message_return(
438 session_id: SessionId,
439 state: Arc<McpAppState>,
440 payload: &str,
441) -> TransportServerResult<http::Response<GenericBody>> {
442 match state.session_store.get(&session_id).await {
443 Some(runtime) => {
444 single_shot_stream(
445 runtime.clone(),
446 session_id,
447 state.clone(),
448 Some(payload),
449 false,
450 )
451 .await
452 }
454 None => {
455 let error = SdkError::session_not_found();
456 error_response(StatusCode::NOT_FOUND, error)
457 .map_err(|err| TransportServerError::HttpError(err.to_string()))
458 }
459 }
460}
461
462pub(crate) async fn process_incoming_message(
463 session_id: SessionId,
464 state: Arc<McpAppState>,
465 payload: &str,
466) -> TransportServerResult<http::Response<GenericBody>> {
467 match state.session_store.get(&session_id).await {
468 Some(runtime) => {
469 let Ok(is_result) = is_result(payload) else {
472 return error_response(StatusCode::BAD_REQUEST, SdkError::parse_error());
473 };
474
475 if is_result {
476 match runtime
477 .consume_payload_string(DEFAULT_STREAM_ID, payload)
478 .await
479 {
480 Ok(()) => {
481 let body = Full::new(Bytes::new())
482 .map_err(|err| TransportServerError::HttpError(err.to_string()))
483 .boxed();
484 http::Response::builder()
485 .status(200)
486 .header("Content-Type", "application/json")
487 .body(body)
488 .map_err(|err| TransportServerError::HttpError(err.to_string()))
489 }
490 Err(err) => {
491 let error =
492 SdkError::internal_error().with_message(err.to_string().as_ref());
493 error_response(StatusCode::BAD_REQUEST, error)
494 }
495 }
496 } else {
497 create_sse_stream(
498 runtime.clone(),
499 session_id.clone(),
500 state.clone(),
501 Some(payload),
502 false,
503 None,
504 )
505 .await
506 }
507 }
508 None => {
509 let error = SdkError::session_not_found();
510 error_response(StatusCode::NOT_FOUND, error)
511 }
512 }
513}
514
515pub(crate) fn is_empty_sse_message(sse_payload: &str) -> bool {
516 sse_payload.is_empty() || sse_payload.trim() == ":"
517}
518
519pub(crate) async fn delete_session(
520 session_id: SessionId,
521 state: Arc<McpAppState>,
522) -> TransportServerResult<http::Response<GenericBody>> {
523 match state.session_store.get(&session_id).await {
524 Some(runtime) => {
525 runtime.shutdown().await;
526 state.session_store.delete(&session_id).await;
527 tracing::info!("client disconnected : {}", &session_id);
528
529 let body = Full::new(Bytes::from("ok"))
530 .map_err(|err| TransportServerError::HttpError(err.to_string()))
531 .boxed();
532 http::Response::builder()
533 .status(200)
534 .header("Content-Type", "application/json")
535 .body(body)
536 .map_err(|err| TransportServerError::HttpError(err.to_string()))
537 }
538 None => {
539 let error = SdkError::session_not_found();
540 error_response(StatusCode::NOT_FOUND, error)
541 }
542 }
543}
544
545pub(crate) fn acceptable_content_type(headers: &HeaderMap) -> bool {
546 let accept_header = headers
547 .get("content-type")
548 .and_then(|val| val.to_str().ok())
549 .unwrap_or("");
550 accept_header
551 .split(',')
552 .any(|val| val.trim().starts_with("application/json"))
553}
554
555pub(crate) fn validate_mcp_protocol_version_header(headers: &HeaderMap) -> SdkResult<()> {
556 let protocol_version_header = headers
557 .get(MCP_PROTOCOL_VERSION_HEADER)
558 .and_then(|val| val.to_str().ok())
559 .unwrap_or("");
560
561 if protocol_version_header.is_empty() {
563 return Ok(());
564 }
565
566 validate_mcp_protocol_version(protocol_version_header)
567}
568
569pub(crate) fn accepts_event_stream(headers: &HeaderMap) -> bool {
570 let accept_header = headers
571 .get(ACCEPT)
572 .and_then(|val| val.to_str().ok())
573 .unwrap_or("");
574
575 accept_header
576 .split(',')
577 .any(|val| val.trim().starts_with("text/event-stream"))
578}
579
580pub(crate) fn valid_streaming_http_accept_header(headers: &HeaderMap) -> bool {
581 let accept_header = headers
582 .get(ACCEPT)
583 .and_then(|val| val.to_str().ok())
584 .unwrap_or("");
585
586 let types: Vec<_> = accept_header.split(',').map(|v| v.trim()).collect();
587
588 let has_event_stream = types.iter().any(|v| v.starts_with("text/event-stream"));
589 let has_json = types.iter().any(|v| v.starts_with("application/json"));
590 has_event_stream && has_json
591}
592
593pub fn error_response(
594 status_code: StatusCode,
595 error: SdkError,
596) -> TransportServerResult<http::Response<GenericBody>> {
597 let error_string = serde_json::to_string(&error).unwrap_or_default();
598 let body = Full::new(Bytes::from(error_string))
599 .map_err(|err| TransportServerError::HttpError(err.to_string()))
600 .boxed();
601
602 http::Response::builder()
603 .status(status_code)
604 .header(CONTENT_TYPE, "application/json")
605 .body(body)
606 .map_err(|err| TransportServerError::HttpError(err.to_string()))
607}
608
609pub(crate) fn query_param(request: &http::Request<&str>, key: &str) -> Option<String> {
623 request.uri().query().and_then(|query| {
624 for pair in query.split('&') {
625 let mut split = pair.splitn(2, '=');
626 let k = split.next()?;
627 let v = split.next().unwrap_or("");
628 if k == key {
629 return Some(v.to_string());
630 }
631 }
632 None
633 })
634}
635
636#[cfg(feature = "sse")]
637pub(crate) async fn handle_sse_connection(
638 state: Arc<McpAppState>,
639 sse_message_endpoint: Option<&str>,
640) -> TransportServerResult<http::Response<GenericBody>> {
641 let session_id: SessionId = state.id_generator.generate();
642
643 let sse_message_endpoint = sse_message_endpoint.unwrap_or(DEFAULT_MESSAGES_ENDPOINT);
644 let messages_endpoint =
645 SseTransport::<ClientMessage>::message_endpoint(sse_message_endpoint, &session_id);
646
647 let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE);
650
651 let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE);
653
654 let Ok(transport) = SseTransport::new(
656 read_rx,
657 write_tx,
658 read_tx,
659 Arc::clone(&state.transport_options),
660 ) else {
661 return Err(TransportServerError::TransportError(
662 "Failed to create SSE transport".to_string(),
663 ));
664 };
665
666 let h: Arc<dyn McpServerHandler> = state.handler.clone();
667 let server: Arc<ServerRuntime> = server_runtime::create_server_instance(
669 Arc::clone(&state.server_details),
670 h,
671 session_id.to_owned(),
672 );
673
674 state
675 .session_store
676 .set(session_id.to_owned(), server.clone())
677 .await;
678
679 tracing::info!("A new client joined : {}", session_id.to_owned());
680
681 tokio::spawn(async move {
683 match server
684 .start_stream(
685 Arc::new(transport),
686 DEFAULT_STREAM_ID,
687 state.ping_interval,
688 None,
689 )
690 .await
691 {
692 Ok(_) => tracing::info!("server {} exited gracefully.", session_id.to_owned()),
693 Err(err) => tracing::info!(
694 "server {} exited with error : {}",
695 session_id.to_owned(),
696 err
697 ),
698 };
699
700 state.session_store.delete(&session_id).await;
701 });
702
703 let initial_sse_event = stream::once(async move { initial_sse_event(&messages_endpoint) });
705
706 let reader = BufReader::new(write_rx);
708
709 let message_stream = stream::unfold(reader, |mut reader| async move {
710 let mut line = String::new();
711
712 match reader.read_line(&mut line).await {
713 Ok(0) => None, Ok(_) => {
715 let trimmed_line = line.trim_end_matches('\n').to_owned();
716 Some((
717 Ok(SseEvent::default().with_data(trimmed_line).as_bytes()),
718 reader,
719 ))
720 }
721 Err(_) => None, }
723 });
724
725 let stream = initial_sse_event.chain(message_stream);
726
727 let streaming_body: GenericBody =
729 http_body_util::BodyExt::boxed(StreamBody::new(stream.map(|res| res.map(Frame::data))));
730
731 let response = http::Response::builder()
732 .status(StatusCode::OK)
733 .header(CONTENT_TYPE, "text/event-stream")
734 .header(CONNECTION, "keep-alive")
735 .body(streaming_body)
736 .map_err(|err| TransportServerError::HttpError(err.to_string()))?;
737
738 Ok(response)
739}