1use std::sync::Arc;
11use std::collections::HashMap;
12use std::pin::Pin;
13use futures::{Stream, StreamExt};
14use hyper::{Response, StatusCode};
15use http_body_util::{StreamBody, BodyExt};
16use bytes::Bytes;
17use hyper::header::{CONTENT_TYPE, CACHE_CONTROL, ACCESS_CONTROL_ALLOW_ORIGIN};
18use serde_json::Value;
19use tokio::sync::{mpsc, RwLock};
20use tracing::{debug, info, error, warn};
21
22use turul_mcp_session_storage::SseEvent;
23
24pub type ConnectionId = String;
26
27pub struct StreamManager {
29 storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
31 connections: Arc<RwLock<HashMap<String, HashMap<ConnectionId, mpsc::Sender<SseEvent>>>>>,
33 config: StreamConfig,
35 instance_id: String,
37}
38
39#[derive(Debug, Clone)]
41pub struct StreamConfig {
42 pub channel_buffer_size: usize,
44 pub max_replay_events: usize,
46 pub keepalive_interval_seconds: u64,
48 pub cors_origin: String,
50}
51
52impl Default for StreamConfig {
53 fn default() -> Self {
54 Self {
55 channel_buffer_size: 1000,
56 max_replay_events: 100,
57 keepalive_interval_seconds: 30,
58 cors_origin: "*".to_string(),
59 }
60 }
61}
62
63pub struct SseStream {
65 stream: Option<Pin<Box<dyn Stream<Item = SseEvent> + Send>>>,
67 session_id: String,
69 connection_id: ConnectionId,
71}
72
73impl SseStream {
74 pub fn session_id(&self) -> &str {
76 &self.session_id
77 }
78
79 pub fn connection_id(&self) -> &str {
81 &self.connection_id
82 }
83
84 pub fn stream_identifier(&self) -> String {
86 format!("{}:{}", self.session_id, self.connection_id)
87 }
88}
89
90impl Drop for SseStream {
91 fn drop(&mut self) {
92 debug!("๐ฅ DROP: SseStream - session={}, connection={}",
93 self.session_id, self.connection_id);
94 if self.stream.is_some() {
95 debug!("๐ฅ Stream still present during drop - this indicates early cleanup");
96 } else {
97 debug!("๐ฅ Stream was properly extracted before drop");
98 }
99 }
100}
101
102#[derive(Debug, thiserror::Error)]
104pub enum StreamError {
105 #[error("Session not found: {0}")]
106 SessionNotFound(String),
107 #[error("Stream not found: session={0}, stream={1}")]
108 StreamNotFound(String, String),
109 #[error("Storage error: {0}")]
110 StorageError(String),
111 #[error("Connection error: {0}")]
112 ConnectionError(String),
113}
114
115impl StreamManager {
116 pub fn new(storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>) -> Self {
118 Self::with_config(storage, StreamConfig::default())
119 }
120
121 pub fn with_config(storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>, config: StreamConfig) -> Self {
123 use uuid::Uuid;
124 let instance_id = Uuid::now_v7().to_string();
125 debug!("๐ง Creating StreamManager instance: {}", instance_id);
126 Self {
127 storage,
128 connections: Arc::new(RwLock::new(HashMap::new())),
129 config,
130 instance_id,
131 }
132 }
133
134 pub async fn handle_sse_connection(
136 &self,
137 session_id: String,
138 connection_id: ConnectionId,
139 last_event_id: Option<u64>,
140 ) -> Result<Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>>, StreamError> {
141 if self.storage.get_session(&session_id).await
143 .map_err(|e| StreamError::StorageError(e.to_string()))?
144 .is_none()
145 {
146 return Err(StreamError::SessionNotFound(session_id));
147 }
148
149 let sse_stream = self.create_sse_stream(session_id.clone(), connection_id.clone(), last_event_id).await?;
151
152 let response = self.stream_to_response(sse_stream).await;
154
155 info!("Created SSE connection: session={}, connection={}, last_event_id={:?}",
156 session_id, connection_id, last_event_id);
157
158 Ok(response)
159 }
160
161 async fn create_sse_stream(
163 &self,
164 session_id: String,
165 connection_id: ConnectionId,
166 last_event_id: Option<u64>,
167 ) -> Result<SseStream, StreamError> {
168 let (sender, mut receiver) = mpsc::channel(self.config.channel_buffer_size);
170
171 self.register_connection(&session_id, connection_id.clone(), sender).await;
173
174 let storage = self.storage.clone();
176 let session_id_clone = session_id.clone();
177 let connection_id_clone = connection_id.clone();
178 let config = self.config.clone();
179
180 let combined_stream = async_stream::stream! {
181 if let Some(after_event_id) = last_event_id {
183 debug!("Replaying events after ID {} for session={}, connection={}",
184 after_event_id, session_id_clone, connection_id_clone);
185
186 match storage.get_events_after(&session_id_clone, after_event_id).await {
187 Ok(events) => {
188 for event in events.into_iter().take(config.max_replay_events) {
189 yield event;
190 }
191 },
192 Err(e) => {
193 error!("Failed to get historical events: {}", e);
194 }
196 }
197 }
198
199 let mut keepalive_interval = tokio::time::interval(
201 tokio::time::Duration::from_secs(config.keepalive_interval_seconds)
202 );
203
204 loop {
205 tokio::select! {
206 event = receiver.recv() => {
208 match event {
209 Some(event) => {
210 debug!("๐จ Received event for connection {}: {}", connection_id_clone, event.event_type);
211 yield event;
212 },
213 None => {
214 debug!("Connection channel closed for session={}, connection={}", session_id_clone, connection_id_clone);
215 break;
216 }
217 }
218 },
219
220 _ = keepalive_interval.tick() => {
222 let keepalive_event = SseEvent {
223 id: 0, timestamp: chrono::Utc::now().timestamp_millis() as u64,
225 event_type: "ping".to_string(),
226 data: serde_json::json!({"type": "keepalive"}),
227 retry: None,
228 };
229 yield keepalive_event;
230 }
231 }
232 }
233
234 debug!("๐งน Cleaning up connection: session={}, connection={}", session_id_clone, connection_id_clone);
236 };
237
238 Ok(SseStream {
239 stream: Some(Box::pin(combined_stream)),
240 session_id,
241 connection_id,
242 })
243 }
244
245 async fn register_connection(
247 &self,
248 session_id: &str,
249 connection_id: ConnectionId,
250 sender: mpsc::Sender<SseEvent>
251 ) {
252 let mut connections = self.connections.write().await;
253
254 debug!("[{}] ๐ BEFORE registration: HashMap has {} sessions", self.instance_id, connections.len());
255 for (sid, conns) in connections.iter() {
256 debug!("[{}] ๐ Existing session before: {} with {} connections", self.instance_id, sid, conns.len());
257 }
258
259 let session_connections = connections.entry(session_id.to_string())
261 .or_insert_with(HashMap::new);
262
263 session_connections.insert(connection_id.clone(), sender);
265
266 debug!("[{}] ๐ Registered connection: session={}, connection={}, total_connections={}",
267 self.instance_id, session_id, connection_id, session_connections.len());
268
269 debug!("[{}] ๐ AFTER registration: HashMap has {} sessions", self.instance_id, connections.len());
270 for (sid, conns) in connections.iter() {
271 debug!("[{}] ๐ Session after: {} with {} connections", self.instance_id, sid, conns.len());
272 }
273 }
274
275 pub async fn unregister_connection(&self, session_id: &str, connection_id: &ConnectionId) {
277 debug!("๐ด UNREGISTER called for session={}, connection={}", session_id, connection_id);
278 let mut connections = self.connections.write().await;
279
280 debug!("๐ BEFORE unregister: HashMap has {} sessions", connections.len());
281
282 if let Some(session_connections) = connections.get_mut(session_id) {
283 if session_connections.remove(connection_id).is_some() {
284 debug!("๐ Unregistered connection: session={}, connection={}", session_id, connection_id);
285
286 if session_connections.is_empty() {
288 connections.remove(session_id);
289 debug!("๐งน Removed empty session: {}", session_id);
290 }
291 }
292 }
293
294 debug!("๐ AFTER unregister: HashMap has {} sessions", connections.len());
295 }
296
297 async fn stream_to_response(&self, mut sse_stream: SseStream) -> Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>> {
299 let session_id = sse_stream.session_id().to_string();
301 let stream_identifier = sse_stream.stream_identifier();
302
303 info!("Converting SSE stream to HTTP response: {}", stream_identifier);
305 debug!("Stream details: session_id={}", session_id);
306
307 let stream = sse_stream.stream.take().expect("Stream should be present in SseStream");
310
311 let formatted_stream = stream.map(|event| {
312 let sse_formatted = event.format();
313 debug!("๐ก Streaming SSE event: id={}, event_type={}", event.id, event.event_type);
314 Ok(hyper::body::Frame::data(Bytes::from(sse_formatted)))
315 });
316
317 let body = StreamBody::new(formatted_stream).boxed_unsync();
319
320 Response::builder()
322 .status(StatusCode::OK)
323 .header(CONTENT_TYPE, "text/event-stream")
324 .header(CACHE_CONTROL, "no-cache")
325 .header(ACCESS_CONTROL_ALLOW_ORIGIN, &self.config.cors_origin)
326 .header("Connection", "keep-alive")
327 .body(body)
328 .unwrap()
329 }
330
331 pub async fn broadcast_to_session(
333 &self,
334 session_id: &str,
335 event_type: String,
336 data: Value,
337 ) -> Result<u64, StreamError> {
338 let event = SseEvent::new(event_type, data);
340
341 let stored_event = self.storage.store_event(session_id, event).await
343 .map_err(|e| StreamError::StorageError(e.to_string()))?;
344
345 let connections = self.connections.read().await;
347 debug!("[{}] ๐ Checking connections for session {}: connections hashmap has {} sessions",
348 self.instance_id, session_id, connections.len());
349
350 if let Some(session_connections) = connections.get(session_id) {
351 debug!("๐ Session {} found with {} connections", session_id, session_connections.len());
352
353 if !session_connections.is_empty() {
354 let (selected_connection_id, selected_sender) = session_connections.iter().next().unwrap();
356
357 if selected_sender.is_closed() {
359 warn!("๐ Sender is closed for connection: session={}, connection={}",
360 session_id, selected_connection_id);
361 debug!("๐ญ Connection sender was closed, event stored for reconnection");
362 } else {
363 debug!("โ
Sender is open, attempting to send to connection: session={}, connection={}",
364 session_id, selected_connection_id);
365
366 match selected_sender.try_send(stored_event.clone()) {
367 Ok(()) => {
368 info!("โ
Sent notification to ONE connection: session={}, connection={}, event_id={}, method={}",
369 session_id, selected_connection_id, stored_event.id, stored_event.event_type);
370 },
371 Err(mpsc::error::TrySendError::Full(_)) => {
372 warn!("โ ๏ธ Connection buffer full: session={}, connection={}", session_id, selected_connection_id);
373 },
375 Err(mpsc::error::TrySendError::Closed(_)) => {
376 warn!("๐ Connection closed during send: session={}, connection={}", session_id, selected_connection_id);
377 }
379 }
380 }
381 } else {
382 debug!("๐ญ No active connections for session: {} (event stored for reconnection)", session_id);
383 }
384 } else {
385 debug!("๐ญ No connections registered for session: {} (event stored for reconnection)", session_id);
386
387 for (sid, conns) in connections.iter() {
389 debug!("๐ Available session: {} with {} connections", sid, conns.len());
390 }
391 }
392
393 Ok(stored_event.id)
394 }
395
396 pub async fn broadcast_to_all_sessions(
398 &self,
399 event_type: String,
400 data: Value,
401 ) -> Result<Vec<String>, StreamError> {
402 let session_ids = self.storage.list_sessions().await
404 .map_err(|e| StreamError::StorageError(e.to_string()))?;
405
406 let mut failed_sessions = Vec::new();
407
408 for session_id in session_ids {
409 if let Err(e) = self.broadcast_to_session(&session_id, event_type.clone(), data.clone()).await {
410 error!("Failed to broadcast to session {}: {}", session_id, e);
411 failed_sessions.push(session_id);
412 }
413 }
414
415 Ok(failed_sessions)
416 }
417
418 pub async fn cleanup_connections(&self) -> usize {
420 debug!("๐งน CLEANUP_CONNECTIONS called");
421 let mut connections = self.connections.write().await;
422 let mut total_cleaned = 0;
423
424 debug!("๐ BEFORE cleanup: HashMap has {} sessions", connections.len());
425
426 connections.retain(|session_id, session_connections| {
428 let initial_count = session_connections.len();
429
430 session_connections.retain(|connection_id, sender| {
432 if sender.is_closed() {
433 debug!("๐งน Cleaned up closed connection: session={}, connection={}", session_id, connection_id);
434 false
435 } else {
436 true
437 }
438 });
439
440 let cleaned_count = initial_count - session_connections.len();
441 total_cleaned += cleaned_count;
442
443 !session_connections.is_empty()
445 });
446
447 if total_cleaned > 0 {
448 info!("Cleaned up {} inactive connections", total_cleaned);
449 }
450
451 total_cleaned
452 }
453
454 pub async fn create_post_sse_stream(
456 &self,
457 session_id: String,
458 response: turul_mcp_json_rpc_server::JsonRpcResponse,
459 ) -> Result<hyper::Response<http_body_util::Full<bytes::Bytes>>, StreamError> {
460 if self.storage.get_session(&session_id).await
462 .map_err(|e| StreamError::StorageError(e.to_string()))?
463 .is_none()
464 {
465 return Err(StreamError::SessionNotFound(session_id));
466 }
467
468 info!("Creating POST SSE stream for session: {}", session_id);
469
470 let response_json = serde_json::to_string(&response)
472 .map_err(|e| StreamError::StorageError(format!("Failed to serialize response: {}", e)))?;
473
474 let mut sse_content = String::new();
480
481 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
485
486 if let Ok(events) = self.storage.get_recent_events(&session_id, 10).await {
487 for event in events {
488 if event.event_type != "ping" { let notification_sse = format!(
491 "event: data\ndata: {}\n\n",
492 event.data
493 );
494 sse_content.push_str(¬ification_sse);
495 debug!("๐ค Including notification in POST SSE stream: event_type={}", event.event_type);
496 }
497 }
498 }
499
500 let response_sse = format!(
502 "event: data\ndata: {}\n\n",
503 response_json
504 );
505 sse_content.push_str(&response_sse);
506
507 debug!("๐ก POST SSE response created: session={}, content_length={}", session_id, sse_content.len());
508
509 Ok(hyper::Response::builder()
511 .status(hyper::StatusCode::OK)
512 .header(hyper::header::CONTENT_TYPE, "text/event-stream")
513 .header(hyper::header::CACHE_CONTROL, "no-cache")
514 .header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, &self.config.cors_origin)
515 .header("Connection", "keep-alive")
516 .header("Mcp-Session-Id", &session_id)
517 .body(http_body_util::Full::new(bytes::Bytes::from(sse_content)))
518 .unwrap())
519 }
520
521 pub async fn get_stats(&self) -> StreamStats {
523 let connections = self.connections.read().await;
524 let session_count = self.storage.session_count().await.unwrap_or(0);
525 let event_count = self.storage.event_count().await.unwrap_or(0);
526
527 let total_connections: usize = connections.values()
529 .map(|session_connections| session_connections.len())
530 .sum();
531
532 StreamStats {
533 active_broadcasters: total_connections, total_sessions: session_count,
535 total_events: event_count,
536 channel_buffer_size: self.config.channel_buffer_size,
537 }
538 }
539}
540
541impl Drop for StreamManager {
542 fn drop(&mut self) {
543 debug!("๐ฅ DROP: StreamManager instance {} - this may cause connection loss!",
544 self.instance_id);
545 debug!("๐ฅ If this appears during request processing, it indicates architecture problem");
546 }
547}
548
549#[derive(Debug, Clone)]
551pub struct StreamStats {
552 pub active_broadcasters: usize,
553 pub total_sessions: usize,
554 pub total_events: usize,
555 pub channel_buffer_size: usize,
556}
557
558#[cfg(not(test))]
560use async_stream;
561
562#[cfg(test)]
563mod tests {
564 use super::*;
565 use turul_mcp_session_storage::InMemorySessionStorage;
566 use turul_mcp_protocol::ServerCapabilities;
567
568 #[tokio::test]
569 async fn test_stream_manager_creation() {
570 let storage = Arc::new(InMemorySessionStorage::new());
571 let manager = StreamManager::new(storage);
572
573 let stats = manager.get_stats().await;
574 assert_eq!(stats.active_broadcasters, 0);
575 assert_eq!(stats.total_sessions, 0);
576 }
577
578 #[tokio::test]
579 async fn test_broadcast_to_session() {
580 let storage = Arc::new(InMemorySessionStorage::new());
581 let manager = StreamManager::new(storage.clone());
582
583 let session = storage.create_session(ServerCapabilities::default()).await.unwrap();
585 let session_id = session.session_id.clone();
586
587 let event_id = manager.broadcast_to_session(
589 &session_id,
590 "test".to_string(),
591 serde_json::json!({"message": "test"})
592 ).await.unwrap();
593
594 assert!(event_id > 0);
595
596 let events = storage.get_events_after(&session_id, 0).await.unwrap();
598 assert_eq!(events.len(), 1);
599 assert_eq!(events[0].id, event_id);
600 }
601}