1use std::collections::HashMap;
13use std::pin::Pin;
14use std::sync::Arc;
15
16use bytes::Bytes;
17use futures::Stream;
18use http_body::Body;
19use http_body_util::{BodyExt, Full};
20use hyper::header::{ACCEPT, CONTENT_TYPE};
21use hyper::{HeaderMap, Method, Request, Response, StatusCode};
22use serde_json::Value;
23use tracing::{debug, error, warn};
24use turul_mcp_session_storage::SessionView;
25
26use crate::ServerConfig;
27use crate::protocol::normalize_header_value;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
31pub enum McpProtocolVersion {
32 V2024_11_05,
34 V2025_03_26,
36 V2025_06_18,
38 #[default]
40 V2025_11_25,
41}
42
43impl McpProtocolVersion {
44 pub fn parse_version(s: &str) -> Option<Self> {
46 match s {
47 "2024-11-05" => Some(Self::V2024_11_05),
48 "2025-03-26" => Some(Self::V2025_03_26),
49 "2025-06-18" => Some(Self::V2025_06_18),
50 "2025-11-25" => Some(Self::V2025_11_25),
51 _ => None,
52 }
53 }
54
55 pub fn as_str(&self) -> &'static str {
57 match self {
58 Self::V2024_11_05 => "2024-11-05",
59 Self::V2025_03_26 => "2025-03-26",
60 Self::V2025_06_18 => "2025-06-18",
61 Self::V2025_11_25 => "2025-11-25",
62 }
63 }
64
65 pub fn supports_streamable_http(&self) -> bool {
67 matches!(
68 self,
69 Self::V2025_03_26 | Self::V2025_06_18 | Self::V2025_11_25
70 )
71 }
72
73 pub fn supports_meta_fields(&self) -> bool {
75 matches!(self, Self::V2025_06_18 | Self::V2025_11_25)
76 }
77
78 pub fn supports_cursors(&self) -> bool {
80 matches!(self, Self::V2025_06_18 | Self::V2025_11_25)
81 }
82
83 pub fn supports_progress_tokens(&self) -> bool {
85 matches!(self, Self::V2025_06_18 | Self::V2025_11_25)
86 }
87
88 pub fn supports_elicitation(&self) -> bool {
90 matches!(self, Self::V2025_06_18 | Self::V2025_11_25)
91 }
92
93 pub fn supports_tasks(&self) -> bool {
95 matches!(self, Self::V2025_11_25)
96 }
97
98 pub fn supports_icons(&self) -> bool {
100 matches!(self, Self::V2025_11_25)
101 }
102
103 pub fn supported_features(&self) -> Vec<&'static str> {
105 let mut features = vec![];
106 if self.supports_streamable_http() {
107 features.push("streamable-http");
108 }
109 if self.supports_meta_fields() {
110 features.push("_meta-fields");
111 }
112 if self.supports_cursors() {
113 features.push("cursor-pagination");
114 }
115 if self.supports_progress_tokens() {
116 features.push("progress-tokens");
117 }
118 if self.supports_elicitation() {
119 features.push("elicitation");
120 }
121 if self.supports_tasks() {
122 features.push("tasks");
123 }
124 if self.supports_icons() {
125 features.push("icons");
126 }
127 features
128 }
129}
130
131#[derive(Debug, Clone)]
133pub struct StreamableHttpContext {
134 pub protocol_version: McpProtocolVersion,
136 pub session_id: Option<String>,
138 pub wants_sse_stream: bool,
140 pub accepts_stream_frames: bool,
142 pub headers: HashMap<String, String>,
144}
145
146impl StreamableHttpContext {
147 pub fn from_request<T>(req: &Request<T>) -> Self {
149 let headers = req.headers();
150
151 let protocol_version = headers
153 .get("MCP-Protocol-Version")
154 .and_then(|h| h.to_str().ok())
155 .and_then(McpProtocolVersion::parse_version)
156 .unwrap_or_default();
157
158 let session_id = headers
160 .get("Mcp-Session-Id")
161 .and_then(|h| h.to_str().ok())
162 .map(|s| s.to_string());
163
164 let accept_header = headers
166 .get(ACCEPT)
167 .and_then(|h| h.to_str().ok())
168 .map(normalize_header_value)
169 .unwrap_or_default();
170
171 let wants_sse_stream = accept_header.contains("text/event-stream");
172 let accepts_stream_frames = accept_header.contains("application/json")
173 || accept_header.contains("text/event-stream")
174 || accept_header.contains("*/*");
175
176 let mut header_map = HashMap::new();
178 for (name, value) in headers.iter() {
179 if let Ok(value_str) = value.to_str() {
180 header_map.insert(name.to_string(), value_str.to_string());
181 }
182 }
183
184 Self {
185 protocol_version,
186 session_id,
187 wants_sse_stream,
188 accepts_stream_frames,
189 headers: header_map,
190 }
191 }
192
193 pub fn wants_sse_stream(&self) -> bool {
195 self.wants_sse_stream
196 }
197
198 pub fn wants_streaming_post(&self) -> bool {
200 self.accepts_stream_frames && self.wants_sse_stream
201 }
202
203 pub fn is_streamable_compatible(&self) -> bool {
205 self.protocol_version.supports_streamable_http() && self.accepts_stream_frames
206 }
207
208 pub fn validate(&self, method: &Method) -> std::result::Result<(), String> {
210 if !self.accepts_stream_frames {
211 return Err(
212 "Accept header must include application/json, text/event-stream, or */*"
213 .to_string(),
214 );
215 }
216
217 if self.wants_sse_stream && !self.protocol_version.supports_streamable_http() {
218 return Err(format!(
219 "Protocol version {} does not support streamable HTTP",
220 self.protocol_version.as_str()
221 ));
222 }
223
224 if *method == Method::GET && self.wants_sse_stream && self.session_id.is_none() {
227 return Err("Mcp-Session-Id header required for SSE streaming connections".to_string());
228 }
229
230 Ok(())
231 }
232
233 pub fn response_headers(&self) -> HeaderMap {
235 let mut headers = HeaderMap::new();
236
237 headers.insert(
239 "MCP-Protocol-Version",
240 self.protocol_version.as_str().parse().unwrap(),
241 );
242
243 if let Some(session_id) = &self.session_id {
245 headers.insert("Mcp-Session-Id", session_id.parse().unwrap());
246 }
247
248 let features = self.protocol_version.supported_features();
250 if !features.is_empty() {
251 headers.insert("MCP-Capabilities", features.join(",").parse().unwrap());
252 }
253
254 headers
255 }
256}
257
258enum SessionValidationError {
263 NotFound(String),
265 StorageError(String),
267}
268
269impl SessionValidationError {
270 fn status_code(&self) -> StatusCode {
271 match self {
272 Self::NotFound(_) => StatusCode::NOT_FOUND,
273 Self::StorageError(_) => StatusCode::INTERNAL_SERVER_ERROR,
274 }
275 }
276
277 fn message(&self) -> &str {
278 match self {
279 Self::NotFound(msg) | Self::StorageError(msg) => msg,
280 }
281 }
282}
283
284pub enum StreamableResponse {
286 Json(Value),
288 Stream(Pin<Box<dyn Stream<Item = std::result::Result<Value, String>> + Send>>),
290 Error { status: StatusCode, message: String },
292}
293
294impl std::fmt::Debug for StreamableResponse {
295 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296 match self {
297 Self::Json(value) => f.debug_tuple("Json").field(value).finish(),
298 Self::Stream(_) => f.debug_tuple("Stream").field(&"<stream>").finish(),
299 Self::Error { status, message } => f
300 .debug_struct("Error")
301 .field("status", status)
302 .field("message", message)
303 .finish(),
304 }
305 }
306}
307
308impl StreamableResponse {
309 pub fn into_response(self, context: &StreamableHttpContext) -> Response<Full<Bytes>> {
311 let mut response_headers = context.response_headers();
312
313 match self {
314 StreamableResponse::Json(json) => {
315 response_headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
316
317 let body = serde_json::to_string(&json)
318 .unwrap_or_else(|_| r#"{"error": "Failed to serialize response"}"#.to_string());
319
320 Response::builder()
321 .status(StatusCode::OK)
322 .body(Full::new(Bytes::from(body)))
323 .unwrap()
324 }
325
326 StreamableResponse::Stream(_stream) => {
327 response_headers.insert(CONTENT_TYPE, "text/event-stream".parse().unwrap());
329 response_headers.insert("Cache-Control", "no-cache, no-transform".parse().unwrap());
330 response_headers.insert("Connection", "keep-alive".parse().unwrap());
331
332 Response::builder()
336 .status(StatusCode::ACCEPTED)
337 .body(Full::new(Bytes::from("Streaming response accepted")))
338 .unwrap()
339 }
340
341 StreamableResponse::Error { status, message } => {
342 response_headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
343
344 let error_json = serde_json::json!({
345 "error": {
346 "code": status.as_u16(),
347 "message": message
348 }
349 });
350
351 let body = serde_json::to_string(&error_json).unwrap_or_else(|_| {
352 r#"{"error": {"code": 500, "message": "Internal server error"}}"#.to_string()
353 });
354
355 Response::builder()
356 .status(status)
357 .body(Full::new(Bytes::from(body)))
358 .unwrap()
359 }
360 }
361 }
362
363 pub fn into_boxed_response(
365 self,
366 context: &StreamableHttpContext,
367 ) -> Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>> {
368 self.into_response(context)
369 .map(|body| body.map_err(|never| match never {}).boxed_unsync())
370 }
371}
372
373#[derive(Clone)]
375pub struct StreamableHttpHandler {
376 config: Arc<ServerConfig>,
377 dispatcher: Arc<turul_mcp_json_rpc_server::JsonRpcDispatcher<turul_mcp_protocol::McpError>>,
378 session_storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
379 stream_manager: Arc<crate::StreamManager>,
380 server_capabilities: turul_mcp_protocol::ServerCapabilities,
381 pub(crate) middleware_stack: Arc<crate::middleware::MiddlewareStack>,
382}
383
384impl StreamableHttpHandler {
385 pub fn new(
386 config: Arc<ServerConfig>,
387 dispatcher: Arc<turul_mcp_json_rpc_server::JsonRpcDispatcher<turul_mcp_protocol::McpError>>,
388 session_storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
389 stream_manager: Arc<crate::StreamManager>,
390 server_capabilities: turul_mcp_protocol::ServerCapabilities,
391 middleware_stack: Arc<crate::middleware::MiddlewareStack>,
392 ) -> Self {
393 Self {
394 config,
395 dispatcher,
396 session_storage,
397 stream_manager,
398 server_capabilities,
399 middleware_stack,
400 }
401 }
402
403 pub async fn handle_request<T>(
405 &self,
406 req: Request<T>,
407 ) -> Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>>
408 where
409 T: Body + Send + 'static,
410 T::Data: Send,
411 T::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
412 {
413 debug!(
414 "Streamable handler request: method={}, uri={}",
415 req.method(),
416 req.uri()
417 );
418 let context = StreamableHttpContext::from_request(&req);
420
421 debug!(
422 "Streamable handler entry: method={}, protocol={}, session={:?}, accepts_stream_frames={}, wants_sse_stream={}",
423 req.method(),
424 context.protocol_version.as_str(),
425 context.session_id,
426 context.accepts_stream_frames,
427 context.wants_sse_stream()
428 );
429
430 if *req.method() == Method::OPTIONS {
433 return Response::builder()
434 .status(StatusCode::OK)
435 .body(Full::new(Bytes::new()))
436 .unwrap()
437 .map(|body| body.map_err(|never| match never {}).boxed_unsync());
438 }
439
440 if let Err(error) = context.validate(req.method()) {
442 warn!("Invalid streamable HTTP request: {}", error);
443 return StreamableResponse::Error {
444 status: StatusCode::BAD_REQUEST,
445 message: error,
446 }
447 .into_boxed_response(&context);
448 }
449
450 match *req.method() {
452 Method::POST => {
453 self.handle_client_message(req, context).await
456 }
457 Method::GET => {
458 self.handle_get_sse_notifications(req, context).await
460 }
461 Method::DELETE => {
462 self.handle_session_delete(req, context).await
464 }
465 _ => StreamableResponse::Error {
466 status: StatusCode::METHOD_NOT_ALLOWED,
467 message: "Method not allowed for this endpoint".to_string(),
468 }
469 .into_boxed_response(&context),
470 }
471 }
472
473 async fn handle_get_sse_notifications<T>(
479 &self,
480 req: Request<T>,
481 context: StreamableHttpContext,
482 ) -> Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>>
483 where
484 T: Body + Send + 'static,
485 {
486 debug!(
487 "Opening streaming connection for session: {:?}",
488 context.session_id
489 );
490
491 let session_id = match context.session_id {
493 Some(ref id) => id.clone(),
494 None => {
495 warn!("Missing session ID for streaming GET request");
496 return StreamableResponse::Error {
497 status: StatusCode::BAD_REQUEST,
498 message: "Mcp-Session-Id header required for streaming connection".to_string(),
499 }
500 .into_boxed_response(&context);
501 }
502 };
503
504 match self.validate_session_exists(&session_id).await {
506 Ok(_) => {
507 debug!(
508 "Session validation successful for streaming GET: {}",
509 session_id
510 );
511 }
512 Err(err) => {
513 error!(
514 "Session validation failed for streaming GET {}: {}",
515 session_id,
516 err.message()
517 );
518 return StreamableResponse::Error {
519 status: err.status_code(),
520 message: format!("Session validation failed: {}", err.message()),
521 }
522 .into_boxed_response(&context);
523 }
524 }
525
526 let last_event_id = req
532 .headers()
533 .get("Last-Event-ID")
534 .and_then(|h| h.to_str().ok())
535 .and_then(|s| s.parse::<u64>().ok());
536
537 let connection_id = uuid::Uuid::now_v7().as_simple().to_string();
539
540 debug!(
541 "Creating streamable HTTP connection: session={}, connection={}, last_event_id={:?}",
542 session_id, connection_id, last_event_id
543 );
544
545 match self
549 .stream_manager
550 .handle_sse_connection(session_id.clone(), connection_id.clone(), last_event_id)
551 .await
552 {
553 Ok(mut streaming_response) => {
554 debug!(
555 "Streamable HTTP connection established: session={}, connection={}",
556 session_id, connection_id
557 );
558
559 let mcp_headers = context.response_headers();
561 for (key, value) in mcp_headers.iter() {
562 streaming_response.headers_mut().insert(key, value.clone());
563 }
564
565 streaming_response
568 }
569 Err(err) => {
570 error!("Failed to create streamable HTTP connection: {}", err);
571 StreamableResponse::Error {
572 status: StatusCode::INTERNAL_SERVER_ERROR,
573 message: format!("Streaming connection failed: {}", err),
574 }
575 .into_boxed_response(&context)
576 }
577 }
578 }
579
580 async fn validate_session_exists(
586 &self,
587 session_id: &str,
588 ) -> std::result::Result<(), SessionValidationError> {
589 match self.session_storage.get_session(session_id).await {
590 Ok(Some(session_info)) => {
591 if session_info.is_terminated() {
592 error!("Session '{}' has been terminated", session_id);
593 return Err(SessionValidationError::NotFound(format!(
594 "Session '{}' has been terminated. Create a new session to continue.",
595 session_id
596 )));
597 }
598 debug!("Session validation successful: {}", session_id);
599 Ok(())
600 }
601 Ok(None) => {
602 error!("Session not found: {}", session_id);
603 Err(SessionValidationError::NotFound(format!(
604 "Session '{}' not found. Sessions must be created via initialize request first.",
605 session_id
606 )))
607 }
608 Err(err) => {
609 error!("Failed to validate session {}: {}", session_id, err);
610 Err(SessionValidationError::StorageError(format!(
611 "Session validation failed: {}",
612 err
613 )))
614 }
615 }
616 }
617
618 #[allow(dead_code)]
620 async fn handle_json_post<T>(
621 &self,
622 req: Request<T>,
623 context: StreamableHttpContext,
624 ) -> Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>>
625 where
626 T: Body + Send + 'static,
627 {
628 debug!("Handling JSON POST (non-streaming/legacy)");
629
630 let content_type = req
634 .headers()
635 .get(CONTENT_TYPE)
636 .and_then(|ct| ct.to_str().ok())
637 .map(normalize_header_value)
638 .unwrap_or_default();
639
640 if !content_type.starts_with("application/json") {
641 warn!("Invalid content type for legacy POST: {}", content_type);
642 return StreamableResponse::Error {
643 status: StatusCode::BAD_REQUEST,
644 message: "Content-Type must be application/json".to_string(),
645 }
646 .into_boxed_response(&context);
647 }
648
649 let body_bytes = match req.into_body().collect().await {
651 Ok(collected) => collected.to_bytes(),
652 Err(_err) => {
653 error!("Failed to read legacy POST request body");
654 return StreamableResponse::Error {
655 status: StatusCode::BAD_REQUEST,
656 message: "Failed to read request body".to_string(),
657 }
658 .into_boxed_response(&context);
659 }
660 };
661
662 if body_bytes.len() > self.config.max_body_size {
664 warn!(
665 "Legacy POST request body too large: {} bytes",
666 body_bytes.len()
667 );
668 return StreamableResponse::Error {
669 status: StatusCode::PAYLOAD_TOO_LARGE,
670 message: "Request body too large".to_string(),
671 }
672 .into_boxed_response(&context);
673 }
674
675 let body_str = match std::str::from_utf8(&body_bytes) {
677 Ok(s) => s,
678 Err(err) => {
679 error!("Invalid UTF-8 in legacy POST request body: {}", err);
680 return StreamableResponse::Error {
681 status: StatusCode::BAD_REQUEST,
682 message: "Request body must be valid UTF-8".to_string(),
683 }
684 .into_boxed_response(&context);
685 }
686 };
687
688 debug!("Received legacy POST JSON-RPC request: {}", body_str);
689
690 use turul_mcp_json_rpc_server::dispatch::{
692 JsonRpcMessage, JsonRpcMessageResult, parse_json_rpc_message,
693 };
694
695 let message = match parse_json_rpc_message(body_str) {
696 Ok(msg) => msg,
697 Err(rpc_err) => {
698 error!("JSON-RPC parse error in legacy POST: {}", rpc_err);
699 let error_json =
700 serde_json::to_string(&rpc_err).unwrap_or_else(|_| "{}".to_string());
701 return Response::builder()
702 .status(StatusCode::OK) .header(CONTENT_TYPE, "application/json")
704 .header("MCP-Protocol-Version", context.protocol_version.as_str())
705 .body(Full::new(Bytes::from(error_json)))
706 .unwrap()
707 .map(|body| body.map_err(|never| match never {}).boxed_unsync());
708 }
709 };
710
711 let message_result = match message {
714 JsonRpcMessage::Request(request) => {
715 debug!(
716 "Processing legacy POST JSON-RPC request: method={}",
717 request.method
718 );
719
720 let response = if request.method == "initialize" {
722 debug!("Handling legacy initialize request - creating new session");
723
724 match self
726 .session_storage
727 .create_session(self.server_capabilities.clone())
728 .await
729 {
730 Ok(session_info) => {
731 debug!(
732 "Created new session for legacy client: {}",
733 session_info.session_id
734 );
735
736 use crate::notification_bridge::StreamManagerNotificationBroadcaster;
738 use turul_mcp_json_rpc_server::r#async::SessionContext;
739
740 let broadcaster = Arc::new(StreamManagerNotificationBroadcaster::new(
741 Arc::clone(&self.stream_manager),
742 ));
743 let broadcaster_any =
744 Arc::new(broadcaster) as Arc<dyn std::any::Any + Send + Sync>;
745
746 let session_context = SessionContext {
747 session_id: session_info.session_id.clone(),
748 metadata: std::collections::HashMap::new(),
749 broadcaster: Some(broadcaster_any),
750 timestamp: chrono::Utc::now().timestamp_millis() as u64,
751 extensions: std::collections::HashMap::new(),
752 };
753
754 self.dispatcher
755 .handle_request_with_context(request, session_context)
756 .await
757 }
758 Err(err) => {
759 error!("Failed to create session during legacy initialize: {}", err);
760 let error_msg = format!("Session creation failed: {}", err);
761 turul_mcp_json_rpc_server::JsonRpcMessage::error(
762 turul_mcp_json_rpc_server::JsonRpcError::internal_error(
763 Some(request.id),
764 Some(error_msg),
765 ),
766 )
767 }
768 }
769 } else {
770 self.dispatcher.handle_request(request).await
772 };
773
774 match response {
776 turul_mcp_json_rpc_server::JsonRpcMessage::Response(resp) => {
777 JsonRpcMessageResult::Response(resp)
778 }
779 turul_mcp_json_rpc_server::JsonRpcMessage::Error(err) => {
780 JsonRpcMessageResult::Error(err)
781 }
782 }
783 }
784 JsonRpcMessage::Notification(notification) => {
785 debug!(
786 "Processing legacy POST JSON-RPC notification: method={}",
787 notification.method
788 );
789
790 let result = self
792 .dispatcher
793 .handle_notification_with_context(notification, None)
794 .await;
795
796 if let Err(err) = result {
797 error!("Legacy POST notification handling error: {}", err);
798 }
799 JsonRpcMessageResult::NoResponse
800 }
801 };
802
803 match message_result {
805 JsonRpcMessageResult::Response(response) => {
806 let response_json = serde_json::to_string(&response)
807 .unwrap_or_else(|_| r#"{"error": "Failed to serialize response"}"#.to_string());
808
809 Response::builder()
810 .status(StatusCode::OK)
811 .header(CONTENT_TYPE, "application/json")
812 .header("MCP-Protocol-Version", context.protocol_version.as_str())
813 .body(Full::new(Bytes::from(response_json)))
814 .unwrap()
815 .map(|body| body.map_err(|never| match never {}).boxed_unsync())
816 }
817 JsonRpcMessageResult::Error(error) => {
818 let error_json = serde_json::to_string(&error)
819 .unwrap_or_else(|_| r#"{"error": "Internal error"}"#.to_string());
820
821 Response::builder()
822 .status(StatusCode::OK) .header(CONTENT_TYPE, "application/json")
824 .header("MCP-Protocol-Version", context.protocol_version.as_str())
825 .body(Full::new(Bytes::from(error_json)))
826 .unwrap()
827 .map(|body| body.map_err(|never| match never {}).boxed_unsync())
828 }
829 JsonRpcMessageResult::NoResponse => {
830 Response::builder()
832 .status(StatusCode::ACCEPTED)
833 .header("MCP-Protocol-Version", context.protocol_version.as_str())
834 .body(Full::new(Bytes::new()))
835 .unwrap()
836 .map(|body| body.map_err(|never| match never {}).boxed_unsync())
837 }
838 }
839 }
840
841 async fn handle_session_delete<T>(
843 &self,
844 _req: Request<T>,
845 context: StreamableHttpContext,
846 ) -> Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>>
847 where
848 T: Body + Send + 'static,
849 {
850 if let Some(session_id) = &context.session_id {
851 debug!("Deleting session: {}", session_id);
852
853 let closed_connections = self
856 .stream_manager
857 .close_session_connections(session_id)
858 .await;
859 debug!(
860 "Closed {} streaming connections for session: {}",
861 closed_connections, session_id
862 );
863
864 match self.session_storage.get_session(session_id).await {
866 Ok(Some(mut session_info)) => {
867 session_info
869 .state
870 .insert("terminated".to_string(), serde_json::Value::Bool(true));
871 session_info.state.insert(
872 "terminated_at".to_string(),
873 serde_json::Value::Number(serde_json::Number::from(
874 chrono::Utc::now().timestamp_millis(),
875 )),
876 );
877 session_info.touch();
878
879 match self.session_storage.update_session(session_info).await {
881 Ok(()) => {
882 debug!(
883 "Session {} marked as terminated (TTL will handle cleanup)",
884 session_id
885 );
886
887 Response::builder()
889 .status(StatusCode::OK)
890 .header(CONTENT_TYPE, "application/json")
891 .header("MCP-Protocol-Version", context.protocol_version.as_str())
892 .header("Mcp-Session-Id", session_id)
893 .body(Full::new(Bytes::from(
894 serde_json::to_string(&serde_json::json!({
895 "status": "session_terminated",
896 "session_id": session_id,
897 "closed_connections": closed_connections,
898 "message": "Session marked for cleanup"
899 }))
900 .unwrap_or_else(|_| {
901 r#"{"status":"session_terminated"}"#.to_string()
902 }),
903 )))
904 .unwrap()
905 .map(|body| body.map_err(|never| match never {}).boxed_unsync())
906 }
907 Err(err) => {
908 error!(
909 "Error marking session {} as terminated: {}",
910 session_id, err
911 );
912 match self.session_storage.delete_session(session_id).await {
914 Ok(_) => {
915 debug!("Session {} deleted as fallback", session_id);
916 Response::builder()
917 .status(StatusCode::OK)
918 .header(CONTENT_TYPE, "application/json")
919 .header(
920 "MCP-Protocol-Version",
921 context.protocol_version.as_str(),
922 )
923 .body(Full::new(Bytes::from(
924 serde_json::to_string(&serde_json::json!({
925 "status": "session_deleted",
926 "session_id": session_id,
927 "closed_connections": closed_connections,
928 "message": "Session removed"
929 }))
930 .unwrap_or_else(|_| {
931 r#"{"status":"session_deleted"}"#.to_string()
932 }),
933 )))
934 .unwrap()
935 .map(|body| {
936 body.map_err(|never| match never {}).boxed_unsync()
937 })
938 }
939 Err(delete_err) => {
940 error!(
941 "Error deleting session {} as fallback: {}",
942 session_id, delete_err
943 );
944 StreamableResponse::Error {
945 status: StatusCode::INTERNAL_SERVER_ERROR,
946 message: "Session termination error".to_string(),
947 }
948 .into_boxed_response(&context)
949 }
950 }
951 }
952 }
953 }
954 Ok(None) => {
955 Response::builder()
957 .status(StatusCode::NOT_FOUND)
958 .header(CONTENT_TYPE, "application/json")
959 .header("MCP-Protocol-Version", context.protocol_version.as_str())
960 .body(Full::new(Bytes::from(
961 serde_json::to_string(&serde_json::json!({
962 "status": "session_not_found",
963 "session_id": session_id,
964 "message": "Session not found"
965 }))
966 .unwrap_or_else(|_| r#"{"status":"session_not_found"}"#.to_string()),
967 )))
968 .unwrap()
969 .map(|body| body.map_err(|never| match never {}).boxed_unsync())
970 }
971 Err(err) => {
972 error!(
973 "Error retrieving session {} for termination: {}",
974 session_id, err
975 );
976 StreamableResponse::Error {
977 status: StatusCode::INTERNAL_SERVER_ERROR,
978 message: "Session lookup error".to_string(),
979 }
980 .into_boxed_response(&context)
981 }
982 }
983 } else {
984 StreamableResponse::Error {
985 status: StatusCode::BAD_REQUEST,
986 message: "Mcp-Session-Id header required for session deletion".to_string(),
987 }
988 .into_boxed_response(&context)
989 }
990 }
991
992 async fn handle_post_streamable_http<T>(
1001 &self,
1002 req: Request<T>,
1003 mut context: StreamableHttpContext,
1004 ) -> Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>>
1005 where
1006 T: Body + Send + 'static,
1007 {
1008 debug!("Streaming handler called - using true streaming POST");
1009
1010 let body_bytes = match req.into_body().collect().await {
1012 Ok(collected) => collected.to_bytes(),
1013 Err(_err) => {
1014 error!("Failed to read streaming POST request body");
1015 return StreamableResponse::Error {
1016 status: StatusCode::BAD_REQUEST,
1017 message: "Failed to read request body".to_string(),
1018 }
1019 .into_boxed_response(&context);
1020 }
1021 };
1022
1023 if body_bytes.len() > self.config.max_body_size {
1025 warn!(
1026 "Streaming POST request body too large: {} bytes",
1027 body_bytes.len()
1028 );
1029 return StreamableResponse::Error {
1030 status: StatusCode::PAYLOAD_TOO_LARGE,
1031 message: "Request body too large".to_string(),
1032 }
1033 .into_boxed_response(&context);
1034 }
1035
1036 let body_str = match std::str::from_utf8(&body_bytes) {
1038 Ok(s) => s,
1039 Err(err) => {
1040 error!("Invalid UTF-8 in streaming POST request body: {}", err);
1041 return StreamableResponse::Error {
1042 status: StatusCode::BAD_REQUEST,
1043 message: "Request body must be valid UTF-8".to_string(),
1044 }
1045 .into_boxed_response(&context);
1046 }
1047 };
1048
1049 debug!("Streaming POST received JSON-RPC request: {}", body_str);
1050
1051 use turul_mcp_json_rpc_server::dispatch::{JsonRpcMessage, parse_json_rpc_message};
1053 use turul_mcp_json_rpc_server::error::JsonRpcErrorObject;
1054
1055 let message = match parse_json_rpc_message(body_str) {
1056 Ok(msg) => msg,
1057 Err(rpc_err) => {
1058 error!("JSON-RPC parse error in streaming POST: {}", rpc_err);
1059 let error_json =
1060 serde_json::to_string(&rpc_err).unwrap_or_else(|_| "{}".to_string());
1061
1062 return Response::builder()
1064 .status(StatusCode::OK) .header(CONTENT_TYPE, "application/json")
1066 .header("MCP-Protocol-Version", context.protocol_version.as_str())
1067 .body(
1068 Full::new(Bytes::from(error_json))
1069 .map_err(|never| match never {})
1070 .boxed_unsync(),
1071 )
1072 .unwrap();
1073 }
1074 };
1075
1076 let is_sessionless_ping = match &message {
1081 JsonRpcMessage::Request(req) => req.method == "ping",
1082 JsonRpcMessage::Notification(notif) => notif.method == "ping",
1083 } && context.session_id.is_none()
1084 && self.config.allow_unauthenticated_ping;
1085
1086 if is_sessionless_ping {
1087 return match message {
1088 JsonRpcMessage::Request(request) => {
1089 let (response, _) = self
1091 .run_middleware_and_dispatch(request, context.headers.clone(), None, None)
1092 .await;
1093 let response_value =
1094 serde_json::to_value(&response).unwrap_or(serde_json::json!({}));
1095 StreamableResponse::Json(response_value).into_boxed_response(&context)
1096 }
1097 JsonRpcMessage::Notification(notification) => {
1098 let dispatcher = Arc::clone(&self.dispatcher);
1100 tokio::spawn(async move {
1101 if let Err(e) = dispatcher
1102 .handle_notification_with_context(notification, None)
1103 .await
1104 {
1105 error!("Failed to process sessionless ping notification: {}", e);
1106 }
1107 });
1108
1109 Response::builder()
1110 .status(StatusCode::ACCEPTED)
1111 .header("MCP-Protocol-Version", context.protocol_version.as_str())
1112 .body(
1113 Full::new(Bytes::new())
1114 .map_err(|never| match never {})
1115 .boxed_unsync(),
1116 )
1117 .unwrap()
1118 }
1119 };
1120 }
1121
1122 let bearer_token = context
1125 .headers
1126 .get("authorization")
1127 .and_then(|v| extract_bearer_token(v));
1128
1129 let pre_session_extensions = if self.middleware_stack.has_pre_session_middleware() {
1131 let method_name = match &message {
1132 JsonRpcMessage::Request(req) => req.method.as_str(),
1133 JsonRpcMessage::Notification(notif) => notif.method.as_str(),
1134 };
1135 let mut pre_ctx = crate::middleware::RequestContext::new(method_name, None);
1136 if let Some(ref token) = bearer_token {
1137 pre_ctx.set_bearer_token(token.clone());
1138 }
1139 for (k, v) in &context.headers {
1141 if k.eq_ignore_ascii_case("authorization") && is_bearer_scheme(v) {
1142 continue;
1143 }
1144 pre_ctx.add_metadata(k.clone(), serde_json::json!(v));
1145 }
1146 match self
1147 .middleware_stack
1148 .execute_before_session(&mut pre_ctx)
1149 .await
1150 {
1151 Ok(()) => Some(pre_ctx.take_extensions()),
1152 Err(crate::middleware::MiddlewareError::HttpChallenge {
1153 status,
1154 www_authenticate,
1155 body,
1156 }) => {
1157 return build_http_challenge_response(
1158 status,
1159 &www_authenticate,
1160 body.as_deref(),
1161 &context,
1162 );
1163 }
1164 Err(other_err) => {
1165 if let JsonRpcMessage::Request(ref req) = message {
1167 let response =
1168 Self::map_middleware_error_to_jsonrpc(other_err, req.id.clone());
1169 let response_value =
1170 serde_json::to_value(&response).unwrap_or(serde_json::json!({}));
1171 return StreamableResponse::Json(response_value)
1172 .into_boxed_response(&context);
1173 } else {
1174 return Response::builder()
1176 .status(StatusCode::FORBIDDEN)
1177 .body(
1178 Full::new(Bytes::from(other_err.to_string()))
1179 .map_err(|never| match never {})
1180 .boxed_unsync(),
1181 )
1182 .unwrap();
1183 }
1184 }
1185 }
1186 } else {
1187 None
1188 };
1189
1190 let session_id = match &message {
1192 JsonRpcMessage::Request(req) if req.method == "initialize" => {
1193 if let Some(existing_id) = &context.session_id {
1195 if let Err(err) = self.validate_session_exists(existing_id).await {
1197 warn!(
1198 "Invalid session ID {} during initialize: {}",
1199 existing_id,
1200 err.message()
1201 );
1202 return StreamableResponse::Error {
1203 status: err.status_code(),
1204 message: format!("Invalid or expired session: {}", err.message()),
1205 }
1206 .into_boxed_response(&context);
1207 }
1208 existing_id.clone()
1209 } else {
1210 match self
1212 .session_storage
1213 .create_session(self.server_capabilities.clone())
1214 .await
1215 {
1216 Ok(session_info) => {
1217 debug!(
1218 "Created new session for initialize: {}",
1219 session_info.session_id
1220 );
1221 context.session_id = Some(session_info.session_id.clone());
1222 session_info.session_id
1223 }
1224 Err(err) => {
1225 error!("Failed to create session during initialize: {}", err);
1226 return StreamableResponse::Error {
1227 status: StatusCode::INTERNAL_SERVER_ERROR,
1228 message: "Failed to create session".to_string(),
1229 }
1230 .into_boxed_response(&context);
1231 }
1232 }
1233 }
1234 }
1235 JsonRpcMessage::Request(_) | JsonRpcMessage::Notification(_) => {
1236 if let Some(existing_id) = &context.session_id {
1238 if let Err(err) = self.validate_session_exists(existing_id).await {
1240 warn!("Invalid session ID {}: {}", existing_id, err.message());
1241 return StreamableResponse::Error {
1242 status: err.status_code(),
1243 message: format!("Invalid or expired session: {}", err.message()),
1244 }
1245 .into_boxed_response(&context);
1246 }
1247 existing_id.clone()
1248 } else {
1249 let method_name = match &message {
1252 JsonRpcMessage::Request(req) => &req.method,
1253 JsonRpcMessage::Notification(notif) => ¬if.method,
1254 };
1255 let request_id = match &message {
1256 JsonRpcMessage::Request(req) => Some(req.id.clone()),
1257 JsonRpcMessage::Notification(_) => None,
1258 };
1259
1260 warn!("Missing session ID for method: {}", method_name);
1261
1262 let error_response = turul_mcp_json_rpc_server::JsonRpcError::new(
1263 request_id,
1264 JsonRpcErrorObject::server_error(
1265 -32001,
1266 "Missing Mcp-Session-Id header. Call initialize first.",
1267 None::<serde_json::Value>,
1268 ),
1269 );
1270
1271 let error_json =
1272 serde_json::to_string(&error_response).unwrap_or_else(|_| "{}".to_string());
1273
1274 return Response::builder()
1275 .status(StatusCode::UNAUTHORIZED)
1276 .header(CONTENT_TYPE, "application/json")
1277 .header("MCP-Protocol-Version", context.protocol_version.as_str())
1278 .body(
1279 Full::new(Bytes::from(error_json))
1280 .map_err(|never| match never {})
1281 .boxed_unsync(),
1282 )
1283 .unwrap();
1284 }
1285 }
1286 };
1287
1288 debug!("Processing streaming request with session: {}", session_id);
1289
1290 match message {
1292 JsonRpcMessage::Request(request) => {
1293 debug!(
1294 "Processing streaming JSON-RPC request: method={}",
1295 request.method
1296 );
1297 self.create_streaming_response(
1298 request,
1299 session_id,
1300 context,
1301 pre_session_extensions.clone(),
1302 )
1303 .await
1304 }
1305 JsonRpcMessage::Notification(notification) => {
1306 debug!(
1307 "Processing streaming JSON-RPC notification: method={}",
1308 notification.method
1309 );
1310
1311 use crate::notification_bridge::StreamManagerNotificationBroadcaster;
1313 use turul_mcp_json_rpc_server::SessionContext;
1314
1315 let broadcaster = Arc::new(StreamManagerNotificationBroadcaster::new(Arc::clone(
1316 &self.stream_manager,
1317 )));
1318 let broadcaster_any = Arc::new(broadcaster) as Arc<dyn std::any::Any + Send + Sync>;
1319
1320 let session_context = SessionContext {
1321 session_id: session_id.clone(),
1322 metadata: std::collections::HashMap::new(),
1323 broadcaster: Some(broadcaster_any),
1324 timestamp: chrono::Utc::now().timestamp_millis() as u64,
1325 extensions: std::collections::HashMap::new(),
1326 };
1327
1328 let dispatcher = Arc::clone(&self.dispatcher);
1330 let notification_clone = notification.clone();
1331
1332 if notification_clone.method == "notifications/initialized" {
1342 if let Err(e) = dispatcher
1343 .handle_notification_with_context(
1344 notification_clone,
1345 Some(session_context),
1346 )
1347 .await
1348 {
1349 error!(
1350 "Failed to process notifications/initialized: {}. \
1351 Session will remain uninitialized — subsequent requests will fail.",
1352 e
1353 );
1354 }
1355 } else {
1356 tokio::spawn(async move {
1357 if let Err(e) = dispatcher
1358 .handle_notification_with_context(
1359 notification_clone,
1360 Some(session_context),
1361 )
1362 .await
1363 {
1364 error!("Failed to process notification: {}", e);
1365 }
1366 });
1367 }
1368
1369 Response::builder()
1371 .status(StatusCode::ACCEPTED)
1372 .header("MCP-Protocol-Version", context.protocol_version.as_str())
1373 .header("Mcp-Session-Id", &session_id)
1374 .body(
1375 Full::new(Bytes::new())
1376 .map_err(|never| match never {})
1377 .boxed_unsync(),
1378 )
1379 .unwrap()
1380 }
1381 }
1382 }
1383
1384 async fn create_streaming_response(
1387 &self,
1388 request: turul_mcp_json_rpc_server::JsonRpcRequest,
1389 session_id: String,
1390 context: StreamableHttpContext,
1391 pre_session_extensions: Option<HashMap<String, serde_json::Value>>,
1392 ) -> Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>> {
1393 debug!(
1394 "Creating streaming response for method: {}, session: {}",
1395 request.method, session_id
1396 );
1397 use http_body_util::StreamBody;
1399 use tokio_stream::StreamExt;
1400 use tokio_stream::wrappers::UnboundedReceiverStream; let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<Result<bytes::Bytes, hyper::Error>>();
1403 let body_stream =
1404 UnboundedReceiverStream::new(rx).map(|item| item.map(http_body::Frame::data));
1405 let body = StreamBody::new(body_stream);
1406
1407 use crate::notification_bridge::{
1409 SharedNotificationBroadcaster, StreamManagerNotificationBroadcaster,
1410 };
1411 use turul_mcp_json_rpc_server::SessionContext;
1412
1413 let broadcaster: SharedNotificationBroadcaster = Arc::new(
1414 StreamManagerNotificationBroadcaster::new(Arc::clone(&self.stream_manager)),
1415 );
1416 let broadcaster_any = Arc::new(broadcaster) as Arc<dyn std::any::Any + Send + Sync>;
1417
1418 let session_context = SessionContext {
1419 session_id: session_id.clone(),
1420 metadata: std::collections::HashMap::new(),
1421 broadcaster: Some(broadcaster_any),
1422 timestamp: chrono::Utc::now().timestamp_millis() as u64,
1423 extensions: std::collections::HashMap::new(),
1424 };
1425
1426 let wants_sse = context.wants_sse_stream();
1428 let connection_id = format!("post-{}", uuid::Uuid::now_v7().as_simple());
1429
1430 let (shutdown_tx, completion_rx) = if wants_sse {
1432 let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>();
1434 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel::<()>();
1435 let (progress_tx, mut progress_rx) = tokio::sync::mpsc::channel(100);
1436
1437 let registration_result = self
1439 .stream_manager
1440 .register_streaming_connection(&session_id, connection_id.clone(), progress_tx)
1441 .await;
1442
1443 if let Err(e) = registration_result {
1444 error!("Failed to register POST streaming connection: {}", e);
1445 (None, None)
1447 } else {
1448 debug!(
1449 "Registered SSE streaming connection for session: {}",
1450 session_id
1451 );
1452
1453 let sender_clone = tx.clone();
1455 let session_id_clone = session_id.clone();
1456 let connection_id_clone = connection_id.clone();
1457 let stream_manager_clone = Arc::clone(&self.stream_manager);
1458
1459 tokio::spawn(async move {
1460 debug!(
1461 "Starting progress forwarding task for session: {}",
1462 session_id_clone
1463 );
1464
1465 loop {
1467 debug!(
1468 "🔍 Progress task entering select loop for session: {}",
1469 session_id_clone
1470 );
1471 tokio::select! {
1472 maybe_event = progress_rx.recv() => {
1474 debug!("🔍 Progress task: progress_rx.recv() branch fired for session: {}", session_id_clone);
1475 match maybe_event {
1476 Some(sse_event) => {
1477 debug!("🔍 Forwarding progress event to POST response: session={}, event={:?}", session_id_clone, sse_event.event_type);
1478
1479 let sse_chunk = sse_event.format();
1481
1482 if let Err(e) = sender_clone.send(Ok(Bytes::from(sse_chunk))) {
1483 error!("Failed to send progress event to POST response: {}", e);
1484 break;
1485 }
1486 }
1487 None => {
1488 debug!("🔍 Progress channel closed naturally for session: {}", session_id_clone);
1490 break;
1491 }
1492 }
1493 }
1494 _ = &mut shutdown_rx => {
1496 debug!("🔍 Progress task: shutdown_rx branch fired! Received explicit shutdown signal for session: {}", session_id_clone);
1497 break;
1498 }
1499 }
1500 }
1501
1502 debug!(
1504 "Progress task unregistering connection for session: {}",
1505 session_id_clone
1506 );
1507 stream_manager_clone
1508 .unregister_connection(&session_id_clone, &connection_id_clone)
1509 .await;
1510
1511 debug!(
1513 "🔍 Progress task: dropping sender_clone for session: {}",
1514 session_id_clone
1515 );
1516 drop(sender_clone);
1517
1518 debug!(
1520 "🔍 Progress task: signaling completion for session: {}",
1521 session_id_clone
1522 );
1523 if completion_tx.send(()).is_err() {
1524 debug!(
1525 "🔍 Progress task: main task already dropped completion_rx for session: {}",
1526 session_id_clone
1527 );
1528 }
1529
1530 debug!(
1531 "🔍 Progress forwarding task completed for session: {}",
1532 session_id_clone
1533 );
1534 });
1535
1536 (Some(shutdown_tx), Some(completion_rx))
1538 }
1539 } else {
1540 (None, None)
1542 };
1543
1544 let request_id = request.id.clone();
1546 let sender = tx; let headers = context.headers.clone();
1550 let self_clone = self.clone();
1551
1552 tokio::spawn(async move {
1553 debug!(
1554 "Spawning streaming task for request ID: {:?}, wants_sse: {}",
1555 request_id, wants_sse
1556 );
1557
1558 let (response, _) = self_clone
1561 .run_middleware_and_dispatch(
1562 request,
1563 headers,
1564 Some(session_context),
1565 pre_session_extensions,
1566 )
1567 .await;
1568
1569 if wants_sse {
1571 let final_frame = match response {
1573 turul_mcp_json_rpc_server::JsonRpcMessage::Response(resp) => {
1574 turul_mcp_json_rpc_server::JsonRpcFrame::FinalResult {
1575 request_id: request_id.clone(),
1576 result: match resp.result {
1577 turul_mcp_json_rpc_server::response::ResponseResult::Success(
1578 val,
1579 ) => val,
1580 turul_mcp_json_rpc_server::response::ResponseResult::Null => {
1581 serde_json::Value::Null
1582 }
1583 },
1584 }
1585 }
1586 turul_mcp_json_rpc_server::JsonRpcMessage::Error(err) => {
1587 turul_mcp_json_rpc_server::JsonRpcFrame::Error {
1588 request_id: request_id.clone(),
1589 error: turul_mcp_json_rpc_server::error::JsonRpcErrorObject {
1590 code: err.error.code,
1591 message: err.error.message,
1592 data: err.error.data,
1593 },
1594 }
1595 }
1596 };
1597
1598 let final_json = final_frame.to_json();
1599 let final_chunk =
1601 format!("data: {}\n\n", serde_json::to_string(&final_json).unwrap());
1602
1603 if let Err(err) = sender.send(Ok(Bytes::from(final_chunk))) {
1604 error!("Failed to send SSE final chunk: {}", err);
1605 }
1606
1607 if let Some(shutdown_tx) = shutdown_tx {
1610 debug!(
1611 "🔍 Main task sending shutdown signal to progress task for request: {:?}",
1612 request_id
1613 );
1614 match shutdown_tx.send(()) {
1615 Ok(()) => {
1616 debug!(
1617 "🔍 Main task: shutdown signal sent successfully for request: {:?}",
1618 request_id
1619 );
1620
1621 if let Some(completion_rx) = completion_rx {
1624 match tokio::time::timeout(
1625 tokio::time::Duration::from_millis(100),
1626 completion_rx,
1627 )
1628 .await
1629 {
1630 Ok(Ok(())) => {
1631 debug!(
1632 "🔍 Main task: progress task completed successfully for request: {:?}",
1633 request_id
1634 );
1635 }
1636 Ok(Err(_)) => {
1637 debug!(
1638 "🔍 Main task: progress task completion signal dropped for request: {:?}",
1639 request_id
1640 );
1641 }
1642 Err(_) => {
1643 debug!(
1644 "🔍 Main task: progress task completion timeout for request: {:?}",
1645 request_id
1646 );
1647 }
1648 }
1649 }
1650 }
1651 Err(_) => {
1652 debug!(
1653 "🔍 Main task: progress task already completed (shutdown_rx dropped) for request: {:?}",
1654 request_id
1655 );
1656 }
1657 }
1658 } else {
1659 debug!(
1660 "🔍 Main task: no shutdown_tx available (not SSE client) for request: {:?}",
1661 request_id
1662 );
1663 }
1664 } else {
1665 let final_json = serde_json::to_string(&response).unwrap();
1667
1668 if let Err(err) = sender.send(Ok(Bytes::from(final_json))) {
1669 error!("Failed to send final JSON response: {}", err);
1670 }
1671 }
1672
1673 debug!(
1674 "🔍 Main task: streaming task completed for request ID: {:?}",
1675 request_id
1676 );
1677
1678 debug!(
1680 "🔍 Main task: dropping main sender for request ID: {:?}",
1681 request_id
1682 );
1683 drop(sender);
1684 });
1685
1686 let content_type = if context.wants_sse_stream() {
1689 "text/event-stream"
1690 } else {
1691 "application/json"
1692 };
1693
1694 let mut response = Response::builder()
1695 .status(StatusCode::OK)
1696 .header(CONTENT_TYPE, content_type)
1697 .header("Transfer-Encoding", "chunked") .header("Cache-Control", "no-cache")
1699 .body(http_body_util::BodyExt::boxed_unsync(body))
1700 .unwrap();
1701
1702 let mcp_headers = context.response_headers();
1704 for (key, value) in mcp_headers.iter() {
1705 response.headers_mut().insert(key, value.clone());
1706 }
1707
1708 response
1709 }
1710
1711 #[allow(dead_code)]
1713 async fn handle_buffered_post<T>(
1714 &self,
1715 _req: Request<T>,
1716 context: StreamableHttpContext,
1717 session_id: String,
1718 ) -> Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>>
1719 where
1720 T: Body + Send + 'static,
1721 {
1722 debug!(
1723 "Using buffered POST for legacy client, session: {}",
1724 session_id
1725 );
1726
1727 Response::builder()
1731 .status(StatusCode::OK)
1732 .header(CONTENT_TYPE, "application/json")
1733 .header("MCP-Protocol-Version", context.protocol_version.as_str())
1734 .header("Mcp-Session-Id", &session_id)
1735 .body(
1736 Full::new(Bytes::from(
1737 r#"{"jsonrpc":"2.0","id":1,"result":"buffered"}"#,
1738 ))
1739 .map_err(|never| match never {})
1740 .boxed_unsync(),
1741 )
1742 .unwrap()
1743 }
1744
1745 async fn handle_client_message<T>(
1749 &self,
1750 req: Request<T>,
1751 context: StreamableHttpContext,
1752 ) -> Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>>
1753 where
1754 T: Body + Send + 'static,
1755 {
1756 debug!("Handling client message via POST (MCP 2025-11-25)");
1757
1758 if !context.accepts_stream_frames {
1761 warn!("Client POST missing application/json in Accept header");
1762 return StreamableResponse::Error {
1763 status: StatusCode::BAD_REQUEST,
1764 message: "Accept header must include application/json, text/event-stream, or */*"
1765 .to_string(),
1766 }
1767 .into_boxed_response(&context);
1768 }
1769
1770 let content_type = req
1772 .headers()
1773 .get(CONTENT_TYPE)
1774 .and_then(|ct| ct.to_str().ok())
1775 .map(normalize_header_value)
1776 .unwrap_or_default();
1777 if !content_type.starts_with("application/json") {
1778 warn!("Invalid content type for POST: {}", content_type);
1779 return StreamableResponse::Error {
1780 status: StatusCode::BAD_REQUEST,
1781 message: "Content-Type must be application/json".to_string(),
1782 }
1783 .into_boxed_response(&context);
1784 }
1785
1786 debug!("Using streaming POST handler for all requests");
1789 return self.handle_post_streamable_http(req, context).await;
1790 }
1791
1792 async fn run_middleware_and_dispatch(
1807 &self,
1808 request: turul_mcp_json_rpc_server::JsonRpcRequest,
1809 headers: HashMap<String, String>,
1810 session: Option<turul_mcp_json_rpc_server::SessionContext>,
1811 pre_session_extensions: Option<HashMap<String, serde_json::Value>>,
1812 ) -> (
1813 turul_mcp_json_rpc_server::JsonRpcMessage,
1814 Option<crate::middleware::SessionInjection>,
1815 ) {
1816 if self.middleware_stack.is_empty() {
1818 let result = if let Some(session_ctx) = session {
1819 self.dispatcher
1820 .handle_request_with_context(request, session_ctx)
1821 .await
1822 } else {
1823 self.dispatcher.handle_request(request).await
1824 };
1825 return (result, None);
1826 }
1827
1828 let normalized_headers: HashMap<String, String> = headers
1830 .iter()
1831 .map(|(k, v)| (k.to_lowercase(), v.clone()))
1832 .collect();
1833
1834 let method = request.method.clone();
1836
1837 let params = request.params.clone().map(|p| match p {
1839 turul_mcp_json_rpc_server::RequestParams::Object(map) => {
1840 serde_json::Value::Object(map.into_iter().collect())
1841 }
1842 turul_mcp_json_rpc_server::RequestParams::Array(arr) => serde_json::Value::Array(arr),
1843 });
1844 let mut ctx = crate::middleware::RequestContext::new(&method, params);
1845
1846 if let Some(ext) = pre_session_extensions {
1848 for (k, v) in ext {
1849 ctx.set_extension(k, v);
1850 }
1851 }
1852
1853 for (k, v) in normalized_headers {
1854 if k == "authorization" && is_bearer_scheme(&v) {
1856 continue;
1857 }
1858 ctx.add_metadata(k, serde_json::json!(v));
1859 }
1860
1861 let session_view = session.as_ref().map(|s| {
1863 crate::middleware::StorageBackedSessionView::new(
1864 s.session_id.clone(),
1865 Arc::clone(&self.session_storage),
1866 )
1867 });
1868
1869 let injection = match self
1871 .middleware_stack
1872 .execute_before(
1873 &mut ctx,
1874 session_view.as_ref().map(|v| v as &dyn SessionView),
1875 )
1876 .await
1877 {
1878 Ok(inj) => inj,
1879 Err(err) => {
1880 return (Self::map_middleware_error_to_jsonrpc(err, request.id), None);
1881 }
1882 };
1883
1884 if !injection.is_empty()
1886 && let Some(ref sv) = session_view
1887 {
1888 for (key, value) in injection.state() {
1889 if let Err(e) = sv.set_state(key, value.clone()).await {
1890 tracing::warn!("Failed to apply injection state '{}': {}", key, e);
1891 }
1892 }
1893 for (key, value) in injection.metadata() {
1894 if let Err(e) = sv.set_metadata(key, value.clone()).await {
1895 tracing::warn!("Failed to apply injection metadata '{}': {}", key, e);
1896 }
1897 }
1898 }
1899
1900 let session = session.map(|mut s| {
1902 s.extensions = ctx.extensions().clone();
1903 s
1904 });
1905
1906 let result = if let Some(session_ctx) = session {
1908 self.dispatcher
1909 .handle_request_with_context(request, session_ctx)
1910 .await
1911 } else {
1912 self.dispatcher.handle_request(request).await
1913 };
1914
1915 let mut dispatcher_result = match &result {
1917 turul_mcp_json_rpc_server::JsonRpcMessage::Response(resp) => match &resp.result {
1918 turul_mcp_json_rpc_server::response::ResponseResult::Success(val) => {
1919 crate::middleware::DispatcherResult::Success(val.clone())
1920 }
1921 turul_mcp_json_rpc_server::response::ResponseResult::Null => {
1922 crate::middleware::DispatcherResult::Success(serde_json::Value::Null)
1923 }
1924 },
1925 turul_mcp_json_rpc_server::JsonRpcMessage::Error(err) => {
1926 crate::middleware::DispatcherResult::Error(err.error.message.clone())
1927 }
1928 };
1929
1930 let _ = self
1931 .middleware_stack
1932 .execute_after(&ctx, &mut dispatcher_result)
1933 .await;
1934
1935 (result, None)
1936 }
1937
1938 fn map_middleware_error_to_jsonrpc(
1940 err: crate::middleware::MiddlewareError,
1941 request_id: turul_mcp_json_rpc_server::RequestId,
1942 ) -> turul_mcp_json_rpc_server::JsonRpcMessage {
1943 use crate::middleware::MiddlewareError;
1944 use crate::middleware::error::error_codes;
1945
1946 let (code, message, data) = match err {
1947 MiddlewareError::Unauthenticated(msg) => (error_codes::UNAUTHENTICATED, msg, None),
1948 MiddlewareError::Unauthorized(msg) => (error_codes::UNAUTHORIZED, msg, None),
1949 MiddlewareError::RateLimitExceeded {
1950 message,
1951 retry_after,
1952 } => {
1953 let data = retry_after.map(|s| serde_json::json!({"retryAfter": s}));
1954 (error_codes::RATE_LIMIT_EXCEEDED, message, data)
1955 }
1956 MiddlewareError::InvalidRequest(msg) => (error_codes::INVALID_REQUEST, msg, None),
1957 MiddlewareError::Internal(msg) => (error_codes::INTERNAL_ERROR, msg, None),
1958 MiddlewareError::Custom { message, .. } => (error_codes::INTERNAL_ERROR, message, None),
1959 MiddlewareError::HttpChallenge { .. } => {
1960 unreachable!(
1961 "HttpChallenge must be caught at transport level before JSON-RPC dispatch"
1962 )
1963 }
1964 };
1965
1966 let error_obj = if let Some(d) = data {
1967 turul_mcp_json_rpc_server::error::JsonRpcErrorObject::server_error(
1968 code,
1969 &message,
1970 Some(d),
1971 )
1972 } else {
1973 turul_mcp_json_rpc_server::error::JsonRpcErrorObject::server_error(
1974 code,
1975 &message,
1976 None::<serde_json::Value>,
1977 )
1978 };
1979
1980 turul_mcp_json_rpc_server::JsonRpcMessage::Error(
1981 turul_mcp_json_rpc_server::JsonRpcError::new(Some(request_id), error_obj),
1982 )
1983 }
1984}
1985
1986use crate::middleware::bearer::{extract_bearer_token, is_bearer_scheme};
1987
1988fn build_http_challenge_response(
1992 status: u16,
1993 www_authenticate: &str,
1994 body: Option<&str>,
1995 context: &StreamableHttpContext,
1996) -> Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>> {
1997 let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::UNAUTHORIZED);
1998 let body_bytes = body.unwrap_or("").to_string();
1999
2000 Response::builder()
2001 .status(status_code)
2002 .header("WWW-Authenticate", www_authenticate)
2003 .header("Cache-Control", "no-store")
2004 .header("Content-Type", "application/json")
2005 .header("MCP-Protocol-Version", context.protocol_version.as_str())
2006 .body(
2007 http_body_util::Full::new(Bytes::from(body_bytes))
2008 .map_err(|never| match never {})
2009 .boxed_unsync(),
2010 )
2011 .unwrap()
2012}
2013
2014#[cfg(test)]
2015mod tests {
2016 use super::*;
2017
2018 #[test]
2019 fn test_version_parsing() {
2020 assert_eq!(
2021 McpProtocolVersion::parse_version("2024-11-05"),
2022 Some(McpProtocolVersion::V2024_11_05)
2023 );
2024 assert_eq!(
2025 McpProtocolVersion::parse_version("2025-03-26"),
2026 Some(McpProtocolVersion::V2025_03_26)
2027 );
2028 assert_eq!(
2029 McpProtocolVersion::parse_version("2025-06-18"),
2030 Some(McpProtocolVersion::V2025_06_18)
2031 );
2032 assert_eq!(McpProtocolVersion::parse_version("invalid"), None);
2033 }
2034
2035 #[test]
2036 fn test_version_capabilities() {
2037 let v1 = McpProtocolVersion::V2024_11_05;
2038 assert!(!v1.supports_streamable_http());
2039 assert!(!v1.supports_meta_fields());
2040
2041 let v2 = McpProtocolVersion::V2025_03_26;
2042 assert!(v2.supports_streamable_http());
2043 assert!(!v2.supports_meta_fields());
2044
2045 let v3 = McpProtocolVersion::V2025_06_18;
2046 assert!(v3.supports_streamable_http());
2047 assert!(v3.supports_meta_fields());
2048 assert!(v3.supports_cursors());
2049 assert!(v3.supports_progress_tokens());
2050 assert!(v3.supports_elicitation());
2051 }
2052
2053 #[test]
2054 fn test_context_validation() {
2055 let mut context = StreamableHttpContext {
2056 protocol_version: McpProtocolVersion::V2025_06_18,
2057 session_id: Some("test-session".to_string()),
2058 wants_sse_stream: true,
2059 accepts_stream_frames: true,
2060 headers: HashMap::new(),
2061 };
2062
2063 assert!(context.validate(&Method::POST).is_ok());
2065 assert!(context.validate(&Method::GET).is_ok());
2067
2068 context.accepts_stream_frames = false;
2070 assert!(context.validate(&Method::POST).is_err());
2071
2072 context.accepts_stream_frames = true;
2073 context.protocol_version = McpProtocolVersion::V2024_11_05;
2074 context.wants_sse_stream = true;
2075 assert!(context.validate(&Method::POST).is_err());
2076
2077 context.protocol_version = McpProtocolVersion::V2025_06_18;
2078 context.session_id = None;
2079 assert!(context.validate(&Method::POST).is_ok());
2081 assert!(context.validate(&Method::GET).is_err());
2083 }
2084
2085 #[test]
2089 fn test_non_bearer_preserved_in_metadata() {
2090 let mut ctx = crate::middleware::RequestContext::new("test/method", None);
2092
2093 let headers = vec![
2094 (
2095 "authorization".to_string(),
2096 "Basic dXNlcjpwYXNz".to_string(),
2097 ),
2098 ("x-custom".to_string(), "value".to_string()),
2099 ];
2100
2101 for (k, v) in &headers {
2102 if k == "authorization" && is_bearer_scheme(v) {
2103 continue;
2104 }
2105 ctx.add_metadata(k.clone(), serde_json::json!(v));
2106 }
2107
2108 assert!(ctx.metadata().contains_key("authorization"));
2109 assert!(ctx.metadata().contains_key("x-custom"));
2110
2111 let mut ctx2 = crate::middleware::RequestContext::new("test/method", None);
2113 let bearer_headers = vec![
2114 ("authorization".to_string(), "Bearer abc123".to_string()),
2115 ("x-custom".to_string(), "value".to_string()),
2116 ];
2117
2118 for (k, v) in &bearer_headers {
2119 if k == "authorization" && is_bearer_scheme(v) {
2120 continue;
2121 }
2122 ctx2.add_metadata(k.clone(), serde_json::json!(v));
2123 }
2124
2125 assert!(!ctx2.metadata().contains_key("authorization"));
2126 assert!(ctx2.metadata().contains_key("x-custom"));
2127 }
2128
2129 #[test]
2130 fn test_malformed_bearer_excluded_from_metadata() {
2131 let mut ctx = crate::middleware::RequestContext::new("test/method", None);
2133
2134 let headers = vec![("authorization".to_string(), "Bearer ".to_string())];
2135
2136 for (k, v) in &headers {
2137 if k == "authorization" && is_bearer_scheme(v) {
2138 continue;
2139 }
2140 ctx.add_metadata(k.clone(), serde_json::json!(v));
2141 }
2142
2143 assert!(!ctx.metadata().contains_key("authorization"));
2144 }
2145}