1use std::convert::Infallible;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::task::{Context, Poll};
15use tracing::{debug, error, warn};
16
17use bytes::Bytes;
18use futures::Stream;
19use http_body::{Body, Frame};
20use http_body_util::{BodyExt, Full};
21use hyper::header::{ACCEPT, CONTENT_TYPE};
22use hyper::{Method, Request, Response, StatusCode};
23
24use crate::middleware::bearer::{extract_bearer_token, is_bearer_scheme};
25use chrono;
26use turul_mcp_json_rpc_server::{
27 JsonRpcDispatcher,
28 r#async::SessionContext,
29 dispatch::{JsonRpcMessage, JsonRpcMessageResult, parse_json_rpc_message},
30 error::{JsonRpcError, JsonRpcErrorObject},
31};
32use turul_mcp_protocol::McpError;
33use turul_mcp_protocol::ServerCapabilities;
34use turul_mcp_session_storage::{InMemorySessionStorage, SessionView};
35use uuid::Uuid;
36
37use crate::{
38 Result, ServerConfig, StreamConfig, StreamManager,
39 json_rpc_responses::*,
40 notification_bridge::{SharedNotificationBroadcaster, StreamManagerNotificationBroadcaster},
41 protocol::{
42 extract_last_event_id, extract_protocol_version, extract_session_id, normalize_header_value,
43 },
44};
45use std::collections::HashMap;
46
47pub struct SessionSseStream {
49 stream: Pin<Box<dyn Stream<Item = std::result::Result<Bytes, Infallible>> + Send>>,
50}
51
52impl SessionSseStream {
53 pub fn new<S>(stream: S) -> Self
54 where
55 S: Stream<Item = std::result::Result<Bytes, Infallible>> + Send + 'static,
56 {
57 Self {
58 stream: Box::pin(stream),
59 }
60 }
61}
62
63impl Drop for SessionSseStream {
64 fn drop(&mut self) {
65 debug!("DROP: SessionSseStream - HTTP response body being cleaned up");
66 debug!("This may indicate early cleanup of SSE response stream");
67 }
68}
69
70impl Body for SessionSseStream {
71 type Data = Bytes;
72 type Error = Infallible;
73
74 fn poll_frame(
75 mut self: Pin<&mut Self>,
76 cx: &mut Context<'_>,
77 ) -> Poll<Option<std::result::Result<Frame<Self::Data>, Self::Error>>> {
78 match self.stream.as_mut().poll_next(cx) {
79 Poll::Ready(Some(Ok(data))) => Poll::Ready(Some(Ok(Frame::data(data)))),
80 Poll::Ready(Some(Err(never))) => match never {},
81 Poll::Ready(None) => Poll::Ready(None),
82 Poll::Pending => Poll::Pending,
83 }
84 }
85}
86
87type JsonRpcBody = Full<Bytes>;
89
90type UnifiedMcpBody = http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>;
92
93#[derive(Debug, Clone, PartialEq)]
95enum AcceptMode {
96 Compliant,
98 JsonOnly,
100 SseOnly,
102 Invalid,
104}
105
106fn parse_mcp_accept_header(accept_header: &str) -> (AcceptMode, bool) {
108 let accepts_json = accept_header.contains("application/json") || accept_header.contains("*/*");
109 let accepts_sse = accept_header.contains("text/event-stream");
110
111 let mode = match (accepts_json, accepts_sse) {
112 (true, true) => AcceptMode::Compliant,
113 (true, false) => AcceptMode::JsonOnly, (false, true) => AcceptMode::SseOnly,
115 (false, false) => AcceptMode::Invalid,
116 };
117
118 let should_use_sse = match mode {
121 AcceptMode::Compliant => true, AcceptMode::JsonOnly => false, AcceptMode::SseOnly => true, AcceptMode::Invalid => false, };
126
127 (mode, should_use_sse)
128}
129
130fn convert_to_unified_body(full_body: Full<Bytes>) -> UnifiedMcpBody {
132 full_body.map_err(|never| match never {}).boxed_unsync()
133}
134
135fn jsonrpc_error_to_unified_body(error: JsonRpcError) -> Result<Response<UnifiedMcpBody>> {
137 let error_json = serde_json::to_string(&error)?;
138 Ok(Response::builder()
139 .status(StatusCode::OK) .header(CONTENT_TYPE, "application/json")
141 .body(convert_to_unified_body(Full::new(Bytes::from(error_json))))
142 .unwrap())
143}
144
145pub struct SessionMcpHandler {
149 pub(crate) config: ServerConfig,
150 pub(crate) dispatcher: Arc<JsonRpcDispatcher<McpError>>,
151 pub(crate) session_storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
152 pub(crate) stream_config: StreamConfig,
153 pub(crate) stream_manager: Arc<StreamManager>,
155 pub(crate) middleware_stack: Arc<crate::middleware::MiddlewareStack>,
156}
157
158impl Clone for SessionMcpHandler {
159 fn clone(&self) -> Self {
160 Self {
161 config: self.config.clone(),
162 dispatcher: Arc::clone(&self.dispatcher),
163 session_storage: Arc::clone(&self.session_storage),
164 stream_config: self.stream_config.clone(),
165 stream_manager: Arc::clone(&self.stream_manager),
166 middleware_stack: Arc::clone(&self.middleware_stack),
167 }
168 }
169}
170
171impl SessionMcpHandler {
172 pub fn new(
174 config: ServerConfig,
175 dispatcher: Arc<JsonRpcDispatcher<McpError>>,
176 stream_config: StreamConfig,
177 ) -> Self {
178 let storage: Arc<turul_mcp_session_storage::BoxedSessionStorage> =
179 Arc::new(InMemorySessionStorage::new());
180 let middleware_stack = Arc::new(crate::middleware::MiddlewareStack::new());
181 Self::with_storage(config, dispatcher, storage, stream_config, middleware_stack)
182 }
183
184 pub fn with_shared_stream_manager(
186 config: ServerConfig,
187 dispatcher: Arc<JsonRpcDispatcher<McpError>>,
188 session_storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
189 stream_config: StreamConfig,
190 stream_manager: Arc<StreamManager>,
191 middleware_stack: Arc<crate::middleware::MiddlewareStack>,
192 ) -> Self {
193 Self {
194 config,
195 dispatcher,
196 session_storage,
197 stream_config,
198 stream_manager,
199 middleware_stack,
200 }
201 }
202
203 pub fn with_storage(
206 config: ServerConfig,
207 dispatcher: Arc<JsonRpcDispatcher<McpError>>,
208 session_storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
209 stream_config: StreamConfig,
210 middleware_stack: Arc<crate::middleware::MiddlewareStack>,
211 ) -> Self {
212 let stream_manager = Arc::new(StreamManager::with_config(
214 Arc::clone(&session_storage),
215 stream_config.clone(),
216 ));
217
218 Self {
219 config,
220 dispatcher,
221 session_storage,
222 stream_config,
223 stream_manager,
224 middleware_stack,
225 }
226 }
227
228 pub fn get_stream_manager(&self) -> &Arc<StreamManager> {
230 &self.stream_manager
231 }
232
233 pub async fn handle_mcp_request<B>(&self, req: Request<B>) -> Result<Response<UnifiedMcpBody>>
235 where
236 B: http_body::Body<Data = bytes::Bytes, Error = hyper::Error> + Send + 'static,
237 {
238 debug!(
239 "SESSION HANDLER processing {} {}",
240 req.method(),
241 req.uri().path()
242 );
243 match *req.method() {
244 Method::POST => {
245 let response = self.handle_json_rpc_request(req).await?;
246 Ok(response)
247 }
248 Method::GET => self.handle_sse_request(req).await,
249 Method::DELETE => {
250 let response = self.handle_delete_request(req).await?;
251 Ok(response.map(convert_to_unified_body))
252 }
253 Method::OPTIONS => {
254 let response = self.handle_preflight();
255 Ok(response.map(convert_to_unified_body))
256 }
257 _ => {
258 let response = self.method_not_allowed();
259 Ok(response.map(convert_to_unified_body))
260 }
261 }
262 }
263
264 async fn handle_json_rpc_request<B>(&self, req: Request<B>) -> Result<Response<UnifiedMcpBody>>
266 where
267 B: http_body::Body<Data = bytes::Bytes, Error = hyper::Error> + Send + 'static,
268 {
269 let headers: HashMap<String, String> = req
271 .headers()
272 .iter()
273 .filter_map(|(k, v)| {
274 v.to_str()
275 .ok()
276 .map(|s| (k.as_str().to_string(), s.to_string()))
277 })
278 .collect();
279
280 let protocol_version = extract_protocol_version(req.headers());
282 let session_id = extract_session_id(req.headers());
283
284 debug!(
285 "POST request - Protocol: {}, Session: {:?}",
286 protocol_version, session_id
287 );
288
289 let content_type = req
291 .headers()
292 .get(CONTENT_TYPE)
293 .and_then(|ct| ct.to_str().ok())
294 .map(normalize_header_value)
295 .unwrap_or_default();
296
297 if !content_type.starts_with("application/json") {
298 warn!("Invalid content type: {}", content_type);
299 return Ok(
300 bad_request_response("Content-Type must be application/json")
301 .map(convert_to_unified_body),
302 );
303 }
304
305 let accept_header = req
307 .headers()
308 .get(ACCEPT)
309 .and_then(|accept| accept.to_str().ok())
310 .map(normalize_header_value)
311 .unwrap_or_else(|| "application/json".to_string());
312
313 let (accept_mode, accepts_sse) = parse_mcp_accept_header(&accept_header);
314 debug!(
315 "POST request Accept header: '{}', mode: {:?}, will use SSE for tool calls: {}",
316 accept_header, accept_mode, accepts_sse
317 );
318
319 let body = req.into_body();
321 let body_bytes = match body.collect().await {
322 Ok(collected) => collected.to_bytes(),
323 Err(err) => {
324 error!("Failed to read request body: {}", err);
325 return Ok(bad_request_response("Failed to read request body")
326 .map(convert_to_unified_body));
327 }
328 };
329
330 if body_bytes.len() > self.config.max_body_size {
332 warn!("Request body too large: {} bytes", body_bytes.len());
333 return Ok(Response::builder()
334 .status(StatusCode::PAYLOAD_TOO_LARGE)
335 .header(CONTENT_TYPE, "application/json")
336 .body(convert_to_unified_body(Full::new(Bytes::from(
337 "Request body too large",
338 ))))
339 .unwrap());
340 }
341
342 let body_str = match std::str::from_utf8(&body_bytes) {
344 Ok(s) => s,
345 Err(err) => {
346 error!("Invalid UTF-8 in request body: {}", err);
347 return Ok(bad_request_response("Request body must be valid UTF-8")
348 .map(convert_to_unified_body));
349 }
350 };
351
352 debug!("Received JSON-RPC request: {}", body_str);
353
354 let message = match parse_json_rpc_message(body_str) {
356 Ok(msg) => msg,
357 Err(rpc_err) => {
358 error!("JSON-RPC parse error: {}", rpc_err);
359 let error_response =
361 serde_json::to_string(&rpc_err).unwrap_or_else(|_| "{}".to_string());
362 return Ok(Response::builder()
363 .status(StatusCode::OK) .header(CONTENT_TYPE, "application/json")
365 .body(convert_to_unified_body(Full::new(Bytes::from(
366 error_response,
367 ))))
368 .unwrap());
369 }
370 };
371
372 let pre_session_extensions = if self.middleware_stack.has_pre_session_middleware() {
375 let method_name = match &message {
376 JsonRpcMessage::Request(req) => req.method.as_str(),
377 JsonRpcMessage::Notification(notif) => notif.method.as_str(),
378 };
379 let bearer_token = headers
380 .get("authorization")
381 .and_then(|v| extract_bearer_token(v));
382 let mut pre_ctx = crate::middleware::RequestContext::new(method_name, None);
383 if let Some(ref token) = bearer_token {
384 pre_ctx.set_bearer_token(token.clone());
385 }
386 for (k, v) in &headers {
387 if k.eq_ignore_ascii_case("authorization") && is_bearer_scheme(v) {
388 continue;
389 }
390 pre_ctx.add_metadata(k.clone(), serde_json::json!(v));
391 }
392 match self
393 .middleware_stack
394 .execute_before_session(&mut pre_ctx)
395 .await
396 {
397 Ok(()) => Some(pre_ctx.take_extensions()),
398 Err(crate::middleware::MiddlewareError::HttpChallenge {
399 status,
400 www_authenticate,
401 body,
402 }) => {
403 let body_str = body.unwrap_or_default();
405 return Ok(Response::builder()
406 .status(StatusCode::from_u16(status).unwrap_or(StatusCode::UNAUTHORIZED))
407 .header("WWW-Authenticate", &www_authenticate)
408 .header("Cache-Control", "no-store")
409 .header(CONTENT_TYPE, "application/json")
410 .body(convert_to_unified_body(Full::new(Bytes::from(body_str))))
411 .unwrap());
412 }
413 Err(other_err) => {
414 if let JsonRpcMessage::Request(ref req) = message {
416 let response =
417 Self::map_middleware_error_to_jsonrpc(other_err, req.id.clone());
418 let response_json =
419 serde_json::to_string(&response).unwrap_or_else(|_| "{}".to_string());
420 return Ok(Response::builder()
421 .status(StatusCode::OK)
422 .header(CONTENT_TYPE, "application/json")
423 .body(convert_to_unified_body(Full::new(Bytes::from(
424 response_json,
425 ))))
426 .unwrap());
427 } else {
428 return Ok(Response::builder()
429 .status(StatusCode::FORBIDDEN)
430 .body(convert_to_unified_body(Full::new(Bytes::from(
431 other_err.to_string(),
432 ))))
433 .unwrap());
434 }
435 }
436 }
437 } else {
438 None
439 };
440
441 let (message_result, response_session_id, method_name) = match message {
443 JsonRpcMessage::Request(request) => {
444 debug!("Processing JSON-RPC request: method={}", request.method);
445 let method_name = request.method.clone();
446
447 let (response, response_session_id) = if request.method == "initialize" {
449 debug!(
450 "Handling initialize request - creating new session via session storage"
451 );
452
453 let capabilities = ServerCapabilities::default();
455 match self.session_storage.create_session(capabilities).await {
456 Ok(session_info) => {
457 debug!(
458 "Created new session via session storage: {}",
459 session_info.session_id
460 );
461
462 let broadcaster: SharedNotificationBroadcaster =
464 Arc::new(StreamManagerNotificationBroadcaster::new(Arc::clone(
465 &self.stream_manager,
466 )));
467 let broadcaster_any =
468 Arc::new(broadcaster) as Arc<dyn std::any::Any + Send + Sync>;
469
470 let session_context = SessionContext {
471 session_id: session_info.session_id.clone(),
472 metadata: std::collections::HashMap::new(),
473 broadcaster: Some(broadcaster_any),
474 timestamp: chrono::Utc::now().timestamp_millis() as u64,
475 extensions: std::collections::HashMap::new(),
476 };
477
478 let (response, _) = self
481 .run_middleware_and_dispatch(
482 request,
483 headers.clone(),
484 session_context,
485 pre_session_extensions.clone(),
486 )
487 .await;
488
489 (response, Some(session_info.session_id))
491 }
492 Err(err) => {
493 error!("Failed to create session during initialize: {}", err);
494 let error_msg = format!("Session creation failed: {}", err);
496 let error_response = turul_mcp_json_rpc_server::JsonRpcMessage::error(
497 turul_mcp_json_rpc_server::JsonRpcError::internal_error(
498 Some(request.id),
499 Some(error_msg),
500 ),
501 );
502 (error_response, None)
503 }
504 }
505 } else {
506 let session_context = if let Some(ref session_id_str) = session_id {
509 debug!("Processing request with session: {}", session_id_str);
510 let broadcaster: SharedNotificationBroadcaster =
511 Arc::new(StreamManagerNotificationBroadcaster::new(Arc::clone(
512 &self.stream_manager,
513 )));
514 let broadcaster_any =
515 Arc::new(broadcaster) as Arc<dyn std::any::Any + Send + Sync>;
516 Some(SessionContext {
517 session_id: session_id_str.clone(),
518 metadata: std::collections::HashMap::new(),
519 broadcaster: Some(broadcaster_any),
520 timestamp: chrono::Utc::now().timestamp_millis() as u64,
521 extensions: std::collections::HashMap::new(),
522 })
523 } else {
524 debug!("Processing request without session (lenient mode)");
525 None
526 };
527
528 let (response, _stashed_injection) = if let Some(ctx) = session_context {
530 self.run_middleware_and_dispatch(
531 request,
532 headers.clone(),
533 ctx,
534 pre_session_extensions.clone(),
535 )
536 .await
537 } else {
538 (self.dispatcher.handle_request(request).await, None)
540 };
541 (response, session_id)
542 };
543
544 let message_result = match response {
546 turul_mcp_json_rpc_server::JsonRpcMessage::Response(resp) => {
547 JsonRpcMessageResult::Response(resp)
548 }
549 turul_mcp_json_rpc_server::JsonRpcMessage::Error(err) => {
550 JsonRpcMessageResult::Error(err)
551 }
552 };
553 (message_result, response_session_id, Some(method_name))
554 }
555 JsonRpcMessage::Notification(notification) => {
556 debug!(
557 "Processing JSON-RPC notification: method={}",
558 notification.method
559 );
560 let method_name = notification.method.clone();
561
562 let session_context = if let Some(ref session_id_str) = session_id {
565 debug!("Processing notification with session: {}", session_id_str);
566 let broadcaster: SharedNotificationBroadcaster = Arc::new(
567 StreamManagerNotificationBroadcaster::new(Arc::clone(&self.stream_manager)),
568 );
569 let broadcaster_any =
570 Arc::new(broadcaster) as Arc<dyn std::any::Any + Send + Sync>;
571
572 Some(SessionContext {
573 session_id: session_id_str.clone(),
574 metadata: std::collections::HashMap::new(),
575 broadcaster: Some(broadcaster_any),
576 timestamp: chrono::Utc::now().timestamp_millis() as u64,
577 extensions: std::collections::HashMap::new(),
578 })
579 } else {
580 debug!("Processing notification without session (lenient mode)");
581 None
582 };
583
584 let result = self
585 .dispatcher
586 .handle_notification_with_context(notification, session_context)
587 .await;
588
589 if let Err(err) = result {
590 error!("Notification handling error: {}", err);
591 }
592 (
593 JsonRpcMessageResult::NoResponse,
594 session_id.clone(),
595 Some(method_name),
596 )
597 }
598 };
599
600 match message_result {
602 JsonRpcMessageResult::Response(response) => {
603 let is_tool_call = method_name.as_ref().is_some_and(|m| m == "tools/call");
606
607 debug!(
608 "Decision point: method={:?}, accept_mode={:?}, accepts_sse={}, server_post_sse_enabled={}, session_id={:?}, is_tool_call={}",
609 method_name,
610 accept_mode,
611 accepts_sse,
612 self.config.enable_post_sse,
613 response_session_id,
614 is_tool_call
615 );
616
617 let should_use_sse = match accept_mode {
619 AcceptMode::JsonOnly => false, AcceptMode::Invalid => false, AcceptMode::Compliant => {
622 self.config.enable_post_sse && accepts_sse && is_tool_call
623 } AcceptMode::SseOnly => self.config.enable_post_sse && accepts_sse, };
626
627 if should_use_sse && response_session_id.is_some() {
628 debug!(
629 "📡 Creating POST SSE stream (mode: {:?}) for tool call with notifications",
630 accept_mode
631 );
632 match self
633 .stream_manager
634 .create_post_sse_stream(
635 response_session_id.clone().unwrap(),
636 response.clone(), )
638 .await
639 {
640 Ok(sse_response) => {
641 debug!("✅ POST SSE stream created successfully");
642 Ok(sse_response
643 .map(|body| body.map_err(|never| match never {}).boxed_unsync()))
644 }
645 Err(e) => {
646 warn!(
647 "Failed to create POST SSE stream, falling back to JSON: {}",
648 e
649 );
650 Ok(
651 jsonrpc_response_with_session(response, response_session_id)?
652 .map(convert_to_unified_body),
653 )
654 }
655 }
656 } else {
657 debug!(
658 "📄 Returning standard JSON response (mode: {:?}) for method: {:?}",
659 accept_mode, method_name
660 );
661 Ok(
662 jsonrpc_response_with_session(response, response_session_id)?
663 .map(convert_to_unified_body),
664 )
665 }
666 }
667 JsonRpcMessageResult::Error(error) => {
668 warn!("Sending JSON-RPC error response");
669 let error_json = serde_json::to_string(&error)?;
671 Ok(Response::builder()
672 .status(StatusCode::OK) .header(CONTENT_TYPE, "application/json")
674 .body(convert_to_unified_body(Full::new(Bytes::from(error_json))))
675 .unwrap())
676 }
677 JsonRpcMessageResult::NoResponse => {
678 Ok(jsonrpc_notification_response()?.map(convert_to_unified_body))
680 }
681 }
682 }
683
684 async fn handle_sse_request<B>(&self, req: Request<B>) -> Result<Response<UnifiedMcpBody>>
689 where
690 B: http_body::Body<Data = bytes::Bytes, Error = hyper::Error> + Send + 'static,
691 {
692 let headers = req.headers();
694 let accept = headers
695 .get(ACCEPT)
696 .and_then(|accept| accept.to_str().ok())
697 .map(normalize_header_value)
698 .unwrap_or_default();
699
700 if !accept.contains("text/event-stream") {
701 warn!(
702 "GET request received without SSE support - header does not contain 'text/event-stream'"
703 );
704 let error = JsonRpcError::new(
705 None,
706 JsonRpcErrorObject::server_error(
707 -32001,
708 "SSE not accepted - missing 'text/event-stream' in Accept header",
709 None,
710 ),
711 );
712 return jsonrpc_error_to_unified_body(error);
713 }
714
715 if !self.config.enable_get_sse {
717 warn!("GET SSE request received but GET SSE is disabled on server");
718 let error = JsonRpcError::new(
719 None,
720 JsonRpcErrorObject::server_error(
721 -32003,
722 "GET SSE is disabled on this server",
723 None,
724 ),
725 );
726 return jsonrpc_error_to_unified_body(error);
727 }
728
729 let protocol_version = extract_protocol_version(headers);
731 let session_id = extract_session_id(headers);
732
733 debug!(
734 "GET SSE request - Protocol: {}, Session: {:?}",
735 protocol_version, session_id
736 );
737
738 let session_id = match session_id {
740 Some(id) => id,
741 None => {
742 warn!("Missing Mcp-Session-Id header for SSE request");
743 let error = JsonRpcError::new(
744 None,
745 JsonRpcErrorObject::server_error(-32002, "Missing Mcp-Session-Id header", None),
746 );
747 return jsonrpc_error_to_unified_body(error);
748 }
749 };
750
751 if let Err(err) = self.validate_session_exists(&session_id).await {
753 error!(
754 "Session validation failed for Session ID {}: {}",
755 session_id, err
756 );
757 let error = JsonRpcError::new(
758 None,
759 JsonRpcErrorObject::server_error(
760 -32003,
761 &format!("Session validation failed: {}", err),
762 None,
763 ),
764 );
765 return jsonrpc_error_to_unified_body(error);
766 }
767
768 let last_event_id = extract_last_event_id(headers);
770
771 let connection_id = Uuid::now_v7().as_simple().to_string();
773
774 debug!(
775 "Creating SSE stream for session: {} with connection: {}, last_event_id: {:?}",
776 session_id, connection_id, last_event_id
777 );
778
779 match self
781 .stream_manager
782 .handle_sse_connection(session_id, connection_id, last_event_id)
783 .await
784 {
785 Ok(response) => Ok(response),
786 Err(err) => {
787 error!("Failed to create SSE connection: {}", err);
788 let error = JsonRpcError::new(
789 None,
790 JsonRpcErrorObject::internal_error(Some(format!(
791 "SSE connection failed: {}",
792 err
793 ))),
794 );
795 jsonrpc_error_to_unified_body(error)
796 }
797 }
798 }
799
800 async fn handle_delete_request<B>(&self, req: Request<B>) -> Result<Response<JsonRpcBody>>
802 where
803 B: http_body::Body<Data = bytes::Bytes, Error = hyper::Error> + Send + 'static,
804 {
805 let session_id = extract_session_id(req.headers());
806
807 debug!("DELETE request - Session: {:?}", session_id);
808
809 if let Some(session_id) = session_id {
810 let closed_connections = self
812 .stream_manager
813 .close_session_connections(&session_id)
814 .await;
815 debug!(
816 "Closed {} SSE connections for session: {}",
817 closed_connections, session_id
818 );
819
820 match self.session_storage.get_session(&session_id).await {
822 Ok(Some(mut session_info)) => {
823 session_info
825 .state
826 .insert("terminated".to_string(), serde_json::Value::Bool(true));
827 session_info.state.insert(
828 "terminated_at".to_string(),
829 serde_json::Value::Number(serde_json::Number::from(
830 chrono::Utc::now().timestamp_millis(),
831 )),
832 );
833 session_info.touch();
834
835 match self.session_storage.update_session(session_info).await {
836 Ok(()) => {
837 debug!(
838 "Session {} marked as terminated (TTL will handle cleanup)",
839 session_id
840 );
841 Ok(Response::builder()
842 .status(StatusCode::OK)
843 .body(Full::new(Bytes::from("Session terminated")))
844 .unwrap())
845 }
846 Err(err) => {
847 error!(
848 "Error marking session {} as terminated: {}",
849 session_id, err
850 );
851 match self.session_storage.delete_session(&session_id).await {
853 Ok(_) => {
854 debug!("Session {} deleted as fallback", session_id);
855 Ok(Response::builder()
856 .status(StatusCode::OK)
857 .body(Full::new(Bytes::from("Session removed")))
858 .unwrap())
859 }
860 Err(delete_err) => {
861 error!(
862 "Error deleting session {} as fallback: {}",
863 session_id, delete_err
864 );
865 Ok(Response::builder()
866 .status(StatusCode::INTERNAL_SERVER_ERROR)
867 .body(Full::new(Bytes::from("Session termination error")))
868 .unwrap())
869 }
870 }
871 }
872 }
873 }
874 Ok(None) => Ok(Response::builder()
875 .status(StatusCode::NOT_FOUND)
876 .body(Full::new(Bytes::from("Session not found")))
877 .unwrap()),
878 Err(err) => {
879 error!(
880 "Error retrieving session {} for termination: {}",
881 session_id, err
882 );
883 Ok(Response::builder()
884 .status(StatusCode::INTERNAL_SERVER_ERROR)
885 .body(Full::new(Bytes::from("Session lookup error")))
886 .unwrap())
887 }
888 }
889 } else {
890 Ok(Response::builder()
891 .status(StatusCode::BAD_REQUEST)
892 .body(Full::new(Bytes::from("Missing Mcp-Session-Id header")))
893 .unwrap())
894 }
895 }
896
897 fn handle_preflight(&self) -> Response<JsonRpcBody> {
899 options_response()
900 }
901
902 fn method_not_allowed(&self) -> Response<JsonRpcBody> {
904 method_not_allowed_response()
905 }
906
907 async fn validate_session_exists(&self, session_id: &str) -> Result<()> {
909 match self.session_storage.get_session(session_id).await {
911 Ok(Some(session_info)) => {
912 if session_info.is_terminated() {
913 error!("Session '{}' has been terminated", session_id);
914 return Err(crate::HttpMcpError::InvalidRequest(format!(
915 "Session '{}' has been terminated. Create a new session to continue.",
916 session_id
917 )));
918 }
919 debug!("Session validation successful: {}", session_id);
920 Ok(())
921 }
922 Ok(None) => {
923 error!("Session not found: {}", session_id);
924 Err(crate::HttpMcpError::InvalidRequest(format!(
925 "Session '{}' not found. Sessions must be created via initialize request first.",
926 session_id
927 )))
928 }
929 Err(err) => {
930 error!("Failed to validate session {}: {}", session_id, err);
931 Err(crate::HttpMcpError::InvalidRequest(format!(
932 "Session validation failed: {}",
933 err
934 )))
935 }
936 }
937 }
938
939 async fn run_middleware_and_dispatch(
942 &self,
943 request: turul_mcp_json_rpc_server::JsonRpcRequest,
944 headers: HashMap<String, String>,
945 session: turul_mcp_json_rpc_server::SessionContext,
946 pre_session_extensions: Option<HashMap<String, serde_json::Value>>,
947 ) -> (
948 turul_mcp_json_rpc_server::JsonRpcMessage,
949 Option<crate::middleware::SessionInjection>,
950 ) {
951 if self.middleware_stack.is_empty() {
953 let result = self
954 .dispatcher
955 .handle_request_with_context(request, session)
956 .await;
957 return (result, None);
958 }
959
960 let normalized_headers: HashMap<String, String> = headers
962 .iter()
963 .map(|(k, v)| (k.to_lowercase(), v.clone()))
964 .collect();
965
966 let method = request.method.clone();
969 let session_id = session.session_id.clone();
970
971 let params = request.params.clone().map(|p| match p {
973 turul_mcp_json_rpc_server::RequestParams::Object(map) => {
974 serde_json::Value::Object(map.into_iter().collect())
975 }
976 turul_mcp_json_rpc_server::RequestParams::Array(arr) => serde_json::Value::Array(arr),
977 });
978 let mut ctx = crate::middleware::RequestContext::new(&method, params);
979
980 if let Some(ext) = pre_session_extensions {
982 for (k, v) in ext {
983 ctx.set_extension(k, v);
984 }
985 }
986
987 for (k, v) in normalized_headers {
988 if k == "authorization" && is_bearer_scheme(&v) {
990 continue;
991 }
992 ctx.add_metadata(k, serde_json::json!(v));
993 }
994
995 let session_view = crate::middleware::StorageBackedSessionView::new(
997 session_id.clone(),
998 Arc::clone(&self.session_storage),
999 );
1000
1001 let injection = match self
1003 .middleware_stack
1004 .execute_before(&mut ctx, Some(&session_view))
1005 .await
1006 {
1007 Ok(inj) => inj,
1008 Err(err) => {
1009 return (Self::map_middleware_error_to_jsonrpc(err, request.id), None);
1011 }
1012 };
1013
1014 if !injection.is_empty() {
1016 for (key, value) in injection.state() {
1017 if let Err(e) = session_view.set_state(key, value.clone()).await {
1018 tracing::warn!("Failed to apply injection state '{}': {}", key, e);
1019 }
1020 }
1021 for (key, value) in injection.metadata() {
1022 if let Err(e) = session_view.set_metadata(key, value.clone()).await {
1023 tracing::warn!("Failed to apply injection metadata '{}': {}", key, e);
1024 }
1025 }
1026 }
1027
1028 let mut session = session;
1030 session.extensions = ctx.extensions().clone();
1031
1032 let result = self
1034 .dispatcher
1035 .handle_request_with_context(request, session)
1036 .await;
1037
1038 let mut dispatcher_result = match &result {
1041 turul_mcp_json_rpc_server::JsonRpcMessage::Response(resp) => match &resp.result {
1042 turul_mcp_json_rpc_server::response::ResponseResult::Success(val) => {
1043 crate::middleware::DispatcherResult::Success(val.clone())
1044 }
1045 turul_mcp_json_rpc_server::response::ResponseResult::Null => {
1046 crate::middleware::DispatcherResult::Success(serde_json::Value::Null)
1047 }
1048 },
1049 turul_mcp_json_rpc_server::JsonRpcMessage::Error(err) => {
1050 crate::middleware::DispatcherResult::Error(err.error.message.clone())
1051 }
1052 };
1053
1054 let _ = self
1056 .middleware_stack
1057 .execute_after(&ctx, &mut dispatcher_result)
1058 .await;
1059
1060 (result, None) }
1062
1063 fn map_middleware_error_to_jsonrpc(
1065 err: crate::middleware::MiddlewareError,
1066 request_id: turul_mcp_json_rpc_server::RequestId,
1067 ) -> turul_mcp_json_rpc_server::JsonRpcMessage {
1068 use crate::middleware::MiddlewareError;
1069 use crate::middleware::error::error_codes;
1070
1071 let (code, message, data) = match err {
1072 MiddlewareError::Unauthenticated(msg) => (error_codes::UNAUTHENTICATED, msg, None),
1073 MiddlewareError::Unauthorized(msg) => (error_codes::UNAUTHORIZED, msg, None),
1074 MiddlewareError::RateLimitExceeded {
1075 message,
1076 retry_after,
1077 } => {
1078 let data = retry_after.map(|s| serde_json::json!({"retryAfter": s}));
1079 (error_codes::RATE_LIMIT_EXCEEDED, message, data)
1080 }
1081 MiddlewareError::InvalidRequest(msg) => (error_codes::INVALID_REQUEST, msg, None),
1082 MiddlewareError::Internal(msg) => (error_codes::INTERNAL_ERROR, msg, None),
1083 MiddlewareError::Custom { message, .. } => (error_codes::INTERNAL_ERROR, message, None),
1084 MiddlewareError::HttpChallenge { .. } => {
1085 unreachable!(
1086 "HttpChallenge must be caught at transport level before JSON-RPC dispatch"
1087 )
1088 }
1089 };
1090
1091 let error_obj = if let Some(d) = data {
1092 turul_mcp_json_rpc_server::error::JsonRpcErrorObject::server_error(
1093 code,
1094 &message,
1095 Some(d),
1096 )
1097 } else {
1098 turul_mcp_json_rpc_server::error::JsonRpcErrorObject::server_error(
1099 code,
1100 &message,
1101 None::<serde_json::Value>,
1102 )
1103 };
1104
1105 turul_mcp_json_rpc_server::JsonRpcMessage::Error(
1106 turul_mcp_json_rpc_server::JsonRpcError::new(Some(request_id), error_obj),
1107 )
1108 }
1109}
1110
1111#[cfg(test)]
1112mod tests {
1113 use super::*;
1114
1115 #[tokio::test]
1116 async fn test_validate_session_exists_rejects_terminated() {
1117 let storage: Arc<turul_mcp_session_storage::BoxedSessionStorage> =
1118 Arc::new(InMemorySessionStorage::new());
1119
1120 let mut session = storage
1122 .create_session(turul_mcp_protocol::ServerCapabilities::default())
1123 .await
1124 .unwrap();
1125 let session_id = session.session_id.clone();
1126 session
1127 .state
1128 .insert("terminated".to_string(), serde_json::json!(true));
1129 storage.update_session(session).await.unwrap();
1130
1131 let dispatcher = Arc::new(JsonRpcDispatcher::<McpError>::default());
1133 let handler = SessionMcpHandler::with_storage(
1134 crate::server::ServerConfig::default(),
1135 dispatcher,
1136 storage,
1137 crate::stream_manager::StreamConfig::default(),
1138 Arc::new(crate::middleware::MiddlewareStack::new()),
1139 );
1140
1141 let result = handler.validate_session_exists(&session_id).await;
1143 assert!(result.is_err(), "Expected Err for terminated session");
1144 let err_msg = format!("{}", result.unwrap_err());
1145 assert!(
1146 err_msg.to_lowercase().contains("terminated"),
1147 "Error must mention 'terminated', got: {}",
1148 err_msg
1149 );
1150 }
1151}