1use std::convert::Infallible;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::task::{Context, Poll};
15use tracing::{debug, error, warn};
16
17use bytes::Bytes;
18use futures::Stream;
19use http_body::{Body, Frame};
20use http_body_util::{BodyExt, Full};
21use hyper::header::{ACCEPT, CONTENT_TYPE};
22use hyper::{Method, Request, Response, StatusCode};
23
24use chrono;
25use turul_mcp_json_rpc_server::{
26 JsonRpcDispatcher,
27 r#async::SessionContext,
28 dispatch::{JsonRpcMessage, JsonRpcMessageResult, parse_json_rpc_message},
29 error::{JsonRpcError, JsonRpcErrorObject},
30};
31use turul_mcp_protocol::McpError;
32use turul_mcp_protocol::ServerCapabilities;
33use turul_mcp_session_storage::InMemorySessionStorage;
34use uuid::Uuid;
35
36use crate::{
37 Result, ServerConfig, StreamConfig, StreamManager,
38 json_rpc_responses::*,
39 notification_bridge::{SharedNotificationBroadcaster, StreamManagerNotificationBroadcaster},
40 protocol::{extract_last_event_id, extract_protocol_version, extract_session_id},
41};
42
43pub struct SessionSseStream {
45 stream: Pin<Box<dyn Stream<Item = std::result::Result<Bytes, Infallible>> + Send>>,
46}
47
48impl SessionSseStream {
49 pub fn new<S>(stream: S) -> Self
50 where
51 S: Stream<Item = std::result::Result<Bytes, Infallible>> + Send + 'static,
52 {
53 Self {
54 stream: Box::pin(stream),
55 }
56 }
57}
58
59impl Drop for SessionSseStream {
60 fn drop(&mut self) {
61 debug!("DROP: SessionSseStream - HTTP response body being cleaned up");
62 debug!("This may indicate early cleanup of SSE response stream");
63 }
64}
65
66impl Body for SessionSseStream {
67 type Data = Bytes;
68 type Error = Infallible;
69
70 fn poll_frame(
71 mut self: Pin<&mut Self>,
72 cx: &mut Context<'_>,
73 ) -> Poll<Option<std::result::Result<Frame<Self::Data>, Self::Error>>> {
74 match self.stream.as_mut().poll_next(cx) {
75 Poll::Ready(Some(Ok(data))) => Poll::Ready(Some(Ok(Frame::data(data)))),
76 Poll::Ready(Some(Err(never))) => match never {},
77 Poll::Ready(None) => Poll::Ready(None),
78 Poll::Pending => Poll::Pending,
79 }
80 }
81}
82
83type JsonRpcBody = Full<Bytes>;
85
86type UnifiedMcpBody = http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>;
88
89#[derive(Debug, Clone, PartialEq)]
91enum AcceptMode {
92 Compliant,
94 JsonOnly,
96 SseOnly,
98 Invalid,
100}
101
102fn parse_mcp_accept_header(accept_header: &str) -> (AcceptMode, bool) {
104 let accepts_json = accept_header.contains("application/json") || accept_header.contains("*/*");
105 let accepts_sse = accept_header.contains("text/event-stream");
106
107 let mode = match (accepts_json, accepts_sse) {
108 (true, true) => AcceptMode::Compliant,
109 (true, false) => AcceptMode::JsonOnly, (false, true) => AcceptMode::SseOnly,
111 (false, false) => AcceptMode::Invalid,
112 };
113
114 let should_use_sse = match mode {
117 AcceptMode::Compliant => true, AcceptMode::JsonOnly => false, AcceptMode::SseOnly => true, AcceptMode::Invalid => false, };
122
123 (mode, should_use_sse)
124}
125
126fn convert_to_unified_body(full_body: Full<Bytes>) -> UnifiedMcpBody {
128 full_body.map_err(|never| match never {}).boxed_unsync()
129}
130
131fn jsonrpc_error_to_unified_body(error: JsonRpcError) -> Result<Response<UnifiedMcpBody>> {
133 let error_json = serde_json::to_string(&error)?;
134 Ok(Response::builder()
135 .status(StatusCode::OK) .header(CONTENT_TYPE, "application/json")
137 .body(convert_to_unified_body(Full::new(Bytes::from(error_json))))
138 .unwrap())
139}
140
141pub struct SessionMcpHandler {
145 pub(crate) config: ServerConfig,
146 pub(crate) dispatcher: Arc<JsonRpcDispatcher<McpError>>,
147 pub(crate) session_storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
148 pub(crate) stream_config: StreamConfig,
149 pub(crate) stream_manager: Arc<StreamManager>,
151}
152
153impl Clone for SessionMcpHandler {
154 fn clone(&self) -> Self {
155 Self {
156 config: self.config.clone(),
157 dispatcher: Arc::clone(&self.dispatcher),
158 session_storage: Arc::clone(&self.session_storage),
159 stream_config: self.stream_config.clone(),
160 stream_manager: Arc::clone(&self.stream_manager),
161 }
162 }
163}
164
165impl SessionMcpHandler {
166 pub fn new(
168 config: ServerConfig,
169 dispatcher: Arc<JsonRpcDispatcher<McpError>>,
170 stream_config: StreamConfig,
171 ) -> Self {
172 let storage: Arc<turul_mcp_session_storage::BoxedSessionStorage> =
173 Arc::new(InMemorySessionStorage::new());
174 Self::with_storage(config, dispatcher, storage, stream_config)
175 }
176
177 pub fn with_shared_stream_manager(
179 config: ServerConfig,
180 dispatcher: Arc<JsonRpcDispatcher<McpError>>,
181 session_storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
182 stream_config: StreamConfig,
183 stream_manager: Arc<StreamManager>,
184 ) -> Self {
185 Self {
186 config,
187 dispatcher,
188 session_storage,
189 stream_config,
190 stream_manager,
191 }
192 }
193
194 pub fn with_storage(
197 config: ServerConfig,
198 dispatcher: Arc<JsonRpcDispatcher<McpError>>,
199 session_storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
200 stream_config: StreamConfig,
201 ) -> Self {
202 let stream_manager = Arc::new(StreamManager::with_config(
204 Arc::clone(&session_storage),
205 stream_config.clone(),
206 ));
207
208 Self {
209 config,
210 dispatcher,
211 session_storage,
212 stream_config,
213 stream_manager,
214 }
215 }
216
217 pub fn get_stream_manager(&self) -> &Arc<StreamManager> {
219 &self.stream_manager
220 }
221
222 pub async fn handle_mcp_request<B>(&self, req: Request<B>) -> Result<Response<UnifiedMcpBody>>
224 where
225 B: http_body::Body<Data = bytes::Bytes, Error = hyper::Error> + Send + 'static,
226 {
227 debug!(
228 "SESSION HANDLER processing {} {}",
229 req.method(),
230 req.uri().path()
231 );
232 match *req.method() {
233 Method::POST => {
234 let response = self.handle_json_rpc_request(req).await?;
235 Ok(response)
236 }
237 Method::GET => self.handle_sse_request(req).await,
238 Method::DELETE => {
239 let response = self.handle_delete_request(req).await?;
240 Ok(response.map(convert_to_unified_body))
241 }
242 Method::OPTIONS => {
243 let response = self.handle_preflight();
244 Ok(response.map(convert_to_unified_body))
245 }
246 _ => {
247 let response = self.method_not_allowed();
248 Ok(response.map(convert_to_unified_body))
249 }
250 }
251 }
252
253 async fn handle_json_rpc_request<B>(&self, req: Request<B>) -> Result<Response<UnifiedMcpBody>>
255 where
256 B: http_body::Body<Data = bytes::Bytes, Error = hyper::Error> + Send + 'static,
257 {
258 let protocol_version = extract_protocol_version(req.headers());
260 let session_id = extract_session_id(req.headers());
261
262 debug!(
263 "POST request - Protocol: {}, Session: {:?}",
264 protocol_version, session_id
265 );
266
267 let content_type = req
269 .headers()
270 .get(CONTENT_TYPE)
271 .and_then(|ct| ct.to_str().ok())
272 .unwrap_or("");
273
274 if !content_type.starts_with("application/json") {
275 warn!("Invalid content type: {}", content_type);
276 return Ok(
277 bad_request_response("Content-Type must be application/json")
278 .map(convert_to_unified_body),
279 );
280 }
281
282 let accept_header = req
284 .headers()
285 .get(ACCEPT)
286 .and_then(|accept| accept.to_str().ok())
287 .unwrap_or("application/json");
288
289 let (accept_mode, accepts_sse) = parse_mcp_accept_header(accept_header);
290 debug!(
291 "POST request Accept header: '{}', mode: {:?}, will use SSE for tool calls: {}",
292 accept_header, accept_mode, accepts_sse
293 );
294
295 let body = req.into_body();
297 let body_bytes = match body.collect().await {
298 Ok(collected) => collected.to_bytes(),
299 Err(err) => {
300 error!("Failed to read request body: {}", err);
301 return Ok(bad_request_response("Failed to read request body")
302 .map(convert_to_unified_body));
303 }
304 };
305
306 if body_bytes.len() > self.config.max_body_size {
308 warn!("Request body too large: {} bytes", body_bytes.len());
309 return Ok(Response::builder()
310 .status(StatusCode::PAYLOAD_TOO_LARGE)
311 .header(CONTENT_TYPE, "application/json")
312 .body(convert_to_unified_body(Full::new(Bytes::from(
313 "Request body too large",
314 ))))
315 .unwrap());
316 }
317
318 let body_str = match std::str::from_utf8(&body_bytes) {
320 Ok(s) => s,
321 Err(err) => {
322 error!("Invalid UTF-8 in request body: {}", err);
323 return Ok(bad_request_response("Request body must be valid UTF-8")
324 .map(convert_to_unified_body));
325 }
326 };
327
328 debug!("Received JSON-RPC request: {}", body_str);
329
330 let message = match parse_json_rpc_message(body_str) {
332 Ok(msg) => msg,
333 Err(rpc_err) => {
334 error!("JSON-RPC parse error: {}", rpc_err);
335 let error_response =
337 serde_json::to_string(&rpc_err).unwrap_or_else(|_| "{}".to_string());
338 return Ok(Response::builder()
339 .status(StatusCode::OK) .header(CONTENT_TYPE, "application/json")
341 .body(convert_to_unified_body(Full::new(Bytes::from(
342 error_response,
343 ))))
344 .unwrap());
345 }
346 };
347
348 let (message_result, response_session_id, method_name) = match message {
350 JsonRpcMessage::Request(request) => {
351 debug!("Processing JSON-RPC request: method={}", request.method);
352 let method_name = request.method.clone();
353
354 let (response, response_session_id) = if request.method == "initialize" {
356 debug!(
357 "Handling initialize request - creating new session via session storage"
358 );
359
360 let capabilities = ServerCapabilities::default();
362 match self.session_storage.create_session(capabilities).await {
363 Ok(session_info) => {
364 debug!(
365 "Created new session via session storage: {}",
366 session_info.session_id
367 );
368
369 let broadcaster: SharedNotificationBroadcaster =
371 Arc::new(StreamManagerNotificationBroadcaster::new(Arc::clone(
372 &self.stream_manager,
373 )));
374 let broadcaster_any =
375 Arc::new(broadcaster) as Arc<dyn std::any::Any + Send + Sync>;
376
377 let session_context = SessionContext {
378 session_id: session_info.session_id.clone(),
379 metadata: std::collections::HashMap::new(),
380 broadcaster: Some(broadcaster_any),
381 timestamp: chrono::Utc::now().timestamp_millis() as u64,
382 };
383
384 let response = self
385 .dispatcher
386 .handle_request_with_context(request, session_context)
387 .await;
388
389 (response, Some(session_info.session_id))
391 }
392 Err(err) => {
393 error!("Failed to create session during initialize: {}", err);
394 let error_msg = format!("Session creation failed: {}", err);
396 let error_response = turul_mcp_json_rpc_server::JsonRpcMessage::error(
397 turul_mcp_json_rpc_server::JsonRpcError::internal_error(
398 Some(request.id),
399 Some(error_msg),
400 ),
401 );
402 (error_response, None)
403 }
404 }
405 } else {
406 let session_context = if let Some(ref session_id_str) = session_id {
409 debug!("Processing request with session: {}", session_id_str);
410 let broadcaster: SharedNotificationBroadcaster =
411 Arc::new(StreamManagerNotificationBroadcaster::new(Arc::clone(
412 &self.stream_manager,
413 )));
414 let broadcaster_any =
415 Arc::new(broadcaster) as Arc<dyn std::any::Any + Send + Sync>;
416 Some(SessionContext {
417 session_id: session_id_str.clone(),
418 metadata: std::collections::HashMap::new(),
419 broadcaster: Some(broadcaster_any),
420 timestamp: chrono::Utc::now().timestamp_millis() as u64,
421 })
422 } else {
423 debug!("Processing request without session (lenient mode)");
424 None
425 };
426
427 let response = if let Some(ctx) = session_context {
428 self.dispatcher
429 .handle_request_with_context(request, ctx)
430 .await
431 } else {
432 self.dispatcher.handle_request(request).await
433 };
434 (response, session_id)
435 };
436
437 let message_result = match response {
439 turul_mcp_json_rpc_server::JsonRpcMessage::Response(resp) => {
440 JsonRpcMessageResult::Response(resp)
441 }
442 turul_mcp_json_rpc_server::JsonRpcMessage::Error(err) => {
443 JsonRpcMessageResult::Error(err)
444 }
445 };
446 (message_result, response_session_id, Some(method_name))
447 }
448 JsonRpcMessage::Notification(notification) => {
449 debug!(
450 "Processing JSON-RPC notification: method={}",
451 notification.method
452 );
453 let method_name = notification.method.clone();
454
455 let session_context = if let Some(ref session_id_str) = session_id {
458 debug!("Processing notification with session: {}", session_id_str);
459 let broadcaster: SharedNotificationBroadcaster = Arc::new(
460 StreamManagerNotificationBroadcaster::new(Arc::clone(&self.stream_manager)),
461 );
462 let broadcaster_any =
463 Arc::new(broadcaster) as Arc<dyn std::any::Any + Send + Sync>;
464
465 Some(SessionContext {
466 session_id: session_id_str.clone(),
467 metadata: std::collections::HashMap::new(),
468 broadcaster: Some(broadcaster_any),
469 timestamp: chrono::Utc::now().timestamp_millis() as u64,
470 })
471 } else {
472 debug!("Processing notification without session (lenient mode)");
473 None
474 };
475
476 let result = self
477 .dispatcher
478 .handle_notification_with_context(notification, session_context)
479 .await;
480
481 if let Err(err) = result {
482 error!("Notification handling error: {}", err);
483 }
484 (
485 JsonRpcMessageResult::NoResponse,
486 session_id.clone(),
487 Some(method_name),
488 )
489 }
490 };
491
492 match message_result {
494 JsonRpcMessageResult::Response(response) => {
495 let is_tool_call = method_name.as_ref().is_some_and(|m| m == "tools/call");
498
499 debug!(
500 "Decision point: method={:?}, accept_mode={:?}, accepts_sse={}, server_post_sse_enabled={}, session_id={:?}, is_tool_call={}",
501 method_name,
502 accept_mode,
503 accepts_sse,
504 self.config.enable_post_sse,
505 response_session_id,
506 is_tool_call
507 );
508
509 let should_use_sse = match accept_mode {
511 AcceptMode::JsonOnly => false, AcceptMode::Invalid => false, AcceptMode::Compliant => {
514 self.config.enable_post_sse && accepts_sse && is_tool_call
515 } AcceptMode::SseOnly => self.config.enable_post_sse && accepts_sse, };
518
519 if should_use_sse && response_session_id.is_some() {
520 debug!(
521 "📡 Creating POST SSE stream (mode: {:?}) for tool call with notifications",
522 accept_mode
523 );
524 match self
525 .stream_manager
526 .create_post_sse_stream(
527 response_session_id.clone().unwrap(),
528 response.clone(), )
530 .await
531 {
532 Ok(sse_response) => {
533 debug!("✅ POST SSE stream created successfully");
534 Ok(sse_response
535 .map(|body| body.map_err(|never| match never {}).boxed_unsync()))
536 }
537 Err(e) => {
538 warn!(
539 "Failed to create POST SSE stream, falling back to JSON: {}",
540 e
541 );
542 Ok(
543 jsonrpc_response_with_session(response, response_session_id)?
544 .map(convert_to_unified_body),
545 )
546 }
547 }
548 } else {
549 debug!(
550 "📄 Returning standard JSON response (mode: {:?}) for method: {:?}",
551 accept_mode, method_name
552 );
553 Ok(
554 jsonrpc_response_with_session(response, response_session_id)?
555 .map(convert_to_unified_body),
556 )
557 }
558 }
559 JsonRpcMessageResult::Error(error) => {
560 warn!("Sending JSON-RPC error response");
561 let error_json = serde_json::to_string(&error)?;
563 Ok(Response::builder()
564 .status(StatusCode::OK) .header(CONTENT_TYPE, "application/json")
566 .body(convert_to_unified_body(Full::new(Bytes::from(error_json))))
567 .unwrap())
568 }
569 JsonRpcMessageResult::NoResponse => {
570 Ok(jsonrpc_notification_response()?.map(convert_to_unified_body))
572 }
573 }
574 }
575
576 async fn handle_sse_request<B>(&self, req: Request<B>) -> Result<Response<UnifiedMcpBody>>
581 where
582 B: http_body::Body<Data = bytes::Bytes, Error = hyper::Error> + Send + 'static,
583 {
584 let headers = req.headers();
586 let accept = headers
587 .get(ACCEPT)
588 .and_then(|accept| accept.to_str().ok())
589 .unwrap_or("");
590
591 if !accept.contains("text/event-stream") {
592 warn!(
593 "GET request received without SSE support - header does not contain 'text/event-stream'"
594 );
595 let error = JsonRpcError::new(
596 None,
597 JsonRpcErrorObject::server_error(
598 -32001,
599 "SSE not accepted - missing 'text/event-stream' in Accept header",
600 None,
601 ),
602 );
603 return jsonrpc_error_to_unified_body(error);
604 }
605
606 if !self.config.enable_get_sse {
608 warn!("GET SSE request received but GET SSE is disabled on server");
609 let error = JsonRpcError::new(
610 None,
611 JsonRpcErrorObject::server_error(
612 -32003,
613 "GET SSE is disabled on this server",
614 None,
615 ),
616 );
617 return jsonrpc_error_to_unified_body(error);
618 }
619
620 let protocol_version = extract_protocol_version(headers);
622 let session_id = extract_session_id(headers);
623
624 debug!(
625 "GET SSE request - Protocol: {}, Session: {:?}",
626 protocol_version, session_id
627 );
628
629 let session_id = match session_id {
631 Some(id) => id,
632 None => {
633 warn!("Missing Mcp-Session-Id header for SSE request");
634 let error = JsonRpcError::new(
635 None,
636 JsonRpcErrorObject::server_error(-32002, "Missing Mcp-Session-Id header", None),
637 );
638 return jsonrpc_error_to_unified_body(error);
639 }
640 };
641
642 if let Err(err) = self.validate_session_exists(&session_id).await {
644 error!(
645 "Session validation failed for Session ID {}: {}",
646 session_id, err
647 );
648 let error = JsonRpcError::new(
649 None,
650 JsonRpcErrorObject::server_error(
651 -32003,
652 &format!("Session validation failed: {}", err),
653 None,
654 ),
655 );
656 return jsonrpc_error_to_unified_body(error);
657 }
658
659 let last_event_id = extract_last_event_id(headers);
661
662 let connection_id = Uuid::now_v7().to_string();
664
665 debug!(
666 "Creating SSE stream for session: {} with connection: {}, last_event_id: {:?}",
667 session_id, connection_id, last_event_id
668 );
669
670 match self
672 .stream_manager
673 .handle_sse_connection(session_id, connection_id, last_event_id)
674 .await
675 {
676 Ok(response) => Ok(response),
677 Err(err) => {
678 error!("Failed to create SSE connection: {}", err);
679 let error = JsonRpcError::new(
680 None,
681 JsonRpcErrorObject::internal_error(Some(format!(
682 "SSE connection failed: {}",
683 err
684 ))),
685 );
686 jsonrpc_error_to_unified_body(error)
687 }
688 }
689 }
690
691 async fn handle_delete_request<B>(&self, req: Request<B>) -> Result<Response<JsonRpcBody>>
693 where
694 B: http_body::Body<Data = bytes::Bytes, Error = hyper::Error> + Send + 'static,
695 {
696 let session_id = extract_session_id(req.headers());
697
698 debug!("DELETE request - Session: {:?}", session_id);
699
700 if let Some(session_id) = session_id {
701 let closed_connections = self
703 .stream_manager
704 .close_session_connections(&session_id)
705 .await;
706 debug!(
707 "Closed {} SSE connections for session: {}",
708 closed_connections, session_id
709 );
710
711 match self.session_storage.get_session(&session_id).await {
713 Ok(Some(mut session_info)) => {
714 session_info
716 .state
717 .insert("terminated".to_string(), serde_json::Value::Bool(true));
718 session_info.state.insert(
719 "terminated_at".to_string(),
720 serde_json::Value::Number(serde_json::Number::from(
721 chrono::Utc::now().timestamp_millis(),
722 )),
723 );
724 session_info.touch();
725
726 match self.session_storage.update_session(session_info).await {
727 Ok(()) => {
728 debug!(
729 "Session {} marked as terminated (TTL will handle cleanup)",
730 session_id
731 );
732 Ok(Response::builder()
733 .status(StatusCode::OK)
734 .body(Full::new(Bytes::from("Session terminated")))
735 .unwrap())
736 }
737 Err(err) => {
738 error!(
739 "Error marking session {} as terminated: {}",
740 session_id, err
741 );
742 match self.session_storage.delete_session(&session_id).await {
744 Ok(_) => {
745 debug!("Session {} deleted as fallback", session_id);
746 Ok(Response::builder()
747 .status(StatusCode::OK)
748 .body(Full::new(Bytes::from("Session removed")))
749 .unwrap())
750 }
751 Err(delete_err) => {
752 error!(
753 "Error deleting session {} as fallback: {}",
754 session_id, delete_err
755 );
756 Ok(Response::builder()
757 .status(StatusCode::INTERNAL_SERVER_ERROR)
758 .body(Full::new(Bytes::from("Session termination error")))
759 .unwrap())
760 }
761 }
762 }
763 }
764 }
765 Ok(None) => Ok(Response::builder()
766 .status(StatusCode::NOT_FOUND)
767 .body(Full::new(Bytes::from("Session not found")))
768 .unwrap()),
769 Err(err) => {
770 error!(
771 "Error retrieving session {} for termination: {}",
772 session_id, err
773 );
774 Ok(Response::builder()
775 .status(StatusCode::INTERNAL_SERVER_ERROR)
776 .body(Full::new(Bytes::from("Session lookup error")))
777 .unwrap())
778 }
779 }
780 } else {
781 Ok(Response::builder()
782 .status(StatusCode::BAD_REQUEST)
783 .body(Full::new(Bytes::from("Missing Mcp-Session-Id header")))
784 .unwrap())
785 }
786 }
787
788 fn handle_preflight(&self) -> Response<JsonRpcBody> {
790 options_response()
791 }
792
793 fn method_not_allowed(&self) -> Response<JsonRpcBody> {
795 method_not_allowed_response()
796 }
797
798 async fn validate_session_exists(&self, session_id: &str) -> Result<()> {
800 match self.session_storage.get_session(session_id).await {
802 Ok(Some(_)) => {
803 debug!("Session validation successful: {}", session_id);
804 Ok(())
805 }
806 Ok(None) => {
807 error!("Session not found: {}", session_id);
808 Err(crate::HttpMcpError::InvalidRequest(format!(
809 "Session '{}' not found. Sessions must be created via initialize request first.",
810 session_id
811 )))
812 }
813 Err(err) => {
814 error!("Failed to validate session {}: {}", session_id, err);
815 Err(crate::HttpMcpError::InvalidRequest(format!(
816 "Session validation failed: {}",
817 err
818 )))
819 }
820 }
821 }
822}