1use bytes::Bytes;
11use futures::{Stream, StreamExt};
12use http_body_util::{BodyExt, StreamBody};
13use hyper::header::{ACCESS_CONTROL_ALLOW_ORIGIN, CACHE_CONTROL, CONTENT_TYPE};
14use hyper::{Response, StatusCode};
15use serde_json::Value;
16use std::collections::{HashMap, HashSet};
17use std::pin::Pin;
18use std::sync::Arc;
19use tokio::sync::{RwLock, mpsc};
20use tracing::{debug, error, warn};
21
22use turul_mcp_session_storage::SseEvent;
23
24pub type ConnectionId = String;
26pub type SessionConnections = HashMap<ConnectionId, mpsc::Sender<SseEvent>>;
27pub type ConnectionsMap = Arc<RwLock<HashMap<String, SessionConnections>>>;
28
29pub struct StreamManager {
31 storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
33 connections: ConnectionsMap,
35 subscriptions: Arc<RwLock<HashMap<String, HashSet<String>>>>,
37 config: StreamConfig,
39 instance_id: String,
41}
42
43#[derive(Debug, Clone)]
45pub struct StreamConfig {
46 pub channel_buffer_size: usize,
48 pub max_replay_events: usize,
50 pub keepalive_interval_seconds: u64,
52 pub cors_origin: String,
54}
55
56impl Default for StreamConfig {
57 fn default() -> Self {
58 Self {
59 channel_buffer_size: 1000,
60 max_replay_events: 100,
61 keepalive_interval_seconds: 30,
62 cors_origin: "*".to_string(),
63 }
64 }
65}
66
67pub struct SseStream {
69 stream: Option<Pin<Box<dyn Stream<Item = SseEvent> + Send>>>,
71 session_id: String,
73 connection_id: ConnectionId,
75}
76
77impl SseStream {
78 pub fn session_id(&self) -> &str {
80 &self.session_id
81 }
82
83 pub fn connection_id(&self) -> &str {
85 &self.connection_id
86 }
87
88 pub fn stream_identifier(&self) -> String {
90 format!("{}:{}", self.session_id, self.connection_id)
91 }
92}
93
94impl Drop for SseStream {
95 fn drop(&mut self) {
96 debug!(
97 "DROP: SseStream - session={}, connection={}",
98 self.session_id, self.connection_id
99 );
100 if self.stream.is_some() {
101 debug!("Stream still present during drop - this indicates early cleanup");
102 } else {
103 debug!("Stream was properly extracted before drop");
104 }
105 }
106}
107
108#[derive(Debug, thiserror::Error)]
110pub enum StreamError {
111 #[error("Session not found: {0}")]
112 SessionNotFound(String),
113 #[error("Stream not found: session={0}, stream={1}")]
114 StreamNotFound(String, String),
115 #[error("Storage error: {0}")]
116 StorageError(String),
117 #[error("Connection error: {0}")]
118 ConnectionError(String),
119 #[error("No connections available for session: {0}")]
120 NoConnections(String),
121 #[error("Session {0} not subscribed to notification type: {1}")]
122 NotSubscribed(String, String),
123}
124
125impl StreamManager {
126 pub fn new(storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>) -> Self {
128 Self::with_config(storage, StreamConfig::default())
129 }
130
131 pub fn with_config(
133 storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
134 config: StreamConfig,
135 ) -> Self {
136 use uuid::Uuid;
137 let instance_id = Uuid::now_v7().to_string();
138 debug!("Creating StreamManager instance: {}", instance_id);
139 Self {
140 storage,
141 connections: Arc::new(RwLock::new(HashMap::new())),
142 subscriptions: Arc::new(RwLock::new(HashMap::new())),
143 config,
144 instance_id,
145 }
146 }
147
148 pub async fn handle_sse_connection(
150 &self,
151 session_id: String,
152 connection_id: ConnectionId,
153 last_event_id: Option<u64>,
154 ) -> Result<
155 Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>>,
156 StreamError,
157 > {
158 if self
160 .storage
161 .get_session(&session_id)
162 .await
163 .map_err(|e| StreamError::StorageError(e.to_string()))?
164 .is_none()
165 {
166 return Err(StreamError::SessionNotFound(session_id));
167 }
168
169 let sse_stream = self
171 .create_sse_stream(session_id.clone(), connection_id.clone(), last_event_id)
172 .await?;
173
174 let response = self.stream_to_response(sse_stream).await;
176
177 debug!(
178 "Created SSE connection: session={}, connection={}, last_event_id={:?}",
179 session_id, connection_id, last_event_id
180 );
181
182 Ok(response)
183 }
184
185 async fn create_sse_stream(
187 &self,
188 session_id: String,
189 connection_id: ConnectionId,
190 last_event_id: Option<u64>,
191 ) -> Result<SseStream, StreamError> {
192 let (sender, mut receiver) = mpsc::channel(self.config.channel_buffer_size);
194
195 self.register_connection(&session_id, connection_id.clone(), sender)
197 .await;
198
199 let storage = self.storage.clone();
201 let session_id_clone = session_id.clone();
202 let connection_id_clone = connection_id.clone();
203 let config = self.config.clone();
204
205 let combined_stream = async_stream::stream! {
206 if let Some(after_event_id) = last_event_id {
208 debug!("Replaying events after ID {} for session={}, connection={}",
209 after_event_id, session_id_clone, connection_id_clone);
210
211 match storage.get_events_after(&session_id_clone, after_event_id).await {
212 Ok(events) => {
213 for event in events.into_iter().take(config.max_replay_events) {
214 yield event;
215 }
216 },
217 Err(e) => {
218 error!("Failed to get historical events: {}", e);
219 }
221 }
222 }
223
224 let mut keepalive_interval = tokio::time::interval(
226 tokio::time::Duration::from_secs(config.keepalive_interval_seconds)
227 );
228
229 loop {
230 tokio::select! {
231 event = receiver.recv() => {
233 match event {
234 Some(event) => {
235 debug!("Received event for connection {}: {}", connection_id_clone, event.event_type);
236 yield event;
237 },
238 None => {
239 debug!("Connection channel closed for session={}, connection={}", session_id_clone, connection_id_clone);
240 break;
241 }
242 }
243 },
244
245 _ = keepalive_interval.tick() => {
247 let keepalive_event = SseEvent {
248 id: 0, timestamp: chrono::Utc::now().timestamp_millis() as u64,
250 event_type: "ping".to_string(),
251 data: serde_json::json!({"type": "keepalive"}),
252 retry: None,
253 };
254 yield keepalive_event;
255 }
256 }
257 }
258
259 debug!("Cleaning up connection: session={}, connection={}", session_id_clone, connection_id_clone);
261 };
262
263 Ok(SseStream {
264 stream: Some(Box::pin(combined_stream)),
265 session_id,
266 connection_id,
267 })
268 }
269
270 async fn register_connection(
272 &self,
273 session_id: &str,
274 connection_id: ConnectionId,
275 sender: mpsc::Sender<SseEvent>,
276 ) {
277 let mut connections = self.connections.write().await;
278
279 debug!(
280 "[{}] ๐ BEFORE registration: HashMap has {} sessions",
281 self.instance_id,
282 connections.len()
283 );
284 for (sid, conns) in connections.iter() {
285 debug!(
286 "[{}] ๐ Existing session before: {} with {} connections",
287 self.instance_id,
288 sid,
289 conns.len()
290 );
291 }
292
293 let session_connections = connections
295 .entry(session_id.to_string())
296 .or_insert_with(HashMap::new);
297
298 session_connections.insert(connection_id.clone(), sender);
300
301 debug!(
302 "[{}] ๐ Registered connection: session={}, connection={}, total_connections={}",
303 self.instance_id,
304 session_id,
305 connection_id,
306 session_connections.len()
307 );
308
309 debug!(
310 "[{}] ๐ AFTER registration: HashMap has {} sessions",
311 self.instance_id,
312 connections.len()
313 );
314 for (sid, conns) in connections.iter() {
315 debug!(
316 "[{}] ๐ Session after: {} with {} connections",
317 self.instance_id,
318 sid,
319 conns.len()
320 );
321 }
322 }
323
324 pub async fn register_streaming_connection(
326 &self,
327 session_id: &str,
328 connection_id: ConnectionId,
329 sender: mpsc::Sender<SseEvent>,
330 ) -> Result<(), StreamError> {
331 if self
333 .storage
334 .get_session(session_id)
335 .await
336 .map_err(|e| StreamError::StorageError(e.to_string()))?
337 .is_none()
338 {
339 return Err(StreamError::SessionNotFound(session_id.to_string()));
340 }
341
342 self.register_connection(session_id, connection_id, sender)
343 .await;
344 Ok(())
345 }
346
347 pub async fn unregister_connection(&self, session_id: &str, connection_id: &ConnectionId) {
349 debug!(
350 "๐ด UNREGISTER called for session={}, connection={}",
351 session_id, connection_id
352 );
353 let mut connections = self.connections.write().await;
354
355 debug!(
356 "๐ BEFORE unregister: HashMap has {} sessions",
357 connections.len()
358 );
359
360 if let Some(session_connections) = connections.get_mut(session_id)
361 && session_connections.remove(connection_id).is_some()
362 {
363 debug!(
364 "๐ Unregistered connection: session={}, connection={}",
365 session_id, connection_id
366 );
367
368 if session_connections.is_empty() {
370 connections.remove(session_id);
371 debug!("๐งน Removed empty session: {}", session_id);
372 }
373 }
374
375 debug!(
376 "๐ AFTER unregister: HashMap has {} sessions",
377 connections.len()
378 );
379 }
380
381 pub async fn close_session_connections(&self, session_id: &str) -> usize {
383 debug!("๐ด Closing all connections for session: {}", session_id);
384 let mut connections = self.connections.write().await;
385
386 let closed_count = if let Some(session_connections) = connections.remove(session_id) {
387 let count = session_connections.len();
388 debug!(
389 "๐ Closed {} SSE connections for session: {}",
390 count, session_id
391 );
392 count
393 } else {
394 debug!("๐ No SSE connections found for session: {}", session_id);
395 0
396 };
397
398 self.clear_subscriptions(session_id).await;
400
401 debug!("๐งน Session {} removed from stream manager", session_id);
402 closed_count
403 }
404
405 async fn stream_to_response(
407 &self,
408 mut sse_stream: SseStream,
409 ) -> Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>> {
410 let session_id = sse_stream.session_id().to_string();
412 let stream_identifier = sse_stream.stream_identifier();
413
414 debug!(
416 "Converting SSE stream to HTTP response: {}",
417 stream_identifier
418 );
419 debug!("Stream details: session_id={}", session_id);
420
421 let stream = sse_stream
424 .stream
425 .take()
426 .expect("Stream should be present in SseStream");
427
428 let formatted_stream = stream.map(|event| {
429 let sse_formatted = event.format();
430 debug!(
431 "๐ก Streaming SSE event: id={}, event_type={}",
432 event.id, event.event_type
433 );
434 Ok(hyper::body::Frame::data(Bytes::from(sse_formatted)))
435 });
436
437 let body = StreamBody::new(formatted_stream).boxed_unsync();
439
440 Response::builder()
442 .status(StatusCode::OK)
443 .header(CONTENT_TYPE, "text/event-stream")
444 .header(CACHE_CONTROL, "no-cache")
445 .header(ACCESS_CONTROL_ALLOW_ORIGIN, &self.config.cors_origin)
446 .header("Connection", "keep-alive")
447 .body(body)
448 .unwrap()
449 }
450
451 pub async fn has_connections(&self, session_id: &str) -> bool {
453 let connections = self.connections.read().await;
454 connections
455 .get(session_id)
456 .map(|session_connections| !session_connections.is_empty())
457 .unwrap_or(false)
458 }
459
460 pub async fn broadcast_to_session(
462 &self,
463 session_id: &str,
464 event_type: String,
465 data: Value,
466 ) -> Result<u64, StreamError> {
467 self.broadcast_to_session_with_options(session_id, event_type, data, true)
468 .await
469 }
470
471 pub async fn broadcast_to_session_with_options(
473 &self,
474 session_id: &str,
475 event_type: String,
476 data: Value,
477 store_when_no_connections: bool,
478 ) -> Result<u64, StreamError> {
479 if !self.is_subscribed(session_id, &event_type).await {
481 debug!(
482 "๐ซ Session {} not subscribed to notification type: {}",
483 session_id, event_type
484 );
485 return Err(StreamError::NotSubscribed(
486 session_id.to_string(),
487 event_type,
488 ));
489 }
490
491 if !store_when_no_connections && !self.has_connections(session_id).await {
493 debug!(
494 "๐ซ Suppressing notification for session {} (no connections, store_when_no_connections=false)",
495 session_id
496 );
497 return Err(StreamError::NoConnections(session_id.to_string()));
498 }
499
500 let event = SseEvent::new(event_type.clone(), data);
502
503 let stored_event = self
505 .storage
506 .store_event(session_id, event)
507 .await
508 .map_err(|e| StreamError::StorageError(e.to_string()))?;
509
510 let connections = self.connections.read().await;
512 debug!(
513 "[{}] ๐ Checking connections for session {}: connections hashmap has {} sessions",
514 self.instance_id,
515 session_id,
516 connections.len()
517 );
518
519 if let Some(session_connections) = connections.get(session_id) {
520 debug!(
521 "๐ Session {} found with {} connections",
522 session_id,
523 session_connections.len()
524 );
525
526 if !session_connections.is_empty() {
527 let (selected_connection_id, selected_sender) =
529 session_connections.iter().next().unwrap();
530
531 if selected_sender.is_closed() {
533 warn!(
534 "๐ Sender is closed for connection: session={}, connection={}",
535 session_id, selected_connection_id
536 );
537 debug!("๐ญ Connection sender was closed, event stored for reconnection");
538 } else {
539 debug!(
540 "โ
Sender is open, attempting to send to connection: session={}, connection={}",
541 session_id, selected_connection_id
542 );
543
544 match selected_sender.try_send(stored_event.clone()) {
545 Ok(()) => {
546 debug!(
547 "Sent notification to ONE connection: session={}, connection={}, event_id={}, method={}",
548 session_id,
549 selected_connection_id,
550 stored_event.id,
551 stored_event.event_type
552 );
553 }
554 Err(mpsc::error::TrySendError::Full(_)) => {
555 warn!(
556 "โ ๏ธ Connection buffer full: session={}, connection={}",
557 session_id, selected_connection_id
558 );
559 }
561 Err(mpsc::error::TrySendError::Closed(_)) => {
562 warn!(
563 "๐ Connection closed during send: session={}, connection={}",
564 session_id, selected_connection_id
565 );
566 }
568 }
569 }
570 } else {
571 debug!(
572 "๐ญ No active connections for session: {} (event stored for reconnection)",
573 session_id
574 );
575 }
576 } else {
577 debug!(
578 "๐ญ No connections registered for session: {} (event stored for reconnection)",
579 session_id
580 );
581
582 for (sid, conns) in connections.iter() {
584 debug!(
585 "๐ Available session: {} with {} connections",
586 sid,
587 conns.len()
588 );
589 }
590 }
591
592 Ok(stored_event.id)
593 }
594
595 pub async fn broadcast_to_all_sessions(
597 &self,
598 event_type: String,
599 data: Value,
600 ) -> Result<Vec<String>, StreamError> {
601 let session_ids = self
603 .storage
604 .list_sessions()
605 .await
606 .map_err(|e| StreamError::StorageError(e.to_string()))?;
607
608 let mut failed_sessions = Vec::new();
609
610 for session_id in session_ids {
611 if let Err(e) = self
612 .broadcast_to_session(&session_id, event_type.clone(), data.clone())
613 .await
614 {
615 error!("Failed to broadcast to session {}: {}", session_id, e);
616 failed_sessions.push(session_id);
617 }
618 }
619
620 Ok(failed_sessions)
621 }
622
623 pub async fn cleanup_connections(&self) -> usize {
625 debug!("๐งน CLEANUP_CONNECTIONS called");
626 let mut connections = self.connections.write().await;
627 let mut total_cleaned = 0;
628
629 debug!(
630 "๐ BEFORE cleanup: HashMap has {} sessions",
631 connections.len()
632 );
633
634 connections.retain(|session_id, session_connections| {
636 let initial_count = session_connections.len();
637
638 session_connections.retain(|connection_id, sender| {
640 if sender.is_closed() {
641 debug!(
642 "๐งน Cleaned up closed connection: session={}, connection={}",
643 session_id, connection_id
644 );
645 false
646 } else {
647 true
648 }
649 });
650
651 let cleaned_count = initial_count - session_connections.len();
652 total_cleaned += cleaned_count;
653
654 !session_connections.is_empty()
656 });
657
658 if total_cleaned > 0 {
659 debug!("Cleaned up {} inactive connections", total_cleaned);
660 }
661
662 total_cleaned
663 }
664
665 pub async fn create_post_sse_stream(
667 &self,
668 session_id: String,
669 response: turul_mcp_json_rpc_server::JsonRpcResponse,
670 ) -> Result<
671 hyper::Response<
672 http_body_util::combinators::BoxBody<bytes::Bytes, std::convert::Infallible>,
673 >,
674 StreamError,
675 > {
676 if self
678 .storage
679 .get_session(&session_id)
680 .await
681 .map_err(|e| StreamError::StorageError(e.to_string()))?
682 .is_none()
683 {
684 return Err(StreamError::SessionNotFound(session_id));
685 }
686
687 debug!("Creating POST SSE stream for session: {}", session_id);
688
689 let response_json = serde_json::to_string(&response).map_err(|e| {
691 StreamError::StorageError(format!("Failed to serialize response: {}", e))
692 })?;
693
694 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
698
699 let mut sse_frames = Vec::new();
700 let mut event_id_counter = 1;
701
702 if let Ok(events) = self.storage.get_recent_events(&session_id, 10).await {
703 for event in events {
704 if event.event_type != "ping" {
706 let notification_sse = format!(
708 "id: {}\nevent: {}\ndata: {}\n\n",
709 event_id_counter,
710 event.event_type, event.data
712 );
713 debug!(
714 "๐ค Including notification in POST SSE stream: id={}, event_type={}",
715 event_id_counter, event.event_type
716 );
717 sse_frames.push(http_body::Frame::data(Bytes::from(notification_sse)));
718 event_id_counter += 1;
719 }
720 }
721 }
722
723 let response_sse = format!(
725 "id: {}\nevent: result\ndata: {}\n\n", event_id_counter, response_json
727 );
728 debug!(
729 "๐ค Sending JSON-RPC response as SSE event: id={}, event=result",
730 event_id_counter
731 );
732 sse_frames.push(http_body::Frame::data(Bytes::from(response_sse)));
733
734 let stream = futures::stream::iter(
736 sse_frames
737 .into_iter()
738 .map(Ok::<_, std::convert::Infallible>),
739 );
740
741 let body = StreamBody::new(stream);
743 let boxed_body = http_body_util::combinators::BoxBody::new(body);
744
745 debug!(
746 "๐ก POST SSE streaming response created: session={}",
747 session_id
748 );
749
750 Ok(hyper::Response::builder()
752 .status(hyper::StatusCode::OK)
753 .header(hyper::header::CONTENT_TYPE, "text/event-stream")
754 .header(hyper::header::CACHE_CONTROL, "no-cache")
755 .header(
756 hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN,
757 &self.config.cors_origin,
758 )
759 .header("Connection", "keep-alive")
760 .header("X-Accel-Buffering", "no") .header("Mcp-Session-Id", &session_id)
762 .body(boxed_body)
763 .unwrap())
764 }
765
766 pub async fn subscribe_to_notifications(
768 &self,
769 session_id: &str,
770 notification_types: Vec<String>,
771 ) {
772 let mut subscriptions = self.subscriptions.write().await;
773 let session_subscriptions = subscriptions
774 .entry(session_id.to_string())
775 .or_insert_with(HashSet::new);
776
777 for notification_type in notification_types {
778 session_subscriptions.insert(notification_type.clone());
779 debug!(
780 "๐ Session {} subscribed to notification: {}",
781 session_id, notification_type
782 );
783 }
784
785 debug!(
786 "Session {} now has {} subscriptions",
787 session_id,
788 session_subscriptions.len()
789 );
790 }
791
792 pub async fn unsubscribe_from_notifications(
794 &self,
795 session_id: &str,
796 notification_types: Vec<String>,
797 ) {
798 let mut subscriptions = self.subscriptions.write().await;
799 if let Some(session_subscriptions) = subscriptions.get_mut(session_id) {
800 for notification_type in notification_types {
801 if session_subscriptions.remove(¬ification_type) {
802 debug!(
803 "๐ Session {} unsubscribed from notification: {}",
804 session_id, notification_type
805 );
806 }
807 }
808
809 if session_subscriptions.is_empty() {
811 subscriptions.remove(session_id);
812 debug!(
813 "๐๏ธ Removed subscription entry for session {} (no remaining subscriptions)",
814 session_id
815 );
816 }
817 }
818 }
819
820 pub async fn is_subscribed(&self, session_id: &str, notification_type: &str) -> bool {
822 let subscriptions = self.subscriptions.read().await;
823 subscriptions
824 .get(session_id)
825 .map(|session_subscriptions| session_subscriptions.contains(notification_type))
826 .unwrap_or(true) }
828
829 pub async fn get_subscriptions(&self, session_id: &str) -> HashSet<String> {
831 let subscriptions = self.subscriptions.read().await;
832 subscriptions.get(session_id).cloned().unwrap_or_default()
833 }
834
835 pub async fn clear_subscriptions(&self, session_id: &str) {
837 let mut subscriptions = self.subscriptions.write().await;
838 if subscriptions.remove(session_id).is_some() {
839 debug!("๐๏ธ Cleared all subscriptions for session: {}", session_id);
840 }
841 }
842
843 pub fn get_config(&self) -> &StreamConfig {
845 &self.config
846 }
847
848 pub async fn get_stats(&self) -> StreamStats {
850 let connections = self.connections.read().await;
851 let session_count = self.storage.session_count().await.unwrap_or(0);
852 let event_count = self.storage.event_count().await.unwrap_or(0);
853
854 let total_connections: usize = connections
856 .values()
857 .map(|session_connections| session_connections.len())
858 .sum();
859
860 StreamStats {
861 active_broadcasters: total_connections, total_sessions: session_count,
863 total_events: event_count,
864 channel_buffer_size: self.config.channel_buffer_size,
865 }
866 }
867}
868
869impl Drop for StreamManager {
870 fn drop(&mut self) {
871 debug!(
872 "DROP: StreamManager instance {} - this may cause connection loss!",
873 self.instance_id
874 );
875 debug!("If this appears during request processing, it indicates architecture problem");
876 }
877}
878
879#[derive(Debug, Clone)]
881pub struct StreamStats {
882 pub active_broadcasters: usize,
883 pub total_sessions: usize,
884 pub total_events: usize,
885 pub channel_buffer_size: usize,
886}
887
888#[cfg(not(test))]
890use async_stream;
891
892#[cfg(test)]
893mod tests {
894 use super::*;
895 use turul_mcp_protocol::ServerCapabilities;
896 use turul_mcp_session_storage::{InMemorySessionStorage, SessionStorage};
897
898 #[tokio::test]
899 async fn test_stream_manager_creation() {
900 let storage = Arc::new(InMemorySessionStorage::new());
901 let manager = StreamManager::new(storage);
902
903 let stats = manager.get_stats().await;
904 assert_eq!(stats.active_broadcasters, 0);
905 assert_eq!(stats.total_sessions, 0);
906 }
907
908 #[tokio::test]
909 async fn test_broadcast_to_session() {
910 let storage = Arc::new(InMemorySessionStorage::new());
911 let manager = StreamManager::new(storage.clone());
912
913 let session = storage
915 .create_session(ServerCapabilities::default())
916 .await
917 .unwrap();
918 let session_id = session.session_id.clone();
919
920 let event_id = manager
922 .broadcast_to_session(
923 &session_id,
924 "test".to_string(),
925 serde_json::json!({"message": "test"}),
926 )
927 .await
928 .unwrap();
929
930 assert!(event_id > 0);
931
932 let events = storage.get_events_after(&session_id, 0).await.unwrap();
934 assert_eq!(events.len(), 1);
935 assert_eq!(events[0].id, event_id);
936 }
937}